|
@@ -2,8 +2,7 @@ package cn.iocoder.yudao.module.ai.service.write;
|
|
|
|
|
|
import cn.hutool.core.util.StrUtil;
|
|
|
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
|
|
-import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
|
|
-import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
|
|
|
+import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
|
|
|
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
|
|
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
|
|
import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO;
|
|
@@ -16,16 +15,12 @@ import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
|
|
|
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
|
|
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
|
|
import cn.iocoder.yudao.module.system.api.dict.DictDataApi;
|
|
|
-import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
|
|
|
import jakarta.annotation.Resource;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
import org.springframework.ai.chat.model.ChatResponse;
|
|
|
import org.springframework.ai.chat.model.StreamingChatModel;
|
|
|
import org.springframework.ai.chat.prompt.ChatOptions;
|
|
|
import org.springframework.ai.chat.prompt.Prompt;
|
|
|
-import org.springframework.ai.ollama.api.OllamaOptions;
|
|
|
-import org.springframework.ai.openai.OpenAiChatOptions;
|
|
|
-import org.springframework.ai.qianfan.QianFanChatOptions;
|
|
|
import org.springframework.stereotype.Service;
|
|
|
import reactor.core.publisher.Flux;
|
|
|
|
|
@@ -56,19 +51,21 @@ public class AiWriteServiceImpl implements AiWriteService {
|
|
|
|
|
|
@Override
|
|
|
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
|
|
|
- // 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok?
|
|
|
+ // 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok?;那可以,有限拿 chatRole 的角色;如果没有,则获取默认的;
|
|
|
AiChatModelDO model = chatModalService.getRequiredDefaultChatModel();
|
|
|
StreamingChatModel chatClient = apiKeyService.getChatClient(model.getKeyId());
|
|
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
|
|
- ChatOptions chatOptions = buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
|
|
|
|
|
|
// 1.2 插入写作信息
|
|
|
+ // TODO @xin:建议把 writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()),写在 toBean 的 consumer 里;原因是,让这个 set 保持完整性
|
|
|
AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class);
|
|
|
writeMapper.insert(writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
|
|
|
|
|
|
// 2.1 构建提示词
|
|
|
+ ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
|
|
|
Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions);
|
|
|
Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
|
|
|
+
|
|
|
// 2.2 流式返回
|
|
|
StringBuffer contentBuffer = new StringBuffer();
|
|
|
return streamResponse.map(chunk -> {
|
|
@@ -92,7 +89,9 @@ public class AiWriteServiceImpl implements AiWriteService {
|
|
|
String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getFormat());
|
|
|
String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getFormat());
|
|
|
String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getFormat());
|
|
|
+ // TODO @xin:建议改成 if return 哈;更简洁;
|
|
|
if (Objects.equals(writeType, AiWriteTypeEnum.WRITING.getType())) {
|
|
|
+ // TODO @xin:写成静态枚举哈
|
|
|
template = "请撰写一篇关于 [{}] 的文章。文章的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。";
|
|
|
return StrUtil.format(template, generateReqVO.getPrompt(), format, tone, language, length);
|
|
|
} else if (Objects.equals(writeType, AiWriteTypeEnum.REPLY.getType())) {
|
|
@@ -103,27 +102,4 @@ public class AiWriteServiceImpl implements AiWriteService {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // TODO 芋艿:复用
|
|
|
- private static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
|
|
|
- Float temperatureF = temperature != null ? temperature.floatValue() : null;
|
|
|
- //noinspection EnhancedSwitchMigration
|
|
|
- switch (platform) {
|
|
|
- case OPENAI:
|
|
|
- return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
|
|
- case OLLAMA:
|
|
|
- return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
|
|
|
- case YI_YAN:
|
|
|
- // TODO 芋艿:貌似 model 只要一设置,就报错
|
|
|
-// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
|
|
- return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
|
|
- case XING_HUO:
|
|
|
- return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
|
|
|
- .setMaxTokens(maxTokens);
|
|
|
- case QIAN_WEN:
|
|
|
- return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
|
|
|
- default:
|
|
|
- throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
}
|