Преглед изворни кода

优化自动注入,可以创建多个 client

cherishsince пре 1 година
родитељ
комит
ac0de5d485

+ 30 - 73
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/ChatController.java

@@ -1,77 +1,34 @@
-//package cn.iocoder.yudao.module.ai.controller.admin;
-//
-//import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
-//import cn.iocoder.yudao.framework.common.pojo.CommonResult;
-//import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
-//import cn.iocoder.yudao.module.ai.controller.admin.vo.AiChatReqVO;
-//import cn.iocoder.yudao.module.ai.enums.OpenAiModelEnum;
-//import io.swagger.v3.oas.annotations.Operation;
-//import io.swagger.v3.oas.annotations.tags.Tag;
-//import jakarta.servlet.http.HttpServletResponse;
-//import lombok.extern.slf4j.Slf4j;
-//import org.springframework.ai.chat.ChatClient;
-//import org.springframework.ai.chat.ChatResponse;
-//import org.springframework.ai.chat.prompt.Prompt;
-//import org.springframework.ai.openai.OpenAiChatClient;
-//import org.springframework.beans.factory.annotation.Autowired;
-//import org.springframework.context.ApplicationContext;
-//import org.springframework.validation.annotation.Validated;
-//import org.springframework.web.bind.annotation.PostMapping;
-//import org.springframework.web.bind.annotation.RequestBody;
-//import org.springframework.web.bind.annotation.RequestMapping;
-//import org.springframework.web.bind.annotation.RestController;
-//import reactor.core.publisher.Flux;
-//
-//import java.util.function.Consumer;
-//
-//// TODO done @fansili:有了 swagger 注释,就不用类注释了
-//@Tag(name = "AI模块")
-//@RestController
-//@RequestMapping("/ai-api")
-//@Slf4j
-//public class ChatController {
+package cn.iocoder.yudao.module.ai.controller.admin;
+
+import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatClient;
+import io.swagger.v3.oas.annotations.tags.Tag;
+import lombok.AllArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.web.bind.annotation.GetMapping;
+import org.springframework.web.bind.annotation.RequestMapping;
+import org.springframework.web.bind.annotation.RequestParam;
+import org.springframework.web.bind.annotation.RestController;
+
+/**
+ * @author fansili
+ * @since 1.0
+ * @time 2024/4/13 17:44
+ */
+@Tag(name = "AI模块")
+@RestController
+@RequestMapping("/ai-api")
+@Slf4j
+@AllArgsConstructor
+public class ChatController {
+
 //
 //    @Autowired
-//    private ApplicationContext applicationContext;
-//
-//    @PostMapping("/chat")
-//    @Operation(summary = "对话聊天", description = "简单的ai聊天")
-//    public CommonResult chat(@RequestBody @Validated AiChatReqVO reqVO) {
-//        ChatClient chatClient = getChatClient(reqVO.getAiModel());
-//        String res;
-//        try {
-//            res = chatClient.call(reqVO.getPrompt());
-//        } catch (Exception e) {
-//            res = e.getMessage();
-//        }
-//        return CommonResult.success(res);
-//    }
-//
-//    @PostMapping("/chatStream")
-//    @Operation(summary = "对话聊天chatStream", description = "简单的ai聊天")
-//    public CommonResult chatStream(HttpServletResponse response, @RequestBody @Validated AiChatReqVO reqVO) throws InterruptedException {
-//        OpenAiChatClient chatClient = applicationContext.getBean(OpenAiChatClient.class);
-//        Flux<ChatResponse> chatResponse = chatClient.stream(new Prompt(reqVO.getPrompt()));
-//        chatResponse.subscribe(new Consumer<ChatResponse>() {
-//            @Override
-//            public void accept(ChatResponse chatResponse) {
-//                System.err.println(chatResponse.getResults().get(0).getOutput().getContent());
-//            }
-//        });
-//        return CommonResult.success(null);
-//    }
+//    private QianWenChatClient qianWenChatClient;
 //
-//    /**
-//     * 根据 ai模型 获取对于的 模型实现类
-//     *
-//     * @param aiModelEnum
-//     * @return
-//     */
-//    private ChatClient getChatClient(OpenAiModelEnum aiModelEnum) {
-//        if (OpenAiModelEnum.OPEN_AI_GPT_3_5 == aiModelEnum) {
-//            return applicationContext.getBean(OpenAiChatClient.class);
-//        }
-//        // AI模型暂不支持
-//        throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODULE_NOT_SUPPORTED);
+//    @GetMapping("/chat")
+//    public String chat(@RequestParam("prompt") String prompt) {
+//        return qianWenChatClient.call(prompt);
 //    }
-//}
+
+}

