Ver Fonte

【增加】AI 写作:支持撰写

xiaoxin há 11 meses atrás
pai
commit
77ead4859c

+ 5 - 0
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/ErrorCodeConstants.java

@@ -42,4 +42,9 @@ public interface ErrorCodeConstants {
     // ========== API 音乐 1-040-006-000 ==========
     ErrorCode MUSIC_NOT_EXISTS = new ErrorCode(1_022_006_000, "音乐不存在!");
 
+
+    // ========== API 写作 1-022-007-000 ==========
+    ErrorCode WRITE_NOT_EXISTS = new ErrorCode(1_022_007_000, "作文不存在!");
+    ErrorCode WRITE_STREAM_ERROR = new ErrorCode(1_022_07_001, "Stream 对话异常!");
+
 }

+ 1 - 1
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/music/AiMusicGenerateModeEnum.java

@@ -7,7 +7,7 @@ import lombok.Getter;
 import java.util.Arrays;
 
 /**
- * AI 音乐状态的枚举
+ * AI 音乐生成模式的枚举
  *
  * @author xiaoxin
  */

+ 37 - 0
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiWriteTypeEnum.java

@@ -0,0 +1,37 @@
+package cn.iocoder.yudao.module.ai.enums.write;
+
+import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+
+import java.util.Arrays;
+
+/**
+ * AI 写作类型的枚举
+ *
+ * @author xiaoxin
+ */
+@AllArgsConstructor
+@Getter
+public enum AiWriteTypeEnum implements IntArrayValuable {
+
+    DESCRIPTION(1, "撰写"),
+    LYRIC(2, "回复");
+
+    /**
+     * 类型
+     */
+    private final Integer type;
+    /**
+     * 类型名
+     */
+    private final String name;
+
+    public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiWriteTypeEnum::getType).toArray();
+
+    @Override
+    public int[] array() {
+        return ARRAYS;
+    }
+
+}

+ 3 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/AiWriteController.java

@@ -15,6 +15,8 @@ import org.springframework.web.bind.annotation.RequestMapping;
 import org.springframework.web.bind.annotation.RestController;
 import reactor.core.publisher.Flux;
 
+import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId;
+
 @Tag(name = "管理后台 - AI 写作")
 @RestController
 @RequestMapping("/ai/write")
@@ -27,6 +29,6 @@ public class AiWriteController {
     @PermitAll
     @Operation(summary = "写作生成(流式)", description = "流式返回,响应较快")
     public Flux<CommonResult<String>> generateComposition(@RequestBody @Valid AiWriteGenerateReqVO generateReqVO) {
-        return writeService.generateComposition(generateReqVO);
+        return writeService.generateWriteContent(generateReqVO, getLoginUserId());
     }
 }

+ 4 - 4
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/vo/AiWriteGenerateReqVO.java

