Browse Source

【优化】AI 知识库: VectorStore 获取 抽到 AiModelFactory

xiaoxin 10 months ago
parent
commit
8e012b10bb

+ 1 - 4
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java

@@ -2,7 +2,6 @@ package cn.iocoder.yudao.module.ai.service.model;
 
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
-import cn.iocoder.yudao.framework.ai.core.factory.AiVectorStoreFactory;
 import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
 import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
 import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
@@ -39,8 +38,6 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
 
     @Resource
     private AiModelFactory modelFactory;
-    @Resource
-    private AiVectorStoreFactory vectorFactory;
 
     @Override
     public Long createApiKey(AiApiKeySaveReqVO createReqVO) {
@@ -149,7 +146,7 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
     public VectorStore getOrCreateVectorStore(Long id) {
         AiApiKeyDO apiKey = validateApiKey(id);
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
-        return vectorFactory.getOrCreateVectorStore(getEmbeddingModel(id), platform, apiKey.getApiKey(), apiKey.getUrl());
+        return modelFactory.getOrCreateVectorStore(getEmbeddingModel(id), platform, apiKey.getApiKey(), apiKey.getUrl());
     }
 
 }

+ 0 - 7
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java

@@ -2,8 +2,6 @@ package cn.iocoder.yudao.framework.ai.config;
 
 import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
 import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactoryImpl;
-import cn.iocoder.yudao.framework.ai.core.factory.AiVectorStoreFactory;
-import cn.iocoder.yudao.framework.ai.core.factory.AiVectorStoreFactoryImpl;
 import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
 import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions;
 import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
@@ -38,11 +36,6 @@ public class YudaoAiAutoConfiguration {
         return new AiModelFactoryImpl();
     }
 
-    @Bean
-    public AiVectorStoreFactory aiVectorFactory() {
-        return new AiVectorStoreFactoryImpl();
-    }
-
 
     // ========== 各种 AI Client 创建 ==========
 

+ 14 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactory.java

@@ -6,6 +6,7 @@ import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
 import org.springframework.ai.chat.model.ChatModel;
 import org.springframework.ai.embedding.EmbeddingModel;
 import org.springframework.ai.image.ImageModel;
+import org.springframework.ai.vectorstore.VectorStore;
 
 /**
  * AI Model 模型工厂的接口类
@@ -92,4 +93,17 @@ public interface AiModelFactory {
      */
     EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url);
 
+    /**
+     * 基于指定配置,获得 VectorStore 对象
+     * <p>
+     * 如果不存在,则进行创建
+     *
+     * @param embeddingModel 嵌入模型
+     * @param platform       平台
+     * @param apiKey         API KEY
+     * @param url            API URL
+     * @return VectorStore 对象
+     */
+    VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url);
+
 }

+ 27 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java

@@ -13,6 +13,7 @@ import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
 import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
 import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
 import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
+import cn.iocoder.yudao.framework.common.util.spring.SpringUtils;
 import com.alibaba.cloud.ai.tongyi.TongYiAutoConfiguration;
 import com.alibaba.cloud.ai.tongyi.TongYiConnectionProperties;
 import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel;
@@ -54,13 +55,17 @@ import org.springframework.ai.qianfan.api.QianFanApi;
 import org.springframework.ai.qianfan.api.QianFanImageApi;
 import org.springframework.ai.stabilityai.StabilityAiImageModel;
 import org.springframework.ai.stabilityai.api.StabilityAiApi;
+import org.springframework.ai.vectorstore.RedisVectorStore;
+import org.springframework.ai.vectorstore.VectorStore;
 import org.springframework.ai.zhipuai.ZhiPuAiChatModel;
 import org.springframework.ai.zhipuai.ZhiPuAiImageModel;
 import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
 import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
+import org.springframework.boot.autoconfigure.data.redis.RedisProperties;
 import org.springframework.retry.support.RetryTemplate;
 import org.springframework.web.client.ResponseErrorHandler;
 import org.springframework.web.client.RestClient;
+import redis.clients.jedis.JedisPooled;
 
 import java.util.List;
 
@@ -191,6 +196,28 @@ public class AiModelFactoryImpl implements AiModelFactory {
         });
     }
 