+ 77 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/OldChatController.java

@@ -0,0 +1,77 @@
+//package cn.iocoder.yudao.module.ai.controller.admin;
+//
+//import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
+//import cn.iocoder.yudao.framework.common.pojo.CommonResult;
+//import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
+//import cn.iocoder.yudao.module.ai.controller.admin.vo.AiChatReqVO;
+//import cn.iocoder.yudao.module.ai.enums.OpenAiModelEnum;
+//import io.swagger.v3.oas.annotations.Operation;
+//import io.swagger.v3.oas.annotations.tags.Tag;
+//import jakarta.servlet.http.HttpServletResponse;
+//import lombok.extern.slf4j.Slf4j;
+//import org.springframework.ai.chat.ChatClient;
+//import org.springframework.ai.chat.ChatResponse;
+//import org.springframework.ai.chat.prompt.Prompt;
+//import org.springframework.ai.openai.OpenAiChatClient;
+//import org.springframework.beans.factory.annotation.Autowired;
+//import org.springframework.context.ApplicationContext;
+//import org.springframework.validation.annotation.Validated;
+//import org.springframework.web.bind.annotation.PostMapping;
+//import org.springframework.web.bind.annotation.RequestBody;
+//import org.springframework.web.bind.annotation.RequestMapping;
+//import org.springframework.web.bind.annotation.RestController;
+//import reactor.core.publisher.Flux;
+//
+//import java.util.function.Consumer;
+//
+//// TODO done @fansili:有了 swagger 注释,就不用类注释了
+//@Tag(name = "AI模块")
+//@RestController
+//@RequestMapping("/ai-api")
+//@Slf4j
+//public class ChatController {
+//
+//    @Autowired
+//    private ApplicationContext applicationContext;
+//
+//    @PostMapping("/chat")
+//    @Operation(summary = "对话聊天", description = "简单的ai聊天")
+//    public CommonResult chat(@RequestBody @Validated AiChatReqVO reqVO) {
+//        ChatClient chatClient = getChatClient(reqVO.getAiModel());
+//        String res;
+//        try {
+//            res = chatClient.call(reqVO.getPrompt());
+//        } catch (Exception e) {
+//            res = e.getMessage();
+//        }
+//        return CommonResult.success(res);
+//    }
+//
+//    @PostMapping("/chatStream")
+//    @Operation(summary = "对话聊天chatStream", description = "简单的ai聊天")
+//    public CommonResult chatStream(HttpServletResponse response, @RequestBody @Validated AiChatReqVO reqVO) throws InterruptedException {
+//        OpenAiChatClient chatClient = applicationContext.getBean(OpenAiChatClient.class);
+//        Flux<ChatResponse> chatResponse = chatClient.stream(new Prompt(reqVO.getPrompt()));
+//        chatResponse.subscribe(new Consumer<ChatResponse>() {
+//            @Override
+//            public void accept(ChatResponse chatResponse) {
+//                System.err.println(chatResponse.getResults().get(0).getOutput().getContent());
+//            }
+//        });
+//        return CommonResult.success(null);
+//    }
+//
+//    /**
+//     * 根据 ai模型 获取对于的 模型实现类
+//     *
+//     * @param aiModelEnum
+//     * @return
+//     */
+//    private ChatClient getChatClient(OpenAiModelEnum aiModelEnum) {
+//        if (OpenAiModelEnum.OPEN_AI_GPT_3_5 == aiModelEnum) {
+//            return applicationContext.getBean(OpenAiChatClient.class);
+//        }
+//        // AI模型暂不支持
+//        throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MODULE_NOT_SUPPORTED);
+//    }
+//}

+ 38 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/AiPlatformEnum.java

