Prechádzať zdrojové kódy

【代码评审】AI:写作实现

YunaiV 9 mesiacov pred
rodič
commit
f20c27a7ef

+ 2 - 2
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/ErrorCodeConstants.java

@@ -33,7 +33,7 @@ public interface ErrorCodeConstants {
     // ========== API 聊天消息 1-040-004-000 ==========
 
     ErrorCode CHAT_MESSAGE_NOT_EXIST = new ErrorCode(1_040_004_000, "消息不存在!");
-    ErrorCode CHAT_STREAM_ERROR = new ErrorCode(1_040_004_001, "Stream 对话异常!");
+    ErrorCode CHAT_STREAM_ERROR = new ErrorCode(1_040_004_001, "对话生成异常!");
 
     // ========== API 绘画 1-040-005-000 ==========
 
@@ -48,6 +48,6 @@ public interface ErrorCodeConstants {
 
     // ========== API 写作 1-022-007-000 ==========
     ErrorCode WRITE_NOT_EXISTS = new ErrorCode(1_022_007_000, "作文不存在!");
-    ErrorCode WRITE_STREAM_ERROR = new ErrorCode(1_022_07_001, "Stream 对话异常!");
+    ErrorCode WRITE_STREAM_ERROR = new ErrorCode(1_022_07_001, "写作生成异常!");
 
 }

+ 1 - 0
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/write/AiLanguageEnum.java

@@ -6,6 +6,7 @@ import lombok.Getter;
 
 import java.util.Arrays;
 
+// TODO @xin:写作的几个,不用枚举类哈。直接搞字段就好了。AiWriteTypeEnum 还是需要的哈
 @AllArgsConstructor
 @Getter
 public enum AiLanguageEnum implements IntArrayValuable {

+ 7 - 3
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/vo/AiWriteGenerateReqVO.java

@@ -1,5 +1,7 @@
 package cn.iocoder.yudao.module.ai.controller.admin.write.vo;
 
+import cn.iocoder.yudao.framework.common.validation.InEnum;
+import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
 import io.swagger.v3.oas.annotations.media.Schema;
 import jakarta.validation.constraints.NotNull;
 import lombok.Data;
@@ -8,6 +10,11 @@ import lombok.Data;
 @Data
 public class AiWriteGenerateReqVO {
 
+    @Schema(description = "写作类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
+    @InEnum(AiWriteTypeEnum.class)
+    private Integer type;
+
+    // TODO @xin:如果非必填,可以不用写 requiredMode
     @Schema(description = "写作内容提示", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "1.撰写:田忌赛马;2.回复:不批")
     private String prompt;
 
@@ -30,7 +37,4 @@ public class AiWriteGenerateReqVO {
     @NotNull(message = "语言不能为空")
     private Integer language;
 
-
-    @Schema(description = "写作类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
-    private Integer type; //参见 AiWriteTypeEnum 枚举
 }

+ 13 - 10
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/write/AiWriteDO.java

@@ -1,5 +1,6 @@
 package cn.iocoder.yudao.module.ai.dal.dataobject.write;
 
+import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
 import com.baomidou.mybatisplus.annotation.IdType;
 import com.baomidou.mybatisplus.annotation.TableId;
@@ -34,6 +35,18 @@ public class AiWriteDO extends BaseDO {
      */
     private Integer type;
 
+    /**
+     * 模型
+     */
+    private String model;
+
+    /**
+     * 平台
+     *
+     * 枚举 {@link AiPlatformEnum}
+     */
+    private String platform;
+
     /**
      * 生成内容提示
      */
@@ -69,16 +82,6 @@ public class AiWriteDO extends BaseDO {
      */
     private Integer language;
 
-    /**
-     * 模型
-     */
-    private String model;
-
-    /**
-     * 平台
-     */
-    private String platform;
-
     /**
      * 错误信息
      */

+ 0 - 2
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteService.java

@@ -11,7 +11,6 @@ import reactor.core.publisher.Flux;
  */
 public interface AiWriteService {
 
-
     /**
      * 生成写作内容
      *
@@ -21,5 +20,4 @@ public interface AiWriteService {
      */
     Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId);
 
-
 }

+ 4 - 6
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/write/AiWriteServiceImpl.java

@@ -46,23 +46,22 @@ public class AiWriteServiceImpl implements AiWriteService {
     @Resource
     private AiChatModelService chatModalService;
     @Resource
-    private AiWriteMapper writeMapper;
-
+    private AiWriteMapper writeMapper; // TODO @xin:上面空一行;因为同类之间不要空行,非同类空行;
 
     @Override
     public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
-        //TODO 芋艿 写作的模型配置放哪好 先用千问测试
         // 1.1 校验模型
+        // TODO @xin:可以约定大于配置先,查询某个名字。例如说,写作助手!然后写作助手,上面是有个 model 的,可以使用它。
         AiChatModelDO model = chatModalService.validateChatModel(14L);
         StreamingChatModel chatClient = apiKeyService.getStreamingChatClient(model.getKeyId());
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
         ChatOptions chatOptions = buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
 
-        //1.2 插入写作信息
+        // 1.2 插入写作信息
         AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class);
         writeMapper.insert(writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
 
-        //2.1 构建提示词
+        // 2.1 构建提示词
         Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions);
         Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
         // 2.2 流式返回
@@ -81,7 +80,6 @@ public class AiWriteServiceImpl implements AiWriteService {
         }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
     }
 
-
     private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
         String template;
         Integer writeType = generateReqVO.getType();

+ 16 - 7
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/model/midjourney/api/MidjourneyApi.java

@@ -9,12 +9,17 @@ import lombok.Data;
 import lombok.Getter;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.ai.openai.api.ApiUtils;
+import org.springframework.http.HttpRequest;
+import org.springframework.http.HttpStatusCode;
+import org.springframework.web.reactive.function.client.ClientResponse;
 import org.springframework.web.reactive.function.client.WebClient;
 import reactor.core.publisher.Mono;
 
 import java.util.Collection;
 import java.util.List;
 import java.util.Map;
+import java.util.function.Function;
+import java.util.function.Predicate;
 
 /**
  * Midjourney API
@@ -25,6 +30,16 @@ import java.util.Map;
 @Slf4j
 public class MidjourneyApi {
 
+    private final Predicate<HttpStatusCode> STATUS_PREDICATE = status -> !status.is2xxSuccessful();
+
+    private final Function<Object, Function<ClientResponse, Mono<? extends Throwable>>> EXCEPTION_FUNCTION =
+            reqParam -> response -> response.bodyToMono(String.class).handle((responseBody, sink) -> {
+                HttpRequest request = response.request();
+                log.error("[midjourney-api] 调用失败!请求方式:[{}],请求地址:[{}],请求参数:[{}],响应数据: [{}]",
+                        request.getMethod(), request.getURI(), reqParam, responseBody);
+                sink.error(new IllegalStateException("[midjourney-api] 调用失败!"));
+            });
+
     private final WebClient webClient;
 
     /**
@@ -80,17 +95,11 @@ public class MidjourneyApi {
     }
 
     private String post(String uri, Object body) {
-        // 1、发送 post 请求
         return webClient.post()
                 .uri(uri)
                 .body(Mono.just(JsonUtils.toJsonString(body)), String.class)
                 .retrieve()
-                .onStatus(status -> !status.is2xxSuccessful(),
-                        response -> response.bodyToMono(String.class)
-                                .handle((respBody, sink) -> {
-                                    log.error("【Midjourney api】调用失败!resp: 【{}】", respBody);
-                                    sink.error(new IllegalStateException("【Midjourney api】调用失败!"));
-                                }))
+                .onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(body))
                 .bodyToMono(String.class)
                 .block();
     }