mirror of
https://github.com/simstudioai/sim.git
synced 2026-02-12 07:24:55 -05:00
improvement(executor): redesign executor + add start block (#1790)
* fix(billing): should allow restoring subscription (#1728) * fix(already-cancelled-sub): UI should allow restoring subscription * restore functionality fixed * fix * improvement(start): revert to start block * make it work with start block * fix start block persistence * cleanup triggers * debounce status checks * update docs * improvement(start): revert to start block * make it work with start block * fix start block persistence * cleanup triggers * debounce status checks * update docs * SSE v0.1 * v0.2 * v0.3 * v0.4 * v0.5 * v0.6 * broken checkpoint * Executor progress - everything preliminarily tested except while loops and triggers * Executor fixes * Fix var typing * Implement while loop execution * Loop and parallel result agg * Refactor v1 - loops work * Fix var resolution in for each loop * Fix while loop condition and variable resolution * Fix loop iteration counts * Fix loop badges * Clean logs * Fix variable references from start block * Fix condition block * Fix conditional convergence * Dont execute orphaned nodse * Code cleanup 1 and error surfacing * compile time try catch * Some fixes * Fix error throwing * Sentinels v1 * Fix multiple start and end nodes in loop * Edge restoration * Fix reachable nodes execution * Parallel subflows * Fix loop/parallel sentinel convergence * Loops and parallels orchestrator * Split executor * Variable resolution split * Dag phase * Refactor * Refactor * Refactor 3 * Lint + refactor * Lint + cleanup + refactor * Readability * Initial logs * Fix trace spans * Console pills for iters * Add input/output pills * Checkpoint * remove unused code * THIS IS THE COMMIT THAT CAN BREAK A LOT OF THINGS * ANOTHER BIG REFACTOR * Lint + fix tests * Fix webhook * Remove comment * Merge stash * Fix triggers? * Stuff * Fix error port * Lint * Consolidate state * Clean up some var resolution * Remove some var resolution logs * Fix chat * Fix chat triggers * Fix chat trigger fully * Snapshot refactor * Fix mcp and custom tools * Lint * Fix parallel default count and trace span overlay * Agent purple * Fix test * Fix test --------- Co-authored-by: Waleed <walif6@gmail.com> Co-authored-by: Vikhyath Mondreti <vikhyathvikku@gmail.com> Co-authored-by: Vikhyath Mondreti <vikhyath@simstudio.ai>
This commit is contained in:
committed by
GitHub
parent
7d67ae397d
commit
3bf00cbd2a
381
apps/sim/executor/orchestrators/loop.ts
Normal file
381
apps/sim/executor/orchestrators/loop.ts
Normal file
@@ -0,0 +1,381 @@
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { buildLoopIndexCondition, DEFAULTS, EDGE } from '@/executor/consts'
|
||||
import type { ExecutionContext, NormalizedBlockOutput } from '@/executor/types'
|
||||
import type { LoopConfigWithNodes } from '@/executor/types/loop'
|
||||
import {
|
||||
buildSentinelEndId,
|
||||
buildSentinelStartId,
|
||||
extractBaseBlockId,
|
||||
} from '@/executor/utils/subflow-utils'
|
||||
import type { SerializedLoop } from '@/serializer/types'
|
||||
import type { DAG } from '../dag/builder'
|
||||
import type { ExecutionState, LoopScope } from '../execution/state'
|
||||
import type { VariableResolver } from '../variables/resolver'
|
||||
|
||||
const logger = createLogger('LoopOrchestrator')
|
||||
|
||||
export type LoopRoute = typeof EDGE.LOOP_CONTINUE | typeof EDGE.LOOP_EXIT
|
||||
|
||||
export interface LoopContinuationResult {
|
||||
shouldContinue: boolean
|
||||
shouldExit: boolean
|
||||
selectedRoute: LoopRoute
|
||||
aggregatedResults?: NormalizedBlockOutput[][]
|
||||
currentIteration?: number
|
||||
}
|
||||
|
||||
export class LoopOrchestrator {
|
||||
constructor(
|
||||
private dag: DAG,
|
||||
private state: ExecutionState,
|
||||
private resolver: VariableResolver
|
||||
) {}
|
||||
|
||||
initializeLoopScope(ctx: ExecutionContext, loopId: string): LoopScope {
|
||||
const loopConfig = this.dag.loopConfigs.get(loopId) as SerializedLoop | undefined
|
||||
if (!loopConfig) {
|
||||
throw new Error(`Loop config not found: ${loopId}`)
|
||||
}
|
||||
|
||||
const scope: LoopScope = {
|
||||
iteration: 0,
|
||||
currentIterationOutputs: new Map(),
|
||||
allIterationOutputs: [],
|
||||
}
|
||||
|
||||
const loopType = loopConfig.loopType
|
||||
logger.debug('Initializing loop scope', { loopId, loopType })
|
||||
|
||||
switch (loopType) {
|
||||
case 'for':
|
||||
scope.maxIterations = loopConfig.iterations || DEFAULTS.MAX_LOOP_ITERATIONS
|
||||
scope.condition = buildLoopIndexCondition(scope.maxIterations)
|
||||
logger.debug('For loop initialized', { loopId, maxIterations: scope.maxIterations })
|
||||
break
|
||||
|
||||
case 'forEach': {
|
||||
const items = this.resolveForEachItems(ctx, loopConfig.forEachItems)
|
||||
scope.items = items
|
||||
scope.maxIterations = items.length
|
||||
scope.item = items[0]
|
||||
scope.condition = buildLoopIndexCondition(scope.maxIterations)
|
||||
logger.debug('ForEach loop initialized', { loopId, itemCount: items.length })
|
||||
break
|
||||
}
|
||||
|
||||
case 'while':
|
||||
scope.condition = loopConfig.whileCondition
|
||||
logger.debug('While loop initialized', { loopId, condition: scope.condition })
|
||||
break
|
||||
|
||||
case 'doWhile':
|
||||
if (loopConfig.doWhileCondition) {
|
||||
scope.condition = loopConfig.doWhileCondition
|
||||
} else {
|
||||
scope.maxIterations = loopConfig.iterations || DEFAULTS.MAX_LOOP_ITERATIONS
|
||||
scope.condition = buildLoopIndexCondition(scope.maxIterations)
|
||||
}
|
||||
scope.skipFirstConditionCheck = true
|
||||
logger.debug('DoWhile loop initialized', { loopId, condition: scope.condition })
|
||||
break
|
||||
|
||||
default:
|
||||
throw new Error(`Unknown loop type: ${loopType}`)
|
||||
}
|
||||
|
||||
this.state.setLoopScope(loopId, scope)
|
||||
return scope
|
||||
}
|
||||
|
||||
storeLoopNodeOutput(
|
||||
ctx: ExecutionContext,
|
||||
loopId: string,
|
||||
nodeId: string,
|
||||
output: NormalizedBlockOutput
|
||||
): void {
|
||||
const scope = this.state.getLoopScope(loopId)
|
||||
if (!scope) {
|
||||
logger.warn('Loop scope not found for node output storage', { loopId, nodeId })
|
||||
return
|
||||
}
|
||||
|
||||
const baseId = extractBaseBlockId(nodeId)
|
||||
scope.currentIterationOutputs.set(baseId, output)
|
||||
logger.debug('Stored loop node output', {
|
||||
loopId,
|
||||
nodeId: baseId,
|
||||
iteration: scope.iteration,
|
||||
outputsCount: scope.currentIterationOutputs.size,
|
||||
})
|
||||
}
|
||||
|
||||
evaluateLoopContinuation(ctx: ExecutionContext, loopId: string): LoopContinuationResult {
|
||||
const scope = this.state.getLoopScope(loopId)
|
||||
if (!scope) {
|
||||
logger.error('Loop scope not found during continuation evaluation', { loopId })
|
||||
return {
|
||||
shouldContinue: false,
|
||||
shouldExit: true,
|
||||
selectedRoute: EDGE.LOOP_EXIT,
|
||||
}
|
||||
}
|
||||
|
||||
const iterationResults: NormalizedBlockOutput[] = []
|
||||
for (const blockOutput of scope.currentIterationOutputs.values()) {
|
||||
iterationResults.push(blockOutput)
|
||||
}
|
||||
|
||||
if (iterationResults.length > 0) {
|
||||
scope.allIterationOutputs.push(iterationResults)
|
||||
logger.debug('Collected iteration results', {
|
||||
loopId,
|
||||
iteration: scope.iteration,
|
||||
resultsCount: iterationResults.length,
|
||||
})
|
||||
}
|
||||
|
||||
scope.currentIterationOutputs.clear()
|
||||
|
||||
const isFirstIteration = scope.iteration === 0
|
||||
const shouldSkipFirstCheck = scope.skipFirstConditionCheck && isFirstIteration
|
||||
if (!shouldSkipFirstCheck) {
|
||||
if (!this.evaluateCondition(ctx, scope, scope.iteration + 1)) {
|
||||
logger.debug('Loop condition false for next iteration - exiting', {
|
||||
loopId,
|
||||
currentIteration: scope.iteration,
|
||||
nextIteration: scope.iteration + 1,
|
||||
})
|
||||
return this.createExitResult(ctx, loopId, scope)
|
||||
}
|
||||
}
|
||||
|
||||
scope.iteration++
|
||||
if (scope.items && scope.iteration < scope.items.length) {
|
||||
scope.item = scope.items[scope.iteration]
|
||||
}
|
||||
|
||||
logger.debug('Loop will continue', {
|
||||
loopId,
|
||||
nextIteration: scope.iteration,
|
||||
})
|
||||
|
||||
return {
|
||||
shouldContinue: true,
|
||||
shouldExit: false,
|
||||
selectedRoute: EDGE.LOOP_CONTINUE,
|
||||
currentIteration: scope.iteration,
|
||||
}
|
||||
}
|
||||
|
||||
private createExitResult(
|
||||
ctx: ExecutionContext,
|
||||
loopId: string,
|
||||
scope: LoopScope
|
||||
): LoopContinuationResult {
|
||||
const results = scope.allIterationOutputs
|
||||
ctx.blockStates?.set(loopId, {
|
||||
output: { results },
|
||||
executed: true,
|
||||
executionTime: DEFAULTS.EXECUTION_TIME,
|
||||
})
|
||||
|
||||
logger.debug('Loop exiting', { loopId, totalIterations: scope.iteration })
|
||||
|
||||
return {
|
||||
shouldContinue: false,
|
||||
shouldExit: true,
|
||||
selectedRoute: EDGE.LOOP_EXIT,
|
||||
aggregatedResults: results,
|
||||
currentIteration: scope.iteration,
|
||||
}
|
||||
}
|
||||
|
||||
private evaluateCondition(ctx: ExecutionContext, scope: LoopScope, iteration?: number): boolean {
|
||||
if (!scope.condition) {
|
||||
logger.warn('No condition defined for loop')
|
||||
return false
|
||||
}
|
||||
|
||||
const currentIteration = scope.iteration
|
||||
if (iteration !== undefined) {
|
||||
scope.iteration = iteration
|
||||
}
|
||||
|
||||
const result = this.evaluateWhileCondition(ctx, scope.condition, scope)
|
||||
|
||||
if (iteration !== undefined) {
|
||||
scope.iteration = currentIteration
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
clearLoopExecutionState(loopId: string, executedBlocks: Set<string>): void {
|
||||
const loopConfig = this.dag.loopConfigs.get(loopId) as LoopConfigWithNodes | undefined
|
||||
if (!loopConfig) {
|
||||
logger.warn('Loop config not found for state clearing', { loopId })
|
||||
return
|
||||
}
|
||||
|
||||
const sentinelStartId = buildSentinelStartId(loopId)
|
||||
const sentinelEndId = buildSentinelEndId(loopId)
|
||||
const loopNodes = loopConfig.nodes
|
||||
|
||||
executedBlocks.delete(sentinelStartId)
|
||||
executedBlocks.delete(sentinelEndId)
|
||||
for (const loopNodeId of loopNodes) {
|
||||
executedBlocks.delete(loopNodeId)
|
||||
}
|
||||
|
||||
logger.debug('Cleared loop execution state', {
|
||||
loopId,
|
||||
nodesCleared: loopNodes.length + 2,
|
||||
})
|
||||
}
|
||||
|
||||
restoreLoopEdges(loopId: string): void {
|
||||
const loopConfig = this.dag.loopConfigs.get(loopId) as LoopConfigWithNodes | undefined
|
||||
if (!loopConfig) {
|
||||
logger.warn('Loop config not found for edge restoration', { loopId })
|
||||
return
|
||||
}
|
||||
|
||||
const sentinelStartId = buildSentinelStartId(loopId)
|
||||
const sentinelEndId = buildSentinelEndId(loopId)
|
||||
const loopNodes = loopConfig.nodes
|
||||
const allLoopNodeIds = new Set([sentinelStartId, sentinelEndId, ...loopNodes])
|
||||
|
||||
let restoredCount = 0
|
||||
for (const nodeId of allLoopNodeIds) {
|
||||
const nodeToRestore = this.dag.nodes.get(nodeId)
|
||||
if (!nodeToRestore) continue
|
||||
|
||||
for (const [potentialSourceId, potentialSourceNode] of this.dag.nodes) {
|
||||
if (!allLoopNodeIds.has(potentialSourceId)) continue
|
||||
|
||||
for (const [_, edge] of potentialSourceNode.outgoingEdges) {
|
||||
if (edge.target === nodeId) {
|
||||
const isBackwardEdge =
|
||||
edge.sourceHandle === EDGE.LOOP_CONTINUE ||
|
||||
edge.sourceHandle === EDGE.LOOP_CONTINUE_ALT
|
||||
|
||||
if (!isBackwardEdge) {
|
||||
nodeToRestore.incomingEdges.add(potentialSourceId)
|
||||
restoredCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug('Restored loop edges', { loopId, edgesRestored: restoredCount })
|
||||
}
|
||||
|
||||
getLoopScope(loopId: string): LoopScope | undefined {
|
||||
return this.state.getLoopScope(loopId)
|
||||
}
|
||||
|
||||
shouldExecuteLoopNode(nodeId: string, loopId: string, context: ExecutionContext): boolean {
|
||||
return true
|
||||
}
|
||||
|
||||
private findLoopForNode(nodeId: string): string | undefined {
|
||||
for (const [loopId, config] of this.dag.loopConfigs) {
|
||||
const nodes = (config as any).nodes || []
|
||||
if (nodes.includes(nodeId)) {
|
||||
return loopId
|
||||
}
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
private evaluateWhileCondition(
|
||||
ctx: ExecutionContext,
|
||||
condition: string,
|
||||
scope: LoopScope
|
||||
): boolean {
|
||||
if (!condition) {
|
||||
return false
|
||||
}
|
||||
|
||||
try {
|
||||
const referencePattern = /<([^>]+)>/g
|
||||
let evaluatedCondition = condition
|
||||
const replacements: Record<string, string> = {}
|
||||
|
||||
evaluatedCondition = evaluatedCondition.replace(referencePattern, (match) => {
|
||||
const resolved = this.resolver.resolveSingleReference(ctx, '', match, scope)
|
||||
if (resolved !== undefined) {
|
||||
if (typeof resolved === 'string') {
|
||||
replacements[match] = `"${resolved}"`
|
||||
return `"${resolved}"`
|
||||
}
|
||||
replacements[match] = String(resolved)
|
||||
return String(resolved)
|
||||
}
|
||||
return match
|
||||
})
|
||||
|
||||
const result = Boolean(new Function(`return (${evaluatedCondition})`)())
|
||||
|
||||
logger.debug('Evaluated loop condition', {
|
||||
condition,
|
||||
replacements,
|
||||
evaluatedCondition,
|
||||
result,
|
||||
iteration: scope.iteration,
|
||||
})
|
||||
|
||||
return result
|
||||
} catch (error) {
|
||||
logger.error('Failed to evaluate loop condition', { condition, error })
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
private resolveForEachItems(ctx: ExecutionContext, items: any): any[] {
|
||||
if (Array.isArray(items)) {
|
||||
return items
|
||||
}
|
||||
|
||||
if (typeof items === 'object' && items !== null) {
|
||||
return Object.entries(items)
|
||||
}
|
||||
|
||||
if (typeof items === 'string') {
|
||||
if (items.startsWith('<') && items.endsWith('>')) {
|
||||
const resolved = this.resolver.resolveSingleReference(ctx, '', items)
|
||||
return Array.isArray(resolved) ? resolved : []
|
||||
}
|
||||
|
||||
try {
|
||||
const normalized = items.replace(/'/g, '"')
|
||||
const parsed = JSON.parse(normalized)
|
||||
return Array.isArray(parsed) ? parsed : []
|
||||
} catch (error) {
|
||||
logger.error('Failed to parse forEach items', { items, error })
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const resolved = this.resolver.resolveInputs(ctx, 'loop_foreach_items', { items }).items
|
||||
|
||||
if (Array.isArray(resolved)) {
|
||||
return resolved
|
||||
}
|
||||
|
||||
logger.warn('ForEach items did not resolve to array', {
|
||||
items,
|
||||
resolved,
|
||||
})
|
||||
|
||||
return []
|
||||
} catch (error: any) {
|
||||
logger.error('Error resolving forEach items, returning empty array:', {
|
||||
error: error.message,
|
||||
})
|
||||
return []
|
||||
}
|
||||
}
|
||||
}
|
||||
227
apps/sim/executor/orchestrators/node.ts
Normal file
227
apps/sim/executor/orchestrators/node.ts
Normal file
@@ -0,0 +1,227 @@
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { EDGE } from '@/executor/consts'
|
||||
import type { ExecutionContext, NormalizedBlockOutput } from '@/executor/types'
|
||||
import { extractBaseBlockId } from '@/executor/utils/subflow-utils'
|
||||
import type { DAG, DAGNode } from '../dag/builder'
|
||||
import type { BlockExecutor } from '../execution/block-executor'
|
||||
import type { ExecutionState } from '../execution/state'
|
||||
import type { LoopOrchestrator } from './loop'
|
||||
import type { ParallelOrchestrator } from './parallel'
|
||||
|
||||
const logger = createLogger('NodeExecutionOrchestrator')
|
||||
|
||||
export interface NodeExecutionResult {
|
||||
nodeId: string
|
||||
output: NormalizedBlockOutput
|
||||
isFinalOutput: boolean
|
||||
}
|
||||
|
||||
export class NodeExecutionOrchestrator {
|
||||
constructor(
|
||||
private dag: DAG,
|
||||
private state: ExecutionState,
|
||||
private blockExecutor: BlockExecutor,
|
||||
private loopOrchestrator: LoopOrchestrator,
|
||||
private parallelOrchestrator: ParallelOrchestrator
|
||||
) {}
|
||||
|
||||
async executeNode(nodeId: string, context: any): Promise<NodeExecutionResult> {
|
||||
const node = this.dag.nodes.get(nodeId)
|
||||
if (!node) {
|
||||
throw new Error(`Node not found in DAG: ${nodeId}`)
|
||||
}
|
||||
|
||||
if (this.state.hasExecuted(nodeId)) {
|
||||
logger.debug('Node already executed, skipping', { nodeId })
|
||||
const output = this.state.getBlockOutput(nodeId) || {}
|
||||
return {
|
||||
nodeId,
|
||||
output,
|
||||
isFinalOutput: false,
|
||||
}
|
||||
}
|
||||
|
||||
const loopId = node.metadata.loopId
|
||||
if (loopId && !this.loopOrchestrator.getLoopScope(loopId)) {
|
||||
logger.debug('Initializing loop scope before first execution', { loopId, nodeId })
|
||||
this.loopOrchestrator.initializeLoopScope(context, loopId)
|
||||
}
|
||||
|
||||
if (loopId && !this.loopOrchestrator.shouldExecuteLoopNode(nodeId, loopId, context)) {
|
||||
logger.debug('Loop node should not execute', { nodeId, loopId })
|
||||
return {
|
||||
nodeId,
|
||||
output: {},
|
||||
isFinalOutput: false,
|
||||
}
|
||||
}
|
||||
|
||||
if (node.metadata.isSentinel) {
|
||||
logger.debug('Executing sentinel node', {
|
||||
nodeId,
|
||||
sentinelType: node.metadata.sentinelType,
|
||||
loopId,
|
||||
})
|
||||
const output = this.handleSentinel(node, context)
|
||||
const isFinalOutput = node.outgoingEdges.size === 0
|
||||
return {
|
||||
nodeId,
|
||||
output,
|
||||
isFinalOutput,
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug('Executing node', { nodeId, blockType: node.block.metadata?.id })
|
||||
const output = await this.blockExecutor.execute(context, node, node.block)
|
||||
const isFinalOutput = node.outgoingEdges.size === 0
|
||||
return {
|
||||
nodeId,
|
||||
output,
|
||||
isFinalOutput,
|
||||
}
|
||||
}
|
||||
|
||||
private handleSentinel(node: DAGNode, context: any): NormalizedBlockOutput {
|
||||
const sentinelType = node.metadata.sentinelType
|
||||
const loopId = node.metadata.loopId
|
||||
if (sentinelType === 'start') {
|
||||
logger.debug('Sentinel start - loop entry', { nodeId: node.id, loopId })
|
||||
return { sentinelStart: true }
|
||||
}
|
||||
|
||||
if (sentinelType === 'end') {
|
||||
logger.debug('Sentinel end - evaluating loop continuation', { nodeId: node.id, loopId })
|
||||
if (!loopId) {
|
||||
logger.warn('Sentinel end called without loopId')
|
||||
return { shouldExit: true, selectedRoute: EDGE.LOOP_EXIT }
|
||||
}
|
||||
|
||||
const continuationResult = this.loopOrchestrator.evaluateLoopContinuation(context, loopId)
|
||||
logger.debug('Loop continuation evaluated', {
|
||||
loopId,
|
||||
shouldContinue: continuationResult.shouldContinue,
|
||||
shouldExit: continuationResult.shouldExit,
|
||||
iteration: continuationResult.currentIteration,
|
||||
})
|
||||
|
||||
if (continuationResult.shouldContinue) {
|
||||
return {
|
||||
shouldContinue: true,
|
||||
shouldExit: false,
|
||||
selectedRoute: continuationResult.selectedRoute,
|
||||
loopIteration: continuationResult.currentIteration,
|
||||
}
|
||||
}
|
||||
return {
|
||||
results: continuationResult.aggregatedResults || [],
|
||||
shouldContinue: false,
|
||||
shouldExit: true,
|
||||
selectedRoute: continuationResult.selectedRoute,
|
||||
totalIterations: continuationResult.aggregatedResults?.length || 0,
|
||||
}
|
||||
}
|
||||
logger.warn('Unknown sentinel type', { sentinelType })
|
||||
return {}
|
||||
}
|
||||
|
||||
async handleNodeCompletion(
|
||||
nodeId: string,
|
||||
output: NormalizedBlockOutput,
|
||||
context: any
|
||||
): Promise<void> {
|
||||
const node = this.dag.nodes.get(nodeId)
|
||||
if (!node) {
|
||||
logger.error('Node not found during completion handling', { nodeId })
|
||||
return
|
||||
}
|
||||
|
||||
logger.debug('Handling node completion', {
|
||||
nodeId: node.id,
|
||||
hasLoopId: !!node.metadata.loopId,
|
||||
isParallelBranch: !!node.metadata.isParallelBranch,
|
||||
isSentinel: !!node.metadata.isSentinel,
|
||||
})
|
||||
|
||||
const loopId = node.metadata.loopId
|
||||
const isParallelBranch = node.metadata.isParallelBranch
|
||||
const isSentinel = node.metadata.isSentinel
|
||||
if (isSentinel) {
|
||||
logger.debug('Handling sentinel node', { nodeId: node.id, loopId })
|
||||
this.handleRegularNodeCompletion(node, output, context)
|
||||
} else if (loopId) {
|
||||
logger.debug('Handling loop node', { nodeId: node.id, loopId })
|
||||
this.handleLoopNodeCompletion(node, output, loopId, context)
|
||||
} else if (isParallelBranch) {
|
||||
const parallelId = this.findParallelIdForNode(node.id)
|
||||
if (parallelId) {
|
||||
logger.debug('Handling parallel node', { nodeId: node.id, parallelId })
|
||||
this.handleParallelNodeCompletion(node, output, parallelId)
|
||||
} else {
|
||||
this.handleRegularNodeCompletion(node, output, context)
|
||||
}
|
||||
} else {
|
||||
logger.debug('Handling regular node', { nodeId: node.id })
|
||||
this.handleRegularNodeCompletion(node, output, context)
|
||||
}
|
||||
}
|
||||
|
||||
private handleLoopNodeCompletion(
|
||||
node: DAGNode,
|
||||
output: NormalizedBlockOutput,
|
||||
loopId: string,
|
||||
context: ExecutionContext
|
||||
): void {
|
||||
this.loopOrchestrator.storeLoopNodeOutput(context, loopId, node.id, output)
|
||||
this.state.setBlockOutput(node.id, output)
|
||||
}
|
||||
|
||||
private handleParallelNodeCompletion(
|
||||
node: DAGNode,
|
||||
output: NormalizedBlockOutput,
|
||||
parallelId: string
|
||||
): void {
|
||||
const scope = this.parallelOrchestrator.getParallelScope(parallelId)
|
||||
if (!scope) {
|
||||
const totalBranches = node.metadata.branchTotal || 1
|
||||
const parallelConfig = this.dag.parallelConfigs.get(parallelId)
|
||||
const nodesInParallel = (parallelConfig as any)?.nodes?.length || 1
|
||||
this.parallelOrchestrator.initializeParallelScope(parallelId, totalBranches, nodesInParallel)
|
||||
}
|
||||
const allComplete = this.parallelOrchestrator.handleParallelBranchCompletion(
|
||||
parallelId,
|
||||
node.id,
|
||||
output
|
||||
)
|
||||
if (allComplete) {
|
||||
this.parallelOrchestrator.aggregateParallelResults(parallelId)
|
||||
}
|
||||
|
||||
this.state.setBlockOutput(node.id, output)
|
||||
}
|
||||
|
||||
private handleRegularNodeCompletion(
|
||||
node: DAGNode,
|
||||
output: NormalizedBlockOutput,
|
||||
context: any
|
||||
): void {
|
||||
this.state.setBlockOutput(node.id, output)
|
||||
|
||||
if (
|
||||
node.metadata.isSentinel &&
|
||||
node.metadata.sentinelType === 'end' &&
|
||||
output.selectedRoute === 'loop_continue'
|
||||
) {
|
||||
const loopId = node.metadata.loopId
|
||||
if (loopId) {
|
||||
logger.debug('Preparing loop for next iteration', { loopId })
|
||||
this.loopOrchestrator.clearLoopExecutionState(loopId, this.state.executedBlocks)
|
||||
this.loopOrchestrator.restoreLoopEdges(loopId)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private findParallelIdForNode(nodeId: string): string | undefined {
|
||||
const baseId = extractBaseBlockId(nodeId)
|
||||
return this.parallelOrchestrator.findParallelIdForNode(baseId)
|
||||
}
|
||||
}
|
||||
181
apps/sim/executor/orchestrators/parallel.ts
Normal file
181
apps/sim/executor/orchestrators/parallel.ts
Normal file
@@ -0,0 +1,181 @@
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import type { NormalizedBlockOutput } from '@/executor/types'
|
||||
import type { ParallelConfigWithNodes } from '@/executor/types/parallel'
|
||||
import {
|
||||
calculateBranchCount,
|
||||
extractBaseBlockId,
|
||||
extractBranchIndex,
|
||||
parseDistributionItems,
|
||||
} from '@/executor/utils/subflow-utils'
|
||||
import type { SerializedParallel } from '@/serializer/types'
|
||||
import type { DAG } from '../dag/builder'
|
||||
import type { ExecutionState, ParallelScope } from '../execution/state'
|
||||
|
||||
const logger = createLogger('ParallelOrchestrator')
|
||||
|
||||
export interface ParallelBranchMetadata {
|
||||
branchIndex: number
|
||||
branchTotal: number
|
||||
distributionItem?: any
|
||||
parallelId: string
|
||||
}
|
||||
|
||||
export interface ParallelAggregationResult {
|
||||
allBranchesComplete: boolean
|
||||
results?: NormalizedBlockOutput[][]
|
||||
completedBranches?: number
|
||||
totalBranches?: number
|
||||
}
|
||||
|
||||
export class ParallelOrchestrator {
|
||||
constructor(
|
||||
private dag: DAG,
|
||||
private state: ExecutionState
|
||||
) {}
|
||||
|
||||
initializeParallelScope(
|
||||
parallelId: string,
|
||||
totalBranches: number,
|
||||
terminalNodesCount = 1
|
||||
): ParallelScope {
|
||||
const scope: ParallelScope = {
|
||||
parallelId,
|
||||
totalBranches,
|
||||
branchOutputs: new Map(),
|
||||
completedCount: 0,
|
||||
totalExpectedNodes: totalBranches * terminalNodesCount,
|
||||
}
|
||||
this.state.setParallelScope(parallelId, scope)
|
||||
logger.debug('Initialized parallel scope', {
|
||||
parallelId,
|
||||
totalBranches,
|
||||
terminalNodesCount,
|
||||
totalExpectedNodes: scope.totalExpectedNodes,
|
||||
})
|
||||
return scope
|
||||
}
|
||||
|
||||
handleParallelBranchCompletion(
|
||||
parallelId: string,
|
||||
nodeId: string,
|
||||
output: NormalizedBlockOutput
|
||||
): boolean {
|
||||
const scope = this.state.getParallelScope(parallelId)
|
||||
if (!scope) {
|
||||
logger.warn('Parallel scope not found for branch completion', { parallelId, nodeId })
|
||||
return false
|
||||
}
|
||||
|
||||
const branchIndex = extractBranchIndex(nodeId)
|
||||
if (branchIndex === null) {
|
||||
logger.warn('Could not extract branch index from node ID', { nodeId })
|
||||
return false
|
||||
}
|
||||
|
||||
if (!scope.branchOutputs.has(branchIndex)) {
|
||||
scope.branchOutputs.set(branchIndex, [])
|
||||
}
|
||||
scope.branchOutputs.get(branchIndex)!.push(output)
|
||||
scope.completedCount++
|
||||
logger.debug('Recorded parallel branch output', {
|
||||
parallelId,
|
||||
branchIndex,
|
||||
nodeId,
|
||||
completedCount: scope.completedCount,
|
||||
totalExpected: scope.totalExpectedNodes,
|
||||
})
|
||||
|
||||
const allComplete = scope.completedCount >= scope.totalExpectedNodes
|
||||
if (allComplete) {
|
||||
logger.debug('All parallel branches completed', {
|
||||
parallelId,
|
||||
totalBranches: scope.totalBranches,
|
||||
completedNodes: scope.completedCount,
|
||||
})
|
||||
}
|
||||
return allComplete
|
||||
}
|
||||
|
||||
aggregateParallelResults(parallelId: string): ParallelAggregationResult {
|
||||
const scope = this.state.getParallelScope(parallelId)
|
||||
if (!scope) {
|
||||
logger.error('Parallel scope not found for aggregation', { parallelId })
|
||||
return { allBranchesComplete: false }
|
||||
}
|
||||
|
||||
const results: NormalizedBlockOutput[][] = []
|
||||
for (let i = 0; i < scope.totalBranches; i++) {
|
||||
const branchOutputs = scope.branchOutputs.get(i) || []
|
||||
results.push(branchOutputs)
|
||||
}
|
||||
this.state.setBlockOutput(parallelId, {
|
||||
results,
|
||||
})
|
||||
logger.debug('Aggregated parallel results', {
|
||||
parallelId,
|
||||
totalBranches: scope.totalBranches,
|
||||
nodesPerBranch: results[0]?.length || 0,
|
||||
totalOutputs: scope.completedCount,
|
||||
})
|
||||
return {
|
||||
allBranchesComplete: true,
|
||||
results,
|
||||
completedBranches: scope.totalBranches,
|
||||
totalBranches: scope.totalBranches,
|
||||
}
|
||||
}
|
||||
extractBranchMetadata(nodeId: string): ParallelBranchMetadata | null {
|
||||
const branchIndex = extractBranchIndex(nodeId)
|
||||
if (branchIndex === null) {
|
||||
return null
|
||||
}
|
||||
|
||||
const baseId = extractBaseBlockId(nodeId)
|
||||
const parallelId = this.findParallelIdForNode(baseId)
|
||||
if (!parallelId) {
|
||||
return null
|
||||
}
|
||||
const parallelConfig = this.dag.parallelConfigs.get(parallelId)
|
||||
if (!parallelConfig) {
|
||||
return null
|
||||
}
|
||||
const { totalBranches, distributionItem } = this.getParallelConfigInfo(
|
||||
parallelConfig,
|
||||
branchIndex
|
||||
)
|
||||
return {
|
||||
branchIndex,
|
||||
branchTotal: totalBranches,
|
||||
distributionItem,
|
||||
parallelId,
|
||||
}
|
||||
}
|
||||
|
||||
getParallelScope(parallelId: string): ParallelScope | undefined {
|
||||
return this.state.getParallelScope(parallelId)
|
||||
}
|
||||
|
||||
findParallelIdForNode(baseNodeId: string): string | undefined {
|
||||
for (const [parallelId, config] of this.dag.parallelConfigs) {
|
||||
const parallelConfig = config as ParallelConfigWithNodes
|
||||
if (parallelConfig.nodes?.includes(baseNodeId)) {
|
||||
return parallelId
|
||||
}
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
private getParallelConfigInfo(
|
||||
parallelConfig: SerializedParallel,
|
||||
branchIndex: number
|
||||
): { totalBranches: number; distributionItem?: any } {
|
||||
const distributionItems = parseDistributionItems(parallelConfig)
|
||||
const totalBranches = calculateBranchCount(parallelConfig, distributionItems)
|
||||
|
||||
let distributionItem: any
|
||||
if (Array.isArray(distributionItems) && branchIndex < distributionItems.length) {
|
||||
distributionItem = distributionItems[branchIndex]
|
||||
}
|
||||
return { totalBranches, distributionItem }
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user