diff --git a/invokeai/frontend/web/src/services/api/run-graph.test.ts b/invokeai/frontend/web/src/services/api/run-graph.test.ts index 2f11740310..9069579d35 100644 --- a/invokeai/frontend/web/src/services/api/run-graph.test.ts +++ b/invokeai/frontend/web/src/services/api/run-graph.test.ts @@ -48,6 +48,7 @@ const createMockGraph = (id = TEST_ID, hasIterateNodes = false): Graph => { nodes: mockNodes, edges: {}, }), + getNodes: vi.fn().mockReturnValue(Object.values(mockNodes)), } as unknown as Graph; }; diff --git a/invokeai/frontend/web/src/services/api/run-graph.ts b/invokeai/frontend/web/src/services/api/run-graph.ts index eebcba39d1..db4c87d7fe 100644 --- a/invokeai/frontend/web/src/services/api/run-graph.ts +++ b/invokeai/frontend/web/src/services/api/run-graph.ts @@ -156,48 +156,35 @@ export const buildRunGraphDependencies = ( /** * Internal business logic for running a graph. + * + * This function is not intended to be used directly. Use `runGraph` instead. + * + * @param arg The arguments for running the graph. + * @param _resolve The resolve function for the promise. Do not call this directly; use the `settle` function instead. + * @param _reject The reject function for the promise. Do not call this directly; use the `settle` function instead. */ const _runGraph = async ( arg: RunGraphArg, - resolve: (value: RunGraphReturn) => void, - reject: (error: Error) => void + _resolve: (value: RunGraphReturn) => void, + _reject: (error: Error) => void ): Promise => { const { graph, outputNodeId, dependencies, destination, prepend, timeout, signal } = arg; - if (!graph.hasNode(outputNodeId)) { - reject(new OutputNodeNotFoundInGraphError(outputNodeId, graph)); - return; - } - - const g = graph.getGraph(); - - if (Object.values(g.nodes).some((node) => node.type === 'iterate')) { - reject(new IterateNodeFoundInGraphError(graph)); - return; - } - /** - * We will use the origin to handle events from the graph. Ideally we'd just use the queue item's id, but there's a - * race condition for fast-running graphs: - * - We enqueue the batch and wait for the respose from the network request, which will include the queue item id. - * - The queue item is executed. - * - We get status change events for the queue item, but we don't have the queue item id yet, so we miss the event. + * We will use the origin to filter out socket events unrelated to this graph. * - * The origin is the only unique identifier that we can set before enqueuing the graph. We set it to something - * unique and use it to filter for events relevant to this graph. + * Ideally we'd use the queue item's id, but there's a race condition for fast-running graphs: + * - We enqueue the batch, which initiates a network request. + * - The queue item is created and quickly completed. + * - The enqueue batch request returns, which includes the queue item id. + * - We set up listeners for the queue item status change events, but the queue item is already completed, so we + * miss the status change event and are left waiting forever. + * + * The origin is a unique identifier that we can set before enqueuing the graph. This allows us to set up listeners + * _before_ enqueuing the graph, ensuring that we don't miss any events. */ const origin = getPrefixedId(graph.id); - const batch: EnqueueBatchArg = { - prepend, - batch: { - graph: g, - origin, - destination, - runs: 1, - }, - }; - /** * The queue item id is set to null initially, but will be updated once the graph is enqueued. It will be used to * retrieve the queue item. @@ -227,6 +214,14 @@ const _runGraph = async ( */ let isSettling = false; const settlementMutex = new Mutex(); + + /** + * Wraps all logic that settles the promise. Return a Result to indicate success or failure. This function will + * handle the cleanup of listeners, timeouts, etc. and resolve or reject the promise based on the result. + * + * Once the graph execution is finished, all remaining logic should be wrapped in this function to avoid race + * conditions or multiple resolutions/rejections of the promise. + */ const settle = async (settlement: () => Promise> | Result) => { await settlementMutex.runExclusive(async () => { // If we are already settling, ignore this call to avoid multiple resolutions or rejections. @@ -243,13 +238,27 @@ const _runGraph = async ( const result = await Promise.resolve(settlement()); if (result.isOk()) { - resolve(result.value); + _resolve(result.value); } else { - reject(result.error); + _reject(result.error); } }); }; + if (!graph.hasNode(outputNodeId)) { + await settle(() => { + return ErrResult(new OutputNodeNotFoundInGraphError(outputNodeId, graph)); + }); + return; + } + + if (graph.getNodes().some((node) => node.type === 'iterate')) { + await settle(() => { + return ErrResult(new IterateNodeFoundInGraphError(graph)); + }); + return; + } + // If a timeout value is provided, we create a timer to reject the promise. if (timeout !== undefined) { const timeoutId = setTimeout(async () => { @@ -345,7 +354,18 @@ const _runGraph = async ( dependencies.eventHandler.unsubscribe(onQueueItemStatusChanged); }); - const enqueueResult = await withResultAsync(() => dependencies.executor.enqueueBatch(batch)); + const enqueueResult = await withResultAsync(() => { + const batch: EnqueueBatchArg = { + prepend, + batch: { + graph: graph.getGraph(), + origin, + destination, + runs: 1, + }, + }; + return dependencies.executor.enqueueBatch(batch); + }); if (enqueueResult.isErr()) { // The enqueue operation itself failed - we cannot proceed. await settle(() => ErrResult(enqueueResult.error));