瀏覽代碼

【代码优化】AI:MJ 生成图片 ACTION 的优化

YunaiV 1 年之前
父節點
當前提交
4c3add508b

+ 12 - 17
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java

@@ -8,7 +8,8 @@ import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageRespVO;
-import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiImageMidjourneyImagineReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
 import cn.iocoder.yudao.module.ai.service.image.AiImageService;
 import io.swagger.v3.oas.annotations.Operation;
@@ -67,30 +68,24 @@ public class AiImageController {
 
     @Operation(summary = "【Midjourney】生成图片")
     @PostMapping("/midjourney/imagine")
-    public CommonResult<Long> midjourneyImagine(@Validated @RequestBody AiImageMidjourneyImagineReqVO reqVO) {
-        if (true) {
-            imageService.midjourneySync();
-            return null;
-        }
+    public CommonResult<Long> midjourneyImagine(@Validated @RequestBody AiMidjourneyImagineReqVO reqVO) {
         Long imageId = imageService.midjourneyImagine(getLoginUserId(), reqVO);
         return success(imageId);
     }
 
-    @Operation(summary = "Midjourney 生成图片的回调通知", description = "由 Midjourney Proxy 回调")
-    @PostMapping("/midjourney-notify")
+    @Operation(summary = "【Midjourney】通知图片进展", description = "由 Midjourney Proxy 回调")
+    @PostMapping("/midjourney/notify") // 必须是 POST 方法,否则会报错
     @PermitAll
-    public void midjourneyNotify(@RequestBody MidjourneyApi.Notify notify) {
+    public CommonResult<Boolean> midjourneyNotify(@Validated @RequestBody MidjourneyApi.Notify notify) {
         imageService.midjourneyNotify(notify);
+        return success(true);
     }
 
-    @Operation(summary = "Midjourney Action", description = "例如说:放大、缩小、U1、U2 等")
-    @GetMapping("/midjourney/action")
-    @Parameter(name = "id", description = "图片id", example = "1")
-    @Parameter(name = "customId", description = "操作id", example = "MJ::JOB::upsample::1::85a4b4c1-8835-46c5-a15c-aea34fad1862")
-    public CommonResult<Boolean> midjourneyAction(@RequestParam("id") Long imageId,
-                                                  @RequestParam("customId") String customId) {
-        imageService.midjourneyAction(getLoginUserId(), imageId, customId);
-        return success(true);
+    @Operation(summary = "【Midjourney】Action 操作(二次生成图片)", description = "例如说:放大、缩小、U1、U2 等")
+    @PostMapping("/midjourney/action")
+    public CommonResult<Long> midjourneyAction(@Validated @RequestBody AiMidjourneyActionReqVO reqVO) {
+        Long imageId = imageService.midjourneyAction(getLoginUserId(), reqVO);
+        return success(imageId);
     }
 
 }

+ 0 - 31
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/MidjourneyActionReqVO.java

@@ -1,31 +0,0 @@
-package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
-
-import io.swagger.v3.oas.annotations.media.Schema;
-import jakarta.validation.constraints.NotNull;
-import lombok.Data;
-
-/**
- * Midjourney:action 请求
- *
- * @author fansili
- * @time 2024/5/30 14:02
- * @since 1.0
- */
-@Data
-public class MidjourneyActionReqVO {
-
-    @Schema(description = "操作按钮id", required = true)
-    @NotNull(message = "customId 不能为空!")
-    private String customId;
-
-    @Schema(description = "操作按钮id", required = true)
-    @NotNull(message = "customId 不能为空!")
-    private String taskId;
-
-    @Schema(description = "通知地址", required = false)
-    @NotNull(message = "回调地址不能为空!")
-    private String notifyHook;
-
-    @Schema(description = "自定义参数", required = false)
-    private String state;
-}

+ 20 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/midjourney/AiMidjourneyActionReqVO.java

@@ -0,0 +1,20 @@
+package cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney;
+
+import io.swagger.v3.oas.annotations.media.Schema;
+import jakarta.validation.constraints.NotEmpty;
+import jakarta.validation.constraints.NotNull;
+import lombok.Data;
+
+@Schema(description = "管理后台 - Action(Midjourney) Request VO")
+@Data
+public class AiMidjourneyActionReqVO {
+
+    @Schema(description = "图片编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
+    @NotNull(message = "图片编号不能为空")
+    private Long id;
+
+    @Schema(description = "操作按钮编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "MJ::JOB::variation::4::06aa3e66-0e97-49cc-8201-e0295d883de4")
+    @NotEmpty(message = "操作按钮编号不能为空")
+    private String customId;
+
+}

+ 1 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/midjourney/AiImageMidjourneyImagineReqVO.java → yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/midjourney/AiMidjourneyImagineReqVO.java

@@ -9,7 +9,7 @@ import java.util.List;
 
 @Schema(description = "管理后台 - 绘画生成(Midjourney) Request VO")
 @Data
-public class AiImageMidjourneyImagineReqVO {
+public class AiMidjourneyImagineReqVO {
 
     @Schema(description = "提示词", requiredMode = Schema.RequiredMode.REQUIRED, example = "中国神龙")
     @NotEmpty(message = "提示词不能为空!")

+ 8 - 8
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageService.java

@@ -4,7 +4,8 @@ import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
 import cn.iocoder.yudao.framework.common.pojo.PageParam;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
-import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiImageMidjourneyImagineReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
 
 /**
@@ -57,7 +58,7 @@ public interface AiImageService {
      * @param reqVO 绘制请求
      * @return 绘画编号
      */
-    Long midjourneyImagine(Long userId, AiImageMidjourneyImagineReqVO reqVO);
+    Long midjourneyImagine(Long userId, AiMidjourneyImagineReqVO reqVO);
 
     /**
      * 【Midjourney】同步图片进展
@@ -74,13 +75,12 @@ public interface AiImageService {
     void midjourneyNotify(MidjourneyApi.Notify notify);
 
     /**
-     * midjourney - action(放大、缩小、U1、U2...)
+     * 【Midjourney】Action 操作(放大、缩小、U1、U2...)
      *
-     * @param loginUserId
-     * @param imageId
-     * @param customId
-     * @return
+     * @param userId 用户编号
+     * @param reqVO 绘制请求
+     * @return 绘画编号
      */
-    void midjourneyAction(Long loginUserId, Long imageId, String customId);
+    Long midjourneyAction(Long userId, AiMidjourneyActionReqVO reqVO);
 
 }

+ 31 - 46
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java

@@ -14,7 +14,8 @@ import cn.iocoder.yudao.framework.common.pojo.PageParam;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
-import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiImageMidjourneyImagineReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
 import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
@@ -149,7 +150,7 @@ public class AiImageServiceImpl implements AiImageService {
 
     @Override
     @Transactional(rollbackFor = Exception.class)
-    public Long midjourneyImagine(Long userId, AiImageMidjourneyImagineReqVO reqVO) {
+    public Long midjourneyImagine(Long userId, AiMidjourneyImagineReqVO reqVO) {
         // 1. 保存数据库
         AiImageDO image = BeanUtils.toBean(reqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
                 .setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus())
@@ -170,11 +171,8 @@ public class AiImageServiceImpl implements AiImageService {
         }
 
         // 4. 情况二【成功】:更新 taskId 和参数
-        imageMapper.updateById(new AiImageDO()
-                .setId(image.getId())
-                .setTaskId(imagineResponse.result())
-                .setOptions(BeanUtil.beanToMap(reqVO))
-        );
+        imageMapper.updateById(new AiImageDO().setId(image.getId())
+                .setTaskId(imagineResponse.result()).setOptions(BeanUtil.beanToMap(reqVO)));
         return image.getId();
     }
 
@@ -245,49 +243,36 @@ public class AiImageServiceImpl implements AiImageService {
     }
 
     @Override
-    public void midjourneyAction(Long loginUserId, Long imageId, String customId) {
-        // 1、检查 image
-        AiImageDO image = validateImageExists(imageId);
-        // 2、检查 customId
-        validateCustomId(customId, image.getButtons());
-
-        // 3、调用 midjourney proxy
-        MidjourneyApi.SubmitResponse submitResponse = midjourneyApi.action(
-                new MidjourneyApi.ActionRequest(customId, image.getTaskId(), midjourneyNotifyUrl));
-        // 4、检查错误 code (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误))
-        if (!MidjourneyApi.SubmitCodeEnum.SUCCESS_CODES.contains(submitResponse.code())) {
-            throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, submitResponse.description());
+    public Long midjourneyAction(Long userId, AiMidjourneyActionReqVO reqVO) {
+        // 1.1 检查 image
+        AiImageDO image = validateImageExists(reqVO.getId());
+        if (ObjUtil.notEqual(userId, image.getUserId())) {
+            throw exception(AI_IMAGE_NOT_EXISTS);
+        }
+        // 1.2 检查 customId
+        MidjourneyApi.Button button = CollUtil.findOne(image.getButtons(),
+                buttonX -> buttonX.customId().equals(reqVO.getCustomId()));
+        if (button == null) {
+            throw exception(AI_IMAGE_CUSTOM_ID_NOT_EXISTS);
         }
 
-        // 5、新增 image 记录(根据 image 新增一个)
-        AiImageDO newImage = new AiImageDO();
-        newImage.setUserId(image.getUserId());
-        newImage.setPrompt(image.getPrompt());
-
-        newImage.setPlatform(image.getPlatform());
-        newImage.setModel(image.getModel());
-        newImage.setWidth(image.getWidth());
-        newImage.setHeight(image.getHeight());
-
-        newImage.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
-        newImage.setPublicStatus(image.getPublicStatus());
+        // 2. 调用 Midjourney Proxy 提交任务
+        MidjourneyApi.SubmitResponse actionResponse = midjourneyApi.action(
+                new MidjourneyApi.ActionRequest(button.customId(), image.getTaskId(), midjourneyNotifyUrl));
+        if (!MidjourneyApi.SubmitCodeEnum.SUCCESS_CODES.contains(actionResponse.code())) {
+            String description = actionResponse.description().contains("quota_not_enough") ?
+                    "账户余额不足" : actionResponse.description();
+            throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, description);
+        }
 
-        newImage.setOptions(image.getOptions());
-        newImage.setTaskId(submitResponse.result());
+        // 3. 新增 image 记录
+        AiImageDO newImage = new AiImageDO().setUserId(image.getUserId()).setPublicStatus(false).setPrompt(image.getPrompt())
+                .setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus())
+                .setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform())
+                .setModel(image.getModel()).setWidth(image.getWidth()).setHeight(image.getHeight())
+                .setOptions(image.getOptions()).setTaskId(actionResponse.result());
         imageMapper.insert(newImage);
-    }
-
-    private static void validateCustomId(String customId, List<MidjourneyApi.Button> buttons) {
-        boolean isTrue = false;
-        for (MidjourneyApi.Button button : buttons) {
-            if (button.customId().equals(customId)) {
-                isTrue = true;
-                break;
-            }
-        }
-        if (!isTrue) {
-            throw exception(AI_IMAGE_CUSTOM_ID_NOT_EXISTS);
-        }
+        return newImage.getId();
     }
 
     /**