Bläddra i källkod

【修改】AI Image dall 请求返回结构优化

cherishsince 10 månader sedan
förälder
incheckning
63a8cc244d

+ 2 - 2
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java

@@ -32,8 +32,8 @@ public class AiImageController {
 
     @Operation(summary = "获取image列表", description = "dall3、midjourney")
     @GetMapping("/list")
-    public PageResult<AiImageListRespVO> list(@Validated @RequestBody AiImageListReqVO req) {
-        return aiImageService.list(req);
+    public CommonResult<PageResult<AiImageListRespVO>> list(@Validated @ModelAttribute AiImageListReqVO req) {
+        return CommonResult.success(aiImageService.list(req));
     }
 
     @Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!")

+ 49 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImageListRespVO.java

@@ -1,6 +1,9 @@
 package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
 
 import cn.iocoder.yudao.framework.common.pojo.PageParam;
+import com.baomidou.mybatisplus.annotation.IdType;
+import com.baomidou.mybatisplus.annotation.TableId;
+import io.swagger.v3.oas.annotations.media.Schema;
 import lombok.Data;
 import lombok.experimental.Accessors;
 
@@ -15,4 +18,50 @@ import lombok.experimental.Accessors;
 @Accessors(chain = true)
 public class AiImageListRespVO extends PageParam {
 
+    private Long id;
+
+    @Schema(description = "用户id")
+    private Long userId;
+
+    @Schema(description = "提示词")
+    private String prompt;
+
+    @Schema(description = "模型 dall2/dall3、MJ、NIJI")
+    private String model;
+
+    @Schema(description = "生成图像的尺寸大小。对于dall-e-2模型,尺寸可为256x256, 512x512, 或 1024x1024。对于dall-e-3模型,尺寸可为1024x1024, 1792x1024, 或 1024x1792。")
+    private String size;
+
+    @Schema(description = "风格")
+    private String style;
+
+    @Schema(description = "图片地址(自己服务器)")
+    private String picUrl;
+
+    @Schema(description = "绘画状态:提交、排队、绘画中、绘画完成、绘画失败")
+    private String status;
+
+    @Schema(description = "绘画图片地址(绘画好的服务器)")
+    private String originalPicUrl;
+
+    @Schema(description = "绘画错误信息")
+    private String errorMessage;
+
+    @Schema(description = "是否发布")
+    private String publicStatus;
+
+    // ============ mj 需要字段
+
+    @Schema(description = "用户操作的Nonce编号(MJ返回)")
+    private String mjNonceId;
+
+    @Schema(description = "用户操作的操作编号(MJ返回)")
+    private String mjOperationId;
+
+    @Schema(description = "用户操作的操作名字(MJ返回)")
+    private String mjOperationName;
+
+    @Schema(description = "mj图片生产成功保存的 components json 数组")
+    private String mjOperations;
+
 }

+ 8 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/convert/AiImageConvert.java

@@ -23,6 +23,14 @@ public interface AiImageConvert {
 
     AiImageConvert INSTANCE = Mappers.getMapper(AiImageConvert.class);
 
+    /**
+     * 转换 - AiImageDallDrawingRespVO
+     *
+     * @param req
+     * @return
+     */
+    AiImageDallRespVO convertAiImageDallDrawingRespVO(AiImageDO req);
+
     /**
      * 转换 - AiImageDallDrawingRespVO
      *

+ 24 - 22
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java

@@ -2,16 +2,10 @@ package cn.iocoder.yudao.module.ai.service.image;
 
 import cn.hutool.core.util.IdUtil;
 import cn.hutool.core.util.StrUtil;
-import cn.iocoder.yudao.framework.ai.core.exception.AiException;
-import org.springframework.ai.image.ImageGeneration;
-import org.springframework.ai.image.ImagePrompt;
-import org.springframework.ai.image.ImageResponse;
+import cn.hutool.http.HttpUtil;
 import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageModelEnum;
 import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageStyleEnum;
-import org.springframework.ai.models.midjourney.api.MidjourneyInteractionsApi;
-import org.springframework.ai.models.midjourney.api.req.ReRollReq;
-import org.springframework.ai.models.midjourney.webSocket.MidjourneyWebSocketStarter;
-import org.springframework.ai.models.midjourney.webSocket.WssNotify;
+import cn.iocoder.yudao.framework.ai.core.exception.AiException;
 import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
@@ -23,9 +17,17 @@ import cn.iocoder.yudao.module.ai.convert.AiImageConvert;
 import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
 import cn.iocoder.yudao.module.ai.enums.AiImageStatusEnum;
+import cn.iocoder.yudao.module.infra.api.file.FileApi;
 import jakarta.annotation.PostConstruct;
 import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
+import org.springframework.ai.image.ImageGeneration;
+import org.springframework.ai.image.ImagePrompt;
+import org.springframework.ai.image.ImageResponse;
+import org.springframework.ai.models.midjourney.api.MidjourneyInteractionsApi;
+import org.springframework.ai.models.midjourney.api.req.ReRollReq;
+import org.springframework.ai.models.midjourney.webSocket.MidjourneyWebSocketStarter;
+import org.springframework.ai.models.midjourney.webSocket.WssNotify;
 import org.springframework.ai.openai.OpenAiImageClient;
 import org.springframework.ai.openai.OpenAiImageOptions;
 import org.springframework.stereotype.Service;
@@ -47,6 +49,7 @@ import java.util.List;
 public class AiImageServiceImpl implements AiImageService {
 
     private final AiImageMapper aiImageMapper;
+    private final FileApi fileApi;
     private final OpenAiImageClient openAiImageClient;
     private final MidjourneyWebSocketStarter midjourneyWebSocketStarter;
     private final MidjourneyInteractionsApi midjourneyInteractionsApi;
@@ -89,8 +92,6 @@ public class AiImageServiceImpl implements AiImageService {
         // 获取 model
         OpenAiImageModelEnum openAiImageModelEnum = OpenAiImageModelEnum.valueOfModel(req.getModel());
         OpenAiImageStyleEnum openAiImageStyleEnum = OpenAiImageStyleEnum.valueOfStyle(req.getStyle());
-        // 转换 AiImageDallDrawingRespVO
-        AiImageDallRespVO respVO = AiImageConvert.INSTANCE.convertAiImageDallDrawingRespVO(req);
         try {
             // 转换openai 参数
             OpenAiImageOptions openAiImageOptions = new OpenAiImageOptions();
@@ -100,22 +101,21 @@ public class AiImageServiceImpl implements AiImageService {
             ImageResponse imageResponse = openAiImageClient.call(new ImagePrompt(req.getPrompt(), openAiImageOptions));
             // 发送
             ImageGeneration imageGeneration = imageResponse.getResult();
+            // 图片保存到服务器
+            String filePath = fileApi.createFile(HttpUtil.downloadBytes(imageGeneration.getOutput().getUrl()));
             // 保存数据库
-            doSave(req.getPrompt(), req.getSize(), req.getModel(),
-                    imageGeneration.getOutput().getUrl(), AiImageStatusEnum.COMPLETE, null,
+            AiImageDO aiImageDO = doSave(req.getPrompt(), req.getSize(), req.getModel(),
+                    filePath, imageGeneration.getOutput().getUrl(), AiImageStatusEnum.COMPLETE, null,
                     null, null, null);
-            // 返回 flex
-            respVO.setOriginalPicUrl(imageGeneration.getOutput().getUrl());
-            respVO.setBase64(imageGeneration.getOutput().getB64Json());
-            return respVO;
+            // 转换 AiImageDallDrawingRespVO
+            return AiImageConvert.INSTANCE.convertAiImageDallDrawingRespVO(aiImageDO);
         } catch (AiException aiException) {
             // 保存数据库
-            doSave(req.getPrompt(), req.getSize(), req.getModel(),
-                    null, AiImageStatusEnum.FAIL, aiException.getMessage(),
+            AiImageDO aiImageDO = doSave(req.getPrompt(), req.getSize(), req.getModel(),
+                    null, null, AiImageStatusEnum.FAIL, aiException.getMessage(),
                     null, null, null);
             // 发送错误信息
-            respVO.setErrorMessage(aiException.getMessage());
-            return respVO;
+            return AiImageConvert.INSTANCE.convertAiImageDallDrawingRespVO(aiImageDO);
         }
     }
 
@@ -125,7 +125,7 @@ public class AiImageServiceImpl implements AiImageService {
         // 保存数据库
         String messageId = String.valueOf(IdUtil.getSnowflakeNextId());
         AiImageDO aiImageDO = doSave(req.getPrompt(), null, "midjoureny",
-                null, AiImageStatusEnum.SUBMIT, null,
+                null, null, AiImageStatusEnum.SUBMIT, null,
                 messageId, null, null);
         // 提交 midjourney 任务
         Boolean imagine = midjourneyInteractionsApi.imagine(messageId, req.getPrompt());
@@ -149,7 +149,7 @@ public class AiImageServiceImpl implements AiImageService {
         String mjOperationName = midjourneyOperationsVO.getLabel();
         // 保存一个 image 任务记录
         doSave(aiImageDO.getPrompt(), aiImageDO.getSize(), aiImageDO.getModel(),
-                null, AiImageStatusEnum.SUBMIT, null,
+                null, null, AiImageStatusEnum.SUBMIT, null,
                 req.getMessageId(), req.getOperateId(), mjOperationName);
         // 提交操作
         midjourneyInteractionsApi.reRoll(
@@ -201,6 +201,7 @@ public class AiImageServiceImpl implements AiImageService {
     private AiImageDO doSave(String prompt,
                              String size,
                              String model,
+                             String picUrl,
                              String originalPicUrl,
                              AiImageStatusEnum statusEnum,
                              String errorMessage,
@@ -218,6 +219,7 @@ public class AiImageServiceImpl implements AiImageService {
         // TODO @芋艿 如何上传到自己服务器
         aiImageDO.setPicUrl(null);
         aiImageDO.setStatus(statusEnum.getStatus());
+        aiImageDO.setPicUrl(picUrl);
         aiImageDO.setOriginalPicUrl(originalPicUrl);
         aiImageDO.setErrorMessage(errorMessage);
         //