@@ -0,0 +1,38 @@
+package cn.iocoder.yudao.framework.ai;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+
+import java.util.Arrays;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * 讯飞星火 模型
+ *
+ * 文档地址:https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E
+ *
+ * 1tokens 约等于1.5个中文汉字 或者 0.8个英文单词
+ * 星火V1.5支持[搜索]内置插件;星火V2.0、V3.0和V3.5支持[搜索]、[天气]、[日期]、[诗词]、[字词]、[股票]六个内置插件
+ * 星火V3.5 现已支持system、Function Call 功能。
+ *
+ * author: fansili
+ * time: 2024/3/11 10:12
+ */
+@Getter
+@AllArgsConstructor
+public enum AiPlatformEnum {
+
+
+    YI_YAN("yiyan"),
+    QIAN_WEN("qianwen"),
+    XING_HUO("xinghuo"),
+
+    ;
+
+    public static final Map<String, AiPlatformEnum> mapValues
+            = Arrays.stream(values()).collect(Collectors.toMap(AiPlatformEnum::name, o -> o));
+
+    private String value;
+
+}

+ 131 - 10
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java

@@ -1,11 +1,29 @@
 package cn.iocoder.yudao.framework.ai.config;
 
+import cn.hutool.core.bean.BeanUtil;
+import cn.iocoder.yudao.framework.ai.AiPlatformEnum;
+import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatClient;
+import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenOptions;
+import cn.iocoder.yudao.framework.ai.chatqianwen.api.QianWenApi;
 import cn.iocoder.yudao.framework.ai.chatxinghuo.XingHuoChatClient;
 import cn.iocoder.yudao.framework.ai.chatxinghuo.XingHuoOptions;
 import cn.iocoder.yudao.framework.ai.chatxinghuo.api.XingHuoApi;
+import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatClient;
+import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanOptions;
+import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanApi;
+import cn.iocoder.yudao.framework.ai.exception.AiException;
+import org.springframework.beans.BeansException;
+import org.springframework.beans.factory.InitializingBean;
+import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
 import org.springframework.boot.autoconfigure.AutoConfiguration;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
 import org.springframework.boot.context.properties.EnableConfigurationProperties;
+import org.springframework.context.ApplicationContext;
+import org.springframework.context.ApplicationContextAware;
 import org.springframework.context.annotation.Bean;
