diff --git a/apps/sim/app/api/workflows/[id]/execute/cancel/route.ts b/apps/sim/app/api/workflows/[id]/execute/cancel/route.ts new file mode 100644 index 0000000000..830214a504 --- /dev/null +++ b/apps/sim/app/api/workflows/[id]/execute/cancel/route.ts @@ -0,0 +1,39 @@ +import { type NextRequest, NextResponse } from 'next/server' +import { z } from 'zod' +import { checkHybridAuth } from '@/lib/auth/hybrid' +import { requestCancellation } from '@/lib/execution/cancellation' + +const CancelExecutionSchema = z.object({ + executionId: z.string().uuid(), +}) + +export const runtime = 'nodejs' +export const dynamic = 'force-dynamic' + +export async function POST(req: NextRequest, { params }: { params: Promise<{ id: string }> }) { + await params + + const auth = await checkHybridAuth(req, { requireWorkflowId: false }) + if (!auth.success || !auth.userId) { + return NextResponse.json({ error: auth.error || 'Unauthorized' }, { status: 401 }) + } + + let body: any = {} + try { + const text = await req.text() + if (text) { + body = JSON.parse(text) + } + } catch { + return NextResponse.json({ error: 'Invalid request body' }, { status: 400 }) + } + + const validation = CancelExecutionSchema.safeParse(body) + if (!validation.success) { + return NextResponse.json({ error: 'Invalid request body' }, { status: 400 }) + } + + const { executionId } = validation.data + const success = await requestCancellation(executionId) + return NextResponse.json({ success }) +} diff --git a/apps/sim/app/api/workflows/[id]/execute/route.ts b/apps/sim/app/api/workflows/[id]/execute/route.ts index dd70158d38..5c318d89e6 100644 --- a/apps/sim/app/api/workflows/[id]/execute/route.ts +++ b/apps/sim/app/api/workflows/[id]/execute/route.ts @@ -7,6 +7,7 @@ import { isTriggerDevEnabled } from '@/lib/core/config/feature-flags' import { generateRequestId } from '@/lib/core/utils/request' import { SSE_HEADERS } from '@/lib/core/utils/sse' import { getBaseUrl } from '@/lib/core/utils/urls' +import { clearCancellation } from '@/lib/execution/cancellation' import { processInputFileFields } from '@/lib/execution/files' import { preprocessExecution } from '@/lib/execution/preprocessing' import { createLogger } from '@/lib/logs/console/logger' @@ -496,7 +497,6 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id: } const encoder = new TextEncoder() - let executorInstance: any = null let isStreamClosed = false const stream = new ReadableStream({ @@ -688,9 +688,6 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id: onBlockStart, onBlockComplete, onStream, - onExecutorCreated: (executor) => { - executorInstance = executor - }, }, loggingSession, }) @@ -757,24 +754,18 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id: }, }) } finally { + await clearCancellation(executionId) + if (!isStreamClosed) { try { controller.enqueue(encoder.encode('data: [DONE]\n\n')) controller.close() } catch { - // Stream already closed - nothing to do + // Stream already closed } } } }, - cancel() { - isStreamClosed = true - logger.info(`[${requestId}] Client aborted SSE stream, cancelling executor`) - - if (executorInstance && typeof executorInstance.cancel === 'function') { - executorInstance.cancel() - } - }, }) return new NextResponse(stream, { diff --git a/apps/sim/executor/execution/block-executor.ts b/apps/sim/executor/execution/block-executor.ts index b0723df04e..814340a404 100644 --- a/apps/sim/executor/execution/block-executor.ts +++ b/apps/sim/executor/execution/block-executor.ts @@ -2,6 +2,7 @@ import { db } from '@sim/db' import { mcpServers } from '@sim/db/schema' import { and, eq, inArray, isNull } from 'drizzle-orm' import { getBaseUrl } from '@/lib/core/utils/urls' +import { isCancellationRequested } from '@/lib/execution/cancellation' import { createLogger } from '@/lib/logs/console/logger' import { BlockType, @@ -32,6 +33,8 @@ import type { SubflowType } from '@/stores/workflows/workflow/types' const logger = createLogger('BlockExecutor') +const CANCELLATION_CHECK_INTERVAL_MS = 1000 + export class BlockExecutor { constructor( private blockHandlers: BlockHandler[], @@ -548,10 +551,16 @@ export class BlockExecutor { return } - const [clientStream, executorStream] = stream.tee() + const { clientStream: controlledClientStream, consume } = this.createControlledStream( + ctx, + stream, + blockId, + responseFormat, + streamingExec + ) const processedClientStream = streamingResponseFormatProcessor.processStream( - clientStream, + controlledClientStream, blockId, selectedOutputs, responseFormat @@ -562,13 +571,6 @@ export class BlockExecutor { stream: processedClientStream, } - const executorConsumption = this.consumeExecutorStream( - executorStream, - streamingExec, - blockId, - responseFormat - ) - const clientConsumption = (async () => { try { await ctx.onStream?.(clientStreamingExec) @@ -577,7 +579,7 @@ export class BlockExecutor { } })() - await Promise.all([clientConsumption, executorConsumption]) + await Promise.all([clientConsumption, consume()]) } private async forwardStream( @@ -605,57 +607,98 @@ export class BlockExecutor { } } - private async consumeExecutorStream( - stream: ReadableStream, - streamingExec: { execution: any }, + private createControlledStream( + ctx: ExecutionContext, + sourceStream: ReadableStream, blockId: string, - responseFormat: any - ): Promise { - const reader = stream.getReader() - const decoder = new TextDecoder() + responseFormat: any, + streamingExec: { execution: any } + ): { clientStream: ReadableStream; consume: () => Promise } { + let clientController: ReadableStreamDefaultController | null = null let fullContent = '' - try { - while (true) { - const { done, value } = await reader.read() - if (done) break - fullContent += decoder.decode(value, { stream: true }) - } - } catch (error) { - logger.error('Error reading executor stream for block', { blockId, error }) - } finally { - try { - reader.releaseLock() - } catch {} - } - - if (!fullContent) { - return - } + const clientStream = new ReadableStream({ + start(controller) { + clientController = controller + }, + }) - const executionOutput = streamingExec.execution?.output - if (!executionOutput || typeof executionOutput !== 'object') { - return - } + const consume = async () => { + const reader = sourceStream.getReader() + const decoder = new TextDecoder() + let lastCancellationCheck = Date.now() - if (responseFormat) { try { - const parsed = JSON.parse(fullContent.trim()) - - streamingExec.execution.output = { - ...parsed, - tokens: executionOutput.tokens, - toolCalls: executionOutput.toolCalls, - providerTiming: executionOutput.providerTiming, - cost: executionOutput.cost, - model: executionOutput.model, + while (true) { + const now = Date.now() + if (ctx.executionId && now - lastCancellationCheck >= CANCELLATION_CHECK_INTERVAL_MS) { + lastCancellationCheck = now + const cancelled = await isCancellationRequested(ctx.executionId) + if (cancelled) { + ctx.isCancelled = true + try { + clientController?.close() + } catch {} + reader.cancel() + break + } + } + + const { done, value } = await reader.read() + if (done) { + try { + clientController?.close() + } catch {} + break + } + + fullContent += decoder.decode(value, { stream: true }) + try { + clientController?.enqueue(value) + } catch {} } - return } catch (error) { - logger.warn('Failed to parse streamed content for response format', { blockId, error }) + if (!ctx.isCancelled) { + logger.error('Error reading stream for block', { blockId, error }) + } + try { + clientController?.close() + } catch {} + } finally { + try { + reader.releaseLock() + } catch {} + } + + if (!fullContent) { + return + } + + const executionOutput = streamingExec.execution?.output + if (!executionOutput || typeof executionOutput !== 'object') { + return } + + if (responseFormat) { + try { + const parsed = JSON.parse(fullContent.trim()) + streamingExec.execution.output = { + ...parsed, + tokens: executionOutput.tokens, + toolCalls: executionOutput.toolCalls, + providerTiming: executionOutput.providerTiming, + cost: executionOutput.cost, + model: executionOutput.model, + } + return + } catch (error) { + logger.warn('Failed to parse streamed content for response format', { blockId, error }) + } + } + + executionOutput.content = fullContent } - executionOutput.content = fullContent + return { clientStream, consume } } } diff --git a/apps/sim/executor/execution/engine.ts b/apps/sim/executor/execution/engine.ts index bf33df5961..fffa156f75 100644 --- a/apps/sim/executor/execution/engine.ts +++ b/apps/sim/executor/execution/engine.ts @@ -1,3 +1,4 @@ +import { isCancellationRequested } from '@/lib/execution/cancellation' import { createLogger } from '@/lib/logs/console/logger' import { BlockType } from '@/executor/constants' import type { DAG } from '@/executor/dag/builder' @@ -33,13 +34,24 @@ export class ExecutionEngine { this.allowResumeTriggers = this.context.metadata.resumeFromSnapshot === true } + private async checkCancellation(): Promise { + if (this.context.isCancelled) return true + const executionId = this.context.executionId + if (!executionId) return false + const cancelled = await isCancellationRequested(executionId) + if (cancelled) { + this.context.isCancelled = true + } + return cancelled + } + async run(triggerBlockId?: string): Promise { const startTime = Date.now() try { this.initializeQueue(triggerBlockId) while (this.hasWork()) { - if (this.context.isCancelled && this.executing.size === 0) { + if ((await this.checkCancellation()) && this.executing.size === 0) { break } await this.processQueue() @@ -234,7 +246,7 @@ export class ExecutionEngine { private async processQueue(): Promise { while (this.readyQueue.length > 0) { - if (this.context.isCancelled) { + if (await this.checkCancellation()) { break } const nodeId = this.dequeue() diff --git a/apps/sim/executor/execution/executor.ts b/apps/sim/executor/execution/executor.ts index 2f6e21573a..76c0cba5c1 100644 --- a/apps/sim/executor/execution/executor.ts +++ b/apps/sim/executor/execution/executor.ts @@ -54,9 +54,11 @@ export class DAGExecutor { const dag = this.dagBuilder.build(this.workflow, triggerBlockId, savedIncomingEdges) const { context, state } = this.createExecutionContext(workflowId, triggerBlockId) - // Link cancellation flag to context Object.defineProperty(context, 'isCancelled', { get: () => this.isCancelled, + set: (value: boolean) => { + this.isCancelled = value + }, enumerable: true, configurable: true, }) diff --git a/apps/sim/executor/execution/snapshot.ts b/apps/sim/executor/execution/snapshot.ts index 60ffdbd484..dfa0d1cc37 100644 --- a/apps/sim/executor/execution/snapshot.ts +++ b/apps/sim/executor/execution/snapshot.ts @@ -34,7 +34,6 @@ export interface ExecutionCallbacks { blockType: string, output: any ) => Promise - onExecutorCreated?: (executor: any) => void } export interface SerializableExecutionState { diff --git a/apps/sim/hooks/use-execution-stream.ts b/apps/sim/hooks/use-execution-stream.ts index f5fe211908..62835966f2 100644 --- a/apps/sim/hooks/use-execution-stream.ts +++ b/apps/sim/hooks/use-execution-stream.ts @@ -76,6 +76,10 @@ export interface ExecuteStreamOptions { */ export function useExecutionStream() { const abortControllerRef = useRef(null) + const currentExecutionRef = useRef<{ workflowId: string; executionId: string | null }>({ + workflowId: '', + executionId: null, + }) const execute = useCallback(async (options: ExecuteStreamOptions) => { const { workflowId, callbacks = {}, ...payload } = options @@ -89,6 +93,8 @@ export function useExecutionStream() { const abortController = new AbortController() abortControllerRef.current = abortController + currentExecutionRef.current = { workflowId, executionId: null } + try { const response = await fetch(`/api/workflows/${workflowId}/execute`, { method: 'POST', @@ -108,6 +114,11 @@ export function useExecutionStream() { throw new Error('No response body') } + const executionId = response.headers.get('X-Execution-Id') + if (executionId) { + currentExecutionRef.current.executionId = executionId + } + // Read SSE stream const reader = response.body.getReader() const decoder = new TextDecoder() @@ -215,6 +226,7 @@ export function useExecutionStream() { throw error } finally { abortControllerRef.current = null + currentExecutionRef.current = { workflowId: '', executionId: null } } }, []) @@ -223,6 +235,17 @@ export function useExecutionStream() { abortControllerRef.current.abort() abortControllerRef.current = null } + + const { workflowId, executionId } = currentExecutionRef.current + if (workflowId && executionId) { + fetch(`/api/workflows/${workflowId}/execute/cancel`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ executionId }), + }).catch(() => {}) + } + + currentExecutionRef.current = { workflowId: '', executionId: null } }, []) return { diff --git a/apps/sim/lib/execution/cancellation.ts b/apps/sim/lib/execution/cancellation.ts new file mode 100644 index 0000000000..e99fe3ec15 --- /dev/null +++ b/apps/sim/lib/execution/cancellation.ts @@ -0,0 +1,50 @@ +import { getRedisClient } from '@/lib/core/config/redis' + +const KEY_PREFIX = 'execution:cancel:' +const TTL_SECONDS = 300 +const TTL_MS = TTL_SECONDS * 1000 + +const memoryStore = new Map() + +export async function requestCancellation(executionId: string): Promise { + const redis = getRedisClient() + if (redis) { + try { + await redis.set(`${KEY_PREFIX}${executionId}`, '1', 'EX', TTL_SECONDS) + return true + } catch { + return false + } + } + memoryStore.set(executionId, Date.now() + TTL_MS) + return true +} + +export async function isCancellationRequested(executionId: string): Promise { + const redis = getRedisClient() + if (redis) { + try { + return (await redis.exists(`${KEY_PREFIX}${executionId}`)) === 1 + } catch { + return false + } + } + const expiry = memoryStore.get(executionId) + if (!expiry) return false + if (Date.now() > expiry) { + memoryStore.delete(executionId) + return false + } + return true +} + +export async function clearCancellation(executionId: string): Promise { + const redis = getRedisClient() + if (redis) { + try { + await redis.del(`${KEY_PREFIX}${executionId}`) + } catch {} + return + } + memoryStore.delete(executionId) +} diff --git a/apps/sim/lib/workflows/executor/execution-core.ts b/apps/sim/lib/workflows/executor/execution-core.ts index 26673e831b..2f3ad76241 100644 --- a/apps/sim/lib/workflows/executor/execution-core.ts +++ b/apps/sim/lib/workflows/executor/execution-core.ts @@ -102,7 +102,7 @@ export async function executeWorkflowCore( const { metadata, workflow, input, workflowVariables, selectedOutputs } = snapshot const { requestId, workflowId, userId, triggerType, executionId, triggerBlockId, useDraftState } = metadata - const { onBlockStart, onBlockComplete, onStream, onExecutorCreated } = callbacks + const { onBlockStart, onBlockComplete, onStream } = callbacks const providedWorkspaceId = metadata.workspaceId if (!providedWorkspaceId) { @@ -349,10 +349,6 @@ export async function executeWorkflowCore( } } - if (onExecutorCreated) { - onExecutorCreated(executorInstance) - } - const result = (await executorInstance.execute( workflowId, resolvedTriggerBlockId