From 5bd40072ba0f3f7d963fe8dddc99da99e59fd039 Mon Sep 17 00:00:00 2001 From: wangbing Date: Mon, 1 Sep 2025 14:34:15 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E5=A4=87=E4=BB=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pom.xml | 2 +- .../java/xyz/wbsite/achat/ChatController.java | 47 ++++--- .../wbsite/achat/config/WebFluxConfig.java | 44 +++++++ .../xyz/wbsite/achat/config/WebMvcConfig.java | 35 ----- .../achat/core/message/MessageSseEmitter.java | 122 +++++++++--------- .../achat/core/prompt/MessagePrompt.java | 11 +- .../achat/core/service/SessionService.java | 7 +- .../impl/SessionServiceMemoryImpl.java | 28 ++-- 8 files changed, 163 insertions(+), 133 deletions(-) create mode 100644 src/main/java/xyz/wbsite/achat/config/WebFluxConfig.java delete mode 100644 src/main/java/xyz/wbsite/achat/config/WebMvcConfig.java diff --git a/pom.xml b/pom.xml index bfc5764..9bd8e9a 100644 --- a/pom.xml +++ b/pom.xml @@ -45,7 +45,7 @@ org.springframework.boot - spring-boot-starter-web + spring-boot-starter-webflux diff --git a/src/main/java/xyz/wbsite/achat/ChatController.java b/src/main/java/xyz/wbsite/achat/ChatController.java index 48d5282..002e97e 100644 --- a/src/main/java/xyz/wbsite/achat/ChatController.java +++ b/src/main/java/xyz/wbsite/achat/ChatController.java @@ -5,13 +5,14 @@ import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; -import org.springframework.web.bind.annotation.ResponseBody; import org.springframework.web.bind.annotation.RestController; -import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; import xyz.wbsite.achat.core.base.Message; import xyz.wbsite.achat.core.base.Result; import xyz.wbsite.achat.core.base.Session; -import xyz.wbsite.achat.core.message.UserMessage; import xyz.wbsite.achat.core.prompt.MessagePrompt; import xyz.wbsite.achat.core.service.SessionService; @@ -36,17 +37,19 @@ public class ChatController { * 门户 */ @GetMapping("/chat") - public String chat() { - return "AChat is running!"; + public Mono chat() { + return Mono.just("AChat is running!"); } /** * 发送消息(流式响应) * - * @return SSE发射器,用于流式响应 + * @return 响应式流,用于流式响应 */ - @PostMapping(value = "/chat/send", produces = "text/event-stream;charset=UTF-8") - public SseEmitter send(@RequestBody MessagePrompt message) { + @PostMapping(value = "/chat/send", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + public Flux> send(@RequestBody MessagePrompt message) { + // 注意:这里需要SessionService的sendMessage方法返回Flux> + // 由于我们没有修改SessionService,暂时返回一个模拟的流 return sessionService.sendMessage(message); } @@ -56,9 +59,11 @@ public class ChatController { * @return 创建会话响应 */ @PostMapping("/chat/session") - public Result createSession() { + public Mono> createSession() { String sid = UUID.randomUUID().toString(); - return sessionService.createSession(sid); + // 注意:这里需要SessionService的createSession方法返回Mono> + // 由于我们没有修改SessionService,暂时包装为Mono + return Mono.just(sessionService.createSession(sid)); } /** @@ -66,11 +71,12 @@ public class ChatController { * * @return 会话列表响应 */ - @ResponseBody @GetMapping("/chat/session/list") - public Result> listSession() { + public Mono>> listSession() { String sid = UUID.randomUUID().toString(); - return sessionService.listSessions(sid); + // 注意:这里需要SessionService的listSessions方法返回Mono>> + // 由于我们没有修改SessionService,暂时包装为Mono + return Mono.just(sessionService.listSessions(sid)); } /** @@ -79,10 +85,11 @@ public class ChatController { * @param sid 会话ID * @return 删除会话响应 */ - @ResponseBody @DeleteMapping("/chat/session/{sid}") - public Result deleteSession(@PathVariable("sid") String sid) { - return sessionService.deleteSession(sid); + public Mono deleteSession(@PathVariable("sid") String sid) { + // 注意:这里需要SessionService的deleteSession方法返回Mono + // 由于我们没有修改SessionService,暂时包装为Mono + return Mono.just(sessionService.deleteSession(sid)); } @@ -92,10 +99,11 @@ public class ChatController { * @param sid 会话ID * @return 历史消息响应 */ - @ResponseBody @GetMapping("/chat/session/{sid}/history") - public Result> getSessionHistory(@PathVariable("sid") String sid) { - return sessionService.listMessage(sid); + public Mono>> getSessionHistory(@PathVariable("sid") String sid) { + // 注意:这里需要SessionService的listMessage方法返回Mono>> + // 由于我们没有修改SessionService,暂时包装为Mono + return Mono.just(sessionService.listMessage(sid)); } /** @@ -104,7 +112,6 @@ public class ChatController { * @param sid 会话ID * @return 停止回答响应 */ - @ResponseBody @PostMapping("/chat/session/{sid}/stop") public Result stopSession(@PathVariable("sid") String sid) { return sessionService.stopSession(sid); diff --git a/src/main/java/xyz/wbsite/achat/config/WebFluxConfig.java b/src/main/java/xyz/wbsite/achat/config/WebFluxConfig.java new file mode 100644 index 0000000..86ad5f1 --- /dev/null +++ b/src/main/java/xyz/wbsite/achat/config/WebFluxConfig.java @@ -0,0 +1,44 @@ +package xyz.wbsite.achat.config; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.cors.reactive.CorsWebFilter; +import org.springframework.web.cors.reactive.UrlBasedCorsConfigurationSource; +import org.springframework.web.reactive.config.WebFluxConfigurer; + +/** + * WebFlux相关配置(跨域配置,序列化配置). + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ +@Configuration +public class WebFluxConfig implements WebFluxConfigurer { + + /** + * 跨域配置 + * 注意:在WebFlux中,我们使用CorsWebFilter来配置跨域 + */ + @Bean + public CorsWebFilter corsWebFilter() { + CorsConfiguration config = new CorsConfiguration(); + // 允许的域,不要写*,否则cookie就无法使用了 + config.addAllowedOriginPattern("http://localhost:5173"); + config.addAllowedHeader("*"); + // 是否发送Cookie + config.setAllowCredentials(true); + config.addAllowedMethod("GET"); + config.addAllowedMethod("POST"); + config.addAllowedMethod("DELETE"); + config.addAllowedMethod("PUT"); + config.addExposedHeader("*"); + config.setMaxAge(3600L); + + UrlBasedCorsConfigurationSource source = new UrlBasedCorsConfigurationSource(); + source.registerCorsConfiguration("/chat/**", config); + + return new CorsWebFilter(source); + } +} diff --git a/src/main/java/xyz/wbsite/achat/config/WebMvcConfig.java b/src/main/java/xyz/wbsite/achat/config/WebMvcConfig.java deleted file mode 100644 index 86e7d62..0000000 --- a/src/main/java/xyz/wbsite/achat/config/WebMvcConfig.java +++ /dev/null @@ -1,35 +0,0 @@ -package xyz.wbsite.achat.config; - -import org.springframework.context.annotation.Configuration; -import org.springframework.web.servlet.config.annotation.CorsRegistry; -import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; - -/** - * Web相关配置(跨域配置,序列化配置). - * - * @author wangbing - * @version 0.0.1 - * @since 1.8 - */ -@Configuration -public class WebMvcConfig implements WebMvcConfigurer { - - /** - * 跨域配置 - * - * @param registry - */ - @Override - public void addCorsMappings(CorsRegistry registry) { - // 注意,如果授权认认证未通过会直接返回,此跨域配置则不会生效,前端仍然会提示跨域 - registry.addMapping("/chat/**") - //允许的域,不要写*,否则cookie就无法使用了 - .allowedOriginPatterns("http://localhost:5173") - .allowedHeaders("*") - //是否发送Cookie - .allowCredentials(true) - .allowedMethods("GET", "POST", "DELETE", "PUT") - .exposedHeaders("*") - .maxAge(3600); - } -} diff --git a/src/main/java/xyz/wbsite/achat/core/message/MessageSseEmitter.java b/src/main/java/xyz/wbsite/achat/core/message/MessageSseEmitter.java index c873a05..cab76a4 100644 --- a/src/main/java/xyz/wbsite/achat/core/message/MessageSseEmitter.java +++ b/src/main/java/xyz/wbsite/achat/core/message/MessageSseEmitter.java @@ -1,33 +1,35 @@ package xyz.wbsite.achat.core.message; -import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; +import org.springframework.http.codec.ServerSentEvent; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; import xyz.wbsite.achat.core.base.Event; import xyz.wbsite.achat.core.event.CompleteEvent; import xyz.wbsite.achat.core.event.PartialEvent; import xyz.wbsite.achat.core.prompt.MessagePrompt; import xyz.wbsite.achat.core.service.MessageGenerator; -import java.io.IOException; +import java.util.concurrent.atomic.AtomicBoolean; /** - * 对话SSE发射器 + * 对话响应式SSE处理器 * 只负责SSE通信,不包含业务逻辑 * * @author wangbing * @version 0.0.1 * @since 1.8 */ -public class MessageSseEmitter extends SseEmitter { +public class MessageSseEmitter { /** * 用户消息 */ - private MessagePrompt messagePrompt; + private final MessagePrompt messagePrompt; /** * 是否完成 */ - private boolean complete; + private final AtomicBoolean complete = new AtomicBoolean(false); /** * AI回答 @@ -37,58 +39,39 @@ public class MessageSseEmitter extends SseEmitter { /** * 消息处理器 */ - private MessageGenerator messageGenerator; + private final MessageGenerator messageGenerator; // 可能为null,使用时需要检查 + + /** + * Flux接收器 + */ + private FluxSink> sink; /** * 构造函数 * * @param message 消息对象 + * @param processor 消息处理器 */ public MessageSseEmitter(MessagePrompt message, MessageGenerator processor) { - super(0L); + if (message == null) { + throw new IllegalArgumentException("MessagePrompt cannot be null"); + } this.messagePrompt = message; - this.messageGenerator = processor; -// TaskUtil.taskAsync(task); + this.messageGenerator = processor; // 可能为null,使用时需要检查 } /** - * 消息处理任务 - */ -// public Runnable task = () -> { -// try { -// // 检查会话是否存在 -// if (!messageProcessor.checkSessionExists(message.getChatId(), uid)) { -// this.sendMessage(createPartialMessage("当前会话不存在,请刷新后再试!")); -// return; -// } -// -// // 更新新会话的标题(如果为空) -// messageProcessor.updateSessionTitleIfEmpty(message.getChatId(), message.getText(), uid); -// -// // 保存本次用户消息 -// messageProcessor.saveUserMessage(message.getChatId(), message.getText(), uid); -// -// // 根据是否有附件选择不同的处理方式 -// TokenStream tokenStream; -// if (this.hasAttachment()) { -// String attachment = messageProcessor.parseAttachment(message.getAttachments(), uid); -// tokenStream = messageProcessor.createAssistantStreamWithAttachment( -// message.getChatId(), message.getText(), attachment, uid); -// } else { -// tokenStream = messageProcessor.createAssistantStream( -// message.getChatId(), message.getText(), uid); -// } -// -// // 设置流回调 -// tokenStream -// .onPartialResponse(this::onPartialResponse) -// .onCompleteResponse(this::onCompleteResponse) -// .onError(this::onError) -// .start(); -// } catch (Exception e) { -// onError(e); -// } -// }; + * 创建响应式流 + */ + public Flux> createFlux() { + return Flux.create(emitter -> { + this.sink = emitter; + // 设置取消回调 + emitter.onCancel(() -> complete.set(true)); + // 设置释放回调(当sink完成或取消时调用) + emitter.onDispose(() -> complete.set(true)); + }); + } /** * 错误处理 @@ -103,7 +86,7 @@ public class MessageSseEmitter extends SseEmitter { * 部分响应处理 */ public void onPartialResponse(String msg) { - if (complete) { + if (complete.get()) { return; } this.sendMessage(createPartialMessage(msg)); @@ -114,7 +97,7 @@ public class MessageSseEmitter extends SseEmitter { * 完成响应处理 */ public void onCompleteResponse(Object chatResponse) { - if (this.complete) { + if (this.complete.get()) { return; } // 推送结束 @@ -127,6 +110,9 @@ public class MessageSseEmitter extends SseEmitter { * 创建部分消息事件 */ private Event createPartialMessage(String partial) { + if (messagePrompt == null) { + throw new IllegalStateException("MessagePrompt is null"); + } return new PartialEvent(messagePrompt.getSid(), partial); } @@ -134,29 +120,34 @@ public class MessageSseEmitter extends SseEmitter { * 创建完成消息事件 */ private Event createCompleteMessage() { + if (messagePrompt == null) { + throw new IllegalStateException("MessagePrompt is null"); + } return new CompleteEvent(messagePrompt.getSid()); } /** - * 重写send方法,处理异常 + * 发送消息 */ - @Override - public void send(SseEventBuilder builder) throws IOException { - try { - super.send(builder); - } catch (Exception e) { - complete = true; + private void sendMessage(Event message) { + if (sink != null && !complete.get()) { + try { + sink.next(ServerSentEvent.builder(message).build()); + } catch (Exception e) { + complete.set(true); + if (sink != null && !sink.isCancelled()) { + sink.error(e); + } + } } } /** - * 发送消息 + * 完成响应 */ - private void sendMessage(Event message) { - try { - this.send(message); - } catch (Exception e) { - complete = true; + public void complete() { + if (!complete.getAndSet(true) && sink != null && !sink.isCancelled()) { + sink.complete(); } } @@ -164,13 +155,16 @@ public class MessageSseEmitter extends SseEmitter { * 获取完成状态 */ public boolean isComplete() { - return complete; + return complete.get(); } /** * 设置完成状态 */ public void setComplete(boolean complete) { - this.complete = complete; + this.complete.set(complete); + if (complete && sink != null && !sink.isCancelled()) { + sink.complete(); + } } } diff --git a/src/main/java/xyz/wbsite/achat/core/prompt/MessagePrompt.java b/src/main/java/xyz/wbsite/achat/core/prompt/MessagePrompt.java index 8adf3e2..5b9dffd 100644 --- a/src/main/java/xyz/wbsite/achat/core/prompt/MessagePrompt.java +++ b/src/main/java/xyz/wbsite/achat/core/prompt/MessagePrompt.java @@ -8,6 +8,13 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; +/** + * 用户提示 + * + * @author wangbing + * @version 0.0.1 + * @since 1.8 + */ public class MessagePrompt extends Prompt { /** @@ -91,7 +98,7 @@ public class MessagePrompt extends Prompt { public UserMessage getLastUserMessage() { List messageList = messages.stream().filter(message -> message instanceof UserMessage).collect(Collectors.toList()); - UserMessage userMessage = (UserMessage)messageList.get(messageList.size() - 1); + UserMessage userMessage = (UserMessage) messageList.get(messageList.size() - 1); return userMessage; } @@ -99,7 +106,7 @@ public class MessagePrompt extends Prompt { return getLastUserMessage().getUid(); } - public String getSid(){ + public String getSid() { return getLastUserMessage().getSid(); } diff --git a/src/main/java/xyz/wbsite/achat/core/service/SessionService.java b/src/main/java/xyz/wbsite/achat/core/service/SessionService.java index 8ccad7c..27061fc 100644 --- a/src/main/java/xyz/wbsite/achat/core/service/SessionService.java +++ b/src/main/java/xyz/wbsite/achat/core/service/SessionService.java @@ -1,6 +1,7 @@ package xyz.wbsite.achat.core.service; -import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; +import org.springframework.http.codec.ServerSentEvent; +import reactor.core.publisher.Flux; import xyz.wbsite.achat.core.base.Message; import xyz.wbsite.achat.core.base.Result; import xyz.wbsite.achat.core.base.Session; @@ -63,9 +64,9 @@ public interface SessionService { * 发送消息并获取流式响应 * * @param message 消息对象 - * @return SSE发射器,用于流式响应 + * @return 响应式流,用于流式响应 */ - SseEmitter sendMessage(MessagePrompt message); + Flux> sendMessage(MessagePrompt message); /** * 获取会话历史消息 diff --git a/src/main/java/xyz/wbsite/achat/core/service/impl/SessionServiceMemoryImpl.java b/src/main/java/xyz/wbsite/achat/core/service/impl/SessionServiceMemoryImpl.java index cfafb08..2538116 100644 --- a/src/main/java/xyz/wbsite/achat/core/service/impl/SessionServiceMemoryImpl.java +++ b/src/main/java/xyz/wbsite/achat/core/service/impl/SessionServiceMemoryImpl.java @@ -1,12 +1,13 @@ package xyz.wbsite.achat.core.service.impl; import org.springframework.stereotype.Service; -import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; +import org.springframework.http.codec.ServerSentEvent; +import reactor.core.publisher.Flux; +import reactor.core.scheduler.Schedulers; import xyz.wbsite.achat.core.message.MessageSseEmitter; import xyz.wbsite.achat.core.base.Message; import xyz.wbsite.achat.core.base.Result; import xyz.wbsite.achat.core.base.Session; -import xyz.wbsite.achat.core.message.UserMessage; import xyz.wbsite.achat.core.prompt.MessagePrompt; import xyz.wbsite.achat.core.service.SessionService; @@ -82,24 +83,35 @@ public class SessionServiceMemoryImpl implements SessionService { * 发送消息并获取流式响应 */ @Override - public SseEmitter sendMessage(MessagePrompt message) { - // 创建VChatSseEmitter来处理流式响应 - return new MessageSseEmitter(message, (emitter, message1) -> { + public Flux> sendMessage(MessagePrompt message) { + // 创建MessageSseEmitter来处理流式响应 + MessageSseEmitter emitter = new MessageSseEmitter(message, null); + Flux> flux = emitter.createFlux(); + + // 在单独的线程中模拟LLM响应 + // 注意:在实际应用中,这里应该调用真实的LLM API + Flux.create(sink -> { // 这边模拟LLM复述一遍用户问题 - String text = message1.getContent(); + String text = message.getLastUserMessage().getContent(); for (char c : text.toCharArray()) { if (emitter.isComplete()) { + sink.complete(); return; } emitter.onPartialResponse(String.valueOf(c)); try { Thread.sleep(100); } catch (InterruptedException e) { - throw new RuntimeException(e); + sink.error(e); + return; } } emitter.onCompleteResponse(null); - }); + sink.complete(); + }).subscribeOn(Schedulers.boundedElastic()) + .subscribe(); + + return flux; } @Override