use-workflow-run.ts 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636
  1. import { useCallback } from 'react'
  2. import {
  3. getIncomers,
  4. useReactFlow,
  5. useStoreApi,
  6. } from 'reactflow'
  7. import produce from 'immer'
  8. import { v4 as uuidV4 } from 'uuid'
  9. import { usePathname } from 'next/navigation'
  10. import { useWorkflowStore } from '../store'
  11. import { useNodesSyncDraft } from '../hooks'
  12. import type { Node } from '../types'
  13. import {
  14. NodeRunningStatus,
  15. WorkflowRunningStatus,
  16. } from '../types'
  17. import { DEFAULT_ITER_TIMES } from '../constants'
  18. import { useWorkflowUpdate } from './use-workflow-interactions'
  19. import { useStore as useAppStore } from '@/app/components/app/store'
  20. import type { IOtherOptions } from '@/service/base'
  21. import { ssePost } from '@/service/base'
  22. import {
  23. fetchPublishedWorkflow,
  24. stopWorkflowRun,
  25. } from '@/service/workflow'
  26. import { useFeaturesStore } from '@/app/components/base/features/hooks'
  27. import { AudioPlayerManager } from '@/app/components/base/audio-btn/audio.player.manager'
  28. import {
  29. getProcessedFilesFromResponse,
  30. } from '@/app/components/base/file-uploader/utils'
  31. export const useWorkflowRun = () => {
  32. const store = useStoreApi()
  33. const workflowStore = useWorkflowStore()
  34. const reactflow = useReactFlow()
  35. const featuresStore = useFeaturesStore()
  36. const { doSyncWorkflowDraft } = useNodesSyncDraft()
  37. const { handleUpdateWorkflowCanvas } = useWorkflowUpdate()
  38. const pathname = usePathname()
  39. const handleBackupDraft = useCallback(() => {
  40. const {
  41. getNodes,
  42. edges,
  43. } = store.getState()
  44. const { getViewport } = reactflow
  45. const {
  46. backupDraft,
  47. setBackupDraft,
  48. environmentVariables,
  49. } = workflowStore.getState()
  50. const { features } = featuresStore!.getState()
  51. if (!backupDraft) {
  52. setBackupDraft({
  53. nodes: getNodes(),
  54. edges,
  55. viewport: getViewport(),
  56. features,
  57. environmentVariables,
  58. })
  59. doSyncWorkflowDraft()
  60. }
  61. }, [reactflow, workflowStore, store, featuresStore, doSyncWorkflowDraft])
  62. const handleLoadBackupDraft = useCallback(() => {
  63. const {
  64. backupDraft,
  65. setBackupDraft,
  66. setEnvironmentVariables,
  67. } = workflowStore.getState()
  68. if (backupDraft) {
  69. const {
  70. nodes,
  71. edges,
  72. viewport,
  73. features,
  74. environmentVariables,
  75. } = backupDraft
  76. handleUpdateWorkflowCanvas({
  77. nodes,
  78. edges,
  79. viewport,
  80. })
  81. setEnvironmentVariables(environmentVariables)
  82. featuresStore!.setState({ features })
  83. setBackupDraft(undefined)
  84. }
  85. }, [handleUpdateWorkflowCanvas, workflowStore, featuresStore])
  86. const handleRun = useCallback(async (
  87. params: any,
  88. callback?: IOtherOptions,
  89. ) => {
  90. const {
  91. getNodes,
  92. setNodes,
  93. } = store.getState()
  94. const newNodes = produce(getNodes(), (draft) => {
  95. draft.forEach((node) => {
  96. node.data.selected = false
  97. node.data._runningStatus = undefined
  98. })
  99. })
  100. setNodes(newNodes)
  101. await doSyncWorkflowDraft()
  102. const {
  103. onWorkflowStarted,
  104. onWorkflowFinished,
  105. onNodeStarted,
  106. onNodeFinished,
  107. onIterationStart,
  108. onIterationNext,
  109. onIterationFinish,
  110. onError,
  111. ...restCallback
  112. } = callback || {}
  113. workflowStore.setState({ historyWorkflowData: undefined })
  114. const appDetail = useAppStore.getState().appDetail
  115. const workflowContainer = document.getElementById('workflow-container')
  116. const {
  117. clientWidth,
  118. clientHeight,
  119. } = workflowContainer!
  120. let url = ''
  121. if (appDetail?.mode === 'advanced-chat')
  122. url = `/apps/${appDetail.id}/advanced-chat/workflows/draft/run`
  123. if (appDetail?.mode === 'workflow')
  124. url = `/apps/${appDetail.id}/workflows/draft/run`
  125. let prevNodeId = ''
  126. const {
  127. setWorkflowRunningData,
  128. } = workflowStore.getState()
  129. setWorkflowRunningData({
  130. result: {
  131. status: WorkflowRunningStatus.Running,
  132. },
  133. tracing: [],
  134. resultText: '',
  135. })
  136. let ttsUrl = ''
  137. let ttsIsPublic = false
  138. if (params.token) {
  139. ttsUrl = '/text-to-audio'
  140. ttsIsPublic = true
  141. }
  142. else if (params.appId) {
  143. if (pathname.search('explore/installed') > -1)
  144. ttsUrl = `/installed-apps/${params.appId}/text-to-audio`
  145. else
  146. ttsUrl = `/apps/${params.appId}/text-to-audio`
  147. }
  148. const player = AudioPlayerManager.getInstance().getAudioPlayer(ttsUrl, ttsIsPublic, uuidV4(), 'none', 'none', (_: any): any => {})
  149. ssePost(
  150. url,
  151. {
  152. body: params,
  153. },
  154. {
  155. onWorkflowStarted: (params) => {
  156. const { task_id, data } = params
  157. const {
  158. workflowRunningData,
  159. setWorkflowRunningData,
  160. setIterParallelLogMap,
  161. } = workflowStore.getState()
  162. const {
  163. edges,
  164. setEdges,
  165. } = store.getState()
  166. setIterParallelLogMap(new Map())
  167. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  168. draft.task_id = task_id
  169. draft.result = {
  170. ...draft?.result,
  171. ...data,
  172. status: WorkflowRunningStatus.Running,
  173. }
  174. }))
  175. const newEdges = produce(edges, (draft) => {
  176. draft.forEach((edge) => {
  177. edge.data = {
  178. ...edge.data,
  179. _run: false,
  180. }
  181. })
  182. })
  183. setEdges(newEdges)
  184. if (onWorkflowStarted)
  185. onWorkflowStarted(params)
  186. },
  187. onWorkflowFinished: (params) => {
  188. const { data } = params
  189. const {
  190. workflowRunningData,
  191. setWorkflowRunningData,
  192. } = workflowStore.getState()
  193. const isStringOutput = data.outputs && Object.keys(data.outputs).length === 1 && typeof data.outputs[Object.keys(data.outputs)[0]] === 'string'
  194. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  195. draft.result = {
  196. ...draft.result,
  197. ...data,
  198. files: getProcessedFilesFromResponse(data.files || []),
  199. } as any
  200. if (isStringOutput) {
  201. draft.resultTabActive = true
  202. draft.resultText = data.outputs[Object.keys(data.outputs)[0]]
  203. }
  204. }))
  205. prevNodeId = ''
  206. if (onWorkflowFinished)
  207. onWorkflowFinished(params)
  208. },
  209. onError: (params) => {
  210. const {
  211. workflowRunningData,
  212. setWorkflowRunningData,
  213. } = workflowStore.getState()
  214. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  215. draft.result = {
  216. ...draft.result,
  217. status: WorkflowRunningStatus.Failed,
  218. }
  219. }))
  220. if (onError)
  221. onError(params)
  222. },
  223. onNodeStarted: (params) => {
  224. const { data } = params
  225. const {
  226. workflowRunningData,
  227. setWorkflowRunningData,
  228. iterParallelLogMap,
  229. setIterParallelLogMap,
  230. } = workflowStore.getState()
  231. const {
  232. getNodes,
  233. setNodes,
  234. edges,
  235. setEdges,
  236. transform,
  237. } = store.getState()
  238. const nodes = getNodes()
  239. const node = nodes.find(node => node.id === data.node_id)
  240. if (node?.parentId) {
  241. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  242. const tracing = draft.tracing!
  243. const iterations = tracing.find(trace => trace.node_id === node?.parentId)
  244. const currIteration = iterations?.details![node.data.iteration_index] || iterations?.details![iterations.details!.length - 1]
  245. if (!data.parallel_run_id) {
  246. currIteration?.push({
  247. ...data,
  248. status: NodeRunningStatus.Running,
  249. } as any)
  250. }
  251. else {
  252. const nodeId = iterations?.node_id as string
  253. if (!iterParallelLogMap.has(nodeId as string))
  254. iterParallelLogMap.set(iterations?.node_id as string, new Map())
  255. const currentIterLogMap = iterParallelLogMap.get(nodeId)!
  256. if (!currentIterLogMap.has(data.parallel_run_id))
  257. currentIterLogMap.set(data.parallel_run_id, [{ ...data, status: NodeRunningStatus.Running } as any])
  258. else
  259. currentIterLogMap.get(data.parallel_run_id)!.push({ ...data, status: NodeRunningStatus.Running } as any)
  260. setIterParallelLogMap(iterParallelLogMap)
  261. if (iterations)
  262. iterations.details = Array.from(currentIterLogMap.values())
  263. }
  264. }))
  265. }
  266. else {
  267. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  268. draft.tracing!.push({
  269. ...data,
  270. status: NodeRunningStatus.Running,
  271. } as any)
  272. }))
  273. const {
  274. setViewport,
  275. } = reactflow
  276. const currentNodeIndex = nodes.findIndex(node => node.id === data.node_id)
  277. const currentNode = nodes[currentNodeIndex]
  278. const position = currentNode.position
  279. const zoom = transform[2]
  280. if (!currentNode.parentId) {
  281. setViewport({
  282. x: (clientWidth - 400 - currentNode.width! * zoom) / 2 - position.x * zoom,
  283. y: (clientHeight - currentNode.height! * zoom) / 2 - position.y * zoom,
  284. zoom: transform[2],
  285. })
  286. }
  287. const newNodes = produce(nodes, (draft) => {
  288. draft[currentNodeIndex].data._runningStatus = NodeRunningStatus.Running
  289. })
  290. setNodes(newNodes)
  291. const incomeNodesId = getIncomers({ id: data.node_id } as Node, newNodes, edges).filter(node => node.data._runningStatus === NodeRunningStatus.Succeeded).map(node => node.id)
  292. const newEdges = produce(edges, (draft) => {
  293. draft.forEach((edge) => {
  294. if (edge.target === data.node_id && incomeNodesId.includes(edge.source))
  295. edge.data = { ...edge.data, _run: true } as any
  296. })
  297. })
  298. setEdges(newEdges)
  299. }
  300. if (onNodeStarted)
  301. onNodeStarted(params)
  302. },
  303. onNodeFinished: (params) => {
  304. const { data } = params
  305. const {
  306. workflowRunningData,
  307. setWorkflowRunningData,
  308. iterParallelLogMap,
  309. setIterParallelLogMap,
  310. } = workflowStore.getState()
  311. const {
  312. getNodes,
  313. setNodes,
  314. } = store.getState()
  315. const nodes = getNodes()
  316. const nodeParentId = nodes.find(node => node.id === data.node_id)!.parentId
  317. if (nodeParentId) {
  318. if (!data.execution_metadata.parallel_mode_run_id) {
  319. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  320. const tracing = draft.tracing!
  321. const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node
  322. if (iterations && iterations.details) {
  323. const iterationIndex = data.execution_metadata?.iteration_index || 0
  324. if (!iterations.details[iterationIndex])
  325. iterations.details[iterationIndex] = []
  326. const currIteration = iterations.details[iterationIndex]
  327. const nodeIndex = currIteration.findIndex(node =>
  328. node.node_id === data.node_id && (
  329. node.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || node.parallel_id === data.execution_metadata?.parallel_id),
  330. )
  331. if (nodeIndex !== -1) {
  332. currIteration[nodeIndex] = {
  333. ...currIteration[nodeIndex],
  334. ...data,
  335. } as any
  336. }
  337. else {
  338. currIteration.push({
  339. ...data,
  340. } as any)
  341. }
  342. }
  343. }))
  344. }
  345. else {
  346. // open parallel mode
  347. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  348. const tracing = draft.tracing!
  349. const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node
  350. if (iterations && iterations.details) {
  351. const iterRunID = data.execution_metadata?.parallel_mode_run_id
  352. const currIteration = iterParallelLogMap.get(iterations.node_id)?.get(iterRunID)
  353. const nodeIndex = currIteration?.findIndex(node =>
  354. node.node_id === data.node_id && (
  355. node?.parallel_run_id === data.execution_metadata?.parallel_mode_run_id),
  356. )
  357. if (currIteration) {
  358. if (nodeIndex !== undefined && nodeIndex !== -1) {
  359. currIteration[nodeIndex] = {
  360. ...currIteration[nodeIndex],
  361. ...data,
  362. } as any
  363. }
  364. else {
  365. currIteration.push({
  366. ...data,
  367. } as any)
  368. }
  369. }
  370. setIterParallelLogMap(iterParallelLogMap)
  371. const iterLogMap = iterParallelLogMap.get(iterations.node_id)
  372. if (iterLogMap)
  373. iterations.details = Array.from(iterLogMap.values())
  374. }
  375. }))
  376. }
  377. }
  378. else {
  379. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  380. const currentIndex = draft.tracing!.findIndex((trace) => {
  381. if (!trace.execution_metadata?.parallel_id)
  382. return trace.node_id === data.node_id
  383. return trace.node_id === data.node_id && trace.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id
  384. })
  385. if (currentIndex > -1 && draft.tracing) {
  386. draft.tracing[currentIndex] = {
  387. ...(draft.tracing[currentIndex].extras
  388. ? { extras: draft.tracing[currentIndex].extras }
  389. : {}),
  390. ...data,
  391. } as any
  392. }
  393. }))
  394. const newNodes = produce(nodes, (draft) => {
  395. const currentNode = draft.find(node => node.id === data.node_id)!
  396. currentNode.data._runningStatus = data.status as any
  397. })
  398. setNodes(newNodes)
  399. prevNodeId = data.node_id
  400. }
  401. if (onNodeFinished)
  402. onNodeFinished(params)
  403. },
  404. onIterationStart: (params) => {
  405. const { data } = params
  406. const {
  407. workflowRunningData,
  408. setWorkflowRunningData,
  409. setIterTimes,
  410. } = workflowStore.getState()
  411. const {
  412. getNodes,
  413. setNodes,
  414. edges,
  415. setEdges,
  416. transform,
  417. } = store.getState()
  418. const nodes = getNodes()
  419. setIterTimes(DEFAULT_ITER_TIMES)
  420. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  421. draft.tracing!.push({
  422. ...data,
  423. status: NodeRunningStatus.Running,
  424. details: [],
  425. iterDurationMap: {},
  426. } as any)
  427. }))
  428. const {
  429. setViewport,
  430. } = reactflow
  431. const currentNodeIndex = nodes.findIndex(node => node.id === data.node_id)
  432. const currentNode = nodes[currentNodeIndex]
  433. const position = currentNode.position
  434. const zoom = transform[2]
  435. if (!currentNode.parentId) {
  436. setViewport({
  437. x: (clientWidth - 400 - currentNode.width! * zoom) / 2 - position.x * zoom,
  438. y: (clientHeight - currentNode.height! * zoom) / 2 - position.y * zoom,
  439. zoom: transform[2],
  440. })
  441. }
  442. const newNodes = produce(nodes, (draft) => {
  443. draft[currentNodeIndex].data._runningStatus = NodeRunningStatus.Running
  444. draft[currentNodeIndex].data._iterationLength = data.metadata.iterator_length
  445. })
  446. setNodes(newNodes)
  447. const newEdges = produce(edges, (draft) => {
  448. const edge = draft.find(edge => edge.target === data.node_id && edge.source === prevNodeId)
  449. if (edge)
  450. edge.data = { ...edge.data, _run: true } as any
  451. })
  452. setEdges(newEdges)
  453. if (onIterationStart)
  454. onIterationStart(params)
  455. },
  456. onIterationNext: (params) => {
  457. const {
  458. workflowRunningData,
  459. setWorkflowRunningData,
  460. iterTimes,
  461. setIterTimes,
  462. } = workflowStore.getState()
  463. const { data } = params
  464. const {
  465. getNodes,
  466. setNodes,
  467. } = store.getState()
  468. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  469. const iteration = draft.tracing!.find(trace => trace.node_id === data.node_id)
  470. if (iteration) {
  471. if (iteration.iterDurationMap && data.duration)
  472. iteration.iterDurationMap[data.parallel_mode_run_id ?? `${data.index - 1}`] = data.duration
  473. if (iteration.details!.length >= iteration.metadata.iterator_length!)
  474. return
  475. }
  476. if (!data.parallel_mode_run_id)
  477. iteration?.details!.push([])
  478. }))
  479. const nodes = getNodes()
  480. const newNodes = produce(nodes, (draft) => {
  481. const currentNode = draft.find(node => node.id === data.node_id)!
  482. currentNode.data._iterationIndex = iterTimes
  483. setIterTimes(iterTimes + 1)
  484. })
  485. setNodes(newNodes)
  486. if (onIterationNext)
  487. onIterationNext(params)
  488. },
  489. onIterationFinish: (params) => {
  490. const { data } = params
  491. const {
  492. workflowRunningData,
  493. setWorkflowRunningData,
  494. setIterTimes,
  495. } = workflowStore.getState()
  496. const {
  497. getNodes,
  498. setNodes,
  499. } = store.getState()
  500. const nodes = getNodes()
  501. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  502. const tracing = draft.tracing!
  503. const currIterationNode = tracing.find(trace => trace.node_id === data.node_id)
  504. if (currIterationNode) {
  505. Object.assign(currIterationNode, {
  506. ...data,
  507. status: NodeRunningStatus.Succeeded,
  508. })
  509. }
  510. }))
  511. setIterTimes(DEFAULT_ITER_TIMES)
  512. const newNodes = produce(nodes, (draft) => {
  513. const currentNode = draft.find(node => node.id === data.node_id)!
  514. currentNode.data._runningStatus = data.status
  515. })
  516. setNodes(newNodes)
  517. prevNodeId = data.node_id
  518. if (onIterationFinish)
  519. onIterationFinish(params)
  520. },
  521. onParallelBranchStarted: (params) => {
  522. // console.log(params, 'parallel start')
  523. },
  524. onParallelBranchFinished: (params) => {
  525. // console.log(params, 'finished')
  526. },
  527. onTextChunk: (params) => {
  528. const { data: { text } } = params
  529. const {
  530. workflowRunningData,
  531. setWorkflowRunningData,
  532. } = workflowStore.getState()
  533. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  534. draft.resultTabActive = true
  535. draft.resultText += text
  536. }))
  537. },
  538. onTextReplace: (params) => {
  539. const { data: { text } } = params
  540. const {
  541. workflowRunningData,
  542. setWorkflowRunningData,
  543. } = workflowStore.getState()
  544. setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
  545. draft.resultText = text
  546. }))
  547. },
  548. onTTSChunk: (messageId: string, audio: string, audioType?: string) => {
  549. if (!audio || audio === '')
  550. return
  551. player.playAudioWithAudio(audio, true)
  552. AudioPlayerManager.getInstance().resetMsgId(messageId)
  553. },
  554. onTTSEnd: (messageId: string, audio: string, audioType?: string) => {
  555. player.playAudioWithAudio(audio, false)
  556. },
  557. ...restCallback,
  558. },
  559. )
  560. }, [store, reactflow, workflowStore, doSyncWorkflowDraft])
  561. const handleStopRun = useCallback((taskId: string) => {
  562. const appId = useAppStore.getState().appDetail?.id
  563. stopWorkflowRun(`/apps/${appId}/workflow-runs/tasks/${taskId}/stop`)
  564. }, [])
  565. const handleRestoreFromPublishedWorkflow = useCallback(async () => {
  566. const appDetail = useAppStore.getState().appDetail
  567. const publishedWorkflow = await fetchPublishedWorkflow(`/apps/${appDetail?.id}/workflows/publish`)
  568. if (publishedWorkflow) {
  569. const nodes = publishedWorkflow.graph.nodes
  570. const edges = publishedWorkflow.graph.edges
  571. const viewport = publishedWorkflow.graph.viewport!
  572. handleUpdateWorkflowCanvas({
  573. nodes,
  574. edges,
  575. viewport,
  576. })
  577. featuresStore?.setState({ features: publishedWorkflow.features })
  578. workflowStore.getState().setPublishedAt(publishedWorkflow.created_at)
  579. workflowStore.getState().setEnvironmentVariables(publishedWorkflow.environment_variables || [])
  580. }
  581. }, [featuresStore, handleUpdateWorkflowCanvas, workflowStore])
  582. return {
  583. handleBackupDraft,
  584. handleLoadBackupDraft,
  585. handleRun,
  586. handleStopRun,
  587. handleRestoreFromPublishedWorkflow,
  588. }
  589. }