check-rerank-model.ts 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import { RETRIEVE_METHOD, type RetrievalConfig } from '@/types/app'
  2. import type {
  3. DefaultModelResponse,
  4. Model,
  5. } from '@/app/components/header/account-setting/model-provider-page/declarations'
  6. import { RerankingModeEnum } from '@/models/datasets'
  7. export const isReRankModelSelected = ({
  8. rerankDefaultModel,
  9. isRerankDefaultModelValid,
  10. retrievalConfig,
  11. rerankModelList,
  12. indexMethod,
  13. }: {
  14. rerankDefaultModel?: DefaultModelResponse
  15. isRerankDefaultModelValid: boolean
  16. retrievalConfig: RetrievalConfig
  17. rerankModelList: Model[]
  18. indexMethod?: string
  19. }) => {
  20. const rerankModelSelected = (() => {
  21. if (retrievalConfig.reranking_model?.reranking_model_name) {
  22. const provider = rerankModelList.find(({ provider }) => provider === retrievalConfig.reranking_model?.reranking_provider_name)
  23. return provider?.models.find(({ model }) => model === retrievalConfig.reranking_model?.reranking_model_name)
  24. }
  25. if (isRerankDefaultModelValid)
  26. return !!rerankDefaultModel
  27. return false
  28. })()
  29. if (
  30. indexMethod === 'high_quality'
  31. && (retrievalConfig.search_method === RETRIEVE_METHOD.hybrid && retrievalConfig.reranking_mode !== RerankingModeEnum.WeightedScore)
  32. && !rerankModelSelected
  33. )
  34. return false
  35. return true
  36. }
  37. export const ensureRerankModelSelected = ({
  38. rerankDefaultModel,
  39. indexMethod,
  40. retrievalConfig,
  41. }: {
  42. rerankDefaultModel: DefaultModelResponse
  43. retrievalConfig: RetrievalConfig
  44. indexMethod?: string
  45. }) => {
  46. const rerankModel = retrievalConfig.reranking_model?.reranking_model_name ? retrievalConfig.reranking_model : undefined
  47. if (
  48. indexMethod === 'high_quality'
  49. && (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.hybrid)
  50. && !rerankModel
  51. && rerankDefaultModel
  52. ) {
  53. return {
  54. ...retrievalConfig,
  55. reranking_model: {
  56. reranking_provider_name: rerankDefaultModel.provider.provider,
  57. reranking_model_name: rerankDefaultModel.model,
  58. },
  59. }
  60. }
  61. return retrievalConfig
  62. }