|
@@ -1,19 +1,16 @@
|
|
|
package cn.iocoder.yudao.module.ai.service.image;
|
|
|
|
|
|
-import cn.hutool.core.bean.BeanUtil;
|
|
|
import cn.hutool.core.codec.Base64;
|
|
|
+import cn.hutool.core.map.MapUtil;
|
|
|
import cn.hutool.core.util.ObjUtil;
|
|
|
import cn.hutool.core.util.StrUtil;
|
|
|
import cn.hutool.extra.spring.SpringUtil;
|
|
|
import cn.hutool.http.HttpUtil;
|
|
|
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
|
|
-import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageModelEnum;
|
|
|
-import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageStyleEnum;
|
|
|
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory;
|
|
|
import cn.iocoder.yudao.framework.common.pojo.PageParam;
|
|
|
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
|
|
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
|
|
-import cn.iocoder.yudao.module.ai.AiCommonConstants;
|
|
|
import cn.iocoder.yudao.module.ai.client.MidjourneyProxyClient;
|
|
|
import cn.iocoder.yudao.module.ai.client.enums.MidjourneyModelEnum;
|
|
|
import cn.iocoder.yudao.module.ai.client.enums.MidjourneySubmitCodeEnum;
|
|
@@ -21,17 +18,18 @@ import cn.iocoder.yudao.module.ai.client.enums.MidjourneyTaskStatusEnum;
|
|
|
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyImagineReqVO;
|
|
|
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyNotifyReqVO;
|
|
|
import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitRespVO;
|
|
|
-import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallReqVO;
|
|
|
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
|
|
|
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
|
|
|
-import cn.iocoder.yudao.module.ai.enums.AiImagePublicStatusEnum;
|
|
|
-import cn.iocoder.yudao.module.ai.enums.AiImageStatusEnum;
|
|
|
+import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
|
|
|
import cn.iocoder.yudao.module.infra.api.file.FileApi;
|
|
|
-import com.google.common.collect.ImmutableMap;
|
|
|
import jakarta.annotation.Resource;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
-import org.springframework.ai.image.*;
|
|
|
+import org.springframework.ai.image.ImageClient;
|
|
|
+import org.springframework.ai.image.ImageOptions;
|
|
|
+import org.springframework.ai.image.ImagePrompt;
|
|
|
+import org.springframework.ai.image.ImageResponse;
|
|
|
import org.springframework.ai.openai.OpenAiImageOptions;
|
|
|
import org.springframework.ai.stabilityai.api.StabilityAiImageOptions;
|
|
|
import org.springframework.beans.factory.annotation.Autowired;
|
|
@@ -41,7 +39,7 @@ import org.springframework.stereotype.Service;
|
|
|
import org.springframework.transaction.annotation.Transactional;
|
|
|
|
|
|
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
|
|
|
-import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.AI_IMAGE_NOT_EXISTS;
|
|
|
+import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.AI_IMAGE_NOT_EXISTS;
|
|
|
|
|
|
/**
|
|
|
* AI 绘画 Service 实现类
|
|
@@ -78,22 +76,18 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- public Long dall(Long userId, AiImageDallReqVO req) {
|
|
|
- req.setPlatform("dall"); // TODO 芋艿:临时写死
|
|
|
+ public Long drawImage(Long userId, AiImageDrawReqVO drawReqVO) {
|
|
|
// 1. 保存数据库
|
|
|
- AiImageDO image = BeanUtils.toBean(req, AiImageDO.class)
|
|
|
- .setUserId(userId).setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus())
|
|
|
- .setWidth(req.getWidth()).setHeight(req.getHeight())
|
|
|
- .setDrawRequest(ImmutableMap.of(AiCommonConstants.DRAW_REQ_KEY_STYLE, req.getStyle()))
|
|
|
- .setPublicStatus(AiImagePublicStatusEnum.PRIVATE.getStatus());
|
|
|
+ AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
|
|
|
+ .setWidth(drawReqVO.getWidth()).setHeight(drawReqVO.getHeight()).setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
|
|
|
imageMapper.insert(image);
|
|
|
- // 2. 异步绘制,后续前端通过返回的 id 进行伦旭
|
|
|
- getSelf().doDall(image, req);
|
|
|
+ // 2. 异步绘制,后续前端通过返回的 id 进行轮询结果
|
|
|
+ getSelf().doDall(image, drawReqVO);
|
|
|
return image.getId();
|
|
|
}
|
|
|
|
|
|
@Async
|
|
|
- public void doDall(AiImageDO image, AiImageDallReqVO req) {
|
|
|
+ public void doDall(AiImageDO image, AiImageDrawReqVO req) {
|
|
|
try {
|
|
|
// 1.1 构建请求
|
|
|
ImageOptions request = buildImageOptions(req);
|
|
@@ -106,7 +100,7 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
String filePath = fileApi.createFile(fileContent);
|
|
|
|
|
|
// 3. 更新数据库
|
|
|
- imageMapper.updateById(new AiImageDO().setId(image.getId()).setStatus(AiImageStatusEnum.COMPLETE.getStatus())
|
|
|
+ imageMapper.updateById(new AiImageDO().setId(image.getId()).setStatus(AiImageStatusEnum.SUCCESS.getStatus())
|
|
|
.setPicUrl(filePath));
|
|
|
} catch (Exception ex) {
|
|
|
log.error("[doDall][image({}) 生成异常]", image, ex);
|
|
@@ -115,30 +109,28 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private static ImageOptions buildImageOptions(AiImageDallReqVO draw) {
|
|
|
- if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.OPEN_AI_DALL.getPlatform())) {
|
|
|
- OpenAiImageOptions request = new OpenAiImageOptions();
|
|
|
- request.setModel(OpenAiImageModelEnum.valueOfModel(draw.getModel()).getModel());
|
|
|
- request.setStyle(OpenAiImageStyleEnum.valueOfStyle(draw.getStyle()).getStyle());
|
|
|
- request.setSize(String.format(AiCommonConstants.DALL_SIZE_TEMPLATE, draw.getWidth(), draw.getHeight()));
|
|
|
- request.setResponseFormat("b64_json");
|
|
|
- return request;
|
|
|
- } else {
|
|
|
- // https://platform.stability.ai/docs/api-reference#tag/Generate/paths/~1v2beta~1stable-image~1generate~1sd3/post
|
|
|
- return StabilityAiImageOptions.builder().withModel(draw.getModel())
|
|
|
+ private static ImageOptions buildImageOptions(AiImageDrawReqVO draw) {
|
|
|
+ if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.OPENAI.getPlatform())) {
|
|
|
+ // https://platform.openai.com/docs/api-reference/images/create
|
|
|
+ return OpenAiImageOptions.builder().withModel(draw.getModel())
|
|
|
.withHeight(draw.getHeight()).withWidth(draw.getWidth())
|
|
|
+ .withStyle(MapUtil.getStr(draw.getOptions(), "style")) // 风格
|
|
|
+ .withResponseFormat("b64_json")
|
|
|
+ .build();
|
|
|
+ } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) {
|
|
|
+ // https://platform.stability.ai/docs/api-reference#tag/Text-to-Image/operation/textToImage
|
|
|
+ return StabilityAiImageOptions.builder().withModel(draw.getModel())
|
|
|
+ .withHeight(draw.getHeight()).withWidth(draw.getWidth()) // TODO @芋艿:各种参数
|
|
|
.build();
|
|
|
}
|
|
|
-// return null;
|
|
|
+ throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform());
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
@Transactional(rollbackFor = Exception.class)
|
|
|
public Long midjourneyImagine(Long loginUserId, AiImageMidjourneyImagineReqVO req) {
|
|
|
-
|
|
|
// 1、构建 AiImageDO
|
|
|
AiImageDO aiImageDO = new AiImageDO();
|
|
|
- aiImageDO.setId(null);
|
|
|
aiImageDO.setUserId(loginUserId);
|
|
|
aiImageDO.setPrompt(req.getPrompt());
|
|
|
aiImageDO.setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform());
|
|
@@ -147,12 +139,6 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
aiImageDO.setWidth(null);
|
|
|
aiImageDO.setHeight(null);
|
|
|
aiImageDO.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
|
|
|
- aiImageDO.setPublicStatus(AiImagePublicStatusEnum.PRIVATE.getStatus());
|
|
|
- aiImageDO.setPicUrl(null);
|
|
|
- aiImageDO.setOriginalPicUrl(null);
|
|
|
- aiImageDO.setDrawRequest(null);
|
|
|
- aiImageDO.setDrawResponse(null);
|
|
|
- aiImageDO.setErrorMessage(null);
|
|
|
|
|
|
// 2、保存 image
|
|
|
imageMapper.insert(aiImageDO);
|
|
@@ -211,7 +197,7 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
//
|
|
|
String imageStatus = null;
|
|
|
if (MidjourneyTaskStatusEnum.SUCCESS == notifyReqVO.getStatus()) {
|
|
|
- imageStatus = AiImageStatusEnum.COMPLETE.getStatus();
|
|
|
+ imageStatus = AiImageStatusEnum.SUCCESS.getStatus();
|
|
|
} else if (MidjourneyTaskStatusEnum.FAILURE == notifyReqVO.getStatus()) {
|
|
|
imageStatus = AiImageStatusEnum.FAIL.getStatus();
|
|
|
}
|
|
@@ -226,8 +212,7 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
.setId(image.getId())
|
|
|
.setStatus(imageStatus)
|
|
|
.setPicUrl(filePath)
|
|
|
- .setOriginalPicUrl(notifyReqVO.getImageUrl())
|
|
|
- .setDrawResponse(BeanUtil.beanToMap(notifyReqVO))
|
|
|
+// .setOriginalPicUrl(notifyReqVO.getImageUrl()) TODO @fan:就不存原始的图片地址啦
|
|
|
);
|
|
|
return true;
|
|
|
}
|