ソースを参照

阿里通义千问,继承 chat options

cherishsince 1 年間 前
コミット
7a785b1ec0

+ 15 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chat/ChatException.java

@@ -0,0 +1,15 @@
+package cn.iocoder.yudao.framework.ai.chat;
+
+/**
+ * 聊天异常
+ *
+ * author: fansili
+ * time: 2024/3/15 20:45
+ */
+public class ChatException extends RuntimeException {
+
+    public ChatException(String message) {
+        super(message);
+    }
+
+}

+ 7 - 32
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatqianwen/QianWenApi.java

@@ -2,7 +2,6 @@ package cn.iocoder.yudao.framework.ai.chatqianwen;
 
 import com.aliyun.broadscope.bailian.sdk.AccessTokenClient;
 import com.aliyun.broadscope.bailian.sdk.ApplicationClient;
-import com.aliyun.broadscope.bailian.sdk.models.ChatRequestMessage;
 import com.aliyun.broadscope.bailian.sdk.models.CompletionsRequest;
 import com.aliyun.broadscope.bailian.sdk.models.CompletionsResponse;
 import lombok.Getter;
@@ -10,8 +9,6 @@ import org.springframework.http.HttpStatusCode;
 import org.springframework.http.ResponseEntity;
 import reactor.core.publisher.Flux;
 
