diff --git a/invokeai/frontend/web/src/services/api/run-graph.ts b/invokeai/frontend/web/src/services/api/run-graph.ts index 892a7476a9..151d8309a9 100644 --- a/invokeai/frontend/web/src/services/api/run-graph.ts +++ b/invokeai/frontend/web/src/services/api/run-graph.ts @@ -3,20 +3,35 @@ import type { AppStore } from 'app/store/store'; import { withResult, withResultAsync } from 'common/util/result'; import { getPrefixedId } from 'features/controlLayers/konva/util'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; +import type { S } from 'services/api/types'; import { QueueError } from 'services/events/errors'; -import type { AppSocket } from 'services/events/types'; import { assert } from 'tsafe'; import { enqueueMutationFixedCacheKeyOptions, queueApi } from './endpoints/queue'; -import type { EnqueueBatchArg, S } from './types'; +import type { EnqueueBatchArg } from './types'; const log = logger('queue'); +interface QueueStatusEventHandler { + subscribe: (handler: (event: S['QueueItemStatusChangedEvent']) => void) => void; + unsubscribe: (handler: (event: S['QueueItemStatusChangedEvent']) => void) => void; +} + +interface GraphExecutor { + enqueueBatch: (batch: EnqueueBatchArg) => Promise<{ item_ids: number[] }>; + getQueueItem: (id: number) => Promise; + cancelQueueItem: (id: number) => Promise; +} + +interface GraphRunnerDependencies { + executor: GraphExecutor; + eventHandler: QueueStatusEventHandler; +} + type RunGraphArg = { graph: Graph; outputNodeId: string; - store: AppStore; - socket: AppSocket; + dependencies: GraphRunnerDependencies; destination?: string; prepend?: boolean; timeout?: number; @@ -28,6 +43,36 @@ type RunGraphReturn = { output: S['GraphExecutionState']['results'][string]; }; +/** + * Creates production dependencies for runGraph using Redux store and socket. + */ +export const createProductionDependencies = ( + store: AppStore, + socket: { + on: (event: 'queue_item_status_changed', handler: (event: S['QueueItemStatusChangedEvent']) => void) => void; + off: (event: 'queue_item_status_changed', handler: (event: S['QueueItemStatusChangedEvent']) => void) => void; + } +): GraphRunnerDependencies => ({ + executor: { + enqueueBatch: (batch) => + store + .dispatch( + queueApi.endpoints.enqueueBatch.initiate(batch, { + ...enqueueMutationFixedCacheKeyOptions, + track: false, + }) + ) + .unwrap(), + getQueueItem: (id) => store.dispatch(queueApi.endpoints.getQueueItem.initiate(id, { subscribe: false })).unwrap(), + cancelQueueItem: (id) => + store.dispatch(queueApi.endpoints.cancelQueueItem.initiate({ item_id: id }, { track: false })).unwrap(), + }, + eventHandler: { + subscribe: (handler) => socket.on('queue_item_status_changed', handler), + unsubscribe: (handler) => socket.off('queue_item_status_changed', handler), + }, +}); + /** * Run a graph and return the output of a specific node. * @@ -38,8 +83,7 @@ type RunGraphReturn = { * * @param arg.graph The graph to execute as an instance of the Graph class. * @param arg.outputNodeId The id of the node whose output will be retrieved. - * @param arg.store The Redux store to use for dispatching actions and accessing state. - * @param arg.socket The socket to use for listening to events. + * @param arg.dependencies The dependencies for queue operations and event handling. * @param arg.destination The destination to assign to the batch. If omitted, the destination is not set. * @param arg.prepend Whether to prepend the graph to the front of the queue. If omitted, the graph is appended to the end of the queue. * @param arg.timeout The timeout for the batch. If omitted, there is no timeout. @@ -50,12 +94,14 @@ type RunGraphReturn = { * @example * * ```ts + * const dependencies = createProductionDependencies(store, socket); * const graph = new Graph(); * const outputNode = graph.addNode({ id: 'my-resize-node', type: 'img_resize', image: { image_name: 'my-image.png' } }); * const controller = new AbortController(); - * const imageDTO = await this.manager.stateApi.runGraphAndReturnImageOutput({ + * const result = await runGraph({ * graph, * outputNodeId: outputNode.id, + * dependencies, * prepend: true, * signal: controller.signal, * }); @@ -65,7 +111,7 @@ type RunGraphReturn = { */ export const runGraph = (arg: RunGraphArg): Promise => { const promise = new Promise((resolve, reject) => { - const { graph, outputNodeId, store, socket, destination, prepend, timeout, signal } = arg; + const { graph, outputNodeId, dependencies, destination, prepend, timeout, signal } = arg; if (!graph.hasNode(outputNodeId)) { reject(new Error(`Graph does not contain output node ${outputNodeId}.`)); @@ -125,7 +171,7 @@ export const runGraph = (arg: RunGraphArg): Promise => { log.trace('Graph canceled by timeout'); cleanup(); if (queueItemId !== null) { - cancelQueueItem(queueItemId, store); + dependencies.executor.cancelQueueItem(queueItemId); } reject(new Error('Graph timed out')); }, timeout); @@ -143,7 +189,7 @@ export const runGraph = (arg: RunGraphArg): Promise => { log.trace('Graph canceled by signal'); cleanup(); if (queueItemId !== null) { - cancelQueueItem(queueItemId, store); + dependencies.executor.cancelQueueItem(queueItemId); } reject(new Error('Graph canceled')); }; @@ -173,7 +219,7 @@ export const runGraph = (arg: RunGraphArg): Promise => { isResolved = true; cleanup(); - const queueItemResult = await withResultAsync(() => getQueueItem(event.item_id, store)); + const queueItemResult = await withResultAsync(() => dependencies.executor.getQueueItem(event.item_id)); if (queueItemResult.isErr()) { reject(queueItemResult.error); return; @@ -213,26 +259,14 @@ export const runGraph = (arg: RunGraphArg): Promise => { } }; - socket.on('queue_item_status_changed', onQueueItemStatusChanged); + dependencies.eventHandler.subscribe(onQueueItemStatusChanged); cleanupFunctions.add(() => { - socket.off('queue_item_status_changed', onQueueItemStatusChanged); + dependencies.eventHandler.unsubscribe(onQueueItemStatusChanged); }); // We are ready to enqueue the graph - const enqueueRequest = store.dispatch( - queueApi.endpoints.enqueueBatch.initiate(batch, { - // Use the same cache key for all enqueueBatch requests, so that all consumers of this query get the same status - // updates. - ...enqueueMutationFixedCacheKeyOptions, - // We do not need RTK to track this request in the store - track: false, - }) - ); - - // Enqueue the graph and get the batch_id, updating the cancel graph callack. We need to do this in a .then() block - // instead of awaiting the promise to avoid await-ing in a promise executor. Also need to catch any errors. - enqueueRequest - .unwrap() + dependencies.executor + .enqueueBatch(batch) .then((data) => { // We queue a single run of the batch, so we expect only one item_id in the response. assert(data.item_ids.length === 1); @@ -240,17 +274,17 @@ export const runGraph = (arg: RunGraphArg): Promise => { queueItemId = data.item_ids[0]; }) .catch((error) => { - reject(error); + if (!isResolved) { + isResolved = true; + cleanup(); + reject(error); + } }); }); return promise; }; -const getQueueItem = (queueItemId: number, store: AppStore): Promise => { - return store.dispatch(queueApi.endpoints.getQueueItem.initiate(queueItemId, { subscribe: false })).unwrap(); -}; - const getOutputFromSession = ( session: S['SessionQueueItem']['session'], nodeId: string @@ -267,20 +301,14 @@ const getOutputFromSession = ( return result; }; -const cancelQueueItem = (queueItemId: number, store: AppStore): Promise => { - return store - .dispatch(queueApi.endpoints.cancelQueueItem.initiate({ item_id: queueItemId }, { track: false })) - .unwrap(); -}; - class NodeNotFoundError extends Error { session: S['SessionQueueItem']['session']; nodeId: string; constructor(nodeId: string, session: S['SessionQueueItem']['session']) { - super(); + const availableNodes = Object.keys(session.source_prepared_mapping); + super(`Node '${nodeId}' not found in session. Available nodes: ${availableNodes.join(', ')}`); this.name = this.constructor.name; - this.message = `Node '${nodeId}' not found in session.`; this.session = session; this.nodeId = nodeId; } @@ -291,9 +319,9 @@ class ResultNotFoundError extends Error { nodeId: string; constructor(nodeId: string, session: S['SessionQueueItem']['session']) { - super(); + const availableResults = Object.keys(session.results); + super(`Result for node '${nodeId}' not found in session. Available results: ${availableResults.join(', ')}`); this.name = this.constructor.name; - this.message = `Result for node '${nodeId}' not found in session.`; this.session = session; this.nodeId = nodeId; }