|
|
package xyz.wbsite.achat.chat;
|
|
|
|
|
|
import org.springframework.util.StringUtils;
|
|
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
|
|
|
|
|
/**
|
|
|
* 流式对话生成器
|
|
|
*
|
|
|
* @author wangbing
|
|
|
* @version 0.0.1
|
|
|
* @since 1.8
|
|
|
*/
|
|
|
public class StreamEmitter extends SseEmitter {
|
|
|
|
|
|
/**
|
|
|
* 流式输出默认超时时间
|
|
|
*/
|
|
|
private static final Long DEFAULT_TIMEOUT = 5 * 60 * 1000L;
|
|
|
|
|
|
/**
|
|
|
* 对话请求
|
|
|
*/
|
|
|
private ChatCompletionRequest request;
|
|
|
|
|
|
/**
|
|
|
* 当前对话状态
|
|
|
*/
|
|
|
private Status status;
|
|
|
|
|
|
/**
|
|
|
* 当前对话ID
|
|
|
*/
|
|
|
private String chatId;
|
|
|
|
|
|
/**
|
|
|
* 是否完成
|
|
|
*/
|
|
|
private boolean complete;
|
|
|
|
|
|
public StreamEmitter(ChatCompletionRequest request) {
|
|
|
super(DEFAULT_TIMEOUT);
|
|
|
this.request = request;
|
|
|
}
|
|
|
|
|
|
// /**
|
|
|
// * 错误处理
|
|
|
// */
|
|
|
// private void onError(Throwable e) {
|
|
|
//// this.sendMessage(createPartialMessage("</think></think>" + e.getMessage()));
|
|
|
// this.answer.append("</think></think>" + e.getMessage());
|
|
|
// this.onCompleteResponse(null);
|
|
|
// }
|
|
|
//
|
|
|
// /**
|
|
|
// * 部分响应处理
|
|
|
// */
|
|
|
// public void onPartialResponse(String msg) {
|
|
|
// if (complete) {
|
|
|
// return;
|
|
|
// }
|
|
|
//// this.sendMessage(createPartialMessage(msg));
|
|
|
// this.answer.append(msg);
|
|
|
// }
|
|
|
|
|
|
/**
|
|
|
* 完成响应处理
|
|
|
*/
|
|
|
public void onCompleteResponse(Object chatResponse) {
|
|
|
if (this.complete) {
|
|
|
return;
|
|
|
}
|
|
|
// 推送结束
|
|
|
// this.sendMessage(createCompleteMessage());
|
|
|
// 关闭链接
|
|
|
this.complete();
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
* 重写send方法,处理异常
|
|
|
*/
|
|
|
@Override
|
|
|
public void send(SseEventBuilder builder) {
|
|
|
try {
|
|
|
super.send(builder);
|
|
|
} catch (Exception e) {
|
|
|
complete = true;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
* 发送片段
|
|
|
*/
|
|
|
private void pushChunk(ChatCompletionChunk chunk) {
|
|
|
try {
|
|
|
this.send(chunk);
|
|
|
} catch (Exception e) {
|
|
|
complete = true;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
* 发送开始片段
|
|
|
*
|
|
|
* @param chatId 对话id
|
|
|
*/
|
|
|
public void onStart(String chatId) {
|
|
|
if (!StringUtils.hasText(chatId)) {
|
|
|
throw new IllegalArgumentException("chatId is empty!");
|
|
|
}
|
|
|
if (this.status != null) {
|
|
|
throw new IllegalArgumentException("chunk has been started!");
|
|
|
}
|
|
|
ChatCompletionChunk chunk = ChatCompletionChunk.builder()
|
|
|
.id(chatId)
|
|
|
.object("chat.completion.chunk")
|
|
|
.created(System.currentTimeMillis() / 1000)
|
|
|
.model(request.getModel())
|
|
|
.withChoices(choices -> {
|
|
|
choices.add(ChatCompletionChunk.choiceBuilder().index(0)
|
|
|
.role(Role.ASSISTANT)
|
|
|
.build());
|
|
|
})
|
|
|
.build();
|
|
|
this.pushChunk(chunk);
|
|
|
this.status = Status.PENDING;
|
|
|
this.chatId = chatId;
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
* 发送部分片段
|
|
|
*
|
|
|
* @param text 片段内容
|
|
|
*/
|
|
|
public void onPartial(String text) {
|
|
|
ChatCompletionChunk chunk = ChatCompletionChunk.builder()
|
|
|
.id(chatId)
|
|
|
.object("chat.completion.chunk")
|
|
|
.created(System.currentTimeMillis() / 1000)
|
|
|
.model(request.getModel())
|
|
|
.withChoices(choices -> {
|
|
|
choices.add(ChatCompletionChunk.choiceBuilder().index(0)
|
|
|
.role(Role.ASSISTANT)
|
|
|
.content(text)
|
|
|
.build());
|
|
|
})
|
|
|
.build();
|
|
|
this.pushChunk(chunk);
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
* 发送完成信号,根据OpenAI API规范,最后一个片段需要包含finish_reason字段
|
|
|
*/
|
|
|
public void onComplete() {
|
|
|
ChatCompletionChunk chunk = ChatCompletionChunk.builder()
|
|
|
.id(chatId)
|
|
|
.object("chat.completion.chunk")
|
|
|
.created(System.currentTimeMillis() / 1000)
|
|
|
.model(request.getModel())
|
|
|
.withChoices(choices -> {
|
|
|
choices.add(ChatCompletionChunk.choiceBuilder().index(0)
|
|
|
.role(Role.ASSISTANT)
|
|
|
// 符合OpenAI规范,最后一个片段需要设置finish_reason
|
|
|
.finish_reason("stop")
|
|
|
.build());
|
|
|
})
|
|
|
.build();
|
|
|
this.pushChunk(chunk);
|
|
|
this.complete();
|
|
|
this.status = Status.SUCCESS;
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
* 获取完成状态
|
|
|
*/
|
|
|
public boolean isComplete() {
|
|
|
return complete;
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
* 设置完成状态
|
|
|
*/
|
|
|
public void setComplete(boolean complete) {
|
|
|
this.complete = complete;
|
|
|
}
|
|
|
}
|