Kaynağa Gözat

增加图片选择,和图片放大相关操作

cherishsince 1 yıl önce
ebeveyn
işleme
b1158fb1a7

+ 20 - 10
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/MidjourneyConfig.java

@@ -4,6 +4,7 @@ import lombok.Data;
 import lombok.experimental.Accessors;
 
 import java.util.Map;
+import java.util.UUID;
 
 /**
  * Midjourney 配置
@@ -15,16 +16,6 @@ import java.util.Map;
 @Accessors(chain = true)
 public class MidjourneyConfig {
 
-    public MidjourneyConfig(String token, String guildId, String channelId, Map<String, String> requestTemplates) {
-        this.token = token;
-        this.guildId = guildId;
-        this.channelId = channelId;
-        this.serverUrl = serverUrl;
-        this.apiInteractions = apiInteractions;
-        this.userAage = userAage;
-        this.requestTemplates = requestTemplates;
-    }
-
     /**
      * token信息
      *
@@ -64,4 +55,23 @@ public class MidjourneyConfig {
 
 
     private Map<String, String> requestTemplates;
+
+    //
+    //
+
+    private String sessionId;
+
+    public MidjourneyConfig(String token, String guildId, String channelId, Map<String, String> requestTemplates) {
+        this.token = token;
+        this.guildId = guildId;
+        this.channelId = channelId;
+        this.serverUrl = serverUrl;
+        this.apiInteractions = apiInteractions;
+        this.userAage = userAage;
+        this.requestTemplates = requestTemplates;
+
+        // 生成 session id
+        sessionId = UUID.randomUUID().toString().replaceAll("-", "");
+    }
+
 }

+ 1 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/MjMessage.java

@@ -17,6 +17,7 @@ public class MjMessage {
 	 * 现在已知:
 	 * 0:我们发送的消息,和指令
 	 * 20: mj生成图片发送过程中
+	 * 19: 选择了某一张图片后的通知
 	 */
 	private Integer type;
 	/**

+ 1 - 2
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/interactions/MjClient.java

@@ -32,8 +32,7 @@ public class MjClient {
         // 封装请求体和头部信息
         HttpEntity<String> requestEntity = new HttpEntity<>(body, headers);
         // 发送请求
-        String result = restTemplate.postForObject(url, requestEntity, String.class);
-        return result;
+        return restTemplate.postForObject(url, requestEntity, String.class);
     }
 
 

+ 30 - 3
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/interactions/MjImagineInteractions.java

@@ -4,6 +4,8 @@ import cn.hutool.core.util.IdUtil;
 import cn.hutool.core.util.StrUtil;
 import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
 import cn.iocoder.yudao.framework.ai.midjourney.constants.MjInteractionsEnum;
+import cn.iocoder.yudao.framework.ai.midjourney.vo.ReRoll;
+import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import lombok.extern.slf4j.Slf4j;
 
@@ -12,7 +14,7 @@ import java.util.List;
 import java.util.UUID;
 
 /**
- *
+ * 图片生成
  *
  * author: fansili
  * time: 2024/4/3 17:36
@@ -28,7 +30,7 @@ public class MjImagineInteractions implements MjInteractions {
 
     @Override
     public List<MjInteractionsEnum> supperInteractions() {
-        return null;
+        return Lists.newArrayList(MjInteractionsEnum.IMAGINE);
     }
 
     @Override
@@ -40,7 +42,7 @@ public class MjImagineInteractions implements MjInteractions {
         HashMap<String, String> requestParams = Maps.newHashMap();
         requestParams.put("guild_id", midjourneyConfig.getGuildId());
         requestParams.put("channel_id", midjourneyConfig.getChannelId());
-        requestParams.put("session_id", UUID.randomUUID().toString().replaceAll("-", ""));
+        requestParams.put("session_id", midjourneyConfig.getSessionId());
         requestParams.put("nonce", String.valueOf(IdUtil.getSnowflakeNextId()));
         requestParams.put("prompt", prompt);
         // 设置参数
@@ -55,4 +57,29 @@ public class MjImagineInteractions implements MjInteractions {
         log.error("请求失败! 请求参数:{} 返回结果! {}", requestBody, res);
         return isSuccess;
     }
+
+    public Boolean reRoll(ReRoll reRoll) {
+        String url = midjourneyConfig.getServerUrl().concat(midjourneyConfig.getApiInteractions());
+        // 获取请求模板
+        String requestTemplate = midjourneyConfig.getRequestTemplates().get("reroll");
+        // 设置参数
+        HashMap<String, String> requestParams = Maps.newHashMap();
+        requestParams.put("guild_id", midjourneyConfig.getGuildId());
+        requestParams.put("channel_id", midjourneyConfig.getChannelId());
+        requestParams.put("session_id", midjourneyConfig.getSessionId());
+        requestParams.put("nonce", String.valueOf(IdUtil.getSnowflakeNextId()));
+        requestParams.put("custom_id", reRoll.getCustomId());
+        requestParams.put("message_id", reRoll.getMessageId());
+        // 设置参数
+        String requestBody = MjClient.setParams(requestTemplate, requestParams);
+        // 发送请求
+        String res = MjClient.post(url, midjourneyConfig.getToken(), requestBody);
+        // 这个 res 只要不返回值,就是成功!
+        boolean isSuccess = StrUtil.isBlank(res);
+        if (isSuccess) {
+            return true;
+        }
+        log.error("请求失败! 请求参数:{} 返回结果! {}", requestBody, res);
+        return isSuccess;
+    }
 }

+ 16 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/vo/ReRoll.java

@@ -0,0 +1,16 @@
+package cn.iocoder.yudao.framework.ai.midjourney.vo;
+
+import lombok.Data;
+import lombok.experimental.Accessors;
+
+/**
+ * author: fansili
+ * time: 2024/4/6 21:33
+ */
+@Data
+@Accessors(chain = true)
+public class ReRoll {
+
+    private String messageId;
+    private String customId;
+}

+ 14 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/resources/http-body/reroll.json

@@ -0,0 +1,14 @@
+{
+  "type": 3,
+  "guild_id": "$guild_id",
+  "channel_id": "$channel_id",
+  "message_id": "$message_id",
+  "application_id": "936929561302675456",
+  "session_id": "$session_id",
+  "nonce": "$nonce",
+  "message_flags": 0,
+  "data": {
+    "component_type": 2,
+    "custom_id": "$custom_id"
+  }
+}