Browse Source

【优化】重新适配 dall2和dall3

cherishsince 1 year ago
parent
commit
6e3f34b0db

+ 7 - 7
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java

@@ -1,7 +1,8 @@
 package cn.iocoder.yudao.module.ai.controller.admin.image;
 
 import cn.iocoder.yudao.framework.common.pojo.CommonResult;
-import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReq;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingRespVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReq;
 import cn.iocoder.yudao.module.ai.service.AiImageService;
 import io.swagger.v3.oas.annotations.Operation;
@@ -14,6 +15,8 @@ import org.springframework.web.bind.annotation.RequestBody;
 import org.springframework.web.bind.annotation.RequestMapping;
 import org.springframework.web.bind.annotation.RestController;
 
+import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
+
 // TODO @芋艿:整理接口定义
 /**
  * ai作图
@@ -33,17 +36,14 @@ public class AiImageController {
 
     @Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!")
     @PostMapping("/dallDrawing")
-    public void dallDrawing(@Validated @RequestBody AiImageDallDrawingReq req) {
-//        Utf8SseEmitter sseEmitter = new Utf8SseEmitter();
-//        aiImageService.dallDrawing(req, sseEmitter);
-//        return sseEmitter;
-
+    public AiImageDallDrawingRespVO dallDrawing(@Validated @RequestBody AiImageDallDrawingReqVO req) {
+        return aiImageService.dallDrawing(req);
     }
 
     @Operation(summary = "midjourney", description = "midjourney图片绘画流程:1、提交任务 2、获取完成的任务 3、选择对应功能 4、获取最终结果")
     @PostMapping("/midjourney")
     public CommonResult<Void> midjourney(@Validated @RequestBody AiImageMidjourneyReq req) {
         aiImageService.midjourney(req);
-        return CommonResult.success(null);
+        return success(null);
     }
 }

+ 3 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImageDallDrawingReq.java → yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImageDallDrawingReqVO.java

@@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
 
 import io.swagger.v3.oas.annotations.media.Schema;
 import jakarta.validation.constraints.NotNull;
+import jakarta.validation.constraints.Size;
 import lombok.Data;
 import lombok.experimental.Accessors;
 
@@ -14,10 +15,11 @@ import lombok.experimental.Accessors;
  */
 @Data
 @Accessors(chain = true)
