|
@@ -14,8 +14,10 @@ import cn.iocoder.yudao.framework.ai.midjourney.webSocket.WssNotify;
|
|
|
import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
|
|
|
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
|
|
|
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
|
|
|
-import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReq;
|
|
|
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReqVO;
|
|
|
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingRespVO;
|
|
|
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReq;
|
|
|
+import cn.iocoder.yudao.module.ai.convert.AiImageConvert;
|
|
|
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.mysql.AiImageMapper;
|
|
|
import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum;
|
|
@@ -60,31 +62,35 @@ public class AiImageServiceImpl implements AiImageService {
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- public void dallDrawing(AiImageDallDrawingReq req) {
|
|
|
+ public AiImageDallDrawingRespVO dallDrawing(AiImageDallDrawingReqVO req) {
|
|
|
// 获取 model
|
|
|
OpenAiImageModelEnum openAiImageModelEnum = OpenAiImageModelEnum.valueOfModel(req.getModal());
|
|
|
OpenAiImageStyleEnum openAiImageStyleEnum = OpenAiImageStyleEnum.valueOfStyle(req.getStyle());
|
|
|
- //
|
|
|
- OpenAiImageOptions openAiImageOptions = new OpenAiImageOptions();
|
|
|
- openAiImageOptions.setModel(openAiImageModelEnum);
|
|
|
- openAiImageOptions.setStyle(openAiImageStyleEnum);
|
|
|
- openAiImageOptions.setSize(req.getSize());
|
|
|
- ImageResponse imageResponse;
|
|
|
+ // 转换 AiImageDallDrawingRespVO
|
|
|
+ AiImageDallDrawingRespVO respVO = AiImageConvert.INSTANCE.convertAiImageDallDrawingRespVO(req);
|
|
|
try {
|
|
|
- imageResponse = openAiImageClient.call(new ImagePrompt(req.getPrompt(), openAiImageOptions));
|
|
|
+ // 转换openai 参数
|
|
|
+ OpenAiImageOptions openAiImageOptions = new OpenAiImageOptions();
|
|
|
+ openAiImageOptions.setModel(openAiImageModelEnum);
|
|
|
+ openAiImageOptions.setStyle(openAiImageStyleEnum);
|
|
|
+ openAiImageOptions.setSize(req.getSize());
|
|
|
+ ImageResponse imageResponse = openAiImageClient.call(new ImagePrompt(req.getPrompt(), openAiImageOptions));
|
|
|
// 发送
|
|
|
ImageGeneration imageGeneration = imageResponse.getResult();
|
|
|
- // 发送信息
|
|
|
-// sendSseEmitter(sseEmitter, imageGeneration);
|
|
|
// 保存数据库
|
|
|
doSave(req.getPrompt(), req.getSize(), req.getModal(),
|
|
|
imageGeneration.getOutput().getUrl(), AiChatDrawingStatusEnum.COMPLETE, null);
|
|
|
+ // 返回 flex
|
|
|
+ respVO.setUrl(imageGeneration.getOutput().getUrl());
|
|
|
+ respVO.setBase64(imageGeneration.getOutput().getB64Json());
|
|
|
+ return respVO;
|
|
|
} catch (AiException aiException) {
|
|
|
// 保存数据库
|
|
|
doSave(req.getPrompt(), req.getSize(), req.getModal(),
|
|
|
null, AiChatDrawingStatusEnum.FAIL, aiException.getMessage());
|
|
|
// 发送错误信息
|
|
|
-// sendSseEmitter(sseEmitter, aiException.getMessage());
|
|
|
+ respVO.setErrorMessage(aiException.getMessage());
|
|
|
+ return respVO;
|
|
|
}
|
|
|
}
|
|
|
|