diff --git a/src/main/java/xyz/wbsite/achat/ChatController.java b/src/main/java/xyz/wbsite/achat/ChatController.java deleted file mode 100644 index 3a07edd..0000000 --- a/src/main/java/xyz/wbsite/achat/ChatController.java +++ /dev/null @@ -1,110 +0,0 @@ -//package xyz.wbsite.achat; -// -//import org.springframework.web.bind.annotation.DeleteMapping; -//import org.springframework.web.bind.annotation.GetMapping; -//import org.springframework.web.bind.annotation.PathVariable; -//import org.springframework.web.bind.annotation.PostMapping; -//import org.springframework.web.bind.annotation.RequestBody; -//import org.springframework.web.bind.annotation.ResponseBody; -//import org.springframework.web.bind.annotation.RestController; -//import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; -//import xyz.wbsite.achat.core.chat.Message; -//import xyz.wbsite.achat.core.Result; -//import xyz.wbsite.achat.core.Session; -//import xyz.wbsite.achat.core.prompt.MessagePrompt; -//import xyz.wbsite.achat.core.service.SessionService; -// -//import javax.annotation.Resource; -//import java.util.List; -//import java.util.UUID; -// -///** -// * AI会话接口. -// * -// * @author wangbing -// * @version 0.0.1 -// * @since 1.8 -// */ -//@RestController -//public class ChatController { -// -// @Resource -// private SessionService sessionService; -// -// /** -// * 门户 -// */ -// @GetMapping("/") -// public String chat() { -// return "AChat is running!"; -// } -// -// /** -// * 发送消息(流式响应) -// * -// * @return SSE发射器,用于流式响应 -// */ -// @PostMapping(value = "/chat/send", produces = "text/event-stream;charset=UTF-8") -// public SseEmitter send(@RequestBody MessagePrompt message) { -// return sessionService.sendMessage(message); -// } -// -// /** -// * 创建会话 -// * -// * @return 创建会话响应 -// */ -// @PostMapping("/chat/session") -// public Result createSession() { -// String sid = UUID.randomUUID().toString(); -// return sessionService.createSession(sid); -// } -// -// /** -// * 获取会话列表 -// * -// * @return 会话列表响应 -// */ -// @ResponseBody -// @GetMapping("/chat/session/list") -// public Result> listSession() { -// String sid = UUID.randomUUID().toString(); -// return sessionService.listSessions(sid); -// } -// -// /** -// * 删除会话 -// * -// * @param sid 会话ID -// * @return 删除会话响应 -// */ -// @ResponseBody -// @DeleteMapping("/chat/session/{sid}") -// public Result deleteSession(@PathVariable("sid") String sid) { -// return sessionService.deleteSession(sid); -// } -// -// /** -// * 获取会话历史消息 -// * -// * @param sid 会话ID -// * @return 历史消息响应 -// */ -// @ResponseBody -// @GetMapping("/chat/session/{sid}/history") -// public Result> getSessionHistory(@PathVariable("sid") String sid) { -// return sessionService.listMessage(sid); -// } -// -// /** -// * 停止AI回答 -// * -// * @param sid 会话ID -// * @return 停止回答响应 -// */ -// @ResponseBody -// @PostMapping("/chat/session/{sid}/stop") -// public Result stopSession(@PathVariable("sid") String sid) { -// return sessionService.stopSession(sid); -// } -//} diff --git a/src/main/java/xyz/wbsite/achat/HelloController.java b/src/main/java/xyz/wbsite/achat/HelloController.java new file mode 100644 index 0000000..c46ac18 --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/HelloController.java @@ -0,0 +1,23 @@ +package xyz.wbsite.achat; + +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RestController; + +/** + * 门户入口 + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ +@RestController +public class HelloController { + + /** + * 门户 + */ + @GetMapping("/") + public String chat() { + return "AChat is running!"; + } +} diff --git a/src/main/java/xyz/wbsite/achat/OpenAiController.java b/src/main/java/xyz/wbsite/achat/OpenAiController.java index 8545578..c23a954 100644 --- a/src/main/java/xyz/wbsite/achat/OpenAiController.java +++ b/src/main/java/xyz/wbsite/achat/OpenAiController.java @@ -12,7 +12,7 @@ import xyz.wbsite.achat.core.chat.CompletionResponse; import xyz.wbsite.achat.core.chat.EmbeddingsRequest; import xyz.wbsite.achat.core.chat.EmbeddingsResponse; import xyz.wbsite.achat.core.model.ModelListResponse; -import xyz.wbsite.achat.core.service.ChatService; +import xyz.wbsite.achat.core.chat.ChatService; import javax.annotation.Resource; diff --git a/src/main/java/xyz/wbsite/achat/SessionController.java b/src/main/java/xyz/wbsite/achat/SessionController.java new file mode 100644 index 0000000..a216028 --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/SessionController.java @@ -0,0 +1,86 @@ +package xyz.wbsite.achat; + +import org.springframework.web.bind.annotation.DeleteMapping; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RestController; +import xyz.wbsite.achat.core.session.Result; +import xyz.wbsite.achat.core.session.Session; +import xyz.wbsite.achat.core.session.SessionService; + +import javax.annotation.Resource; +import java.util.List; + +/** + * 会话服务接口 + * 提供会话、消息的创建、删除、查询等功能 + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ +@RestController +public class SessionController { + + @Resource + private SessionService sessionService; + + /** + * 创建会话 + * + * @param uid 用户标识 + * @return 创建会话响应 + */ + @PostMapping("/{uid}/session") + public Result createSession(@PathVariable("uid") String uid) { + return sessionService.createSession(uid); + } + + /** + * 删除会话 + * + * @param uid 用户标识 + * @param sid 会话ID + * @return 删除会话响应 + */ + @DeleteMapping("/{uid}/session/{sid}") + public Result deleteSession(@PathVariable("uid") String uid, @PathVariable("sid") String sid) { + return sessionService.deleteSession(uid, sid); + } + + /** + * 会话详情 + * + * @param uid 用户标识 + * @param sid 会话ID + * @return 删除会话响应 + */ + @GetMapping("/{uid}/session/{sid}") + public Result getSession(@PathVariable("uid") String uid, @PathVariable("sid") String sid) { + return sessionService.getSession(uid, sid); + } + + /** + * 获取会话列表 + * + * @param uid 用户标识 + * @return 会话列表响应 + */ + @GetMapping("/{uid}/session/list") + public Result> listSession(@PathVariable("uid") String uid) { + return sessionService.listSessions(uid); + } + + /** + * 删除会话历史消息 + * + * @param uid 用户标识 + * @param sid 会话ID + * @return 删除会话响应 + */ + @DeleteMapping("/{uid}/session/{sid}/messages/{mid}") + public Result deleteMessage(@PathVariable("uid") String uid, @PathVariable("sid") String sid, @PathVariable("mid") String mid) { + return sessionService.deleteMessage(uid, sid, mid); + } +} diff --git a/src/main/java/xyz/wbsite/achat/config/ChatConfig.java b/src/main/java/xyz/wbsite/achat/config/ChatConfig.java index 903d648..1394e03 100644 --- a/src/main/java/xyz/wbsite/achat/config/ChatConfig.java +++ b/src/main/java/xyz/wbsite/achat/config/ChatConfig.java @@ -3,8 +3,8 @@ package xyz.wbsite.achat.config; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; -import xyz.wbsite.achat.core.service.ChatService; -import xyz.wbsite.achat.core.service.impl.ChatServiceSampleImpl; +import xyz.wbsite.achat.core.chat.ChatService; +import xyz.wbsite.achat.core.chat.ChatServiceSampleImpl; /** * 对话配置 diff --git a/src/main/java/xyz/wbsite/achat/config/ThreadPoolConfig.java b/src/main/java/xyz/wbsite/achat/config/ThreadPoolConfig.java new file mode 100644 index 0000000..1d98945 --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/config/ThreadPoolConfig.java @@ -0,0 +1,66 @@ +package xyz.wbsite.achat.config; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; + +import java.util.concurrent.ThreadPoolExecutor; + +/** + * 线程池配置. + * 该类用于配置全局的线程池,用于执行异步任务。 + * 根据系统资源自动调整线程池参数,以获得最优性能。 + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ +@Configuration +public class ThreadPoolConfig { + + // 获取CPU核心数 + private final int cpuCount = Runtime.getRuntime().availableProcessors(); + + /** + * 线程池维护线程的最少数量 + * 对于IO密集型任务,核心线程数设置为CPU核心数,确保基础处理能力 + */ + private final int corePoolSize = Math.max(1, cpuCount); + + /** + * 线程池维护线程的最大数量 + * 动态获取CPU核数+1,最小为1,确保在高负载时有足够的线程处理请求 + */ + private final int maxPoolSize = Math.max(1, cpuCount + 1); + + /** + * 缓存队列容量 + * 对于IO密集型任务,使用较小的队列容量来更好地利用线程 + */ + private final int queueCapacity = 100; + + /** + * 允许的空闲时间 + * 空闲线程的最大存活时间,单位:秒 + */ + private final int keepAliveSeconds = 30; + + /** + * 注册全局线程池执行器 + */ + @Bean + public ThreadPoolTaskExecutor threadPoolTaskExecutor() { + ThreadPoolTaskExecutor threadPoolTaskExecutor = new ThreadPoolTaskExecutor(); + threadPoolTaskExecutor.setCorePoolSize(corePoolSize); + threadPoolTaskExecutor.setMaxPoolSize(maxPoolSize); + threadPoolTaskExecutor.setQueueCapacity(queueCapacity); + threadPoolTaskExecutor.setKeepAliveSeconds(keepAliveSeconds); + threadPoolTaskExecutor.setThreadNamePrefix("ThreadPool-"); + // rejection-policy:当pool已经达到max size的时候,如何处理新任务 + // CALLER_RUNS:不在新线程中公式任务,而是由调用者所在的线程来执行 + // 对拒绝task的处理策略 + threadPoolTaskExecutor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy()); + threadPoolTaskExecutor.initialize(); + return threadPoolTaskExecutor; + } +} diff --git a/src/main/java/xyz/wbsite/achat/core/Attachment.java b/src/main/java/xyz/wbsite/achat/core/Attachment.java deleted file mode 100644 index 39d3487..0000000 --- a/src/main/java/xyz/wbsite/achat/core/Attachment.java +++ /dev/null @@ -1,30 +0,0 @@ -//package xyz.wbsite.achat.core; -// -///** -// * 附件 -// * -// * @author wangbing -// * @version 0.0.1 -// * @since 1.8 -// */ -//public class Attachment { -// private String filename; -// -// private String fid; -// -// public String getFilename() { -// return filename; -// } -// -// public void setFilename(String filename) { -// this.filename = filename; -// } -// -// public String getFid() { -// return fid; -// } -// -// public void setFid(String fid) { -// this.fid = fid; -// } -//} diff --git a/src/main/java/xyz/wbsite/achat/core/Event.java b/src/main/java/xyz/wbsite/achat/core/Event.java deleted file mode 100644 index 37327fa..0000000 --- a/src/main/java/xyz/wbsite/achat/core/Event.java +++ /dev/null @@ -1,75 +0,0 @@ -//package xyz.wbsite.achat.core; -// -///** -// * 服务器推送事件 -// * -// * @author wangbing -// * @version 0.0.1 -// * @since 1.8 -// */ -//public class Event { -// -// private String id; -// -// private String object; -// -// private String model; -// -// private Long created; -// -// private String sid; -// -// private String uid; -// -// public Event() { -// this.created = System.currentTimeMillis(); -// } -// -// public String getId() { -// return id; -// } -// -// public void setId(String id) { -// this.id = id; -// } -// -// public String getObject() { -// return object; -// } -// -// public void setObject(String object) { -// this.object = object; -// } -// -// public String getModel() { -// return model; -// } -// -// public void setModel(String model) { -// this.model = model; -// } -// -// public Long getCreated() { -// return created; -// } -// -// public void setCreated(Long created) { -// this.created = created; -// } -// -// public String getSid() { -// return sid; -// } -// -// public void setSid(String sid) { -// this.sid = sid; -// } -// -// public String getUid() { -// return uid; -// } -// -// public void setUid(String uid) { -// this.uid = uid; -// } -//} diff --git a/src/main/java/xyz/wbsite/achat/core/Prompt.java b/src/main/java/xyz/wbsite/achat/core/Prompt.java deleted file mode 100644 index 657d56d..0000000 --- a/src/main/java/xyz/wbsite/achat/core/Prompt.java +++ /dev/null @@ -1,17 +0,0 @@ -//package xyz.wbsite.achat.core; -// -//public class Prompt { -// -// /** -// * 提示词 -// */ -// private String prompt; -// -// public String getPrompt() { -// return prompt; -// } -// -// public void setPrompt(String prompt) { -// this.prompt = prompt; -// } -//} diff --git a/src/main/java/xyz/wbsite/achat/core/Result.java b/src/main/java/xyz/wbsite/achat/core/Result.java deleted file mode 100644 index 1d37924..0000000 --- a/src/main/java/xyz/wbsite/achat/core/Result.java +++ /dev/null @@ -1,209 +0,0 @@ -//package xyz.wbsite.achat.core; -// -//import java.util.HashMap; -//import java.util.Map; -// -///** -// * 接口响应结果基类 -// * 泛型支持的数据响应封装,提供统一的响应格式和错误处理 -// * -// * @author wangbing -// * @version 0.0.1 -// * @since 1.8 -// */ -//public class Result { -// -// /** -// * 响应状态码 -// */ -// private int code = 200; -// -// /** -// * 响应消息 -// */ -// private String message = "success"; -// -// /** -// * 响应数据 -// */ -// private T data; -// -// /** -// * 是否成功 -// */ -// private boolean success = true; -// -// /** -// * 错误详情,用于表单验证等场景 -// */ -// private Map errors; -// -// /** -// * 请求ID,用于问题追踪 -// */ -// private String requestId; -// -// /** -// * 响应时间戳 -// */ -// private long timestamp = System.currentTimeMillis(); -// -// public int getCode() { -// return code; -// } -// -// public Result setCode(int code) { -// this.code = code; -// return this; -// } -// -// public String getMessage() { -// return message; -// } -// -// public Result setMessage(String message) { -// this.message = message; -// return this; -// } -// -// public T getData() { -// return data; -// } -// -// public Result setData(T data) { -// this.data = data; -// return this; -// } -// -// public boolean isSuccess() { -// return success; -// } -// -// public Result setSuccess(boolean success) { -// this.success = success; -// return this; -// } -// -// public Map getErrors() { -// return errors; -// } -// -// public Result setErrors(Map errors) { -// this.errors = errors; -// return this; -// } -// -// public String getRequestId() { -// return requestId; -// } -// -// public Result setRequestId(String requestId) { -// this.requestId = requestId; -// return this; -// } -// -// public long getTimestamp() { -// return timestamp; -// } -// -// public Result setTimestamp(long timestamp) { -// this.timestamp = timestamp; -// return this; -// } -// -// /** -// * 添加字段级别的错误信息 -// * -// * @param field 字段名 -// * @param error 错误信息 -// * @return 当前结果对象,支持链式调用 -// */ -// public Result addError(String field, String error) { -// if (errors == null) { -// errors = new HashMap<>(); -// } -// errors.put(field, error); -// return this; -// } -// -// /** -// * 返回成功信息 -// * -// * @return 结果 -// */ -// public static Result success() { -// return new Result<>(); -// } -// -// /** -// * 返回带数据的成功信息 -// * -// * @param data 响应数据 -// * @return 结果 -// */ -// public static Result success(T data) { -// Result result = new Result<>(); -// result.setData(data); -// return result; -// } -// -// /** -// * 返回错误信息 -// * -// * @param message 错误信息 -// * @return 错误信息对象 -// */ -// public static Result error(String message) { -// Result result = new Result<>(); -// result.message = message; -// result.code = 500; -// result.success = false; -// return result; -// } -// -// /** -// * 返回错误信息 -// * -// * @param code 错误码 -// * @param message 错误信息 -// * @return 错误信息对象 -// */ -// public static Result error(int code, String message) { -// Result result = new Result<>(); -// result.code = code; -// result.message = message; -// result.success = false; -// return result; -// } -// -// /** -// * 返回带数据的错误信息 -// * -// * @param code 错误码 -// * @param message 错误信息 -// * @param data 错误相关数据 -// * @return 错误信息对象 -// */ -// public static Result error(int code, String message, T data) { -// Result result = new Result<>(); -// result.code = code; -// result.message = message; -// result.success = false; -// result.data = data; -// return result; -// } -// -// /** -// * 从异常创建错误响应 -// * -// * @param e 异常 -// * @return 错误信息对象 -// */ -// public static Result error(Exception e) { -// Result result = new Result<>(); -// result.code = 500; -// result.message = e.getMessage() != null ? e.getMessage() : "系统异常"; -// result.success = false; -// return result; -// } -//} \ No newline at end of file diff --git a/src/main/java/xyz/wbsite/achat/core/Session.java b/src/main/java/xyz/wbsite/achat/core/Session.java deleted file mode 100644 index 9e862e8..0000000 --- a/src/main/java/xyz/wbsite/achat/core/Session.java +++ /dev/null @@ -1,126 +0,0 @@ -//package xyz.wbsite.achat.core; -// -// -///** -// * 会话 -// * -// * @author wangbing -// * @version 0.0.1 -// * @since 1.8 -// */ -//public class Session { -// /** -// * 主键 -// */ -// private String id; -// /** -// * 用户ID -// */ -// private String uid; -// private String title; -// private String model; -// private String prompt; -// private String temperature; -// private String topP; -// private String frequencyPenalty; -// private String presencePenalty; -// private String maxTokens; -// private String lastTime; -// private String lastMessage; -// -// public String getId() { -// return id; -// } -// -// public void setId(String id) { -// this.id = id; -// } -// -// public String getUid() { -// return uid; -// } -// -// public void setUid(String uid) { -// this.uid = uid; -// } -// -// public String getTitle() { -// return title; -// } -// -// public void setTitle(String title) { -// this.title = title; -// } -// -// public String getModel() { -// return model; -// } -// -// public void setModel(String model) { -// this.model = model; -// } -// -// public String getPrompt() { -// return prompt; -// } -// -// public void setPrompt(String prompt) { -// this.prompt = prompt; -// } -// -// public String getTemperature() { -// return temperature; -// } -// -// public void setTemperature(String temperature) { -// this.temperature = temperature; -// } -// -// public String getTopP() { -// return topP; -// } -// -// public void setTopP(String topP) { -// this.topP = topP; -// } -// -// public String getFrequencyPenalty() { -// return frequencyPenalty; -// } -// -// public void setFrequencyPenalty(String frequencyPenalty) { -// this.frequencyPenalty = frequencyPenalty; -// } -// -// public String getPresencePenalty() { -// return presencePenalty; -// } -// -// public void setPresencePenalty(String presencePenalty) { -// this.presencePenalty = presencePenalty; -// } -// -// public String getMaxTokens() { -// return maxTokens; -// } -// -// public void setMaxTokens(String maxTokens) { -// this.maxTokens = maxTokens; -// } -// -// public String getLastTime() { -// return lastTime; -// } -// -// public void setLastTime(String lastTime) { -// this.lastTime = lastTime; -// } -// -// public String getLastMessage() { -// return lastMessage; -// } -// -// public void setLastMessage(String lastMessage) { -// this.lastMessage = lastMessage; -// } -//} diff --git a/src/main/java/xyz/wbsite/achat/core/service/ChatService.java b/src/main/java/xyz/wbsite/achat/core/chat/ChatService.java similarity index 78% rename from src/main/java/xyz/wbsite/achat/core/service/ChatService.java rename to src/main/java/xyz/wbsite/achat/core/chat/ChatService.java index b44250f..c34f13c 100644 --- a/src/main/java/xyz/wbsite/achat/core/service/ChatService.java +++ b/src/main/java/xyz/wbsite/achat/core/chat/ChatService.java @@ -1,5 +1,6 @@ -package xyz.wbsite.achat.core.service; +package xyz.wbsite.achat.core.chat; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import xyz.wbsite.achat.core.chat.ChatCompletionRequest; import xyz.wbsite.achat.core.chat.ChatCompletionResponse; import xyz.wbsite.achat.core.chat.CompletionRequest; @@ -14,7 +15,7 @@ public interface ChatService { ChatCompletionResponse chat(ChatCompletionRequest request); - StreamEmitter streamChat(ChatCompletionRequest request); + SseEmitter streamChat(ChatCompletionRequest request); EmbeddingsResponse embeddings(EmbeddingsRequest request); } diff --git a/src/main/java/xyz/wbsite/achat/core/service/impl/ChatServiceSampleImpl.java b/src/main/java/xyz/wbsite/achat/core/chat/ChatServiceSampleImpl.java similarity index 50% rename from src/main/java/xyz/wbsite/achat/core/service/impl/ChatServiceSampleImpl.java rename to src/main/java/xyz/wbsite/achat/core/chat/ChatServiceSampleImpl.java index 12a5c99..9f46286 100644 --- a/src/main/java/xyz/wbsite/achat/core/service/impl/ChatServiceSampleImpl.java +++ b/src/main/java/xyz/wbsite/achat/core/chat/ChatServiceSampleImpl.java @@ -1,5 +1,7 @@ -package xyz.wbsite.achat.core.service.impl; +package xyz.wbsite.achat.core.chat; +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import xyz.wbsite.achat.core.chat.ChatCompletionRequest; import xyz.wbsite.achat.core.chat.ChatCompletionResponse; import xyz.wbsite.achat.core.chat.CompletionRequest; @@ -9,20 +11,31 @@ import xyz.wbsite.achat.core.chat.EmbeddingsResponse; import xyz.wbsite.achat.core.chat.Role; import xyz.wbsite.achat.core.chat.StreamEmitter; import xyz.wbsite.achat.core.chat.Usage; -import xyz.wbsite.achat.core.service.ChatService; +import xyz.wbsite.achat.core.chat.ChatService; +import javax.annotation.Resource; +import java.util.Collections; + +/** + * 该类用于测试AI服务的基本功能,返回固定的响应内容。 + * 该类实现了ChatService接口,提供了prompt、chat、streamChat方法。 + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ public class ChatServiceSampleImpl implements ChatService { + + @Resource + private ThreadPoolTaskExecutor taskExecutor; + + /** + * 默认提示语 + */ + private final String DEFAULT_PROMPT = "您好,我还没有接入AI,请接入后再试!"; + @Override public CompletionResponse prompt(CompletionRequest request) { -// CompletionResponse response = new CompletionResponse(); -// response.setObject("chat.completion"); -// response.setCreated(System.currentTimeMillis() / 1000); -// response.setModel(request.getModel()); -// List choices = new ArrayList<>(); -// choices.add(Choice.builder().index(0).message(Message.builder().role(Role.ASSISTANT).content("您好,我还没有接入AI,请接入后再试!").build()).finish_reason("stop").build()); -// response.setChoices(choices); -// response.setUsage(Usage.builder().prompt_tokens(10).completion_tokens(20).total_tokens(30).build()); - return CompletionResponse.builder() .id("chatcmpl-" + System.currentTimeMillis()) .object("chat.completion") @@ -32,7 +45,7 @@ public class ChatServiceSampleImpl implements ChatService { choices.add(CompletionResponse.choiceBuilder() .index(0) .role(Role.ASSISTANT) - .content("您好,我还没有接入AI,请接入后再试!") + .content(DEFAULT_PROMPT) .finish_reason("stop").build()); }) .usage(Usage.builder() @@ -54,7 +67,7 @@ public class ChatServiceSampleImpl implements ChatService { choices.add(ChatCompletionResponse.choiceBuilder() .index(0) .role(Role.ASSISTANT) - .content("您好,我还没有接入AI,请接入后再试!") + .content(DEFAULT_PROMPT) .finish_reason("stop") .build()); }) @@ -67,7 +80,7 @@ public class ChatServiceSampleImpl implements ChatService { } @Override - public StreamEmitter streamChat(ChatCompletionRequest request) { + public SseEmitter streamChat(ChatCompletionRequest request) { // 验证请求参数 if (request.getModel() == null) { throw new IllegalArgumentException("模型不能为空"); @@ -75,45 +88,54 @@ public class ChatServiceSampleImpl implements ChatService { if (request.getMessages() == null || request.getMessages().isEmpty()) { throw new IllegalArgumentException("消息不能为空"); } - StreamEmitter streamEmitter = StreamEmitter.builder() - .chatCompletionRequest(request) - .build(); + StreamEmitter streamEmitter = new StreamEmitter(request); - return streamEmitter; + taskExecutor.execute(() -> { + String chatId = "chatcmpl-" + System.currentTimeMillis(); + streamEmitter.onStart(chatId); + char[] charArray = DEFAULT_PROMPT.toCharArray(); + for (char c : charArray) { + streamEmitter.onPartial(String.valueOf(c)); + } + streamEmitter.onComplete(); + }); -// SseEmitter emitter = new SseEmitter(Long.MAX_VALUE); - // 在单独的线程中处理流式响应 -// new Thread(() -> { -// try { -// // 模拟流式响应的逻辑 -// // 实际应用中应从服务层获取流式数据并发送 -// String id = "chatcmpl-" + System.currentTimeMillis(); -// long created = System.currentTimeMillis() / 1000; -// String model = request.getModel(); -// -// // 发送初始数据块 -// ChatCompletionChunk chunk = new ChatCompletionChunk(); -// chunk.setId(id); -// chunk.setObject("chat.completion.chunk"); -// chunk.setCreated(created); -// chunk.setModel(model); -// // chunk.setChoices(/* 实际的选择项列表 */); -// emitter.send(chunk, MediaType.APPLICATION_JSON); -// -// // 发送更多数据块... -// -// // 发送完成信号 -// emitter.complete(); -// } catch (Exception e) { -// emitter.completeWithError(e); -// } -// }).start(); -// -// return emitter; + return streamEmitter; } @Override public EmbeddingsResponse embeddings(EmbeddingsRequest request) { - return null; + // 验证请求参数 + if (request == null) { + throw new IllegalArgumentException("请求参数不能为空"); + } + if (request.getModel() == null || request.getModel().isEmpty()) { + throw new IllegalArgumentException("模型不能为空"); + } + if (request.getInput() == null || request.getInput().isEmpty()) { + throw new IllegalArgumentException("输入文本不能为空"); + } + + // 模拟生成嵌入向量响应 + return EmbeddingsResponse.builder() + .id("embedding-" + System.currentTimeMillis()) + .object("list") + .created(System.currentTimeMillis() / 1000) + .model(request.getModel()) + .withData(data -> { + for (int i = 0; i < request.getInput().size(); i++) { + data.add(EmbeddingsResponse.dataBuilder() + .index(i) + .embedding(Collections.emptyList()) + .object("embedding") + .build()); + } + }) + .usage(Usage.builder() + .prompt_tokens(10) + .completion_tokens(10) + .total_tokens(10) + .build()) + .build(); } } diff --git a/src/main/java/xyz/wbsite/achat/core/chat/EmbeddingsRequest.java b/src/main/java/xyz/wbsite/achat/core/chat/EmbeddingsRequest.java index faf5d26..ed0461f 100644 --- a/src/main/java/xyz/wbsite/achat/core/chat/EmbeddingsRequest.java +++ b/src/main/java/xyz/wbsite/achat/core/chat/EmbeddingsRequest.java @@ -1,12 +1,89 @@ package xyz.wbsite.achat.core.chat; +import java.util.ArrayList; +import java.util.List; + /** - * 嵌入请求 + * 嵌入请求 - 符合OpenAI官方API规范 * * @author wangbing * @version 0.0.1 * @since 1.8 */ public class EmbeddingsRequest { + private String model; + private List input; + private String user; + + // 无参构造函数 + public EmbeddingsRequest() { + } + + // 私有构造函数,用于Builder模式 + private EmbeddingsRequest(Builder builder) { + this.model = builder.model; + this.input = builder.input; + this.user = builder.user; + } + + // 静态builder方法,返回Builder实例 + public static Builder builder() { + return new Builder(); + } + + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + public List getInput() { + return input; + } + + public void setInput(List input) { + this.input = input; + } + + public String getUser() { + return user; + } + + public void setUser(String user) { + this.user = user; + } + + // Builder内部类 + public static class Builder { + private String model; + private List input = new ArrayList<>(); + private String user; + + public Builder model(String model) { + this.model = model; + return this; + } + + public Builder input(List input) { + this.input = input; + return this; + } + + public Builder addInput(String text) { + this.input.add(text); + return this; + } + + public Builder user(String user) { + this.user = user; + return this; + } + // 构建EmbeddingsRequest对象 + public EmbeddingsRequest build() { + return new EmbeddingsRequest(this); + } + } } \ No newline at end of file diff --git a/src/main/java/xyz/wbsite/achat/core/chat/EmbeddingsResponse.java b/src/main/java/xyz/wbsite/achat/core/chat/EmbeddingsResponse.java index aa80fd8..18a222b 100644 --- a/src/main/java/xyz/wbsite/achat/core/chat/EmbeddingsResponse.java +++ b/src/main/java/xyz/wbsite/achat/core/chat/EmbeddingsResponse.java @@ -4,7 +4,8 @@ import java.util.ArrayList; import java.util.List; /** - * OpenAI聊天完成响应 - 符合OpenAI官方API规范 + * 嵌入响应 - 符合OpenAI官方API规范 + * 用于存储嵌入模型生成的向量表示 * * @author wangbing * @version 0.0.1 @@ -15,18 +16,30 @@ public class EmbeddingsResponse { private String object; private long created; private String model; - private List choices; + private List data = new ArrayList<>(); private Usage usage; + /** + * 无参构造函数 + */ + public EmbeddingsResponse() { + } + + /** + * 私有构造函数,用于Builder模式 + */ private EmbeddingsResponse(Builder builder) { this.id = builder.id; this.object = builder.object; this.created = builder.created; this.model = builder.model; - this.choices = builder.choices; + this.data = builder.data; this.usage = builder.usage; } + /** + * 静态builder方法,返回Builder实例 + */ public static Builder builder() { return new Builder(); } @@ -63,12 +76,15 @@ public class EmbeddingsResponse { this.model = model; } - public List getChoices() { - return choices; + /** + * 获取嵌入向量数据列表 + */ + public List getData() { + return data; } - public void setChoices(List choices) { - this.choices = choices; + public void setData(List data) { + this.data = data; } public Usage getUsage() { @@ -79,12 +95,15 @@ public class EmbeddingsResponse { this.usage = usage; } + /** + * Builder内部类,用于构建EmbeddingsResponse对象 + */ public static class Builder { private String id; private String object; private long created; private String model; - private List choices = new ArrayList<>(); + private List data = new ArrayList<>(); private Usage usage; public Builder id(String id) { @@ -107,13 +126,21 @@ public class EmbeddingsResponse { return this; } - public Builder choices(List choices) { - this.choices = choices; + public Builder data(List data) { + this.data = data; return this; } - public Builder withChoices(java.util.function.Consumer> choicesConsumer) { - choicesConsumer.accept(this.choices); + public Builder addData(Data data) { + this.data.add(data); + return this; + } + + /** + * 提供data列表的函数式访问,用于在lambda表达式中操作data列表 + */ + public Builder withData(java.util.function.Consumer> dataConsumer) { + dataConsumer.accept(this.data); return this; } @@ -122,83 +149,115 @@ public class EmbeddingsResponse { return this; } + /** + * 构建EmbeddingsResponse对象 + */ public EmbeddingsResponse build() { return new EmbeddingsResponse(this); } } - public static class Choice { - private int index = 0; - private Message message; - private String finish_reason; - - public Integer getIndex() { - return index; - } + /** + * 嵌入数据项 - 包含索引和嵌入向量 + */ + public static class Data { + private int index; + private List embedding; + private String object; - public void setIndex(Integer index) { - this.index = index; + /** + * 无参构造函数 + */ + public Data() { } - public Message getMessage() { - return message; + /** + * 私有构造函数,用于Builder模式 + */ + private Data(Builder builder) { + this.index = builder.index; + this.embedding = builder.embedding; + this.object = builder.object; } - public void setMessage(Message message) { - this.message = message; + /** + * 静态builder方法,返回Builder实例 + */ + public static Builder builder() { + return new Builder(); } - public String getFinish_reason() { - return finish_reason; + public int getIndex() { + return index; } - public void setFinish_reason(String finish_reason) { - this.finish_reason = finish_reason; - } - } - - public static ChoiceBuilder choiceBuilder() { - return new ChoiceBuilder(); - } - - public static class ChoiceBuilder { - private Integer index; - private Role role; - private String content; - private String name; - private String finish_reason; - - public ChoiceBuilder index(Integer index) { + public void setIndex(int index) { this.index = index; - return this; } - public ChoiceBuilder role(Role role) { - this.role = role; - return this; + /** + * 获取嵌入向量 + */ + public List getEmbedding() { + return embedding; } - public ChoiceBuilder content(String content) { - this.content = content; - return this; + public void setEmbedding(List embedding) { + this.embedding = embedding; } - public ChoiceBuilder name(String name) { - this.name = name; - return this; + public String getObject() { + return object; } - public ChoiceBuilder finish_reason(String finish_reason) { - this.finish_reason = finish_reason; - return this; + public void setObject(String object) { + this.object = object; } - public Choice build() { - Choice choice = new Choice(); - choice.setIndex(index); - choice.setMessage(Message.builder().role(role).content(content).name(name).build()); - choice.setFinish_reason(finish_reason); - return choice; + /** + * Data的Builder内部类 + */ + public static class Builder { + private int index; + private List embedding; + private String object = "embedding"; + + public Builder index(int index) { + this.index = index; + return this; + } + + public Builder embedding(List embedding) { + this.embedding = embedding; + return this; + } + + public Builder addEmbeddingValue(double value) { + if (this.embedding == null) { + this.embedding = new ArrayList<>(); + } + this.embedding.add(value); + return this; + } + + public Builder object(String object) { + this.object = object; + return this; + } + + /** + * 构建Data对象 + */ + public Data build() { + return new Data(this); + } } } + + /** + * 创建DataBuilder实例 + */ + public static Data.Builder dataBuilder() { + return Data.builder(); + } } \ No newline at end of file diff --git a/src/main/java/xyz/wbsite/achat/core/chat/StreamEmitter.java b/src/main/java/xyz/wbsite/achat/core/chat/StreamEmitter.java index df5d4bc..6261e79 100644 --- a/src/main/java/xyz/wbsite/achat/core/chat/StreamEmitter.java +++ b/src/main/java/xyz/wbsite/achat/core/chat/StreamEmitter.java @@ -1,10 +1,7 @@ package xyz.wbsite.achat.core.chat; +import org.springframework.util.StringUtils; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; -import xyz.wbsite.achat.core.service.ChatCompletionGenerator; - -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; /** * 流式对话生成器 @@ -16,123 +13,54 @@ import java.util.concurrent.Executors; public class StreamEmitter extends SseEmitter { /** - * 用户消息 + * 流式输出默认超时时间 */ - private ChatCompletionRequest chatCompletionRequest; + private static final Long DEFAULT_TIMEOUT = 5 * 60 * 1000L; /** - * 是否完成 + * 对话请求 */ - private boolean complete; + private ChatCompletionRequest request; /** - * AI回答 + * 当前对话状态 */ - private final StringBuilder answer = new StringBuilder(); + private Status status; /** - * 消息处理器 + * 当前对话ID */ - private ChatCompletionGenerator chatCompletionGenerator; + private String chatId; /** - * 线程池,用于执行消息生成器 + * 是否完成 */ - private static final ExecutorService messageExecutor = Executors.newFixedThreadPool( - Math.max(4, Runtime.getRuntime().availableProcessors()), - r -> { - Thread thread = new Thread(r, "message-generator-thread"); - thread.setDaemon(true); - return thread; - } - ); - - private StreamEmitter(Builder builder) { - super(Long.MAX_VALUE); - this.chatCompletionRequest = builder.chatCompletionRequest; - this.chatCompletionGenerator = builder.chatCompletionGenerator; - - // 使用线程池执行MessageGenerator - if (this.chatCompletionGenerator != null && this.chatCompletionRequest != null) { - messageExecutor.execute(() -> { - try { - this.chatCompletionGenerator.on(this, this.chatCompletionRequest); - } catch (Exception e) { - onError(e); - } finally { - if (!isComplete()) { - onCompleteResponse(null); - } - } - }); - } - } - - public static Builder builder() { - return new Builder(); - } - - // Builder内部类 - public static class Builder { - private ChatCompletionRequest chatCompletionRequest; - private ChatCompletionGenerator chatCompletionGenerator; - - public Builder chatCompletionRequest(ChatCompletionRequest chatCompletionRequest) { - this.chatCompletionRequest = chatCompletionRequest; - return this; - } - - public Builder chatCompletionGenerator(ChatCompletionGenerator chatCompletionGenerator) { - this.chatCompletionGenerator = chatCompletionGenerator; - return this; - } - - // 构建StreamEmitter对象 - public StreamEmitter build() { - return new StreamEmitter(this); - } - } - - public StreamEmitter(ChatCompletionRequest chatCompletionRequest, ChatCompletionGenerator chatCompletionGenerator) { - super(Long.MAX_VALUE); - this.chatCompletionRequest = chatCompletionRequest; - this.chatCompletionGenerator = chatCompletionGenerator; - - // 使用线程池执行MessageGenerator - if (chatCompletionGenerator != null && chatCompletionRequest != null) { - messageExecutor.execute(() -> { - try { - chatCompletionGenerator.on(this, chatCompletionRequest); - } catch (Exception e) { - onError(e); - } finally { - if (!isComplete()) { - onCompleteResponse(null); - } - } - }); - } - } + private boolean complete; - /** - * 错误处理 - */ - private void onError(Throwable e) { -// this.sendMessage(createPartialMessage("" + e.getMessage())); - this.answer.append("" + e.getMessage()); - this.onCompleteResponse(null); + public StreamEmitter(ChatCompletionRequest request) { + super(DEFAULT_TIMEOUT); + this.request = request; } - /** - * 部分响应处理 - */ - public void onPartialResponse(String msg) { - if (complete) { - return; - } -// this.sendMessage(createPartialMessage(msg)); - this.answer.append(msg); - } +// /** +// * 错误处理 +// */ +// private void onError(Throwable e) { +//// this.sendMessage(createPartialMessage("" + e.getMessage())); +// this.answer.append("" + e.getMessage()); +// this.onCompleteResponse(null); +// } +// +// /** +// * 部分响应处理 +// */ +// public void onPartialResponse(String msg) { +// if (complete) { +// return; +// } +//// this.sendMessage(createPartialMessage(msg)); +// this.answer.append(msg); +// } /** * 完成响应处理 @@ -147,20 +75,6 @@ public class StreamEmitter extends SseEmitter { this.complete(); } -// /** -// * 创建部分消息事件 -// */ -// private Event createPartialMessage(String partial) { -// return new PartialEvent(messagePrompt.getSid(), partial); -// } -// -// /** -// * 创建完成消息事件 -// */ -// private Event createCompleteMessage() { -// return new CompleteEvent(messagePrompt.getSid()); -// } - /** * 重写send方法,处理异常 */ @@ -173,20 +87,87 @@ public class StreamEmitter extends SseEmitter { } } - public void callStart() { - } - /** - * 发送消息 + * 发送片段 */ - private void sendMessage(Object message) { + private void pushChunk(ChatCompletionChunk chunk) { try { - this.send(message); + this.send(chunk); } catch (Exception e) { complete = true; } } + /** + * 发送开始片段 + * + * @param chatId 对话id + */ + public void onStart(String chatId) { + if (!StringUtils.hasText(chatId)) { + throw new IllegalArgumentException("chatId is empty!"); + } + if (this.status != null) { + throw new IllegalArgumentException("chunk has been started!"); + } + ChatCompletionChunk chunk = ChatCompletionChunk.builder() + .id(chatId) + .object("chat.completion.chunk") + .created(System.currentTimeMillis() / 1000) + .model(request.getModel()) + .withChoices(choices -> { + choices.add(ChatCompletionChunk.choiceBuilder().index(0) + .role(Role.ASSISTANT) + .build()); + }) + .build(); + this.pushChunk(chunk); + this.status = Status.PENDING; + this.chatId = chatId; + } + + /** + * 发送部分片段 + * + * @param text 片段内容 + */ + public void onPartial(String text) { + ChatCompletionChunk chunk = ChatCompletionChunk.builder() + .id(chatId) + .object("chat.completion.chunk") + .created(System.currentTimeMillis() / 1000) + .model(request.getModel()) + .withChoices(choices -> { + choices.add(ChatCompletionChunk.choiceBuilder().index(0) + .role(Role.ASSISTANT) + .content(text) + .build()); + }) + .build(); + this.pushChunk(chunk); + } + + /** + * 发送完成信号,根据OpenAI API规范,最后一个片段需要包含finish_reason字段 + */ + public void onComplete() { + ChatCompletionChunk chunk = ChatCompletionChunk.builder() + .id(chatId) + .object("chat.completion.chunk") + .created(System.currentTimeMillis() / 1000) + .model(request.getModel()) + .withChoices(choices -> { + choices.add(ChatCompletionChunk.choiceBuilder().index(0) + .role(Role.ASSISTANT) + // 符合OpenAI规范,最后一个片段需要设置finish_reason + .finish_reason("stop") + .build()); + }) + .build(); + this.pushChunk(chunk); + this.status = Status.SUCCESS; + } + /** * 获取完成状态 */ diff --git a/src/main/java/xyz/wbsite/achat/core/event/CompleteEvent.java b/src/main/java/xyz/wbsite/achat/core/event/CompleteEvent.java deleted file mode 100644 index e00f277..0000000 --- a/src/main/java/xyz/wbsite/achat/core/event/CompleteEvent.java +++ /dev/null @@ -1,103 +0,0 @@ -//package xyz.wbsite.achat.core.event; -// -//import xyz.wbsite.achat.core.Event; -// -///** -// * 完成事件 -// * 用于标识流式响应的结束 -// * -// * @author wangbing -// * @version 0.0.1 -// * @since 1.8 -// */ -//public class CompleteEvent extends Event { -// -// /** -// * 完整响应内容 -// */ -// private String content; -// -// /** -// * 构造函数 -// * -// * @param sid 会话ID -// */ -// public CompleteEvent(String sid) { -// super(); -// this.setSid(sid); -// this.setObject("chat.completion"); -// } -// -// /** -// * 生成的令牌总数 -// */ -// private Integer completionTokens; -// -// /** -// * 提示词的令牌数量 -// */ -// private Integer promptTokens; -// -// /** -// * 总令牌数量 -// */ -// private Integer totalTokens; -// -// /** -// * 完成状态 -// */ -// private String finishReason; -// -// /** -// * 生成用时(毫秒) -// */ -// private Long generationTime; -// -// public String getContent() { -// return content; -// } -// -// public void setContent(String content) { -// this.content = content; -// } -// -// public Integer getCompletionTokens() { -// return completionTokens; -// } -// -// public void setCompletionTokens(Integer completionTokens) { -// this.completionTokens = completionTokens; -// } -// -// public Integer getPromptTokens() { -// return promptTokens; -// } -// -// public void setPromptTokens(Integer promptTokens) { -// this.promptTokens = promptTokens; -// } -// -// public Integer getTotalTokens() { -// return totalTokens; -// } -// -// public void setTotalTokens(Integer totalTokens) { -// this.totalTokens = totalTokens; -// } -// -// public String getFinishReason() { -// return finishReason; -// } -// -// public void setFinishReason(String finishReason) { -// this.finishReason = finishReason; -// } -// -// public Long getGenerationTime() { -// return generationTime; -// } -// -// public void setGenerationTime(Long generationTime) { -// this.generationTime = generationTime; -// } -//} \ No newline at end of file diff --git a/src/main/java/xyz/wbsite/achat/core/event/PartialEvent.java b/src/main/java/xyz/wbsite/achat/core/event/PartialEvent.java deleted file mode 100644 index 0e01846..0000000 --- a/src/main/java/xyz/wbsite/achat/core/event/PartialEvent.java +++ /dev/null @@ -1,80 +0,0 @@ -//package xyz.wbsite.achat.core.event; -// -// -//import xyz.wbsite.achat.core.Event; -// -///** -// * 分段内容事件 -// * 用于推送流式响应的部分内容 -// * -// * @author wangbing -// * @version 0.0.1 -// * @since 1.8 -// */ -//public class PartialEvent extends Event { -// -// /** -// * 分段内容 -// */ -// private String content; -// -// /** -// * 构造函数 -// * -// * @param sid 会话ID -// * @param partial 部分内容 -// */ -// public PartialEvent(String sid, String partial) { -// super(); -// this.setSid(sid); -// this.content = partial; -// this.setObject("chat.completion.chunk"); -// } -// -// /** -// * 是否是最后一段 -// */ -// private Boolean isFinal; -// -// /** -// * 完成率,范围0-1 -// */ -// private Double finishRate; -// -// /** -// * 当前累积的生成令牌数量 -// */ -// private Integer completionTokens; -// -// public String getContent() { -// return content; -// } -// -// public void setContent(String content) { -// this.content = content; -// } -// -// public Boolean getIsFinal() { -// return isFinal; -// } -// -// public void setIsFinal(Boolean isFinal) { -// this.isFinal = isFinal; -// } -// -// public Double getFinishRate() { -// return finishRate; -// } -// -// public void setFinishRate(Double finishRate) { -// this.finishRate = finishRate; -// } -// -// public Integer getCompletionTokens() { -// return completionTokens; -// } -// -// public void setCompletionTokens(Integer completionTokens) { -// this.completionTokens = completionTokens; -// } -//} \ No newline at end of file diff --git a/src/main/java/xyz/wbsite/achat/core/event/StartEvent.java b/src/main/java/xyz/wbsite/achat/core/event/StartEvent.java deleted file mode 100644 index fadeae2..0000000 --- a/src/main/java/xyz/wbsite/achat/core/event/StartEvent.java +++ /dev/null @@ -1,66 +0,0 @@ -//package xyz.wbsite.achat.core.event; -// -//import xyz.wbsite.achat.core.Event; -// -///** -// * 开始推送事件 -// * 用于标识流式响应的开始 -// * -// * @author wangbing -// * @version 0.0.1 -// * @since 1.8 -// */ -//public class StartEvent extends Event { -// -// /** -// * 提示词令牌数量 -// */ -// private Integer promptTokens; -// -// /** -// * 最大令牌数量限制 -// */ -// private Integer maxTokens; -// -// /** -// * 温度参数 -// */ -// private Double temperature; -// -// /** -// * 随机种子 -// */ -// private Integer seed; -// -// public Integer getPromptTokens() { -// return promptTokens; -// } -// -// public void setPromptTokens(Integer promptTokens) { -// this.promptTokens = promptTokens; -// } -// -// public Integer getMaxTokens() { -// return maxTokens; -// } -// -// public void setMaxTokens(Integer maxTokens) { -// this.maxTokens = maxTokens; -// } -// -// public Double getTemperature() { -// return temperature; -// } -// -// public void setTemperature(Double temperature) { -// this.temperature = temperature; -// } -// -// public Integer getSeed() { -// return seed; -// } -// -// public void setSeed(Integer seed) { -// this.seed = seed; -// } -//} \ No newline at end of file diff --git a/src/main/java/xyz/wbsite/achat/core/prompt/MessagePrompt.java b/src/main/java/xyz/wbsite/achat/core/prompt/MessagePrompt.java deleted file mode 100644 index 009a9be..0000000 --- a/src/main/java/xyz/wbsite/achat/core/prompt/MessagePrompt.java +++ /dev/null @@ -1,105 +0,0 @@ -//package xyz.wbsite.achat.core.prompt; -// -//import xyz.wbsite.achat.core.chat.Message; -//import xyz.wbsite.achat.core.Prompt; -//import xyz.wbsite.achat.core.message.UserMessage; -// -//import java.util.List; -//import java.util.Map; -//import java.util.stream.Collectors; -// -//public class MessagePrompt extends Prompt { -// -// /** -// * 模型 -// */ -// private String model; -// -// /** -// * 消息列表 -// */ -// private List messages; -// -// /** -// * 是否流式返回 -// */ -// private Boolean stream; -// -// /** -// * 最大token数 -// */ -// private Integer maxTokens; -// -// /** -// * 温度参 -// */ -// private Double temperature; -// -// /** -// * 额外参数 -// */ -// private Map extraParams; -// -// public String getModel() { -// return model; -// } -// -// public void setModel(String model) { -// this.model = model; -// } -// -// public List getMessages() { -// return messages; -// } -// -// public void setMessages(List messages) { -// this.messages = messages; -// } -// -// public Boolean getStream() { -// return stream; -// } -// -// public void setStream(Boolean stream) { -// this.stream = stream; -// } -// -// public Integer getMaxTokens() { -// return maxTokens; -// } -// -// public void setMaxTokens(Integer maxTokens) { -// this.maxTokens = maxTokens; -// } -// -// public Double getTemperature() { -// return temperature; -// } -// -// public void setTemperature(Double temperature) { -// this.temperature = temperature; -// } -// -// public Map getExtraParams() { -// return extraParams; -// } -// -// public void setExtraParams(Map extraParams) { -// this.extraParams = extraParams; -// } -// -// public UserMessage getLastUserMessage() { -// List messageList = messages.stream().filter(message -> message instanceof UserMessage).collect(Collectors.toList()); -// UserMessage userMessage = (UserMessage)messageList.get(messageList.size() - 1); -// return userMessage; -// } -// -// public String getUid() { -// return getLastUserMessage().getUid(); -// } -// -// public String getSid(){ -// return getLastUserMessage().getSid(); -// } -// -//} diff --git a/src/main/java/xyz/wbsite/achat/core/service/ChatCompletionGenerator.java b/src/main/java/xyz/wbsite/achat/core/service/ChatCompletionGenerator.java deleted file mode 100644 index 690e107..0000000 --- a/src/main/java/xyz/wbsite/achat/core/service/ChatCompletionGenerator.java +++ /dev/null @@ -1,18 +0,0 @@ -package xyz.wbsite.achat.core.service; - -import xyz.wbsite.achat.core.chat.ChatCompletionRequest; -import xyz.wbsite.achat.core.chat.StreamEmitter; - -/** - * 推理生成器 - *

