feat(ui): iterate on runGraph

This commit is contained in:
psychedelicious
2025-06-28 16:57:41 +10:00
parent e379ac12c3
commit d1cbf56695

View File

@@ -5,7 +5,6 @@ import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { QueueError } from 'services/events/errors';
import type { AppSocket } from 'services/events/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from './endpoints/queue';
@@ -22,7 +21,6 @@ type RunGraphArg = {
prepend?: boolean;
timeout?: number;
signal?: AbortSignal;
pollingInterval?: number; // Optional polling interval for checking the queue item status
};
type RunGraphReturn = {
@@ -31,18 +29,23 @@ type RunGraphReturn = {
};
/**
* Run a graph and return an image output. The specified output node must return an image output, else the promise
* will reject with an error.
* Run a graph and return the output of a specific node.
*
* @param arg The arguments for the function.
* @param arg.graph The graph to execute.
* The batch will be enqueued with runs set to 1, meaning it will only run once.
*
* Iterate nodes, which cause graph expansion, are not supported by this utility because they cause a single node
* to have multiple outputs. An error will be thrown if the graph contains any iterate nodes.
*
* @param arg.graph The graph to execute as an instance of the Graph class.
* @param arg.outputNodeId The id of the node whose output will be retrieved.
* @param arg.store The Redux store to use for dispatching actions and accessing state.
* @param arg.socket The socket to use for listening to events.
* @param arg.destination The destination to assign to the batch. If omitted, the destination is not set.
* @param arg.prepend Whether to prepend the graph to the front of the queue. If omitted, the graph is appended to the end of the queue.
* @param arg.timeout The timeout for the batch. If omitted, there is no timeout.
* @param arg.signal An optional signal to cancel the operation. If omitted, the operation cannot be canceled!
*
* @returns A promise that resolves to the image output or rejects with an error.
* @returns A promise that resolves to the output and completed session, or rejects with an error if the graph fails or is canceled.
*
* @example
*
@@ -61,51 +64,97 @@ type RunGraphReturn = {
* ```
*/
export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
const { graph, outputNodeId, destination, prepend, timeout, signal, store, socket, pollingInterval } = arg;
if (!graph.hasNode(outputNodeId)) {
throw new Error(`Graph does not contain node with id: ${outputNodeId}`);
}
/**
* We will use the origin to handle events from the graph. Ideally we'd just use the queue item's id, but there's a
* race condition:
* - The queue item id is not available until the graph is enqueued
* - The graph may complete before we can set up the listeners to handle the completion event
*
* The origin is the only unique identifier we have that is guaranteed to be available before the graph is enqueued,
* so we will use that to filter events.
*/
const origin = getPrefixedId(graph.id);
const batch: EnqueueBatchArg = {
prepend,
batch: {
graph: graph.getGraph(),
origin,
destination,
runs: 1,
},
};
const promise = new Promise<RunGraphReturn>((resolve, reject) => {
const { graph, outputNodeId, store, socket, destination, prepend, timeout, signal } = arg;
if (!graph.hasNode(outputNodeId)) {
reject(new Error(`Graph does not contain output node ${outputNodeId}.`));
return;
}
const g = graph.getGraph();
if (Object.values(g.nodes).some((node) => node.type === 'iterate')) {
reject(new Error('Iterate nodes are not supported by this utility.'));
return;
}
/**
* Track execution state.
* We will use the origin to handle events from the graph. Ideally we'd just use the queue item's id, but there's a
* race condition:
* - The queue item id is not available until the graph is enqueued.
* - The graph may complete before we get a response back from enqueuing, so our listeners would miss the event.
*
* The origin is the only unique identifier that we can set before enqueuing the graph, so we use it to filter
* queue item status change events.
*/
let didSuceed = false;
const origin = getPrefixedId(graph.id);
const batch: EnqueueBatchArg = {
prepend,
batch: {
graph: g,
origin,
destination,
runs: 1,
},
};
/**
* Flag to indicate whether the graph has already been resolved. This is used to prevent multiple resolutions.
*/
let isResolved = false;
/**
* The queue item id is set to null initially, but will be updated once the graph is enqueued.
*/
let queueItemId: number | null = null;
/**
* If a timeout is provided, we will cancel the graph if it takes too long - but we need a way to clear the timeout
* if the graph completes or errors before the timeout.
*/
let timeoutId: number | null = null;
let pollingIntervalId: number | null = null;
const cleanupFunctions: Set<() => void> = new Set();
const cleanup = () => {
for (const func of cleanupFunctions) {
func();
}
};
if (timeout !== undefined) {
const timeoutId = window.setTimeout(() => {
if (isResolved) {
return;
}
log.trace('Graph canceled by timeout');
cleanup();
if (queueItemId !== null) {
cancelQueueItem(queueItemId, store);
}
reject(new Error('Graph timed out'));
}, timeout);
cleanupFunctions.add(() => {
window.clearTimeout(timeoutId);
});
}
if (signal !== undefined) {
signal.addEventListener('abort', () => {
if (isResolved) {
return;
}
log.trace('Graph canceled by signal');
cleanup();
if (queueItemId !== null) {
cancelQueueItem(queueItemId, store);
}
reject(new Error('Graph canceled'));
});
// TODO(psyche): Do we need to somehow clean up the signal? Not sure what is required here.
}
const onQueueItemStatusChanged = async (event: S['QueueItemStatusChangedEvent']) => {
if (isResolved) {
return;
}
const queueItemStatusChangedHandler = async (event: S['QueueItemStatusChangedEvent']) => {
// Ignore events that are not for this graph
if (event.origin !== origin) {
return;
@@ -116,17 +165,21 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
return;
}
// Once we get here, the event is for the correct graph and the status is either 'completed', 'failed', or 'canceled'.
// The queue item is finished
isResolved = true;
cleanup();
if (event.status === 'completed') {
const queueItemResult = await withResultAsync(() => getQueueItem(event.item_id, store));
if (queueItemResult.isErr()) {
reject(queueItemResult.error);
return;
}
const queueItem = queueItemResult.value;
const { session } = queueItem;
const queueItemResult = await withResultAsync(() => getQueueItem(event.item_id, store));
if (queueItemResult.isErr()) {
reject(queueItemResult.error);
return;
}
const queueItem = queueItemResult.value;
const { status, session, error_type, error_message, error_traceback } = queueItem;
if (status === 'completed') {
const getOutputResult = withResult(() => getOutputFromSession(session, outputNodeId));
if (getOutputResult.isErr()) {
reject(getOutputResult.error);
@@ -134,62 +187,32 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
}
const output = getOutputResult.value;
didSuceed = true;
resolve({ session, output });
return;
}
if (event.status === 'failed') {
if (status === 'failed') {
// We expect the event to have error details, but technically it's possible that it doesn't
const { error_type, error_message, error_traceback } = event;
if (error_type && error_message && error_traceback) {
reject(new QueueError(error_type, error_message, error_traceback));
} else {
reject(new Error('Queue item failed, but no error details were provided'));
return;
}
// If we don't have error details, we can't provide a useful error message
reject(new Error('Queue item failed, but no error details were provided'));
return;
}
if (event.status === 'canceled') {
if (status === 'canceled') {
reject(new Error('Graph canceled'));
return;
}
assert<Equals<never, typeof event.status>>(false);
};
if (pollingInterval !== undefined) {
const pollForResult = async () => {
if (queueItemId === null) {
return;
}
const _queueItemId = queueItemId;
const getQueueItemResult = await withResultAsync(() => getQueueItem(_queueItemId, store));
if (getQueueItemResult.isErr()) {
reject(getQueueItemResult.error);
return;
}
const queueItem = getQueueItemResult.value;
if (queueItem.status === 'pending' || queueItem.status === 'in_progress') {
return;
}
cleanup();
const { session } = queueItem;
const getOutputResult = withResult(() => getOutputFromSession(session, outputNodeId));
if (getOutputResult.isErr()) {
reject(getOutputResult.error);
return;
}
const output = getOutputResult.value;
didSuceed = true;
resolve({ session, output });
return;
};
pollingIntervalId = window.setInterval(pollForResult, pollingInterval);
}
socket.on('queue_item_status_changed', onQueueItemStatusChanged);
cleanupFunctions.add(() => {
socket.off('queue_item_status_changed', onQueueItemStatusChanged);
});
// We are ready to enqueue the graph
const enqueueRequest = store.dispatch(
@@ -215,60 +238,6 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
.catch((error) => {
reject(error);
});
socket.on('queue_item_status_changed', queueItemStatusChangedHandler);
const _cleanupTimeout = () => {
if (timeoutId !== null) {
window.clearTimeout(timeoutId);
timeoutId = null;
}
};
const _cleanupPollingInterval = () => {
if (pollingIntervalId !== null) {
window.clearInterval(pollingIntervalId);
pollingIntervalId = null;
}
};
const _cleanupListeners = () => {
socket.off('queue_item_status_changed', queueItemStatusChangedHandler);
};
const cleanup = () => {
_cleanupTimeout();
_cleanupPollingInterval();
_cleanupListeners();
};
if (timeout) {
timeoutId = window.setTimeout(() => {
if (didSuceed) {
// If we already succeeded, we don't need to do anything
return;
}
log.trace('Graph canceled by timeout');
cleanup();
if (queueItemId !== null) {
cancelQueueItem(queueItemId, store);
}
reject(new Error('Graph timed out'));
}, timeout);
}
if (signal) {
signal.addEventListener('abort', () => {
if (didSuceed) {
// If we already succeeded, we don't need to do anything
return;
}
log.trace('Graph canceled by signal');
cleanup();
if (queueItemId !== null) {
cancelQueueItem(queueItemId, store);
}
reject(new Error('Graph canceled'));
});
}
});
return promise;