|
@@ -1,7 +1,7 @@
|
|
|
package cn.iocoder.yudao.module.ai.service.image;
|
|
|
|
|
|
-import cn.hutool.core.util.IdUtil;
|
|
|
import cn.hutool.http.HttpUtil;
|
|
|
+import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
|
|
import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageModelEnum;
|
|
|
import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageStyleEnum;
|
|
|
import cn.iocoder.yudao.framework.ai.core.exception.AiException;
|
|
@@ -11,6 +11,10 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
|
|
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
|
|
|
import cn.iocoder.yudao.module.ai.AiCommonConstants;
|
|
|
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
|
|
|
+import cn.iocoder.yudao.module.ai.client.MidjourneyProxyClient;
|
|
|
+import cn.iocoder.yudao.module.ai.client.vo.MidjourneyImagineReqVO;
|
|
|
+import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitCodeEnum;
|
|
|
+import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitRespVO;
|
|
|
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
|
|
|
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
|
|
@@ -18,21 +22,22 @@ import cn.iocoder.yudao.module.ai.enums.AiImagePublicStatusEnum;
|
|
|
import cn.iocoder.yudao.module.ai.enums.AiImageStatusEnum;
|
|
|
import cn.iocoder.yudao.module.infra.api.file.FileApi;
|
|
|
import com.google.common.collect.ImmutableMap;
|
|
|
-import jakarta.annotation.PostConstruct;
|
|
|
import jakarta.annotation.Resource;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
import org.springframework.ai.image.ImageGeneration;
|
|
|
import org.springframework.ai.image.ImagePrompt;
|
|
|
import org.springframework.ai.image.ImageResponse;
|
|
|
-import org.springframework.ai.models.midjourney.api.MidjourneyInteractionsApi;
|
|
|
-import org.springframework.ai.models.midjourney.webSocket.MidjourneyWebSocketStarter;
|
|
|
import org.springframework.ai.openai.OpenAiImageClient;
|
|
|
import org.springframework.ai.openai.OpenAiImageOptions;
|
|
|
+import org.springframework.beans.factory.annotation.Autowired;
|
|
|
+import org.springframework.beans.factory.annotation.Value;
|
|
|
import org.springframework.scheduling.annotation.Async;
|
|
|
import org.springframework.stereotype.Service;
|
|
|
import org.springframework.transaction.annotation.Transactional;
|
|
|
|
|
|
+import java.util.HashMap;
|
|
|
import java.util.List;
|
|
|
+import java.util.Map;
|
|
|
|
|
|
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
|
|
|
|
|
@@ -59,28 +64,11 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
private FileApi fileApi;
|
|
|
@Resource
|
|
|
private OpenAiImageClient openAiImageClient;
|
|
|
- @Resource
|
|
|
- private MidjourneyWebSocketStarter midjourneyWebSocketStarter;
|
|
|
- @Resource
|
|
|
- private MidjourneyInteractionsApi midjourneyInteractionsApi;
|
|
|
-
|
|
|
- // TODO @fan:接 mj proxy
|
|
|
- @PostConstruct
|
|
|
- public void startMidjourney() {
|
|
|
- // todo @fan 暂时注释掉
|
|
|
-// log.info("midjourney web socket starter...");
|
|
|
-// midjourneyWebSocketStarter.start(new WssNotify() {
|
|
|
-// @Override
|
|
|
-// public void notify(int code, String message) {
|
|
|
-// log.info("code: {}, message: {}", code, message);
|
|
|
-// if (message.contains("Authentication failed")) {
|
|
|
-// // TODO 芋艿,这里看怎么处理,token无效的时候会认证失败!
|
|
|
-// // 认证失败
|
|
|
-// log.error("midjourney socket 认证失败,检查token是否失效!");
|
|
|
-// }
|
|
|
-// }
|
|
|
-// });
|
|
|
- }
|
|
|
+ @Autowired
|
|
|
+ private MidjourneyProxyClient midjourneyProxyClient;
|
|
|
+
|
|
|
+ @Value("${ai.midjourney-proxy.notifyUrl:http://127.0.0.1:48080/admin-api/ai/image/midjourney-notify}")
|
|
|
+ private String midjourneyNotifyUrl;
|
|
|
|
|
|
@Override
|
|
|
public PageResult<AiImageDO> getImagePageMy(Long loginUserId, AiImageListReqVO req) {
|
|
@@ -143,18 +131,53 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
|
|
|
@Override
|
|
|
@Transactional(rollbackFor = Exception.class)
|
|
|
- public void midjourney(AiImageMidjourneyReqVO req) {
|
|
|
- // 保存数据库
|
|
|
- String messageId = String.valueOf(IdUtil.getSnowflakeNextId());
|
|
|
- // todo
|
|
|
-// AiImageDO aiImageDO = doSave(req.getPrompt(), null, "midjoureny",
|
|
|
-// null, null, AiImageStatusEnum.SUBMIT, null,
|
|
|
-// messageId, null, null);
|
|
|
- // 提交 midjourney 任务
|
|
|
- Boolean imagine = midjourneyInteractionsApi.imagine(messageId, req.getPrompt());
|
|
|
- if (!imagine) {
|
|
|
- throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_IMAGINE_FAIL);
|
|
|
+ public Long midjourneyImagine(Long loginUserId, AiImageMidjourneyImagineReqVO req) {
|
|
|
+
|
|
|
+ // 1、构建 AiImageDO
|
|
|
+ AiImageDO aiImageDO = new AiImageDO();
|
|
|
+ aiImageDO.setId(null);
|
|
|
+ aiImageDO.setUserId(loginUserId);
|
|
|
+ aiImageDO.setPrompt(req.getPrompt());
|
|
|
+ aiImageDO.setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform());
|
|
|
+ // todo @范 平台需要转换(mj 模型一般分版本)
|
|
|
+ aiImageDO.setModel(null);
|
|
|
+ aiImageDO.setWidth(null);
|
|
|
+ aiImageDO.setHeight(null);
|
|
|
+ aiImageDO.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
|
|
|
+ aiImageDO.setPublicStatus(AiImagePublicStatusEnum.PRIVATE.getStatus());
|
|
|
+ aiImageDO.setPicUrl(null);
|
|
|
+ aiImageDO.setOriginalPicUrl(null);
|
|
|
+ aiImageDO.setDrawRequest(null);
|
|
|
+ aiImageDO.setDrawResponse(null);
|
|
|
+ aiImageDO.setErrorMessage(null);
|
|
|
+
|
|
|
+ // 2、保存 image
|
|
|
+ imageMapper.insert(aiImageDO);
|
|
|
+
|
|
|
+ // 3、调用 MidjourneyProxy 提交任务
|
|
|
+ MidjourneyImagineReqVO imagineReqVO = BeanUtils.toBean(req, MidjourneyImagineReqVO.class);
|
|
|
+ imagineReqVO.setNotifyHook(midjourneyNotifyUrl);
|
|
|
+ imagineReqVO.setState(String.valueOf(aiImageDO.getId()));
|
|
|
+ MidjourneySubmitRespVO submitRespVO = midjourneyProxyClient.imagine(imagineReqVO);
|
|
|
+
|
|
|
+ // 4、保存任务 id (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误))
|
|
|
+ String updateStatus = null;
|
|
|
+ String errorMessage = null;
|
|
|
+ Map<String, Object> drawResponse = new HashMap<>();
|
|
|
+
|
|
|
+ if (!MidjourneySubmitCodeEnum.SUCCESS_CODES.contains(submitRespVO.getCode())) {
|
|
|
+ updateStatus = AiImageStatusEnum.FAIL.getStatus();
|
|
|
+ errorMessage = submitRespVO.getDescription();
|
|
|
+ } else {
|
|
|
+ drawResponse.put("jobId", submitRespVO.getResult());
|
|
|
}
|
|
|
+ imageMapper.updateById(new AiImageDO()
|
|
|
+ .setId(aiImageDO.getId())
|
|
|
+ .setStatus(updateStatus)
|
|
|
+ .setErrorMessage(errorMessage)
|
|
|
+ .setDrawResponse(drawResponse)
|
|
|
+ );
|
|
|
+ return aiImageDO.getId();
|
|
|
}
|
|
|
|
|
|
@Transactional(rollbackFor = Exception.class)
|