- * 抽象出来用于生成消息的实现层 - * - * @author wangbing - * @version 0.0.1 - * @since 1.8 - */ -public interface ChatCompletionGenerator { - - void on(StreamEmitter emitter, ChatCompletionRequest chatCompletionRequest); -} \ No newline at end of file diff --git a/src/main/java/xyz/wbsite/achat/core/service/SessionService.java b/src/main/java/xyz/wbsite/achat/core/service/SessionService.java deleted file mode 100644 index 13b97c7..0000000 --- a/src/main/java/xyz/wbsite/achat/core/service/SessionService.java +++ /dev/null @@ -1,84 +0,0 @@ -//package xyz.wbsite.achat.core.service; -// -//import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; -//import xyz.wbsite.achat.core.chat.Message; -//import xyz.wbsite.achat.core.Result; -//import xyz.wbsite.achat.core.Session; -//import xyz.wbsite.achat.core.prompt.MessagePrompt; -// -//import java.util.List; -// -///** -// * 会话管理服务接口 -// * 提供会话的创建、删除、查询、消息收发等功能 -// * -// * @author wangbing -// * @version 0.0.1 -// * @since 1.8 -// */ -//public interface SessionService { -// -// /** -// * 创建新会话 -// * -// * @param uid 用户编号 -// * @return 创建的会话对象 -// */ -// Result createSession(String uid); -// -// /** -// * 删除会话 -// * -// * @param sid 会话ID -// * @return 删除是否成功 -// */ -// Result deleteSession(String sid); -// -// /** -// * 查询会话列表 -// * -// * @param uid 用户编号 -// * @return 会话列表 -// */ -// Result> listSessions(String uid); -// -// /** -// * 获取会话详情 -// * -// * @param sid 会话ID -// * @return 会话对象 -// */ -// Result getSession(String sid); -// -// /** -// * 停止会话 -// * -// * @param sid 会话ID -// * @return 会话对象 -// */ -// Result stopSession(String sid); -// -// /** -// * 发送消息并获取流式响应 -// * -// * @param message 消息对象 -// * @return SSE发射器,用于流式响应 -// */ -// SseEmitter sendMessage(MessagePrompt message); -// -// /** -// * 获取会话历史消息 -// * -// * @param sid 会话ID -// * @return 消息列表 -// */ -// Result> listMessage(String sid); -// -// /** -// * 获取会话历史消息 -// * -// * @param sid 会话ID -// * @return 消息列表 -// */ -// Result deleteMessage(String sid); -//} \ No newline at end of file diff --git a/src/main/java/xyz/wbsite/achat/core/service/impl/SessionServiceMemoryImpl.java b/src/main/java/xyz/wbsite/achat/core/service/impl/SessionServiceMemoryImpl.java deleted file mode 100644 index 9a2f1a1..0000000 --- a/src/main/java/xyz/wbsite/achat/core/service/impl/SessionServiceMemoryImpl.java +++ /dev/null @@ -1,121 +0,0 @@ -//package xyz.wbsite.achat.core.service.impl; -// -//import org.springframework.stereotype.Service; -//import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; -//import xyz.wbsite.achat.core.message.MessageSseEmitter; -//import xyz.wbsite.achat.core.chat.Message; -//import xyz.wbsite.achat.core.Result; -//import xyz.wbsite.achat.core.Session; -//import xyz.wbsite.achat.core.prompt.MessagePrompt; -//import xyz.wbsite.achat.core.service.SessionService; -// -//import java.util.ArrayList; -//import java.util.HashMap; -//import java.util.List; -//import java.util.Map; -//import java.util.UUID; -//import java.util.stream.Collectors; -// -///** -// * 会话管理服务实现类 -// * 实现会话的创建、删除、查询、消息收发等功能 -// * -// * @author wangbing -// * @version 0.0.1 -// * @since 1.8 -// */ -//@Service -//public class SessionServiceMemoryImpl implements SessionService { -// -// private final Map sessionStore = new HashMap<>(); -// -// private final List messageStore = new ArrayList<>(); -// -// /** -// * 创建新会话 -// */ -// @Override -// public Result createSession(String uid) { -// Session session = new Session(); -// session.setId(String.valueOf(UUID.randomUUID().toString())); -// session.setUid(uid); -// session.setTitle("新对话"); -// sessionStore.put(session.getId(), session); -// return Result.success(session); -// } -// -// /** -// * 删除会话 -// */ -// @Override -// public Result deleteSession(String sid) { -// sessionStore.remove(sid); -// return Result.success(); -// } -// -// /** -// * 查询会话列表 -// */ -// @Override -// public Result> listSessions(String uid) { -// List collect = sessionStore.values() -// .stream() -// .filter(item -> item.getUid().equals(uid)) -// .collect(Collectors.toList()); -// return Result.success(collect); -// } -// -// /** -// * 获取会话详情 -// */ -// @Override -// public Result getSession(String sid) { -// Session session = sessionStore.get(sid); -// if (session == null) { -// return Result.error("会话不存在"); -// } -// return Result.success(session); -// } -// -// /** -// * 发送消息并获取流式响应 -// */ -// @Override -// public SseEmitter sendMessage(MessagePrompt message) { -// // 创建VChatSseEmitter来处理流式响应 -// return new MessageSseEmitter(message, (emitter, message1) -> { -// // 这边模拟LLM复述一遍用户问题 -// String text = message1.getContent(); -// for (char c : text.toCharArray()) { -// if (emitter.isComplete()) { -// return; -// } -// emitter.onPartialResponse(String.valueOf(c)); -// try { -// Thread.sleep(100); -// } catch (InterruptedException e) { -// throw new RuntimeException(e); -// } -// } -// emitter.onCompleteResponse(null); -// }); -// } -// -// @Override -// public Result stopSession(String sid) { -// // todo -// return Result.success(); -// } -// -// @Override -// public Result> listMessage(String sid) { -// List messages = messageStore.stream().filter(item -> item.getSid().equals(sid)).collect(Collectors.toList()); -// return Result.success(messages); -// } -// -// @Override -// public Result deleteMessage(String sid) { -// messageStore.forEach(item -> messageStore.remove(item)); -// return null; -// } -//} \ No newline at end of file diff --git a/src/main/java/xyz/wbsite/achat/core/session/Message.java b/src/main/java/xyz/wbsite/achat/core/session/Message.java new file mode 100644 index 0000000..b1f38fb --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/core/session/Message.java @@ -0,0 +1,34 @@ +package xyz.wbsite.achat.core.session; + +/** + * + */ +public class Message extends xyz.wbsite.achat.core.chat.Message { + private String id; + private String uid; + private String sid; + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getUid() { + return uid; + } + + public void setUid(String uid) { + this.uid = uid; + } + + public String getSid() { + return sid; + } + + public void setSid(String sid) { + this.sid = sid; + } +} diff --git a/src/main/java/xyz/wbsite/achat/core/session/Result.java b/src/main/java/xyz/wbsite/achat/core/session/Result.java new file mode 100644 index 0000000..2c1cb04 --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/core/session/Result.java @@ -0,0 +1,163 @@ +package xyz.wbsite.achat.core.session; + +/** + * 接口响应结果基类 + * 泛型支持的数据响应封装,提供统一的响应格式和错误处理 + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ +public class Result { + + /** + * 响应状态码 + */ + private int code = 200; + + /** + * 响应消息 + */ + private String message = "success"; + + /** + * 响应数据 + */ + private T data; + + /** + * 是否成功 + */ + private boolean success = true; + + /** + * 响应时间戳 + */ + private long timestamp = System.currentTimeMillis(); + + public int getCode() { + return code; + } + + public Result setCode(int code) { + this.code = code; + return this; + } + + public String getMessage() { + return message; + } + + public Result setMessage(String message) { + this.message = message; + return this; + } + + public T getData() { + return data; + } + + public Result setData(T data) { + this.data = data; + return this; + } + + public boolean isSuccess() { + return success; + } + + public Result setSuccess(boolean success) { + this.success = success; + return this; + } + + public long getTimestamp() { + return timestamp; + } + + public Result setTimestamp(long timestamp) { + this.timestamp = timestamp; + return this; + } + + /** + * 返回成功信息 + * + * @return 结果 + */ + public static Result success() { + return new Result<>(); + } + + /** + * 返回带数据的成功信息 + * + * @param data 响应数据 + * @return 结果 + */ + public static Result success(T data) { + Result result = new Result<>(); + result.setData(data); + return result; + } + + /** + * 返回错误信息 + * + * @param message 错误信息 + * @return 错误信息对象 + */ + public static Result error(String message) { + Result result = new Result<>(); + result.message = message; + result.code = 500; + result.success = false; + return result; + } + + /** + * 返回错误信息 + * + * @param code 错误码 + * @param message 错误信息 + * @return 错误信息对象 + */ + public static Result error(int code, String message) { + Result result = new Result<>(); + result.code = code; + result.message = message; + result.success = false; + return result; + } + + /** + * 返回带数据的错误信息 + * + * @param code 错误码 + * @param message 错误信息 + * @param data 错误相关数据 + * @return 错误信息对象 + */ + public static Result error(int code, String message, T data) { + Result result = new Result<>(); + result.code = code; + result.message = message; + result.success = false; + result.data = data; + return result; + } + + /** + * 从异常创建错误响应 + * + * @param e 异常 + * @return 错误信息对象 + */ + public static Result error(Exception e) { + Result result = new Result<>(); + result.code = 500; + result.message = e.getMessage() != null ? e.getMessage() : "系统异常"; + result.success = false; + return result; + } +} \ No newline at end of file diff --git a/src/main/java/xyz/wbsite/achat/core/session/Session.java b/src/main/java/xyz/wbsite/achat/core/session/Session.java new file mode 100644 index 0000000..cde66ff --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/core/session/Session.java @@ -0,0 +1,138 @@ +package xyz.wbsite.achat.core.session; + + +import java.util.List; + +/** + * 会话 + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ +public class Session { + /** + * 主键 + */ + private String id; + /** + * 用户ID + */ + private String uid; + private String title; + private String model; + private String prompt; + private String temperature; + private String topP; + private String frequencyPenalty; + private String presencePenalty; + private String maxTokens; + private String lastTime; + private String lastMessage; + + private List messages; + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getUid() { + return uid; + } + + public void setUid(String uid) { + this.uid = uid; + } + + public String getTitle() { + return title; + } + + public void setTitle(String title) { + this.title = title; + } + + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + public String getPrompt() { + return prompt; + } + + public void setPrompt(String prompt) { + this.prompt = prompt; + } + + public String getTemperature() { + return temperature; + } + + public void setTemperature(String temperature) { + this.temperature = temperature; + } + + public String getTopP() { + return topP; + } + + public void setTopP(String topP) { + this.topP = topP; + } + + public String getFrequencyPenalty() { + return frequencyPenalty; + } + + public void setFrequencyPenalty(String frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + public String getPresencePenalty() { + return presencePenalty; + } + + public void setPresencePenalty(String presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public String getMaxTokens() { + return maxTokens; + } + + public void setMaxTokens(String maxTokens) { + this.maxTokens = maxTokens; + } + + public String getLastTime() { + return lastTime; + } + + public void setLastTime(String lastTime) { + this.lastTime = lastTime; + } + + public String getLastMessage() { + return lastMessage; + } + + public void setLastMessage(String lastMessage) { + this.lastMessage = lastMessage; + } + + public List getMessages() { + return messages; + } + + public void setMessages(List messages) { + this.messages = messages; + } +} diff --git a/src/main/java/xyz/wbsite/achat/core/session/SessionService.java b/src/main/java/xyz/wbsite/achat/core/session/SessionService.java new file mode 100644 index 0000000..93d7fd5 --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/core/session/SessionService.java @@ -0,0 +1,56 @@ +package xyz.wbsite.achat.core.session; + +import xyz.wbsite.achat.core.chat.Message; + +import java.util.List; + +/** + * 会话管理服务接口 + * 提供会话的创建、删除、查询、消息收发等功能 + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ +public interface SessionService { + + /** + * 创建新会话 + * + * @param uid 用户编号 + * @return 创建的会话对象 + */ + Result createSession(String uid); + + /** + * 删除会话 + * + * @param sid 会话ID + * @return 删除是否成功 + */ + Result deleteSession(String uid, String sid); + + /** + * 查询会话列表 + * + * @param uid 用户编号 + * @return 会话列表 + */ + Result> listSessions(String uid); + + /** + * 获取会话详情 + * + * @param sid 会话ID + * @return 会话对象 + */ + Result getSession(String uid,String sid); + + /** + * 获取会话历史消息 + * + * @param sid 会话ID + * @return 消息列表 + */ + Result deleteMessage(String uid, String sid, String mid); +} \ No newline at end of file diff --git a/src/main/java/xyz/wbsite/achat/core/session/SessionServiceMemoryImpl.java b/src/main/java/xyz/wbsite/achat/core/session/SessionServiceMemoryImpl.java new file mode 100644 index 0000000..7dbebbb --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/core/session/SessionServiceMemoryImpl.java @@ -0,0 +1,176 @@ +package xyz.wbsite.achat.core.session; + +import org.springframework.stereotype.Service; + +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; +import java.util.stream.Collectors; + +/** + * 会话管理服务实现类 + * 实现会话的创建、删除、查询、消息收发等功能 + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ +@Service +public class SessionServiceMemoryImpl implements SessionService { + + /** + * 会话存储 + */ + private final List sessionStore = new ArrayList<>(); + + /** + * 消息存储 + */ + private final List messageStore = new ArrayList<>(); + + /** + * 创建新会话 + * + * @param uid 用户ID + * @return 结果 + */ + @Override + public Result createSession(String uid) { + // 检查参数合法性 + if (uid == null || uid.isEmpty()) { + return Result.error("用户ID不能为空"); + } + + Session session = new Session(); + session.setId(String.valueOf(UUID.randomUUID().toString())); + session.setUid(uid); + session.setTitle("新对话"); + sessionStore.add(session); + return Result.success(session); + } + + /** + * 删除会话 + * + * @param uid 用户ID + * @param sid 会话ID + * @return 结果 + */ + @Override + public Result deleteSession(String uid, String sid) { + // 检查参数合法性 + if (uid == null || uid.isEmpty() || sid == null || sid.isEmpty()) { + return Result.error("用户ID或会话ID不能为空"); + } + + // 从sessionStore查找比对uid和sid然后删除 + boolean sessionRemoved = sessionStore.removeIf(session -> + session.getId().equals(sid) && session.getUid().equals(uid) + ); + + // 从messageStore查找比对uid和sid然后删除 + messageStore.removeIf(message -> + message.getUid().equals(uid) && message.getSid().equals(sid) + ); + + if (!sessionRemoved) { + return Result.error("会话不存在或无权限删除"); + } + + return Result.success(); + } + + /** + * 查询会话列表 + * + * @param uid 用户ID + * @return + */ + /** + * 查询会话列表 + * + * @param uid 用户ID + * @return 会话列表 + */ + @Override + public Result> listSessions(String uid) { + // 检查参数合法性 + if (uid == null || uid.isEmpty()) { + return Result.error("用户ID不能为空"); + } + + // 从sessionStore查找比对uid返回所有会话 + List userSessions = sessionStore.stream() + .filter(session -> session.getUid().equals(uid)) + .collect(Collectors.toList()); + + return Result.success(userSessions); + } + + /** + * 获取会话详情 + * + * @param uid 用户ID + * @param sid 会话ID + * @return 会话详情 + */ + @Override + public Result getSession(String uid, String sid) { + // 检查参数合法性 + if (uid == null || uid.isEmpty() || sid == null || sid.isEmpty()) { + return Result.error("用户ID或会话ID不能为空"); + } + + // 从sessionStore查找比对uid和sid返回会话 + Session session = sessionStore.stream() + .filter(s -> s.getId().equals(sid) && s.getUid().equals(uid)) + .findFirst() + .orElse(null); + + if (session == null) { + return Result.error("会话不存在或无权限查看"); + } + + // 创建会话副本 + Session sessionCopy = new Session(); + sessionCopy.setId(session.getId()); + sessionCopy.setUid(session.getUid()); + sessionCopy.setTitle(session.getTitle()); + + // 从messageStore检索出当前会话的聊天记录 + List sessionMessages = messageStore.stream() + .filter(message -> message.getUid().equals(uid) && message.getSid().equals(sid)) + .collect(Collectors.toList()); + + sessionCopy.setMessages(sessionMessages); + + return Result.success(sessionCopy); + } + + /** + * 删除会话中的消息 + * + * @param uid 用户ID + * @param sid 会话ID + * @param mid 消息ID + * @return 结果 + */ + @Override + public Result deleteMessage(String uid, String sid, String mid) { + // 检查参数合法性 + if (uid == null || uid.isEmpty() || sid == null || sid.isEmpty() || mid == null || mid.isEmpty()) { + return Result.error("用户ID、会话ID或消息ID不能为空"); + } + + // 从messageStore查找比对uid、sid和mid然后删除 + boolean mr = messageStore.removeIf(message -> + message.getId().equals(mid) && message.getUid().equals(uid) && message.getSid().equals(sid) + ); + + if (!mr) { + return Result.error("消息不存在或无权限删除"); + } + + return Result.success(); + } +} \ No newline at end of file