Quellcode durchsuchen

增加yudao ai client

cherishsince vor 1 Jahr
Ursprung
Commit
97df2755f9

+ 2 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatqianwen/QianWenChatClient.java

@@ -36,6 +36,8 @@ public class QianWenChatClient implements ChatClient, StreamingChatClient {
 
     private QianWenOptions qianWenOptions;
 
+
+    public QianWenChatClient() {}
     public QianWenChatClient(QianWenApi qianWenApi) {
         this.qianWenApi = qianWenApi;
     }

+ 19 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/AiClient.java

@@ -0,0 +1,19 @@
+package cn.iocoder.yudao.framework.ai.config;
+
+import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
+import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
+import reactor.core.publisher.Flux;
+
+/**
+ * ai client传入
+ *
+ * @author fansili
+ * @time 2024/4/14 10:27
+ * @since 1.0
+ */
+public interface AiClient {
+
+    ChatResponse call(Prompt prompt, String clientName);
+
+    Flux<ChatResponse> stream(Prompt prompt, String clientName);
+}

+ 52 - 60
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java

@@ -14,7 +14,6 @@ import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanApi;
 import cn.iocoder.yudao.framework.ai.exception.AiException;
 import org.springframework.beans.BeansException;
 import org.springframework.beans.factory.InitializingBean;
-import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
 import org.springframework.boot.autoconfigure.AutoConfiguration;
 import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
 import org.springframework.boot.context.properties.EnableConfigurationProperties;
@@ -23,6 +22,7 @@ import org.springframework.context.ApplicationContextAware;
 import org.springframework.context.annotation.Bean;
 import org.springframework.context.support.GenericApplicationContext;
 
+import java.util.HashMap;
 import java.util.Map;
 
 /**
@@ -36,11 +36,33 @@ import java.util.Map;
 @EnableConfigurationProperties(YudaoAiProperties.class)
 public class YudaoAiAutoConfiguration {
 
-    // TODO @芋艿:我看sharding jdbc 差不多这么玩的
     @Bean
-    @ConditionalOnMissingBean(value = InitChatClient.class)
-    public InitChatClient initChatClient(YudaoAiProperties yudaoAiProperties) {
-        return new InitChatClient(yudaoAiProperties);
+    @ConditionalOnMissingBean(value = AiClient.class)
+    public AiClient aiClient(YudaoAiProperties yudaoAiProperties) {
+        Map<String, Object> chatClientMap = buildChatClientMap(yudaoAiProperties);
+        return new YudaoAiClient(chatClientMap);
+    }
+
+    public Map<String, Object> buildChatClientMap(YudaoAiProperties yudaoAiProperties) {
+        Map<String, Object> chatMap = new HashMap<>();
+        for (Map.Entry<String, Map<String, Object>> properties : yudaoAiProperties.entrySet()) {
+            String beanName = properties.getKey();
+            Map<String, Object> aiPlatformMap = properties.getValue();
+
+            // 检查平台类型是否正确
+            String aiPlatform = String.valueOf(aiPlatformMap.get("aiPlatform"));
+            if (!AiPlatformEnum.mapValues.containsKey(aiPlatform)) {
+                throw new AiException("AI平台名称错误! 可以参考 AiPlatformEnum 类!");
+            }
+            // 获取平台类型
+            AiPlatformEnum aiPlatformEnum = AiPlatformEnum.mapValues.get(aiPlatform);
+            // 获取 chat properties
+            YudaoAiProperties.ChatProperties chatProperties = getChatProperties(aiPlatformEnum, aiPlatformMap);
+            // 创建客户端
+            Object chatClient = createChatClient(chatProperties);
+            chatMap.put(beanName, chatClient);
+        }
+        return chatMap;
     }
 
     public static class InitChatClient implements InitializingBean, ApplicationContextAware {
@@ -53,26 +75,8 @@ public class YudaoAiAutoConfiguration {
         }
 
         @Override
-        public void afterPropertiesSet() throws Exception {
-            for (Map.Entry<String, Map<String, Object>> properties : yudaoAiProperties.entrySet()) {
-                String beanName = properties.getKey();
-                Map<String, Object> aiPlatformMap = properties.getValue();
-
-                // 检查平台类型是否正确
-                String aiPlatform = String.valueOf(aiPlatformMap.get("aiPlatform"));
-                if (!AiPlatformEnum.mapValues.containsKey(aiPlatform)) {
-                    throw new AiException("AI平台名称错误! 可以参考 AiPlatformEnum 类!");
-                }
-                // 获取平台类型
-                AiPlatformEnum aiPlatformEnum = AiPlatformEnum.mapValues.get(aiPlatform);
-                // 获取 chat properties
-                YudaoAiProperties.ChatProperties chatProperties = getChatProperties(aiPlatformEnum, aiPlatformMap);
-                // 创建客户端
-                registerChatClient(applicationContext, chatProperties, beanName);
-//                applicationContext.refresh();
-
+        public void afterPropertiesSet() {
 
-            }
 
             System.err.println(applicationContext.getBean("qianWen"));
             System.err.println(applicationContext.getBean("yiYan"));
@@ -84,53 +88,41 @@ public class YudaoAiAutoConfiguration {
         }
     }
 
-    private static void registerChatClient(GenericApplicationContext applicationContext, YudaoAiProperties.ChatProperties chatProperties, String beanName) {
-        ConfigurableListableBeanFactory beanFactory = applicationContext.getBeanFactory();
-        Object wrapperBean = null;
+    private static Object createChatClient(YudaoAiProperties.ChatProperties chatProperties) {
         if (AiPlatformEnum.XING_HUO == chatProperties.getAiPlatform()) {
             YudaoAiProperties.XingHuoProperties xingHuoProperties = (YudaoAiProperties.XingHuoProperties) chatProperties;
-            wrapperBean = beanFactory.initializeBean(
-                    new XingHuoChatClient(
-                            new XingHuoApi(
-                                    xingHuoProperties.getAppId(),
-                                    xingHuoProperties.getAppKey(),
-                                    xingHuoProperties.getSecretKey()
-                            ),
-                            new XingHuoOptions().setChatModel(xingHuoProperties.getChatModel())
+            return new XingHuoChatClient(
+                    new XingHuoApi(
+                            xingHuoProperties.getAppId(),
+                            xingHuoProperties.getAppKey(),
+                            xingHuoProperties.getSecretKey()
                     ),
-                    beanName
+                    new XingHuoOptions().setChatModel(xingHuoProperties.getChatModel())
             );
         } else if (AiPlatformEnum.QIAN_WEN == chatProperties.getAiPlatform()) {
             YudaoAiProperties.QianWenProperties qianWenProperties = (YudaoAiProperties.QianWenProperties) chatProperties;
-            wrapperBean = beanFactory.initializeBean(new QianWenChatClient(
-                            new QianWenApi(
-                                    qianWenProperties.getAccessKeyId(),
-                                    qianWenProperties.getAccessKeySecret(),
-                                    qianWenProperties.getAgentKey(),
-                                    qianWenProperties.getEndpoint()
-                            ),
-                            new QianWenOptions()
-                                    .setAppId(qianWenProperties.getAppId())
+            return new QianWenChatClient(
+                    new QianWenApi(
+                            qianWenProperties.getAccessKeyId(),
+                            qianWenProperties.getAccessKeySecret(),
+                            qianWenProperties.getAgentKey(),
+                            qianWenProperties.getEndpoint()
                     ),
-                    beanName
+                    new QianWenOptions()
+                            .setAppId(qianWenProperties.getAppId())
             );
         } else if (AiPlatformEnum.YI_YAN == chatProperties.getAiPlatform()) {
             YudaoAiProperties.YiYanProperties yiYanProperties = (YudaoAiProperties.YiYanProperties) chatProperties;
-
-            wrapperBean = beanFactory.initializeBean(new YiYanChatClient(
-                            new YiYanApi(
-                                    yiYanProperties.getAppKey(),
-                                    yiYanProperties.getSecretKey(),
-                                    yiYanProperties.getChatModel(),
-                                    yiYanProperties.getRefreshTokenSecondTime()
-                            ),
-                            new YiYanOptions().setMax_output_tokens(2048)),
-                    beanName
-            );
-        }
-        if (wrapperBean != null) {
-            beanFactory.registerSingleton(beanName, wrapperBean);
+            return new YiYanChatClient(
+                    new YiYanApi(
+                            yiYanProperties.getAppKey(),
+                            yiYanProperties.getSecretKey(),
+                            yiYanProperties.getChatModel(),
+                            yiYanProperties.getRefreshTokenSecondTime()
+                    ),
+                    new YiYanOptions().setMax_output_tokens(2048));
         }
+        throw new AiException("不支持的Ai类型!");
     }
 
 

+ 44 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiClient.java

@@ -0,0 +1,44 @@
+package cn.iocoder.yudao.framework.ai.config;
+
+import cn.iocoder.yudao.framework.ai.chat.ChatClient;
+import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
+import cn.iocoder.yudao.framework.ai.chat.StreamingChatClient;
+import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
+import cn.iocoder.yudao.framework.ai.exception.AiException;
+import reactor.core.publisher.Flux;
+
+import java.util.Map;
+
+/**
+ * yudao ai client
+ *
+ * @author fansili
+ * @time 2024/4/14 10:27
+ * @since 1.0
+ */
+public class YudaoAiClient implements AiClient{
+
+    protected Map<String, Object> chatClientMap;
+
+    public YudaoAiClient(Map<String, Object> chatClientMap) {
+        this.chatClientMap = chatClientMap;
+    }
+
+    @Override
+    public ChatResponse call(Prompt prompt, String clientName) {
+        if (!chatClientMap.containsKey(clientName)) {
+            throw new AiException("clientName不存在!");
+        }
+        ChatClient chatClient = (ChatClient) chatClientMap.get(clientName);
+        return chatClient.call(prompt);
+    }
+
+    @Override
+    public Flux<ChatResponse> stream(Prompt prompt, String clientName) {
+        if (!chatClientMap.containsKey(clientName)) {
+            throw new AiException("clientName不存在!");
+        }
+        StreamingChatClient streamingChatClient = (StreamingChatClient) chatClientMap.get(clientName);
+        return streamingChatClient.stream(prompt);
+    }
+}