-public class AiImageDallDrawingReq {
+public class AiImageDallDrawingReqVO {
 
     @Schema(description = "提示词")
     @NotNull(message = "提示词不能为空!")
+    @Size(max = 1200, message = "提示词最大1200")
     private String prompt;
 
     @Schema(description = "模型")

+ 45 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImageDallDrawingRespVO.java

@@ -0,0 +1,45 @@
+package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
+
+import io.swagger.v3.oas.annotations.media.Schema;
+import jakarta.validation.constraints.NotNull;
+import jakarta.validation.constraints.Size;
+import lombok.Data;
+import lombok.experimental.Accessors;
+
+/**
+ * dall2/dall2 绘画
+ *
+ * @author fansili
+ * @time 2024/4/25 16:24
+ * @since 1.0
+ */
+@Data
+@Accessors(chain = true)
+public class AiImageDallDrawingRespVO {
+
+
+    @Schema(description = "提示词")
+    @NotNull(message = "提示词不能为空!")
+    @Size(max = 1200, message = "提示词最大1200")
+    private String prompt;
+
+    @Schema(description = "模型")
+    @NotNull(message = "模型不能为空")
+    private String modal;
+
+    @Schema(description = "风格")
+    private String style;
+
+    @Schema(description = "图片size 1024x1024 ...")
+    private String size;
+
+    @Schema(description = "可以访问图像的URL。")
+    private String url;
+
+    @Schema(description = "图片base64。")
+    private String base64;
+
+    @Schema(description = "错误信息。")
+    private String errorMessage;
+
+}

+ 27 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/convert/AiImageConvert.java

@@ -0,0 +1,27 @@
+package cn.iocoder.yudao.module.ai.convert;
+
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingRespVO;
+import org.mapstruct.Mapper;
+import org.mapstruct.factory.Mappers;
+
+/**
+ * ai image convert
+ *
+ * @author fansili
+ * @time 2024/4/18 16:39
+ * @since 1.0
+ */
+@Mapper
+public interface AiImageConvert {
+
+    AiImageConvert INSTANCE = Mappers.getMapper(AiImageConvert.class);
+
+    /**
+     * 转换 - AiImageDallDrawingRespVO
+     *
+     * @param req
+     * @return
+     */
+    AiImageDallDrawingRespVO convertAiImageDallDrawingRespVO(AiImageDallDrawingReqVO req);
+}

+ 3 - 2
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiImageService.java

@@ -1,6 +1,7 @@
 package cn.iocoder.yudao.module.ai.service;
 
-import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReq;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingRespVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReq;
 
 /**
@@ -17,7 +18,7 @@ public interface AiImageService {
      *
      * @param req
      */
-    void dallDrawing(AiImageDallDrawingReq req);
+    AiImageDallDrawingRespVO dallDrawing(AiImageDallDrawingReqVO req);
 
     /**
      * midjourney 图片生成

+ 18 - 12
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiImageServiceImpl.java

@@ -14,8 +14,10 @@ import cn.iocoder.yudao.framework.ai.midjourney.webSocket.WssNotify;
 import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
 import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
 import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
-import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReq;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingRespVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReq;
+import cn.iocoder.yudao.module.ai.convert.AiImageConvert;
 import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.AiImageMapper;
 import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum;
@@ -60,31 +62,35 @@ public class AiImageServiceImpl implements AiImageService {
     }
 
     @Override
-    public void dallDrawing(AiImageDallDrawingReq req) {
+    public AiImageDallDrawingRespVO dallDrawing(AiImageDallDrawingReqVO req) {
         // 获取 model
         OpenAiImageModelEnum openAiImageModelEnum = OpenAiImageModelEnum.valueOfModel(req.getModal());
         OpenAiImageStyleEnum openAiImageStyleEnum = OpenAiImageStyleEnum.valueOfStyle(req.getStyle());
-        //
-        OpenAiImageOptions openAiImageOptions = new OpenAiImageOptions();
-        openAiImageOptions.setModel(openAiImageModelEnum);
-        openAiImageOptions.setStyle(openAiImageStyleEnum);
-        openAiImageOptions.setSize(req.getSize());
-        ImageResponse imageResponse;
+        // 转换 AiImageDallDrawingRespVO
+        AiImageDallDrawingRespVO respVO = AiImageConvert.INSTANCE.convertAiImageDallDrawingRespVO(req);
         try {
-            imageResponse = openAiImageClient.call(new ImagePrompt(req.getPrompt(), openAiImageOptions));
+            // 转换openai 参数
+            OpenAiImageOptions openAiImageOptions = new OpenAiImageOptions();
+            openAiImageOptions.setModel(openAiImageModelEnum);
+            openAiImageOptions.setStyle(openAiImageStyleEnum);
+            openAiImageOptions.setSize(req.getSize());
+            ImageResponse imageResponse = openAiImageClient.call(new ImagePrompt(req.getPrompt(), openAiImageOptions));
             // 发送
             ImageGeneration imageGeneration = imageResponse.getResult();
-            // 发送信息
-//            sendSseEmitter(sseEmitter, imageGeneration);
             // 保存数据库
             doSave(req.getPrompt(), req.getSize(), req.getModal(),
                     imageGeneration.getOutput().getUrl(), AiChatDrawingStatusEnum.COMPLETE, null);
+            // 返回 flex
+            respVO.setUrl(imageGeneration.getOutput().getUrl());
+            respVO.setBase64(imageGeneration.getOutput().getB64Json());
+            return respVO;
         } catch (AiException aiException) {
             // 保存数据库
             doSave(req.getPrompt(), req.getSize(), req.getModal(),
                     null, AiChatDrawingStatusEnum.FAIL, aiException.getMessage());
             // 发送错误信息
-//            sendSseEmitter(sseEmitter, aiException.getMessage());
+            respVO.setErrorMessage(aiException.getMessage());
+            return respVO;
         }
     }
 

+ 1 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/resources/http/image.http

@@ -1,7 +1,7 @@
 
 ### chat dallDrawing
 
-POST {{baseUrl}}/ai/image/dallDrawing
+POST {{baseUrl}}/admin-api/ai/image/dallDrawing
 Content-Type: application/json
 Authorization: {{token}}