commit c585f9899be5a90e96e20d69aad533ae763e80ac Author: wangbing Date: Mon Sep 1 13:52:37 2025 +0800 上传备份 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..13e058c --- /dev/null +++ b/.gitignore @@ -0,0 +1,21 @@ +target/ +pom.xml.tag +pom.xml.releaseBackup +pom.xml.versionsBackup +pom.xml.next +release.properties +/.idea +*.iml +/.settings +/bin +/gen +/build +/gradle +/classes +.classpath +.project +*.gradle +gradlew +local.properties +node_modules/ +data/ diff --git a/pom.xml b/pom.xml new file mode 100644 index 0000000..bfc5764 --- /dev/null +++ b/pom.xml @@ -0,0 +1,81 @@ + + + 4.0.0 + + org.springframework.boot + spring-boot-starter-parent + 2.7.13 + + + xyz.wbsite.achat + starter-achat + 0.0.1-SNAPSHOT + jar + starter-achat + An abstract chat starter + + + UTF-8 + UTF-8 + 8 + true + 2.7.13 + + + + + + aliyun + Aliyun Repository + default + https://maven.aliyun.com/repository/public + + + + + + aliyun + Aliyun Repository + https://maven.aliyun.com/repository/public + default + + + + + + org.springframework.boot + spring-boot-starter-web + + + + org.springframework.boot + spring-boot-starter-test + test + + + + + + + + + + org.springframework.boot + spring-boot-dependencies + ${spring-boot.version} + pom + import + + + + + + + ${project.artifactId} + + src/main/java + + src/test/java + + diff --git a/src/main/java/xyz/wbsite/Application.java b/src/main/java/xyz/wbsite/Application.java new file mode 100644 index 0000000..d82caa6 --- /dev/null +++ b/src/main/java/xyz/wbsite/Application.java @@ -0,0 +1,23 @@ +package xyz.wbsite; + +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; + +/** + * 应用启动入口 + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ +@SpringBootApplication +public class Application { + + /** + * 程序入口 + */ + public static void main(String[] args) { + SpringApplication application = new SpringApplication(Application.class); + application.run(args); + } +} \ No newline at end of file diff --git a/src/main/java/xyz/wbsite/achat/ChatController.java b/src/main/java/xyz/wbsite/achat/ChatController.java new file mode 100644 index 0000000..48d5282 --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/ChatController.java @@ -0,0 +1,112 @@ +package xyz.wbsite.achat; + +import org.springframework.web.bind.annotation.DeleteMapping; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.ResponseBody; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; +import xyz.wbsite.achat.core.base.Message; +import xyz.wbsite.achat.core.base.Result; +import xyz.wbsite.achat.core.base.Session; +import xyz.wbsite.achat.core.message.UserMessage; +import xyz.wbsite.achat.core.prompt.MessagePrompt; +import xyz.wbsite.achat.core.service.SessionService; + +import javax.annotation.Resource; +import java.util.List; +import java.util.UUID; + +/** + * AI会话接口. + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ +@RestController +public class ChatController { + + @Resource + private SessionService sessionService; + + /** + * 门户 + */ + @GetMapping("/chat") + public String chat() { + return "AChat is running!"; + } + + /** + * 发送消息(流式响应) + * + * @return SSE发射器,用于流式响应 + */ + @PostMapping(value = "/chat/send", produces = "text/event-stream;charset=UTF-8") + public SseEmitter send(@RequestBody MessagePrompt message) { + return sessionService.sendMessage(message); + } + + /** + * 创建会话 + * + * @return 创建会话响应 + */ + @PostMapping("/chat/session") + public Result createSession() { + String sid = UUID.randomUUID().toString(); + return sessionService.createSession(sid); + } + + /** + * 获取会话列表 + * + * @return 会话列表响应 + */ + @ResponseBody + @GetMapping("/chat/session/list") + public Result> listSession() { + String sid = UUID.randomUUID().toString(); + return sessionService.listSessions(sid); + } + + /** + * 删除会话 + * + * @param sid 会话ID + * @return 删除会话响应 + */ + @ResponseBody + @DeleteMapping("/chat/session/{sid}") + public Result deleteSession(@PathVariable("sid") String sid) { + return sessionService.deleteSession(sid); + } + + + /** + * 获取会话历史消息 + * + * @param sid 会话ID + * @return 历史消息响应 + */ + @ResponseBody + @GetMapping("/chat/session/{sid}/history") + public Result> getSessionHistory(@PathVariable("sid") String sid) { + return sessionService.listMessage(sid); + } + + /** + * 停止AI回答 + * + * @param sid 会话ID + * @return 停止回答响应 + */ + @ResponseBody + @PostMapping("/chat/session/{sid}/stop") + public Result stopSession(@PathVariable("sid") String sid) { + return sessionService.stopSession(sid); + } +} diff --git a/src/main/java/xyz/wbsite/achat/config/WebMvcConfig.java b/src/main/java/xyz/wbsite/achat/config/WebMvcConfig.java new file mode 100644 index 0000000..86e7d62 --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/config/WebMvcConfig.java @@ -0,0 +1,35 @@ +package xyz.wbsite.achat.config; + +import org.springframework.context.annotation.Configuration; +import org.springframework.web.servlet.config.annotation.CorsRegistry; +import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; + +/** + * Web相关配置(跨域配置,序列化配置). + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ +@Configuration +public class WebMvcConfig implements WebMvcConfigurer { + + /** + * 跨域配置 + * + * @param registry + */ + @Override + public void addCorsMappings(CorsRegistry registry) { + // 注意,如果授权认认证未通过会直接返回,此跨域配置则不会生效,前端仍然会提示跨域 + registry.addMapping("/chat/**") + //允许的域,不要写*,否则cookie就无法使用了 + .allowedOriginPatterns("http://localhost:5173") + .allowedHeaders("*") + //是否发送Cookie + .allowCredentials(true) + .allowedMethods("GET", "POST", "DELETE", "PUT") + .exposedHeaders("*") + .maxAge(3600); + } +} diff --git a/src/main/java/xyz/wbsite/achat/core/base/Attachment.java b/src/main/java/xyz/wbsite/achat/core/base/Attachment.java new file mode 100644 index 0000000..0549f73 --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/core/base/Attachment.java @@ -0,0 +1,30 @@ +package xyz.wbsite.achat.core.base; + +/** + * 附件 + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ +public class Attachment { + private String filename; + + private String fid; + + public String getFilename() { + return filename; + } + + public void setFilename(String filename) { + this.filename = filename; + } + + public String getFid() { + return fid; + } + + public void setFid(String fid) { + this.fid = fid; + } +} diff --git a/src/main/java/xyz/wbsite/achat/core/base/Event.java b/src/main/java/xyz/wbsite/achat/core/base/Event.java new file mode 100644 index 0000000..234a06d --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/core/base/Event.java @@ -0,0 +1,59 @@ +package xyz.wbsite.achat.core.base; + + +import java.util.Date; + +/** + * 服务器推送事件 + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ +public class Event { + + private String sid; + + private Date time; + + private String text; + + private String type; + + public Event(String sid) { + this.sid = sid; + this.time = new Date(); + } + + public String getSid() { + return sid; + } + + public void setSid(String sid) { + this.sid = sid; + } + + public Date getTime() { + return time; + } + + public void setTime(Date time) { + this.time = time; + } + + public String getText() { + return text; + } + + public void setText(String text) { + this.text = text; + } + + public String getType() { + return type; + } + + public void setType(String type) { + this.type = type; + } +} diff --git a/src/main/java/xyz/wbsite/achat/core/base/Message.java b/src/main/java/xyz/wbsite/achat/core/base/Message.java new file mode 100644 index 0000000..1a7031e --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/core/base/Message.java @@ -0,0 +1,52 @@ +package xyz.wbsite.achat.core.base; + +import xyz.wbsite.achat.enums.Role; + +/** + * 消息 + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ +public class Message { + + /** + * 角色 + */ + private Role role; + + /** + * 会话ID + */ + private String sid; + + /** + * 消息 + */ + private String content; + + public Role getRole() { + return role; + } + + public void setRole(Role role) { + this.role = role; + } + + public String getSid() { + return sid; + } + + public void setSid(String sid) { + this.sid = sid; + } + + public String getContent() { + return content; + } + + public void setContent(String content) { + this.content = content; + } +} diff --git a/src/main/java/xyz/wbsite/achat/core/base/Prompt.java b/src/main/java/xyz/wbsite/achat/core/base/Prompt.java new file mode 100644 index 0000000..7241df2 --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/core/base/Prompt.java @@ -0,0 +1,17 @@ +package xyz.wbsite.achat.core.base; + +public class Prompt { + + /** + * 提示词 + */ + private String prompt; + + public String getPrompt() { + return prompt; + } + + public void setPrompt(String prompt) { + this.prompt = prompt; + } +} diff --git a/src/main/java/xyz/wbsite/achat/core/base/Result.java b/src/main/java/xyz/wbsite/achat/core/base/Result.java new file mode 100644 index 0000000..4ab91f7 --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/core/base/Result.java @@ -0,0 +1,209 @@ +package xyz.wbsite.achat.core.base; + +import java.util.HashMap; +import java.util.Map; + +/** + * 接口响应结果基类 + * 泛型支持的数据响应封装,提供统一的响应格式和错误处理 + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ +public class Result { + + /** + * 响应状态码 + */ + private int code = 200; + + /** + * 响应消息 + */ + private String message = "success"; + + /** + * 响应数据 + */ + private T data; + + /** + * 是否成功 + */ + private boolean success = true; + + /** + * 错误详情,用于表单验证等场景 + */ + private Map errors; + + /** + * 请求ID,用于问题追踪 + */ + private String requestId; + + /** + * 响应时间戳 + */ + private long timestamp = System.currentTimeMillis(); + + public int getCode() { + return code; + } + + public Result setCode(int code) { + this.code = code; + return this; + } + + public String getMessage() { + return message; + } + + public Result setMessage(String message) { + this.message = message; + return this; + } + + public T getData() { + return data; + } + + public Result setData(T data) { + this.data = data; + return this; + } + + public boolean isSuccess() { + return success; + } + + public Result setSuccess(boolean success) { + this.success = success; + return this; + } + + public Map getErrors() { + return errors; + } + + public Result setErrors(Map errors) { + this.errors = errors; + return this; + } + + public String getRequestId() { + return requestId; + } + + public Result setRequestId(String requestId) { + this.requestId = requestId; + return this; + } + + public long getTimestamp() { + return timestamp; + } + + public Result setTimestamp(long timestamp) { + this.timestamp = timestamp; + return this; + } + + /** + * 添加字段级别的错误信息 + * + * @param field 字段名 + * @param error 错误信息 + * @return 当前结果对象,支持链式调用 + */ + public Result addError(String field, String error) { + if (errors == null) { + errors = new HashMap<>(); + } + errors.put(field, error); + return this; + } + + /** + * 返回成功信息 + * + * @return 结果 + */ + public static Result success() { + return new Result<>(); + } + + /** + * 返回带数据的成功信息 + * + * @param data 响应数据 + * @return 结果 + */ + public static Result success(T data) { + Result result = new Result<>(); + result.setData(data); + return result; + } + + /** + * 返回错误信息 + * + * @param message 错误信息 + * @return 错误信息对象 + */ + public static Result error(String message) { + Result result = new Result<>(); + result.message = message; + result.code = 500; + result.success = false; + return result; + } + + /** + * 返回错误信息 + * + * @param code 错误码 + * @param message 错误信息 + * @return 错误信息对象 + */ + public static Result error(int code, String message) { + Result result = new Result<>(); + result.code = code; + result.message = message; + result.success = false; + return result; + } + + /** + * 返回带数据的错误信息 + * + * @param code 错误码 + * @param message 错误信息 + * @param data 错误相关数据 + * @return 错误信息对象 + */ + public static Result error(int code, String message, T data) { + Result result = new Result<>(); + result.code = code; + result.message = message; + result.success = false; + result.data = data; + return result; + } + + /** + * 从异常创建错误响应 + * + * @param e 异常 + * @return 错误信息对象 + */ + public static Result error(Exception e) { + Result result = new Result<>(); + result.code = 500; + result.message = e.getMessage() != null ? e.getMessage() : "系统异常"; + result.success = false; + return result; + } +} \ No newline at end of file diff --git a/src/main/java/xyz/wbsite/achat/core/base/Session.java b/src/main/java/xyz/wbsite/achat/core/base/Session.java new file mode 100644 index 0000000..d7be327 --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/core/base/Session.java @@ -0,0 +1,126 @@ +package xyz.wbsite.achat.core.base; + + +/** + * 会话 + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ +public class Session { + /** + * 主键 + */ + private String id; + /** + * 用户ID + */ + private String uid; + private String title; + private String model; + private String prompt; + private String temperature; + private String topP; + private String frequencyPenalty; + private String presencePenalty; + private String maxTokens; + private String lastTime; + private String lastMessage; + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getUid() { + return uid; + } + + public void setUid(String uid) { + this.uid = uid; + } + + public String getTitle() { + return title; + } + + public void setTitle(String title) { + this.title = title; + } + + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + public String getPrompt() { + return prompt; + } + + public void setPrompt(String prompt) { + this.prompt = prompt; + } + + public String getTemperature() { + return temperature; + } + + public void setTemperature(String temperature) { + this.temperature = temperature; + } + + public String getTopP() { + return topP; + } + + public void setTopP(String topP) { + this.topP = topP; + } + + public String getFrequencyPenalty() { + return frequencyPenalty; + } + + public void setFrequencyPenalty(String frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + public String getPresencePenalty() { + return presencePenalty; + } + + public void setPresencePenalty(String presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public String getMaxTokens() { + return maxTokens; + } + + public void setMaxTokens(String maxTokens) { + this.maxTokens = maxTokens; + } + + public String getLastTime() { + return lastTime; + } + + public void setLastTime(String lastTime) { + this.lastTime = lastTime; + } + + public String getLastMessage() { + return lastMessage; + } + + public void setLastMessage(String lastMessage) { + this.lastMessage = lastMessage; + } +} diff --git a/src/main/java/xyz/wbsite/achat/core/event/CompleteEvent.java b/src/main/java/xyz/wbsite/achat/core/event/CompleteEvent.java new file mode 100644 index 0000000..3b9864d --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/core/event/CompleteEvent.java @@ -0,0 +1,22 @@ +package xyz.wbsite.achat.core.event; + +import xyz.wbsite.achat.core.base.Event; + +/** + * 完成事件 + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ +public class CompleteEvent extends Event { + + public CompleteEvent(String sid) { + super(sid); + setType("complete"); + } + + public static CompleteEvent of(String sid) { + return new CompleteEvent(sid); + } +} diff --git a/src/main/java/xyz/wbsite/achat/core/event/PartialEvent.java b/src/main/java/xyz/wbsite/achat/core/event/PartialEvent.java new file mode 100644 index 0000000..dd6ac2b --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/core/event/PartialEvent.java @@ -0,0 +1,29 @@ +package xyz.wbsite.achat.core.event; + +import xyz.wbsite.achat.core.base.Event; + +/** + * 部分消息事件 + *

