use-config.ts 11 KB


  1. import {
  2. useCallback,
  3. useEffect,
  4. useRef,
  5. useState,
  6. } from 'react'
  7. import produce from 'immer'
  8. import { isEqual } from 'lodash-es'
  9. import type { ValueSelector, Var } from '../../types'
  10. import { BlockEnum, VarType } from '../../types'
  11. import {
  12. useIsChatMode, useNodesReadOnly,
  13. useWorkflow,
  14. } from '../../hooks'
  15. import type { KnowledgeRetrievalNodeType, MultipleRetrievalConfig } from './types'
  16. import {
  17. getMultipleRetrievalConfig,
  18. getSelectedDatasetsMode,
  19. } from './utils'
  20. import { RETRIEVE_TYPE } from '@/types/app'
  21. import { DATASET_DEFAULT } from '@/config'
  22. import type { DataSet } from '@/models/datasets'
  23. import { fetchDatasets } from '@/service/datasets'
  24. import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
  25. import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run'
  26. import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
  27. import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
  28. const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
  29. const { nodesReadOnly: readOnly } = useNodesReadOnly()
  30. const isChatMode = useIsChatMode()
  31. const { getBeforeNodesInSameBranch } = useWorkflow()
  32. const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start)
  33. const startNodeId = startNode?.id
  34. const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
  35. const inputRef = useRef(inputs)
  36. const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => {
  37. const newInputs = produce(s, (draft) => {
  38. if (s.retrieval_mode === RETRIEVE_TYPE.multiWay)
  39. delete draft.single_retrieval_config
  40. else
  41. delete draft.multiple_retrieval_config
  42. })
  43. // not work in pass to draft...
  44. doSetInputs(newInputs)
  45. inputRef.current = newInputs
  46. }, [doSetInputs])
  47. const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => {
  48. const newInputs = produce(inputs, (draft) => {
  49. draft.query_variable_selector = newVar as ValueSelector
  50. })
  51. setInputs(newInputs)
  52. }, [inputs, setInputs])
  53. const {
  54. currentProvider,
  55. currentModel,
  56. } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration)
  57. const {
  58. modelList: rerankModelList,
  59. defaultModel: rerankDefaultModel,
  60. } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
  61. const {
  62. currentModel: currentRerankModel,
  63. currentProvider: currentRerankProvider,
  64. } = useCurrentProviderAndModel(
  65. rerankModelList,
  66. rerankDefaultModel
  67. ? {
  68. ...rerankDefaultModel,
  69. provider: rerankDefaultModel.provider.provider,
  70. }
  71. : undefined,
  72. )
  73. const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
  74. const newInputs = produce(inputRef.current, (draft) => {
  75. if (!draft.single_retrieval_config) {
  76. draft.single_retrieval_config = {
  77. model: {
  78. provider: '',
  79. name: '',
  80. mode: '',
  81. completion_params: {},
  82. },
  83. }
  84. }
  85. const draftModel = draft.single_retrieval_config?.model
  86. draftModel.provider = model.provider
  87. draftModel.name = model.modelId
  88. draftModel.mode = model.mode!
  89. })
  90. setInputs(newInputs)
  91. }, [setInputs])
  92. const handleCompletionParamsChange = useCallback((newParams: Record<string, any>) => {
  93. // inputRef.current.single_retrieval_config?.model is old when change the provider...
  94. if (isEqual(newParams, inputRef.current.single_retrieval_config?.model.completion_params))
  95. return
  96. const newInputs = produce(inputRef.current, (draft) => {
  97. if (!draft.single_retrieval_config) {
  98. draft.single_retrieval_config = {
  99. model: {
  100. provider: '',
  101. name: '',
  102. mode: '',
  103. completion_params: {},
  104. },
  105. }
  106. }
  107. draft.single_retrieval_config.model.completion_params = newParams
  108. })
  109. setInputs(newInputs)
  110. }, [setInputs])
  111. // set defaults models
  112. useEffect(() => {
  113. const inputs = inputRef.current
  114. if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider && currentRerankModel && rerankDefaultModel)
  115. return
  116. if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider)
  117. return
  118. const newInput = produce(inputs, (draft) => {
  119. if (currentProvider?.provider && currentModel?.model) {
  120. const hasSetModel = draft.single_retrieval_config?.model?.provider
  121. if (!hasSetModel) {
  122. draft.single_retrieval_config = {
  123. model: {
  124. provider: currentProvider?.provider,
  125. name: currentModel?.model,
  126. mode: currentModel?.model_properties?.mode as string,
  127. completion_params: {},
  128. },
  129. }
  130. }
  131. }
  132. const multipleRetrievalConfig = draft.multiple_retrieval_config
  133. draft.multiple_retrieval_config = {
  134. top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k,
  135. score_threshold: multipleRetrievalConfig?.score_threshold,
  136. reranking_model: multipleRetrievalConfig?.reranking_model,
  137. reranking_mode: multipleRetrievalConfig?.reranking_mode,
  138. weights: multipleRetrievalConfig?.weights,
  139. reranking_enable: multipleRetrievalConfig?.reranking_enable !== undefined
  140. ? multipleRetrievalConfig.reranking_enable
  141. : Boolean(currentRerankModel && rerankDefaultModel),
  142. }
  143. })
  144. setInputs(newInput)
  145. // eslint-disable-next-line react-hooks/exhaustive-deps
  146. }, [currentProvider?.provider, currentModel, rerankDefaultModel])
  147. const [selectedDatasets, setSelectedDatasets] = useState<DataSet[]>([])
  148. const [rerankModelOpen, setRerankModelOpen] = useState(false)
  149. const handleRetrievalModeChange = useCallback((newMode: RETRIEVE_TYPE) => {
  150. const newInputs = produce(inputs, (draft) => {
  151. draft.retrieval_mode = newMode
  152. if (newMode === RETRIEVE_TYPE.multiWay) {
  153. const multipleRetrievalConfig = draft.multiple_retrieval_config
  154. draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, {
  155. provider: currentRerankProvider?.provider,
  156. model: currentRerankModel?.model,
  157. })
  158. }
  159. else {
  160. const hasSetModel = draft.single_retrieval_config?.model?.provider
  161. if (!hasSetModel) {
  162. draft.single_retrieval_config = {
  163. model: {
  164. provider: currentProvider?.provider || '',
  165. name: currentModel?.model || '',
  166. mode: currentModel?.model_properties?.mode as string,
  167. completion_params: {},
  168. },
  169. }
  170. }
  171. }
  172. })
  173. setInputs(newInputs)
  174. }, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel, currentRerankProvider])
  175. const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => {
  176. const newInputs = produce(inputs, (draft) => {
  177. draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, {
  178. provider: currentRerankProvider?.provider,
  179. model: currentRerankModel?.model,
  180. })
  181. })
  182. setInputs(newInputs)
  183. }, [inputs, setInputs, selectedDatasets, currentRerankModel, currentRerankProvider])
  184. // datasets
  185. useEffect(() => {
  186. (async () => {
  187. const inputs = inputRef.current
  188. const datasetIds = inputs.dataset_ids
  189. if (datasetIds?.length > 0) {
  190. const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasetIds } })
  191. setSelectedDatasets(dataSetsWithDetail)
  192. }
  193. const newInputs = produce(inputs, (draft) => {
  194. draft.dataset_ids = datasetIds
  195. draft._datasets = selectedDatasets
  196. })
  197. setInputs(newInputs)
  198. })()
  199. // eslint-disable-next-line react-hooks/exhaustive-deps
  200. }, [])
  201. useEffect(() => {
  202. const inputs = inputRef.current
  203. let query_variable_selector: ValueSelector = inputs.query_variable_selector
  204. if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId)
  205. query_variable_selector = [startNodeId, 'sys.query']
  206. setInputs(produce(inputs, (draft) => {
  207. draft.query_variable_selector = query_variable_selector
  208. }))
  209. // eslint-disable-next-line react-hooks/exhaustive-deps
  210. }, [])
  211. const handleOnDatasetsChange = useCallback((newDatasets: DataSet[]) => {
  212. const {
  213. mixtureHighQualityAndEconomic,
  214. mixtureInternalAndExternal,
  215. inconsistentEmbeddingModel,
  216. allInternal,
  217. allExternal,
  218. } = getSelectedDatasetsMode(newDatasets)
  219. const newInputs = produce(inputs, (draft) => {
  220. draft.dataset_ids = newDatasets.map(d => d.id)
  221. draft._datasets = newDatasets
  222. if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) {
  223. const multipleRetrievalConfig = draft.multiple_retrieval_config
  224. draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, {
  225. provider: currentRerankProvider?.provider,
  226. model: currentRerankModel?.model,
  227. })
  228. }
  229. })
  230. setInputs(newInputs)
  231. setSelectedDatasets(newDatasets)
  232. if (
  233. (allInternal && (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel))
  234. || mixtureInternalAndExternal
  235. || allExternal
  236. )
  237. setRerankModelOpen(true)
  238. }, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel, currentRerankProvider])
  239. const filterVar = useCallback((varPayload: Var) => {
  240. return varPayload.type === VarType.string
  241. }, [])
  242. // single run
  243. const {
  244. isShowSingleRun,
  245. hideSingleRun,
  246. runningStatus,
  247. handleRun,
  248. handleStop,
  249. runInputData,
  250. setRunInputData,
  251. runResult,
  252. } = useOneStepRun<KnowledgeRetrievalNodeType>({
  253. id,
  254. data: inputs,
  255. defaultRunInputData: {
  256. query: '',
  257. },
  258. })
  259. const query = runInputData.query
  260. const setQuery = useCallback((newQuery: string) => {
  261. setRunInputData({
  262. ...runInputData,
  263. query: newQuery,
  264. })
  265. }, [runInputData, setRunInputData])
  266. return {
  267. readOnly,
  268. inputs,
  269. handleQueryVarChange,
  270. filterVar,
  271. handleRetrievalModeChange,
  272. handleMultipleRetrievalConfigChange,
  273. handleModelChanged,
  274. handleCompletionParamsChange,
  275. selectedDatasets: selectedDatasets.filter(d => d.name),
  276. handleOnDatasetsChange,
  277. isShowSingleRun,
  278. hideSingleRun,
  279. runningStatus,
  280. handleRun,
  281. handleStop,
  282. query,
  283. setQuery,
  284. runResult,
  285. rerankModelOpen,
  286. setRerankModelOpen,
  287. }
  288. }
  289. export default useConfig