Sfoglia il codice sorgente

【新增】AI:会话接入 API KEY 逻辑

YunaiV 1 anno fa
parent
commit
b7180d3481

+ 9 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiApiKeyMapper.java

@@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.dal.mysql.model;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
 import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
+import cn.iocoder.yudao.framework.mybatis.core.query.QueryWrapperX;
 import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
 import org.apache.ibatis.annotations.Mapper;
@@ -23,4 +24,12 @@ public interface AiApiKeyMapper extends BaseMapperX<AiApiKeyDO> {
                 .orderByDesc(AiApiKeyDO::getId));
     }
 
+    default AiApiKeyDO selectFirstByPlatformAndStatus(String platform, Integer status) {
+        return selectOne(new QueryWrapperX<AiApiKeyDO>()
+                .eq("platform", platform)
+                .eq("status", status)
+                .limitN(1)
+                .orderByAsc("id"));
+    }
+
 }

+ 33 - 7
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java

@@ -4,11 +4,13 @@ import cn.hutool.core.collection.CollUtil;
 import cn.hutool.core.util.ObjUtil;
 import cn.hutool.core.util.StrUtil;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
-import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory;
+import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions;
+import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
+import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
+import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
 import cn.iocoder.yudao.framework.common.pojo.CommonResult;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
-import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
@@ -18,6 +20,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
+import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
 import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
 import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
 import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
@@ -28,6 +31,8 @@ import org.springframework.ai.chat.StreamingChatClient;
 import org.springframework.ai.chat.messages.*;
 import org.springframework.ai.chat.prompt.ChatOptions;
 import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.ollama.api.OllamaOptions;
+import org.springframework.ai.openai.OpenAiChatOptions;
 import org.springframework.stereotype.Service;
 import org.springframework.transaction.annotation.Transactional;
 import reactor.core.publisher.Flux;
@@ -54,9 +59,6 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
     @Resource
     private AiChatMessageMapper chatMessageMapper;
 
-    @Resource
-    private AiClientFactory clientFactory;
-
     @Resource
     private AiChatConversationService chatConversationService;
     @Resource
@@ -168,11 +170,33 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
 
         // 2. 构建 ChatOptions 对象
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
-        ChatOptions chatOptions = clientFactory.buildChatOptions(platform, model.getModel(),
+        ChatOptions chatOptions = buildChatOptions(platform, model.getModel(),
                 conversation.getTemperature(), conversation.getMaxTokens());
         return new Prompt(chatMessages, chatOptions);
     }
 
+    private static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
+        Float temperatureF = temperature != null ? temperature.floatValue() : null;
+        //noinspection EnhancedSwitchMigration
+        switch (platform) {
+            case OPENAI:
+                return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
+            case OLLAMA:
+                return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
+            case YI_YAN:
+                // TODO @fan:增加一个 model
+                return new YiYanChatOptions().setTemperature(temperatureF).setMaxOutputTokens(maxTokens);
+            case XING_HUO:
+                return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
+                        .setMaxTokens(maxTokens);
+            case QIAN_WEN:
+                // TODO @fan:增加 model、temperature 参数
+                return new QianWenOptions().setMaxTokens(maxTokens);
+            default:
+                throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
+        }
+    }
+
     /**
      * 从历史消息中,获得倒序的 n 组消息作为消息上下文
      *
@@ -183,7 +207,9 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
      * @param sendReqVO 发送请求
      * @return 消息上下文
      */