@@ -8,14 +8,14 @@ import lombok.Data;
 @Data
 public class AiWriteGenerateReqVO {
 
-    @Schema(description = "写作内容", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "田忌赛马")
-    private String content;
+    @Schema(description = "写作内容提示", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "田忌赛马")
+    private String contentPrompt;
 
     @Schema(description = "原文", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "领导我要辞职")
     private String originalContent;
 
     @Schema(description = "回复内容", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "准了")
-    private String replyContent;
+    private String replyContentPrompt;
 
     @Schema(description = "长度", requiredMode = Schema.RequiredMode.REQUIRED, example = "中等")
     @NotBlank(message = "长度不能为空")
@@ -35,5 +35,5 @@ public class AiWriteGenerateReqVO {
 
 
     @Schema(description = "写作类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
-    private Integer writeType;
+    private Integer writeType; //参见 AiWriteTypeEnum 枚举
 }

+ 97 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/write/AiWriteDO.java

@@ -0,0 +1,97 @@
+package cn.iocoder.yudao.module.ai.dal.dataobject.write;
+
+import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
+import com.baomidou.mybatisplus.annotation.IdType;
+import com.baomidou.mybatisplus.annotation.TableId;
+import com.baomidou.mybatisplus.annotation.TableName;
+import lombok.Data;
+import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
+
+/**
+ * AI 写作 DO
+ *
+ * @author xiaoxin
+ */
+@TableName(value = "ai_write", autoResultMap = true)
+@Data
+public class AiWriteDO extends BaseDO {
+
+    /**
+     * 编号
+     */
+    @TableId(type = IdType.AUTO)
+    private Long id;
+
+    /**
+     * 用户编号
+     */
+    private Long userId;
+
+    /**
+     * 写作类型
+     * <p>
+     * 枚举 {@link AiWriteTypeEnum}
+     */
+    private Integer writeType;
+
+    /**
+     * 撰写内容提示
+     */
+    private String contentPrompt;
+
+    /**
+     * 生成的撰写内容
+     */
+    private String generatedContent;
+
+    /**
+     * 原文
+     */
+    private String originalContent;
+
+    /**
+     * 回复内容提示
+     */
+    private String replyContentPrompt;
+
+    /**
+     * 生成的回复内容
+     */
+    private String generatedReplyContent;
+
+    /**
+     * 长度提示词
+     */
+    private String length;
+
+    /**
+     * 格式提示词
+     */
+    private String format;
+
+    /**
+     * 语气提示词
+     */
+    private String tone;
+
+    /**
+     * 语言提示词
+     */
+    private String language;
+
+    /**
+     * 模型
+     */
+    private String model;
+
+    /**
+     * 平台
+     */
+    private String platform;
+
+    /**
+     * 错误信息
+     */
+    private String errorMessage;
+
+}

+ 14 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/write/AiWriteMapper.java

@@ -0,0 +1,14 @@
+package cn.iocoder.yudao.module.ai.dal.mysql.write;
+
+import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
+import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO;
+import org.apache.ibatis.annotations.Mapper;
+
+/**
+ * AI 音乐 Mapper
+ *
+ * @author xiaoxin
+ */
+@Mapper
+public interface AiWriteMapper extends BaseMapperX<AiWriteDO> {
+}

+ 8 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteService.java

@@ -12,7 +12,14 @@ import reactor.core.publisher.Flux;
 public interface AiWriteService {
 
 
-    Flux<CommonResult<String>> generateComposition(AiWriteGenerateReqVO generateReqVO);
+    /**
+     * 生成写作内容
+     *
+     * @param generateReqVO 作文生成请求参数
+     * @param userId        用户编号
+     * @return 生成结果
+     */
+    Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId);
 
 
 }

+ 38 - 20
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java

@@ -2,22 +2,27 @@ package cn.iocoder.yudao.module.ai.service.write;
 
 import cn.hutool.core.util.StrUtil;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
-import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions;
 import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
 import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
-import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
 import cn.iocoder.yudao.framework.common.pojo.CommonResult;
+import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
 import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO;
+import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper;
 import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
 import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
+import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
+import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
 import jakarta.annotation.Resource;
 import lombok.extern.slf4j.Slf4j;
-import org.springframework.ai.chat.ChatResponse;
-import org.springframework.ai.chat.StreamingChatClient;
+import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.chat.model.StreamingChatModel;
 import org.springframework.ai.chat.prompt.ChatOptions;
 import org.springframework.ai.chat.prompt.Prompt;
 import org.springframework.ai.ollama.api.OllamaOptions;
 import org.springframework.ai.openai.OpenAiChatOptions;
+import org.springframework.ai.qianfan.QianFanChatOptions;
 import org.springframework.stereotype.Service;
 import reactor.core.publisher.Flux;
 
@@ -35,16 +40,29 @@ public class AiWriteServiceImpl implements AiWriteService {
 
     @Resource
     private AiApiKeyService apiKeyService;
+    @Resource
+    private AiChatModelService chatModalService;
+    @Resource
+    private AiWriteMapper writeMapper;
 
 
     @Override
-    public Flux<CommonResult<String>> generateComposition(AiWriteGenerateReqVO generateReqVO) {
-        StreamingChatClient chatClient = apiKeyService.getStreamingChatClient(6L);
-        AiPlatformEnum platform = AiPlatformEnum.validatePlatform("QianWen");
-        ChatOptions chatOptions = buildChatOptions(platform, "qwen-72b-chat", 1.0, 1000);
+    public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
+        //TODO 芋艿 写作的模型配置放哪好 先用千问测试
+        // 1.1 校验模型
+        AiChatModelDO model = chatModalService.validateChatModel(14L);
+        StreamingChatModel chatClient = apiKeyService.getStreamingChatClient(model.getKeyId());
+        AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
+        ChatOptions chatOptions = buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
+
+        //1.2 插入写作信息
+        AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class);
+        writeMapper.insert(writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
+
+        //2.1 构建提示词
         Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions);
         Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
-        // 3.3 流式返回
+        // 2.2 流式返回
         StringBuffer contentBuffer = new StringBuffer();
         return streamResponse.map(chunk -> {
             String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
@@ -53,17 +71,17 @@ public class AiWriteServiceImpl implements AiWriteService {
             // 响应结果
             return success(newContent);
         }).doOnComplete(() -> {
-            log.info("generateComposition complete, content: {}", contentBuffer);
-        }).onErrorResume(error -> {
-            log.error("[AI 写作] 发生异常", error);
-            return Flux.just(error(ErrorCodeConstants.AI_CHAT_STREAM_ERROR));
-        });
+            writeMapper.updateById(new AiWriteDO().setId(writeDO.getId()).setGeneratedContent(contentBuffer.toString()));
+        }).doOnError(throwable -> {
+            log.error("[AI Write][generateReqVO({}) 发生异常]", generateReqVO, throwable);
+            writeMapper.updateById(new AiWriteDO().setId(writeDO.getId()).setErrorMessage(throwable.getMessage()));
+        }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
     }
 
 
     private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
-        String template = "请直接写一篇关于 [{}] 的文章,格式为:{},语气为:{},语言为:{},长度为:{}。请确保涵盖主要内容,不需要任何额外的解释或道歉。";
-        String content = generateReqVO.getContent();
+        String template = "请直接写一篇关于 [{}] 的文章,格式为:{},语气为:{},语言为:{},长度为:{}。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。";
+        String content = generateReqVO.getContentPrompt();
         String format = generateReqVO.getFormat();
         String tone = generateReqVO.getTone();
         String language = generateReqVO.getLanguage();
@@ -81,14 +99,14 @@ public class AiWriteServiceImpl implements AiWriteService {
             case OLLAMA:
                 return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
             case YI_YAN:
-                // TODO @fan:增加一个 model
-                return new YiYanChatOptions().setTemperature(temperatureF).setMaxOutputTokens(maxTokens);
+                // TODO 芋艿:貌似 model 只要一设置,就报错
+//                return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
+                return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
             case XING_HUO:
                 return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
                         .setMaxTokens(maxTokens);
             case QIAN_WEN:
-                // TODO @fan:增加 model、temperature 参数
-                return new QianWenOptions().setMaxTokens(maxTokens);
+                return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
             default:
                 throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
         }