Parcourir la source

【增加】增加 midjourney 提交任务

cherishsince il y a 1 an
Parent
commit
a1f738dd81

+ 36 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/client/vo/MidjourneySubmitCodeEnum.java

@@ -0,0 +1,36 @@
+package cn.iocoder.yudao.module.ai.client.vo;
+
+import com.google.common.collect.Lists;
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+
+import java.util.List;
+
+/**
+ * Midjourney 提交任务 code 枚举
+ *
+ * @author fansili
+ * @time 2024/5/30 14:33
+ * @since 1.0
+ */
+@Getter
+@AllArgsConstructor
+public enum MidjourneySubmitCodeEnum {
+
+    // 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)
+    SUBMIT_SUCCESS("1", "提交成功"),
+    ALREADY_EXISTS("1", "已存在"),
+    QUEUING("22", "排队中"),
+
+    ;
+
+    public static final List<String> SUCCESS_CODES = Lists.newArrayList(
+            SUBMIT_SUCCESS.code,
+            ALREADY_EXISTS.code,
+            QUEUING.code
+    );
+
+    private String code;
+    private String name;
+
+}

+ 3 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/client/vo/MidjourneySubmitRespVO.java

@@ -3,6 +3,8 @@ package cn.iocoder.yudao.module.ai.client.vo;
 import io.swagger.v3.oas.annotations.media.Schema;
 import lombok.Data;
 
+import java.util.Map;
+
 /**
  * Midjourney:Imagine 请求
  *
@@ -20,7 +22,7 @@ public class MidjourneySubmitRespVO {
     private String description;
 
     @Schema(description = "扩展字段")
-    private String properties;
+    private Map<String, Object> properties;
 
     @Schema(description = "任务ID")
     private String result;

+ 16 - 21
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java

@@ -3,13 +3,17 @@ package cn.iocoder.yudao.module.ai.controller.admin.image;
 import cn.iocoder.yudao.framework.common.pojo.CommonResult;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
-import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageListReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageMyRespVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
 import cn.iocoder.yudao.module.ai.service.image.AiImageService;
 import io.swagger.v3.oas.annotations.Operation;
 import io.swagger.v3.oas.annotations.Parameter;
 import io.swagger.v3.oas.annotations.tags.Tag;
 import jakarta.annotation.Resource;
+import jakarta.servlet.http.HttpServletRequest;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.validation.annotation.Validated;
 import org.springframework.web.bind.annotation.*;
@@ -49,32 +53,17 @@ public class AiImageController {
     }
 
     // TODO @fan:建议把 dallDrawing、midjourney 融合成一个 draw 接口,异步绘制;然后返回一个 id 给前端;前端通过 get 接口轮询,直到获取到生成成功
+    // TODO @芋艿: 参数差异较大
     @Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!")
     @PostMapping("/dall")
     public CommonResult<Long> dall(@Validated @RequestBody AiImageDallReqVO req) {
         return success(aiImageService.dall(getLoginUserId(), req));
     }
 
-    @Operation(summary = "midjourney绘画", description = "midjourney图片绘画流程:1、提交任务 2、获取完成的任务 3、选择对应功能 4、获取最终结果")
-    @PostMapping("/midjourney")
-    public CommonResult<Void> midjourney(@Validated @RequestBody AiImageMidjourneyReqVO req) {
-        aiImageService.midjourney(req);
-        return success(null);
-    }
-
-    @Operation(summary = "midjourney绘画操作", description = "一般有选择图片、放大、换一批...")
-    @PostMapping("/midjourney-operate")
-    public CommonResult<Void> midjourneyOperate(@Validated @RequestBody AiImageMidjourneyOperateReqVO req) {
-        aiImageService.midjourneyOperate(req);
-        return success(null);
-    }
-
-    // TODO @fan:要不先不要 midjourneyOperate、cancelMidjourney 接口哈
-    @Operation(summary = "取消 midjourney 绘画", description = "取消 midjourney 绘画")
-    @PostMapping("/cancel-midjourney")
-    public CommonResult<Void> cancelMidjourney(@RequestParam("id") Long id) {
-        // @范 这里实现mj取消逻辑
-        return success(null);
+    @Operation(summary = "midjourney-imagine 绘画", description = "...")
+    @PostMapping("/midjourney/imagine")
+    public CommonResult<Long> midjourneyImagine(@Validated @RequestBody AiImageMidjourneyImagineReqVO req) {
+        return success(aiImageService.midjourneyImagine(getLoginUserId(), req));
     }
 
     @Operation(summary = "删除【我的】绘画记录")
@@ -83,4 +72,10 @@ public class AiImageController {
     public CommonResult<Boolean> deleteIdMy(@RequestParam("id") Long id) {
         return success(aiImageService.deleteIdMy(id, getLoginUserId()));
     }
+
+    @Operation(summary = "删除【我的】绘画记录")
+    @RequestMapping("/midjourney-notify")
+    public CommonResult<Boolean> midjourneyNotify(HttpServletRequest request) {
+        return success(true);
+    }
 }

+ 9 - 8
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImageMidjourneyReqVO.java → yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImageMidjourneyImagineReqVO.java

@@ -1,9 +1,12 @@
 package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
 
 import io.swagger.v3.oas.annotations.media.Schema;
+import jakarta.validation.constraints.NotNull;
 import lombok.Data;
 import lombok.experimental.Accessors;
 
+import java.util.List;
+
 /**
  * midjourney req
  *
@@ -13,17 +16,15 @@ import lombok.experimental.Accessors;
  */
 @Data
 @Accessors(chain = true)
