Jelajahi Sumber

【代码优化】AI:新增 AIUtils,用于对接 spring ai 各种对象的构建

YunaiV 11 bulan lalu
induk
melakukan
471968eaf2

+ 2 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/AiWriteController.java

@@ -26,9 +26,10 @@ public class AiWriteController {
     private AiWriteService writeService;
 
     @PostMapping(value = "/generate-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
-    @PermitAll
     @Operation(summary = "写作生成(流式)", description = "流式返回,响应较快")
+    @PermitAll  // 解决 SSE 最终响应的时候,会被 Access Denied 拦截的问题
     public Flux<CommonResult<String>> generateWriteContent(@RequestBody @Valid AiWriteGenerateReqVO generateReqVO) {
         return writeService.generateWriteContent(generateReqVO, getLoginUserId());
     }
+
 }

+ 3 - 37
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java

@@ -4,8 +4,7 @@ import cn.hutool.core.collection.CollUtil;
 import cn.hutool.core.util.ObjUtil;
 import cn.hutool.core.util.StrUtil;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
-import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
-import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
+import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
 import cn.iocoder.yudao.framework.common.pojo.CommonResult;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
@@ -19,7 +18,6 @@ import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
 import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
 import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
 import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
-import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
 import jakarta.annotation.Resource;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.ai.chat.messages.*;
@@ -28,9 +26,6 @@ import org.springframework.ai.chat.model.ChatResponse;
 import org.springframework.ai.chat.model.StreamingChatModel;
 import org.springframework.ai.chat.prompt.ChatOptions;
 import org.springframework.ai.chat.prompt.Prompt;
-import org.springframework.ai.ollama.api.OllamaOptions;
-import org.springframework.ai.openai.OpenAiChatOptions;
-import org.springframework.ai.qianfan.QianFanChatOptions;
 import org.springframework.stereotype.Service;
 import org.springframework.transaction.annotation.Transactional;
 import reactor.core.publisher.Flux;
@@ -148,46 +143,17 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         }
         // 1.2 history message 历史消息
         List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
-        contextMessages.forEach(message -> {
-            // TODO @芋艿:看看有没优化空间
-            if (MessageType.USER.getValue().equals(message.getType())) {
-                chatMessages.add(new UserMessage(message.getContent()));
-            } else {
-                chatMessages.add(new AssistantMessage(message.getContent()));
-            }
-        });
+        contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent())));
         // 1.3 user message 新发送消息
         chatMessages.add(new UserMessage(sendReqVO.getContent()));
 
         // 2. 构建 ChatOptions 对象
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
-        ChatOptions chatOptions = buildChatOptions(platform, model.getModel(),
+        ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(),
                 conversation.getTemperature(), conversation.getMaxTokens());
         return new Prompt(chatMessages, chatOptions);
     }
 
