Browse Source

【代码优化】AI:思维导入、写作的生成

YunaiV 1 year ago
parent
commit
68ed8cd6f8

+ 1 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java

@@ -111,7 +111,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
                 userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
 
-        // 3.2 创建 chat 需要的 Prompt
+        // 3.2 构建 Prompt,并进行调用
         Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO);
         Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
 

+ 27 - 22
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java

@@ -32,13 +32,12 @@ import reactor.core.publisher.Flux;
 
 import java.util.ArrayList;
 import java.util.List;
-import java.util.Objects;
 
 import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
 import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
 
 /**
- * AI 写作 Service 实现类
+ * AI 思维导图 Service 实现类
  *
  * @author xiaoxin
  */
@@ -58,30 +57,28 @@ public class AiMindMapServiceImpl implements AiMindMapService {
 
     @Override
     public Flux<CommonResult<String>> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) {
-        // 1 获取脑图模型 尝试获取思维导图助手角色,如果没有则使用默认模型
-        AiChatRoleDO mindMapRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
+        // 1. 获取脑图模型。尝试获取思维导图助手角色,如果没有则使用默认模型
+        AiChatRoleDO role = CollUtil.getFirst(
+                chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
         // 1.1 获取脑图执行模型
-        AiChatModelDO model = getModel(mindMapRole);
+        AiChatModelDO model = getModel(role);
         // 1.2 获取角色设定消息
-        String systemMessage = Objects.nonNull(mindMapRole) && StrUtil.isNotBlank(mindMapRole.getSystemMessage())
-                ? mindMapRole.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
+        String systemMessage = role != null && StrUtil.isNotBlank(role.getSystemMessage())
+                ? role.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
         // 1.3 校验平台
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
         ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
 
-        // 2 插入思维导图信息
+        // 2. 插入思维导图信息
         AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class,
                 mindMap -> mindMap.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
         mindMapMapper.insert(mindMapDO);
 
-        ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
-        // 3.1 角色设定
-        List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
-        // 3.3 构建提示词
-        Prompt prompt = new Prompt(chatMessages, chatOptions);
-
+        // 3.1 构建 Prompt,并进行调用
+        Prompt prompt = buildPrompt(generateReqVO, model, systemMessage);
         Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
-        // 3.4 流式返回
+
+        // 3.2 流式返回
         StringBuffer contentBuffer = new StringBuffer();
         return streamResponse.map(chunk -> {
             String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
@@ -102,24 +99,32 @@ public class AiMindMapServiceImpl implements AiMindMapService {
 
     }
 
+    private Prompt buildPrompt(AiMindMapGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) {
+        // 1. 构建 message 列表
+        List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
+        // 2. 构建 options 对象
+        AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
+        ChatOptions options = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
+        return new Prompt(chatMessages, options);
+    }
+
     private static List<Message> buildMessages(AiMindMapGenerateReqVO generateReqVO, String systemMessage) {
         List<Message> chatMessages = new ArrayList<>();
+        // 1. 角色设定
         if (StrUtil.isNotBlank(systemMessage)) {
-            // 1.1 角色设定
             chatMessages.add(new SystemMessage(systemMessage));
         }
-        // 1.2 用户输入
+        // 2. 用户输入
         chatMessages.add(new UserMessage(generateReqVO.getPrompt()));
         return chatMessages;
     }
 
-    // TODO 芋艿:这里脑图、写作都用到了,是不是可以抽哪里去
-    private AiChatModelDO getModel(AiChatRoleDO chatRoleDO) {
+    private AiChatModelDO getModel(AiChatRoleDO role) {
         AiChatModelDO model = null;
-        if (Objects.nonNull(chatRoleDO) && Objects.nonNull(chatRoleDO.getModelId())) {
-            model = chatModalService.getChatModel(chatRoleDO.getModelId());
+        if (role != null && role.getModelId() != null) {
+            model = chatModalService.getChatModel(role.getModelId());
         }
-        if (Objects.isNull(model)) {
+        if (model != null) {
             model = chatModalService.getRequiredDefaultChatModel();
         }
         Assert.notNull(model, "[AI] 获取不到模型");

+ 17 - 12
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java

@@ -68,8 +68,9 @@ public class AiWriteServiceImpl implements AiWriteService {
 
     @Override
     public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
-        // 1 获取写作模型 尝试获取写作助手角色,没有则使用默认模型
-        AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
+        // 1 获取写作模型。尝试获取写作助手角色,没有则使用默认模型
+        AiChatRoleDO writeRole = CollUtil.getFirst(
+                chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
         // 1.1 获取写作执行模型
         AiChatModelDO model = getModel(writeRole);
         // 1.2 获取角色设定消息
@@ -84,16 +85,11 @@ public class AiWriteServiceImpl implements AiWriteService {
                 write -> write.setUserId(userId).setPlatform(platform.getPlatform()).setModel(model.getModel()));
         writeMapper.insert(writeDO);
 
-        // 3. 调用大模型,写作生成
-        ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
-        // 3.1 构建消息列表
-        List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
-        // 3.2 构建提示词
-        Prompt prompt = new Prompt(chatMessages, chatOptions);
-        // 3.3 流式调用
+        // 3.1 构建 Prompt,并进行调用
+        Prompt prompt = buildPrompt(generateReqVO, model, systemMessage);
         Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
 
-        // 4. 流式返回
+        // 3.2 流式返回
         StringBuffer contentBuffer = new StringBuffer();
         return streamResponse.map(chunk -> {
             String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
@@ -125,6 +121,15 @@ public class AiWriteServiceImpl implements AiWriteService {
         return model;
     }
 
+    private Prompt buildPrompt(AiWriteGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) {
+        // 1. 构建 message 列表
+        List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
+        // 2. 构建 options 对象
+        AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
+        ChatOptions options = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
+        return new Prompt(chatMessages, options);
+    }
+
     private List<Message> buildMessages(AiWriteGenerateReqVO generateReqVO, String systemMessage) {
         List<Message> chatMessages = new ArrayList<>();
         if (StrUtil.isNotBlank(systemMessage)) {
@@ -132,11 +137,11 @@ public class AiWriteServiceImpl implements AiWriteService {
             chatMessages.add(new SystemMessage(systemMessage));
         }
         // 1.2 用户输入
-        chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO)));
+        chatMessages.add(new UserMessage(buildUserMessage(generateReqVO)));
         return chatMessages;
     }
 
-    private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
+    private String buildUserMessage(AiWriteGenerateReqVO generateReqVO) {
         String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
         String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());
         String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage());