|
@@ -1,13 +1,17 @@
|
|
|
package cn.iocoder.yudao.module.ai.service.write;
|
|
|
|
|
|
+import cn.hutool.core.collection.CollUtil;
|
|
|
import cn.hutool.core.util.StrUtil;
|
|
|
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
|
|
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
|
|
|
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
|
|
+import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
|
|
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
|
|
import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
|
|
|
+import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRolePageReqVO;
|
|
|
import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
|
|
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper;
|
|
|
import cn.iocoder.yudao.module.ai.enums.DictTypeConstants;
|
|
@@ -15,6 +19,7 @@ import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
|
|
|
import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
|
|
|
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
|
|
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
|
|
+import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
|
|
|
import cn.iocoder.yudao.module.system.api.dict.DictDataApi;
|
|
|
import jakarta.annotation.Resource;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
@@ -25,6 +30,7 @@ import org.springframework.ai.chat.prompt.Prompt;
|
|
|
import org.springframework.stereotype.Service;
|
|
|
import reactor.core.publisher.Flux;
|
|
|
|
|
|
+import java.util.List;
|
|
|
import java.util.Objects;
|
|
|
|
|
|
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
|
|
@@ -43,6 +49,8 @@ public class AiWriteServiceImpl implements AiWriteService {
|
|
|
private AiApiKeyService apiKeyService;
|
|
|
@Resource
|
|
|
private AiChatModelService chatModalService;
|
|
|
+ @Resource
|
|
|
+ private AiChatRoleService chatRoleService;
|
|
|
|
|
|
@Resource
|
|
|
private DictDataApi dictDataApi;
|
|
@@ -52,15 +60,22 @@ public class AiWriteServiceImpl implements AiWriteService {
|
|
|
|
|
|
@Override
|
|
|
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
|
|
|
- // 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok?;那可以,有限拿 chatRole 的角色;如果没有,则获取默认的;
|
|
|
- AiChatModelDO model = chatModalService.getRequiredDefaultChatModel();
|
|
|
- StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
|
|
|
+ // 1.1 获取写作模型 尝试获取写作助手角色,如果没有则使用默认模型
|
|
|
+ AiChatRoleDO writeRole = selectOneWriteRole();
|
|
|
+ AiChatModelDO model;
|
|
|
+ if (Objects.nonNull(writeRole)) {
|
|
|
+ model = chatModalService.getChatModel(writeRole.getModelId());
|
|
|
+ } else {
|
|
|
+ model = chatModalService.getRequiredDefaultChatModel();
|
|
|
+ }
|
|
|
+
|
|
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
|
|
|
|
|
+ StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
|
|
|
+
|
|
|
// 1.2 插入写作信息
|
|
|
- // TODO @xin:建议把 writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()),写在 toBean 的 consumer 里;原因是,让这个 set 保持完整性
|
|
|
- AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class);
|
|
|
- writeMapper.insert(writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
|
|
|
+ AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
|
|
|
+ writeMapper.insert(writeDO);
|
|
|
|
|
|
// 2.1 构建提示词
|
|
|
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
|
|
@@ -87,23 +102,30 @@ public class AiWriteServiceImpl implements AiWriteService {
|
|
|
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
|
|
|
}
|
|
|
|
|
|
+ private AiChatRoleDO selectOneWriteRole() {
|
|
|
+ AiChatRoleDO chatRoleDO = null;
|
|
|
+ PageResult<AiChatRoleDO> writeRolePage = chatRoleService.getChatRolePage(new AiChatRolePageReqVO().setName("写作助手"));
|
|
|
+ List<AiChatRoleDO> list = writeRolePage.getList();
|
|
|
+ if (CollUtil.isNotEmpty(list)) {
|
|
|
+ chatRoleDO = list.get(0);
|
|
|
+ }
|
|
|
+ return chatRoleDO;
|
|
|
+ }
|
|
|
+
|
|
|
private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
|
|
|
- String template;
|
|
|
- Integer writeType = generateReqVO.getType();
|
|
|
+ Integer type = generateReqVO.getType();
|
|
|
String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
|
|
|
- String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getFormat());
|
|
|
- String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getFormat());
|
|
|
- String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getFormat());
|
|
|
- // TODO @xin:建议改成 if return 哈;更简洁;
|
|
|
- if (Objects.equals(writeType, AiWriteTypeEnum.WRITING.getType())) {
|
|
|
- // TODO @xin:写成静态枚举哈
|
|
|
- template = "请撰写一篇关于 [{}] 的文章。文章的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。";
|
|
|
- return StrUtil.format(template, generateReqVO.getPrompt(), format, tone, language, length);
|
|
|
- } else if (Objects.equals(writeType, AiWriteTypeEnum.REPLY.getType())) {
|
|
|
- template = "请针对如下内容:[{}] 做个回复。回复内容参考:[{}], 回复的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。";
|
|
|
- return StrUtil.format(template, generateReqVO.getOriginalContent(), generateReqVO.getPrompt(), format, tone, language, length);
|
|
|
+ String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());
|
|
|
+ String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage());
|
|
|
+ String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getLength());
|
|
|
+ String prompt = generateReqVO.getPrompt();
|
|
|
+ // 校验写作类型是否合法
|
|
|
+ AiWriteTypeEnum.validateType(type);
|
|
|
+
|
|
|
+ if (Objects.equals(type, AiWriteTypeEnum.WRITING.getType())) {
|
|
|
+ return StrUtil.format(AiWriteTypeEnum.WRITING.getTemplate(), prompt, format, tone, language, length);
|
|
|
} else {
|
|
|
- throw new IllegalArgumentException(StrUtil.format("未知写作类型({})", writeType));
|
|
|
+ return StrUtil.format(AiWriteTypeEnum.REPLY.getTemplate(), generateReqVO.getOriginalContent(), prompt, format, tone, language, length);
|
|
|
}
|
|
|
}
|
|
|
|