|
@@ -2,16 +2,10 @@ package cn.iocoder.yudao.module.ai.service.image;
|
|
|
|
|
|
import cn.hutool.core.util.IdUtil;
|
|
|
import cn.hutool.core.util.StrUtil;
|
|
|
-import cn.iocoder.yudao.framework.ai.core.exception.AiException;
|
|
|
-import org.springframework.ai.image.ImageGeneration;
|
|
|
-import org.springframework.ai.image.ImagePrompt;
|
|
|
-import org.springframework.ai.image.ImageResponse;
|
|
|
+import cn.hutool.http.HttpUtil;
|
|
|
import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageModelEnum;
|
|
|
import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageStyleEnum;
|
|
|
-import org.springframework.ai.models.midjourney.api.MidjourneyInteractionsApi;
|
|
|
-import org.springframework.ai.models.midjourney.api.req.ReRollReq;
|
|
|
-import org.springframework.ai.models.midjourney.webSocket.MidjourneyWebSocketStarter;
|
|
|
-import org.springframework.ai.models.midjourney.webSocket.WssNotify;
|
|
|
+import cn.iocoder.yudao.framework.ai.core.exception.AiException;
|
|
|
import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
|
|
|
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
|
|
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
|
|
@@ -23,9 +17,17 @@ import cn.iocoder.yudao.module.ai.convert.AiImageConvert;
|
|
|
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
|
|
|
import cn.iocoder.yudao.module.ai.enums.AiImageStatusEnum;
|
|
|
+import cn.iocoder.yudao.module.infra.api.file.FileApi;
|
|
|
import jakarta.annotation.PostConstruct;
|
|
|
import lombok.AllArgsConstructor;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
+import org.springframework.ai.image.ImageGeneration;
|
|
|
+import org.springframework.ai.image.ImagePrompt;
|
|
|
+import org.springframework.ai.image.ImageResponse;
|
|
|
+import org.springframework.ai.models.midjourney.api.MidjourneyInteractionsApi;
|
|
|
+import org.springframework.ai.models.midjourney.api.req.ReRollReq;
|
|
|
+import org.springframework.ai.models.midjourney.webSocket.MidjourneyWebSocketStarter;
|
|
|
+import org.springframework.ai.models.midjourney.webSocket.WssNotify;
|
|
|
import org.springframework.ai.openai.OpenAiImageClient;
|
|
|
import org.springframework.ai.openai.OpenAiImageOptions;
|
|
|
import org.springframework.stereotype.Service;
|
|
@@ -47,6 +49,7 @@ import java.util.List;
|
|
|
public class AiImageServiceImpl implements AiImageService {
|
|
|
|
|
|
private final AiImageMapper aiImageMapper;
|
|
|
+ private final FileApi fileApi;
|
|
|
private final OpenAiImageClient openAiImageClient;
|
|
|
private final MidjourneyWebSocketStarter midjourneyWebSocketStarter;
|
|
|
private final MidjourneyInteractionsApi midjourneyInteractionsApi;
|
|
@@ -89,8 +92,6 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
// 获取 model
|
|
|
OpenAiImageModelEnum openAiImageModelEnum = OpenAiImageModelEnum.valueOfModel(req.getModel());
|
|
|
OpenAiImageStyleEnum openAiImageStyleEnum = OpenAiImageStyleEnum.valueOfStyle(req.getStyle());
|
|
|
- // 转换 AiImageDallDrawingRespVO
|
|
|
- AiImageDallRespVO respVO = AiImageConvert.INSTANCE.convertAiImageDallDrawingRespVO(req);
|
|
|
try {
|
|
|
// 转换openai 参数
|
|
|
OpenAiImageOptions openAiImageOptions = new OpenAiImageOptions();
|
|
@@ -100,22 +101,21 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
ImageResponse imageResponse = openAiImageClient.call(new ImagePrompt(req.getPrompt(), openAiImageOptions));
|
|
|
// 发送
|
|
|
ImageGeneration imageGeneration = imageResponse.getResult();
|
|
|
+ // 图片保存到服务器
|
|
|
+ String filePath = fileApi.createFile(HttpUtil.downloadBytes(imageGeneration.getOutput().getUrl()));
|
|
|
// 保存数据库
|
|
|
- doSave(req.getPrompt(), req.getSize(), req.getModel(),
|
|
|
- imageGeneration.getOutput().getUrl(), AiImageStatusEnum.COMPLETE, null,
|
|
|
+ AiImageDO aiImageDO = doSave(req.getPrompt(), req.getSize(), req.getModel(),
|
|
|
+ filePath, imageGeneration.getOutput().getUrl(), AiImageStatusEnum.COMPLETE, null,
|
|
|
null, null, null);
|
|
|
- // 返回 flex
|
|
|
- respVO.setOriginalPicUrl(imageGeneration.getOutput().getUrl());
|
|
|
- respVO.setBase64(imageGeneration.getOutput().getB64Json());
|
|
|
- return respVO;
|
|
|
+ // 转换 AiImageDallDrawingRespVO
|
|
|
+ return AiImageConvert.INSTANCE.convertAiImageDallDrawingRespVO(aiImageDO);
|
|
|
} catch (AiException aiException) {
|
|
|
// 保存数据库
|
|
|
- doSave(req.getPrompt(), req.getSize(), req.getModel(),
|
|
|
- null, AiImageStatusEnum.FAIL, aiException.getMessage(),
|
|
|
+ AiImageDO aiImageDO = doSave(req.getPrompt(), req.getSize(), req.getModel(),
|
|
|
+ null, null, AiImageStatusEnum.FAIL, aiException.getMessage(),
|
|
|
null, null, null);
|
|
|
// 发送错误信息
|
|
|
- respVO.setErrorMessage(aiException.getMessage());
|
|
|
- return respVO;
|
|
|
+ return AiImageConvert.INSTANCE.convertAiImageDallDrawingRespVO(aiImageDO);
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -125,7 +125,7 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
// 保存数据库
|
|
|
String messageId = String.valueOf(IdUtil.getSnowflakeNextId());
|
|
|
AiImageDO aiImageDO = doSave(req.getPrompt(), null, "midjoureny",
|
|
|
- null, AiImageStatusEnum.SUBMIT, null,
|
|
|
+ null, null, AiImageStatusEnum.SUBMIT, null,
|
|
|
messageId, null, null);
|
|
|
// 提交 midjourney 任务
|
|
|
Boolean imagine = midjourneyInteractionsApi.imagine(messageId, req.getPrompt());
|
|
@@ -149,7 +149,7 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
String mjOperationName = midjourneyOperationsVO.getLabel();
|
|
|
// 保存一个 image 任务记录
|
|
|
doSave(aiImageDO.getPrompt(), aiImageDO.getSize(), aiImageDO.getModel(),
|
|
|
- null, AiImageStatusEnum.SUBMIT, null,
|
|
|
+ null, null, AiImageStatusEnum.SUBMIT, null,
|
|
|
req.getMessageId(), req.getOperateId(), mjOperationName);
|
|
|
// 提交操作
|
|
|
midjourneyInteractionsApi.reRoll(
|
|
@@ -201,6 +201,7 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
private AiImageDO doSave(String prompt,
|
|
|
String size,
|
|
|
String model,
|
|
|
+ String picUrl,
|
|
|
String originalPicUrl,
|
|
|
AiImageStatusEnum statusEnum,
|
|
|
String errorMessage,
|
|
@@ -218,6 +219,7 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
// TODO @芋艿 如何上传到自己服务器
|
|
|
aiImageDO.setPicUrl(null);
|
|
|
aiImageDO.setStatus(statusEnum.getStatus());
|
|
|
+ aiImageDO.setPicUrl(picUrl);
|
|
|
aiImageDO.setOriginalPicUrl(originalPicUrl);
|
|
|
aiImageDO.setErrorMessage(errorMessage);
|
|
|
//
|