+    @Override
+    public VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url) {
+        String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey, url);
+        return Singleton.get(cacheKey, (Func0<VectorStore>) () -> {
+            // TODO 芋艿 @xin 这两个配置取哪好呢
+            // TODO 不同模型的向量维度可能会不一样,目前看貌似是以 index 来做区分的,维度不一样存不到一个 index 上
+            // TODO 回复:好的哈
+            String index = "default-index";
+            String prefix = "default:";
+            var config = RedisVectorStore.RedisVectorStoreConfig.builder()
+                    .withIndexName(index)
+                    .withPrefix(prefix)
+                    .build();
+            RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class);
+            RedisVectorStore redisVectorStore = new RedisVectorStore(config, embeddingModel,
+                    new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
+                    true);
+            redisVectorStore.afterPropertiesSet();
+            return redisVectorStore;
+        });
+    }
+
     private static String buildClientCacheKey(Class<?> clazz, Object... params) {
         if (ArrayUtil.isEmpty(params)) {
             return clazz.getName();

+ 0 - 28
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiVectorStoreFactory.java

@@ -1,28 +0,0 @@
-package cn.iocoder.yudao.framework.ai.core.factory;
-
-import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
-import org.springframework.ai.embedding.EmbeddingModel;
-import org.springframework.ai.vectorstore.VectorStore;
-
-// TODO @xin:也放到 AiModelFactory 里面好了,后续改成 AiFactory
-/**
- * AI Vector 模型工厂的接口类
- *
- * @author xiaoxin
- */
-public interface AiVectorStoreFactory {
-
-    /**
-     * 基于指定配置,获得 VectorStore 对象
-     * <p>
-     * 如果不存在,则进行创建
-     *
-     * @param embeddingModel 嵌入模型
-     * @param platform       平台
-     * @param apiKey         API KEY
-     * @param url            API URL
-     * @return VectorStore 对象
-     */
-    VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url);
-
-}

+ 0 - 52
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiVectorStoreFactoryImpl.java

@@ -1,52 +0,0 @@
-package cn.iocoder.yudao.framework.ai.core.factory;
-
-import cn.hutool.core.lang.Singleton;
-import cn.hutool.core.lang.func.Func0;
-import cn.hutool.core.util.ArrayUtil;
-import cn.hutool.core.util.StrUtil;
-import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
-import cn.iocoder.yudao.framework.common.util.spring.SpringUtils;
-import org.springframework.ai.embedding.EmbeddingModel;
-import org.springframework.ai.vectorstore.RedisVectorStore;
-import org.springframework.ai.vectorstore.VectorStore;
-import org.springframework.boot.autoconfigure.data.redis.RedisProperties;
-import redis.clients.jedis.JedisPooled;
-
-/**
- * AI Vector 模型工厂的实现类
- * 使用 redisVectorStore 实现 VectorStore
- *
- * @author xiaoxin
- */
-public class AiVectorStoreFactoryImpl implements AiVectorStoreFactory {
-
-    @Override
-    public VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url) {
-        String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey, url);
-        return Singleton.get(cacheKey, (Func0<VectorStore>) () -> {
-            // TODO 芋艿 @xin 这两个配置取哪好呢
-            // TODO 不同模型的向量维度可能会不一样,目前看貌似是以 index 来做区分的,维度不一样存不到一个 index 上
-            // TODO 回复:好的哈
-            String index = "default-index";
-            String prefix = "default:";
-            var config = RedisVectorStore.RedisVectorStoreConfig.builder()
-                    .withIndexName(index)
-                    .withPrefix(prefix)
-                    .build();
-            RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class);
-            RedisVectorStore redisVectorStore = new RedisVectorStore(config, embeddingModel,
-                    new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
-                    true);
-            redisVectorStore.afterPropertiesSet();
-            return redisVectorStore;
-        });
-    }
-
-    private static String buildClientCacheKey(Class<?> clazz, Object... params) {
-        if (ArrayUtil.isEmpty(params)) {
-            return clazz.getName();
-        }
-        return StrUtil.format("{}#{}", clazz.getName(), ArrayUtil.join(params, "_"));
-    }
-
-}