|
@@ -12,21 +12,29 @@ import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
|
|
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
|
|
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
|
|
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
|
|
|
+import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentSearchReqVO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
|
|
|
+import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
|
|
|
+import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
|
|
|
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
|
|
|
+import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
|
|
|
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
|
|
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
|
|
import jakarta.annotation.Resource;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
-import org.springframework.ai.chat.messages.*;
|
|
|
+import org.springframework.ai.chat.messages.Message;
|
|
|
+import org.springframework.ai.chat.messages.MessageType;
|
|
|
+import org.springframework.ai.chat.messages.SystemMessage;
|
|
|
+import org.springframework.ai.chat.messages.UserMessage;
|
|
|
import org.springframework.ai.chat.model.ChatModel;
|
|
|
import org.springframework.ai.chat.model.ChatResponse;
|
|
|
import org.springframework.ai.chat.model.StreamingChatModel;
|
|
|
import org.springframework.ai.chat.prompt.ChatOptions;
|
|
|
import org.springframework.ai.chat.prompt.Prompt;
|
|
|
+import org.springframework.ai.chat.prompt.PromptTemplate;
|
|
|
import org.springframework.stereotype.Service;
|
|
|
import org.springframework.transaction.annotation.Transactional;
|
|
|
import reactor.core.publisher.Flux;
|
|
@@ -59,6 +67,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
private AiChatModelService chatModalService;
|
|
|
@Resource
|
|
|
private AiApiKeyService apiKeyService;
|
|
|
+ @Resource
|
|
|
+ private AiKnowledgeSegmentService knowledgeSegmentService;
|
|
|
|
|
|
@Transactional(rollbackFor = Exception.class)
|
|
|
public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) {
|
|
@@ -80,13 +90,16 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
|
|
|
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
|
|
|
|
|
|
- // 3.2 创建 chat 需要的 Prompt
|
|
|
- Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO);
|
|
|
+ // 3.2 召回段落
|
|
|
+ List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
|
|
|
+
|
|
|
+ // 3.3 创建 chat 需要的 Prompt
|
|
|
+ Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO);
|
|
|
ChatResponse chatResponse = chatModel.call(prompt);
|
|
|
|
|
|
- // 3.3 段式返回
|
|
|
+ // 3.4 段式返回
|
|
|
String newContent = chatResponse.getResult().getOutput().getContent();
|
|
|
- chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(newContent));
|
|
|
+ chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setSegmentIds(convertList(segmentList, AiKnowledgeSegmentDO::getId)).setContent(newContent));
|
|
|
return new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
|
|
|
.setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class).setContent(newContent));
|
|
|
}
|
|
@@ -111,11 +124,15 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
|
|
|
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
|
|
|
|
|
|
- // 3.2 构建 Prompt,并进行调用
|
|
|
- Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO);
|
|
|
+
|
|
|
+ // 3.2 召回段落
|
|
|
+ List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
|
|
|
+
|
|
|
+ // 3.3 构建 Prompt,并进行调用
|
|
|
+ Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO);
|
|
|
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
|
|
|
|
|
|
- // 3.3 流式返回
|
|
|
+ // 3.4 流式返回
|
|
|
// TODO 注意:Schedulers.immediate() 目的是,避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题
|
|
|
StringBuffer contentBuffer = new StringBuffer();
|
|
|
return streamResponse.map(chunk -> {
|
|
@@ -128,7 +145,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
}).doOnComplete(() -> {
|
|
|
// 忽略租户,因为 Flux 异步无法透传租户
|
|
|
TenantUtils.executeIgnore(() ->
|
|
|
- chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(contentBuffer.toString())));
|
|
|
+ chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setSegmentIds(convertList(segmentList, AiKnowledgeSegmentDO::getId))
|
|
|
+ .setContent(contentBuffer.toString())));
|
|
|
}).doOnError(throwable -> {
|
|
|
log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable);
|
|
|
// 忽略租户,因为 Flux 异步无法透传租户
|
|
@@ -137,18 +155,35 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR)));
|
|
|
}
|
|
|
|
|
|
- private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
|
|
|
+ private List<AiKnowledgeSegmentDO> recallSegment(String content, Long knowledgeId) {
|
|
|
+ if (Objects.isNull(knowledgeId)) {
|
|
|
+ return Collections.emptyList();
|
|
|
+ }
|
|
|
+ return knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(knowledgeId).setContent(content));
|
|
|
+ }
|
|
|
+
|
|
|
+ private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,List<AiKnowledgeSegmentDO> segmentList,
|
|
|
AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) {
|
|
|
// 1. 构建 Prompt Message 列表
|
|
|
List<Message> chatMessages = new ArrayList<>();
|
|
|
- // 1.1 system context 角色设定
|
|
|
+
|
|
|
+ // 1.1 召回内容消息构建
|
|
|
+ if (CollUtil.isNotEmpty(segmentList)) {
|
|
|
+ PromptTemplate promptTemplate = new PromptTemplate(AiChatRoleEnum.AI_KNOWLEDGE_ROLE.getSystemMessage());
|
|
|
+ StringBuilder infoBuilder = StrUtil.builder();
|
|
|
+ segmentList.forEach(segment -> infoBuilder.append(System.lineSeparator()).append(segment.getContent()));
|
|
|
+ Message message = promptTemplate.createMessage(Map.of("info", infoBuilder.toString()));
|
|
|
+ chatMessages.add(message);
|
|
|
+ }
|
|
|
+
|
|
|
+ // 1.2 system context 角色设定
|
|
|
if (StrUtil.isNotBlank(conversation.getSystemMessage())) {
|
|
|
chatMessages.add(new SystemMessage(conversation.getSystemMessage()));
|
|
|
}
|
|
|
- // 1.2 history message 历史消息
|
|
|
+ // 1.3 history message 历史消息
|
|
|
List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
|
|
|
contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent())));
|
|
|
- // 1.3 user message 新发送消息
|
|
|
+ // 1.4 user message 新发送消息
|
|
|
chatMessages.add(new UserMessage(sendReqVO.getContent()));
|
|
|
|
|
|
// 2. 构建 ChatOptions 对象
|
|
@@ -160,12 +195,12 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
|
|
|
/**
|
|
|
* 从历史消息中,获得倒序的 n 组消息作为消息上下文
|
|
|
- *
|
|
|
+ * <p>
|
|
|
* n 组:指的是 user + assistant 形成一组
|
|
|
*
|
|
|
- * @param messages 消息列表
|
|
|
+ * @param messages 消息列表
|
|
|
* @param conversation 对话
|
|
|
- * @param sendReqVO 发送请求
|
|
|
+ * @param sendReqVO 发送请求
|
|
|
* @return 消息上下文
|
|
|
*/
|
|
|
private List<AiChatMessageDO> filterContextMessages(List<AiChatMessageDO> messages,
|
|
@@ -182,7 +217,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
}
|
|
|
AiChatMessageDO userMessage = CollUtil.get(messages, i - 1);
|
|
|
if (userMessage == null || ObjUtil.notEqual(assistantMessage.getReplyId(), userMessage.getId())
|
|
|
- || StrUtil.isEmpty(assistantMessage.getContent())) {
|
|
|
+ || StrUtil.isEmpty(assistantMessage.getContent())) {
|
|
|
continue;
|
|
|
}
|
|
|
// 由于后续要 reverse 反转,所以先添加 assistantMessage
|