diff --git a/invokeai/frontend/web/src/services/api/run-graph.ts b/invokeai/frontend/web/src/services/api/run-graph.ts index df439d1abf..12b9dabfb7 100644 --- a/invokeai/frontend/web/src/services/api/run-graph.ts +++ b/invokeai/frontend/web/src/services/api/run-graph.ts @@ -5,7 +5,6 @@ import { getPrefixedId } from 'features/controlLayers/konva/util'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import { QueueError } from 'services/events/errors'; import type { AppSocket } from 'services/events/types'; -import type { Equals } from 'tsafe'; import { assert } from 'tsafe'; import { enqueueMutationFixedCacheKeyOptions, queueApi } from './endpoints/queue'; @@ -22,7 +21,6 @@ type RunGraphArg = { prepend?: boolean; timeout?: number; signal?: AbortSignal; - pollingInterval?: number; // Optional polling interval for checking the queue item status }; type RunGraphReturn = { @@ -31,18 +29,23 @@ type RunGraphReturn = { }; /** - * Run a graph and return an image output. The specified output node must return an image output, else the promise - * will reject with an error. + * Run a graph and return the output of a specific node. * - * @param arg The arguments for the function. - * @param arg.graph The graph to execute. + * The batch will be enqueued with runs set to 1, meaning it will only run once. + * + * Iterate nodes, which cause graph expansion, are not supported by this utility because they cause a single node + * to have multiple outputs. An error will be thrown if the graph contains any iterate nodes. + * + * @param arg.graph The graph to execute as an instance of the Graph class. * @param arg.outputNodeId The id of the node whose output will be retrieved. + * @param arg.store The Redux store to use for dispatching actions and accessing state. + * @param arg.socket The socket to use for listening to events. * @param arg.destination The destination to assign to the batch. If omitted, the destination is not set. * @param arg.prepend Whether to prepend the graph to the front of the queue. If omitted, the graph is appended to the end of the queue. * @param arg.timeout The timeout for the batch. If omitted, there is no timeout. * @param arg.signal An optional signal to cancel the operation. If omitted, the operation cannot be canceled! * - * @returns A promise that resolves to the image output or rejects with an error. + * @returns A promise that resolves to the output and completed session, or rejects with an error if the graph fails or is canceled. * * @example * @@ -61,51 +64,97 @@ type RunGraphReturn = { * ``` */ export const runGraph = (arg: RunGraphArg): Promise => { - const { graph, outputNodeId, destination, prepend, timeout, signal, store, socket, pollingInterval } = arg; - - if (!graph.hasNode(outputNodeId)) { - throw new Error(`Graph does not contain node with id: ${outputNodeId}`); - } - - /** - * We will use the origin to handle events from the graph. Ideally we'd just use the queue item's id, but there's a - * race condition: - * - The queue item id is not available until the graph is enqueued - * - The graph may complete before we can set up the listeners to handle the completion event - * - * The origin is the only unique identifier we have that is guaranteed to be available before the graph is enqueued, - * so we will use that to filter events. - */ - const origin = getPrefixedId(graph.id); - - const batch: EnqueueBatchArg = { - prepend, - batch: { - graph: graph.getGraph(), - origin, - destination, - runs: 1, - }, - }; - const promise = new Promise((resolve, reject) => { + const { graph, outputNodeId, store, socket, destination, prepend, timeout, signal } = arg; + + if (!graph.hasNode(outputNodeId)) { + reject(new Error(`Graph does not contain output node ${outputNodeId}.`)); + return; + } + + const g = graph.getGraph(); + + if (Object.values(g.nodes).some((node) => node.type === 'iterate')) { + reject(new Error('Iterate nodes are not supported by this utility.')); + return; + } + /** - * Track execution state. + * We will use the origin to handle events from the graph. Ideally we'd just use the queue item's id, but there's a + * race condition: + * - The queue item id is not available until the graph is enqueued. + * - The graph may complete before we get a response back from enqueuing, so our listeners would miss the event. + * + * The origin is the only unique identifier that we can set before enqueuing the graph, so we use it to filter + * queue item status change events. */ - let didSuceed = false; + const origin = getPrefixedId(graph.id); + + const batch: EnqueueBatchArg = { + prepend, + batch: { + graph: g, + origin, + destination, + runs: 1, + }, + }; + + /** + * Flag to indicate whether the graph has already been resolved. This is used to prevent multiple resolutions. + */ + let isResolved = false; + /** * The queue item id is set to null initially, but will be updated once the graph is enqueued. */ let queueItemId: number | null = null; - /** - * If a timeout is provided, we will cancel the graph if it takes too long - but we need a way to clear the timeout - * if the graph completes or errors before the timeout. - */ - let timeoutId: number | null = null; - let pollingIntervalId: number | null = null; + const cleanupFunctions: Set<() => void> = new Set(); + const cleanup = () => { + for (const func of cleanupFunctions) { + func(); + } + }; + + if (timeout !== undefined) { + const timeoutId = window.setTimeout(() => { + if (isResolved) { + return; + } + log.trace('Graph canceled by timeout'); + cleanup(); + if (queueItemId !== null) { + cancelQueueItem(queueItemId, store); + } + reject(new Error('Graph timed out')); + }, timeout); + + cleanupFunctions.add(() => { + window.clearTimeout(timeoutId); + }); + } + + if (signal !== undefined) { + signal.addEventListener('abort', () => { + if (isResolved) { + return; + } + log.trace('Graph canceled by signal'); + cleanup(); + if (queueItemId !== null) { + cancelQueueItem(queueItemId, store); + } + reject(new Error('Graph canceled')); + }); + // TODO(psyche): Do we need to somehow clean up the signal? Not sure what is required here. + } + + const onQueueItemStatusChanged = async (event: S['QueueItemStatusChangedEvent']) => { + if (isResolved) { + return; + } - const queueItemStatusChangedHandler = async (event: S['QueueItemStatusChangedEvent']) => { // Ignore events that are not for this graph if (event.origin !== origin) { return; @@ -116,17 +165,21 @@ export const runGraph = (arg: RunGraphArg): Promise => { return; } - // Once we get here, the event is for the correct graph and the status is either 'completed', 'failed', or 'canceled'. + // The queue item is finished + isResolved = true; cleanup(); - if (event.status === 'completed') { - const queueItemResult = await withResultAsync(() => getQueueItem(event.item_id, store)); - if (queueItemResult.isErr()) { - reject(queueItemResult.error); - return; - } - const queueItem = queueItemResult.value; - const { session } = queueItem; + const queueItemResult = await withResultAsync(() => getQueueItem(event.item_id, store)); + if (queueItemResult.isErr()) { + reject(queueItemResult.error); + return; + } + + const queueItem = queueItemResult.value; + + const { status, session, error_type, error_message, error_traceback } = queueItem; + + if (status === 'completed') { const getOutputResult = withResult(() => getOutputFromSession(session, outputNodeId)); if (getOutputResult.isErr()) { reject(getOutputResult.error); @@ -134,62 +187,32 @@ export const runGraph = (arg: RunGraphArg): Promise => { } const output = getOutputResult.value; - didSuceed = true; resolve({ session, output }); return; } - if (event.status === 'failed') { + if (status === 'failed') { // We expect the event to have error details, but technically it's possible that it doesn't - const { error_type, error_message, error_traceback } = event; if (error_type && error_message && error_traceback) { reject(new QueueError(error_type, error_message, error_traceback)); - } else { - reject(new Error('Queue item failed, but no error details were provided')); + return; } + + // If we don't have error details, we can't provide a useful error message + reject(new Error('Queue item failed, but no error details were provided')); return; } - if (event.status === 'canceled') { + if (status === 'canceled') { reject(new Error('Graph canceled')); return; } - - assert>(false); }; - if (pollingInterval !== undefined) { - const pollForResult = async () => { - if (queueItemId === null) { - return; - } - const _queueItemId = queueItemId; - const getQueueItemResult = await withResultAsync(() => getQueueItem(_queueItemId, store)); - if (getQueueItemResult.isErr()) { - reject(getQueueItemResult.error); - return; - } - const queueItem = getQueueItemResult.value; - if (queueItem.status === 'pending' || queueItem.status === 'in_progress') { - return; - } - - cleanup(); - - const { session } = queueItem; - const getOutputResult = withResult(() => getOutputFromSession(session, outputNodeId)); - if (getOutputResult.isErr()) { - reject(getOutputResult.error); - return; - } - const output = getOutputResult.value; - didSuceed = true; - resolve({ session, output }); - return; - }; - - pollingIntervalId = window.setInterval(pollForResult, pollingInterval); - } + socket.on('queue_item_status_changed', onQueueItemStatusChanged); + cleanupFunctions.add(() => { + socket.off('queue_item_status_changed', onQueueItemStatusChanged); + }); // We are ready to enqueue the graph const enqueueRequest = store.dispatch( @@ -215,60 +238,6 @@ export const runGraph = (arg: RunGraphArg): Promise => { .catch((error) => { reject(error); }); - - socket.on('queue_item_status_changed', queueItemStatusChangedHandler); - - const _cleanupTimeout = () => { - if (timeoutId !== null) { - window.clearTimeout(timeoutId); - timeoutId = null; - } - }; - const _cleanupPollingInterval = () => { - if (pollingIntervalId !== null) { - window.clearInterval(pollingIntervalId); - pollingIntervalId = null; - } - }; - const _cleanupListeners = () => { - socket.off('queue_item_status_changed', queueItemStatusChangedHandler); - }; - - const cleanup = () => { - _cleanupTimeout(); - _cleanupPollingInterval(); - _cleanupListeners(); - }; - - if (timeout) { - timeoutId = window.setTimeout(() => { - if (didSuceed) { - // If we already succeeded, we don't need to do anything - return; - } - log.trace('Graph canceled by timeout'); - cleanup(); - if (queueItemId !== null) { - cancelQueueItem(queueItemId, store); - } - reject(new Error('Graph timed out')); - }, timeout); - } - - if (signal) { - signal.addEventListener('abort', () => { - if (didSuceed) { - // If we already succeeded, we don't need to do anything - return; - } - log.trace('Graph canceled by signal'); - cleanup(); - if (queueItemId !== null) { - cancelQueueItem(queueItemId, store); - } - reject(new Error('Graph canceled')); - }); - } }); return promise;