浏览代码

【代码优化】AI:完善 LlamaChatModelTests 单测,方便大家快速体验

YunaiV 11 月之前
父节点
当前提交
0139317ac4

+ 3 - 3
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java

@@ -55,8 +55,6 @@ public class AiModelFactoryImpl implements AiModelFactory {
         return Singleton.get(cacheKey, (Func0<ChatModel>) () -> {
             //noinspection EnhancedSwitchMigration
             switch (platform) {
-                case OLLAMA:
-                    return buildOllamaChatClient(url);
                 case YI_YAN:
                     return buildYiYanChatClient(apiKey);
                 case XING_HUO:
@@ -67,6 +65,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
                     return buildDeepSeekChatClient(apiKey);
                 case OPENAI:
                     return buildOpenAiChatModel(apiKey, url);
+                case OLLAMA:
+                    return buildOllamaChatModel(url);
                 default:
                     throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
             }
@@ -163,7 +163,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
     /**
      * 可参考 {@link OllamaAutoConfiguration}
      */
-    private static OllamaChatModel buildOllamaChatClient(String url) {
+    private static OllamaChatModel buildOllamaChatModel(String url) {
         OllamaApi ollamaApi = new OllamaApi(url);
         return new OllamaChatModel(ollamaApi);
     }

+ 63 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/LlamaChatModelTests.java

@@ -0,0 +1,63 @@
+package cn.iocoder.yudao.framework.ai.chat;
+
+import org.junit.jupiter.api.Disabled;
+import org.junit.jupiter.api.Test;
+import org.springframework.ai.chat.messages.Message;
+import org.springframework.ai.chat.messages.SystemMessage;
+import org.springframework.ai.chat.messages.UserMessage;
+import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.ollama.OllamaChatModel;
+import org.springframework.ai.ollama.api.OllamaApi;
+import org.springframework.ai.ollama.api.OllamaModel;
+import org.springframework.ai.ollama.api.OllamaOptions;
+import reactor.core.publisher.Flux;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * {@link OllamaChatModel} 集成测试
+ *
+ * @author 芋道源码
+ */
+public class LlamaChatModelTests {
+
+    private final OllamaApi ollamaApi = new OllamaApi(
+            "http://127.0.0.1:11434");
+    private final OllamaChatModel chatModel = new OllamaChatModel(ollamaApi,
+            OllamaOptions.create().withModel(OllamaModel.LLAMA3.getModelName()));
+
+    @Test
+    @Disabled
+    public void testCall() {
+        // 准备参数
+        List<Message> messages = new ArrayList<>();
+        messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
+        messages.add(new UserMessage("1 + 1 = ?"));
+
+        // 调用
+        ChatResponse response = chatModel.call(new Prompt(messages));
+        // 打印结果
+        System.out.println(response);
+        System.out.println(response.getResult().getOutput());
+    }
+
+    @Test
+    @Disabled
+    public void testStream() {
+        // 准备参数
+        List<Message> messages = new ArrayList<>();
+        messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
+        messages.add(new UserMessage("1 + 1 = ?"));
+
+        // 调用
+        Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
+        // 打印结果
+        flux.doOnNext(response -> {
+//            System.out.println(response);
+            System.out.println(response.getResult().getOutput());
+        }).then().block();
+    }
+
+}