diff --git a/invokeai/frontend/web/src/services/api/run-graph.ts b/invokeai/frontend/web/src/services/api/run-graph.ts index 919df641a9..ee335ac68f 100644 --- a/invokeai/frontend/web/src/services/api/run-graph.ts +++ b/invokeai/frontend/web/src/services/api/run-graph.ts @@ -15,6 +15,27 @@ import type { EnqueueBatchArg } from './types'; const log = logger('queue'); +type Deferred = { + promise: Promise; + resolve: (value: T) => void; + reject: (error: Error) => void; +}; + +/** + * Create a promise and expose its resolve and reject callbacks. + */ +const createDeferred = (): Deferred => { + let resolve!: (value: T) => void; + let reject!: (error: Error) => void; + + const promise = new Promise((res, rej) => { + resolve = res; + reject = rej; + }); + + return { promise, resolve, reject }; +}; + interface QueueStatusEventHandler { subscribe: (handler: (event: S['QueueItemStatusChangedEvent']) => void) => void; unsubscribe: (handler: (event: S['QueueItemStatusChangedEvent']) => void) => void; @@ -46,36 +67,6 @@ type RunGraphReturn = { output: S['GraphExecutionState']['results'][string]; }; -/** - * Creates production dependencies for runGraph using Redux store and socket. - */ -export const buildRunGraphDependencies = ( - store: AppStore, - socket: { - on: (event: 'queue_item_status_changed', handler: (event: S['QueueItemStatusChangedEvent']) => void) => void; - off: (event: 'queue_item_status_changed', handler: (event: S['QueueItemStatusChangedEvent']) => void) => void; - } -): GraphRunnerDependencies => ({ - executor: { - enqueueBatch: (batch) => - store - .dispatch( - queueApi.endpoints.enqueueBatch.initiate(batch, { - ...enqueueMutationFixedCacheKeyOptions, - track: false, - }) - ) - .unwrap(), - getQueueItem: (id) => store.dispatch(queueApi.endpoints.getQueueItem.initiate(id, { subscribe: false })).unwrap(), - cancelQueueItem: (id) => - store.dispatch(queueApi.endpoints.cancelQueueItem.initiate({ item_id: id }, { track: false })).unwrap(), - }, - eventHandler: { - subscribe: (handler) => socket.on('queue_item_status_changed', handler), - unsubscribe: (handler) => socket.off('queue_item_status_changed', handler), - }, -}); - /** * Run a graph and return the output of a specific node. * @@ -113,206 +104,245 @@ export const buildRunGraphDependencies = ( * ``` */ export const runGraph = (arg: RunGraphArg): Promise => { - const promise = new Promise((resolve, reject) => { - const { graph, outputNodeId, dependencies, destination, prepend, timeout, signal } = arg; + // A deferred promise works around the antipattern of async promise executors. + const { promise, resolve, reject } = createDeferred(); + _runGraph(arg, resolve, reject); + return promise; +}; - if (!graph.hasNode(outputNodeId)) { - reject(new OutputNodeNotFoundError(outputNodeId, graph)); - return; - } +/** + * Creates production dependencies for runGraph using Redux store and socket. + */ +export const buildRunGraphDependencies = ( + store: AppStore, + socket: { + on: (event: 'queue_item_status_changed', handler: (event: S['QueueItemStatusChangedEvent']) => void) => void; + off: (event: 'queue_item_status_changed', handler: (event: S['QueueItemStatusChangedEvent']) => void) => void; + } +): GraphRunnerDependencies => ({ + executor: { + enqueueBatch: (batch) => + store + .dispatch( + queueApi.endpoints.enqueueBatch.initiate(batch, { + ...enqueueMutationFixedCacheKeyOptions, + track: false, + }) + ) + .unwrap(), + getQueueItem: (id) => store.dispatch(queueApi.endpoints.getQueueItem.initiate(id, { subscribe: false })).unwrap(), + cancelQueueItem: (id) => + store.dispatch(queueApi.endpoints.cancelQueueItem.initiate({ item_id: id }, { track: false })).unwrap(), + }, + eventHandler: { + subscribe: (handler) => socket.on('queue_item_status_changed', handler), + unsubscribe: (handler) => socket.off('queue_item_status_changed', handler), + }, +}); - const g = graph.getGraph(); +/** + * Internal business logic for running a graph. + */ +const _runGraph = async ( + arg: RunGraphArg, + resolve: (value: RunGraphReturn) => void, + reject: (error: Error) => void +): Promise => { + const { graph, outputNodeId, dependencies, destination, prepend, timeout, signal } = arg; - if (Object.values(g.nodes).some((node) => node.type === 'iterate')) { - reject(new IterateNodeFoundError(graph)); - return; - } + if (!graph.hasNode(outputNodeId)) { + reject(new OutputNodeNotFoundError(outputNodeId, graph)); + return; + } - /** - * We will use the origin to handle events from the graph. Ideally we'd just use the queue item's id, but there's a - * race condition for fast-running graphs: - * - We enqueue the batch and wait for the respose from the network request, which will include the queue item id. - * - The queue item is executed. - * - We get status change events for the queue item, but we don't have the queue item id yet, so we miss the event. - * - * The origin is the only unique identifier that we can set before enqueuing the graph. We set it to something - * unique and use it to filter for events relevant to this graph. - */ - const origin = getPrefixedId(graph.id); + const g = graph.getGraph(); - const batch: EnqueueBatchArg = { - prepend, - batch: { - graph: g, - origin, - destination, - runs: 1, - }, - }; + if (Object.values(g.nodes).some((node) => node.type === 'iterate')) { + reject(new IterateNodeFoundError(graph)); + return; + } - /** - * The queue item id is set to null initially, but will be updated once the graph is enqueued. It will be used to - * retrieve the queue item. - */ - let queueItemId: number | null = null; + /** + * We will use the origin to handle events from the graph. Ideally we'd just use the queue item's id, but there's a + * race condition for fast-running graphs: + * - We enqueue the batch and wait for the respose from the network request, which will include the queue item id. + * - The queue item is executed. + * - We get status change events for the queue item, but we don't have the queue item id yet, so we miss the event. + * + * The origin is the only unique identifier that we can set before enqueuing the graph. We set it to something + * unique and use it to filter for events relevant to this graph. + */ + const origin = getPrefixedId(graph.id); - /** - * Set of cleanup functions for listeners, timeouts, etc that need to be called when the graph is settled. - */ - const cleanupFunctions: Set<() => void> = new Set(); - const cleanup = () => { - for (const func of cleanupFunctions) { - try { - func(); - } catch (error) { - log.warn({ error: parseify(error) }, 'Error during cleanup'); - } + const batch: EnqueueBatchArg = { + prepend, + batch: { + graph: g, + origin, + destination, + runs: 1, + }, + }; + + /** + * The queue item id is set to null initially, but will be updated once the graph is enqueued. It will be used to + * retrieve the queue item. + */ + let queueItemId: number | null = null; + + /** + * Set of cleanup functions for listeners, timeouts, etc that need to be called when the graph is settled. + */ + const cleanupFunctions: Set<() => void> = new Set(); + const cleanup = () => { + for (const func of cleanupFunctions) { + try { + func(); + } catch (error) { + log.warn({ error: parseify(error) }, 'Error during cleanup'); } - }; - - /** - * We use a mutex to ensure that the promise is resolved or rejected only once, even if multiple events - * are received or the settle function is called multiple times. - * - * A flag allows pending locks to bail if the promise has already been settled. - */ - let isSettling = false; - const settlementMutex = new Mutex(); - const settle = async (settlement: () => Promise> | Result) => { - await settlementMutex.runExclusive(async () => { - // If we are already settling, ignore this call to avoid multiple resolutions or rejections. - // We don't want to _cancel_ pending locks as this would raise. - if (isSettling) { - return; - } - isSettling = true; - - // Clean up listeners, timeouts, etc. ASAP. - cleanup(); - - // Normalize the settlement function to always return a promise. - const result = await Promise.resolve(settlement()); - - if (result.isOk()) { - resolve(result.value); - } else { - reject(result.error); - } - }); - }; - - // If a timeout value is provided, we create a timer to reject the promise. - if (timeout !== undefined) { - const timeoutId = setTimeout(async () => { - await settle(() => { - log.trace('Graph canceled by timeout'); - if (queueItemId !== null) { - // It's possible the cancelation will fail, but we have no way to handle that gracefully. Log a warning - // and move on to reject. - dependencies.executor.cancelQueueItem(queueItemId).catch((error) => { - log.warn({ error: parseify(error) }, 'Failed to cancel queue item during timeout'); - }); - } - return ErrResult(new SessionTimeoutError(queueItemId)); - }); - }, timeout); - - cleanupFunctions.add(() => { - clearTimeout(timeoutId); - }); } + }; - // If a signal is provided, we add an abort handler to reject the promise if the signal is aborted. - if (signal !== undefined) { - const abortHandler = () => { - settle(() => { - log.trace('Graph canceled by signal'); - if (queueItemId !== null) { - // It's possible the cancelation will fail, but we have no way to handle that gracefully. Log a warning - // and move on to reject. - dependencies.executor.cancelQueueItem(queueItemId).catch((error) => { - log.warn({ error: parseify(error) }, 'Failed to cancel queue item during abort'); - }); - } - return ErrResult(new SessionAbortedError(queueItemId)); - }); - }; - - signal.addEventListener('abort', abortHandler); - cleanupFunctions.add(() => { - signal.removeEventListener('abort', abortHandler); - }); - } - - // Handle the queue item status change events. - const onQueueItemStatusChanged = async (event: S['QueueItemStatusChangedEvent']) => { - // Ignore events that are not for this graph - if (event.origin !== origin) { + /** + * We use a mutex to ensure that the promise is resolved or rejected only once, even if multiple events + * are received or the settle function is called multiple times. + * + * A flag allows pending locks to bail if the promise has already been settled. + */ + let isSettling = false; + const settlementMutex = new Mutex(); + const settle = async (settlement: () => Promise> | Result) => { + await settlementMutex.runExclusive(async () => { + // If we are already settling, ignore this call to avoid multiple resolutions or rejections. + // We don't want to _cancel_ pending locks as this would raise. + if (isSettling) { return; } + isSettling = true; - // Ignore events where the status is pending or in progress - no need to do anything for these - if (event.status === 'pending' || event.status === 'in_progress') { - return; + // Clean up listeners, timeouts, etc. ASAP. + cleanup(); + + // Normalize the settlement function to always return a promise. + const result = await Promise.resolve(settlement()); + + if (result.isOk()) { + resolve(result.value); + } else { + reject(result.error); } - - await settle(async () => { - // We need to handle any errors, including retrieving the queue item - const queueItemResult = await withResultAsync(() => dependencies.executor.getQueueItem(event.item_id)); - if (queueItemResult.isErr()) { - return ErrResult(queueItemResult.error); - } - - const queueItem = queueItemResult.value; - - const { status, session, error_type, error_message, error_traceback } = queueItem; - - // We are confident that the queue item is not pending or in progress, at this time. - assert(status !== 'pending' && status !== 'in_progress'); - - if (status === 'completed') { - const getOutputResult = withResult(() => getOutputFromSession(queueItemId, session, outputNodeId)); - if (getOutputResult.isErr()) { - return ErrResult(getOutputResult.error); - } - const output = getOutputResult.value; - return OkResult({ session, output }); - } - - if (status === 'failed') { - return ErrResult(new SessionExecutionError(queueItemId, session, error_type, error_message, error_traceback)); - } - - if (status === 'canceled') { - return ErrResult(new SessionCancelationError(queueItemId, session)); - } - - assert>(false); - }); - }; - - dependencies.eventHandler.subscribe(onQueueItemStatusChanged); - cleanupFunctions.add(() => { - dependencies.eventHandler.unsubscribe(onQueueItemStatusChanged); }); + }; - // We are ready to enqueue the graph - dependencies.executor - .enqueueBatch(batch) - .then((data) => { - // We queue a single run of the batch, so we know there is only one item_id in the response. - assert(data.item_ids.length === 1); - assert(data.item_ids[0] !== undefined); - queueItemId = data.item_ids[0]; - }) - .catch(async (error) => { - await settle(() => { - return ErrResult(error); - }); + // If a timeout value is provided, we create a timer to reject the promise. + if (timeout !== undefined) { + const timeoutId = setTimeout(async () => { + await settle(() => { + log.trace('Graph canceled by timeout'); + if (queueItemId !== null) { + // It's possible the cancelation will fail, but we have no way to handle that gracefully. Log a warning + // and move on to reject. + dependencies.executor.cancelQueueItem(queueItemId).catch((error) => { + log.warn({ error: parseify(error) }, 'Failed to cancel queue item during timeout'); + }); + } + return ErrResult(new SessionTimeoutError(queueItemId)); }); + }, timeout); + + cleanupFunctions.add(() => { + clearTimeout(timeoutId); + }); + } + + // If a signal is provided, we add an abort handler to reject the promise if the signal is aborted. + if (signal !== undefined) { + const abortHandler = () => { + settle(() => { + log.trace('Graph canceled by signal'); + if (queueItemId !== null) { + // It's possible the cancelation will fail, but we have no way to handle that gracefully. Log a warning + // and move on to reject. + dependencies.executor.cancelQueueItem(queueItemId).catch((error) => { + log.warn({ error: parseify(error) }, 'Failed to cancel queue item during abort'); + }); + } + return ErrResult(new SessionAbortedError(queueItemId)); + }); + }; + + signal.addEventListener('abort', abortHandler); + cleanupFunctions.add(() => { + signal.removeEventListener('abort', abortHandler); + }); + } + + // Handle the queue item status change events. + const onQueueItemStatusChanged = async (event: S['QueueItemStatusChangedEvent']) => { + // Ignore events that are not for this graph + if (event.origin !== origin) { + return; + } + + // Ignore events where the status is pending or in progress - no need to do anything for these + if (event.status === 'pending' || event.status === 'in_progress') { + return; + } + + await settle(async () => { + // We need to handle any errors, including retrieving the queue item + const queueItemResult = await withResultAsync(() => dependencies.executor.getQueueItem(event.item_id)); + if (queueItemResult.isErr()) { + return ErrResult(queueItemResult.error); + } + + const queueItem = queueItemResult.value; + + const { status, session, error_type, error_message, error_traceback } = queueItem; + + // We are confident that the queue item is not pending or in progress, at this time. + assert(status !== 'pending' && status !== 'in_progress'); + + if (status === 'completed') { + const getOutputResult = withResult(() => getOutputFromSession(queueItemId, session, outputNodeId)); + if (getOutputResult.isErr()) { + return ErrResult(getOutputResult.error); + } + const output = getOutputResult.value; + return OkResult({ session, output }); + } + + if (status === 'failed') { + return ErrResult(new SessionExecutionError(queueItemId, session, error_type, error_message, error_traceback)); + } + + if (status === 'canceled') { + return ErrResult(new SessionCancelationError(queueItemId, session)); + } + + assert>(false); + }); + }; + + dependencies.eventHandler.subscribe(onQueueItemStatusChanged); + cleanupFunctions.add(() => { + dependencies.eventHandler.unsubscribe(onQueueItemStatusChanged); }); - return promise; + const enqueueResult = await withResultAsync(() => dependencies.executor.enqueueBatch(batch)); + if (enqueueResult.isErr()) { + // The enqueue operation itself failed - we cannot proceed. + await settle(() => ErrResult(enqueueResult.error)); + return; + } + + // Retrieve the queue item id from the enqueue result. + const { item_ids } = enqueueResult.value; + // We expect exactly one item id to be returned. We control the batch config, so we can safely assert this. + assert(item_ids.length === 1); + assert(item_ids[0] !== undefined); + queueItemId = item_ids[0]; }; const getOutputFromSession = (