Browse Source

【代码优化】AI:完善 YiYanChatTests 单测,方便大家快速体验

YunaiV 9 months ago
parent
commit
4daff93313

+ 13 - 14
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java

@@ -58,7 +58,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
                 case TONG_YI:
                     return buildTongYiChatModel(apiKey);
                 case YI_YAN:
-                    return buildYiYanChatClient(apiKey);
+                    return buildYiYanChatModel(apiKey);
                 case XING_HUO:
                     return buildXingHuoChatClient(apiKey);
                 case DEEP_SEEK:
@@ -156,6 +156,18 @@ public class AiModelFactoryImpl implements AiModelFactory {
         return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties);
     }
 
+    /**
+     * 可参考 {@link QianFanAutoConfiguration#qianFanChatModel(QianFanConnectionProperties, QianFanChatProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
+     */
+    private static QianFanChatModel buildYiYanChatModel(String key) {
+        List<String> keys = StrUtil.split(key, '|');
+        Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式");
+        String appKey = keys.get(0);
+        String secretKey = keys.get(1);
+        QianFanApi qianFanApi = new QianFanApi(appKey, secretKey);
+        return new QianFanChatModel(qianFanApi);
+    }
+
     /**
      * 可参考 {@link OpenAiAutoConfiguration}
      */
@@ -182,19 +194,6 @@ public class AiModelFactoryImpl implements AiModelFactory {
         return new OllamaChatModel(ollamaApi);
     }
 
-    /**
-     * 可参考 {@link QianFanAutoConfiguration#qianFanChatModel(QianFanConnectionProperties, QianFanChatProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
-     */
-    private static QianFanChatModel buildYiYanChatClient(String key) {
-        // TODO @xin:貌似目前设置,request 势必会报错;看看能不能有办法,参考 buildQianWenChatClient,调用 QianFanAutoConfiguration#qianFanChatModel初始化,当然 key 要用自己的哈
-        List<String> keys = StrUtil.split(key, '|');
-        Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式");
-        String appKey = keys.get(0);
-        String secretKey = keys.get(1);
-        QianFanApi qianFanApi = new QianFanApi(appKey, secretKey);
-        return new QianFanChatModel(qianFanApi);
-    }
-
     /**
      * 可参考 {@link YudaoAiAutoConfiguration#xingHuoChatClient(YudaoAiProperties)}
      */

+ 1 - 3
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java

@@ -27,9 +27,7 @@ public class AiUtils {
             case OLLAMA:
                 return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
             case YI_YAN:
-                // TODO @xin:貌似 model 只要一设置,就报错;可以排查下
-//                return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
-                return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
+                return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
             case XING_HUO:
                 return XingHuoChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build();
             case TONG_YI:

+ 53 - 66
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/YiYanChatTests.java

@@ -1,74 +1,61 @@
 package cn.iocoder.yudao.framework.ai.chat;
 
-//import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
-//import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
-//import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
-//import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatModel;
-//import org.junit.Before;
-//import org.junit.Test;
-//import org.springframework.ai.chat.messages.Message;
-//import org.springframework.ai.chat.messages.SystemMessage;
-//import org.springframework.ai.chat.messages.UserMessage;
-//import org.springframework.ai.chat.model.ChatResponse;
-//import org.springframework.ai.chat.prompt.Prompt;
-//import reactor.core.publisher.Flux;
-//
-//import java.util.ArrayList;
-//import java.util.List;
-//import java.util.Scanner;
+import org.junit.jupiter.api.Disabled;
+import org.junit.jupiter.api.Test;
+import org.springframework.ai.chat.messages.Message;
+import org.springframework.ai.chat.messages.UserMessage;
+import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.qianfan.QianFanChatModel;
+import org.springframework.ai.qianfan.QianFanChatOptions;
+import org.springframework.ai.qianfan.api.QianFanApi;
+import reactor.core.publisher.Flux;
+
+import java.util.ArrayList;
+import java.util.List;
 
-// TODO 芋艿:整理单测
 /**
- * chat 文心一言
- * <p>
- * author: fansili
- * time: 2024/3/12 20:59
+ * {@link QianFanChatModel} 的集成测试
+ *
+ * @author fansili
  */
 public class YiYanChatTests {
 
-//    private YiYanChatClient yiYanChatClient;
-//
-//    @Before
-//    public void setup() {
-//        YiYanApi yiYanApi = new YiYanApi(
-//                "x0cuLZ7XsaTCU08vuJWO87Lg",
-//                "R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK",
-//                YiYanChatModel.ERNIE4_3_5_8K,
-//                86400
-//        );
-//        YiYanChatOptions yiYanOptions = new YiYanChatOptions();
-//        yiYanOptions.setMaxOutputTokens(2048);
-//        yiYanOptions.setTopP(0.6f);
-//        yiYanOptions.setTemperature(0.85f);
-//        yiYanChatClient = new YiYanChatClient(
-//                yiYanApi,
-//                yiYanOptions
-//        );
-//    }
-//
-//    @Test
-//    public void callTest() {
-//
-//        // tip: 百度的message 有特殊规则(最后一个message为当前请求的信息,前面的message为历史对话信息)
-//        // tip: 地址 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11
-//        List<Message> messages = new ArrayList<>();
-//        messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景,所有问题都采用文言文回答。"));
-//        messages.add(new UserMessage("长沙怎么样?"));
-//
-//        ChatResponse call = yiYanChatClient.call(new Prompt(messages));
-//        System.err.println(call.getResult());
-//    }
-//
-//    @Test
-//    public void streamTest() {
-//        List<Message> messages = new ArrayList<>();
-//        messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景,所有问题都采用文言文回答。"));
-//        messages.add(new UserMessage("长沙怎么样?"));
-//
-//        Flux<ChatResponse> fluxResponse = yiYanChatClient.stream(new Prompt(messages));
-//        fluxResponse.subscribe(chatResponse -> System.err.print(chatResponse.getResult().getOutput().getContent()));
-//        // 阻止退出
-//        Scanner scanner = new Scanner(System.in);
-//        scanner.nextLine();
-//    }
+    private final QianFanApi qianFanApi = new QianFanApi(
+            "qS8k8dYr2nXunagK4SSU8Xjj",
+            "pHGbx51ql2f0hOyabQvSZezahVC3hh3e");
+    private final QianFanChatModel chatModel = new QianFanChatModel(qianFanApi,
+            QianFanChatOptions.builder().withModel(QianFanApi.ChatModel.ERNIE_Tiny_8K.getValue()).build()
+    );
+
+    @Test
+    @Disabled
+    public void testCall() {
+        // 准备参数
+        List<Message> messages = new ArrayList<>();
+        // TODO @芋艿:文心一言,只要带上 system message 就报错,已经各种测试,很莫名!
+//        messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
+        messages.add(new UserMessage("1 + 1 = ?"));
+
+        // 调用
+        ChatResponse response = chatModel.call(new Prompt(messages));
+        // 打印结果
+        System.out.println(response);
+    }
+
+    @Test
+    @Disabled
+    public void testStream() {
+        // 准备参数
+        List<Message> messages = new ArrayList<>();
+        // TODO @芋艿:文心一言,只要带上 system message 就报错,已经各种测试,很莫名!
+//        messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
+        messages.add(new UserMessage("1 + 1 = ?"));
+
+        // 调用
+        Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
+        // 打印结果
+        flux.doOnNext(System.out::println).then().block();
+    }
+
 }