فهرست منبع

【优化】dall 绘画,改为异步。

cherishsince 10 ماه پیش
والد
کامیت
e97408b3ac

+ 47 - 28
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java

@@ -35,6 +35,9 @@ import org.springframework.transaction.annotation.Transactional;
 
 import java.util.Collections;
 import java.util.List;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
 
 /**
  * ai 作图
@@ -53,6 +56,8 @@ public class AiImageServiceImpl implements AiImageService {
     private final OpenAiImageClient openAiImageClient;
     private final MidjourneyWebSocketStarter midjourneyWebSocketStarter;
     private final MidjourneyInteractionsApi midjourneyInteractionsApi;
+    private static ThreadPoolExecutor EXECUTOR = new ThreadPoolExecutor(
+            3, 5, 1, TimeUnit.HOURS, new LinkedBlockingQueue<>(32));
 
     @PostConstruct
     public void startMidjourney() {
@@ -89,34 +94,48 @@ public class AiImageServiceImpl implements AiImageService {
 
     @Override
     public AiImageDallRespVO dallDrawing(AiImageDallReqVO req) {
-        // 获取 model
-        OpenAiImageModelEnum openAiImageModelEnum = OpenAiImageModelEnum.valueOfModel(req.getModel());
-        OpenAiImageStyleEnum openAiImageStyleEnum = OpenAiImageStyleEnum.valueOfStyle(req.getStyle());
-        try {
-            // 转换openai 参数
-            OpenAiImageOptions openAiImageOptions = new OpenAiImageOptions();
-            openAiImageOptions.setModel(openAiImageModelEnum.getModel());
-            openAiImageOptions.setStyle(openAiImageStyleEnum.getStyle());
-            openAiImageOptions.setSize(req.getSize());
-            ImageResponse imageResponse = openAiImageClient.call(new ImagePrompt(req.getPrompt(), openAiImageOptions));
-            // 发送
-            ImageGeneration imageGeneration = imageResponse.getResult();
-            // 图片保存到服务器
-            String filePath = fileApi.createFile(HttpUtil.downloadBytes(imageGeneration.getOutput().getUrl()));
-            // 保存数据库
-            AiImageDO aiImageDO = doSave(req.getPrompt(), req.getSize(), req.getModel(),
-                    filePath, imageGeneration.getOutput().getUrl(), AiImageStatusEnum.COMPLETE, null,
-                    null, null, null);
-            // 转换 AiImageDallDrawingRespVO
-            return AiImageConvert.INSTANCE.convertAiImageDallDrawingRespVO(aiImageDO);
-        } catch (AiException aiException) {
-            // 保存数据库
-            AiImageDO aiImageDO = doSave(req.getPrompt(), req.getSize(), req.getModel(),
-                    null, null, AiImageStatusEnum.FAIL, aiException.getMessage(),
-                    null, null, null);
-            // 发送错误信息
-            return AiImageConvert.INSTANCE.convertAiImageDallDrawingRespVO(aiImageDO);
-        }
+        // 保存数据库
+        AiImageDO aiImageDO = doSave(req.getPrompt(), req.getSize(), req.getModel(),
+                null, null, AiImageStatusEnum.IN_PROGRESS, null,
+                null, null, null);
+        // 异步执行
+        EXECUTOR.execute(() -> {
+            try {
+
+                // 获取 model
+                OpenAiImageModelEnum openAiImageModelEnum = OpenAiImageModelEnum.valueOfModel(req.getModel());
+                OpenAiImageStyleEnum openAiImageStyleEnum = OpenAiImageStyleEnum.valueOfStyle(req.getStyle());
+
+                // 转换openai 参数
+                OpenAiImageOptions openAiImageOptions = new OpenAiImageOptions();
+                openAiImageOptions.setModel(openAiImageModelEnum.getModel());
+                openAiImageOptions.setStyle(openAiImageStyleEnum.getStyle());
+                openAiImageOptions.setSize(req.getSize());
+                ImageResponse imageResponse = openAiImageClient.call(new ImagePrompt(req.getPrompt(), openAiImageOptions));
+                // 发送
+                ImageGeneration imageGeneration = imageResponse.getResult();
+                // 图片保存到服务器
+                String filePath = fileApi.createFile(HttpUtil.downloadBytes(imageGeneration.getOutput().getUrl()));
+                // 更新数据库
+                aiImageMapper.updateById(
+                        new AiImageDO()
+                                .setId(aiImageDO.getId())
+                                .setStatus(AiImageStatusEnum.COMPLETE.getStatus())
+                                .setPicUrl(filePath)
+                                .setOriginalPicUrl(imageGeneration.getOutput().getUrl())
+                );
+            } catch (AiException aiException) {
+                // 更新错误信息
+                aiImageMapper.updateById(
+                        new AiImageDO()
+                                .setId(aiImageDO.getId())
+                                .setStatus(AiImageStatusEnum.FAIL.getStatus())
+                                .setErrorMessage(aiException.getMessage())
+                );
+            }
+        });
+        // 转换 AiImageDallDrawingRespVO
+        return AiImageConvert.INSTANCE.convertAiImageDallDrawingRespVO(aiImageDO);
     }
 
     @Override