-import java.util.List;
-
 /**
  * 阿里 通义千问
  *
@@ -35,11 +32,10 @@ public class QianWenApi {
     private String token;
     private ApplicationClient client;
 
-    public QianWenApi(String accessKeyId, String accessKeySecret, String agentKey, String appId, String endpoint) {
+    public QianWenApi(String accessKeyId, String accessKeySecret, String agentKey, String endpoint) {
         this.accessKeyId = accessKeyId;
         this.accessKeySecret = accessKeySecret;
         this.agentKey = agentKey;
-        this.appId = appId;
 
         if (endpoint != null) {
             this.endpoint = endpoint;
@@ -54,35 +50,14 @@ public class QianWenApi {
                 .build();
     }
 
-    public ResponseEntity<CompletionsResponse> chatCompletionEntity(ChatRequestMessage message) {
-        // 创建request
-        CompletionsRequest request = new CompletionsRequest()
-                // 设置 appid
-                .setAppId(appId)
-                .setMessages(List.of(message))
-                // 返回choice message结果
-                .setParameters(new CompletionsRequest.Parameter().setResultFormat("message"));
-        //
+    public ResponseEntity<CompletionsResponse> chatCompletionEntity(CompletionsRequest request) {
+        // 发送请求
         CompletionsResponse response = client.completions(request);
-        int httpCode = 200;
-        if (!response.isSuccess()) {
-            System.out.printf("failed to create completion, requestId: %s, code: %s, message: %s\n",
-                    response.getRequestId(), response.getCode(), response.getMessage());
-            httpCode = 500;
-        }
-        return new ResponseEntity<>(response, HttpStatusCode.valueOf(httpCode));
+        // 阿里云的这个 http code 随便设置,外面判断是否成功用的 CompletionsResponse.isSuccess
+        return new ResponseEntity<>(response, HttpStatusCode.valueOf(200));
     }
 
-    public Flux<CompletionsResponse> chatCompletionStream(ChatRequestMessage message) {
-        return client.streamCompletions(
-                new CompletionsRequest()
-                        // 设置 appid
-                        .setAppId(appId)
-                        // 开启 stream
-                        .setStream(true)
-                        .setMessages(List.of(message))
-                        //开启增量输出模式,后面输出不会包含已经输出的内容
-                        .setParameters(new CompletionsRequest.Parameter().setIncrementalOutput(true))
-        );
+    public Flux<CompletionsResponse> chatCompletionStream(CompletionsRequest request) {
+        return client.streamCompletions(request);
     }
 }

+ 69 - 14
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatqianwen/QianWenChatClient.java

@@ -1,14 +1,13 @@
 package cn.iocoder.yudao.framework.ai.chatqianwen;
 
-import cn.iocoder.yudao.framework.ai.chat.ChatClient;
-import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
-import cn.iocoder.yudao.framework.ai.chat.Generation;
-import cn.iocoder.yudao.framework.ai.chat.StreamingChatClient;
+import cn.hutool.core.util.IdUtil;
+import cn.hutool.json.JSONUtil;
+import cn.iocoder.yudao.framework.ai.chat.*;
+import cn.iocoder.yudao.framework.ai.chat.messages.MessageType;
+import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions;
 import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
 import cn.iocoder.yudao.framework.ai.chatyiyan.exception.YiYanApiException;
-import com.aliyun.broadscope.bailian.sdk.models.ChatRequestMessage;
-import com.aliyun.broadscope.bailian.sdk.models.ChatUserMessage;
-import com.aliyun.broadscope.bailian.sdk.models.CompletionsResponse;
+import com.aliyun.broadscope.bailian.sdk.models.*;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.http.ResponseEntity;
 import org.springframework.retry.RetryCallback;
@@ -19,6 +18,7 @@ import reactor.core.publisher.Flux;
 
 import java.time.Duration;
 import java.util.List;
+import java.util.Optional;
 import java.util.stream.Collectors;
 
 /**
@@ -34,10 +34,17 @@ public class QianWenChatClient  implements ChatClient, StreamingChatClient {
 
     private QianWenApi qianWenApi;
 
+    private ChatOptions chatOptions;
+
     public QianWenChatClient(QianWenApi qianWenApi) {
         this.qianWenApi = qianWenApi;
     }
 
+    public QianWenChatClient(QianWenApi qianWenApi, ChatOptions chatOptions) {
+        this.qianWenApi = qianWenApi;
+        this.chatOptions = chatOptions;
+    }
+
     public final RetryTemplate retryTemplate = RetryTemplate.builder()
             // 最大重试次数 10
             .maxAttempts(10)
@@ -58,7 +65,7 @@ public class QianWenChatClient  implements ChatClient, StreamingChatClient {
         return this.retryTemplate.execute(ctx -> {
             // ctx 会有重试的信息
             // 创建 request 请求,stream模式需要供应商支持
-            ChatRequestMessage request = this.createRequest(prompt, false);
+            CompletionsRequest request = this.createRequest(prompt, false);
             // 调用 callWithFunctionSupport 发送请求
             ResponseEntity<CompletionsResponse> responseEntity = qianWenApi.chatCompletionEntity(request);
             // 获取结果封装 chatCompletion
@@ -67,21 +74,69 @@ public class QianWenChatClient  implements ChatClient, StreamingChatClient {
                 return new ChatResponse(List.of(new Generation(String.format("failed to create completion, requestId: %s, code: %s, message: %s\n",
                         response.getRequestId(), response.getCode(), response.getMessage()))));
             }
-            List<Generation> generations = response.getData().getChoices().stream()
-                    .map(item -> new Generation(item.getMessage().getContent())).collect(Collectors.toList());
-            return new ChatResponse(generations);
+            // 转换为 Generation 返回
+            return new ChatResponse(List.of(new Generation(response.getData().getText())));
         });
     }
 
-    private ChatRequestMessage createRequest(Prompt prompt, boolean b) {
-        return new ChatUserMessage(prompt.getContents());
+    private CompletionsRequest createRequest(Prompt prompt, boolean stream) {
+        // 两个都为null 则没有配置文件
+        if (chatOptions == null && prompt.getOptions() == null) {
+            throw new ChatException("ChatOptions 未配置参数!");
+        }
+        // 优先使用 Prompt 里面的 ChatOptions
+        ChatOptions options = chatOptions;
+        if (prompt.getOptions() != null) {
+            options = (ChatOptions) prompt.getOptions();
+        }
+        QianWenOptions qianWenOptions = (QianWenOptions) options;
+        // 需要额外处理
+        if (!stream) {
+            // 如果不需要 stream 输出,那么需要将这个设置为false,不然只会输出最后几个文字
+            if (qianWenOptions.getParameters() == null) {
+                qianWenOptions.setParameters(new CompletionsRequest.Parameter().setIncrementalOutput(false));
+            } else {
+                qianWenOptions.getParameters().setIncrementalOutput(false);
+            }
+        } else {
+            // 如果不需要 stream 输出,设置为true这样不会输出累加内容
+            if (qianWenOptions.getParameters() == null) {
+                qianWenOptions.setParameters(new CompletionsRequest.Parameter().setIncrementalOutput(true));
+            } else {
+                qianWenOptions.getParameters().setIncrementalOutput(true);
+            }
+        }
+
+        // 创建request
+        return new CompletionsRequest()
+                // 请求唯一标识,请确保RequestId不重复。
+                .setRequestId(IdUtil.getSnowflakeNextIdStr())
+                // 设置 appid
+                .setAppId(qianWenOptions.getAppId())
+                .setMessages(prompt.getInstructions().stream().map(m -> {
+                    // 转换成 千问 对于的请求message
+                    if (MessageType.USER == m.getMessageType()) {
+                        return new ChatUserMessage(m.getContent());
+                    } else if (MessageType.SYSTEM == m.getMessageType()) {
+                        return new ChatSystemMessage(m.getContent());
+                    } else if (MessageType.ASSISTANT == m.getMessageType()) {
+                        return new ChatAssistantMessage(m.getContent());
+                    }
+                    throw new ChatException(String.format("存在不能适配的消息! %s", JSONUtil.toJsonPrettyStr(m)));
+                }).collect(Collectors.toList()))
+                // 返回choice message结果
+                .setParameters(qianWenOptions.getParameters())
+                // 设置 ChatOptions 里面公共的参数
+                .setTopP(options.getTopP() == null ? null : options.getTopP().doubleValue())
+                // 设置输出方式
+                .setStream(stream);
     }
 
     @Override
     public Flux<ChatResponse> stream(Prompt prompt) {
         // ctx 会有重试的信息
         // 创建 request 请求,stream模式需要供应商支持
-        ChatRequestMessage request = this.createRequest(prompt, true);
+        CompletionsRequest request = this.createRequest(prompt, true);
         // 调用 callWithFunctionSupport 发送请求
         Flux<CompletionsResponse> response = this.qianWenApi.chatCompletionStream(request);
         return response.map(res -> {

+ 128 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatqianwen/QianWenOptions.java

@@ -0,0 +1,128 @@
+package cn.iocoder.yudao.framework.ai.chatqianwen;
+
+import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions;
+import com.aliyun.broadscope.bailian.sdk.models.CompletionsRequest;
+import lombok.Data;
+import lombok.experimental.Accessors;
+
+import java.util.List;
+
+/**
+ * 阿里云 千问 属性
+ *
+ * 地址:https://help.aliyun.com/document_detail/2684682.html?spm=a2c4g.2621347.0.0.195117e7Ytpkyo
+ *
+ * author: fansili
+ * time: 2024/3/15 19:57
+ */
+@Data
+@Accessors
+public class QianWenOptions implements ChatOptions {
+
+    private String appId;
+    /**
+     * 是否流式输出, 默认为否。
+     */
+    private Boolean stream;
+    /**
+     * 用户与模型的对话历史
+     */
+    private List<Message> messages;
+    /**
+     * 生成时,核采样方法的概率阈值。例如,取值为0.8时,仅保留累计概率之和大于等于0.8的概率分布中的token,
+     * 作为随机采样的候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的随机性越低。
+     * 默认值为0.8。注意,取值不要大于等于1
+     */
+    private Float topP;
+    /**
+     * 模型参数设置。
+     */
+    private CompletionsRequest.Parameter parameters = new CompletionsRequest.Parameter();
+
+    //
+    // 适配 ChatOptions
+
+    @Override
+    public Float getTemperature() {
+        return Float.parseFloat(this.parameters.getTemperature().toString());
+    }
+
+    @Override
+    public void setTemperature(Float temperature) {
+        this.parameters.setTemperature(Double.valueOf(temperature.toString()));
+    }
+
+    @Override
+    public void setTopP(Float topP) {
+        this.topP = topP;
+    }
+
+    @Override
+    public Integer getTopK() {
+        return this.parameters.getTopK();
+    }
+
+    @Override
+    public void setTopK(Integer topK) {
+        this.parameters.setTopK(topK);
+    }
+
+    @Data
+    @Accessors
+    public static class Message {
+        /**
+         * 角色: system、user或assistant
+         */
+        private String role;
+        /**
+         * 提示词或模型内容
+         */
+        private String content;
+    }
+
+    @Data
+    @Accessors
+    public static class Parameters {
+        /**
+         * 输出格式, 默认为"text"
+         * "text"表示旧版本的text
+         * "message"表示兼容openai的message
+         */
+        private String resultFormat;
+        /**
+         * 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。
+         * 取值越大,生成的随机性越高;取值越小,生成的确定性越高。
+         * 注意:如果top_k参数为空或者top_k的值大于100,表示不启用top_k策略,此时仅有top_p策略生效,默认是空。
+         */
+        private Integer topK;
+        /**
+         * 生成时使用的随机数种子,用户控制模型生成内容的随机性。
+         * seed支持无符号64位整数,默认值为1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。
+         */
+        private Integer seed;
+        /**
+         * 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。
+         * 较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,
+         * 生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
+         * 取值范围: [0, 2),系统默认值1.0。不建议取值为0,无意义。
+         */
+        private Float temperature;
+        /**
+         * 用于限制模型生成token的数量,max_tokens设置的是生成上限,并不表示一定会生成这么多的token数量。
+         * 其中qwen-turbo 最大值和默认值为1500, qwen-max、qwen-max-1201 、qwen-max-longcontext 和 qwen-plus最大值和默认值均为2000。
+         */
+        private Integer maxTokens;
+        /**
+         * stop参数用于实现内容生成过程的精确控制,在生成内容即将包含指定的字符串或token_ids时自动停止,生成内容不包含指定的内容。
+         * 例如,如果指定stop为"你好",表示将要生成"你好"时停止;如果指定stop为[37763, 367],表示将要生成"Observation"时停止。
+         */
+        private List<String> stop;
+        /**
+         * 用于控制流式输出模式,默认False,即后面内容会包含已经输出的内容;设置为True,将开启增量输出模式,
+         * 后面输出不会包含已经输出的内容,您需要自行拼接整体输出,参考流式输出示例代码。
+         */
+        private Boolean incrementalOutput;
+    }
+}
+
+

+ 7 - 2
yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/QianWenChatClientTests.java

@@ -3,6 +3,8 @@ package cn.iocoder.yudao.framework.ai.chat;
 import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
 import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenApi;
 import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatClient;
+import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenOptions;
+import com.aliyun.broadscope.bailian.sdk.models.CompletionsRequest;
 import org.junit.Before;
 import org.junit.Test;
 import reactor.core.publisher.Flux;
@@ -24,10 +26,13 @@ public class QianWenChatClientTests {
                 "",
                 "",
                 "",
-                "",
                 null
         );
-        qianWenChatClient = new QianWenChatClient(qianWenApi);
+        qianWenChatClient = new QianWenChatClient(
+                qianWenApi,
+                new QianWenOptions()
+                        .setAppId("5f14955f201a44eb8dbe0c57250a32ce")
+        );
     }
 
     @Test