Compare commits

...

1 Commits
master ... flux

Author SHA1 Message Date
王兵 5bd40072ba 上传备份
6 days ago

@ -45,7 +45,7 @@
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.springframework.boot</groupId> <groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId> <artifactId>spring-boot-starter-webflux</artifactId>
</dependency> </dependency>
<dependency> <dependency>

@ -5,13 +5,14 @@ import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import xyz.wbsite.achat.core.base.Message; import xyz.wbsite.achat.core.base.Message;
import xyz.wbsite.achat.core.base.Result; import xyz.wbsite.achat.core.base.Result;
import xyz.wbsite.achat.core.base.Session; import xyz.wbsite.achat.core.base.Session;
import xyz.wbsite.achat.core.message.UserMessage;
import xyz.wbsite.achat.core.prompt.MessagePrompt; import xyz.wbsite.achat.core.prompt.MessagePrompt;
import xyz.wbsite.achat.core.service.SessionService; import xyz.wbsite.achat.core.service.SessionService;
@ -36,17 +37,19 @@ public class ChatController {
* *
*/ */
@GetMapping("/chat") @GetMapping("/chat")
public String chat() { public Mono<String> chat() {
return "AChat is running!"; return Mono.just("AChat is running!");
} }
/** /**
* *
* *
* @return SSE * @return
*/ */
@PostMapping(value = "/chat/send", produces = "text/event-stream;charset=UTF-8") @PostMapping(value = "/chat/send", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public SseEmitter send(@RequestBody MessagePrompt message) { public Flux<ServerSentEvent<?>> send(@RequestBody MessagePrompt message) {
// 注意这里需要SessionService的sendMessage方法返回Flux<ServerSentEvent<?>>
// 由于我们没有修改SessionService暂时返回一个模拟的流
return sessionService.sendMessage(message); return sessionService.sendMessage(message);
} }
@ -56,9 +59,11 @@ public class ChatController {
* @return * @return
*/ */
@PostMapping("/chat/session") @PostMapping("/chat/session")
public Result<Session> createSession() { public Mono<Result<Session>> createSession() {
String sid = UUID.randomUUID().toString(); String sid = UUID.randomUUID().toString();
return sessionService.createSession(sid); // 注意这里需要SessionService的createSession方法返回Mono<Result<Session>>
// 由于我们没有修改SessionService暂时包装为Mono
return Mono.just(sessionService.createSession(sid));
} }
/** /**
@ -66,11 +71,12 @@ public class ChatController {
* *
* @return * @return
*/ */
@ResponseBody
@GetMapping("/chat/session/list") @GetMapping("/chat/session/list")
public Result<List<Session>> listSession() { public Mono<Result<List<Session>>> listSession() {
String sid = UUID.randomUUID().toString(); String sid = UUID.randomUUID().toString();
return sessionService.listSessions(sid); // 注意这里需要SessionService的listSessions方法返回Mono<Result<List<Session>>>
// 由于我们没有修改SessionService暂时包装为Mono
return Mono.just(sessionService.listSessions(sid));
} }
/** /**
@ -79,10 +85,11 @@ public class ChatController {
* @param sid ID * @param sid ID
* @return * @return
*/ */
@ResponseBody
@DeleteMapping("/chat/session/{sid}") @DeleteMapping("/chat/session/{sid}")
public Result deleteSession(@PathVariable("sid") String sid) { public Mono<Result> deleteSession(@PathVariable("sid") String sid) {
return sessionService.deleteSession(sid); // 注意这里需要SessionService的deleteSession方法返回Mono<Result>
// 由于我们没有修改SessionService暂时包装为Mono
return Mono.just(sessionService.deleteSession(sid));
} }
@ -92,10 +99,11 @@ public class ChatController {
* @param sid ID * @param sid ID
* @return * @return
*/ */
@ResponseBody
@GetMapping("/chat/session/{sid}/history") @GetMapping("/chat/session/{sid}/history")
public Result<List<Message>> getSessionHistory(@PathVariable("sid") String sid) { public Mono<Result<List<Message>>> getSessionHistory(@PathVariable("sid") String sid) {
return sessionService.listMessage(sid); // 注意这里需要SessionService的listMessage方法返回Mono<Result<List<Message>>>
// 由于我们没有修改SessionService暂时包装为Mono
return Mono.just(sessionService.listMessage(sid));
} }
/** /**
@ -104,7 +112,6 @@ public class ChatController {
* @param sid ID * @param sid ID
* @return * @return
*/ */
@ResponseBody
@PostMapping("/chat/session/{sid}/stop") @PostMapping("/chat/session/{sid}/stop")
public Result<Void> stopSession(@PathVariable("sid") String sid) { public Result<Void> stopSession(@PathVariable("sid") String sid) {
return sessionService.stopSession(sid); return sessionService.stopSession(sid);

@ -0,0 +1,44 @@
package xyz.wbsite.achat.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.reactive.CorsWebFilter;
import org.springframework.web.cors.reactive.UrlBasedCorsConfigurationSource;
import org.springframework.web.reactive.config.WebFluxConfigurer;
/**
* WebFlux(,).
*
* @author wangbing
* @version 0.0.1
* @since 1.8
*/
@Configuration
public class WebFluxConfig implements WebFluxConfigurer {
/**
*
* WebFlux使CorsWebFilter
*/
@Bean
public CorsWebFilter corsWebFilter() {
CorsConfiguration config = new CorsConfiguration();
// 允许的域,不要写*否则cookie就无法使用了
config.addAllowedOriginPattern("http://localhost:5173");
config.addAllowedHeader("*");
// 是否发送Cookie
config.setAllowCredentials(true);
config.addAllowedMethod("GET");
config.addAllowedMethod("POST");
config.addAllowedMethod("DELETE");
config.addAllowedMethod("PUT");
config.addExposedHeader("*");
config.setMaxAge(3600L);
UrlBasedCorsConfigurationSource source = new UrlBasedCorsConfigurationSource();
source.registerCorsConfiguration("/chat/**", config);
return new CorsWebFilter(source);
}
}

@ -1,35 +0,0 @@
package xyz.wbsite.achat.config;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.CorsRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
/**
* Web(,).
*
* @author wangbing
* @version 0.0.1
* @since 1.8
*/
@Configuration
public class WebMvcConfig implements WebMvcConfigurer {
/**
*
*
* @param registry
*/
@Override
public void addCorsMappings(CorsRegistry registry) {
// 注意,如果授权认认证未通过会直接返回,此跨域配置则不会生效,前端仍然会提示跨域
registry.addMapping("/chat/**")
//允许的域,不要写*否则cookie就无法使用了
.allowedOriginPatterns("http://localhost:5173")
.allowedHeaders("*")
//是否发送Cookie
.allowCredentials(true)
.allowedMethods("GET", "POST", "DELETE", "PUT")
.exposedHeaders("*")
.maxAge(3600);
}
}

@ -1,33 +1,35 @@
package xyz.wbsite.achat.core.message; package xyz.wbsite.achat.core.message;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import org.springframework.http.codec.ServerSentEvent;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import xyz.wbsite.achat.core.base.Event; import xyz.wbsite.achat.core.base.Event;
import xyz.wbsite.achat.core.event.CompleteEvent; import xyz.wbsite.achat.core.event.CompleteEvent;
import xyz.wbsite.achat.core.event.PartialEvent; import xyz.wbsite.achat.core.event.PartialEvent;
import xyz.wbsite.achat.core.prompt.MessagePrompt; import xyz.wbsite.achat.core.prompt.MessagePrompt;
import xyz.wbsite.achat.core.service.MessageGenerator; import xyz.wbsite.achat.core.service.MessageGenerator;
import java.io.IOException; import java.util.concurrent.atomic.AtomicBoolean;
/** /**
* SSE * SSE
* SSE * SSE
* *
* @author wangbing * @author wangbing
* @version 0.0.1 * @version 0.0.1
* @since 1.8 * @since 1.8
*/ */
public class MessageSseEmitter extends SseEmitter { public class MessageSseEmitter {
/** /**
* *
*/ */
private MessagePrompt messagePrompt; private final MessagePrompt messagePrompt;
/** /**
* *
*/ */
private boolean complete; private final AtomicBoolean complete = new AtomicBoolean(false);
/** /**
* AI * AI
@ -37,58 +39,39 @@ public class MessageSseEmitter extends SseEmitter {
/** /**
* *
*/ */
private MessageGenerator messageGenerator; private final MessageGenerator messageGenerator; // 可能为null使用时需要检查
/**
* Flux
*/
private FluxSink<ServerSentEvent<?>> sink;
/** /**
* *
* *
* @param message * @param message
* @param processor
*/ */
public MessageSseEmitter(MessagePrompt message, MessageGenerator processor) { public MessageSseEmitter(MessagePrompt message, MessageGenerator processor) {
super(0L); if (message == null) {
throw new IllegalArgumentException("MessagePrompt cannot be null");
}
this.messagePrompt = message; this.messagePrompt = message;
this.messageGenerator = processor; this.messageGenerator = processor; // 可能为null使用时需要检查
// TaskUtil.taskAsync(task);
} }
/** /**
* *
*/ */
// public Runnable task = () -> { public Flux<ServerSentEvent<?>> createFlux() {
// try { return Flux.create(emitter -> {
// // 检查会话是否存在 this.sink = emitter;
// if (!messageProcessor.checkSessionExists(message.getChatId(), uid)) { // 设置取消回调
// this.sendMessage(createPartialMessage("当前会话不存在,请刷新后再试!")); emitter.onCancel(() -> complete.set(true));
// return; // 设置释放回调当sink完成或取消时调用
// } emitter.onDispose(() -> complete.set(true));
// });
// // 更新新会话的标题(如果为空) }
// messageProcessor.updateSessionTitleIfEmpty(message.getChatId(), message.getText(), uid);
//
// // 保存本次用户消息
// messageProcessor.saveUserMessage(message.getChatId(), message.getText(), uid);
//
// // 根据是否有附件选择不同的处理方式
// TokenStream tokenStream;
// if (this.hasAttachment()) {
// String attachment = messageProcessor.parseAttachment(message.getAttachments(), uid);
// tokenStream = messageProcessor.createAssistantStreamWithAttachment(
// message.getChatId(), message.getText(), attachment, uid);
// } else {
// tokenStream = messageProcessor.createAssistantStream(
// message.getChatId(), message.getText(), uid);
// }
//
// // 设置流回调
// tokenStream
// .onPartialResponse(this::onPartialResponse)
// .onCompleteResponse(this::onCompleteResponse)
// .onError(this::onError)
// .start();
// } catch (Exception e) {
// onError(e);
// }
// };
/** /**
* *
@ -103,7 +86,7 @@ public class MessageSseEmitter extends SseEmitter {
* *
*/ */
public void onPartialResponse(String msg) { public void onPartialResponse(String msg) {
if (complete) { if (complete.get()) {
return; return;
} }
this.sendMessage(createPartialMessage(msg)); this.sendMessage(createPartialMessage(msg));
@ -114,7 +97,7 @@ public class MessageSseEmitter extends SseEmitter {
* *
*/ */
public void onCompleteResponse(Object chatResponse) { public void onCompleteResponse(Object chatResponse) {
if (this.complete) { if (this.complete.get()) {
return; return;
} }
// 推送结束 // 推送结束
@ -127,6 +110,9 @@ public class MessageSseEmitter extends SseEmitter {
* *
*/ */
private Event createPartialMessage(String partial) { private Event createPartialMessage(String partial) {
if (messagePrompt == null) {
throw new IllegalStateException("MessagePrompt is null");
}
return new PartialEvent(messagePrompt.getSid(), partial); return new PartialEvent(messagePrompt.getSid(), partial);
} }
@ -134,29 +120,34 @@ public class MessageSseEmitter extends SseEmitter {
* *
*/ */
private Event createCompleteMessage() { private Event createCompleteMessage() {
if (messagePrompt == null) {
throw new IllegalStateException("MessagePrompt is null");
}
return new CompleteEvent(messagePrompt.getSid()); return new CompleteEvent(messagePrompt.getSid());
} }
/** /**
* send *
*/ */
@Override private void sendMessage(Event message) {
public void send(SseEventBuilder builder) throws IOException { if (sink != null && !complete.get()) {
try { try {
super.send(builder); sink.next(ServerSentEvent.builder(message).build());
} catch (Exception e) { } catch (Exception e) {
complete = true; complete.set(true);
if (sink != null && !sink.isCancelled()) {
sink.error(e);
}
}
} }
} }
/** /**
* *
*/ */
private void sendMessage(Event message) { public void complete() {
try { if (!complete.getAndSet(true) && sink != null && !sink.isCancelled()) {
this.send(message); sink.complete();
} catch (Exception e) {
complete = true;
} }
} }
@ -164,13 +155,16 @@ public class MessageSseEmitter extends SseEmitter {
* *
*/ */
public boolean isComplete() { public boolean isComplete() {
return complete; return complete.get();
} }
/** /**
* *
*/ */
public void setComplete(boolean complete) { public void setComplete(boolean complete) {
this.complete = complete; this.complete.set(complete);
if (complete && sink != null && !sink.isCancelled()) {
sink.complete();
}
} }
} }

@ -8,6 +8,13 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
/**
*
*
* @author wangbing
* @version 0.0.1
* @since 1.8
*/
public class MessagePrompt extends Prompt { public class MessagePrompt extends Prompt {
/** /**
@ -91,7 +98,7 @@ public class MessagePrompt extends Prompt {
public UserMessage getLastUserMessage() { public UserMessage getLastUserMessage() {
List<Message> messageList = messages.stream().filter(message -> message instanceof UserMessage).collect(Collectors.toList()); List<Message> messageList = messages.stream().filter(message -> message instanceof UserMessage).collect(Collectors.toList());
UserMessage userMessage = (UserMessage)messageList.get(messageList.size() - 1); UserMessage userMessage = (UserMessage) messageList.get(messageList.size() - 1);
return userMessage; return userMessage;
} }
@ -99,7 +106,7 @@ public class MessagePrompt extends Prompt {
return getLastUserMessage().getUid(); return getLastUserMessage().getUid();
} }
public String getSid(){ public String getSid() {
return getLastUserMessage().getSid(); return getLastUserMessage().getSid();
} }

@ -1,6 +1,7 @@
package xyz.wbsite.achat.core.service; package xyz.wbsite.achat.core.service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import org.springframework.http.codec.ServerSentEvent;
import reactor.core.publisher.Flux;
import xyz.wbsite.achat.core.base.Message; import xyz.wbsite.achat.core.base.Message;
import xyz.wbsite.achat.core.base.Result; import xyz.wbsite.achat.core.base.Result;
import xyz.wbsite.achat.core.base.Session; import xyz.wbsite.achat.core.base.Session;
@ -63,9 +64,9 @@ public interface SessionService {
* *
* *
* @param message * @param message
* @return SSE * @return
*/ */
SseEmitter sendMessage(MessagePrompt message); Flux<ServerSentEvent<?>> sendMessage(MessagePrompt message);
/** /**
* *

@ -1,12 +1,13 @@
package xyz.wbsite.achat.core.service.impl; package xyz.wbsite.achat.core.service.impl;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import org.springframework.http.codec.ServerSentEvent;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;
import xyz.wbsite.achat.core.message.MessageSseEmitter; import xyz.wbsite.achat.core.message.MessageSseEmitter;
import xyz.wbsite.achat.core.base.Message; import xyz.wbsite.achat.core.base.Message;
import xyz.wbsite.achat.core.base.Result; import xyz.wbsite.achat.core.base.Result;
import xyz.wbsite.achat.core.base.Session; import xyz.wbsite.achat.core.base.Session;
import xyz.wbsite.achat.core.message.UserMessage;
import xyz.wbsite.achat.core.prompt.MessagePrompt; import xyz.wbsite.achat.core.prompt.MessagePrompt;
import xyz.wbsite.achat.core.service.SessionService; import xyz.wbsite.achat.core.service.SessionService;
@ -82,24 +83,35 @@ public class SessionServiceMemoryImpl implements SessionService {
* *
*/ */
@Override @Override
public SseEmitter sendMessage(MessagePrompt message) { public Flux<ServerSentEvent<?>> sendMessage(MessagePrompt message) {
// 创建VChatSseEmitter来处理流式响应 // 创建MessageSseEmitter来处理流式响应
return new MessageSseEmitter(message, (emitter, message1) -> { MessageSseEmitter emitter = new MessageSseEmitter(message, null);
Flux<ServerSentEvent<?>> flux = emitter.createFlux();
// 在单独的线程中模拟LLM响应
// 注意在实际应用中这里应该调用真实的LLM API
Flux.create(sink -> {
// 这边模拟LLM复述一遍用户问题 // 这边模拟LLM复述一遍用户问题
String text = message1.getContent(); String text = message.getLastUserMessage().getContent();
for (char c : text.toCharArray()) { for (char c : text.toCharArray()) {
if (emitter.isComplete()) { if (emitter.isComplete()) {
sink.complete();
return; return;
} }
emitter.onPartialResponse(String.valueOf(c)); emitter.onPartialResponse(String.valueOf(c));
try { try {
Thread.sleep(100); Thread.sleep(100);
} catch (InterruptedException e) { } catch (InterruptedException e) {
throw new RuntimeException(e); sink.error(e);
return;
} }
} }
emitter.onCompleteResponse(null); emitter.onCompleteResponse(null);
}); sink.complete();
}).subscribeOn(Schedulers.boundedElastic())
.subscribe();
return flux;
} }
@Override @Override

Loading…
Cancel
Save

Powered by TurnKey Linux.