-    private List<AiChatMessageDO> filterContextMessages(List<AiChatMessageDO> messages, AiChatConversationDO conversation, AiChatMessageSendReqVO sendReqVO) {
+    private List<AiChatMessageDO> filterContextMessages(List<AiChatMessageDO> messages,
+                                                        AiChatConversationDO conversation,
+                                                        AiChatMessageSendReqVO sendReqVO) {
         if (conversation.getMaxContexts() == null || ObjUtil.notEqual(sendReqVO.getUseContext(), Boolean.TRUE)) {
             return Collections.emptyList();
         }

+ 1 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageService.java

@@ -64,4 +64,5 @@ public interface AiImageService {
      * @return
      */
     Boolean midjourneyNotify(Long loginUserId, MidjourneyNotifyReqVO notifyReqVO);
+
 }

+ 5 - 5
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java

@@ -7,7 +7,6 @@ import cn.hutool.core.util.StrUtil;
 import cn.hutool.extra.spring.SpringUtil;
 import cn.hutool.http.HttpUtil;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
-import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory;
 import cn.iocoder.yudao.framework.common.pojo.PageParam;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
@@ -23,6 +22,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyIma
 import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
 import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
+import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
 import cn.iocoder.yudao.module.infra.api.file.FileApi;
 import jakarta.annotation.Resource;
 import lombok.extern.slf4j.Slf4j;
@@ -57,7 +57,7 @@ public class AiImageServiceImpl implements AiImageService {
     private FileApi fileApi;
 
     @Resource
-    private AiClientFactory aiClientFactory;
+    private AiApiKeyService apiKeyService;
 
     @Autowired
     private MidjourneyProxyClient midjourneyProxyClient;
@@ -82,17 +82,17 @@ public class AiImageServiceImpl implements AiImageService {
                 .setWidth(drawReqVO.getWidth()).setHeight(drawReqVO.getHeight()).setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
         imageMapper.insert(image);
         // 2. 异步绘制,后续前端通过返回的 id 进行轮询结果
-        getSelf().doDall(image, drawReqVO);
+        getSelf().executeDrawImage(image, drawReqVO);
         return image.getId();
     }
 
     @Async
-    public void doDall(AiImageDO image, AiImageDrawReqVO req) {
+    public void executeDrawImage(AiImageDO image, AiImageDrawReqVO req) {
         try {
             // 1.1 构建请求
             ImageOptions request = buildImageOptions(req);
             // 1.2 执行请求
-            ImageClient imageClient = aiClientFactory.getDefaultImageClient(AiPlatformEnum.validatePlatform(req.getPlatform()));
+            ImageClient imageClient = apiKeyService.getImageClient(AiPlatformEnum.validatePlatform(req.getPlatform()));
             ImageResponse response = imageClient.call(new ImagePrompt(req.getPrompt(), request));
 
             // 2. 上传到文件服务

+ 12 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java

@@ -1,11 +1,13 @@
 package cn.iocoder.yudao.module.ai.service.model;
 
+import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
 import jakarta.validation.Valid;
 import org.springframework.ai.chat.StreamingChatClient;
+import org.springframework.ai.image.ImageClient;
 
 import java.util.List;
 
@@ -79,4 +81,14 @@ public interface AiApiKeyService {
      */
     StreamingChatClient getStreamingChatClient(Long id);
 
+    /**
+     * 获得 ImageClient 对象
+     *
+     * TODO 可优化点:目前默认获取 platform 对应的第一个开启的配置用于绘画;后续可以支持配置选择
+     *
+     * @param platform 平台
+     * @return ImageClient 对象
+     */
+    ImageClient getImageClient(AiPlatformEnum platform);
+
 }

+ 10 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java

@@ -11,6 +11,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper;
 import jakarta.annotation.Resource;
 import org.springframework.ai.chat.StreamingChatClient;
+import org.springframework.ai.image.ImageClient;
 import org.springframework.stereotype.Service;
 import org.springframework.validation.annotation.Validated;
 
@@ -101,4 +102,13 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
         return clientFactory.getOrCreateStreamingChatClient(platform, apiKey.getApiKey(), apiKey.getUrl());
     }
 
+    @Override
+    public ImageClient getImageClient(AiPlatformEnum platform) {
+        AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getName(), CommonStatusEnum.ENABLE.getStatus());
+        if (apiKey == null) {
+            return null;
+        }
+        return clientFactory.getOrCreateImageClient(platform, apiKey.getApiKey(), apiKey.getUrl());
+    }
+
 }

+ 7 - 7
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactory.java

@@ -2,7 +2,6 @@ package cn.iocoder.yudao.framework.ai.core.factory;
 
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 import org.springframework.ai.chat.StreamingChatClient;
-import org.springframework.ai.chat.prompt.ChatOptions;
 import org.springframework.ai.image.ImageClient;
 
 /**
@@ -45,14 +44,15 @@ public interface AiClientFactory {
     ImageClient getDefaultImageClient(AiPlatformEnum platform);
 
     /**
-     * 创建 Chat 参数
+     * 基于指定配置,获得 ImageClient 对象
+     *
+     * 如果不存在,则进行创建
      *
      * @param platform 平台
-     * @param model 模型
-     * @param temperature 温度
-     * @param maxTokens 生成的最大 Token
-     * @return Chat 参数
+     * @param apiKey API KEY
+     * @param url API URL
+     * @return ImageClient 对象
      */
-    ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens);
+    ImageClient getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url);
 
 }

+ 26 - 29
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactoryImpl.java

@@ -11,29 +11,25 @@ import cn.iocoder.yudao.framework.ai.config.YudaoAiProperties;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient;
 import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal;
-import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions;
 import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi;
 import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatClient;
-import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
-import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
 import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoApi;
 import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
-import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
 import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
 import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
 import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
 import org.springframework.ai.chat.StreamingChatClient;
-import org.springframework.ai.chat.prompt.ChatOptions;
 import org.springframework.ai.image.ImageClient;
 import org.springframework.ai.ollama.OllamaChatClient;
 import org.springframework.ai.ollama.api.OllamaApi;
