Browse Source

【代码评审】AI:写作部分的建议

YunaiV 9 months ago
parent
commit
b4014bf2df

+ 3 - 3
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/AiChatRoleEnum.java

@@ -7,7 +7,7 @@ import lombok.Getter;
 import java.util.Arrays;
 import java.util.Arrays;
 
 
 /**
 /**
- * AI 写作类型的枚举
+ * AI 内置聊天角色的枚举
  *
  *
  * @author xiaoxin
  * @author xiaoxin
  */
  */
@@ -21,6 +21,7 @@ public enum AiChatRoleEnum implements IntArrayValuable {
             2.	回复生成:根据用户提供的场景和提示词,生成合适的对话或文字回复,确保语气和风格符合场景需求。
             2.	回复生成:根据用户提供的场景和提示词,生成合适的对话或文字回复,确保语气和风格符合场景需求。
             除此之外不需要除了正文内容外的其他回复,如标题、开头、任何解释性语句或道歉。
             除此之外不需要除了正文内容外的其他回复,如标题、开头、任何解释性语句或道歉。
             """),
             """),
+
     AI_MIND_MAP_ROLE(2, "脑图助手", """
     AI_MIND_MAP_ROLE(2, "脑图助手", """
              你是一位非常优秀的思维导图助手,你会把用户的所有提问都总结成思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子:
              你是一位非常优秀的思维导图助手,你会把用户的所有提问都总结成思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子:
              # Geek-AI 助手
              # Geek-AI 助手
@@ -38,7 +39,6 @@ public enum AiChatRoleEnum implements IntArrayValuable {
             除此之外不要任何解释性语句。
             除此之外不要任何解释性语句。
             """);
             """);
 
 
-
     /**
     /**
      * 角色
      * 角色
      */
      */
@@ -51,7 +51,7 @@ public enum AiChatRoleEnum implements IntArrayValuable {
     /**
     /**
      * 角色设定
      * 角色设定
      */
      */
-    private final String prompt;
+    private final String systemMessage;
 
 
     public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiChatRoleEnum::getRole).toArray();
     public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiChatRoleEnum::getRole).toArray();
 
 

+ 2 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/mindmap/vo/AiMindMapGenerateReqVO.java

@@ -7,7 +7,9 @@ import lombok.Data;
 @Schema(description = "管理后台 - AI 思维导图生成 Request VO")
 @Schema(description = "管理后台 - AI 思维导图生成 Request VO")
 @Data
 @Data
 public class AiMindMapGenerateReqVO {
 public class AiMindMapGenerateReqVO {
+
     @Schema(description = "思维导图内容提示", example = "Java 学习路线")
     @Schema(description = "思维导图内容提示", example = "Java 学习路线")
     @NotBlank(message = "思维导图内容提示不能为空")
     @NotBlank(message = "思维导图内容提示不能为空")
     private String prompt;
     private String prompt;
+
 }
 }

+ 7 - 5
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java

@@ -12,6 +12,7 @@ import lombok.Data;
  *
  *
  * @author xiaoxin
  * @author xiaoxin
  */
  */
