provider-context.tsx 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. 'use client'
  2. import { createContext, useContext, useContextSelector } from 'use-context-selector'
  3. import useSWR from 'swr'
  4. import { useEffect, useState } from 'react'
  5. import {
  6. fetchModelList,
  7. fetchModelProviders,
  8. fetchSupportRetrievalMethods,
  9. } from '@/service/common'
  10. import {
  11. ModelStatusEnum,
  12. ModelTypeEnum,
  13. } from '@/app/components/header/account-setting/model-provider-page/declarations'
  14. import type { Model, ModelProvider } from '@/app/components/header/account-setting/model-provider-page/declarations'
  15. import type { RETRIEVE_METHOD } from '@/types/app'
  16. import { Plan, type UsagePlanInfo } from '@/app/components/billing/type'
  17. import { fetchCurrentPlanInfo } from '@/service/billing'
  18. import { parseCurrentPlan } from '@/app/components/billing/utils'
  19. import { defaultPlan } from '@/app/components/billing/config'
  20. type ProviderContextState = {
  21. modelProviders: ModelProvider[]
  22. textGenerationModelList: Model[]
  23. supportRetrievalMethods: RETRIEVE_METHOD[]
  24. isAPIKeySet: boolean
  25. plan: {
  26. type: Plan
  27. usage: UsagePlanInfo
  28. total: UsagePlanInfo
  29. }
  30. isFetchedPlan: boolean
  31. enableBilling: boolean
  32. onPlanInfoChanged: () => void
  33. enableReplaceWebAppLogo: boolean
  34. modelLoadBalancingEnabled: boolean
  35. datasetOperatorEnabled: boolean
  36. }
  37. const ProviderContext = createContext<ProviderContextState>({
  38. modelProviders: [],
  39. textGenerationModelList: [],
  40. supportRetrievalMethods: [],
  41. isAPIKeySet: true,
  42. plan: {
  43. type: Plan.sandbox,
  44. usage: {
  45. vectorSpace: 32,
  46. buildApps: 12,
  47. teamMembers: 1,
  48. annotatedResponse: 1,
  49. documentsUploadQuota: 50,
  50. },
  51. total: {
  52. vectorSpace: 200,
  53. buildApps: 50,
  54. teamMembers: 1,
  55. annotatedResponse: 10,
  56. documentsUploadQuota: 500,
  57. },
  58. },
  59. isFetchedPlan: false,
  60. enableBilling: false,
  61. onPlanInfoChanged: () => { },
  62. enableReplaceWebAppLogo: false,
  63. modelLoadBalancingEnabled: false,
  64. datasetOperatorEnabled: false,
  65. })
  66. export const useProviderContext = () => useContext(ProviderContext)
  67. // Adding a dangling comma to avoid the generic parsing issue in tsx, see:
  68. // https://github.com/microsoft/TypeScript/issues/15713
  69. // eslint-disable-next-line @typescript-eslint/comma-dangle
  70. export const useProviderContextSelector = <T,>(selector: (state: ProviderContextState) => T): T =>
  71. useContextSelector(ProviderContext, selector)
  72. type ProviderContextProviderProps = {
  73. children: React.ReactNode
  74. }
  75. export const ProviderContextProvider = ({
  76. children,
  77. }: ProviderContextProviderProps) => {
  78. const { data: providersData } = useSWR('/workspaces/current/model-providers', fetchModelProviders)
  79. const fetchModelListUrlPrefix = '/workspaces/current/models/model-types/'
  80. const { data: textGenerationModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelTypeEnum.textGeneration}`, fetchModelList)
  81. const { data: supportRetrievalMethods } = useSWR('/datasets/retrieval-setting', fetchSupportRetrievalMethods)
  82. const [plan, setPlan] = useState(defaultPlan)
  83. const [isFetchedPlan, setIsFetchedPlan] = useState(false)
  84. const [enableBilling, setEnableBilling] = useState(true)
  85. const [enableReplaceWebAppLogo, setEnableReplaceWebAppLogo] = useState(false)
  86. const [modelLoadBalancingEnabled, setModelLoadBalancingEnabled] = useState(false)
  87. const [datasetOperatorEnabled, setDatasetOperatorEnabled] = useState(false)
  88. const fetchPlan = async () => {
  89. const data = await fetchCurrentPlanInfo()
  90. const enabled = data.billing.enabled
  91. setEnableBilling(enabled)
  92. setEnableReplaceWebAppLogo(data.can_replace_logo)
  93. if (enabled) {
  94. setPlan(parseCurrentPlan(data))
  95. setIsFetchedPlan(true)
  96. }
  97. if (data.model_load_balancing_enabled)
  98. setModelLoadBalancingEnabled(true)
  99. if (data.dataset_operator_enabled)
  100. setDatasetOperatorEnabled(true)
  101. }
  102. useEffect(() => {
  103. fetchPlan()
  104. }, [])
  105. return (
  106. <ProviderContext.Provider value={{
  107. modelProviders: providersData?.data || [],
  108. textGenerationModelList: textGenerationModelList?.data || [],
  109. isAPIKeySet: !!textGenerationModelList?.data.some(model => model.status === ModelStatusEnum.active),
  110. supportRetrievalMethods: supportRetrievalMethods?.retrieval_method || [],
  111. plan,
  112. isFetchedPlan,
  113. enableBilling,
  114. onPlanInfoChanged: fetchPlan,
  115. enableReplaceWebAppLogo,
  116. modelLoadBalancingEnabled,
  117. datasetOperatorEnabled,
  118. }}>
  119. {children}
  120. </ProviderContext.Provider>
  121. )
  122. }
  123. export default ProviderContext