You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

186 lines
4.8 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package xyz.wbsite.achat.chat;
import org.springframework.util.StringUtils;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
/**
* 流式对话生成器
*
* @author wangbing
* @version 0.0.1
* @since 1.8
*/
public class StreamEmitter extends SseEmitter {
/**
* 流式输出默认超时时间
*/
private static final Long DEFAULT_TIMEOUT = 5 * 60 * 1000L;
/**
* 对话请求
*/
private ChatCompletionRequest request;
/**
* 当前对话状态
*/
private Status status;
/**
* 当前对话ID
*/
private String chatId;
/**
* 是否完成
*/
private boolean complete;
public StreamEmitter(ChatCompletionRequest request) {
super(DEFAULT_TIMEOUT);
this.request = request;
}
// /**
// * 错误处理
// */
// private void onError(Throwable e) {
//// this.sendMessage(createPartialMessage("</think></think>" + e.getMessage()));
// this.answer.append("</think></think>" + e.getMessage());
// this.onCompleteResponse(null);
// }
//
// /**
// * 部分响应处理
// */
// public void onPartialResponse(String msg) {
// if (complete) {
// return;
// }
//// this.sendMessage(createPartialMessage(msg));
// this.answer.append(msg);
// }
/**
* 完成响应处理
*/
public void onCompleteResponse(Object chatResponse) {
if (this.complete) {
return;
}
// 推送结束
// this.sendMessage(createCompleteMessage());
// 关闭链接
this.complete();
}
/**
* 重写send方法处理异常
*/
@Override
public void send(SseEventBuilder builder) {
try {
super.send(builder);
} catch (Exception e) {
complete = true;
}
}
/**
* 发送片段
*/
private void pushChunk(ChatCompletionChunk chunk) {
try {
this.send(chunk);
} catch (Exception e) {
complete = true;
}
}
/**
* 发送开始片段
*
* @param chatId 对话id
*/
public void onStart(String chatId) {
if (!StringUtils.hasText(chatId)) {
throw new IllegalArgumentException("chatId is empty!");
}
if (this.status != null) {
throw new IllegalArgumentException("chunk has been started!");
}
ChatCompletionChunk chunk = ChatCompletionChunk.builder()
.id(chatId)
.object("chat.completion.chunk")
.created(System.currentTimeMillis() / 1000)
.model(request.getModel())
.withChoices(choices -> {
choices.add(ChatCompletionChunk.choiceBuilder().index(0)
.role(Role.ASSISTANT)
.build());
})
.build();
this.pushChunk(chunk);
this.status = Status.PENDING;
this.chatId = chatId;
}
/**
* 发送部分片段
*
* @param text 片段内容
*/
public void onPartial(String text) {
ChatCompletionChunk chunk = ChatCompletionChunk.builder()
.id(chatId)
.object("chat.completion.chunk")
.created(System.currentTimeMillis() / 1000)
.model(request.getModel())
.withChoices(choices -> {
choices.add(ChatCompletionChunk.choiceBuilder().index(0)
.role(Role.ASSISTANT)
.content(text)
.build());
})
.build();
this.pushChunk(chunk);
}
/**
* 发送完成信号根据OpenAI API规范最后一个片段需要包含finish_reason字段
*/
public void onComplete() {
ChatCompletionChunk chunk = ChatCompletionChunk.builder()
.id(chatId)
.object("chat.completion.chunk")
.created(System.currentTimeMillis() / 1000)
.model(request.getModel())
.withChoices(choices -> {
choices.add(ChatCompletionChunk.choiceBuilder().index(0)
.role(Role.ASSISTANT)
// 符合OpenAI规范最后一个片段需要设置finish_reason
.finish_reason("stop")
.build());
})
.build();
this.pushChunk(chunk);
this.complete();
this.status = Status.SUCCESS;
}
/**
* 获取完成状态
*/
public boolean isComplete() {
return complete;
}
/**
* 设置完成状态
*/
public void setComplete(boolean complete) {
this.complete = complete;
}
}

Powered by TurnKey Linux.