index.tsx 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import type { FC } from 'react'
  2. import { useState } from 'react'
  3. import { useTranslation } from 'react-i18next'
  4. import {
  5. RiLoader2Line,
  6. } from '@remixicon/react'
  7. import type {
  8. CustomConfigurationModelFixedFields,
  9. ModelItem,
  10. ModelProvider,
  11. } from '../declarations'
  12. import { ConfigurationMethodEnum } from '../declarations'
  13. import {
  14. DEFAULT_BACKGROUND_COLOR,
  15. MODEL_PROVIDER_QUOTA_GET_PAID,
  16. modelTypeFormat,
  17. } from '../utils'
  18. import ProviderIcon from '../provider-icon'
  19. import ModelBadge from '../model-badge'
  20. import CredentialPanel from './credential-panel'
  21. import QuotaPanel from './quota-panel'
  22. import ModelList from './model-list'
  23. import AddModelButton from './add-model-button'
  24. import { ChevronDownDouble } from '@/app/components/base/icons/src/vender/line/arrows'
  25. import { fetchModelProviderModelList } from '@/service/common'
  26. import { useEventEmitterContextContext } from '@/context/event-emitter'
  27. import { IS_CE_EDITION } from '@/config'
  28. import { useAppContext } from '@/context/app-context'
  29. export const UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST = 'UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST'
  30. type ProviderAddedCardProps = {
  31. provider: ModelProvider
  32. onOpenModal: (configurationMethod: ConfigurationMethodEnum, currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => void
  33. }
  34. const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
  35. provider,
  36. onOpenModal,
  37. }) => {
  38. const { t } = useTranslation()
  39. const { eventEmitter } = useEventEmitterContextContext()
  40. const [fetched, setFetched] = useState(false)
  41. const [loading, setLoading] = useState(false)
  42. const [collapsed, setCollapsed] = useState(true)
  43. const [modelList, setModelList] = useState<ModelItem[]>([])
  44. const configurationMethods = provider.configurate_methods.filter(method => method !== ConfigurationMethodEnum.fetchFromRemote)
  45. const systemConfig = provider.system_configuration
  46. const hasModelList = fetched && !!modelList.length
  47. const { isCurrentWorkspaceManager } = useAppContext()
  48. const showQuota = systemConfig.enabled && [...MODEL_PROVIDER_QUOTA_GET_PAID].includes(provider.provider) && !IS_CE_EDITION
  49. const getModelList = async (providerName: string) => {
  50. if (loading)
  51. return
  52. try {
  53. setLoading(true)
  54. const modelsData = await fetchModelProviderModelList(`/workspaces/current/model-providers/${providerName}/models`)
  55. setModelList(modelsData.data)
  56. setCollapsed(false)
  57. setFetched(true)
  58. }
  59. finally {
  60. setLoading(false)
  61. }
  62. }
  63. const handleOpenModelList = () => {
  64. if (fetched) {
  65. setCollapsed(false)
  66. return
  67. }
  68. getModelList(provider.provider)
  69. }
  70. eventEmitter?.useSubscription((v: any) => {
  71. if (v?.type === UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST && v.payload === provider.provider)
  72. getModelList(v.payload)
  73. })
  74. return (
  75. <div
  76. className='mb-2 rounded-xl border-[0.5px] border-black/5 shadow-xs'
  77. style={{ background: provider.background || DEFAULT_BACKGROUND_COLOR }}
  78. >
  79. <div className='flex pl-3 py-2 pr-2 rounded-t-xl'>
  80. <div className='grow px-1 pt-1 pb-0.5'>
  81. <ProviderIcon
  82. className='mb-2'
  83. provider={provider}
  84. />
  85. <div className='flex gap-0.5'>
  86. {
  87. provider.supported_model_types.map(modelType => (
  88. <ModelBadge key={modelType}>
  89. {modelTypeFormat(modelType)}
  90. </ModelBadge>
  91. ))
  92. }
  93. </div>
  94. </div>
  95. {
  96. showQuota && (
  97. <QuotaPanel
  98. provider={provider}
  99. />
  100. )
  101. }
  102. {
  103. configurationMethods.includes(ConfigurationMethodEnum.predefinedModel) && isCurrentWorkspaceManager && (
  104. <CredentialPanel
  105. onSetup={() => onOpenModal(ConfigurationMethodEnum.predefinedModel)}
  106. provider={provider}
  107. />
  108. )
  109. }
  110. </div>
  111. {
  112. collapsed && (
  113. <div className='group flex items-center justify-between pl-2 py-1.5 pr-[11px] border-t border-t-black/5 bg-white/30 text-xs font-medium text-gray-500'>
  114. <div className='group-hover:hidden pl-1 pr-1.5 h-6 leading-6'>
  115. {
  116. hasModelList
  117. ? t('common.modelProvider.modelsNum', { num: modelList.length })
  118. : t('common.modelProvider.showModels')
  119. }
  120. </div>
  121. <div
  122. className='hidden group-hover:flex items-center pl-1 pr-1.5 h-6 rounded-lg hover:bg-white cursor-pointer'
  123. onClick={handleOpenModelList}
  124. >
  125. <ChevronDownDouble className='mr-0.5 w-3 h-3' />
  126. {
  127. hasModelList
  128. ? t('common.modelProvider.showModelsNum', { num: modelList.length })
  129. : t('common.modelProvider.showModels')
  130. }
  131. {
  132. loading && (
  133. <RiLoader2Line className='ml-0.5 animate-spin w-3 h-3' />
  134. )
  135. }
  136. </div>
  137. {
  138. configurationMethods.includes(ConfigurationMethodEnum.customizableModel) && isCurrentWorkspaceManager && (
  139. <AddModelButton
  140. onClick={() => onOpenModal(ConfigurationMethodEnum.customizableModel)}
  141. className='hidden group-hover:flex group-hover:text-primary-600'
  142. />
  143. )
  144. }
  145. </div>
  146. )
  147. }
  148. {
  149. !collapsed && (
  150. <ModelList
  151. provider={provider}
  152. models={modelList}
  153. onCollapse={() => setCollapsed(true)}
  154. onConfig={currentCustomConfigurationModelFixedFields => onOpenModal(ConfigurationMethodEnum.customizableModel, currentCustomConfigurationModelFixedFields)}
  155. onChange={(provider: string) => getModelList(provider)}
  156. />
  157. )
  158. }
  159. </div>
  160. )
  161. }
  162. export default ProviderAddedCard