-    private static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
-        Float temperatureF = temperature != null ? temperature.floatValue() : null;
-        //noinspection EnhancedSwitchMigration
-        switch (platform) {
-            case OPENAI:
-                return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
-            case OLLAMA:
-                return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
-            case YI_YAN:
-                // TODO 芋艿:貌似 model 只要一设置,就报错
-//                return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
-                return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
-            case XING_HUO:
-                return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
-                        .setMaxTokens(maxTokens);
-            case QIAN_WEN:
-                return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
-            default:
-                throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
-        }
-    }
-
     /**
      * 从历史消息中,获得倒序的 n 组消息作为消息上下文
      *

+ 7 - 31
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java

@@ -2,8 +2,7 @@ package cn.iocoder.yudao.module.ai.service.write;
 
 import cn.hutool.core.util.StrUtil;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
-import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
-import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
+import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
 import cn.iocoder.yudao.framework.common.pojo.CommonResult;
 import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
 import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO;
@@ -16,16 +15,12 @@ import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
 import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
 import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
 import cn.iocoder.yudao.module.system.api.dict.DictDataApi;
-import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
 import jakarta.annotation.Resource;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.ai.chat.model.ChatResponse;
 import org.springframework.ai.chat.model.StreamingChatModel;
 import org.springframework.ai.chat.prompt.ChatOptions;
 import org.springframework.ai.chat.prompt.Prompt;
-import org.springframework.ai.ollama.api.OllamaOptions;
-import org.springframework.ai.openai.OpenAiChatOptions;
-import org.springframework.ai.qianfan.QianFanChatOptions;
 import org.springframework.stereotype.Service;
 import reactor.core.publisher.Flux;
 
@@ -56,19 +51,21 @@ public class AiWriteServiceImpl implements AiWriteService {
 
     @Override
     public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
-        // 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok?
+        // 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok?;那可以,有限拿 chatRole 的角色;如果没有,则获取默认的;
         AiChatModelDO model = chatModalService.getRequiredDefaultChatModel();
         StreamingChatModel chatClient = apiKeyService.getChatClient(model.getKeyId());
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
-        ChatOptions chatOptions = buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
 
         // 1.2 插入写作信息
+        // TODO @xin:建议把 writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()),写在 toBean 的 consumer 里;原因是,让这个 set 保持完整性
         AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class);
         writeMapper.insert(writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
 
         // 2.1 构建提示词
+        ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
         Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions);
         Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
+
         // 2.2 流式返回
         StringBuffer contentBuffer = new StringBuffer();
         return streamResponse.map(chunk -> {
@@ -92,7 +89,9 @@ public class AiWriteServiceImpl implements AiWriteService {
         String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getFormat());
         String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getFormat());
         String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getFormat());
+        // TODO @xin:建议改成 if return 哈;更简洁;
         if (Objects.equals(writeType, AiWriteTypeEnum.WRITING.getType())) {
+            // TODO @xin:写成静态枚举哈
             template = "请撰写一篇关于 [{}] 的文章。文章的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。";
             return StrUtil.format(template, generateReqVO.getPrompt(), format, tone, language, length);
         } else if (Objects.equals(writeType, AiWriteTypeEnum.REPLY.getType())) {
@@ -103,27 +102,4 @@ public class AiWriteServiceImpl implements AiWriteService {
         }
     }
 
-    // TODO 芋艿:复用
-    private static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
-        Float temperatureF = temperature != null ? temperature.floatValue() : null;
-        //noinspection EnhancedSwitchMigration
-        switch (platform) {
-            case OPENAI:
-                return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
-            case OLLAMA:
-                return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
-            case YI_YAN:
-                // TODO 芋艿:貌似 model 只要一设置,就报错
-//                return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
-                return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
-            case XING_HUO:
-                return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
-                        .setMaxTokens(maxTokens);
-            case QIAN_WEN:
-                return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
-            default:
-                throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
-        }
-    }
-
 }

+ 59 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java

@@ -0,0 +1,59 @@
+package cn.iocoder.yudao.framework.ai.core.util;
+
+import cn.hutool.core.util.StrUtil;
+import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
+import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
+import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
+import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
+import org.springframework.ai.chat.messages.*;
+import org.springframework.ai.chat.prompt.ChatOptions;
+import org.springframework.ai.ollama.api.OllamaOptions;
+import org.springframework.ai.openai.OpenAiChatOptions;
+import org.springframework.ai.qianfan.QianFanChatOptions;
+
+/**
+ * Spring AI 工具类
+ *
+ * @author 芋道源码
+ */
+public class AiUtils {
+
+    public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
+        Float temperatureF = temperature != null ? temperature.floatValue() : null;
+        //noinspection EnhancedSwitchMigration
+        switch (platform) {
+            case OPENAI:
+                return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
+            case OLLAMA:
+                return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
+            case YI_YAN:
+                // TODO @xin:貌似 model 只要一设置,就报错;可以排查下
+//                return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
+                return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
+            case XING_HUO:
+                return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
+                        .setMaxTokens(maxTokens);
+            case QIAN_WEN:
+                return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
+            default:
+                throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
+        }
+    }
+
+    public static Message buildMessage(String type, String content) {
+        if (MessageType.USER.getValue().equals(type)) {
+            return new UserMessage(content);
+        }
+        if (MessageType.ASSISTANT.getValue().equals(type)) {
+            return new AssistantMessage(content);
+        }
+        if (MessageType.SYSTEM.getValue().equals(type)) {
+            return new SystemMessage(content);
+        }
+        if (MessageType.FUNCTION.getValue().equals(type)) {
+            return new FunctionMessage(content);
+        }
+        throw new IllegalArgumentException(StrUtil.format("未知消息类型({})", type));
+    }
+
+}