Browse Source

【增加】对接 Midjourney,增加nonce传递,更新Midjourney image 状态

cherishsince 1 year ago
parent
commit
03b4460eae

+ 3 - 3
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/AiImageController.java

@@ -4,7 +4,6 @@ import cn.iocoder.yudao.framework.common.pojo.CommonResult;
 import cn.iocoder.yudao.module.ai.service.AiImageService;
 import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq;
 import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyReq;
-import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyRes;
 import io.swagger.v3.oas.annotations.Operation;
 import io.swagger.v3.oas.annotations.tags.Tag;
 import lombok.AllArgsConstructor;
@@ -42,7 +41,8 @@ public class AiImageController {
 
     @Operation(summary = "midjourney", description = "midjourney图片绘画流程:1、提交任务 2、获取完成的任务 3、选择对应功能 4、获取最终结果")
     @PostMapping("/midjourney")
-    public CommonResult<AiImageMidjourneyRes> midjourney(@Validated @RequestBody AiImageMidjourneyReq req) {
-        return CommonResult.success(aiImageService.midjourney(req));
+    public CommonResult<Void> midjourney(@Validated @RequestBody AiImageMidjourneyReq req) {
+        aiImageService.midjourney(req);
+        return CommonResult.success(null);
     }
 }

+ 1 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiImageService.java

@@ -28,5 +28,5 @@ public interface AiImageService {
      * @param req
      * @return
      */
-    AiImageMidjourneyRes midjourney(AiImageMidjourneyReq req);
+    void midjourney(AiImageMidjourneyReq req);
 }

+ 5 - 7
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiImageServiceImpl.java

@@ -95,18 +95,15 @@ public class AiImageServiceImpl implements AiImageService {
 
     @Override
     @Transactional(rollbackFor = Exception.class)
-    public AiImageMidjourneyRes midjourney(AiImageMidjourneyReq req) {
+    public void midjourney(AiImageMidjourneyReq req) {
         // 保存数据库
-        doSave(req.getPrompt(), null, "midjoureny",
+        AiImageDO aiImageDO = doSave(req.getPrompt(), null, "midjoureny",
                 null, AiChatDrawingStatusEnum.SUBMIT, null);
         // 提交 midjourney 任务
-        Boolean imagine = midjourneyInteractionsApi.imagine(req.getPrompt());
+        Boolean imagine = midjourneyInteractionsApi.imagine(aiImageDO.getId(), req.getPrompt());
         if (!imagine) {
             throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_IMAGINE_FAIL);
         }
-        //
-
-        return null;
     }
 
     private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) {
@@ -120,7 +117,7 @@ public class AiImageServiceImpl implements AiImageService {
         }
     }
 
-    private void doSave(String prompt,
+    private AiImageDO doSave(String prompt,
                         String size,
                         String model,
                         String imageUrl,
@@ -138,5 +135,6 @@ public class AiImageServiceImpl implements AiImageService {
         aiImageDO.setDrawingStatus(drawingStatusEnum.getStatus());
         aiImageDO.setDrawingError(drawingError);
         aiImageMapper.insert(aiImageDO);
+        return aiImageDO;
     }
 }

+ 50 - 1
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/midjourneyHandler/YuDaoMidjourneyMessageHandler.java

@@ -1,7 +1,15 @@
 package cn.iocoder.yudao.module.ai.service.midjourneyHandler;
 
+import cn.hutool.core.collection.CollUtil;
+import cn.hutool.core.util.StrUtil;
 import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyMessage;
+import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyGennerateStatusEnum;
 import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyMessageHandler;
+import cn.iocoder.yudao.module.ai.dal.dataobject.AiImageDO;
+import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum;
+import cn.iocoder.yudao.module.ai.mapper.AiImageMapper;
+import com.alibaba.fastjson2.JSON;
+import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.stereotype.Component;
 
@@ -14,10 +22,51 @@ import org.springframework.stereotype.Component;
  */
 @Component
 @Slf4j
+@AllArgsConstructor
 public class YuDaoMidjourneyMessageHandler implements MidjourneyMessageHandler {
 
+    private final AiImageMapper aiImageMapper;
+
     @Override
     public void messageHandler(MidjourneyMessage midjourneyMessage) {
-        log.info("yudao-midjourney-midjourney-message-handler", midjourneyMessage);
+        log.info("yudao-midjourney-midjourney-message-handler {}", JSON.toJSONString(midjourneyMessage));
+        if (midjourneyMessage.getContent() != null) {
+            log.info("进度id {} 状态 {} 进度 {}",
+                    midjourneyMessage.getNonce(),
+                    midjourneyMessage.getGenerateStatus(),
+                    midjourneyMessage.getContent().getProgress());
+        }
+        //
+        updateImage(midjourneyMessage);
+    }
+
+    private void updateImage(MidjourneyMessage midjourneyMessage) {
+        // Nonce 不存在不更新
+        if (StrUtil.isBlank(midjourneyMessage.getNonce())) {
+            return;
+        }
+        // 获取id
+        Long aiImageId = Long.valueOf(midjourneyMessage.getNonce());
+        // 获取生成 url
+        String imageUrl = null;
+        if (CollUtil.isNotEmpty(midjourneyMessage.getAttachments())) {
+            imageUrl = midjourneyMessage.getAttachments().get(0).getUrl();
+        }
+        // 转换状态
+        AiChatDrawingStatusEnum drawingStatusEnum = null;
+        String generateStatus = midjourneyMessage.getGenerateStatus();
+        if (MidjourneyGennerateStatusEnum.COMPLETED.getStatus().equals(generateStatus)) {
+            drawingStatusEnum = AiChatDrawingStatusEnum.COMPLETE;
+        } else if (MidjourneyGennerateStatusEnum.IN_PROGRESS.getStatus().equals(generateStatus)) {
+            drawingStatusEnum = AiChatDrawingStatusEnum.IN_PROGRESS;
+        }  else if (MidjourneyGennerateStatusEnum.WAITING.getStatus().equals(generateStatus)) {
+            drawingStatusEnum = AiChatDrawingStatusEnum.WAITING;
+        }
+        aiImageMapper.updateById(
+                new AiImageDO()
+                        .setId(aiImageId)
+                        .setDrawingImageUrl(imageUrl)
+                        .setDrawingStatus(drawingStatusEnum == null ? null : drawingStatusEnum.getStatus())
+        );
     }
 }

+ 47 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/MidjourneyMessage.java

@@ -14,6 +14,10 @@ public class MidjourneyMessage {
 	 * id是一个重要的字段,在同时生成多个的时候,可以区分生成信息
 	 */
 	private String id;
+	/**
+	 * 提交id(nonce 可能会不存在,系统提示的时候,这个为空)
+	 */
+	private String nonce;
 	/**
 	 * 现在已知:
 	 * 0:我们发送的消息,和指令
@@ -45,6 +49,14 @@ public class MidjourneyMessage {
 	 * {@link MidjourneyGennerateStatusEnum}
 	 */
 	private String generateStatus;
+	/**
+	 * 一般用于提示信息
+	 * - 错误
+	 * - 并发队列满了
+	 * - 账号违规了、敏感词
+	 * - 账号被封
+	 */
+	private List<Embed> embeds;
 
 	@Data
 	@Accessors(chain = true)
@@ -123,4 +135,39 @@ public class MidjourneyMessage {
 		private String progress;
 		private String status;
 	}
+
+	/**
+	 * embed 用于警告、提示、错误
+	 */
+	@Data
+	@Accessors(chain = true)
+	public static class Embed {
+
+		// 内容扫描版本号
+		private int contentScanVersion;
+
+		// 颜色值,这里用Java的Color类来表示,注意实际使用中可能需要自定义方法来从int转换为Color对象
+		private String color;
+
+		// 页脚信息,包含文本
+		private Footer footer;
+
+		// 描述信息
+		private String description;
+
+		// 消息类型,这里是富文本类型(这个区分不同提示类型)
+		private String type;
+
+		// 标题
+		private String title;
+
+		// Footer类,作为嵌套类存在,用来表示footer部分的JSON对象
+		@Data
+		@Accessors(chain = true)
+		public static class Footer {
+			// 页脚文本
+			private String text;
+		}
+
+	}
 }

+ 3 - 1
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/api/MidjourneyInteractionsApi.java

@@ -38,11 +38,13 @@ public class MidjourneyInteractionsApi extends MidjourneyInteractions {
         this.url = midjourneyConfig.getServerUrl().concat(midjourneyConfig.getApiInteractions());
     }
 
-    public Boolean imagine(String prompt) {
+    public Boolean imagine(Long id, String prompt) {
+        String nonce = String.valueOf(id);
         // 获取请求模板
         String requestTemplate = midjourneyConfig.getRequestTemplates().get("imagine");
         // 设置参数
         HashMap<String, String> requestParams = getDefaultParams();
+        requestParams.put("nonce", nonce);
         requestParams.put("prompt", prompt);
         // 解析 template 参数占位符
         String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams);

+ 8 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/constants/MidjourneyConstants.java

@@ -6,6 +6,10 @@ public final class MidjourneyConstants {
 	 * 消息 - 编号
 	 */
 	public static final String MSG_ID = "id";
+	/**
+	 * 用于区分操作唯一性
+	 */
+	public static final String MSG_NONCE = "nonce";
 	/**
 	 * 消息 - 类型
 	 * 现在已知:
@@ -32,6 +36,10 @@ public final class MidjourneyConstants {
 	 * 附件(生成中比较模糊的图片)
 	 */
 	public static final String MSG_ATTACHMENTS = "attachments";
+	/**
+	 * 一般用于提示
+	 */
+	public static final String MSG_EMBEDS = "embeds";
 
 
 	//

+ 23 - 2
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/webSocket/listener/MidjourneyMessageListener.java

@@ -42,12 +42,14 @@ public class MidjourneyMessageListener {
         if (ignoreAndLogMessage(data, messageType)) {
             return;
         }
+        log.info("socket message: {}", raw);
         // 转换几个重要的信息
         MidjourneyMessage mjMessage = new MidjourneyMessage();
-        mjMessage.setId(data.getString(MidjourneyConstants.MSG_ID));
+        mjMessage.setId(getString(data, MidjourneyConstants.MSG_ID, ""));
+        mjMessage.setNonce(getString(data, MidjourneyConstants.MSG_NONCE, ""));
         mjMessage.setType(data.getInt(MidjourneyConstants.MSG_TYPE));
         mjMessage.setRawData(StrUtil.str(raw.toJson(), "UTF-8"));
-		mjMessage.setContent(MidjourneyUtil.parseContent(data.getString(MidjourneyConstants.MSG_CONTENT)));
+        mjMessage.setContent(MidjourneyUtil.parseContent(data.getString(MidjourneyConstants.MSG_CONTENT)));
         // 转换 components
         if (!data.getArray(MidjourneyConstants.MSG_COMPONENTS).isEmpty()) {
             String componentsJson = StrUtil.str(data.getArray(MidjourneyConstants.MSG_COMPONENTS).toJson(), "UTF-8");
@@ -60,6 +62,12 @@ public class MidjourneyMessageListener {
             List<MidjourneyMessage.Attachment> attachments = JsonUtils.parseArray(attachmentsJson, MidjourneyMessage.Attachment.class);
             mjMessage.setAttachments(attachments);
         }
+        // 转换 embeds 提示信息
+        if (!data.getArray(MidjourneyConstants.MSG_EMBEDS).isEmpty()) {
+            String embedJson = StrUtil.str(data.getArray(MidjourneyConstants.MSG_EMBEDS).toJson(), "UTF-8");
+            List<MidjourneyMessage.Embed> embeds = JsonUtils.parseArray(embedJson, MidjourneyMessage.Embed.class);
+            mjMessage.setEmbeds(embeds);
+        }
         // 转换状态
         convertGenerateStatus(mjMessage);
         // message handler 调用
@@ -68,7 +76,20 @@ public class MidjourneyMessageListener {
         }
     }
 
+    private String getString(DataObject data, String key, String defaultValue) {
+        if (!data.hasKey(key)) {
+            return defaultValue;
+        }
+        return data.getString(key);
+    }
+
     private void convertGenerateStatus(MidjourneyMessage mjMessage) {
+        //
+        // tip:提示、警告、异常 content是没有内容的
+        // tip: 一般错误信息在 Embeds 只要 Embeds有值,content就没信息。
+        if (CollUtil.isNotEmpty(mjMessage.getEmbeds())) {
+            return;
+        }
         if (mjMessage.getType() == 20 && mjMessage.getContent().getStatus().contains("Waiting")) {
             mjMessage.setGenerateStatus(MidjourneyGennerateStatusEnum.WAITING.getStatus());
         } else if (mjMessage.getType() == 20 && !StrUtil.isBlank(mjMessage.getContent().getProgress())) {