hooks.ts 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. import {
  2. useCallback,
  3. useEffect,
  4. useMemo,
  5. useState,
  6. } from 'react'
  7. import useSWR, { useSWRConfig } from 'swr'
  8. import { useContext } from 'use-context-selector'
  9. import type {
  10. CustomConfigurationModelFixedFields,
  11. DefaultModel,
  12. DefaultModelResponse,
  13. Model,
  14. ModelTypeEnum,
  15. } from './declarations'
  16. import {
  17. ConfigurationMethodEnum,
  18. ModelStatusEnum,
  19. } from './declarations'
  20. import I18n from '@/context/i18n'
  21. import {
  22. fetchDefaultModal,
  23. fetchModelList,
  24. fetchModelProviderCredentials,
  25. fetchModelProviders,
  26. getPayUrl,
  27. } from '@/service/common'
  28. import { useProviderContext } from '@/context/provider-context'
  29. type UseDefaultModelAndModelList = (
  30. defaultModel: DefaultModelResponse | undefined,
  31. modelList: Model[],
  32. ) => [DefaultModel | undefined, (model: DefaultModel) => void]
  33. export const useSystemDefaultModelAndModelList: UseDefaultModelAndModelList = (
  34. defaultModel,
  35. modelList,
  36. ) => {
  37. const currentDefaultModel = useMemo(() => {
  38. const currentProvider = modelList.find(provider => provider.provider === defaultModel?.provider.provider)
  39. const currentModel = currentProvider?.models.find(model => model.model === defaultModel?.model)
  40. const currentDefaultModel = currentProvider && currentModel && {
  41. model: currentModel.model,
  42. provider: currentProvider.provider,
  43. }
  44. return currentDefaultModel
  45. }, [defaultModel, modelList])
  46. const [defaultModelState, setDefaultModelState] = useState<DefaultModel | undefined>(currentDefaultModel)
  47. const handleDefaultModelChange = useCallback((model: DefaultModel) => {
  48. setDefaultModelState(model)
  49. }, [])
  50. useEffect(() => {
  51. setDefaultModelState(currentDefaultModel)
  52. }, [currentDefaultModel])
  53. return [defaultModelState, handleDefaultModelChange]
  54. }
  55. export const useLanguage = () => {
  56. const { locale } = useContext(I18n)
  57. return locale.replace('-', '_')
  58. }
  59. export const useProviderCredentialsAndLoadBalancing = (
  60. provider: string,
  61. configurationMethod: ConfigurationMethodEnum,
  62. configured?: boolean,
  63. currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields,
  64. ) => {
  65. const { data: predefinedFormSchemasValue, mutate: mutatePredefined } = useSWR(
  66. (configurationMethod === ConfigurationMethodEnum.predefinedModel && configured)
  67. ? `/workspaces/current/model-providers/${provider}/credentials`
  68. : null,
  69. fetchModelProviderCredentials,
  70. )
  71. const { data: customFormSchemasValue, mutate: mutateCustomized } = useSWR(
  72. (configurationMethod === ConfigurationMethodEnum.customizableModel && currentCustomConfigurationModelFixedFields)
  73. ? `/workspaces/current/model-providers/${provider}/models/credentials?model=${currentCustomConfigurationModelFixedFields?.__model_name}&model_type=${currentCustomConfigurationModelFixedFields?.__model_type}`
  74. : null,
  75. fetchModelProviderCredentials,
  76. )
  77. const credentials = useMemo(() => {
  78. return configurationMethod === ConfigurationMethodEnum.predefinedModel
  79. ? predefinedFormSchemasValue?.credentials
  80. : customFormSchemasValue?.credentials
  81. ? {
  82. ...customFormSchemasValue?.credentials,
  83. ...currentCustomConfigurationModelFixedFields,
  84. }
  85. : undefined
  86. }, [
  87. configurationMethod,
  88. currentCustomConfigurationModelFixedFields,
  89. customFormSchemasValue?.credentials,
  90. predefinedFormSchemasValue?.credentials,
  91. ])
  92. const mutate = useMemo(() => () => {
  93. mutatePredefined()
  94. mutateCustomized()
  95. }, [mutateCustomized, mutatePredefined])
  96. return {
  97. credentials,
  98. loadBalancing: (configurationMethod === ConfigurationMethodEnum.predefinedModel
  99. ? predefinedFormSchemasValue
  100. : customFormSchemasValue
  101. )?.load_balancing,
  102. mutate,
  103. }
  104. // as ([Record<string, string | boolean | undefined> | undefined, ModelLoadBalancingConfig | undefined])
  105. }
  106. export const useModelList = (type: ModelTypeEnum) => {
  107. const { data, mutate, isLoading } = useSWR(`/workspaces/current/models/model-types/${type}`, fetchModelList)
  108. return {
  109. data: data?.data || [],
  110. mutate,
  111. isLoading,
  112. }
  113. }
  114. export const useDefaultModel = (type: ModelTypeEnum) => {
  115. const { data, mutate, isLoading } = useSWR(`/workspaces/current/default-model?model_type=${type}`, fetchDefaultModal)
  116. return {
  117. data: data?.data,
  118. mutate,
  119. isLoading,
  120. }
  121. }
  122. export const useCurrentProviderAndModel = (modelList: Model[], defaultModel?: DefaultModel) => {
  123. const currentProvider = modelList.find(provider => provider.provider === defaultModel?.provider)
  124. const currentModel = currentProvider?.models.find(model => model.model === defaultModel?.model)
  125. return {
  126. currentProvider,
  127. currentModel,
  128. }
  129. }
  130. export const useTextGenerationCurrentProviderAndModelAndModelList = (defaultModel?: DefaultModel) => {
  131. const { textGenerationModelList } = useProviderContext()
  132. const activeTextGenerationModelList = textGenerationModelList.filter(model => model.status === ModelStatusEnum.active)
  133. const {
  134. currentProvider,
  135. currentModel,
  136. } = useCurrentProviderAndModel(textGenerationModelList, defaultModel)
  137. return {
  138. currentProvider,
  139. currentModel,
  140. textGenerationModelList,
  141. activeTextGenerationModelList,
  142. }
  143. }
  144. export const useModelListAndDefaultModel = (type: ModelTypeEnum) => {
  145. const { data: modelList } = useModelList(type)
  146. const { data: defaultModel } = useDefaultModel(type)
  147. return {
  148. modelList,
  149. defaultModel,
  150. }
  151. }
  152. export const useModelListAndDefaultModelAndCurrentProviderAndModel = (type: ModelTypeEnum) => {
  153. const { modelList, defaultModel } = useModelListAndDefaultModel(type)
  154. const { currentProvider, currentModel } = useCurrentProviderAndModel(
  155. modelList,
  156. { provider: defaultModel?.provider.provider || '', model: defaultModel?.model || '' },
  157. )
  158. return {
  159. modelList,
  160. defaultModel,
  161. currentProvider,
  162. currentModel,
  163. }
  164. }
  165. export const useUpdateModelList = () => {
  166. const { mutate } = useSWRConfig()
  167. const updateModelList = useCallback((type: ModelTypeEnum) => {
  168. mutate(`/workspaces/current/models/model-types/${type}`)
  169. }, [mutate])
  170. return updateModelList
  171. }
  172. export const useAnthropicBuyQuota = () => {
  173. const [loading, setLoading] = useState(false)
  174. const handleGetPayUrl = async () => {
  175. if (loading)
  176. return
  177. setLoading(true)
  178. try {
  179. const res = await getPayUrl('/workspaces/current/model-providers/anthropic/checkout-url')
  180. window.location.href = res.url
  181. }
  182. finally {
  183. setLoading(false)
  184. }
  185. }
  186. return handleGetPayUrl
  187. }
  188. export const useModelProviders = () => {
  189. const { data: providersData, mutate, isLoading } = useSWR('/workspaces/current/model-providers', fetchModelProviders)
  190. return {
  191. data: providersData?.data || [],
  192. mutate,
  193. isLoading,
  194. }
  195. }
  196. export const useUpdateModelProviders = () => {
  197. const { mutate } = useSWRConfig()
  198. const updateModelProviders = useCallback(() => {
  199. mutate('/workspaces/current/model-providers')
  200. }, [mutate])
  201. return updateModelProviders
  202. }