上传备份

flux
王兵 6 days ago
parent 5814569607
commit 5bd40072ba

@ -45,7 +45,7 @@
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
<artifactId>spring-boot-starter-webflux</artifactId>
</dependency>
<dependency>

@ -5,13 +5,14 @@ import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import xyz.wbsite.achat.core.base.Message;
import xyz.wbsite.achat.core.base.Result;
import xyz.wbsite.achat.core.base.Session;
import xyz.wbsite.achat.core.message.UserMessage;
import xyz.wbsite.achat.core.prompt.MessagePrompt;
import xyz.wbsite.achat.core.service.SessionService;
@ -36,17 +37,19 @@ public class ChatController {
*
*/
@GetMapping("/chat")
public String chat() {
return "AChat is running!";
public Mono<String> chat() {
return Mono.just("AChat is running!");
}
/**
*
*
* @return SSE
* @return
*/
@PostMapping(value = "/chat/send", produces = "text/event-stream;charset=UTF-8")
public SseEmitter send(@RequestBody MessagePrompt message) {
@PostMapping(value = "/chat/send", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<ServerSentEvent<?>> send(@RequestBody MessagePrompt message) {
// 注意这里需要SessionService的sendMessage方法返回Flux<ServerSentEvent<?>>
// 由于我们没有修改SessionService暂时返回一个模拟的流
return sessionService.sendMessage(message);
}
@ -56,9 +59,11 @@ public class ChatController {
* @return
*/
@PostMapping("/chat/session")
public Result<Session> createSession() {
public Mono<Result<Session>> createSession() {
String sid = UUID.randomUUID().toString();
return sessionService.createSession(sid);
// 注意这里需要SessionService的createSession方法返回Mono<Result<Session>>
// 由于我们没有修改SessionService暂时包装为Mono
return Mono.just(sessionService.createSession(sid));
}
/**
@ -66,11 +71,12 @@ public class ChatController {
*
* @return
*/
@ResponseBody
@GetMapping("/chat/session/list")
public Result<List<Session>> listSession() {
public Mono<Result<List<Session>>> listSession() {
String sid = UUID.randomUUID().toString();
return sessionService.listSessions(sid);
// 注意这里需要SessionService的listSessions方法返回Mono<Result<List<Session>>>
// 由于我们没有修改SessionService暂时包装为Mono
return Mono.just(sessionService.listSessions(sid));
}
/**
@ -79,10 +85,11 @@ public class ChatController {
* @param sid ID
* @return
*/
@ResponseBody
@DeleteMapping("/chat/session/{sid}")
public Result deleteSession(@PathVariable("sid") String sid) {
return sessionService.deleteSession(sid);
public Mono<Result> deleteSession(@PathVariable("sid") String sid) {
// 注意这里需要SessionService的deleteSession方法返回Mono<Result>
// 由于我们没有修改SessionService暂时包装为Mono
return Mono.just(sessionService.deleteSession(sid));
}
@ -92,10 +99,11 @@ public class ChatController {
* @param sid ID
* @return
*/
@ResponseBody
@GetMapping("/chat/session/{sid}/history")
public Result<List<Message>> getSessionHistory(@PathVariable("sid") String sid) {
return sessionService.listMessage(sid);
public Mono<Result<List<Message>>> getSessionHistory(@PathVariable("sid") String sid) {
// 注意这里需要SessionService的listMessage方法返回Mono<Result<List<Message>>>
// 由于我们没有修改SessionService暂时包装为Mono
return Mono.just(sessionService.listMessage(sid));
}
/**
@ -104,7 +112,6 @@ public class ChatController {
* @param sid ID
* @return
*/
@ResponseBody
@PostMapping("/chat/session/{sid}/stop")
public Result<Void> stopSession(@PathVariable("sid") String sid) {
return sessionService.stopSession(sid);

@ -0,0 +1,44 @@
package xyz.wbsite.achat.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.reactive.CorsWebFilter;
import org.springframework.web.cors.reactive.UrlBasedCorsConfigurationSource;
import org.springframework.web.reactive.config.WebFluxConfigurer;
/**
* WebFlux(,).
*
* @author wangbing
* @version 0.0.1
* @since 1.8
*/
@Configuration
public class WebFluxConfig implements WebFluxConfigurer {
/**
*
* WebFlux使CorsWebFilter
*/
@Bean
public CorsWebFilter corsWebFilter() {
CorsConfiguration config = new CorsConfiguration();
// 允许的域,不要写*否则cookie就无法使用了
config.addAllowedOriginPattern("http://localhost:5173");
config.addAllowedHeader("*");
// 是否发送Cookie
config.setAllowCredentials(true);
config.addAllowedMethod("GET");
config.addAllowedMethod("POST");
config.addAllowedMethod("DELETE");
config.addAllowedMethod("PUT");
config.addExposedHeader("*");
config.setMaxAge(3600L);
UrlBasedCorsConfigurationSource source = new UrlBasedCorsConfigurationSource();
source.registerCorsConfiguration("/chat/**", config);
return new CorsWebFilter(source);
}
}

@ -1,35 +0,0 @@
package xyz.wbsite.achat.config;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.CorsRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
/**
* Web(,).
*
* @author wangbing
* @version 0.0.1
* @since 1.8
*/
@Configuration
public class WebMvcConfig implements WebMvcConfigurer {
/**
*
*
* @param registry
*/
@Override
public void addCorsMappings(CorsRegistry registry) {
// 注意,如果授权认认证未通过会直接返回,此跨域配置则不会生效,前端仍然会提示跨域
registry.addMapping("/chat/**")
//允许的域,不要写*否则cookie就无法使用了
.allowedOriginPatterns("http://localhost:5173")
.allowedHeaders("*")
//是否发送Cookie
.allowCredentials(true)
.allowedMethods("GET", "POST", "DELETE", "PUT")
.exposedHeaders("*")
.maxAge(3600);
}
}

@ -1,33 +1,35 @@
package xyz.wbsite.achat.core.message;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import org.springframework.http.codec.ServerSentEvent;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import xyz.wbsite.achat.core.base.Event;
import xyz.wbsite.achat.core.event.CompleteEvent;
import xyz.wbsite.achat.core.event.PartialEvent;
import xyz.wbsite.achat.core.prompt.MessagePrompt;
import xyz.wbsite.achat.core.service.MessageGenerator;
import java.io.IOException;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* SSE
* SSE
* SSE
*
* @author wangbing
* @version 0.0.1
* @since 1.8
*/
public class MessageSseEmitter extends SseEmitter {
public class MessageSseEmitter {
/**
*
*/
private MessagePrompt messagePrompt;
private final MessagePrompt messagePrompt;
/**
*
*/
private boolean complete;
private final AtomicBoolean complete = new AtomicBoolean(false);
/**
* AI
@ -37,58 +39,39 @@ public class MessageSseEmitter extends SseEmitter {
/**
*
*/
private MessageGenerator messageGenerator;
private final MessageGenerator messageGenerator; // 可能为null使用时需要检查
/**
* Flux
*/
private FluxSink<ServerSentEvent<?>> sink;
/**
*
*
* @param message
* @param processor
*/
public MessageSseEmitter(MessagePrompt message, MessageGenerator processor) {
super(0L);
if (message == null) {
throw new IllegalArgumentException("MessagePrompt cannot be null");
}
this.messagePrompt = message;
this.messageGenerator = processor;
// TaskUtil.taskAsync(task);
this.messageGenerator = processor; // 可能为null使用时需要检查
}
/**
*
*/
// public Runnable task = () -> {
// try {
// // 检查会话是否存在
// if (!messageProcessor.checkSessionExists(message.getChatId(), uid)) {
// this.sendMessage(createPartialMessage("当前会话不存在,请刷新后再试!"));
// return;
// }
//
// // 更新新会话的标题(如果为空)
// messageProcessor.updateSessionTitleIfEmpty(message.getChatId(), message.getText(), uid);
//
// // 保存本次用户消息
// messageProcessor.saveUserMessage(message.getChatId(), message.getText(), uid);
//
// // 根据是否有附件选择不同的处理方式
// TokenStream tokenStream;
// if (this.hasAttachment()) {
// String attachment = messageProcessor.parseAttachment(message.getAttachments(), uid);
// tokenStream = messageProcessor.createAssistantStreamWithAttachment(
// message.getChatId(), message.getText(), attachment, uid);
// } else {
// tokenStream = messageProcessor.createAssistantStream(
// message.getChatId(), message.getText(), uid);
// }
//
// // 设置流回调
// tokenStream
// .onPartialResponse(this::onPartialResponse)
// .onCompleteResponse(this::onCompleteResponse)
// .onError(this::onError)
// .start();
// } catch (Exception e) {
// onError(e);
// }
// };
*
*/
public Flux<ServerSentEvent<?>> createFlux() {
return Flux.create(emitter -> {
this.sink = emitter;
// 设置取消回调
emitter.onCancel(() -> complete.set(true));
// 设置释放回调当sink完成或取消时调用
emitter.onDispose(() -> complete.set(true));
});
}
/**
*
@ -103,7 +86,7 @@ public class MessageSseEmitter extends SseEmitter {
*
*/
public void onPartialResponse(String msg) {
if (complete) {
if (complete.get()) {
return;
}
this.sendMessage(createPartialMessage(msg));
@ -114,7 +97,7 @@ public class MessageSseEmitter extends SseEmitter {
*
*/
public void onCompleteResponse(Object chatResponse) {
if (this.complete) {
if (this.complete.get()) {
return;
}
// 推送结束
@ -127,6 +110,9 @@ public class MessageSseEmitter extends SseEmitter {
*
*/
private Event createPartialMessage(String partial) {
if (messagePrompt == null) {
throw new IllegalStateException("MessagePrompt is null");
}
return new PartialEvent(messagePrompt.getSid(), partial);
}
@ -134,29 +120,34 @@ public class MessageSseEmitter extends SseEmitter {
*
*/
private Event createCompleteMessage() {
if (messagePrompt == null) {
throw new IllegalStateException("MessagePrompt is null");
}
return new CompleteEvent(messagePrompt.getSid());
}
/**
* send
*
*/
@Override
public void send(SseEventBuilder builder) throws IOException {
try {
super.send(builder);
} catch (Exception e) {
complete = true;
private void sendMessage(Event message) {
if (sink != null && !complete.get()) {
try {
sink.next(ServerSentEvent.builder(message).build());
} catch (Exception e) {
complete.set(true);
if (sink != null && !sink.isCancelled()) {
sink.error(e);
}
}
}
}
/**
*
*
*/
private void sendMessage(Event message) {
try {
this.send(message);
} catch (Exception e) {
complete = true;
public void complete() {
if (!complete.getAndSet(true) && sink != null && !sink.isCancelled()) {
sink.complete();
}
}
@ -164,13 +155,16 @@ public class MessageSseEmitter extends SseEmitter {
*
*/
public boolean isComplete() {
return complete;
return complete.get();
}
/**
*
*/
public void setComplete(boolean complete) {
this.complete = complete;
this.complete.set(complete);
if (complete && sink != null && !sink.isCancelled()) {
sink.complete();
}
}
}

@ -8,6 +8,13 @@ import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
*
*
* @author wangbing
* @version 0.0.1
* @since 1.8
*/
public class MessagePrompt extends Prompt {
/**
@ -91,7 +98,7 @@ public class MessagePrompt extends Prompt {
public UserMessage getLastUserMessage() {
List<Message> messageList = messages.stream().filter(message -> message instanceof UserMessage).collect(Collectors.toList());
UserMessage userMessage = (UserMessage)messageList.get(messageList.size() - 1);
UserMessage userMessage = (UserMessage) messageList.get(messageList.size() - 1);
return userMessage;
}
@ -99,7 +106,7 @@ public class MessagePrompt extends Prompt {
return getLastUserMessage().getUid();
}
public String getSid(){
public String getSid() {
return getLastUserMessage().getSid();
}

@ -1,6 +1,7 @@
package xyz.wbsite.achat.core.service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import org.springframework.http.codec.ServerSentEvent;
import reactor.core.publisher.Flux;
import xyz.wbsite.achat.core.base.Message;
import xyz.wbsite.achat.core.base.Result;
import xyz.wbsite.achat.core.base.Session;
@ -63,9 +64,9 @@ public interface SessionService {
*
*
* @param message
* @return SSE
* @return
*/
SseEmitter sendMessage(MessagePrompt message);
Flux<ServerSentEvent<?>> sendMessage(MessagePrompt message);
/**
*

@ -1,12 +1,13 @@
package xyz.wbsite.achat.core.service.impl;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import org.springframework.http.codec.ServerSentEvent;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;
import xyz.wbsite.achat.core.message.MessageSseEmitter;
import xyz.wbsite.achat.core.base.Message;
import xyz.wbsite.achat.core.base.Result;
import xyz.wbsite.achat.core.base.Session;
import xyz.wbsite.achat.core.message.UserMessage;
import xyz.wbsite.achat.core.prompt.MessagePrompt;
import xyz.wbsite.achat.core.service.SessionService;
@ -82,24 +83,35 @@ public class SessionServiceMemoryImpl implements SessionService {
*
*/
@Override
public SseEmitter sendMessage(MessagePrompt message) {
// 创建VChatSseEmitter来处理流式响应
return new MessageSseEmitter(message, (emitter, message1) -> {
public Flux<ServerSentEvent<?>> sendMessage(MessagePrompt message) {
// 创建MessageSseEmitter来处理流式响应
MessageSseEmitter emitter = new MessageSseEmitter(message, null);
Flux<ServerSentEvent<?>> flux = emitter.createFlux();
// 在单独的线程中模拟LLM响应
// 注意在实际应用中这里应该调用真实的LLM API
Flux.create(sink -> {
// 这边模拟LLM复述一遍用户问题
String text = message1.getContent();
String text = message.getLastUserMessage().getContent();
for (char c : text.toCharArray()) {
if (emitter.isComplete()) {
sink.complete();
return;
}
emitter.onPartialResponse(String.valueOf(c));
try {
Thread.sleep(100);
} catch (InterruptedException e) {
throw new RuntimeException(e);
sink.error(e);
return;
}
}
emitter.onCompleteResponse(null);
});
sink.complete();
}).subscribeOn(Schedulers.boundedElastic())
.subscribe();
return flux;
}
@Override

Loading…
Cancel
Save

Powered by TurnKey Linux.