+import org.springframework.context.support.GenericApplicationContext;
+
+import java.util.Map;
 
 /**
  * ai 自动配置
@@ -18,15 +36,118 @@ import org.springframework.context.annotation.Bean;
 @EnableConfigurationProperties(YudaoAiProperties.class)
 public class YudaoAiAutoConfiguration {
 
+    // TODO @芋艿:我看sharding jdbc 差不多这么玩的
     @Bean
-    public XingHuoChatClient xingHuoChatClient(YudaoAiProperties yudaoAiProperties) {
-        return new XingHuoChatClient(
-                new XingHuoApi(
-                        yudaoAiProperties.getXingHuo().getAppId(),
-                        yudaoAiProperties.getXingHuo().getAppKey(),
-                        yudaoAiProperties.getXingHuo().getSecretKey()
-                ),
-                new XingHuoOptions().setChatModel(yudaoAiProperties.getXingHuo().getChatModel())
-        );
+    @ConditionalOnMissingBean(value = InitChatClient.class)
+    public InitChatClient initChatClient(YudaoAiProperties yudaoAiProperties) {
+        return new InitChatClient(yudaoAiProperties);
+    }
+
+    public static class InitChatClient implements InitializingBean, ApplicationContextAware {
+
+        private GenericApplicationContext applicationContext;
+        private YudaoAiProperties yudaoAiProperties;
+
+        public InitChatClient(YudaoAiProperties yudaoAiProperties) {
+            this.yudaoAiProperties = yudaoAiProperties;
+        }
+
+        @Override
+        public void afterPropertiesSet() throws Exception {
+            for (Map.Entry<String, Map<String, Object>> properties : yudaoAiProperties.entrySet()) {
+                String beanName = properties.getKey();
+                Map<String, Object> aiPlatformMap = properties.getValue();
+
+                // 检查平台类型是否正确
+                String aiPlatform = String.valueOf(aiPlatformMap.get("aiPlatform"));
+                if (!AiPlatformEnum.mapValues.containsKey(aiPlatform)) {
+                    throw new AiException("AI平台名称错误! 可以参考 AiPlatformEnum 类!");
+                }
+                // 获取平台类型
+                AiPlatformEnum aiPlatformEnum = AiPlatformEnum.mapValues.get(aiPlatform);
+                // 获取 chat properties
+                YudaoAiProperties.ChatProperties chatProperties = getChatProperties(aiPlatformEnum, aiPlatformMap);
+                // 创建客户端
+                registerChatClient(applicationContext, chatProperties, beanName);
+//                applicationContext.refresh();
+
+
+            }
+
+            System.err.println(applicationContext.getBean("qianWen"));
+            System.err.println(applicationContext.getBean("yiYan"));
+        }
+
+        @Override
+        public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
+            this.applicationContext = (GenericApplicationContext) applicationContext;
+        }
+    }
+
+    private static void registerChatClient(GenericApplicationContext applicationContext, YudaoAiProperties.ChatProperties chatProperties, String beanName) {
+        ConfigurableListableBeanFactory beanFactory = applicationContext.getBeanFactory();
+        Object wrapperBean = null;
+        if (AiPlatformEnum.XING_HUO == chatProperties.getAiPlatform()) {
+            YudaoAiProperties.XingHuoProperties xingHuoProperties = (YudaoAiProperties.XingHuoProperties) chatProperties;
+            wrapperBean = beanFactory.initializeBean(
+                    new XingHuoChatClient(
+                            new XingHuoApi(
+                                    xingHuoProperties.getAppId(),
+                                    xingHuoProperties.getAppKey(),
+                                    xingHuoProperties.getSecretKey()
+                            ),
+                            new XingHuoOptions().setChatModel(xingHuoProperties.getChatModel())
+                    ),
+                    beanName
+            );
+        } else if (AiPlatformEnum.QIAN_WEN == chatProperties.getAiPlatform()) {
+            YudaoAiProperties.QianWenProperties qianWenProperties = (YudaoAiProperties.QianWenProperties) chatProperties;
+            wrapperBean = beanFactory.initializeBean(new QianWenChatClient(
+                            new QianWenApi(
+                                    qianWenProperties.getAccessKeyId(),
+                                    qianWenProperties.getAccessKeySecret(),
+                                    qianWenProperties.getAgentKey(),
+                                    qianWenProperties.getEndpoint()
+                            ),
+                            new QianWenOptions()
+                                    .setAppId(qianWenProperties.getAppId())
+                    ),
+                    beanName
+            );
+        } else if (AiPlatformEnum.YI_YAN == chatProperties.getAiPlatform()) {
+            YudaoAiProperties.YiYanProperties yiYanProperties = (YudaoAiProperties.YiYanProperties) chatProperties;
+
+            wrapperBean = beanFactory.initializeBean(new YiYanChatClient(
+                            new YiYanApi(
+                                    yiYanProperties.getAppKey(),
+                                    yiYanProperties.getSecretKey(),
+                                    yiYanProperties.getChatModel(),
+                                    yiYanProperties.getRefreshTokenSecondTime()
+                            ),
+                            new YiYanOptions().setMax_output_tokens(2048)),
+                    beanName
+            );
+        }
+        if (wrapperBean != null) {
+            beanFactory.registerSingleton(beanName, wrapperBean);
+        }
+    }
+
+
+    private static YudaoAiProperties.ChatProperties getChatProperties(AiPlatformEnum aiPlatformEnum, Map<String, Object> aiPlatformMap) {
+        if (AiPlatformEnum.XING_HUO == aiPlatformEnum) {
+            YudaoAiProperties.XingHuoProperties xingHuoProperties = new YudaoAiProperties.XingHuoProperties();
+            BeanUtil.fillBeanWithMap(aiPlatformMap, xingHuoProperties, true);
+            return xingHuoProperties;
+        } else if (AiPlatformEnum.YI_YAN == aiPlatformEnum) {
+            YudaoAiProperties.YiYanProperties yiYanProperties = new YudaoAiProperties.YiYanProperties();
+            BeanUtil.fillBeanWithMap(aiPlatformMap, yiYanProperties, true);
+            return yiYanProperties;
+        } else if (AiPlatformEnum.QIAN_WEN == aiPlatformEnum) {
+            YudaoAiProperties.QianWenProperties qianWenProperties = new YudaoAiProperties.QianWenProperties();
+            BeanUtil.fillBeanWithMap(aiPlatformMap, qianWenProperties, true);
+            return qianWenProperties;
+        }
+        throw new AiException("不支持的Ai类型!");
     }
-}
+}

+ 16 - 5
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiProperties.java

@@ -1,11 +1,15 @@
 package cn.iocoder.yudao.framework.ai.config;
 
+import cn.iocoder.yudao.framework.ai.AiPlatformEnum;
 import cn.iocoder.yudao.framework.ai.chatxinghuo.XingHuoChatModel;
 import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatModel;
 import lombok.Data;
 import lombok.experimental.Accessors;
 import org.springframework.boot.context.properties.ConfigurationProperties;
 
+import java.util.LinkedHashMap;
+import java.util.Map;
+
 /**
  * ai 自动配置
  *
@@ -15,16 +19,18 @@ import org.springframework.boot.context.properties.ConfigurationProperties;
  */
 @Data
 @ConfigurationProperties(prefix = "yudao.ai")
