package xyz.wbsite.achat.chat; import org.springframework.util.StringUtils; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; /** * 流式对话生成器 * * @author wangbing * @version 0.0.1 * @since 1.8 */ public class StreamEmitter extends SseEmitter { /** * 流式输出默认超时时间 */ private static final Long DEFAULT_TIMEOUT = 5 * 60 * 1000L; /** * 对话请求 */ private ChatCompletionRequest request; /** * 当前对话状态 */ private Status status; /** * 当前对话ID */ private String chatId; /** * 是否完成 */ private boolean complete; public StreamEmitter(ChatCompletionRequest request) { super(DEFAULT_TIMEOUT); this.request = request; } // /** // * 错误处理 // */ // private void onError(Throwable e) { //// this.sendMessage(createPartialMessage("" + e.getMessage())); // this.answer.append("" + e.getMessage()); // this.onCompleteResponse(null); // } // // /** // * 部分响应处理 // */ // public void onPartialResponse(String msg) { // if (complete) { // return; // } //// this.sendMessage(createPartialMessage(msg)); // this.answer.append(msg); // } /** * 完成响应处理 */ public void onCompleteResponse(Object chatResponse) { if (this.complete) { return; } // 推送结束 // this.sendMessage(createCompleteMessage()); // 关闭链接 this.complete(); } /** * 重写send方法,处理异常 */ @Override public void send(SseEventBuilder builder) { try { super.send(builder); } catch (Exception e) { complete = true; } } /** * 发送片段 */ private void pushChunk(ChatCompletionChunk chunk) { try { this.send(chunk); } catch (Exception e) { complete = true; } } /** * 发送开始片段 * * @param chatId 对话id */ public void onStart(String chatId) { if (!StringUtils.hasText(chatId)) { throw new IllegalArgumentException("chatId is empty!"); } if (this.status != null) { throw new IllegalArgumentException("chunk has been started!"); } ChatCompletionChunk chunk = ChatCompletionChunk.builder() .id(chatId) .object("chat.completion.chunk") .created(System.currentTimeMillis() / 1000) .model(request.getModel()) .withChoices(choices -> { choices.add(ChatCompletionChunk.choiceBuilder().index(0) .role(Role.ASSISTANT) .build()); }) .build(); this.pushChunk(chunk); this.status = Status.PENDING; this.chatId = chatId; } /** * 发送部分片段 * * @param text 片段内容 */ public void onPartial(String text) { ChatCompletionChunk chunk = ChatCompletionChunk.builder() .id(chatId) .object("chat.completion.chunk") .created(System.currentTimeMillis() / 1000) .model(request.getModel()) .withChoices(choices -> { choices.add(ChatCompletionChunk.choiceBuilder().index(0) .role(Role.ASSISTANT) .content(text) .build()); }) .build(); this.pushChunk(chunk); } /** * 发送完成信号,根据OpenAI API规范,最后一个片段需要包含finish_reason字段 */ public void onComplete() { ChatCompletionChunk chunk = ChatCompletionChunk.builder() .id(chatId) .object("chat.completion.chunk") .created(System.currentTimeMillis() / 1000) .model(request.getModel()) .withChoices(choices -> { choices.add(ChatCompletionChunk.choiceBuilder().index(0) .role(Role.ASSISTANT) // 符合OpenAI规范,最后一个片段需要设置finish_reason .finish_reason("stop") .build()); }) .build(); this.pushChunk(chunk); this.complete(); this.status = Status.SUCCESS; } /** * 获取完成状态 */ public boolean isComplete() { return complete; } /** * 设置完成状态 */ public void setComplete(boolean complete) { this.complete = complete; } }