mirror of
https://github.com/simstudioai/sim.git
synced 2026-04-28 03:00:29 -04:00
fix(parallel): remove broken node-counting completion + resolver claim cross-block (#4045)
* fix(parallel): remove broken node-counting completion in parallel blocks * fix resolver claim --------- Co-authored-by: Vikhyath Mondreti <vikhyath@simstudio.ai>
This commit is contained in:
@@ -22,8 +22,6 @@ export interface ParallelScope {
|
||||
parallelId: string
|
||||
totalBranches: number
|
||||
branchOutputs: Map<number, NormalizedBlockOutput[]>
|
||||
completedCount: number
|
||||
totalExpectedNodes: number
|
||||
items?: any[]
|
||||
/** Error message if parallel validation failed (e.g., exceeded max branches) */
|
||||
validationError?: string
|
||||
|
||||
@@ -58,9 +58,7 @@ export class NodeExecutionOrchestrator {
|
||||
|
||||
const parallelId = node.metadata.parallelId
|
||||
if (parallelId && !this.parallelOrchestrator.getParallelScope(ctx, parallelId)) {
|
||||
const parallelConfig = this.dag.parallelConfigs.get(parallelId)
|
||||
const nodesInParallel = parallelConfig?.nodes?.length || 1
|
||||
await this.parallelOrchestrator.initializeParallelScope(ctx, parallelId, nodesInParallel)
|
||||
await this.parallelOrchestrator.initializeParallelScope(ctx, parallelId)
|
||||
}
|
||||
|
||||
if (node.metadata.isSentinel) {
|
||||
@@ -157,8 +155,7 @@ export class NodeExecutionOrchestrator {
|
||||
if (!this.parallelOrchestrator.getParallelScope(ctx, parallelId)) {
|
||||
const parallelConfig = this.dag.parallelConfigs.get(parallelId)
|
||||
if (parallelConfig) {
|
||||
const nodesInParallel = parallelConfig.nodes?.length || 1
|
||||
await this.parallelOrchestrator.initializeParallelScope(ctx, parallelId, nodesInParallel)
|
||||
await this.parallelOrchestrator.initializeParallelScope(ctx, parallelId)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -237,20 +234,9 @@ export class NodeExecutionOrchestrator {
|
||||
): Promise<void> {
|
||||
const scope = this.parallelOrchestrator.getParallelScope(ctx, parallelId)
|
||||
if (!scope) {
|
||||
const parallelConfig = this.dag.parallelConfigs.get(parallelId)
|
||||
const nodesInParallel = parallelConfig?.nodes?.length || 1
|
||||
await this.parallelOrchestrator.initializeParallelScope(ctx, parallelId, nodesInParallel)
|
||||
await this.parallelOrchestrator.initializeParallelScope(ctx, parallelId)
|
||||
}
|
||||
const allComplete = this.parallelOrchestrator.handleParallelBranchCompletion(
|
||||
ctx,
|
||||
parallelId,
|
||||
node.id,
|
||||
output
|
||||
)
|
||||
if (allComplete) {
|
||||
await this.parallelOrchestrator.aggregateParallelResults(ctx, parallelId)
|
||||
}
|
||||
|
||||
this.parallelOrchestrator.handleParallelBranchCompletion(ctx, parallelId, node.id, output)
|
||||
this.state.setBlockOutput(node.id, output)
|
||||
}
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ describe('ParallelOrchestrator', () => {
|
||||
)
|
||||
const ctx = createContext()
|
||||
|
||||
const initializePromise = orchestrator.initializeParallelScope(ctx, 'parallel-1', 1)
|
||||
const initializePromise = orchestrator.initializeParallelScope(ctx, 'parallel-1')
|
||||
await Promise.resolve()
|
||||
|
||||
expect(onBlockStart).toHaveBeenCalledTimes(1)
|
||||
|
||||
@@ -47,17 +47,13 @@ export class ParallelOrchestrator {
|
||||
private contextExtensions: ContextExtensions | null = null
|
||||
) {}
|
||||
|
||||
async initializeParallelScope(
|
||||
ctx: ExecutionContext,
|
||||
parallelId: string,
|
||||
terminalNodesCount = 1
|
||||
): Promise<ParallelScope> {
|
||||
async initializeParallelScope(ctx: ExecutionContext, parallelId: string): Promise<ParallelScope> {
|
||||
const parallelConfig = this.dag.parallelConfigs.get(parallelId)
|
||||
if (!parallelConfig) {
|
||||
throw new Error(`Parallel config not found: ${parallelId}`)
|
||||
}
|
||||
|
||||
if (terminalNodesCount === 0 || parallelConfig.nodes.length === 0) {
|
||||
if (parallelConfig.nodes.length === 0) {
|
||||
const errorMessage =
|
||||
'Parallel has no executable blocks inside. Add or enable at least one block in the parallel.'
|
||||
logger.error(errorMessage, { parallelId })
|
||||
@@ -108,8 +104,6 @@ export class ParallelOrchestrator {
|
||||
parallelId,
|
||||
totalBranches: 0,
|
||||
branchOutputs: new Map(),
|
||||
completedCount: 0,
|
||||
totalExpectedNodes: 0,
|
||||
items: [],
|
||||
isEmpty: true,
|
||||
}
|
||||
@@ -186,8 +180,6 @@ export class ParallelOrchestrator {
|
||||
parallelId,
|
||||
totalBranches: branchCount,
|
||||
branchOutputs: new Map(),
|
||||
completedCount: 0,
|
||||
totalExpectedNodes: branchCount * terminalNodesCount,
|
||||
items,
|
||||
}
|
||||
|
||||
@@ -253,8 +245,6 @@ export class ParallelOrchestrator {
|
||||
parallelId,
|
||||
totalBranches: 0,
|
||||
branchOutputs: new Map(),
|
||||
completedCount: 0,
|
||||
totalExpectedNodes: 0,
|
||||
items: [],
|
||||
validationError: errorMessage,
|
||||
}
|
||||
@@ -277,32 +267,34 @@ export class ParallelOrchestrator {
|
||||
return resolveArrayInput(ctx, config.distribution, this.resolver)
|
||||
}
|
||||
|
||||
/**
|
||||
* Stores a node's output in the branch outputs for later aggregation.
|
||||
* Aggregation is triggered by the sentinel-end node via the edge mechanism,
|
||||
* not by counting individual node completions. This avoids incorrect completion
|
||||
* detection when branches have conditional paths (error edges, conditions).
|
||||
*/
|
||||
handleParallelBranchCompletion(
|
||||
ctx: ExecutionContext,
|
||||
parallelId: string,
|
||||
nodeId: string,
|
||||
output: NormalizedBlockOutput
|
||||
): boolean {
|
||||
): void {
|
||||
const scope = ctx.parallelExecutions?.get(parallelId)
|
||||
if (!scope) {
|
||||
logger.warn('Parallel scope not found for branch completion', { parallelId, nodeId })
|
||||
return false
|
||||
return
|
||||
}
|
||||
|
||||
const branchIndex = extractBranchIndex(nodeId)
|
||||
if (branchIndex === null) {
|
||||
logger.warn('Could not extract branch index from node ID', { nodeId })
|
||||
return false
|
||||
return
|
||||
}
|
||||
|
||||
if (!scope.branchOutputs.has(branchIndex)) {
|
||||
scope.branchOutputs.set(branchIndex, [])
|
||||
}
|
||||
scope.branchOutputs.get(branchIndex)!.push(output)
|
||||
scope.completedCount++
|
||||
|
||||
const allComplete = scope.completedCount >= scope.totalExpectedNodes
|
||||
return allComplete
|
||||
}
|
||||
|
||||
async aggregateParallelResults(
|
||||
|
||||
@@ -228,8 +228,6 @@ export interface ExecutionContext {
|
||||
parallelId: string
|
||||
totalBranches: number
|
||||
branchOutputs: Map<number, any[]>
|
||||
completedCount: number
|
||||
totalExpectedNodes: number
|
||||
parallelType?: 'count' | 'collection'
|
||||
items?: any[]
|
||||
}
|
||||
|
||||
@@ -43,8 +43,6 @@ describe('getIterationContext', () => {
|
||||
parallelId: 'p1',
|
||||
totalBranches: 3,
|
||||
branchOutputs: new Map(),
|
||||
completedCount: 0,
|
||||
totalExpectedNodes: 3,
|
||||
},
|
||||
],
|
||||
]),
|
||||
@@ -135,8 +133,6 @@ describe('buildUnifiedParentIterations', () => {
|
||||
parallelId: 'outer-p',
|
||||
totalBranches: 4,
|
||||
branchOutputs: new Map(),
|
||||
completedCount: 0,
|
||||
totalExpectedNodes: 4,
|
||||
},
|
||||
],
|
||||
]),
|
||||
@@ -164,8 +160,6 @@ describe('buildUnifiedParentIterations', () => {
|
||||
parallelId: 'parallel-1',
|
||||
totalBranches: 5,
|
||||
branchOutputs: new Map(),
|
||||
completedCount: 0,
|
||||
totalExpectedNodes: 5,
|
||||
},
|
||||
],
|
||||
]),
|
||||
@@ -232,8 +226,6 @@ describe('buildUnifiedParentIterations', () => {
|
||||
parallelId: 'outer-p',
|
||||
totalBranches: 3,
|
||||
branchOutputs: new Map(),
|
||||
completedCount: 0,
|
||||
totalExpectedNodes: 3,
|
||||
},
|
||||
],
|
||||
]),
|
||||
@@ -275,8 +267,6 @@ describe('buildUnifiedParentIterations', () => {
|
||||
parallelId: 'P1',
|
||||
totalBranches: 2,
|
||||
branchOutputs: new Map(),
|
||||
completedCount: 0,
|
||||
totalExpectedNodes: 2,
|
||||
},
|
||||
],
|
||||
[
|
||||
@@ -285,8 +275,6 @@ describe('buildUnifiedParentIterations', () => {
|
||||
parallelId: 'P2',
|
||||
totalBranches: 2,
|
||||
branchOutputs: new Map(),
|
||||
completedCount: 0,
|
||||
totalExpectedNodes: 2,
|
||||
},
|
||||
],
|
||||
[
|
||||
@@ -295,8 +283,6 @@ describe('buildUnifiedParentIterations', () => {
|
||||
parallelId: 'P2__obranch-1',
|
||||
totalBranches: 2,
|
||||
branchOutputs: new Map(),
|
||||
completedCount: 0,
|
||||
totalExpectedNodes: 2,
|
||||
},
|
||||
],
|
||||
]),
|
||||
@@ -363,8 +349,6 @@ describe('buildUnifiedParentIterations', () => {
|
||||
parallelId: 'parallel-1',
|
||||
totalBranches: 5,
|
||||
branchOutputs: new Map(),
|
||||
completedCount: 0,
|
||||
totalExpectedNodes: 5,
|
||||
},
|
||||
],
|
||||
]),
|
||||
@@ -423,8 +407,6 @@ describe('buildUnifiedParentIterations', () => {
|
||||
parallelId: 'parallel-1',
|
||||
totalBranches: 3,
|
||||
branchOutputs: new Map(),
|
||||
completedCount: 0,
|
||||
totalExpectedNodes: 3,
|
||||
},
|
||||
],
|
||||
]),
|
||||
@@ -478,8 +460,6 @@ describe('buildContainerIterationContext', () => {
|
||||
parallelId: 'parallel-1',
|
||||
totalBranches: 5,
|
||||
branchOutputs: new Map(),
|
||||
completedCount: 0,
|
||||
totalExpectedNodes: 5,
|
||||
},
|
||||
],
|
||||
]),
|
||||
@@ -541,8 +521,6 @@ describe('buildContainerIterationContext', () => {
|
||||
parallelId: 'P2__obranch-1',
|
||||
totalBranches: 5,
|
||||
branchOutputs: new Map(),
|
||||
completedCount: 0,
|
||||
totalExpectedNodes: 5,
|
||||
},
|
||||
],
|
||||
]),
|
||||
@@ -568,8 +546,6 @@ describe('buildContainerIterationContext', () => {
|
||||
parallelId: 'outer-parallel',
|
||||
totalBranches: 3,
|
||||
branchOutputs: new Map(),
|
||||
completedCount: 0,
|
||||
totalExpectedNodes: 3,
|
||||
},
|
||||
],
|
||||
]),
|
||||
|
||||
@@ -6,6 +6,11 @@ import type { ResolutionContext } from './reference'
|
||||
|
||||
vi.mock('@sim/logger', () => loggerMock)
|
||||
|
||||
interface BlockDef {
|
||||
id: string
|
||||
name: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a minimal workflow for testing.
|
||||
*/
|
||||
@@ -18,7 +23,8 @@ function createTestWorkflow(
|
||||
distribution?: any
|
||||
parallelType?: 'count' | 'collection'
|
||||
}
|
||||
> = {}
|
||||
> = {},
|
||||
blockDefs: BlockDef[] = []
|
||||
) {
|
||||
const normalizedParallels: Record<
|
||||
string,
|
||||
@@ -37,9 +43,18 @@ function createTestWorkflow(
|
||||
parallelType: parallel.parallelType,
|
||||
}
|
||||
}
|
||||
const blocks = blockDefs.map((b) => ({
|
||||
id: b.id,
|
||||
position: { x: 0, y: 0 },
|
||||
config: { tool: 'test', params: {} },
|
||||
inputs: {},
|
||||
outputs: {},
|
||||
metadata: { id: 'function', name: b.name },
|
||||
enabled: true,
|
||||
}))
|
||||
return {
|
||||
version: '1.0',
|
||||
blocks: [],
|
||||
blocks,
|
||||
connections: [],
|
||||
loops: {},
|
||||
parallels: normalizedParallels,
|
||||
@@ -54,8 +69,6 @@ function createParallelScope(items: any[]) {
|
||||
parallelId: 'parallel-1',
|
||||
totalBranches: items.length,
|
||||
branchOutputs: new Map(),
|
||||
completedCount: 0,
|
||||
totalExpectedNodes: 1,
|
||||
items,
|
||||
}
|
||||
}
|
||||
@@ -65,13 +78,16 @@ function createParallelScope(items: any[]) {
|
||||
*/
|
||||
function createTestContext(
|
||||
currentNodeId: string,
|
||||
parallelExecutions?: Map<string, any>
|
||||
parallelExecutions?: Map<string, any>,
|
||||
blockOutputs?: Record<string, any>
|
||||
): ResolutionContext {
|
||||
return {
|
||||
executionContext: {
|
||||
parallelExecutions: parallelExecutions ?? new Map(),
|
||||
},
|
||||
executionState: {},
|
||||
executionState: {
|
||||
getBlockOutput: (id: string) => blockOutputs?.[id],
|
||||
},
|
||||
currentNodeId,
|
||||
} as ResolutionContext
|
||||
}
|
||||
@@ -385,4 +401,119 @@ describe('ParallelResolver', () => {
|
||||
expect(resolver.resolve('<parallel.currentItem>', createTestContext('block-3₍1₎'))).toBe('p4')
|
||||
})
|
||||
})
|
||||
|
||||
describe('named parallel references', () => {
|
||||
it.concurrent('should resolve result from anywhere after parallel completes', () => {
|
||||
const workflow = createTestWorkflow(
|
||||
{ 'parallel-1': { nodes: ['block-1'], distribution: ['a', 'b'] } },
|
||||
[{ id: 'parallel-1', name: 'Parallel 1' }]
|
||||
)
|
||||
const resolver = new ParallelResolver(workflow)
|
||||
const results = [[{ response: 'a' }], [{ response: 'b' }]]
|
||||
const ctx = createTestContext('block-outside', new Map(), {
|
||||
'parallel-1': { results },
|
||||
})
|
||||
|
||||
expect(resolver.resolve('<parallel1.result>', ctx)).toEqual(results)
|
||||
expect(resolver.resolve('<parallel1.results>', ctx)).toEqual(results)
|
||||
})
|
||||
|
||||
it.concurrent('should resolve result with nested path', () => {
|
||||
const workflow = createTestWorkflow(
|
||||
{ 'parallel-1': { nodes: ['block-1'], distribution: ['a', 'b'] } },
|
||||
[{ id: 'parallel-1', name: 'Parallel 1' }]
|
||||
)
|
||||
const resolver = new ParallelResolver(workflow)
|
||||
const results = [[{ response: 'a' }], [{ response: 'b' }]]
|
||||
const ctx = createTestContext('block-outside', new Map(), {
|
||||
'parallel-1': { results },
|
||||
})
|
||||
|
||||
expect(resolver.resolve('<parallel1.result.0>', ctx)).toEqual([{ response: 'a' }])
|
||||
expect(resolver.resolve('<parallel1.result.1.0.response>', ctx)).toBe('b')
|
||||
})
|
||||
|
||||
it.concurrent('should resolve result with empty currentNodeId', () => {
|
||||
const workflow = createTestWorkflow(
|
||||
{ 'parallel-1': { nodes: ['block-1'], distribution: ['a', 'b'] } },
|
||||
[{ id: 'parallel-1', name: 'Parallel 1' }]
|
||||
)
|
||||
const resolver = new ParallelResolver(workflow)
|
||||
const results = [[{ output: 'x' }], [{ output: 'y' }]]
|
||||
const ctx = createTestContext('', new Map(), {
|
||||
'parallel-1': { results },
|
||||
})
|
||||
|
||||
expect(resolver.resolve('<parallel1.results>', ctx)).toEqual(results)
|
||||
})
|
||||
|
||||
it.concurrent('should return undefined when no output stored yet', () => {
|
||||
const workflow = createTestWorkflow(
|
||||
{ 'parallel-1': { nodes: ['block-1'], distribution: ['a'] } },
|
||||
[{ id: 'parallel-1', name: 'Parallel 1' }]
|
||||
)
|
||||
const resolver = new ParallelResolver(workflow)
|
||||
const ctx = createTestContext('block-outside', new Map())
|
||||
|
||||
expect(resolver.resolve('<parallel1.results>', ctx)).toBeUndefined()
|
||||
})
|
||||
|
||||
it.concurrent('should resolve iteration properties via named reference', () => {
|
||||
const workflow = createTestWorkflow(
|
||||
{
|
||||
'parallel-1': {
|
||||
nodes: ['block-1'],
|
||||
distribution: ['x', 'y', 'z'],
|
||||
parallelType: 'collection',
|
||||
},
|
||||
},
|
||||
[{ id: 'parallel-1', name: 'Parallel 1' }]
|
||||
)
|
||||
const resolver = new ParallelResolver(workflow)
|
||||
const ctx = createTestContext('block-1₍1₎')
|
||||
|
||||
expect(resolver.resolve('<parallel1.index>', ctx)).toBe(1)
|
||||
expect(resolver.resolve('<parallel1.currentItem>', ctx)).toBe('y')
|
||||
expect(resolver.resolve('<parallel1.items>', ctx)).toEqual(['x', 'y', 'z'])
|
||||
})
|
||||
|
||||
it.concurrent('should throw InvalidFieldError for unknown property on named ref', () => {
|
||||
const workflow = createTestWorkflow(
|
||||
{
|
||||
'parallel-1': {
|
||||
nodes: ['block-1'],
|
||||
distribution: ['a'],
|
||||
parallelType: 'collection',
|
||||
},
|
||||
},
|
||||
[{ id: 'parallel-1', name: 'Parallel 1' }]
|
||||
)
|
||||
const resolver = new ParallelResolver(workflow)
|
||||
const ctx = createTestContext('block-1₍0₎')
|
||||
|
||||
expect(() => resolver.resolve('<parallel1.unknownProp>', ctx)).toThrow(InvalidFieldError)
|
||||
})
|
||||
|
||||
it.concurrent('should not resolve named ref when no matching block exists', () => {
|
||||
const workflow = createTestWorkflow({ 'parallel-1': { nodes: ['block-1'] } }, [
|
||||
{ id: 'parallel-1', name: 'Parallel 1' },
|
||||
])
|
||||
const resolver = new ParallelResolver(workflow)
|
||||
expect(resolver.canResolve('<parallel99.index>')).toBe(false)
|
||||
})
|
||||
|
||||
it.concurrent('should resolve generic parallel results from inside a branch', () => {
|
||||
const workflow = createTestWorkflow({
|
||||
'parallel-1': { nodes: ['block-1'], distribution: ['a', 'b'] },
|
||||
})
|
||||
const resolver = new ParallelResolver(workflow)
|
||||
const results = [[{ response: 'a' }], [{ response: 'b' }]]
|
||||
const ctx = createTestContext('block-1₍0₎', new Map(), {
|
||||
'parallel-1': { results },
|
||||
})
|
||||
|
||||
expect(resolver.resolve('<parallel.results>', ctx)).toEqual(results)
|
||||
expect(resolver.resolve('<parallel.result>', ctx)).toEqual(results)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -28,6 +28,7 @@ export class ParallelResolver implements Resolver {
|
||||
}
|
||||
}
|
||||
|
||||
private static OUTPUT_PROPERTIES = new Set(['result', 'results'])
|
||||
private static KNOWN_PROPERTIES = new Set(['index', 'currentItem', 'items'])
|
||||
|
||||
canResolve(reference: string): boolean {
|
||||
@@ -73,6 +74,10 @@ export class ParallelResolver implements Resolver {
|
||||
)
|
||||
}
|
||||
|
||||
if (rest.length > 0 && ParallelResolver.OUTPUT_PROPERTIES.has(rest[0])) {
|
||||
return this.resolveOutput(targetParallelId, rest.slice(1), context)
|
||||
}
|
||||
|
||||
// Look up config using the original (non-cloned) ID
|
||||
const originalParallelId = stripOuterBranchSuffix(targetParallelId)
|
||||
const parallelConfig = this.workflow.parallels?.[originalParallelId]
|
||||
@@ -116,7 +121,9 @@ export class ParallelResolver implements Resolver {
|
||||
|
||||
if (!ParallelResolver.KNOWN_PROPERTIES.has(property)) {
|
||||
const isCollection = parallelConfig.parallelType === 'collection'
|
||||
const availableFields = isCollection ? ['index', 'currentItem', 'items'] : ['index']
|
||||
const availableFields = isCollection
|
||||
? ['index', 'currentItem', 'items', 'result']
|
||||
: ['index', 'result']
|
||||
throw new InvalidFieldError(firstPart, property, availableFields)
|
||||
}
|
||||
|
||||
@@ -216,6 +223,22 @@ export class ParallelResolver implements Resolver {
|
||||
return undefined
|
||||
}
|
||||
|
||||
private resolveOutput(
|
||||
parallelId: string,
|
||||
pathParts: string[],
|
||||
context: ResolutionContext
|
||||
): unknown {
|
||||
const output = context.executionState.getBlockOutput(parallelId)
|
||||
if (!output || typeof output !== 'object') {
|
||||
return undefined
|
||||
}
|
||||
const value = (output as Record<string, unknown>).results
|
||||
if (pathParts.length > 0) {
|
||||
return navigatePath(value, pathParts)
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
private getDistributionItems(parallelConfig: SerializedParallel): unknown[] {
|
||||
const rawItems = parallelConfig.distribution ?? []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user