-import org.springframework.ai.ollama.api.OllamaOptions;
 import org.springframework.ai.openai.OpenAiChatClient;
-import org.springframework.ai.openai.OpenAiChatOptions;
 import org.springframework.ai.openai.OpenAiImageClient;
 import org.springframework.ai.openai.api.ApiUtils;
 import org.springframework.ai.openai.api.OpenAiApi;
+import org.springframework.ai.openai.api.OpenAiImageApi;
 import org.springframework.ai.stabilityai.StabilityAiImageClient;
+import org.springframework.ai.stabilityai.api.StabilityAiApi;
+import org.springframework.web.client.RestClient;
 
 import java.util.List;
 
@@ -100,36 +96,26 @@ public class AiClientFactoryImpl implements AiClientFactory {
         }
     }
 
-    private static String buildClientCacheKey(Class<?> clazz, Object... params) {
-        if (ArrayUtil.isEmpty(params)) {
-            return clazz.getName();
-        }
-        return StrUtil.format("{}#{}", clazz.getName(), ArrayUtil.join(params, "_"));
-    }
-
     @Override
-    public ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
-        Float temperatureF = temperature != null ? temperature.floatValue() : null;
+    public ImageClient getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url) {
         //noinspection EnhancedSwitchMigration
         switch (platform) {
             case OPENAI:
-                return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
-            case OLLAMA:
-                return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
-            case YI_YAN:
-                // TODO @fan:增加一个 model
-                return new YiYanChatOptions().setTemperature(temperatureF).setMaxOutputTokens(maxTokens);
-            case XING_HUO:
-                return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
-                        .setMaxTokens(maxTokens);
-            case QIAN_WEN:
-                // TODO @fan:增加 model、temperature 参数
-                return new QianWenOptions().setMaxTokens(maxTokens);
+                return buildOpenAiImageClient(apiKey, url);
+            case STABLE_DIFFUSION:
+                return buildStabilityAiImageClient(apiKey, url);
             default:
                 throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
         }
     }
 
+    private static String buildClientCacheKey(Class<?> clazz, Object... params) {
+        if (ArrayUtil.isEmpty(params)) {
+            return clazz.getName();
+        }
+        return StrUtil.format("{}#{}", clazz.getName(), ArrayUtil.join(params, "_"));
+    }
+
     // ========== 各种创建 spring-ai 客户端的方法 ==========
 
     /**
@@ -182,7 +168,6 @@ public class AiClientFactoryImpl implements AiClientFactory {
         return new QianWenChatClient(qianWenApi);
     }
 
-
 //    private static VertexAiGeminiChatClient buildGoogleGemir(String key) {
 //        List<String> keys = StrUtil.split(key, '|');
 //        Assert.equals(keys.size(), 2, "VertexAiGeminiChatClient 的密钥需要 (projectId|location) 格式");
@@ -190,4 +175,16 @@ public class AiClientFactoryImpl implements AiClientFactory {
 //        return new VertexAiGeminiChatClient(vertexApi);
 //    }
 
+    private ImageClient buildOpenAiImageClient(String openAiToken, String url) {
+        url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL);
+        OpenAiImageApi openAiApi = new OpenAiImageApi(url, openAiToken, RestClient.builder());
+        return new OpenAiImageClient(openAiApi);
+    }
+
+    private ImageClient buildStabilityAiImageClient(String apiKey, String url) {
+        url = StrUtil.blankToDefault(url, StabilityAiApi.DEFAULT_BASE_URL);
+        StabilityAiApi stabilityAiApi = new StabilityAiApi(apiKey, StabilityAiApi.DEFAULT_IMAGE_MODEL, url);
+        return new StabilityAiImageClient(stabilityAiApi);
+    }
+
 }

+ 1 - 6
yudao-server/src/main/resources/application.yaml

@@ -161,7 +161,6 @@ spring:
           project-id: 1 # TODO 芋艿:缺配置
           location: 2
 
-
 yudao.ai:
   yiyan:
     enable: true
@@ -193,11 +192,6 @@ yudao.ai:
     topP: 0.8
     topK: 0
     api-key: sk-Zsd81gZYg7
-  openAiImage:
-    enable: true
-    api-key: ${OPEN_AI_KEY}
-    model: dall_e_2
-    style: vivid
   midjourney:
     enable: true
     token: MTE4MjE3MjY2MjkxNTY3ODIzOA.GEV1SG.c49F8lZoGCUHwsj8O0UdodmM6nyQHvuD2fXflw
@@ -206,6 +200,7 @@ yudao.ai:
   suno:
     enable: true
     token: 16b4356581984d538652354b60d69ff0
+
 --- #################### 芋道相关配置 ####################
 
 yudao: