diff --git a/src/main/java/xyz/wbsite/ai/Test.java b/src/main/java/xyz/wbsite/ai/Test.java index 375f0c2..219cfe4 100644 --- a/src/main/java/xyz/wbsite/ai/Test.java +++ b/src/main/java/xyz/wbsite/ai/Test.java @@ -1,11 +1,9 @@ package xyz.wbsite.ai; import cn.hutool.core.collection.CollUtil; +import dev.langchain4j.agent.tool.*; import dev.langchain4j.data.document.Document; -import dev.langchain4j.data.message.AiMessage; -import dev.langchain4j.data.message.ChatMessage; -import dev.langchain4j.data.message.SystemMessage; -import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.data.message.*; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.memory.chat.MessageWindowChatMemory; import dev.langchain4j.model.StreamingResponseHandler; @@ -17,10 +15,15 @@ import dev.langchain4j.model.output.Response; import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever; import dev.langchain4j.service.AiServices; import dev.langchain4j.service.TokenStream; +import dev.langchain4j.service.tool.DefaultToolExecutor; +import dev.langchain4j.service.tool.ToolExecutor; import dev.langchain4j.store.embedding.EmbeddingStoreIngestor; import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore; +import java.util.ArrayList; import java.util.List; +import java.util.UUID; +import java.util.function.Consumer; /** * 主函数入口 @@ -30,8 +33,8 @@ public class Test { public static void main(String[] args) { // testSimpleChat(args); // testStreamChat(args); - testRagChat(args); -// testTool(args); +// testRagChat(args); + testTool(args); } public static void testSimpleChat(String[] args) { @@ -138,62 +141,57 @@ public class Test { } public static void testTool(String[] args) { -// ChatLanguageModel model = OllamaChatModel.builder() -// .baseUrl("http://36.138.207.178:11434") -// .modelName("qwen2.5:7b") -// .logRequests(true) -// .logResponses(true) -// .build(); -// -// List chatMessages = new ArrayList<>(); -// chatMessages.add(UserMessage.from("请问,泰州市的天气怎么样?")); -// -// Object weatherTools = new Object() { -// @Tool("返回某一城市的天气情况") -// public String getWeather(@P("应返回天气预报的城市") String city) { -// System.out.println(city); -// return "天气阴转多云,1~6℃"; -// } -// }; -// -// List toolSpecifications = ToolSpecifications.toolSpecificationsFrom(weatherTools); -// -// ChatRequest chatRequest = ChatRequest.builder() -// .messages(chatMessages) -// .parameters(ChatRequestParameters.builder() -// .toolSpecifications(toolSpecifications) -// .build()) -// .build(); -// -// -// ChatResponse chatResponse = model.chat(chatRequest); -// AiMessage aiMessage = chatResponse.aiMessage(); -// chatMessages.add(aiMessage); -// if (aiMessage.hasToolExecutionRequests()) { -// System.out.println("LLM决定调用工具"); -// System.out.println(chatResponse.aiMessage()); -// List toolExecutionRequests = chatResponse.aiMessage().toolExecutionRequests(); -// toolExecutionRequests.forEach(new Consumer() { -// @Override -// public void accept(ToolExecutionRequest toolExecutionRequest) { -// ToolExecutor toolExecutor = new DefaultToolExecutor(weatherTools, toolExecutionRequest); -// System.out.println("Now let's execute the tool " + toolExecutionRequest.name()); -// String result = toolExecutor.execute(toolExecutionRequest, UUID.randomUUID().toString()); -// ToolExecutionResultMessage toolExecutionResultMessages = ToolExecutionResultMessage.from(toolExecutionRequest, result); -// chatMessages.add(toolExecutionResultMessages); -// } -// }); -// } -// -// // STEP 4: Model generates final response -// ChatRequest chatRequest2 = ChatRequest.builder() -// .messages(chatMessages) -// .parameters(ChatRequestParameters.builder() -// .toolSpecifications(toolSpecifications) -// .build()) -// .build(); -// ChatResponse finalChatResponse = model.chat(chatRequest2); -// System.out.println(finalChatResponse.aiMessage().text()); + OpenAiChatModel model = OpenAiChatModel.builder() + .baseUrl("http://36.138.207.178:11434/v1") + .apiKey("1") + .modelName("qwen2.5:7b") + .build(); + + List chatMessages = new ArrayList<>(); + chatMessages.add(UserMessage.from("请问,泰州市的天气怎么样?")); + + Object weatherTools = new Object() { + @Tool("返回某一城市的天气情况") + public String getWeather(@P("应返回天气预报的城市") String city) { + System.out.println(city); + return "天气阴转多云,1~6℃"; + } + }; + + List toolSpecifications = ToolSpecifications.toolSpecificationsFrom(weatherTools); + + ChatRequest chatRequest = ChatRequest.builder() + .messages(chatMessages) + .toolSpecifications(toolSpecifications) + .build(); + + + ChatResponse chatResponse = model.chat(chatRequest); + AiMessage aiMessage = chatResponse.aiMessage(); + chatMessages.add(aiMessage); + if (aiMessage.hasToolExecutionRequests()) { + System.out.println("LLM决定调用工具"); + System.out.println(chatResponse.aiMessage()); + List toolExecutionRequests = chatResponse.aiMessage().toolExecutionRequests(); + toolExecutionRequests.forEach(new Consumer() { + @Override + public void accept(ToolExecutionRequest toolExecutionRequest) { + ToolExecutor toolExecutor = new DefaultToolExecutor(weatherTools, toolExecutionRequest); + System.out.println("Now let's execute the tool " + toolExecutionRequest.name()); + String result = toolExecutor.execute(toolExecutionRequest, UUID.randomUUID().toString()); + ToolExecutionResultMessage toolExecutionResultMessages = ToolExecutionResultMessage.from(toolExecutionRequest, result); + chatMessages.add(toolExecutionResultMessages); + } + }); + } + + // STEP 4: Model generates final response + ChatRequest chatRequest2 = ChatRequest.builder() + .messages(chatMessages) + .toolSpecifications(toolSpecifications) + .build(); + ChatResponse finalChatResponse = model.chat(chatRequest2); + System.out.println(finalChatResponse.aiMessage().text()); } // 创建一个助手接口