소스 검색

【代码优化】AI:完善 TongYiChatModelTests 单测,方便大家快速体验

YunaiV 9 달 전
부모
커밋
4f11d00cfd

+ 1 - 1
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/enums/AiPlatformEnum.java

@@ -14,8 +14,8 @@ public enum AiPlatformEnum {
 
 
     // ========== 国内平台 ==========
     // ========== 国内平台 ==========
 
 
+    TONG_YI("TongYi", "通义千问"), // 阿里
     YI_YAN("YiYan", "文心一言"), // 百度
     YI_YAN("YiYan", "文心一言"), // 百度
-    QIAN_WEN("QianWen", "千问"), // 阿里
     DEEP_SEEK("DeepSeek", "DeepSeek"), // DeepSeek
     DEEP_SEEK("DeepSeek", "DeepSeek"), // DeepSeek
     XING_HUO("XingHuo", "星火"), // 讯飞
     XING_HUO("XingHuo", "星火"), // 讯飞
 
 

+ 20 - 19
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java

@@ -55,12 +55,12 @@ public class AiModelFactoryImpl implements AiModelFactory {
         return Singleton.get(cacheKey, (Func0<ChatModel>) () -> {
         return Singleton.get(cacheKey, (Func0<ChatModel>) () -> {
             //noinspection EnhancedSwitchMigration
             //noinspection EnhancedSwitchMigration
             switch (platform) {
             switch (platform) {
+                case TONG_YI:
+                    return buildTongYiChatModel(apiKey);
                 case YI_YAN:
                 case YI_YAN:
                     return buildYiYanChatClient(apiKey);
                     return buildYiYanChatClient(apiKey);
                 case XING_HUO:
                 case XING_HUO:
                     return buildXingHuoChatClient(apiKey);
                     return buildXingHuoChatClient(apiKey);
-                case QIAN_WEN:
-                    return buildQianWenChatClient(apiKey);
                 case DEEP_SEEK:
                 case DEEP_SEEK:
                     return buildDeepSeekChatClient(apiKey);
                     return buildDeepSeekChatClient(apiKey);
                 case OPENAI:
                 case OPENAI:
@@ -77,16 +77,16 @@ public class AiModelFactoryImpl implements AiModelFactory {
     public ChatModel getDefaultChatModel(AiPlatformEnum platform) {
     public ChatModel getDefaultChatModel(AiPlatformEnum platform) {
         //noinspection EnhancedSwitchMigration
         //noinspection EnhancedSwitchMigration
         switch (platform) {
         switch (platform) {
-            case OLLAMA:
-                return SpringUtil.getBean(OllamaChatModel.class);
+            case TONG_YI:
+                return SpringUtil.getBean(TongYiChatModel.class);
             case YI_YAN:
             case YI_YAN:
                 return SpringUtil.getBean(QianFanChatModel.class);
                 return SpringUtil.getBean(QianFanChatModel.class);
             case XING_HUO:
             case XING_HUO:
                 return SpringUtil.getBean(XingHuoChatClient.class);
                 return SpringUtil.getBean(XingHuoChatClient.class);
-            case QIAN_WEN:
-                return SpringUtil.getBean(TongYiChatModel.class);
             case OPENAI:
             case OPENAI:
                 return SpringUtil.getBean(OpenAiChatModel.class);
                 return SpringUtil.getBean(OpenAiChatModel.class);
+            case OLLAMA:
+                return SpringUtil.getBean(OllamaChatModel.class);
             default:
             default:
                 throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
                 throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
         }
         }
@@ -142,6 +142,20 @@ public class AiModelFactoryImpl implements AiModelFactory {
 
 
     // ========== 各种创建 spring-ai 客户端的方法 ==========
     // ========== 各种创建 spring-ai 客户端的方法 ==========
 
 
+    /**
+     * 可参考 {@link TongYiAutoConfiguration#tongYiChatClient(Generation, TongYiChatProperties, TongYiConnectionProperties)}
+     */
+    private static TongYiChatModel buildTongYiChatModel(String key) {
+        com.alibaba.dashscope.aigc.generation.Generation generation = SpringUtil.getBean(Generation.class);
+        TongYiChatProperties chatOptions = SpringUtil.getBean(TongYiChatProperties.class);
+        // TODO @芋艿:貌似 apiKey 是全局唯一的???得测试下
+        // TODO @芋艿:貌似阿里云不是增量返回的
+        // 该 issue 进行跟进中 https://github.com/alibaba/spring-cloud-alibaba/issues/3790
+        TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties();
+        connectionProperties.setApiKey(key);
+        return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties);
+    }
+
     /**
     /**
      * 可参考 {@link OpenAiAutoConfiguration}
      * 可参考 {@link OpenAiAutoConfiguration}
      */
      */
@@ -196,19 +210,6 @@ public class AiModelFactoryImpl implements AiModelFactory {
         return new DeepSeekChatClient(apiKey);
         return new DeepSeekChatClient(apiKey);
     }
     }
 
 
-    /**
-     * 可参考 {@link TongYiAutoConfiguration#tongYiChatClient(Generation, TongYiChatProperties, TongYiConnectionProperties)}
-     */
-    private static TongYiChatModel buildQianWenChatClient(String key) {
-        com.alibaba.dashscope.aigc.generation.Generation generation = SpringUtil.getBean(Generation.class);
-        TongYiChatProperties chatOptions = SpringUtil.getBean(TongYiChatProperties.class);
-        // TODO @xin:貌似 apiKey 是全局唯一的???得测试下
-        // TODO @xin:貌似阿里云不是增量返回的
-        TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties();
-        connectionProperties.setApiKey(key);
-        return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties);
-    }
-
     private StabilityAiImageModel buildStabilityAiImageClient(String apiKey, String url) {
     private StabilityAiImageModel buildStabilityAiImageClient(String apiKey, String url) {
         url = StrUtil.blankToDefault(url, StabilityAiApi.DEFAULT_BASE_URL);
         url = StrUtil.blankToDefault(url, StabilityAiApi.DEFAULT_BASE_URL);
         StabilityAiApi stabilityAiApi = new StabilityAiApi(apiKey, StabilityAiApi.DEFAULT_IMAGE_MODEL, url);
         StabilityAiApi stabilityAiApi = new StabilityAiApi(apiKey, StabilityAiApi.DEFAULT_IMAGE_MODEL, url);

+ 1 - 1
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java

@@ -32,7 +32,7 @@ public class AiUtils {
                 return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
                 return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
             case XING_HUO:
             case XING_HUO:
                 return XingHuoChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build();
                 return XingHuoChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build();
-            case QIAN_WEN:
+            case TONG_YI:
                 return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
                 return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
             case DEEP_SEEK:
             case DEEP_SEEK:
                 return DeepSeekChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build();
                 return DeepSeekChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build();

+ 0 - 105
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/QianWenChatClientTests.java

@@ -1,105 +0,0 @@
-//package cn.iocoder.yudao.framework.ai.chat;
-//
-//import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient;
-//import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal;
-//import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions;
-//import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi;
-//import com.alibaba.dashscope.aigc.generation.GenerationResult;
-//import com.alibaba.dashscope.aigc.generation.models.QwenParam;
-//import com.alibaba.dashscope.common.Message;
-//import com.alibaba.dashscope.common.MessageManager;
-//import com.alibaba.dashscope.common.Role;
-//import com.alibaba.dashscope.exception.InputRequiredException;
-//import com.alibaba.dashscope.exception.NoApiKeyException;
-//import org.junit.Before;
-//import org.junit.Test;
-//import org.springframework.ai.chat.messages.SystemMessage;
-//import org.springframework.ai.chat.messages.UserMessage;
-//import org.springframework.ai.chat.model.ChatResponse;
-//import org.springframework.ai.chat.prompt.Prompt;
-//import reactor.core.publisher.Flux;
-//
-//import java.util.ArrayList;
-//import java.util.List;
-//import java.util.Scanner;
-//import java.util.function.Consumer;
-//
-//// TODO 芋艿:整理单测
-///**
-// * author: fansili
-// * time: 2024/3/13 21:37
-// */
-//public class QianWenChatClientTests {
-//
-//    private QianWenChatClient qianWenChatClient;
-//
-//    @Before
-//    public void setup() {
-//        QianWenApi qianWenApi = new QianWenApi("sk-Zsd81gZYg7", QianWenChatModal.QWEN_72B_CHAT);
-//        QianWenOptions qianWenOptions = new QianWenOptions();
-//        qianWenOptions.setTopP(0.8F);
-////        qianWenOptions.setTopK(3); TODO 芋艿:临时处理
-////        qianWenOptions.setTemperature(0.6F); TODO 芋艿:临时处理
-//        qianWenChatClient = new QianWenChatClient(
-//                qianWenApi,
-//                qianWenOptions
-//        );
-//    }
-//
-//    @Test
-//    public void callTest() {
-//        List<org.springframework.ai.chat.messages.Message> messages = new ArrayList<>();
-//        messages.add(new SystemMessage("你是一个优质的小红书文艺作者,抒写着各城市的美好文化和风景。"));
-//        messages.add(new UserMessage("长沙怎么样?"));
-//
-//        ChatResponse call = qianWenChatClient.call(new Prompt(messages));
-//        System.err.println(call.getResult());
-//    }
-//
-//    @Test
-//    public void streamTest() {
-//        List<org.springframework.ai.chat.messages.Message> messages = new ArrayList<>();
-//        messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
-//        messages.add(new UserMessage("长沙怎么样?"));
-//
-//        Flux<ChatResponse> flux = qianWenChatClient.stream(new Prompt(messages));
-//        flux.subscribe(new Consumer<ChatResponse>() {
-//            @Override
-//            public void accept(ChatResponse chatResponse) {
-//                System.err.print(chatResponse.getResult().getOutput().getContent());
-//            }
-//        });
-//
-//        // 阻止退出
-//        Scanner scanner = new Scanner(System.in);
-//        scanner.nextLine();
-//    }
-//
-//    @Test
-//    public void qianwenDemoTest() throws NoApiKeyException, InputRequiredException {
-//        com.alibaba.dashscope.aigc.generation.Generation gen = new com.alibaba.dashscope.aigc.generation.Generation();
-//        MessageManager msgManager = new MessageManager(10);
-//        Message systemMsg =
-//                Message.builder().role(Role.SYSTEM.getValue()).content("You are a helpful assistant.").build();
-//        Message userMsg = Message.builder().role(Role.USER.getValue()).content("就当前的海洋污染的情况,写一份限塑的倡议书提纲,需要有理有据地号召大家克制地使用塑料制品").build();
-//        msgManager.add(systemMsg);
-//        msgManager.add(userMsg);
-//        QwenParam param =
-//                QwenParam.builder().model("qwen-72b-chat").messages(msgManager.get())
-//                        .resultFormat(QwenParam.ResultFormat.MESSAGE)
-//                        .topP(0.8)
-//                        /* set the random seed, optional, default to 1234 if not set */
-//                        .seed(100)
-//                        .apiKey("sk-Zsd81gZYg7")
-//                        .build();
-//        GenerationResult result = gen.call(param);
-//        System.out.println(result);
-//        System.out.println("-----------------");
-//        System.out.println("-----------------");
-//        msgManager.add(result);
-//        param.setPrompt("能否缩短一些,只讲三点");
-//        param.setMessages(msgManager.get());
-//        result = gen.call(param);
-//        System.out.println(result);
-//    }
-//}

+ 75 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/TongYiChatModelTests.java

@@ -0,0 +1,75 @@
+package cn.iocoder.yudao.framework.ai.chat;
+
+import cn.hutool.core.util.ReflectUtil;
+import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel;
+import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
+import com.alibaba.dashscope.aigc.generation.Generation;
+import com.alibaba.dashscope.common.MessageManager;
+import com.alibaba.dashscope.utils.Constants;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Disabled;
+import org.junit.jupiter.api.Test;
+import org.springframework.ai.chat.messages.Message;
+import org.springframework.ai.chat.messages.SystemMessage;
+import org.springframework.ai.chat.messages.UserMessage;
+import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.chat.prompt.Prompt;
+import reactor.core.publisher.Flux;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * {@link TongYiChatModel} 集成测试类
+ *
+ * @author fansili
+ */
+public class TongYiChatModelTests {
+
+    private final Generation generation = new Generation();
+    private final TongYiChatModel chatModel = new TongYiChatModel(generation,
+            TongYiChatOptions.builder().withModel("qwen1.5-72b-chat").build());
+
+    static {
+        Constants.apiKey = "sk-Zsd81gZYg7";
+    }
+
+    @BeforeEach
+    public void before() {
+        // 防止 TongYiChatModel 调用空指针
+        ReflectUtil.setFieldValue(chatModel, "msgManager", new MessageManager());
+    }
+
+    @Test
+    @Disabled
+    public void testCall() {
+        // 准备参数
+        List<Message> messages = new ArrayList<>();
+        messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
+        messages.add(new UserMessage("1 + 1 = ?"));
+
+        // 调用
+        ChatResponse response = chatModel.call(new Prompt(messages));
+        // 打印结果
+        System.out.println(response);
+        System.out.println(response.getResult().getOutput());
+    }
+
+    @Test
+    @Disabled
+    public void testStream() {
+        // 准备参数
+        List<Message> messages = new ArrayList<>();
+        messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
+        messages.add(new UserMessage("1 + 1 = ?"));
+
+        // 调用
+        Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
+        // 打印结果
+        flux.doOnNext(response -> {
+//            System.out.println(response);
+            System.out.println(response.getResult().getOutput());
+        }).then().block();
+    }
+
+}