-public class YudaoAiProperties {
+public class YudaoAiProperties extends LinkedHashMap<String, Map<String, Object>> {
 
-    private QianWenProperties qianWen;
-    private XingHuoProperties xingHuo;
-    private YiYanProperties yiYan;
+//    private QianWenProperties qianWen;
+//    private XingHuoProperties xingHuo;
+//    private YiYanProperties yiYan;
 
     @Data
     @Accessors(chain = true)
     public static class ChatProperties {
 
+        private AiPlatformEnum aiPlatform;
+
         private Float temperature;
 
         private Float topP;
@@ -48,9 +54,14 @@ public class YudaoAiProperties {
          */
         private String accessKeySecret;
         /**
-         * 阿里云:agentKey(相当于应用id)
+         * 阿里云:agentKey
          */
         private String agentKey;
+        /**
+         * 阿里云:agentKey(相当于应用id)
+         */
+        private String appId;
+
     }
 
     @Data

+ 15 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/exception/AiException.java

@@ -0,0 +1,15 @@
+package cn.iocoder.yudao.framework.ai.exception;
+
+/**
+ * ai 异常
+ *
+ * @author fansili
+ * @time 2024/4/13 17:05
+ * @since 1.0
+ */
+public class AiException extends RuntimeException {
+
+    public AiException(String message) {
+        super(message);
+    }
+}

+ 13 - 3
yudao-server/src/main/resources/application-local.yaml

@@ -224,20 +224,30 @@ wx:
 # 芋道配置项,设置当前项目所有自定义的配置
 yudao:
   ai:
-    temperature: 1
-    topP: 1
-    topK: 1
     qianWen:
+      aiPlatform: QIAN_WEN
+      temperature: 1
+      topP: 1
+      topK: 1
       endpoint: bailian.cn-beijing.aliyuncs.com
       accessKeyId: LTAI5tNTVhXW4fLKUjMrr98z
       accessKeySecret: ZJ0JQeyjzxxm5CfeTV6k1wNE9UsvZP
       agentKey: f0c1088824594f589c8f10567ccd929f_p_efm
+      appId: 5f14955f201a44eb8dbe0c57250a32ce
     xingHuo:
+      aiPlatform: XING_HUO
+      temperature: 1
+      topP: 1
+      topK: 1
       appId: 13c8cca6
       appKey: cb6415c19d6162cda07b47316fcb0416
       secretKey: Y2JiYTIxZjA3MDMxMjNjZjQzYzVmNzdh
       chatModel: XING_HUO_3_5
     yiYan:
+      aiPlatform: YI_YAN
+      temperature: 1
+      topP: 1
+      topK: 1
       appKey: x0cuLZ7XsaTCU08vuJWO87Lg
       secretKey: R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK
       refreshTokenSecondTime: 86400