import { useCallback } from 'react' import { getIncomers, useReactFlow, useStoreApi, } from 'reactflow' import produce from 'immer' import { v4 as uuidV4 } from 'uuid' import { usePathname } from 'next/navigation' import { useWorkflowStore } from '../store' import { useNodesSyncDraft } from '../hooks' import type { Node } from '../types' import { NodeRunningStatus, WorkflowRunningStatus, } from '../types' import { DEFAULT_ITER_TIMES } from '../constants' import { useWorkflowUpdate } from './use-workflow-interactions' import { useStore as useAppStore } from '@/app/components/app/store' import type { IOtherOptions } from '@/service/base' import { ssePost } from '@/service/base' import { fetchPublishedWorkflow, stopWorkflowRun, } from '@/service/workflow' import { useFeaturesStore } from '@/app/components/base/features/hooks' import { AudioPlayerManager } from '@/app/components/base/audio-btn/audio.player.manager' import { getProcessedFilesFromResponse, } from '@/app/components/base/file-uploader/utils' export const useWorkflowRun = () => { const store = useStoreApi() const workflowStore = useWorkflowStore() const reactflow = useReactFlow() const featuresStore = useFeaturesStore() const { doSyncWorkflowDraft } = useNodesSyncDraft() const { handleUpdateWorkflowCanvas } = useWorkflowUpdate() const pathname = usePathname() const handleBackupDraft = useCallback(() => { const { getNodes, edges, } = store.getState() const { getViewport } = reactflow const { backupDraft, setBackupDraft, environmentVariables, } = workflowStore.getState() const { features } = featuresStore!.getState() if (!backupDraft) { setBackupDraft({ nodes: getNodes(), edges, viewport: getViewport(), features, environmentVariables, }) doSyncWorkflowDraft() } }, [reactflow, workflowStore, store, featuresStore, doSyncWorkflowDraft]) const handleLoadBackupDraft = useCallback(() => { const { backupDraft, setBackupDraft, setEnvironmentVariables, } = workflowStore.getState() if (backupDraft) { const { nodes, edges, viewport, features, environmentVariables, } = backupDraft handleUpdateWorkflowCanvas({ nodes, edges, viewport, }) setEnvironmentVariables(environmentVariables) featuresStore!.setState({ features }) setBackupDraft(undefined) } }, [handleUpdateWorkflowCanvas, workflowStore, featuresStore]) const handleRun = useCallback(async ( params: any, callback?: IOtherOptions, ) => { const { getNodes, setNodes, } = store.getState() const newNodes = produce(getNodes(), (draft) => { draft.forEach((node) => { node.data.selected = false node.data._runningStatus = undefined }) }) setNodes(newNodes) await doSyncWorkflowDraft() const { onWorkflowStarted, onWorkflowFinished, onNodeStarted, onNodeFinished, onIterationStart, onIterationNext, onIterationFinish, onError, ...restCallback } = callback || {} workflowStore.setState({ historyWorkflowData: undefined }) const appDetail = useAppStore.getState().appDetail const workflowContainer = document.getElementById('workflow-container') const { clientWidth, clientHeight, } = workflowContainer! let url = '' if (appDetail?.mode === 'advanced-chat') url = `/apps/${appDetail.id}/advanced-chat/workflows/draft/run` if (appDetail?.mode === 'workflow') url = `/apps/${appDetail.id}/workflows/draft/run` let prevNodeId = '' const { setWorkflowRunningData, } = workflowStore.getState() setWorkflowRunningData({ result: { status: WorkflowRunningStatus.Running, }, tracing: [], resultText: '', }) let ttsUrl = '' let ttsIsPublic = false if (params.token) { ttsUrl = '/text-to-audio' ttsIsPublic = true } else if (params.appId) { if (pathname.search('explore/installed') > -1) ttsUrl = `/installed-apps/${params.appId}/text-to-audio` else ttsUrl = `/apps/${params.appId}/text-to-audio` } const player = AudioPlayerManager.getInstance().getAudioPlayer(ttsUrl, ttsIsPublic, uuidV4(), 'none', 'none', (_: any): any => {}) ssePost( url, { body: params, }, { onWorkflowStarted: (params) => { const { task_id, data } = params const { workflowRunningData, setWorkflowRunningData, setIterParallelLogMap, } = workflowStore.getState() const { edges, setEdges, } = store.getState() setIterParallelLogMap(new Map()) setWorkflowRunningData(produce(workflowRunningData!, (draft) => { draft.task_id = task_id draft.result = { ...draft?.result, ...data, status: WorkflowRunningStatus.Running, } })) const newEdges = produce(edges, (draft) => { draft.forEach((edge) => { edge.data = { ...edge.data, _run: false, } }) }) setEdges(newEdges) if (onWorkflowStarted) onWorkflowStarted(params) }, onWorkflowFinished: (params) => { const { data } = params const { workflowRunningData, setWorkflowRunningData, } = workflowStore.getState() const isStringOutput = data.outputs && Object.keys(data.outputs).length === 1 && typeof data.outputs[Object.keys(data.outputs)[0]] === 'string' setWorkflowRunningData(produce(workflowRunningData!, (draft) => { draft.result = { ...draft.result, ...data, files: getProcessedFilesFromResponse(data.files || []), } as any if (isStringOutput) { draft.resultTabActive = true draft.resultText = data.outputs[Object.keys(data.outputs)[0]] } })) prevNodeId = '' if (onWorkflowFinished) onWorkflowFinished(params) }, onError: (params) => { const { workflowRunningData, setWorkflowRunningData, } = workflowStore.getState() setWorkflowRunningData(produce(workflowRunningData!, (draft) => { draft.result = { ...draft.result, status: WorkflowRunningStatus.Failed, } })) if (onError) onError(params) }, onNodeStarted: (params) => { const { data } = params const { workflowRunningData, setWorkflowRunningData, iterParallelLogMap, setIterParallelLogMap, } = workflowStore.getState() const { getNodes, setNodes, edges, setEdges, transform, } = store.getState() const nodes = getNodes() const node = nodes.find(node => node.id === data.node_id) if (node?.parentId) { setWorkflowRunningData(produce(workflowRunningData!, (draft) => { const tracing = draft.tracing! const iterations = tracing.find(trace => trace.node_id === node?.parentId) const currIteration = iterations?.details![node.data.iteration_index] || iterations?.details![iterations.details!.length - 1] if (!data.parallel_run_id) { currIteration?.push({ ...data, status: NodeRunningStatus.Running, } as any) } else { const nodeId = iterations?.node_id as string if (!iterParallelLogMap.has(nodeId as string)) iterParallelLogMap.set(iterations?.node_id as string, new Map()) const currentIterLogMap = iterParallelLogMap.get(nodeId)! if (!currentIterLogMap.has(data.parallel_run_id)) currentIterLogMap.set(data.parallel_run_id, [{ ...data, status: NodeRunningStatus.Running } as any]) else currentIterLogMap.get(data.parallel_run_id)!.push({ ...data, status: NodeRunningStatus.Running } as any) setIterParallelLogMap(iterParallelLogMap) if (iterations) iterations.details = Array.from(currentIterLogMap.values()) } })) } else { setWorkflowRunningData(produce(workflowRunningData!, (draft) => { draft.tracing!.push({ ...data, status: NodeRunningStatus.Running, } as any) })) const { setViewport, } = reactflow const currentNodeIndex = nodes.findIndex(node => node.id === data.node_id) const currentNode = nodes[currentNodeIndex] const position = currentNode.position const zoom = transform[2] if (!currentNode.parentId) { setViewport({ x: (clientWidth - 400 - currentNode.width! * zoom) / 2 - position.x * zoom, y: (clientHeight - currentNode.height! * zoom) / 2 - position.y * zoom, zoom: transform[2], }) } const newNodes = produce(nodes, (draft) => { draft[currentNodeIndex].data._runningStatus = NodeRunningStatus.Running }) setNodes(newNodes) const incomeNodesId = getIncomers({ id: data.node_id } as Node, newNodes, edges).filter(node => node.data._runningStatus === NodeRunningStatus.Succeeded).map(node => node.id) const newEdges = produce(edges, (draft) => { draft.forEach((edge) => { if (edge.target === data.node_id && incomeNodesId.includes(edge.source)) edge.data = { ...edge.data, _run: true } as any }) }) setEdges(newEdges) } if (onNodeStarted) onNodeStarted(params) }, onNodeFinished: (params) => { const { data } = params const { workflowRunningData, setWorkflowRunningData, iterParallelLogMap, setIterParallelLogMap, } = workflowStore.getState() const { getNodes, setNodes, } = store.getState() const nodes = getNodes() const nodeParentId = nodes.find(node => node.id === data.node_id)!.parentId if (nodeParentId) { if (!data.execution_metadata.parallel_mode_run_id) { setWorkflowRunningData(produce(workflowRunningData!, (draft) => { const tracing = draft.tracing! const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node if (iterations && iterations.details) { const iterationIndex = data.execution_metadata?.iteration_index || 0 if (!iterations.details[iterationIndex]) iterations.details[iterationIndex] = [] const currIteration = iterations.details[iterationIndex] const nodeIndex = currIteration.findIndex(node => node.node_id === data.node_id && ( node.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || node.parallel_id === data.execution_metadata?.parallel_id), ) if (nodeIndex !== -1) { currIteration[nodeIndex] = { ...currIteration[nodeIndex], ...data, } as any } else { currIteration.push({ ...data, } as any) } } })) } else { // open parallel mode setWorkflowRunningData(produce(workflowRunningData!, (draft) => { const tracing = draft.tracing! const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node if (iterations && iterations.details) { const iterRunID = data.execution_metadata?.parallel_mode_run_id const currIteration = iterParallelLogMap.get(iterations.node_id)?.get(iterRunID) const nodeIndex = currIteration?.findIndex(node => node.node_id === data.node_id && ( node?.parallel_run_id === data.execution_metadata?.parallel_mode_run_id), ) if (currIteration) { if (nodeIndex !== undefined && nodeIndex !== -1) { currIteration[nodeIndex] = { ...currIteration[nodeIndex], ...data, } as any } else { currIteration.push({ ...data, } as any) } } setIterParallelLogMap(iterParallelLogMap) const iterLogMap = iterParallelLogMap.get(iterations.node_id) if (iterLogMap) iterations.details = Array.from(iterLogMap.values()) } })) } } else { setWorkflowRunningData(produce(workflowRunningData!, (draft) => { const currentIndex = draft.tracing!.findIndex((trace) => { if (!trace.execution_metadata?.parallel_id) return trace.node_id === data.node_id return trace.node_id === data.node_id && trace.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id }) if (currentIndex > -1 && draft.tracing) { draft.tracing[currentIndex] = { ...(draft.tracing[currentIndex].extras ? { extras: draft.tracing[currentIndex].extras } : {}), ...data, } as any } })) const newNodes = produce(nodes, (draft) => { const currentNode = draft.find(node => node.id === data.node_id)! currentNode.data._runningStatus = data.status as any }) setNodes(newNodes) prevNodeId = data.node_id } if (onNodeFinished) onNodeFinished(params) }, onIterationStart: (params) => { const { data } = params const { workflowRunningData, setWorkflowRunningData, setIterTimes, } = workflowStore.getState() const { getNodes, setNodes, edges, setEdges, transform, } = store.getState() const nodes = getNodes() setIterTimes(DEFAULT_ITER_TIMES) setWorkflowRunningData(produce(workflowRunningData!, (draft) => { draft.tracing!.push({ ...data, status: NodeRunningStatus.Running, details: [], iterDurationMap: {}, } as any) })) const { setViewport, } = reactflow const currentNodeIndex = nodes.findIndex(node => node.id === data.node_id) const currentNode = nodes[currentNodeIndex] const position = currentNode.position const zoom = transform[2] if (!currentNode.parentId) { setViewport({ x: (clientWidth - 400 - currentNode.width! * zoom) / 2 - position.x * zoom, y: (clientHeight - currentNode.height! * zoom) / 2 - position.y * zoom, zoom: transform[2], }) } const newNodes = produce(nodes, (draft) => { draft[currentNodeIndex].data._runningStatus = NodeRunningStatus.Running draft[currentNodeIndex].data._iterationLength = data.metadata.iterator_length }) setNodes(newNodes) const newEdges = produce(edges, (draft) => { const edge = draft.find(edge => edge.target === data.node_id && edge.source === prevNodeId) if (edge) edge.data = { ...edge.data, _run: true } as any }) setEdges(newEdges) if (onIterationStart) onIterationStart(params) }, onIterationNext: (params) => { const { workflowRunningData, setWorkflowRunningData, iterTimes, setIterTimes, } = workflowStore.getState() const { data } = params const { getNodes, setNodes, } = store.getState() setWorkflowRunningData(produce(workflowRunningData!, (draft) => { const iteration = draft.tracing!.find(trace => trace.node_id === data.node_id) if (iteration) { if (iteration.iterDurationMap && data.duration) iteration.iterDurationMap[data.parallel_mode_run_id ?? `${data.index - 1}`] = data.duration if (iteration.details!.length >= iteration.metadata.iterator_length!) return } if (!data.parallel_mode_run_id) iteration?.details!.push([]) })) const nodes = getNodes() const newNodes = produce(nodes, (draft) => { const currentNode = draft.find(node => node.id === data.node_id)! currentNode.data._iterationIndex = iterTimes setIterTimes(iterTimes + 1) }) setNodes(newNodes) if (onIterationNext) onIterationNext(params) }, onIterationFinish: (params) => { const { data } = params const { workflowRunningData, setWorkflowRunningData, setIterTimes, } = workflowStore.getState() const { getNodes, setNodes, } = store.getState() const nodes = getNodes() setWorkflowRunningData(produce(workflowRunningData!, (draft) => { const tracing = draft.tracing! const currIterationNode = tracing.find(trace => trace.node_id === data.node_id) if (currIterationNode) { Object.assign(currIterationNode, { ...data, status: NodeRunningStatus.Succeeded, }) } })) setIterTimes(DEFAULT_ITER_TIMES) const newNodes = produce(nodes, (draft) => { const currentNode = draft.find(node => node.id === data.node_id)! currentNode.data._runningStatus = data.status }) setNodes(newNodes) prevNodeId = data.node_id if (onIterationFinish) onIterationFinish(params) }, onParallelBranchStarted: (params) => { // console.log(params, 'parallel start') }, onParallelBranchFinished: (params) => { // console.log(params, 'finished') }, onTextChunk: (params) => { const { data: { text } } = params const { workflowRunningData, setWorkflowRunningData, } = workflowStore.getState() setWorkflowRunningData(produce(workflowRunningData!, (draft) => { draft.resultTabActive = true draft.resultText += text })) }, onTextReplace: (params) => { const { data: { text } } = params const { workflowRunningData, setWorkflowRunningData, } = workflowStore.getState() setWorkflowRunningData(produce(workflowRunningData!, (draft) => { draft.resultText = text })) }, onTTSChunk: (messageId: string, audio: string, audioType?: string) => { if (!audio || audio === '') return player.playAudioWithAudio(audio, true) AudioPlayerManager.getInstance().resetMsgId(messageId) }, onTTSEnd: (messageId: string, audio: string, audioType?: string) => { player.playAudioWithAudio(audio, false) }, ...restCallback, }, ) }, [store, reactflow, workflowStore, doSyncWorkflowDraft]) const handleStopRun = useCallback((taskId: string) => { const appId = useAppStore.getState().appDetail?.id stopWorkflowRun(`/apps/${appId}/workflow-runs/tasks/${taskId}/stop`) }, []) const handleRestoreFromPublishedWorkflow = useCallback(async () => { const appDetail = useAppStore.getState().appDetail const publishedWorkflow = await fetchPublishedWorkflow(`/apps/${appDetail?.id}/workflows/publish`) if (publishedWorkflow) { const nodes = publishedWorkflow.graph.nodes const edges = publishedWorkflow.graph.edges const viewport = publishedWorkflow.graph.viewport! handleUpdateWorkflowCanvas({ nodes, edges, viewport, }) featuresStore?.setState({ features: publishedWorkflow.features }) workflowStore.getState().setPublishedAt(publishedWorkflow.created_at) workflowStore.getState().setEnvironmentVariables(publishedWorkflow.environment_variables || []) } }, [featuresStore, handleUpdateWorkflowCanvas, workflowStore]) return { handleBackupDraft, handleLoadBackupDraft, handleRun, handleStopRun, handleRestoreFromPublishedWorkflow, } }