Selaa lähdekoodia

百度 文心一言 适配chatOptions

cherishsince 1 vuosi sitten
vanhempi
commit
f41e43713c

+ 41 - 20
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatyiyan/YiYanChatClient.java

@@ -1,10 +1,8 @@
 package cn.iocoder.yudao.framework.ai.chatyiyan;
 
-import cn.iocoder.yudao.framework.ai.chat.ChatClient;
-import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
-import cn.iocoder.yudao.framework.ai.chat.Generation;
-import cn.iocoder.yudao.framework.ai.chat.StreamingChatClient;
-import cn.iocoder.yudao.framework.ai.chat.messages.Message;
+import cn.hutool.core.bean.BeanUtil;
+import cn.iocoder.yudao.framework.ai.chat.*;
+import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions;
 import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
 import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanChatCompletion;
 import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanChatCompletionRequest;
@@ -18,7 +16,6 @@ import org.springframework.retry.support.RetryTemplate;
 import reactor.core.publisher.Flux;
 
 import java.time.Duration;
-import java.util.ArrayList;
 import java.util.List;
 
 /**
@@ -32,10 +29,17 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient {
 
     private YiYanApi yiYanApi;
 
+    private YiYanOptions yiYanOptions;
+
     public YiYanChatClient(YiYanApi yiYanApi) {
         this.yiYanApi = yiYanApi;
     }
 
+    public YiYanChatClient(YiYanApi yiYanApi, YiYanOptions yiYanOptions) {
+        this.yiYanApi = yiYanApi;
+        this.yiYanOptions = yiYanOptions;
+    }
+
     public final RetryTemplate retryTemplate = RetryTemplate.builder()
             // 最大重试次数 10
             .maxAttempts(10)
@@ -70,20 +74,6 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient {
         });
     }
 
-    private YiYanChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
-        List<YiYanChatCompletionRequest.Message> messages = new ArrayList<>();
-        List<Message> instructions = prompt.getInstructions();
-        for (Message instruction : instructions) {
-            YiYanChatCompletionRequest.Message message = new YiYanChatCompletionRequest.Message();
-            message.setContent(instruction.getContent());
-            message.setRole(instruction.getMessageType().getValue());
-            messages.add(message);
-        }
-        YiYanChatCompletionRequest request = new YiYanChatCompletionRequest(messages);
-        request.setStream(stream);
-        return request;
-    }
-
     @Override
     public Flux<ChatResponse> stream(Prompt prompt) {
         // ctx 会有重试的信息
@@ -93,4 +83,35 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient {
         Flux<YiYanChatCompletion> response = this.yiYanApi.chatCompletionStream(request);
         return response.map(res -> new ChatResponse(List.of(new Generation(res.getResult()))));
     }
+
+    private YiYanChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
+        // 两个都为null 则没有配置文件
+        if (yiYanOptions == null && prompt.getOptions() == null) {
+            throw new ChatException("ChatOptions 未配置参数!");
+        }
+        // 优先使用 Prompt 里面的 ChatOptions
+        ChatOptions options = yiYanOptions;
+        if (prompt.getOptions() != null) {
+            options = (ChatOptions) prompt.getOptions();
+        }
+        // Prompt 里面是一个 ChatOptions,用户可以随意传入,这里做一下判断
+        if (!(options instanceof YiYanOptions)) {
+            throw new ChatException("Prompt 传入的不是 YiYanOptions!");
+        }
+        // 转换 YiYanOptions
+        YiYanOptions qianWenOptions = (YiYanOptions) options;
+        // 创建 request
+        List<YiYanChatCompletionRequest.Message> messageList = prompt.getInstructions().stream().map(
+                msg -> new YiYanChatCompletionRequest.Message()
+                        .setRole(msg.getMessageType().getValue())
+                        .setContent(msg.getContent())
+        ).toList();
+        YiYanChatCompletionRequest request = new YiYanChatCompletionRequest(messageList);
+        // 复制 qianWenOptions 属性取 request(这里 options 属性和 request 基本保持一致)
+        // top: 由于遵循 spring-ai规范,支持在构建client的时候传入默认的 chatOptions
+        BeanUtil.copyProperties(qianWenOptions, request);
+        // 设置 stream
+        request.setStream(stream);
+        return request;
+    }
 }

+ 144 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatyiyan/YiYanOptions.java

@@ -0,0 +1,144 @@
+package cn.iocoder.yudao.framework.ai.chatyiyan;
+
+import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions;
+import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanChatCompletionRequest;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import lombok.Data;
+import lombok.experimental.Accessors;
+
+import java.util.List;
+
+/**
+ * 百度 问心一言
+ *
+ * 文档地址:https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
+ *
+ * author: fansili
+ * time: 2024/3/16 19:33
+ */
+@Data
+@Accessors(chain = true)
+public class YiYanOptions implements ChatOptions {
+
+    /**
+     * 一个可触发函数的描述列表,说明:
+     * (1)支持的function数量无限制
+     * (2)长度限制,最后一个message的content长度(即此轮对话的问题)、functions和system字段总内容不能超过20480 个字符,且不能超过5120 tokens
+     * 必填:否
+     */
+    private List<YiYanChatCompletionRequest.Function> functions;
+    /**
+     * 说明:
+     * (1)较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
+     * (2)默认0.8,范围 (0, 1.0],不能为0
+     * 必填:否
+     */
+    private Float temperature;
+    /**
+     * 说明:
+     * (1)影响输出文本的多样性,取值越大,生成文本的多样性越强
+     * (2)默认0.8,取值范围 [0, 1.0]
+     * 必填:否
+     */
+    private Float top_p;
+    /**
+     * 通过对已生成的token增加惩罚,减少重复生成的现象。说明:
+     * (1)值越大表示惩罚越大
+     * (2)默认1.0,取值范围:[1.0, 2.0]
+     *
+     * 必填:否
+     */
+    private Float penalty_score;
+    /**
+     * 是否以流式接口的形式返回数据,默认false
+     * 必填:否
+     */
+    private Boolean stream;
+    /**
+     * 模型人设,主要用于人设设定,例如,你是xxx公司制作的AI助手,说明:
+     * (1)长度限制,最后一个message的content长度(即此轮对话的问题)、functions和system字段总内容不能超过20480 个字符,且不能超过5120 tokens
+     * (2)如果同时使用system和functions,可能暂无法保证使用效果,持续进行优化
+     * 必填:否
+     */
+    private String system;
+    /**
+     * 生成停止标识,当模型生成结果以stop中某个元素结尾时,停止文本生成。说明:
+     * (1)每个元素长度不超过20字符
+     * (2)最多4个元素
+     * 必填:否
+     */
+    private List<String> stop;
+    /**
+     * 是否强制关闭实时搜索功能,默认false,表示不关闭
+     * 必填:否
+     */
+    private Boolean disable_search;
+    /**
+     * 是否开启上角标返回,说明:
+     * (1)开启后,有概率触发搜索溯源信息search_info,search_info内容见响应参数介绍
+     * (2)默认false,不开启
+     * 必填:否
+     */
+    private Boolean enable_citation;
+    /**
+     * 指定模型最大输出token数,范围[2, 2048]
+     * 必填:否
+     */
+    private Integer max_output_tokens;
+    /**
+     * 指定响应内容的格式,说明:
+     * (1)可选值:
+     * · json_object:以json格式返回,可能出现不满足效果情况
+     * · text:以文本格式返回
+     * (2)如果不填写参数response_format值,默认为text
+     * 必填:否
+     */
+    private String response_format;
+    /**
+     * 表示最终用户的唯一标识符
+     * 必填:否
+     */
+    private String user_id;
+    /**
+     * 在函数调用场景下,提示大模型选择指定的函数(非强制),说明:指定的函数名必须在functions中存在
+     * 必填:否
+     *
+     * ERNIE-4.0-8K 模型没有这个字段
+     */
+    private String tool_choice;
+
+    //
+    // 以下兼容 spring-ai ChatOptions 暂时没有其他地方用到
+
+    @Override
+    public Float getTemperature() {
+        return this.temperature;
+    }
+
+    @Override
+    public void setTemperature(Float temperature) {
+        this.temperature = temperature;
+    }
+
+    @Override
+    public Float getTopP() {
+        return top_p;
+    }
+
+    @Override
+    public void setTopP(Float topP) {
+        this.top_p = topP;
+    }
+
+    // 百度么有 topK
+
+    @Override
+    public Integer getTopK() {
+        return null;
+    }
+
+    @Override
+    public void setTopK(Integer topK) {
+
+    }
+}

