瀏覽代碼

【代码优化】AI:将 ChatClient 替换成 ChatModel,和 Spring AI 对齐

YunaiV 9 月之前
父節點
當前提交
e0f08a0f02

+ 4 - 4
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java

@@ -70,7 +70,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
         // 1.2 校验模型
         AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
-        ChatModel chatClient = apiKeyService.getChatClient(model.getKeyId());
+        ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
 
         // 2. 插入 user 发送消息
         AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
@@ -82,7 +82,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
 
         // 3.2 创建 chat 需要的 Prompt
         Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO);
-        ChatResponse chatResponse = chatClient.call(prompt);
+        ChatResponse chatResponse = chatModel.call(prompt);
 
         // 3.3 段式返回
         String newContent = chatResponse.getResult().getOutput().getContent();
@@ -101,7 +101,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
         // 1.2 校验模型
         AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
-        StreamingChatModel chatClient = apiKeyService.getChatClient(model.getKeyId());
+        StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
 
         // 2. 插入 user 发送消息
         AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
@@ -113,7 +113,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
 
         // 3.2 创建 chat 需要的 Prompt
         Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO);
-        Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
+        Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
 
         // 3.3 流式返回
         // TODO 注意:Schedulers.immediate() 目的是,避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题

+ 2 - 2
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java

@@ -98,8 +98,8 @@ public class AiImageServiceImpl implements AiImageService {
             // 1.1 构建请求
             ImageOptions request = buildImageOptions(req);
             // 1.2 执行请求
-            ImageModel imageClient = apiKeyService.getImageClient(AiPlatformEnum.validatePlatform(req.getPlatform()));
-            ImageResponse response = imageClient.call(new ImagePrompt(req.getPrompt(), request));
+            ImageModel imageModel = apiKeyService.getImageModel(AiPlatformEnum.validatePlatform(req.getPlatform()));
+            ImageResponse response = imageModel.call(new ImagePrompt(req.getPrompt(), request));
 
             // 2. 上传到文件服务
             byte[] fileContent = Base64.decode(response.getResult().getOutput().getB64Json());

+ 4 - 4
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java

@@ -81,17 +81,17 @@ public interface AiApiKeyService {
      * @param id 编号
      * @return ChatModel 对象
      */
-    ChatModel getChatClient(Long id);
+    ChatModel getChatModel(Long id);
 
     /**
-     * 获得 ImageClient 对象
+     * 获得 ImageModel 对象
      *
      * TODO 可优化点:目前默认获取 platform 对应的第一个开启的配置用于绘画;后续可以支持配置选择
      *
      * @param platform 平台
-     * @return ImageClient 对象
+     * @return ImageModel 对象
      */
-    ImageModel getImageClient(AiPlatformEnum platform);
+    ImageModel getImageModel(AiPlatformEnum platform);
 
     /**
      * 获得 MidjourneyApi 对象

+ 8 - 8
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java

@@ -1,7 +1,7 @@
 package cn.iocoder.yudao.module.ai.service.model;
 
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
-import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory;
+import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
 import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
 import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
 import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
@@ -35,7 +35,7 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
     private AiApiKeyMapper apiKeyMapper;
 
     @Resource
-    private AiClientFactory clientFactory;
+    private AiModelFactory modelFactory;
 
     @Override
     public Long createApiKey(AiApiKeySaveReqVO createReqVO) {
@@ -98,19 +98,19 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
     // ========== 与 spring-ai 集成 ==========
 
     @Override
-    public ChatModel getChatClient(Long id) {
+    public ChatModel getChatModel(Long id) {
         AiApiKeyDO apiKey = validateApiKey(id);
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
-        return clientFactory.getOrCreateChatClient(platform, apiKey.getApiKey(), apiKey.getUrl());
+        return modelFactory.getOrCreateChatClient(platform, apiKey.getApiKey(), apiKey.getUrl());
     }
 
     @Override
-    public ImageModel getImageClient(AiPlatformEnum platform) {
+    public ImageModel getImageModel(AiPlatformEnum platform) {
         AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getName(), CommonStatusEnum.ENABLE.getStatus());
         if (apiKey == null) {
             throw exception(API_KEY_IMAGE_NODE_FOUND, platform.getName());
         }
-        return clientFactory.getOrCreateImageClient(platform, apiKey.getApiKey(), apiKey.getUrl());
+        return modelFactory.getOrCreateImageModel(platform, apiKey.getApiKey(), apiKey.getUrl());
     }
 
     @Override
@@ -120,7 +120,7 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
         if (apiKey == null) {
             throw exception(API_KEY_MIDJOURNEY_NOT_FOUND);
         }
-        return clientFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl());
+        return modelFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl());
     }
 
     @Override
@@ -130,7 +130,7 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
         if (apiKey == null) {
             throw exception(API_KEY_SUNO_NOT_FOUND);
         }
-        return clientFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl());
+        return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl());
     }
 
 }

+ 2 - 2
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java

@@ -54,7 +54,7 @@ public class AiWriteServiceImpl implements AiWriteService {
     public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
         // 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok?;那可以,有限拿 chatRole 的角色;如果没有,则获取默认的;
         AiChatModelDO model = chatModalService.getRequiredDefaultChatModel();
-        StreamingChatModel chatClient = apiKeyService.getChatClient(model.getKeyId());
+        StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
 
         // 1.2 插入写作信息
@@ -65,7 +65,7 @@ public class AiWriteServiceImpl implements AiWriteService {
         // 2.1 构建提示词
         ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
         Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions);
-        Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
+        Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
 
         // 2.2 流式返回
         StringBuffer contentBuffer = new StringBuffer();

+ 4 - 4
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java

@@ -1,7 +1,7 @@
 package cn.iocoder.yudao.framework.ai.config;
 
-import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory;
-import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactoryImpl;
+import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
+import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactoryImpl;
 import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatClient;
 import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions;
 import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
@@ -28,8 +28,8 @@ import org.springframework.context.annotation.Import;
 public class YudaoAiAutoConfiguration {
 
     @Bean
-    public AiClientFactory aiClientFactory() {
-        return new AiClientFactoryImpl();
+    public AiModelFactory aiModelFactory() {
+        return new AiModelFactoryImpl();
     }
 
     // ========== 各种 AI Client 创建 ==========

+ 9 - 9
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactory.java → yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactory.java

@@ -7,11 +7,11 @@ import org.springframework.ai.chat.model.ChatModel;
 import org.springframework.ai.image.ImageModel;
 
 /**
- * AI 客户端工厂的接口类
+ * AI Model 模型工厂的接口类
  *
  * @author fansili
  */
-public interface AiClientFactory {
+public interface AiModelFactory {
 
     /**
      * 基于指定配置,获得 ChatModel 对象
@@ -33,29 +33,29 @@ public interface AiClientFactory {
      * @param platform 平台
      * @return ChatModel 对象
      */
-    ChatModel getDefaultChatClient(AiPlatformEnum platform);
+    ChatModel getDefaultChatModel(AiPlatformEnum platform);
 
     /**
-     * 基于默认配置,获得 ImageClient 对象
+     * 基于默认配置,获得 ImageModel 对象
      *
      * 默认配置,指的是在 application.yaml 配置文件中的 spring.ai 相关的配置
      *
      * @param platform 平台
-     * @return ImageClient 对象
+     * @return ImageModel 对象
      */
-    ImageModel getDefaultImageClient(AiPlatformEnum platform);
+    ImageModel getDefaultImageModel(AiPlatformEnum platform);
 
     /**
-     * 基于指定配置,获得 ImageClient 对象
+     * 基于指定配置,获得 ImageModel 对象
      *
      * 如果不存在,则进行创建
      *
      * @param platform 平台
      * @param apiKey API KEY
      * @param url API URL
-     * @return ImageClient 对象
+     * @return ImageModel 对象
      */
-    ImageModel getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url);
+    ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url);
 
     /**
      * 基于指定配置,获得 MidjourneyApi 对象

+ 20 - 17
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactoryImpl.java → yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java

@@ -43,11 +43,11 @@ import org.springframework.web.client.RestClient;
 import java.util.List;
 
 /**
- * AI 客户端工厂的实现类
+ * AI Model 模型工厂的实现类
  *
  * @author 芋道源码
  */
-public class AiClientFactoryImpl implements AiClientFactory {
+public class AiModelFactoryImpl implements AiModelFactory {
 
     @Override
     public ChatModel getOrCreateChatClient(AiPlatformEnum platform, String apiKey, String url) {
@@ -55,8 +55,6 @@ public class AiClientFactoryImpl implements AiClientFactory {
         return Singleton.get(cacheKey, (Func0<ChatModel>) () -> {
             //noinspection EnhancedSwitchMigration
             switch (platform) {
-                case OPENAI:
-                    return buildOpenAiChatClient(apiKey, url);
                 case OLLAMA:
                     return buildOllamaChatClient(url);
                 case YI_YAN:
@@ -67,6 +65,8 @@ public class AiClientFactoryImpl implements AiClientFactory {
                     return buildQianWenChatClient(apiKey);
                 case DEEP_SEEK:
                     return buildDeepSeekChatClient(apiKey);
+                case OPENAI:
+                    return buildOpenAiChatModel(apiKey, url);
                 default:
                     throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
             }
@@ -74,11 +74,9 @@ public class AiClientFactoryImpl implements AiClientFactory {
     }
 
     @Override
-    public ChatModel getDefaultChatClient(AiPlatformEnum platform) {
+    public ChatModel getDefaultChatModel(AiPlatformEnum platform) {
         //noinspection EnhancedSwitchMigration
         switch (platform) {
-            case OPENAI:
-                return SpringUtil.getBean(OpenAiChatModel.class);
             case OLLAMA:
                 return SpringUtil.getBean(OllamaChatModel.class);
             case YI_YAN:
@@ -87,13 +85,15 @@ public class AiClientFactoryImpl implements AiClientFactory {
                 return SpringUtil.getBean(XingHuoChatClient.class);
             case QIAN_WEN:
                 return SpringUtil.getBean(TongYiChatModel.class);
+            case OPENAI:
+                return SpringUtil.getBean(OpenAiChatModel.class);
             default:
                 throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
         }
     }
 
     @Override
-    public ImageModel getDefaultImageClient(AiPlatformEnum platform) {
+    public ImageModel getDefaultImageModel(AiPlatformEnum platform) {
         //noinspection EnhancedSwitchMigration
         switch (platform) {
             case OPENAI:
@@ -106,11 +106,11 @@ public class AiClientFactoryImpl implements AiClientFactory {
     }
 
     @Override
-    public ImageModel getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url) {
+    public ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url) {
         //noinspection EnhancedSwitchMigration
         switch (platform) {
             case OPENAI:
-                return buildOpenAiImageClient(apiKey, url);
+                return buildOpenAiImageModel(apiKey, url);
             case STABLE_DIFFUSION:
                 return buildStabilityAiImageClient(apiKey, url);
             default:
@@ -145,12 +145,21 @@ public class AiClientFactoryImpl implements AiClientFactory {
     /**
      * 可参考 {@link OpenAiAutoConfiguration}
      */
-    private static OpenAiChatModel buildOpenAiChatClient(String openAiToken, String url) {
+    private static OpenAiChatModel buildOpenAiChatModel(String openAiToken, String url) {
         url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL);
         OpenAiApi openAiApi = new OpenAiApi(url, openAiToken);
         return new OpenAiChatModel(openAiApi);
     }
 
+    /**
+     * 可参考 {@link OpenAiAutoConfiguration}
+     */
+    private OpenAiImageModel buildOpenAiImageModel(String openAiToken, String url) {
+        url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL);
+        OpenAiImageApi openAiApi = new OpenAiImageApi(url, openAiToken, RestClient.builder());
+        return new OpenAiImageModel(openAiApi);
+    }
+
     /**
      * 可参考 {@link OllamaAutoConfiguration}
      */
@@ -200,12 +209,6 @@ public class AiClientFactoryImpl implements AiClientFactory {
         return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties);
     }
 
-    private OpenAiImageModel buildOpenAiImageClient(String openAiToken, String url) {
-        url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL);
-        OpenAiImageApi openAiApi = new OpenAiImageApi(url, openAiToken, RestClient.builder());
-        return new OpenAiImageModel(openAiApi);
-    }
-
     private StabilityAiImageModel buildStabilityAiImageClient(String apiKey, String url) {
         url = StrUtil.blankToDefault(url, StabilityAiApi.DEFAULT_BASE_URL);
         StabilityAiApi stabilityAiApi = new StabilityAiApi(apiKey, StabilityAiApi.DEFAULT_IMAGE_MODEL, url);