|
@@ -1,5 +1,6 @@
|
|
|
package cn.iocoder.yudao.module.ai.service.impl;
|
|
|
|
|
|
+import cn.hutool.core.util.StrUtil;
|
|
|
import cn.iocoder.yudao.framework.ai.exception.AiException;
|
|
|
import cn.iocoder.yudao.framework.ai.image.ImageGeneration;
|
|
|
import cn.iocoder.yudao.framework.ai.image.ImagePrompt;
|
|
@@ -9,18 +10,20 @@ import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageOptions;
|
|
|
import cn.iocoder.yudao.framework.ai.imageopenai.enums.OpenAiImageModelEnum;
|
|
|
import cn.iocoder.yudao.framework.ai.imageopenai.enums.OpenAiImageStyleEnum;
|
|
|
import cn.iocoder.yudao.framework.ai.midjourney.api.MidjourneyInteractionsApi;
|
|
|
+import cn.iocoder.yudao.framework.ai.midjourney.api.req.ReRollReq;
|
|
|
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyWebSocketStarter;
|
|
|
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.WssNotify;
|
|
|
import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
|
|
|
+import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
|
|
+import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
|
|
|
+import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
|
|
|
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
|
|
|
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
|
|
|
-import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReqVO;
|
|
|
-import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingRespVO;
|
|
|
-import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReqVO;
|
|
|
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
|
|
|
import cn.iocoder.yudao.module.ai.convert.AiImageConvert;
|
|
|
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.mysql.AiImageMapper;
|
|
|
-import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum;
|
|
|
+import cn.iocoder.yudao.module.ai.enums.AiImageDrawingStatusEnum;
|
|
|
import cn.iocoder.yudao.module.ai.service.AiImageService;
|
|
|
import jakarta.annotation.PostConstruct;
|
|
|
import lombok.AllArgsConstructor;
|
|
@@ -28,6 +31,9 @@ import lombok.extern.slf4j.Slf4j;
|
|
|
import org.springframework.stereotype.Service;
|
|
|
import org.springframework.transaction.annotation.Transactional;
|
|
|
|
|
|
+import java.util.Collections;
|
|
|
+import java.util.List;
|
|
|
+
|
|
|
/**
|
|
|
* ai 作图
|
|
|
*
|
|
@@ -61,6 +67,23 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
});
|
|
|
}
|
|
|
|
|
|
+ @Override
|
|
|
+ public PageResult<AiImageListRespVO> list(AiImageListReqVO req) {
|
|
|
+ // 获取登录用户
|
|
|
+ Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
|
|
+ // 查询当前用户下所有的绘画记录
|
|
|
+ PageResult<AiImageDO> pageResult = aiImageMapper.selectPage(req,
|
|
|
+ new LambdaQueryWrapperX<AiImageDO>()
|
|
|
+ .eq(AiImageDO::getUserId, loginUserId)
|
|
|
+ .orderByDesc(AiImageDO::getId)
|
|
|
+ );
|
|
|
+ // 转换 PageResult<AiImageListRespVO> 返回
|
|
|
+ PageResult<AiImageListRespVO> result = new PageResult<>();
|
|
|
+ result.setTotal(pageResult.getTotal());
|
|
|
+ result.setList(AiImageConvert.INSTANCE.convertAiImageListRespVO(pageResult.getList()));
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+
|
|
|
@Override
|
|
|
public AiImageDallDrawingRespVO dallDrawing(AiImageDallDrawingReqVO req) {
|
|
|
// 获取 model
|
|
@@ -79,7 +102,8 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
ImageGeneration imageGeneration = imageResponse.getResult();
|
|
|
// 保存数据库
|
|
|
doSave(req.getPrompt(), req.getSize(), req.getModal(),
|
|
|
- imageGeneration.getOutput().getUrl(), AiChatDrawingStatusEnum.COMPLETE, null);
|
|
|
+ imageGeneration.getOutput().getUrl(), AiImageDrawingStatusEnum.COMPLETE, null,
|
|
|
+ null, null, null);
|
|
|
// 返回 flex
|
|
|
respVO.setUrl(imageGeneration.getOutput().getUrl());
|
|
|
respVO.setBase64(imageGeneration.getOutput().getB64Json());
|
|
@@ -87,7 +111,8 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
} catch (AiException aiException) {
|
|
|
// 保存数据库
|
|
|
doSave(req.getPrompt(), req.getSize(), req.getModal(),
|
|
|
- null, AiChatDrawingStatusEnum.FAIL, aiException.getMessage());
|
|
|
+ null, AiImageDrawingStatusEnum.FAIL, aiException.getMessage(),
|
|
|
+ null, null, null);
|
|
|
// 发送错误信息
|
|
|
respVO.setErrorMessage(aiException.getMessage());
|
|
|
return respVO;
|
|
@@ -99,7 +124,8 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
public void midjourney(AiImageMidjourneyReqVO req) {
|
|
|
// 保存数据库
|
|
|
AiImageDO aiImageDO = doSave(req.getPrompt(), null, "midjoureny",
|
|
|
- null, AiChatDrawingStatusEnum.SUBMIT, null);
|
|
|
+ null, AiImageDrawingStatusEnum.SUBMIT, null,
|
|
|
+ null, null, null);
|
|
|
// 提交 midjourney 任务
|
|
|
Boolean imagine = midjourneyInteractionsApi.imagine(aiImageDO.getId(), req.getPrompt());
|
|
|
if (!imagine) {
|
|
@@ -107,23 +133,71 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-// private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) {
|
|
|
-// try {
|
|
|
-// sseEmitter.send(object, MediaType.APPLICATION_JSON);
|
|
|
-// } catch (IOException e) {
|
|
|
-// throw new RuntimeException(e);
|
|
|
-// } finally {
|
|
|
-// // 发送 complete
|
|
|
-// sseEmitter.complete();
|
|
|
-// }
|
|
|
-// }
|
|
|
+ @Transactional(rollbackFor = Exception.class)
|
|
|
+ @Override
|
|
|
+ public void midjourneyOperate(AiImageMidjourneyOperateReqVO req) {
|
|
|
+ // 校验是否存在
|
|
|
+ AiImageDO aiImageDO = validateExists(req);
|
|
|
+ // 获取 midjourneyOperations
|
|
|
+ List<AiImageMidjourneyOperationsVO> midjourneyOperations = getMidjourneyOperations(aiImageDO);
|
|
|
+ // 校验 OperateId 是否存在
|
|
|
+ AiImageMidjourneyOperationsVO midjourneyOperationsVO = validateMidjourneyOperationsExists(midjourneyOperations, req.getOperateId());
|
|
|
+ // 校验 messageId
|
|
|
+ validateMessageId(aiImageDO.getMjMessageId(), req.getMessageId());
|
|
|
+ // 获取 mjOperationName
|
|
|
+ String mjOperationName = midjourneyOperationsVO.getLabel();
|
|
|
+ // 保存一个 image 任务记录
|
|
|
+ doSave(aiImageDO.getPrompt(), aiImageDO.getSize(), aiImageDO.getModal(),
|
|
|
+ null, AiImageDrawingStatusEnum.SUBMIT, null,
|
|
|
+ req.getMessageId(), req.getOperateId(), mjOperationName);
|
|
|
+ // 提交操作
|
|
|
+ midjourneyInteractionsApi.reRoll(
|
|
|
+ new ReRollReq()
|
|
|
+ .setCustomId(req.getOperateId())
|
|
|
+ .setMessageId(req.getMessageId())
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ private void validateMessageId(String mjMessageId, String messageId) {
|
|
|
+ if (!mjMessageId.equals(messageId)) {
|
|
|
+ throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_MESSAGE_ID_INCORRECT);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private AiImageMidjourneyOperationsVO validateMidjourneyOperationsExists(List<AiImageMidjourneyOperationsVO> midjourneyOperations, String operateId) {
|
|
|
+ for (AiImageMidjourneyOperationsVO midjourneyOperation : midjourneyOperations) {
|
|
|
+ if (midjourneyOperation.getCustom_id().equals(operateId)) {
|
|
|
+ return midjourneyOperation;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_OPERATION_NOT_EXISTS);
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ private List<AiImageMidjourneyOperationsVO> getMidjourneyOperations(AiImageDO aiImageDO) {
|
|
|
+ if (StrUtil.isBlank(aiImageDO.getMjOperations())) {
|
|
|
+ return Collections.emptyList();
|
|
|
+ }
|
|
|
+ return JsonUtils.parseArray(aiImageDO.getMjOperations(), AiImageMidjourneyOperationsVO.class);
|
|
|
+ }
|
|
|
+
|
|
|
+ private AiImageDO validateExists(AiImageMidjourneyOperateReqVO req) {
|
|
|
+ AiImageDO aiImageDO = aiImageMapper.selectById(req.getId());
|
|
|
+ if (aiImageDO == null) {
|
|
|
+ throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_IMAGINE_FAIL);
|
|
|
+ }
|
|
|
+ return aiImageDO;
|
|
|
+ }
|
|
|
|
|
|
private AiImageDO doSave(String prompt,
|
|
|
- String size,
|
|
|
- String model,
|
|
|
- String imageUrl,
|
|
|
- AiChatDrawingStatusEnum drawingStatusEnum,
|
|
|
- String drawingError) {
|
|
|
+ String size,
|
|
|
+ String model,
|
|
|
+ String drawingImageUrl,
|
|
|
+ AiImageDrawingStatusEnum drawingStatusEnum,
|
|
|
+ String drawingErrorMessage,
|
|
|
+ String mjMessageId,
|
|
|
+ String mjOperationId,
|
|
|
+ String mjOperationName) {
|
|
|
// 保存数据库
|
|
|
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
|
|
AiImageDO aiImageDO = new AiImageDO();
|
|
@@ -132,9 +206,15 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
aiImageDO.setSize(size);
|
|
|
aiImageDO.setModal(model);
|
|
|
aiImageDO.setUserId(loginUserId);
|
|
|
- aiImageDO.setDrawingImageUrl(imageUrl);
|
|
|
+ // TODO @芋艿 如何上传到自己服务器
|
|
|
+ aiImageDO.setImageUrl(null);
|
|
|
aiImageDO.setDrawingStatus(drawingStatusEnum.getStatus());
|
|
|
- aiImageDO.setDrawingError(drawingError);
|
|
|
+ aiImageDO.setDrawingImageUrl(drawingImageUrl);
|
|
|
+ aiImageDO.setDrawingErrorMessage(drawingErrorMessage);
|
|
|
+ //
|
|
|
+ aiImageDO.setMjMessageId(mjMessageId);
|
|
|
+ aiImageDO.setMjOperationId(mjOperationId);
|
|
|
+ aiImageDO.setMjOperationName(mjOperationName);
|
|
|
aiImageMapper.insert(aiImageDO);
|
|
|
return aiImageDO;
|
|
|
}
|