Browse Source

【新增】AI:对话消息记录召回段落

xiaoxin 10 months ago
parent
commit
c05d7c9f95

+ 16 - 2
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java

@@ -1,13 +1,18 @@
 package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
 
-import com.baomidou.mybatisplus.annotation.TableId;
-import org.springframework.ai.chat.messages.MessageType;
 import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
 import com.baomidou.mybatisplus.annotation.KeySequence;
+import com.baomidou.mybatisplus.annotation.TableField;
+import com.baomidou.mybatisplus.annotation.TableId;
 import com.baomidou.mybatisplus.annotation.TableName;
+import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler;
 import lombok.*;
+import org.springframework.ai.chat.messages.MessageType;
+
+import java.util.List;
 
 /**
  * AI Chat 消息 DO
@@ -66,6 +71,15 @@ public class AiChatMessageDO extends BaseDO {
      */
     private Long roleId;
 
+
+    /**
+     * 段落编号数组
+     *
+     * 关联 {@link AiKnowledgeSegmentDO#getId()} 字段
+     */
+    @TableField(typeHandler = JacksonTypeHandler.class)
+    private List<Long> segmentIds;
+
     /**
      * 模型标志
      */

+ 32 - 19
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java

@@ -90,13 +90,16 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
                 userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
 
-        // 3.2 创建 chat 需要的 Prompt
-        Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO);
+        // 3.2 召回段落
+        List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
+
+        // 3.3 创建 chat 需要的 Prompt
+        Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO);
         ChatResponse chatResponse = chatModel.call(prompt);
 
-        // 3.3 段式返回
+        // 3.4 段式返回
         String newContent = chatResponse.getResult().getOutput().getContent();
-        chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(newContent));
+        chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setSegmentIds(convertList(segmentList, AiKnowledgeSegmentDO::getId)).setContent(newContent));
         return new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
                 .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class).setContent(newContent));
     }
@@ -121,11 +124,15 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
                 userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
 
-        // 3.2 构建 Prompt,并进行调用
-        Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO);
+
+        // 3.2 召回段落
+        List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
+
+        // 3.3 构建 Prompt,并进行调用
+        Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO);
         Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
 
-        // 3.3 流式返回
+        // 3.4 流式返回
         // TODO 注意:Schedulers.immediate() 目的是,避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题
         StringBuffer contentBuffer = new StringBuffer();
         return streamResponse.map(chunk -> {
@@ -138,7 +145,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         }).doOnComplete(() -> {
             // 忽略租户,因为 Flux 异步无法透传租户
             TenantUtils.executeIgnore(() ->
-                    chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(contentBuffer.toString())));
+                    chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setSegmentIds(convertList(segmentList, AiKnowledgeSegmentDO::getId))
+                            .setContent(contentBuffer.toString())));
         }).doOnError(throwable -> {
             log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable);
             // 忽略租户,因为 Flux 异步无法透传租户
@@ -147,21 +155,26 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR)));
     }
 
-    private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
+    private List<AiKnowledgeSegmentDO> recallSegment(String content, Long knowledgeId) {
+        List<AiKnowledgeSegmentDO> segmentList = new ArrayList<>();
+        if (Objects.nonNull(knowledgeId)) {
+            segmentList = knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(knowledgeId).setContent(content));
+        }
+        return segmentList;
+    }
+
+    private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,List<AiKnowledgeSegmentDO> segmentList,
                                AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) {
         // 1. 构建 Prompt Message 列表
         List<Message> chatMessages = new ArrayList<>();
 
-        // 1.1 知识库召回
-        if (Objects.nonNull(conversation.getKnowledgeId())) {
-            List<AiKnowledgeSegmentDO> segmentList = knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(conversation.getKnowledgeId()).setContent(sendReqVO.getContent()));
-            if (CollUtil.isNotEmpty(segmentList)) {
-                PromptTemplate promptTemplate = new PromptTemplate(AiChatRoleEnum.AI_KNOWLEDGE_ROLE.getSystemMessage());
-                StringBuilder infoBuilder = StrUtil.builder();
-                segmentList.forEach(segment -> infoBuilder.append(System.lineSeparator()).append(segment.getContent()));
-                Message message = promptTemplate.createMessage(Map.of("info", infoBuilder.toString()));
-                chatMessages.add(message);
-            }
+        // 1.1 召回内容消息构建
+        if (CollUtil.isNotEmpty(segmentList)) {
+            PromptTemplate promptTemplate = new PromptTemplate(AiChatRoleEnum.AI_KNOWLEDGE_ROLE.getSystemMessage());
+            StringBuilder infoBuilder = StrUtil.builder();
+            segmentList.forEach(segment -> infoBuilder.append(System.lineSeparator()).append(segment.getContent()));
+            Message message = promptTemplate.createMessage(Map.of("info", infoBuilder.toString()));
+            chatMessages.add(message);
         }
 
         // 1.2 system context 角色设定