Bladeren bron

【新增】AI:绘图(MJ)接入 API KEY 管理

YunaiV 11 maanden geleden
bovenliggende
commit
6225e18f70

+ 5 - 5
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/AiImageController.java

@@ -9,7 +9,7 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageRespVO;
-import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdatePublicStatusReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdateReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
@@ -114,11 +114,11 @@ public class AiImageController {
         return success(BeanUtils.toBean(pageResult, AiImageRespVO.class));
     }
 
-    @PutMapping("/update-public-status")
-    @Operation(summary = "更新绘画发布状态")
+    @PutMapping("/update")
+    @Operation(summary = "更新绘画")
     @PreAuthorize("@ss.hasPermission('ai:image:update')")
-    public CommonResult<Boolean> updateImagePublicStatus(@Valid @RequestBody AiImageUpdatePublicStatusReqVO updateReqVO) {
-        imageService.updateImagePublicStatus(updateReqVO);
+    public CommonResult<Boolean> updateImage(@Valid @RequestBody AiImageUpdateReqVO updateReqVO) {
+        imageService.updateImage(updateReqVO);
         return success(true);
     }
 

+ 4 - 4
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImageUpdatePublicStatusReqVO.java → yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/image/vo/AiImageUpdateReqVO.java

@@ -4,15 +4,15 @@ import io.swagger.v3.oas.annotations.media.Schema;
 import jakarta.validation.constraints.NotNull;
 import lombok.Data;
 
-@Schema(description = "管理后台 - AI 绘画修改发布状态 Request VO")
+@Schema(description = "管理后台 - AI 绘画修改 Request VO")
 @Data
-public class AiImageUpdatePublicStatusReqVO {
+public class AiImageUpdateReqVO {
 
     @Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "15583")
+    @NotNull(message = "编号不能为空")
     private Long id;
 
-    @Schema(description = "是否发布", requiredMode = Schema.RequiredMode.REQUIRED, example = "true")
-    @NotNull(message = "是否发布不能为空")
+    @Schema(description = "是否发布", example = "true")
     private Boolean publicStatus;
 
 }

+ 3 - 3
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageService.java

@@ -5,7 +5,7 @@ import cn.iocoder.yudao.framework.common.pojo.PageParam;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO;
-import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdatePublicStatusReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdateReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
@@ -71,11 +71,11 @@ public interface AiImageService {
     PageResult<AiImageDO> getImagePage(AiImagePageReqVO pageReqVO);
 
     /**
-     * 更新绘画发布状态
+     * 更新绘画
      *
      * @param updateReqVO 更新信息
      */
-    void updateImagePublicStatus(@Valid AiImageUpdatePublicStatusReqVO updateReqVO);
+    void updateImage(@Valid AiImageUpdateReqVO updateReqVO);
 
     /**
      * 删除绘画

+ 5 - 5
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java

@@ -15,7 +15,7 @@ import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO;
-import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdatePublicStatusReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdateReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
@@ -62,9 +62,6 @@ public class AiImageServiceImpl implements AiImageService {
     @Resource
     private AiApiKeyService apiKeyService;
 
-    @Resource
-    private MidjourneyApi midjourneyApi;
-
     @Override
     public PageResult<AiImageDO> getImagePageMy(Long userId, PageParam pageReqVO) {
         return imageMapper.selectPage(userId, pageReqVO);
@@ -151,7 +148,7 @@ public class AiImageServiceImpl implements AiImageService {
     }
 
     @Override
-    public void updateImagePublicStatus(AiImageUpdatePublicStatusReqVO updateReqVO) {
+    public void updateImage(AiImageUpdateReqVO updateReqVO) {
         // 1. 校验存在
         validateImageExists(updateReqVO.getId());
         // 2. 更新发布状态
@@ -179,6 +176,7 @@ public class AiImageServiceImpl implements AiImageService {
     @Override
     @Transactional(rollbackFor = Exception.class)
     public Long midjourneyImagine(Long userId, AiMidjourneyImagineReqVO reqVO) {
+        MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi();
         // 1. 保存数据库
         AiImageDO image = BeanUtils.toBean(reqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
                 .setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus())
@@ -206,6 +204,7 @@ public class AiImageServiceImpl implements AiImageService {
 
     @Override
     public Integer midjourneySync() {
+        MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi();
         // 1.1 获取 Midjourney 平台,状态在 “进行中” 的 image
         List<AiImageDO> imageList = imageMapper.selectListByStatusAndPlatform(
                 AiImageStatusEnum.IN_PROGRESS.getStatus(), AiPlatformEnum.MIDJOURNEY.getPlatform());
@@ -272,6 +271,7 @@ public class AiImageServiceImpl implements AiImageService {
 
     @Override
     public Long midjourneyAction(Long userId, AiMidjourneyActionReqVO reqVO) {
+        MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi();
         // 1.1 检查 image
         AiImageDO image = validateImageExists(reqVO.getId());
         if (ObjUtil.notEqual(userId, image.getUserId())) {

+ 0 - 58
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/MidjourneyImageOptions.java

@@ -1,58 +0,0 @@
-package cn.iocoder.yudao.module.ai.service.image;
-
-import lombok.Data;
-import org.springframework.ai.image.ImageOptions;
-
-/**
- * @author fansili
- * @time 2024/6/5 10:34
- * @since 1.0
- */
-@Data
-public class MidjourneyImageOptions implements ImageOptions {
-    /**
-     * 模型
-     */
-    private String model;
-    /**
-     * 宽度
-     */
-    private Integer width;
-    /**
-     * 高度
-     */
-    private Integer height;
-    /**
-     * 版本
-     */
-    private String version;
-    /**
-     * 参数
-     */
-    private String state;
-
-    @Override
-    public Integer getN() {
-        return 0;
-    }
-
-    @Override
-    public String getModel() {
-        return model;
-    }
-
-    @Override
-    public Integer getWidth() {
-        return width;
-    }
-
-    @Override
-    public Integer getHeight() {
-        return height;
-    }
-
-    @Override
-    public String getResponseFormat() {
-        return "";
-    }
-}

+ 10 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java

@@ -1,6 +1,7 @@
 package cn.iocoder.yudao.module.ai.service.model;
 
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
+import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
 import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO;
@@ -92,6 +93,15 @@ public interface AiApiKeyService {
      */
     ImageClient getImageClient(AiPlatformEnum platform);
 
+    /**
+     * 获得 MidjourneyApi 对象
+     *
+     * TODO 可优化点:目前默认获取 Midjourney 对应的第一个开启的配置用于绘画;后续可以支持配置选择
+     *
+     * @return MidjourneyApi 对象
+     */
+    MidjourneyApi getMidjourneyApi();
+
     /**
      * 获得 SunoApi 对象
      *

+ 11 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java

@@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.service.model;
 
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory;
+import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
 import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
 import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
@@ -112,6 +113,16 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
         return clientFactory.getOrCreateImageClient(platform, apiKey.getApiKey(), apiKey.getUrl());
     }
 
+    @Override
+    public MidjourneyApi getMidjourneyApi() {
+        AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(
+                AiPlatformEnum.MIDJOURNEY.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
+        if (apiKey == null) {
+            return null;
+        }
+        return clientFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl());
+    }
+
     @Override
     public SunoApi getSunoApi() {
         AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(

+ 12 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactory.java

@@ -1,6 +1,7 @@
 package cn.iocoder.yudao.framework.ai.core.factory;
 
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
+import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
 import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
 import org.springframework.ai.chat.StreamingChatClient;
 import org.springframework.ai.image.ImageClient;
@@ -56,6 +57,17 @@ public interface AiClientFactory {
      */
     ImageClient getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url);
 
+    /**
+     * 基于指定配置,获得 MidjourneyApi 对象
+     *
+     * 如果不存在,则进行创建
+     *
+     * @param apiKey API KEY
+     * @param url API URL
+     * @return MidjourneyApi 对象
+     */
+    MidjourneyApi getOrCreateMidjourneyApi(String apiKey, String url);
+
     /**
      * 基于指定配置,获得 SunoApi 对象
      *

+ 12 - 1
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactoryImpl.java

@@ -9,6 +9,7 @@ import cn.hutool.extra.spring.SpringUtil;
 import cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration;
 import cn.iocoder.yudao.framework.ai.config.YudaoAiProperties;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
+import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
 import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
 import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient;
 import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal;
@@ -110,9 +111,19 @@ public class AiClientFactoryImpl implements AiClientFactory {
         }
     }
 
+    @Override
+    public MidjourneyApi getOrCreateMidjourneyApi(String apiKey, String url) {
+        String cacheKey = buildClientCacheKey(MidjourneyApi.class, AiPlatformEnum.MIDJOURNEY.getPlatform(), apiKey, url);
+        return Singleton.get(cacheKey, (Func0<MidjourneyApi>) () -> {
+            YudaoAiProperties.MidjourneyProperties properties = SpringUtil.getBean(YudaoAiProperties.class).getMidjourney();
+            return new MidjourneyApi(url, apiKey, properties.getNotifyUrl());
+        });
+    }
+
     @Override
     public SunoApi getOrCreateSunoApi(String apiKey, String url) {
-        return new SunoApi(url);
+        String cacheKey = buildClientCacheKey(SunoApi.class, AiPlatformEnum.SUNO.getPlatform(), apiKey, url);
+        return Singleton.get(cacheKey, (Func0<SunoApi>) () -> new SunoApi(url));
     }
 
     private static String buildClientCacheKey(Class<?> clazz, Object... params) {