default.ts 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import { BlockEnum } from '../../types'
  2. import type { NodeDefault } from '../../types'
  3. import type { KnowledgeRetrievalNodeType } from './types'
  4. import { checkoutRerankModelConfigedInRetrievalSettings } from './utils'
  5. import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants'
  6. import { DATASET_DEFAULT } from '@/config'
  7. import { RETRIEVE_TYPE } from '@/types/app'
  8. const i18nPrefix = 'workflow'
  9. const nodeDefault: NodeDefault<KnowledgeRetrievalNodeType> = {
  10. defaultValue: {
  11. query_variable_selector: [],
  12. dataset_ids: [],
  13. retrieval_mode: RETRIEVE_TYPE.multiWay,
  14. multiple_retrieval_config: {
  15. top_k: DATASET_DEFAULT.top_k,
  16. score_threshold: undefined,
  17. reranking_enable: false,
  18. },
  19. },
  20. getAvailablePrevNodes(isChatMode: boolean) {
  21. const nodes = isChatMode
  22. ? ALL_CHAT_AVAILABLE_BLOCKS
  23. : ALL_COMPLETION_AVAILABLE_BLOCKS.filter(type => type !== BlockEnum.End)
  24. return nodes
  25. },
  26. getAvailableNextNodes(isChatMode: boolean) {
  27. const nodes = isChatMode ? ALL_CHAT_AVAILABLE_BLOCKS : ALL_COMPLETION_AVAILABLE_BLOCKS
  28. return nodes
  29. },
  30. checkValid(payload: KnowledgeRetrievalNodeType, t: any) {
  31. let errorMessages = ''
  32. if (!errorMessages && (!payload.query_variable_selector || payload.query_variable_selector.length === 0))
  33. errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.knowledgeRetrieval.queryVariable`) })
  34. if (!errorMessages && (!payload.dataset_ids || payload.dataset_ids.length === 0))
  35. errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.knowledgeRetrieval.knowledge`) })
  36. if (!errorMessages && payload.retrieval_mode === RETRIEVE_TYPE.oneWay && !payload.single_retrieval_config?.model?.provider)
  37. errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t('common.modelProvider.systemReasoningModel.key') })
  38. const { _datasets, multiple_retrieval_config, retrieval_mode } = payload
  39. if (retrieval_mode === RETRIEVE_TYPE.multiWay) {
  40. const checked = checkoutRerankModelConfigedInRetrievalSettings(_datasets || [], multiple_retrieval_config)
  41. if (!errorMessages && !checked)
  42. errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) })
  43. }
  44. return {
  45. isValid: !errorMessages,
  46. errorMessage: errorMessages,
  47. }
  48. },
  49. }
  50. export default nodeDefault