Преглед на файлове

【增加】niji 模型参数设置

cherishsince преди 1 година
родител
ревизия
8f3076b2ea

+ 30 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/client/enums/MidjourneyModelEnum.java

@@ -0,0 +1,30 @@
+package cn.iocoder.yudao.module.ai.client.enums;
+
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+
+/**
+ * 来源于 midjourney-proxy
+ */
+@Getter
+@AllArgsConstructor
+public enum MidjourneyModelEnum {
+
+	MIDJOURNEY("midjourney", "midjourney"),
+	NIJI("Niji", "Niji"),
+
+	;
+
+	private String model;
+	private String name;
+
+	public static MidjourneyModelEnum valueOfModel(String model) {
+		for (MidjourneyModelEnum itemEnum : MidjourneyModelEnum.values()) {
+			if (itemEnum.getModel().equals(model)) {
+				return itemEnum;
+			}
+		}
+		throw new IllegalArgumentException("Invalid MessageType value: " + model);
+	}
+}

+ 11 - 4
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java

@@ -12,6 +12,7 @@ import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
 import cn.iocoder.yudao.module.ai.AiCommonConstants;
 import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
 import cn.iocoder.yudao.module.ai.client.MidjourneyProxyClient;
+import cn.iocoder.yudao.module.ai.client.enums.MidjourneyModelEnum;
 import cn.iocoder.yudao.module.ai.client.enums.MidjourneySubmitCodeEnum;
 import cn.iocoder.yudao.module.ai.client.vo.MidjourneyImagineReqVO;
 import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitRespVO;
@@ -157,10 +158,16 @@ public class AiImageServiceImpl implements AiImageService {
         // 3、调用 MidjourneyProxy 提交任务
         MidjourneyImagineReqVO imagineReqVO = BeanUtils.toBean(req, MidjourneyImagineReqVO.class);
         imagineReqVO.setNotifyHook(midjourneyNotifyUrl);
-        // 设置 midjourney 扩展参数,通过 --ar 来设置尺寸
-        String midjourneySizeParam = String.format("--ar %s:%s", req.getWidth(), req.getHeight());
-        String midjourneyVersionParam = String.format("--v %s", req.getVersion());
-        imagineReqVO.setState(midjourneySizeParam.concat(" ").concat(midjourneyVersionParam));
+        // 设置 midjourney 扩展参数
+        //  --ar 来设置尺寸
+        String midjourneySizeParam = String.format(" --ar %s:%s ", req.getWidth(), req.getHeight());
+        // --v 版本
+        String midjourneyVersionParam = String.format(" --v %s ", req.getVersion());
+        // --niji 模型
+        MidjourneyModelEnum midjourneyModelEnum = MidjourneyModelEnum.valueOfModel(req.getModel());
+        String midjourneyNijiParam = MidjourneyModelEnum.NIJI == midjourneyModelEnum ? " --niji " : "";
+        // 设置参数
+        imagineReqVO.setState(midjourneySizeParam.concat(midjourneyVersionParam).concat(midjourneyNijiParam));
         MidjourneySubmitRespVO submitRespVO = midjourneyProxyClient.imagine(imagineReqVO);
 
         // 4、保存任务 id (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误))