+ * 用于流式会话下表示AI回复的部分消息 + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ +public class PartialEvent extends Event { + public PartialEvent(String sid) { + super(sid); + setType("partial"); + } + + public PartialEvent(String sid, String partial) { + super(sid); + setType("partial"); + setText(partial); + } + + public static PartialEvent of(String sid, String partial) { + return new PartialEvent(sid, partial); + } +} diff --git a/src/main/java/xyz/wbsite/achat/core/message/AiMessage.java b/src/main/java/xyz/wbsite/achat/core/message/AiMessage.java new file mode 100644 index 0000000..65a4fc2 --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/core/message/AiMessage.java @@ -0,0 +1,15 @@ +package xyz.wbsite.achat.core.message; + + +import xyz.wbsite.achat.core.base.Message; + +/** + * 用户消息 + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ +public class AiMessage extends Message { + +} diff --git a/src/main/java/xyz/wbsite/achat/core/message/MessageSseEmitter.java b/src/main/java/xyz/wbsite/achat/core/message/MessageSseEmitter.java new file mode 100644 index 0000000..c873a05 --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/core/message/MessageSseEmitter.java @@ -0,0 +1,176 @@ +package xyz.wbsite.achat.core.message; + +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; +import xyz.wbsite.achat.core.base.Event; +import xyz.wbsite.achat.core.event.CompleteEvent; +import xyz.wbsite.achat.core.event.PartialEvent; +import xyz.wbsite.achat.core.prompt.MessagePrompt; +import xyz.wbsite.achat.core.service.MessageGenerator; + +import java.io.IOException; + +/** + * 对话SSE发射器 + * 只负责SSE通信,不包含业务逻辑 + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ +public class MessageSseEmitter extends SseEmitter { + + /** + * 用户消息 + */ + private MessagePrompt messagePrompt; + + /** + * 是否完成 + */ + private boolean complete; + + /** + * AI回答 + */ + private final StringBuilder answer = new StringBuilder(); + + /** + * 消息处理器 + */ + private MessageGenerator messageGenerator; + + /** + * 构造函数 + * + * @param message 消息对象 + */ + public MessageSseEmitter(MessagePrompt message, MessageGenerator processor) { + super(0L); + this.messagePrompt = message; + this.messageGenerator = processor; +// TaskUtil.taskAsync(task); + } + + /** + * 消息处理任务 + */ +// public Runnable task = () -> { +// try { +// // 检查会话是否存在 +// if (!messageProcessor.checkSessionExists(message.getChatId(), uid)) { +// this.sendMessage(createPartialMessage("当前会话不存在,请刷新后再试!")); +// return; +// } +// +// // 更新新会话的标题(如果为空) +// messageProcessor.updateSessionTitleIfEmpty(message.getChatId(), message.getText(), uid); +// +// // 保存本次用户消息 +// messageProcessor.saveUserMessage(message.getChatId(), message.getText(), uid); +// +// // 根据是否有附件选择不同的处理方式 +// TokenStream tokenStream; +// if (this.hasAttachment()) { +// String attachment = messageProcessor.parseAttachment(message.getAttachments(), uid); +// tokenStream = messageProcessor.createAssistantStreamWithAttachment( +// message.getChatId(), message.getText(), attachment, uid); +// } else { +// tokenStream = messageProcessor.createAssistantStream( +// message.getChatId(), message.getText(), uid); +// } +// +// // 设置流回调 +// tokenStream +// .onPartialResponse(this::onPartialResponse) +// .onCompleteResponse(this::onCompleteResponse) +// .onError(this::onError) +// .start(); +// } catch (Exception e) { +// onError(e); +// } +// }; + + /** + * 错误处理 + */ + private void onError(Throwable e) { + this.sendMessage(createPartialMessage("" + e.getMessage())); + this.answer.append("" + e.getMessage()); + this.onCompleteResponse(null); + } + + /** + * 部分响应处理 + */ + public void onPartialResponse(String msg) { + if (complete) { + return; + } + this.sendMessage(createPartialMessage(msg)); + this.answer.append(msg); + } + + /** + * 完成响应处理 + */ + public void onCompleteResponse(Object chatResponse) { + if (this.complete) { + return; + } + // 推送结束 + this.sendMessage(createCompleteMessage()); + // 关闭链接 + this.complete(); + } + + /** + * 创建部分消息事件 + */ + private Event createPartialMessage(String partial) { + return new PartialEvent(messagePrompt.getSid(), partial); + } + + /** + * 创建完成消息事件 + */ + private Event createCompleteMessage() { + return new CompleteEvent(messagePrompt.getSid()); + } + + /** + * 重写send方法,处理异常 + */ + @Override + public void send(SseEventBuilder builder) throws IOException { + try { + super.send(builder); + } catch (Exception e) { + complete = true; + } + } + + /** + * 发送消息 + */ + private void sendMessage(Event message) { + try { + this.send(message); + } catch (Exception e) { + complete = true; + } + } + + /** + * 获取完成状态 + */ + public boolean isComplete() { + return complete; + } + + /** + * 设置完成状态 + */ + public void setComplete(boolean complete) { + this.complete = complete; + } +} diff --git a/src/main/java/xyz/wbsite/achat/core/message/UserMessage.java b/src/main/java/xyz/wbsite/achat/core/message/UserMessage.java new file mode 100644 index 0000000..ea6172f --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/core/message/UserMessage.java @@ -0,0 +1,42 @@ +package xyz.wbsite.achat.core.message; + +import xyz.wbsite.achat.core.base.Attachment; +import xyz.wbsite.achat.core.base.Message; + +import java.util.List; + +/** + * 用户消息 + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ +public class UserMessage extends Message { + + /** + * 用户ID + */ + private String uid; + + /** + * 附件 + */ + private List attachments; + + public String getUid() { + return uid; + } + + public void setUid(String uid) { + this.uid = uid; + } + + public List getAttachments() { + return attachments; + } + + public void setAttachments(List attachments) { + this.attachments = attachments; + } +} diff --git a/src/main/java/xyz/wbsite/achat/core/prompt/MessagePrompt.java b/src/main/java/xyz/wbsite/achat/core/prompt/MessagePrompt.java new file mode 100644 index 0000000..8adf3e2 --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/core/prompt/MessagePrompt.java @@ -0,0 +1,106 @@ +package xyz.wbsite.achat.core.prompt; + +import xyz.wbsite.achat.core.base.Message; +import xyz.wbsite.achat.core.base.Prompt; +import xyz.wbsite.achat.core.message.UserMessage; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public class MessagePrompt extends Prompt { + + /** + * 模型 + */ + private String model; + + /** + * 消息列表 + */ + private List messages; + + /** + * 是否流式返回 + */ + private Boolean stream; + + /** + * 最大token数 + */ + private Integer maxTokens; + + /** + * 温度参 + */ + private Double temperature; + + /** + * 额外参数 + */ + private Map extraParams; + + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + public List getMessages() { + return messages; + } + + public void setMessages(List messages) { + this.messages = messages; + } + + public Boolean getStream() { + return stream; + } + + public void setStream(Boolean stream) { + this.stream = stream; + } + + public Integer getMaxTokens() { + return maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + public Double getTemperature() { + return temperature; + } + + public void setTemperature(Double temperature) { + this.temperature = temperature; + } + + public Map getExtraParams() { + return extraParams; + } + + public void setExtraParams(Map extraParams) { + this.extraParams = extraParams; + } + + + public UserMessage getLastUserMessage() { + List messageList = messages.stream().filter(message -> message instanceof UserMessage).collect(Collectors.toList()); + UserMessage userMessage = (UserMessage)messageList.get(messageList.size() - 1); + return userMessage; + } + + public String getUid() { + return getLastUserMessage().getUid(); + } + + public String getSid(){ + return getLastUserMessage().getSid(); + } + +} diff --git a/src/main/java/xyz/wbsite/achat/core/service/MessageGenerator.java b/src/main/java/xyz/wbsite/achat/core/service/MessageGenerator.java new file mode 100644 index 0000000..c08ef2c --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/core/service/MessageGenerator.java @@ -0,0 +1,18 @@ +package xyz.wbsite.achat.core.service; + +import xyz.wbsite.achat.core.message.MessageSseEmitter; +import xyz.wbsite.achat.core.base.Message; + +/** + * 消息生成器 + *

+ * 抽象出来用于生成消息的实现层 + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ +public interface MessageGenerator { + + void on(MessageSseEmitter emitter, Message message); +} \ No newline at end of file diff --git a/src/main/java/xyz/wbsite/achat/core/service/SessionService.java b/src/main/java/xyz/wbsite/achat/core/service/SessionService.java new file mode 100644 index 0000000..8ccad7c --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/core/service/SessionService.java @@ -0,0 +1,85 @@ +package xyz.wbsite.achat.core.service; + +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; +import xyz.wbsite.achat.core.base.Message; +import xyz.wbsite.achat.core.base.Result; +import xyz.wbsite.achat.core.base.Session; +import xyz.wbsite.achat.core.message.UserMessage; +import xyz.wbsite.achat.core.prompt.MessagePrompt; + +import java.util.List; + +/** + * 会话管理服务接口 + * 提供会话的创建、删除、查询、消息收发等功能 + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ +public interface SessionService { + + /** + * 创建新会话 + * + * @param uid 用户编号 + * @return 创建的会话对象 + */ + Result createSession(String uid); + + /** + * 删除会话 + * + * @param sid 会话ID + * @return 删除是否成功 + */ + Result deleteSession(String sid); + + /** + * 查询会话列表 + * + * @param uid 用户编号 + * @return 会话列表 + */ + Result> listSessions(String uid); + + /** + * 获取会话详情 + * + * @param sid 会话ID + * @return 会话对象 + */ + Result getSession(String sid); + + /** + * 停止会话 + * + * @param sid 会话ID + * @return 会话对象 + */ + Result stopSession(String sid); + + /** + * 发送消息并获取流式响应 + * + * @param message 消息对象 + * @return SSE发射器,用于流式响应 + */ + SseEmitter sendMessage(MessagePrompt message); + + /** + * 获取会话历史消息 + * + * @param sid 会话ID + * @return 消息列表 + */ + Result> listMessage(String sid); + + /** + * 获取会话历史消息 + * + * @param sid 会话ID + * @return 消息列表 + */ + Result deleteMessage(String sid); +} \ No newline at end of file diff --git a/src/main/java/xyz/wbsite/achat/core/service/impl/SessionServiceMemoryImpl.java b/src/main/java/xyz/wbsite/achat/core/service/impl/SessionServiceMemoryImpl.java new file mode 100644 index 0000000..cfafb08 --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/core/service/impl/SessionServiceMemoryImpl.java @@ -0,0 +1,122 @@ +package xyz.wbsite.achat.core.service.impl; + +import org.springframework.stereotype.Service; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; +import xyz.wbsite.achat.core.message.MessageSseEmitter; +import xyz.wbsite.achat.core.base.Message; +import xyz.wbsite.achat.core.base.Result; +import xyz.wbsite.achat.core.base.Session; +import xyz.wbsite.achat.core.message.UserMessage; +import xyz.wbsite.achat.core.prompt.MessagePrompt; +import xyz.wbsite.achat.core.service.SessionService; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Collectors; + +/** + * 会话管理服务实现类 + * 实现会话的创建、删除、查询、消息收发等功能 + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ +@Service +public class SessionServiceMemoryImpl implements SessionService { + + private final Map sessionStore = new HashMap<>(); + + private final List messageStore = new ArrayList<>(); + + /** + * 创建新会话 + */ + @Override + public Result createSession(String uid) { + Session session = new Session(); + session.setId(String.valueOf(UUID.randomUUID().toString())); + session.setUid(uid); + session.setTitle("新对话"); + sessionStore.put(session.getId(), session); + return Result.success(session); + } + + /** + * 删除会话 + */ + @Override + public Result deleteSession(String sid) { + sessionStore.remove(sid); + return Result.success(); + } + + /** + * 查询会话列表 + */ + @Override + public Result> listSessions(String uid) { + List collect = sessionStore.values() + .stream() + .filter(item -> item.getUid().equals(uid)) + .collect(Collectors.toList()); + return Result.success(collect); + } + + /** + * 获取会话详情 + */ + @Override + public Result getSession(String sid) { + Session session = sessionStore.get(sid); + if (session == null) { + return Result.error("会话不存在"); + } + return Result.success(session); + } + + /** + * 发送消息并获取流式响应 + */ + @Override + public SseEmitter sendMessage(MessagePrompt message) { + // 创建VChatSseEmitter来处理流式响应 + return new MessageSseEmitter(message, (emitter, message1) -> { + // 这边模拟LLM复述一遍用户问题 + String text = message1.getContent(); + for (char c : text.toCharArray()) { + if (emitter.isComplete()) { + return; + } + emitter.onPartialResponse(String.valueOf(c)); + try { + Thread.sleep(100); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + emitter.onCompleteResponse(null); + }); + } + + @Override + public Result stopSession(String sid) { + // todo + return Result.success(); + } + + @Override + public Result> listMessage(String sid) { + List messages = messageStore.stream().filter(item -> item.getSid().equals(sid)).collect(Collectors.toList()); + return Result.success(messages); + } + + @Override + public Result deleteMessage(String sid) { + messageStore.forEach(item -> messageStore.remove(item)); + return null; + } +} \ No newline at end of file diff --git a/src/main/java/xyz/wbsite/achat/enums/Role.java b/src/main/java/xyz/wbsite/achat/enums/Role.java new file mode 100644 index 0000000..1d24e85 --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/enums/Role.java @@ -0,0 +1,41 @@ +package xyz.wbsite.achat.enums; + +import com.fasterxml.jackson.annotation.JsonValue; + +/** + * 消息角色 + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ +public enum Role { + USER("user"), + ASSISTANT("assistant"), + SYSTEM("system"), + FUNCTION("function"), + TOOL("tool"), + UNKNOWN("unknown"); + + private final String value; + + Role(String value) { + this.value = value; + } + + // 序列化时返回小写字符串 + @JsonValue + public String getValue() { + return value; + } + + // 反序列化时根据字符串匹配枚举 + public static Role fromValue(String value) { + for (Role role : Role.values()) { + if (role.value.equalsIgnoreCase(value)) { + return role; + } + } + throw new IllegalArgumentException("Invalid Role value: " + value); + } +} diff --git a/src/main/java/xyz/wbsite/achat/enums/Status.java b/src/main/java/xyz/wbsite/achat/enums/Status.java new file mode 100644 index 0000000..6c564f0 --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/enums/Status.java @@ -0,0 +1,23 @@ +package xyz.wbsite.achat.enums; + +/** + * 消息状态枚举 + */ +public enum Status { + /** + * 待处理 + */ + PENDING, + /** + * 成功 + */ + SUCCESS, + /** + * 错误 + */ + ERROR, + /** + * 取消 + */ + CANCELLED +} diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties new file mode 100644 index 0000000..132a9e5 --- /dev/null +++ b/src/main/resources/application.properties @@ -0,0 +1,2 @@ +server.port=8080 +spring.application.name=achat diff --git a/src/test/java/xyz/wbsite/achat/TestConfig.java b/src/test/java/xyz/wbsite/achat/TestConfig.java new file mode 100644 index 0000000..278d71d --- /dev/null +++ b/src/test/java/xyz/wbsite/achat/TestConfig.java @@ -0,0 +1,19 @@ +package xyz.wbsite.achat; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Configuration; + +import javax.annotation.PostConstruct; + +@Configuration +public class TestConfig { + + @Autowired + private ApplicationContext applicationContext; + + @PostConstruct + public void initLocalData() { + + } +} \ No newline at end of file diff --git a/src/test/resources/application.properties b/src/test/resources/application.properties new file mode 100644 index 0000000..1586721 --- /dev/null +++ b/src/test/resources/application.properties @@ -0,0 +1,96 @@ +# 激活环境 dev:开发,prod:生产 +spring.profiles.active=dev +# 启动完成打开游览器 +spring.profiles.open.web=false + +# 应用名称 +spring.application.name=aipro_sv + +# 停机配置 +# - 停机模式 graceful:优雅关机;immediate:立即关机 +server.shutdown=graceful +# - 停机最大等待时间 +spring.lifecycle.timeout-per-shutdown-phase=30s +# servlet配置 +server.servlet.encoding.enabled=true +server.servlet.encoding.force=true +# - 编码配置 +server.servlet.encoding.charset=UTF-8 +# - 上下文配置 +server.servlet.context-path= +# tomcat配置 +server.tomcat.uri-encoding=UTF-8 +server.tomcat.max-http-form-post-size=-1 + +# 允许循环引用,部分依赖因为存在循环依赖,暂时不能全局解决,暂时关闭循环检查 +spring.main.allow-circular-references=true +# 路径匹配模式,默认path_pattern_parser虽然高效,但与ant不兼容,由于项目中用到了ant语法,暂时保留 +spring.mvc.pathmatch.matching-strategy=ant_path_matcher + +# 性能优化配置 +# - 开启Gzip压缩 +server.compression.enabled=true +# - 路径匹配配置 +spring.mvc.static-path-pattern=/static/** +# - 资源目录配置 +spring.web.resources.static-locations=classpath:/META-INF/resources/,classpath:/resources/,classpath:/static/,classpath:/public/,file:${app.static.custom.path} +# - 启用静态资源 +spring.web.resources.add-mappings=true +# - 启用客户端缓存 +spring.web.resources.chain.enabled=true +spring.web.resources.chain.cache=true +spring.web.resources.chain.compressed=true +# - 客户端缓存时间 +spring.web.resources.cache.period= 3600 + +# 序列化配置 +spring.jackson.date-format=yyyy-MM-dd HH:mm:ss +# - 时区配置 +spring.jackson.time-zone=GMT+8 +# - 属性为null时不序列化 +spring.jackson.default-property-inclusion=non_null +# - 排序 +spring.jackson.mapper.sort-properties-alphabetically=true +# - 忽略不对应属性错误 +spring.jackson.deserialization.fail-on-unknown-properties=false + +# 文件上传配置 +# - 文件上传时禁止懒加载 +spring.servlet.multipart.resolveLazily=false +# - 文件上传大小限制 +spring.servlet.multipart.max-file-size=100MB +spring.servlet.multipart.max-request-size=100MB + +# 分页工具配置 +pagehelper.autoRuntimeDialect=true +pagehelper.reasonable=false +pagehelper.supportMethodsArguments=true +pagehelper.params=count=countSql + +# 页面模板配置 +spring.freemarker.enabled=true +# - 模板资源路径 +spring.freemarker.template-loader-path=classpath:/views/ +# - 模板后缀 +spring.freemarker.suffix=.ftl +spring.freemarker.allow-request-override=false +# - 模板可缓存 +spring.freemarker.cache=true +spring.freemarker.check-template-location=true +spring.freemarker.charset=UTF-8 +spring.freemarker.content-type=text/html +spring.freemarker.expose-request-attributes=false +spring.freemarker.expose-session-attributes=false +spring.freemarker.expose-spring-macro-helpers=false +spring.freemarker.settings.template_update_delay=1 +spring.freemarker.settings.locale=zh_CN +# - 日期时间格式化 +spring.freemarker.settings.datetime_format=yyyy-MM-dd HH:mm:ss +# - 日期格式化 +spring.freemarker.settings.date_format=yyyy-MM-dd +# - 数字格式化 +spring.freemarker.settings.number_format=#.## +# - 启用兼容模式 +spring.freemarker.settings.classic_compatible=true +spring.freemarker.settings.whitespace_stripping=true +spring.freemarker.settings.url_escaping_charset=utf-8 diff --git a/src/test/resources/logback-spring.xml b/src/test/resources/logback-spring.xml new file mode 100644 index 0000000..5d80b9f --- /dev/null +++ b/src/test/resources/logback-spring.xml @@ -0,0 +1,22 @@ + + + + + + debug级别及以上 + + DEBUG + + + %highlight(%d{yyyy-MM-dd HH:mm:ss.SSS} [%-4level] [%thread] [%logger{36}-%method] %ex %msg%n) + + + + + + + + + + +