+// TODO @xin:如果没 typehandler 的需求,autoResultMap 可以去掉哈
 @TableName(value = "ai_mind_map", autoResultMap = true)
 @TableName(value = "ai_mind_map", autoResultMap = true)
 @Data
 @Data
 public class AiMindMapDO extends BaseDO {
 public class AiMindMapDO extends BaseDO {
@@ -24,20 +25,21 @@ public class AiMindMapDO extends BaseDO {
 
 
     /**
     /**
      * 用户编号
      * 用户编号
+     *
+     * 关联 AdminUserDO 的 userId 字段
      */
      */
     private Long userId;
     private Long userId;
 
 
-    /**
-     * 模型
-     */
-    private String model;
-
     /**
     /**
      * 平台
      * 平台
      * <p>
      * <p>
      * 枚举 {@link AiPlatformEnum}
      * 枚举 {@link AiPlatformEnum}
      */
      */
     private String platform;
     private String platform;
+    /**
+     * 模型
+     */
+    private String model;
 
 
     /**
     /**
      * 生成内容提示
      * 生成内容提示

+ 16 - 11
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/write/AiWriteDO.java

@@ -2,18 +2,18 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.write;
 
 
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
 import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
+import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
 import com.baomidou.mybatisplus.annotation.IdType;
 import com.baomidou.mybatisplus.annotation.IdType;
 import com.baomidou.mybatisplus.annotation.TableId;
 import com.baomidou.mybatisplus.annotation.TableId;
 import com.baomidou.mybatisplus.annotation.TableName;
 import com.baomidou.mybatisplus.annotation.TableName;
 import lombok.Data;
 import lombok.Data;
-import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
 
 
 /**
 /**
  * AI 写作 DO
  * AI 写作 DO
  *
  *
  * @author xiaoxin
  * @author xiaoxin
  */
  */
-@TableName(value = "ai_write", autoResultMap = true)
+@TableName("ai_write")
 @Data
 @Data
 public class AiWriteDO extends BaseDO {
 public class AiWriteDO extends BaseDO {
 
 
@@ -25,6 +25,8 @@ public class AiWriteDO extends BaseDO {
 
 
     /**
     /**
      * 用户编号
      * 用户编号
+     *
+     * 关联 AdminUserDO 的 userId 字段
      */
      */
     private Long userId;
     private Long userId;
 
 
@@ -35,17 +37,16 @@ public class AiWriteDO extends BaseDO {
      */
      */
     private Integer type;
     private Integer type;
 
 
-    /**
-     * 模型
-     */
-    private String model;
-
     /**
     /**
      * 平台
      * 平台
      *
      *
      * 枚举 {@link AiPlatformEnum}
      * 枚举 {@link AiPlatformEnum}
      */
      */
     private String platform;
     private String platform;
+    /**
+     * 模型
+     */
+    private String model;
 
 
     /**
     /**
      * 生成内容提示
      * 生成内容提示
@@ -56,7 +57,6 @@ public class AiWriteDO extends BaseDO {
      * 生成的内容
      * 生成的内容
      */
      */
     private String generatedContent;
     private String generatedContent;
-
     /**
     /**
      * 原文
      * 原文
      */
      */
@@ -64,21 +64,26 @@ public class AiWriteDO extends BaseDO {
 
 
     /**
     /**
      * 长度提示词
      * 长度提示词
+     *
+     * 字典:{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_LENGTH}
      */
      */
     private Integer length;
     private Integer length;
-
     /**
     /**
      * 格式提示词
      * 格式提示词
+     *
+     * 字典:{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_FORMAT}
      */
      */
     private Integer format;
     private Integer format;
-
     /**
     /**
      * 语气提示词
      * 语气提示词
+     *
+     * 字典:{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_TONE}
      */
      */
     private Integer tone;
     private Integer tone;
-
     /**
     /**
      * 语言提示词
      * 语言提示词
+     *
+     * 字典:{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_LANGUAGE}
      */
      */
     private Integer language;
     private Integer language;
 
 

+ 1 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java

@@ -66,7 +66,7 @@ public class AiMindMapServiceImpl implements AiMindMapService {
             systemMessage = mindMapRole.getSystemMessage();
             systemMessage = mindMapRole.getSystemMessage();
         } else {
         } else {
             model = chatModalService.getRequiredDefaultChatModel();
             model = chatModalService.getRequiredDefaultChatModel();
-            systemMessage = AiChatRoleEnum.AI_MIND_MAP_ROLE.getPrompt();
+            systemMessage = AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
         }
         }
 
 
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());

+ 1 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatRoleService.java

@@ -120,6 +120,7 @@ public interface AiChatRoleService {
 
 
     /**
     /**
      * 根据名字获得聊天角色
      * 根据名字获得聊天角色
+     *
      * @param name 名字
      * @param name 名字
      * @return 聊天角色列表
      * @return 聊天角色列表
      */
      */

+ 9 - 5
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java

@@ -65,6 +65,7 @@ public class AiWriteServiceImpl implements AiWriteService {
     public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
     public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
         // 1.1 获取写作模型 尝试获取写作助手角色,如果没有则使用默认模型
         // 1.1 获取写作模型 尝试获取写作助手角色,如果没有则使用默认模型
         AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
         AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
+        // TODO @xin:如果有 writeRole,但是没 modeId,是不是也可以用 systemMessage 哈?建议的写法是:先通过 modelId 获取 model。如果 model == null,则 chatModalService.getRequiredDefaultChatModel();如果还是 null,则抛出异常;。。。。。。。。。。。。。。然后,systemMessage = writeRole != null && writeRole.systemPrompt != "" 这样处理。
         AiChatModelDO model;
         AiChatModelDO model;
         String systemMessage;
         String systemMessage;
         if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
         if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
@@ -72,18 +73,21 @@ public class AiWriteServiceImpl implements AiWriteService {
             systemMessage = writeRole.getSystemMessage();
             systemMessage = writeRole.getSystemMessage();
         } else {
         } else {
             model = chatModalService.getRequiredDefaultChatModel();
             model = chatModalService.getRequiredDefaultChatModel();
-            systemMessage = AiChatRoleEnum.AI_WRITE_ROLE.getPrompt();
+            systemMessage = AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
         }
         }
         // 1.2 校验平台
         // 1.2 校验平台
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
         StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
         StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
 
 
         // 2. 插入写作信息
         // 2. 插入写作信息
-        AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
+        AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class,
+                write -> write.setUserId(userId).setPlatform(platform.getPlatform()).setModel(model.getModel()));
         writeMapper.insert(writeDO);
         writeMapper.insert(writeDO);
 
 
+        // 3. 调用大模型,写作生成
         ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
         ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
         // 3.1 角色设定
         // 3.1 角色设定
+        // TODO @xin:要不把 90 到 97 这部分,合并到一个方法里。目的是:让这个方法的主干更明确
         List<Message> chatMessages = new ArrayList<>();
         List<Message> chatMessages = new ArrayList<>();
         if (StrUtil.isNotBlank(systemMessage)) {
         if (StrUtil.isNotBlank(systemMessage)) {
             chatMessages.add(new SystemMessage(systemMessage));
             chatMessages.add(new SystemMessage(systemMessage));
@@ -94,7 +98,7 @@ public class AiWriteServiceImpl implements AiWriteService {
         Prompt prompt = new Prompt(chatMessages, chatOptions);
         Prompt prompt = new Prompt(chatMessages, chatOptions);
         Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
         Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
 
 
-        // 3.2 流式返回
+        // 4. 流式返回
         StringBuffer contentBuffer = new StringBuffer();
         StringBuffer contentBuffer = new StringBuffer();
         return streamResponse.map(chunk -> {
         return streamResponse.map(chunk -> {
             String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
             String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
@@ -115,13 +119,13 @@ public class AiWriteServiceImpl implements AiWriteService {
     }
     }
 
 
     private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
     private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
-        Integer type = generateReqVO.getType();
         String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
         String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
         String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());
         String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());
         String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage());
         String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage());
         String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getLength());
         String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getLength());
+        // 格式化 prompt
         String prompt = generateReqVO.getPrompt();
         String prompt = generateReqVO.getPrompt();
-        if (Objects.equals(type, AiWriteTypeEnum.WRITING.getType())) {
+        if (Objects.equals(generateReqVO.getType(), AiWriteTypeEnum.WRITING.getType())) {
             return StrUtil.format(AiWriteTypeEnum.WRITING.getPrompt(), prompt, format, tone, language, length);
             return StrUtil.format(AiWriteTypeEnum.WRITING.getPrompt(), prompt, format, tone, language, length);
         } else {
         } else {
             return StrUtil.format(AiWriteTypeEnum.REPLY.getPrompt(), generateReqVO.getOriginalContent(), prompt, format, tone, language, length);
             return StrUtil.format(AiWriteTypeEnum.REPLY.getPrompt(), generateReqVO.getOriginalContent(), prompt, format, tone, language, length);