浏览代码

聊天对话,增加 创建对话、还是继续对话逻辑

cherishsince 1 年之前
父节点
当前提交
7794992225

+ 34 - 0
yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/ChatTypeEnum.java

@@ -0,0 +1,34 @@
+package cn.iocoder.yudao.module.ai.enums;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+
+/**
+ * 聊天类型
+ *
+ * @author fansili
+ * @time 2024/4/14 17:58
+ * @since 1.0
+ */
+@AllArgsConstructor
+@Getter
+public enum ChatTypeEnum {
+
+    ROLE_CHAT("roleChat", "角色模板聊天"),
+    USER_CHAT("userChat", "用户普通聊天"),
+
+    ;
+
+    private String type;
+
+    private String name;
+
+    public static ChatTypeEnum valueOfType(String type) {
+        for (ChatTypeEnum itemEnum : ChatTypeEnum.values()) {
+            if (itemEnum.getType().equals(type)) {
+                return itemEnum;
+            }
+        }
+        throw new IllegalArgumentException("Invalid MessageType value: " + type);
+    }
+}

+ 3 - 3
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dataobject/AiChatMessageDO.java

@@ -23,12 +23,12 @@ public class AiChatMessageDO {
     /**
      * 聊天ID,关联到特定的会话或对话
      */
-    private Long chatId;
+    private Long chatConversationId;
 
     /**
      * 角色ID,用于标识发送消息的用户或系统的身份
      */
-    private String userId;
+    private Long userId;
 
     /**
      * 消息具体内容,存储用户的发言或者系统响应的文字信息
@@ -38,7 +38,7 @@ public class AiChatMessageDO {
     /**
      * 消息类型,枚举值可能包括'system'(系统消息)、'user'(用户消息)和'assistant'(助手消息)
      */
-    private Double messageType;
+    private String messageType;
 
     /**
      * 在生成消息时采用的Top-K采样大小,

+ 94 - 8
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/ChatServiceImpl.java

@@ -1,14 +1,28 @@
 package cn.iocoder.yudao.module.ai.service.impl;
 
+import cn.hutool.core.exceptions.ExceptionUtil;
 import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
+import cn.iocoder.yudao.framework.ai.chat.messages.MessageType;
 import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
 import cn.iocoder.yudao.framework.ai.config.AiClient;
+import cn.iocoder.yudao.framework.common.exception.ServerException;
+import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
+import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
+import cn.iocoder.yudao.module.ai.dataobject.AiChatConversationDO;
+import cn.iocoder.yudao.module.ai.dataobject.AiChatMessageDO;
+import cn.iocoder.yudao.module.ai.dataobject.AiChatRoleDO;
 import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum;
+import cn.iocoder.yudao.module.ai.enums.ChatConversationTypeEnum;
+import cn.iocoder.yudao.module.ai.enums.ChatTypeEnum;
+import cn.iocoder.yudao.module.ai.mapper.AiChatConversationMapper;
+import cn.iocoder.yudao.module.ai.mapper.AiChatMessageMapper;
+import cn.iocoder.yudao.module.ai.mapper.AiChatRoleMapper;
 import cn.iocoder.yudao.module.ai.service.ChatService;
 import cn.iocoder.yudao.module.ai.vo.ChatReq;
 import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.stereotype.Service;
+import org.springframework.transaction.annotation.Transactional;
 import reactor.core.publisher.Flux;
 
 /**
@@ -24,6 +38,10 @@ import reactor.core.publisher.Flux;
 public class ChatServiceImpl implements ChatService {
 
     private final AiClient aiClient;
+    private final AiChatRoleMapper aiChatRoleMapper;
+    private final AiChatMessageMapper aiChatMessageMapper;
+    private final AiChatConversationMapper aiChatConversationMapper;
+
 
     /**
      * chat
@@ -31,16 +49,84 @@ public class ChatServiceImpl implements ChatService {
      * @param req
      * @return
      */
+    @Transactional(rollbackFor = Exception.class)
     public String chat(ChatReq req) {
+        // 获取 client 类型
         AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
-        // 创建 chat 需要的 Prompt
-        Prompt prompt = new Prompt(req.getPrompt());
-        req.setTopK(req.getTopK());
-        req.setTopP(req.getTopP());
-        req.setTemperature(req.getTemperature());
-        // 发送 call 调用
-        ChatResponse call = aiClient.call(prompt, clientNameEnum.getName());
-        return call.getResult().getOutput().getContent();
+        // 获取 对话类型(新建还是继续)
+        ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType());
+
+        AiChatConversationDO aiChatConversationDO;
+        Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
+        if (ChatConversationTypeEnum.NEW == chatConversationTypeEnum) {
+            // 创建一个新的对话
+            aiChatConversationDO = createNewChatConversation(req, loginUserId);
+        } else {
+            // 继续对话
+            if (req.getConversationId() == null) {
+                throw new ServerException(ErrorCodeConstants.AI_CHAT_CONTINUE_CONVERSATION_ID_NOT_NULL);
+            }
+            aiChatConversationDO = aiChatConversationMapper.selectById(req.getConversationId());
+        }
+
+        String content;
+        try {
+            // 创建 chat 需要的 Prompt
+            Prompt prompt = new Prompt(req.getPrompt());
+            req.setTopK(req.getTopK());
+            req.setTopP(req.getTopP());
+            req.setTemperature(req.getTemperature());
+            // 发送 call 调用
+            ChatResponse call = aiClient.call(prompt, clientNameEnum.getName());
+            content = call.getResult().getOutput().getContent();
+        } catch (Exception e) {
+            content = ExceptionUtil.getMessage(e);
+        }
+
+        // 增加 chat message 记录
+        aiChatMessageMapper.insert(
+                new AiChatMessageDO()
+                        .setId(null)
+                        .setChatConversationId(aiChatConversationDO.getId())
+                        .setUserId(loginUserId)
+                        .setMessage(req.getPrompt())
+                        .setMessageType(MessageType.USER.getValue())
+                        .setTopK(req.getTopK())
+                        .setTopP(req.getTopP())
+                        .setTemperature(req.getTemperature())
+        );
+
+        // chat count 先+1
+        aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
+        return content;
+    }
+
+    private AiChatConversationDO createNewChatConversation(ChatReq req, Long loginUserId) {
+        // 获取 chat 角色
+        String chatRoleName = null;
+        ChatTypeEnum chatTypeEnum = null;
+        Long chatRoleId = req.getChatRoleId();
+        if (req.getChatRoleId() != null) {
+            AiChatRoleDO aiChatRoleDO = aiChatRoleMapper.selectById(chatRoleId);
+            if (aiChatRoleDO == null) {
+                throw new ServerException(ErrorCodeConstants.AI_CHAT_ROLE_NOT_EXISTENT);
+            }
+            chatTypeEnum = ChatTypeEnum.ROLE_CHAT;
+            chatRoleName = aiChatRoleDO.getRoleName();
+        } else {
+            chatTypeEnum = ChatTypeEnum.USER_CHAT;
+        }
+        //
+        AiChatConversationDO insertChatConversation = new AiChatConversationDO()
+                .setId(null)
+                .setUserId(loginUserId)
+                .setChatRoleId(req.getChatRoleId())
+                .setChatRoleName(chatRoleName)
+                .setChatType(chatTypeEnum.getType())
+                .setChatCount(1)
+                .setChatTitle(req.getPrompt().substring(0, 20) + "...");
+        aiChatConversationMapper.insert(insertChatConversation);
+        return insertChatConversation;
     }
 
     /**

+ 13 - 3
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/vo/ChatReq.java

@@ -24,19 +24,29 @@ public class ChatReq {
     @Schema(description = "填入固定值,1 issues, 2 pr")
     private String prompt;
 
+    @Schema(description = "chat角色模板")
+    private Long chatRoleId;
+
     @Schema(description = "用于控制随机性和多样性的温度参数")
-    private Float temperature;
+    private Double temperature;
 
     @Schema(description = "生成时,核采样方法的概率阈值。例如,取值为0.8时,仅保留累计概率之和大于等于0.8的概率分布中的token,\n" +
             "     * 作为随机采样的候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的随机性越低。\n" +
             "     * 默认值为0.8。注意,取值不要大于等于1\n")
-    private Float topP;
+    private Double topP;
 
     @Schema(description = "在生成消息时采用的Top-K采样大小,表示模型生成回复时考虑的候选项集合的大小")
-    private Integer topK;
+    private Double topK;
 
     @Schema(description = "ai模型(查看 AiClientNameEnum)")
     @NotNull(message = "模型不能为空!")
     @Size(max = 30, message = "模型字符最大30个字符!")
     private String modal;
+
+    @Schema(description = "对话类型(new、continue)")
+    @NotNull(message = "对话类型,不能为空!")
+    private String conversationType;
+
+    @Schema(description = "对话Id")
+    private Long conversationId;
 }

+ 1 - 1
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chat/messages/AbstractMessage.java

@@ -59,7 +59,7 @@ public abstract class AbstractMessage implements Message {
 	}
 
 	protected AbstractMessage(MessageType messageType, String textContent, List<MediaData> mediaData,
-			Map<String, Object> messageProperties) {
+							  Map<String, Object> messageProperties) {
 
 		Assert.notNull(messageType, "Message type must not be null");
 		Assert.notNull(textContent, "Content must not be null");