-public class AiImageMidjourneyReqVO {
+public class AiImageMidjourneyImagineReqVO {
 
     @Schema(description = "提示词")
+    @NotNull(message = "提示词不能为空!")
     private String prompt;
 
-    @Schema(description = "绘画比例 1:1、3:4、4:3、9:16、16:9")
-    private String size;
-
-    @Schema(description = "风格")
-    private String style;
+    @Schema(description = "模型(midjourney、niji)")
+    private String model;
 
-    @Schema(description = "参考图")
-    private String referImage;
+    @Schema(description = "垫图(参考图)base64数组")
+    private List<String> base64Array;
 }

+ 3 - 2
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageService.java

@@ -3,8 +3,8 @@ package cn.iocoder.yudao.module.ai.service.image;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageListReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyOperateReqVO;
-import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReqVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
 
 /**
@@ -44,10 +44,11 @@ public interface AiImageService {
     /**
      * midjourney 图片生成
      *
+     * @param loginUserId
      * @param req
      * @return
      */
-    void midjourney(AiImageMidjourneyReqVO req);
+    Long midjourneyImagine(Long loginUserId, AiImageMidjourneyImagineReqVO req);
 
     /**
      * midjourney 操作(u1、u2、放大、换一批...)

+ 60 - 37
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java

@@ -1,7 +1,7 @@
 package cn.iocoder.yudao.module.ai.service.image;
 
-import cn.hutool.core.util.IdUtil;
 import cn.hutool.http.HttpUtil;
+import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageModelEnum;
 import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageStyleEnum;
 import cn.iocoder.yudao.framework.ai.core.exception.AiException;
@@ -11,6 +11,10 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
 import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
 import cn.iocoder.yudao.module.ai.AiCommonConstants;
 import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
+import cn.iocoder.yudao.module.ai.client.MidjourneyProxyClient;
+import cn.iocoder.yudao.module.ai.client.vo.MidjourneyImagineReqVO;
+import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitCodeEnum;
+import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitRespVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
 import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
@@ -18,21 +22,22 @@ import cn.iocoder.yudao.module.ai.enums.AiImagePublicStatusEnum;
 import cn.iocoder.yudao.module.ai.enums.AiImageStatusEnum;
 import cn.iocoder.yudao.module.infra.api.file.FileApi;
 import com.google.common.collect.ImmutableMap;
-import jakarta.annotation.PostConstruct;
 import jakarta.annotation.Resource;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.ai.image.ImageGeneration;
 import org.springframework.ai.image.ImagePrompt;
 import org.springframework.ai.image.ImageResponse;
-import org.springframework.ai.models.midjourney.api.MidjourneyInteractionsApi;
-import org.springframework.ai.models.midjourney.webSocket.MidjourneyWebSocketStarter;
 import org.springframework.ai.openai.OpenAiImageClient;
 import org.springframework.ai.openai.OpenAiImageOptions;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.beans.factory.annotation.Value;
 import org.springframework.scheduling.annotation.Async;
 import org.springframework.stereotype.Service;
 import org.springframework.transaction.annotation.Transactional;
 
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 
 import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
 
@@ -59,28 +64,11 @@ public class AiImageServiceImpl implements AiImageService {
     private FileApi fileApi;
     @Resource
     private OpenAiImageClient openAiImageClient;
-    @Resource
-    private MidjourneyWebSocketStarter midjourneyWebSocketStarter;
-    @Resource
-    private MidjourneyInteractionsApi midjourneyInteractionsApi;
-
-    // TODO @fan:接 mj proxy
-    @PostConstruct
-    public void startMidjourney() {
-        // todo @fan 暂时注释掉
-//        log.info("midjourney web socket starter...");
-//        midjourneyWebSocketStarter.start(new WssNotify() {
-//            @Override
-//            public void notify(int code, String message) {
-//                log.info("code: {}, message: {}", code, message);
-//                if (message.contains("Authentication failed")) {
-//                    // TODO 芋艿,这里看怎么处理,token无效的时候会认证失败!
-//                    // 认证失败
-//                    log.error("midjourney socket 认证失败,检查token是否失效!");
-//                }
-//            }
-//        });
-    }
+    @Autowired
+    private MidjourneyProxyClient midjourneyProxyClient;
+
+    @Value("${ai.midjourney-proxy.notifyUrl:http://127.0.0.1:48080/admin-api/ai/image/midjourney-notify}")
+    private String midjourneyNotifyUrl;
 
     @Override
     public PageResult<AiImageDO> getImagePageMy(Long loginUserId, AiImageListReqVO req) {
@@ -143,18 +131,53 @@ public class AiImageServiceImpl implements AiImageService {
 
     @Override
     @Transactional(rollbackFor = Exception.class)
-    public void midjourney(AiImageMidjourneyReqVO req) {
-        // 保存数据库
-        String messageId = String.valueOf(IdUtil.getSnowflakeNextId());
-        // todo
-//        AiImageDO aiImageDO = doSave(req.getPrompt(), null, "midjoureny",
-//                null, null, AiImageStatusEnum.SUBMIT, null,
-//                messageId, null, null);
-        // 提交 midjourney 任务
-        Boolean imagine = midjourneyInteractionsApi.imagine(messageId, req.getPrompt());
-        if (!imagine) {
-            throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_IMAGINE_FAIL);
+    public Long midjourneyImagine(Long loginUserId, AiImageMidjourneyImagineReqVO req) {
+
+        // 1、构建 AiImageDO
+        AiImageDO aiImageDO = new AiImageDO();
+        aiImageDO.setId(null);
+        aiImageDO.setUserId(loginUserId);
+        aiImageDO.setPrompt(req.getPrompt());
+        aiImageDO.setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform());
+        // todo @范 平台需要转换(mj 模型一般分版本)
+        aiImageDO.setModel(null);
+        aiImageDO.setWidth(null);
+        aiImageDO.setHeight(null);
+        aiImageDO.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
+        aiImageDO.setPublicStatus(AiImagePublicStatusEnum.PRIVATE.getStatus());
+        aiImageDO.setPicUrl(null);
+        aiImageDO.setOriginalPicUrl(null);
+        aiImageDO.setDrawRequest(null);
+        aiImageDO.setDrawResponse(null);
+        aiImageDO.setErrorMessage(null);
+
+        // 2、保存 image
+        imageMapper.insert(aiImageDO);
+
+        // 3、调用 MidjourneyProxy 提交任务
+        MidjourneyImagineReqVO imagineReqVO = BeanUtils.toBean(req, MidjourneyImagineReqVO.class);
+        imagineReqVO.setNotifyHook(midjourneyNotifyUrl);
+        imagineReqVO.setState(String.valueOf(aiImageDO.getId()));
+        MidjourneySubmitRespVO submitRespVO = midjourneyProxyClient.imagine(imagineReqVO);
+
+        // 4、保存任务 id (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误))
+        String updateStatus = null;
+        String errorMessage = null;
+        Map<String, Object> drawResponse = new HashMap<>();
+
+        if (!MidjourneySubmitCodeEnum.SUCCESS_CODES.contains(submitRespVO.getCode())) {
+            updateStatus = AiImageStatusEnum.FAIL.getStatus();
+            errorMessage = submitRespVO.getDescription();
+        } else {
+            drawResponse.put("jobId", submitRespVO.getResult());
         }
+        imageMapper.updateById(new AiImageDO()
+                .setId(aiImageDO.getId())
+                .setStatus(updateStatus)
+                .setErrorMessage(errorMessage)
+                .setDrawResponse(drawResponse)
+        );
+        return aiImageDO.getId();
     }
 
     @Transactional(rollbackFor = Exception.class)