+ 6 - 4
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatyiyan/api/YiYanChatCompletionRequest.java

@@ -37,14 +37,14 @@ public class YiYanChatCompletionRequest {
      * (2)默认0.8,范围 (0, 1.0],不能为0
      * 必填:否
      */
-    private String temperature;
+    private Float temperature;
     /**
      * 说明:
      * (1)影响输出文本的多样性,取值越大,生成文本的多样性越强
      * (2)默认0.8,取值范围 [0, 1.0]
      * 必填:否
      */
-    private String top_p;
+    private Float top_p;
     /**
      * 通过对已生成的token增加惩罚,减少重复生成的现象。说明:
      * (1)值越大表示惩罚越大
@@ -52,7 +52,7 @@ public class YiYanChatCompletionRequest {
      *
      * 必填:否
      */
-    private String penalty_score;
+    private Float penalty_score;
     /**
      * 是否以流式接口的形式返回数据,默认false
      * 必填:否
@@ -71,7 +71,7 @@ public class YiYanChatCompletionRequest {
      * (2)最多4个元素
      * 必填:否
      */
-    private String stop;
+    private List<String> stop;
     /**
      * 是否强制关闭实时搜索功能,默认false,表示不关闭
      * 必填:否
@@ -106,6 +106,8 @@ public class YiYanChatCompletionRequest {
     /**
      * 在函数调用场景下,提示大模型选择指定的函数(非强制),说明:指定的函数名必须在functions中存在
      * 必填:否
+     *
+     * ERNIE-4.0-8K 模型没有这个字段
      */
     private String tool_choice;
 

+ 3 - 3
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/QianWenChatClientTests.java

@@ -23,9 +23,9 @@ public class QianWenChatClientTests {
     @Before
     public void setup() {
         QianWenApi qianWenApi = new QianWenApi(
-                "",
-                "",
-                "",
+                "LTAI5tNTVhXW4fLKUjMrr98z",
+                "ZJ0JQeyjzxxm5CfeTV6k1wNE9UsvZP",
+                "f0c1088824594f589c8f10567ccd929f_p_efm",
                 null
         );
         qianWenChatClient = new QianWenChatClient(

+ 3 - 8
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/YiYanChatTests.java

@@ -4,12 +4,12 @@ import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
 import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanApi;
 import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatClient;
 import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatModel;
+import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanOptions;
 import org.junit.Before;
 import org.junit.Test;
 import reactor.core.publisher.Flux;
 
 import java.util.Scanner;
-import java.util.function.Consumer;
 
 /**
  * chat 文心一言
@@ -29,7 +29,7 @@ public class YiYanChatTests {
                 YiYanChatModel.ERNIE4_3_5_8K,
                 86400
         );
-        yiYanChatClient = new YiYanChatClient(yiYanApi);
+        yiYanChatClient = new YiYanChatClient(yiYanApi, new YiYanOptions().setMax_output_tokens(2048));
     }
 
     @Test
@@ -41,12 +41,7 @@ public class YiYanChatTests {
     @Test
     public void streamTest() {
         Flux<ChatResponse> fluxResponse = yiYanChatClient.stream(new Prompt("用java帮我写一个快排算法?"));
-        fluxResponse.subscribe(new Consumer<ChatResponse>() {
-            @Override
-            public void accept(ChatResponse chatResponse) {
-                System.err.print(chatResponse.getResult().getOutput().getContent());
-            }
-        });
+        fluxResponse.subscribe(chatResponse -> System.err.print(chatResponse.getResult().getOutput().getContent()));
         // 阻止退出
         Scanner scanner = new Scanner(System.in);
         scanner.nextLine();