瀏覽代碼

【增加】Midjourney Proxy 回调通知

cherishsince 1 年之前
父節點
當前提交
56e8707e38

+ 4 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/client/vo/MidjourneyNotifyVO.java → yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/client/vo/MidjourneyNotifyReqVO.java

@@ -11,7 +11,10 @@ import lombok.Data;
  * @since 1.0
  */
 @Data
-public class MidjourneyNotifyVO {
+public class MidjourneyNotifyReqVO {
+
+    @Schema(description = "job id")
+    private String id;
 
     @Schema(description = "任务类型")
     private MidjourneyTaskActionEnum action;

+ 4 - 4
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java

@@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.controller.admin.image;
 import cn.iocoder.yudao.framework.common.pojo.CommonResult;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
+import cn.iocoder.yudao.module.ai.client.vo.MidjourneyNotifyReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageListReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO;
@@ -13,7 +14,6 @@ import io.swagger.v3.oas.annotations.Operation;
 import io.swagger.v3.oas.annotations.Parameter;
 import io.swagger.v3.oas.annotations.tags.Tag;
 import jakarta.annotation.Resource;
-import jakarta.servlet.http.HttpServletRequest;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.validation.annotation.Validated;
 import org.springframework.web.bind.annotation.*;
@@ -74,9 +74,9 @@ public class AiImageController {
         return success(aiImageService.deleteIdMy(id, getLoginUserId()));
     }
 
-    @Operation(summary = "删除【我的】绘画记录")
+    @Operation(summary = "midjourney proxy - 回调通知")
     @RequestMapping("/midjourney-notify")
-    public CommonResult<Boolean> midjourneyNotify(HttpServletRequest request) {
-        return success(true);
+    public CommonResult<Boolean> midjourneyNotify(MidjourneyNotifyReqVO notifyReqVO) {
+        return success(aiImageService.midjourneyNotify(getLoginUserId(), notifyReqVO));
     }
 }

+ 3 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/image/AiImageDO.java

@@ -28,6 +28,9 @@ public class AiImageDO extends BaseDO {
     @Schema(description = "用户编号")
     private Long userId;
 
+    @Schema(description = "midjourney proxy 关联的 job id")
+    private String jobId;
+
     @Schema(description = "提示词")
     private String prompt;
 

+ 10 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/image/AiImageMapper.java

@@ -4,7 +4,6 @@ import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
 import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
 import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
 import org.apache.ibatis.annotations.Mapper;
-import org.springframework.stereotype.Repository;
 
 /**
  * AI 绘图 Mapper
@@ -26,4 +25,14 @@ public interface AiImageMapper extends BaseMapperX<AiImageDO> {
         return;
     }
 
+    /**
+     * 查询 - 根据 job id
+     *
+     * @param id
+     * @return
+     */
+    default AiImageDO selectByJobId(String id) {
+        return this.selectOne(new LambdaQueryWrapperX<AiImageDO>().eq(AiImageDO::getJobId, id));
+    }
+
 }

+ 9 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageService.java

@@ -1,6 +1,7 @@
 package cn.iocoder.yudao.module.ai.service.image;
 
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
+import cn.iocoder.yudao.module.ai.client.vo.MidjourneyNotifyReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageListReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO;
@@ -65,4 +66,12 @@ public interface AiImageService {
      */
     Boolean deleteIdMy(Long id, Long loginUserId);
 
+    /**
+     * midjourney proxy - 回调通知
+     *
+     * @param loginUserId
+     * @param notifyReqVO
+     * @return
+     */
+    Boolean midjourneyNotify(Long loginUserId, MidjourneyNotifyReqVO notifyReqVO);
 }

+ 37 - 34
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java

@@ -1,5 +1,7 @@
 package cn.iocoder.yudao.module.ai.service.image;
 
+import cn.hutool.core.bean.BeanUtil;
+import cn.hutool.core.util.StrUtil;
 import cn.hutool.http.HttpUtil;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageModelEnum;
@@ -14,9 +16,14 @@ import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
 import cn.iocoder.yudao.module.ai.client.MidjourneyProxyClient;
 import cn.iocoder.yudao.module.ai.client.enums.MidjourneyModelEnum;
 import cn.iocoder.yudao.module.ai.client.enums.MidjourneySubmitCodeEnum;
+import cn.iocoder.yudao.module.ai.client.enums.MidjourneyTaskStatusEnum;
 import cn.iocoder.yudao.module.ai.client.vo.MidjourneyImagineReqVO;
+import cn.iocoder.yudao.module.ai.client.vo.MidjourneyNotifyReqVO;
 import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitRespVO;
-import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageListReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyOperateReqVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
 import cn.iocoder.yudao.module.ai.enums.AiImagePublicStatusEnum;
@@ -36,15 +43,9 @@ import org.springframework.scheduling.annotation.Async;
 import org.springframework.stereotype.Service;
 import org.springframework.transaction.annotation.Transactional;
 
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-
 import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
 
 
-// TODO @fan:注释优化下哈
-
 /**
  * AI 绘画(接入 dall2/dall3、midjourney)
  *
@@ -56,9 +57,6 @@ import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionU
 @Slf4j
 public class AiImageServiceImpl implements AiImageService {
 
-    // TODO @fan:使用 @Resource 注入
-
-    // TODO @fan:imageMapper
     @Resource
     private AiImageMapper imageMapper;
     @Resource
@@ -173,19 +171,16 @@ public class AiImageServiceImpl implements AiImageService {
         // 4、保存任务 id (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误))
         String updateStatus = null;
         String errorMessage = null;
-        Map<String, Object> drawResponse = new HashMap<>();
 
         if (!MidjourneySubmitCodeEnum.SUCCESS_CODES.contains(submitRespVO.getCode())) {
             updateStatus = AiImageStatusEnum.FAIL.getStatus();
             errorMessage = submitRespVO.getDescription();
-        } else {
-            drawResponse.put("jobId", submitRespVO.getResult());
         }
         imageMapper.updateById(new AiImageDO()
                 .setId(aiImageDO.getId())
                 .setStatus(updateStatus)
                 .setErrorMessage(errorMessage)
-                .setDrawResponse(drawResponse)
+                .setJobId(submitRespVO.getResult())
         );
         return aiImageDO.getId();
     }
@@ -228,28 +223,36 @@ public class AiImageServiceImpl implements AiImageService {
         return imageMapper.deleteById(id) > 0;
     }
 
-    private void validateMessageId(String mjMessageId, String messageId) {
-        if (!mjMessageId.equals(messageId)) {
-            throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_MESSAGE_ID_INCORRECT);
+    @Override
+    public Boolean midjourneyNotify(Long loginUserId, MidjourneyNotifyReqVO notifyReqVO) {
+        // 1、根据 job id 查询关联的 image
+        AiImageDO image = imageMapper.selectByJobId(notifyReqVO.getId());
+        if (image == null) {
+            log.warn("midjourneyNotify 回调的 jobId 不存在! jobId: {}", notifyReqVO.getId());
+            return false;
         }
-    }
-
-    private AiImageMidjourneyOperationsVO validateMidjourneyOperationsExists(List<AiImageMidjourneyOperationsVO> midjourneyOperations, String operateId) {
-        for (AiImageMidjourneyOperationsVO midjourneyOperation : midjourneyOperations) {
-            if (midjourneyOperation.getCustom_id().equals(operateId)) {
-                return midjourneyOperation;
-            }
+        //
+        String imageStatus = null;
+        if (MidjourneyTaskStatusEnum.SUCCESS == notifyReqVO.getStatus()) {
+            imageStatus = AiImageStatusEnum.COMPLETE.getStatus();
+        } else if (MidjourneyTaskStatusEnum.FAILURE == notifyReqVO.getStatus()) {
+            imageStatus = AiImageStatusEnum.FAIL.getStatus();
         }
-        throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_OPERATION_NOT_EXISTS);
-    }
-
-
-    private List<AiImageMidjourneyOperationsVO> getMidjourneyOperations(AiImageDO aiImageDO) {
-//        if (StrUtil.isBlank(aiImageDO.getMjOperations())) {
-//            return Collections.emptyList();
-//        }
-//        return JsonUtils.parseArray(aiImageDO.getMjOperations(), AiImageMidjourneyOperationsVO.class);
-        return null;
+        // 2、上传图片
+        String filePath = null;
+        if (!StrUtil.isBlank(notifyReqVO.getImageUrl())) {
+            filePath = fileApi.createFile(HttpUtil.downloadBytes(notifyReqVO.getImageUrl()));
+        }
+        // 2、更新 image 状态
+        imageMapper.updateById(
+                new AiImageDO()
+                        .setId(image.getId())
+                        .setStatus(imageStatus)
+                        .setPicUrl(filePath)
+                        .setOriginalPicUrl(notifyReqVO.getImageUrl())
+                        .setDrawResponse(BeanUtil.beanToMap(notifyReqVO))
+        );
+        return true;
     }
 
     private AiImageDO validateExists(Long id) {