mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-13 18:25:28 -05:00
feat(ui): iterate on runGraph
This commit is contained in:
@@ -5,7 +5,6 @@ import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { QueueError } from 'services/events/errors';
|
||||
import type { AppSocket } from 'services/events/types';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from './endpoints/queue';
|
||||
@@ -22,7 +21,6 @@ type RunGraphArg = {
|
||||
prepend?: boolean;
|
||||
timeout?: number;
|
||||
signal?: AbortSignal;
|
||||
pollingInterval?: number; // Optional polling interval for checking the queue item status
|
||||
};
|
||||
|
||||
type RunGraphReturn = {
|
||||
@@ -31,18 +29,23 @@ type RunGraphReturn = {
|
||||
};
|
||||
|
||||
/**
|
||||
* Run a graph and return an image output. The specified output node must return an image output, else the promise
|
||||
* will reject with an error.
|
||||
* Run a graph and return the output of a specific node.
|
||||
*
|
||||
* @param arg The arguments for the function.
|
||||
* @param arg.graph The graph to execute.
|
||||
* The batch will be enqueued with runs set to 1, meaning it will only run once.
|
||||
*
|
||||
* Iterate nodes, which cause graph expansion, are not supported by this utility because they cause a single node
|
||||
* to have multiple outputs. An error will be thrown if the graph contains any iterate nodes.
|
||||
*
|
||||
* @param arg.graph The graph to execute as an instance of the Graph class.
|
||||
* @param arg.outputNodeId The id of the node whose output will be retrieved.
|
||||
* @param arg.store The Redux store to use for dispatching actions and accessing state.
|
||||
* @param arg.socket The socket to use for listening to events.
|
||||
* @param arg.destination The destination to assign to the batch. If omitted, the destination is not set.
|
||||
* @param arg.prepend Whether to prepend the graph to the front of the queue. If omitted, the graph is appended to the end of the queue.
|
||||
* @param arg.timeout The timeout for the batch. If omitted, there is no timeout.
|
||||
* @param arg.signal An optional signal to cancel the operation. If omitted, the operation cannot be canceled!
|
||||
*
|
||||
* @returns A promise that resolves to the image output or rejects with an error.
|
||||
* @returns A promise that resolves to the output and completed session, or rejects with an error if the graph fails or is canceled.
|
||||
*
|
||||
* @example
|
||||
*
|
||||
@@ -61,51 +64,97 @@ type RunGraphReturn = {
|
||||
* ```
|
||||
*/
|
||||
export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
|
||||
const { graph, outputNodeId, destination, prepend, timeout, signal, store, socket, pollingInterval } = arg;
|
||||
|
||||
if (!graph.hasNode(outputNodeId)) {
|
||||
throw new Error(`Graph does not contain node with id: ${outputNodeId}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* We will use the origin to handle events from the graph. Ideally we'd just use the queue item's id, but there's a
|
||||
* race condition:
|
||||
* - The queue item id is not available until the graph is enqueued
|
||||
* - The graph may complete before we can set up the listeners to handle the completion event
|
||||
*
|
||||
* The origin is the only unique identifier we have that is guaranteed to be available before the graph is enqueued,
|
||||
* so we will use that to filter events.
|
||||
*/
|
||||
const origin = getPrefixedId(graph.id);
|
||||
|
||||
const batch: EnqueueBatchArg = {
|
||||
prepend,
|
||||
batch: {
|
||||
graph: graph.getGraph(),
|
||||
origin,
|
||||
destination,
|
||||
runs: 1,
|
||||
},
|
||||
};
|
||||
|
||||
const promise = new Promise<RunGraphReturn>((resolve, reject) => {
|
||||
const { graph, outputNodeId, store, socket, destination, prepend, timeout, signal } = arg;
|
||||
|
||||
if (!graph.hasNode(outputNodeId)) {
|
||||
reject(new Error(`Graph does not contain output node ${outputNodeId}.`));
|
||||
return;
|
||||
}
|
||||
|
||||
const g = graph.getGraph();
|
||||
|
||||
if (Object.values(g.nodes).some((node) => node.type === 'iterate')) {
|
||||
reject(new Error('Iterate nodes are not supported by this utility.'));
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Track execution state.
|
||||
* We will use the origin to handle events from the graph. Ideally we'd just use the queue item's id, but there's a
|
||||
* race condition:
|
||||
* - The queue item id is not available until the graph is enqueued.
|
||||
* - The graph may complete before we get a response back from enqueuing, so our listeners would miss the event.
|
||||
*
|
||||
* The origin is the only unique identifier that we can set before enqueuing the graph, so we use it to filter
|
||||
* queue item status change events.
|
||||
*/
|
||||
let didSuceed = false;
|
||||
const origin = getPrefixedId(graph.id);
|
||||
|
||||
const batch: EnqueueBatchArg = {
|
||||
prepend,
|
||||
batch: {
|
||||
graph: g,
|
||||
origin,
|
||||
destination,
|
||||
runs: 1,
|
||||
},
|
||||
};
|
||||
|
||||
/**
|
||||
* Flag to indicate whether the graph has already been resolved. This is used to prevent multiple resolutions.
|
||||
*/
|
||||
let isResolved = false;
|
||||
|
||||
/**
|
||||
* The queue item id is set to null initially, but will be updated once the graph is enqueued.
|
||||
*/
|
||||
let queueItemId: number | null = null;
|
||||
/**
|
||||
* If a timeout is provided, we will cancel the graph if it takes too long - but we need a way to clear the timeout
|
||||
* if the graph completes or errors before the timeout.
|
||||
*/
|
||||
let timeoutId: number | null = null;
|
||||
|
||||
let pollingIntervalId: number | null = null;
|
||||
const cleanupFunctions: Set<() => void> = new Set();
|
||||
const cleanup = () => {
|
||||
for (const func of cleanupFunctions) {
|
||||
func();
|
||||
}
|
||||
};
|
||||
|
||||
if (timeout !== undefined) {
|
||||
const timeoutId = window.setTimeout(() => {
|
||||
if (isResolved) {
|
||||
return;
|
||||
}
|
||||
log.trace('Graph canceled by timeout');
|
||||
cleanup();
|
||||
if (queueItemId !== null) {
|
||||
cancelQueueItem(queueItemId, store);
|
||||
}
|
||||
reject(new Error('Graph timed out'));
|
||||
}, timeout);
|
||||
|
||||
cleanupFunctions.add(() => {
|
||||
window.clearTimeout(timeoutId);
|
||||
});
|
||||
}
|
||||
|
||||
if (signal !== undefined) {
|
||||
signal.addEventListener('abort', () => {
|
||||
if (isResolved) {
|
||||
return;
|
||||
}
|
||||
log.trace('Graph canceled by signal');
|
||||
cleanup();
|
||||
if (queueItemId !== null) {
|
||||
cancelQueueItem(queueItemId, store);
|
||||
}
|
||||
reject(new Error('Graph canceled'));
|
||||
});
|
||||
// TODO(psyche): Do we need to somehow clean up the signal? Not sure what is required here.
|
||||
}
|
||||
|
||||
const onQueueItemStatusChanged = async (event: S['QueueItemStatusChangedEvent']) => {
|
||||
if (isResolved) {
|
||||
return;
|
||||
}
|
||||
|
||||
const queueItemStatusChangedHandler = async (event: S['QueueItemStatusChangedEvent']) => {
|
||||
// Ignore events that are not for this graph
|
||||
if (event.origin !== origin) {
|
||||
return;
|
||||
@@ -116,17 +165,21 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
|
||||
return;
|
||||
}
|
||||
|
||||
// Once we get here, the event is for the correct graph and the status is either 'completed', 'failed', or 'canceled'.
|
||||
// The queue item is finished
|
||||
isResolved = true;
|
||||
cleanup();
|
||||
|
||||
if (event.status === 'completed') {
|
||||
const queueItemResult = await withResultAsync(() => getQueueItem(event.item_id, store));
|
||||
if (queueItemResult.isErr()) {
|
||||
reject(queueItemResult.error);
|
||||
return;
|
||||
}
|
||||
const queueItem = queueItemResult.value;
|
||||
const { session } = queueItem;
|
||||
const queueItemResult = await withResultAsync(() => getQueueItem(event.item_id, store));
|
||||
if (queueItemResult.isErr()) {
|
||||
reject(queueItemResult.error);
|
||||
return;
|
||||
}
|
||||
|
||||
const queueItem = queueItemResult.value;
|
||||
|
||||
const { status, session, error_type, error_message, error_traceback } = queueItem;
|
||||
|
||||
if (status === 'completed') {
|
||||
const getOutputResult = withResult(() => getOutputFromSession(session, outputNodeId));
|
||||
if (getOutputResult.isErr()) {
|
||||
reject(getOutputResult.error);
|
||||
@@ -134,62 +187,32 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
|
||||
}
|
||||
const output = getOutputResult.value;
|
||||
|
||||
didSuceed = true;
|
||||
resolve({ session, output });
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.status === 'failed') {
|
||||
if (status === 'failed') {
|
||||
// We expect the event to have error details, but technically it's possible that it doesn't
|
||||
const { error_type, error_message, error_traceback } = event;
|
||||
if (error_type && error_message && error_traceback) {
|
||||
reject(new QueueError(error_type, error_message, error_traceback));
|
||||
} else {
|
||||
reject(new Error('Queue item failed, but no error details were provided'));
|
||||
return;
|
||||
}
|
||||
|
||||
// If we don't have error details, we can't provide a useful error message
|
||||
reject(new Error('Queue item failed, but no error details were provided'));
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.status === 'canceled') {
|
||||
if (status === 'canceled') {
|
||||
reject(new Error('Graph canceled'));
|
||||
return;
|
||||
}
|
||||
|
||||
assert<Equals<never, typeof event.status>>(false);
|
||||
};
|
||||
|
||||
if (pollingInterval !== undefined) {
|
||||
const pollForResult = async () => {
|
||||
if (queueItemId === null) {
|
||||
return;
|
||||
}
|
||||
const _queueItemId = queueItemId;
|
||||
const getQueueItemResult = await withResultAsync(() => getQueueItem(_queueItemId, store));
|
||||
if (getQueueItemResult.isErr()) {
|
||||
reject(getQueueItemResult.error);
|
||||
return;
|
||||
}
|
||||
const queueItem = getQueueItemResult.value;
|
||||
if (queueItem.status === 'pending' || queueItem.status === 'in_progress') {
|
||||
return;
|
||||
}
|
||||
|
||||
cleanup();
|
||||
|
||||
const { session } = queueItem;
|
||||
const getOutputResult = withResult(() => getOutputFromSession(session, outputNodeId));
|
||||
if (getOutputResult.isErr()) {
|
||||
reject(getOutputResult.error);
|
||||
return;
|
||||
}
|
||||
const output = getOutputResult.value;
|
||||
didSuceed = true;
|
||||
resolve({ session, output });
|
||||
return;
|
||||
};
|
||||
|
||||
pollingIntervalId = window.setInterval(pollForResult, pollingInterval);
|
||||
}
|
||||
socket.on('queue_item_status_changed', onQueueItemStatusChanged);
|
||||
cleanupFunctions.add(() => {
|
||||
socket.off('queue_item_status_changed', onQueueItemStatusChanged);
|
||||
});
|
||||
|
||||
// We are ready to enqueue the graph
|
||||
const enqueueRequest = store.dispatch(
|
||||
@@ -215,60 +238,6 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
|
||||
.catch((error) => {
|
||||
reject(error);
|
||||
});
|
||||
|
||||
socket.on('queue_item_status_changed', queueItemStatusChangedHandler);
|
||||
|
||||
const _cleanupTimeout = () => {
|
||||
if (timeoutId !== null) {
|
||||
window.clearTimeout(timeoutId);
|
||||
timeoutId = null;
|
||||
}
|
||||
};
|
||||
const _cleanupPollingInterval = () => {
|
||||
if (pollingIntervalId !== null) {
|
||||
window.clearInterval(pollingIntervalId);
|
||||
pollingIntervalId = null;
|
||||
}
|
||||
};
|
||||
const _cleanupListeners = () => {
|
||||
socket.off('queue_item_status_changed', queueItemStatusChangedHandler);
|
||||
};
|
||||
|
||||
const cleanup = () => {
|
||||
_cleanupTimeout();
|
||||
_cleanupPollingInterval();
|
||||
_cleanupListeners();
|
||||
};
|
||||
|
||||
if (timeout) {
|
||||
timeoutId = window.setTimeout(() => {
|
||||
if (didSuceed) {
|
||||
// If we already succeeded, we don't need to do anything
|
||||
return;
|
||||
}
|
||||
log.trace('Graph canceled by timeout');
|
||||
cleanup();
|
||||
if (queueItemId !== null) {
|
||||
cancelQueueItem(queueItemId, store);
|
||||
}
|
||||
reject(new Error('Graph timed out'));
|
||||
}, timeout);
|
||||
}
|
||||
|
||||
if (signal) {
|
||||
signal.addEventListener('abort', () => {
|
||||
if (didSuceed) {
|
||||
// If we already succeeded, we don't need to do anything
|
||||
return;
|
||||
}
|
||||
log.trace('Graph canceled by signal');
|
||||
cleanup();
|
||||
if (queueItemId !== null) {
|
||||
cancelQueueItem(queueItemId, store);
|
||||
}
|
||||
reject(new Error('Graph canceled'));
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
return promise;
|
||||
|
||||
Reference in New Issue
Block a user