mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-14 05:04:56 -05:00
feat(ui): use DI to make runGraph testable
This commit is contained in:
@@ -3,20 +3,35 @@ import type { AppStore } from 'app/store/store';
|
||||
import { withResult, withResultAsync } from 'common/util/result';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import type { S } from 'services/api/types';
|
||||
import { QueueError } from 'services/events/errors';
|
||||
import type { AppSocket } from 'services/events/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from './endpoints/queue';
|
||||
import type { EnqueueBatchArg, S } from './types';
|
||||
import type { EnqueueBatchArg } from './types';
|
||||
|
||||
const log = logger('queue');
|
||||
|
||||
interface QueueStatusEventHandler {
|
||||
subscribe: (handler: (event: S['QueueItemStatusChangedEvent']) => void) => void;
|
||||
unsubscribe: (handler: (event: S['QueueItemStatusChangedEvent']) => void) => void;
|
||||
}
|
||||
|
||||
interface GraphExecutor {
|
||||
enqueueBatch: (batch: EnqueueBatchArg) => Promise<{ item_ids: number[] }>;
|
||||
getQueueItem: (id: number) => Promise<S['SessionQueueItem']>;
|
||||
cancelQueueItem: (id: number) => Promise<S['SessionQueueItem']>;
|
||||
}
|
||||
|
||||
interface GraphRunnerDependencies {
|
||||
executor: GraphExecutor;
|
||||
eventHandler: QueueStatusEventHandler;
|
||||
}
|
||||
|
||||
type RunGraphArg = {
|
||||
graph: Graph;
|
||||
outputNodeId: string;
|
||||
store: AppStore;
|
||||
socket: AppSocket;
|
||||
dependencies: GraphRunnerDependencies;
|
||||
destination?: string;
|
||||
prepend?: boolean;
|
||||
timeout?: number;
|
||||
@@ -28,6 +43,36 @@ type RunGraphReturn = {
|
||||
output: S['GraphExecutionState']['results'][string];
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates production dependencies for runGraph using Redux store and socket.
|
||||
*/
|
||||
export const createProductionDependencies = (
|
||||
store: AppStore,
|
||||
socket: {
|
||||
on: (event: 'queue_item_status_changed', handler: (event: S['QueueItemStatusChangedEvent']) => void) => void;
|
||||
off: (event: 'queue_item_status_changed', handler: (event: S['QueueItemStatusChangedEvent']) => void) => void;
|
||||
}
|
||||
): GraphRunnerDependencies => ({
|
||||
executor: {
|
||||
enqueueBatch: (batch) =>
|
||||
store
|
||||
.dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batch, {
|
||||
...enqueueMutationFixedCacheKeyOptions,
|
||||
track: false,
|
||||
})
|
||||
)
|
||||
.unwrap(),
|
||||
getQueueItem: (id) => store.dispatch(queueApi.endpoints.getQueueItem.initiate(id, { subscribe: false })).unwrap(),
|
||||
cancelQueueItem: (id) =>
|
||||
store.dispatch(queueApi.endpoints.cancelQueueItem.initiate({ item_id: id }, { track: false })).unwrap(),
|
||||
},
|
||||
eventHandler: {
|
||||
subscribe: (handler) => socket.on('queue_item_status_changed', handler),
|
||||
unsubscribe: (handler) => socket.off('queue_item_status_changed', handler),
|
||||
},
|
||||
});
|
||||
|
||||
/**
|
||||
* Run a graph and return the output of a specific node.
|
||||
*
|
||||
@@ -38,8 +83,7 @@ type RunGraphReturn = {
|
||||
*
|
||||
* @param arg.graph The graph to execute as an instance of the Graph class.
|
||||
* @param arg.outputNodeId The id of the node whose output will be retrieved.
|
||||
* @param arg.store The Redux store to use for dispatching actions and accessing state.
|
||||
* @param arg.socket The socket to use for listening to events.
|
||||
* @param arg.dependencies The dependencies for queue operations and event handling.
|
||||
* @param arg.destination The destination to assign to the batch. If omitted, the destination is not set.
|
||||
* @param arg.prepend Whether to prepend the graph to the front of the queue. If omitted, the graph is appended to the end of the queue.
|
||||
* @param arg.timeout The timeout for the batch. If omitted, there is no timeout.
|
||||
@@ -50,12 +94,14 @@ type RunGraphReturn = {
|
||||
* @example
|
||||
*
|
||||
* ```ts
|
||||
* const dependencies = createProductionDependencies(store, socket);
|
||||
* const graph = new Graph();
|
||||
* const outputNode = graph.addNode({ id: 'my-resize-node', type: 'img_resize', image: { image_name: 'my-image.png' } });
|
||||
* const controller = new AbortController();
|
||||
* const imageDTO = await this.manager.stateApi.runGraphAndReturnImageOutput({
|
||||
* const result = await runGraph({
|
||||
* graph,
|
||||
* outputNodeId: outputNode.id,
|
||||
* dependencies,
|
||||
* prepend: true,
|
||||
* signal: controller.signal,
|
||||
* });
|
||||
@@ -65,7 +111,7 @@ type RunGraphReturn = {
|
||||
*/
|
||||
export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
|
||||
const promise = new Promise<RunGraphReturn>((resolve, reject) => {
|
||||
const { graph, outputNodeId, store, socket, destination, prepend, timeout, signal } = arg;
|
||||
const { graph, outputNodeId, dependencies, destination, prepend, timeout, signal } = arg;
|
||||
|
||||
if (!graph.hasNode(outputNodeId)) {
|
||||
reject(new Error(`Graph does not contain output node ${outputNodeId}.`));
|
||||
@@ -125,7 +171,7 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
|
||||
log.trace('Graph canceled by timeout');
|
||||
cleanup();
|
||||
if (queueItemId !== null) {
|
||||
cancelQueueItem(queueItemId, store);
|
||||
dependencies.executor.cancelQueueItem(queueItemId);
|
||||
}
|
||||
reject(new Error('Graph timed out'));
|
||||
}, timeout);
|
||||
@@ -143,7 +189,7 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
|
||||
log.trace('Graph canceled by signal');
|
||||
cleanup();
|
||||
if (queueItemId !== null) {
|
||||
cancelQueueItem(queueItemId, store);
|
||||
dependencies.executor.cancelQueueItem(queueItemId);
|
||||
}
|
||||
reject(new Error('Graph canceled'));
|
||||
};
|
||||
@@ -173,7 +219,7 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
|
||||
isResolved = true;
|
||||
cleanup();
|
||||
|
||||
const queueItemResult = await withResultAsync(() => getQueueItem(event.item_id, store));
|
||||
const queueItemResult = await withResultAsync(() => dependencies.executor.getQueueItem(event.item_id));
|
||||
if (queueItemResult.isErr()) {
|
||||
reject(queueItemResult.error);
|
||||
return;
|
||||
@@ -213,26 +259,14 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
|
||||
}
|
||||
};
|
||||
|
||||
socket.on('queue_item_status_changed', onQueueItemStatusChanged);
|
||||
dependencies.eventHandler.subscribe(onQueueItemStatusChanged);
|
||||
cleanupFunctions.add(() => {
|
||||
socket.off('queue_item_status_changed', onQueueItemStatusChanged);
|
||||
dependencies.eventHandler.unsubscribe(onQueueItemStatusChanged);
|
||||
});
|
||||
|
||||
// We are ready to enqueue the graph
|
||||
const enqueueRequest = store.dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batch, {
|
||||
// Use the same cache key for all enqueueBatch requests, so that all consumers of this query get the same status
|
||||
// updates.
|
||||
...enqueueMutationFixedCacheKeyOptions,
|
||||
// We do not need RTK to track this request in the store
|
||||
track: false,
|
||||
})
|
||||
);
|
||||
|
||||
// Enqueue the graph and get the batch_id, updating the cancel graph callack. We need to do this in a .then() block
|
||||
// instead of awaiting the promise to avoid await-ing in a promise executor. Also need to catch any errors.
|
||||
enqueueRequest
|
||||
.unwrap()
|
||||
dependencies.executor
|
||||
.enqueueBatch(batch)
|
||||
.then((data) => {
|
||||
// We queue a single run of the batch, so we expect only one item_id in the response.
|
||||
assert(data.item_ids.length === 1);
|
||||
@@ -240,17 +274,17 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
|
||||
queueItemId = data.item_ids[0];
|
||||
})
|
||||
.catch((error) => {
|
||||
reject(error);
|
||||
if (!isResolved) {
|
||||
isResolved = true;
|
||||
cleanup();
|
||||
reject(error);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return promise;
|
||||
};
|
||||
|
||||
const getQueueItem = (queueItemId: number, store: AppStore): Promise<S['SessionQueueItem']> => {
|
||||
return store.dispatch(queueApi.endpoints.getQueueItem.initiate(queueItemId, { subscribe: false })).unwrap();
|
||||
};
|
||||
|
||||
const getOutputFromSession = (
|
||||
session: S['SessionQueueItem']['session'],
|
||||
nodeId: string
|
||||
@@ -267,20 +301,14 @@ const getOutputFromSession = (
|
||||
return result;
|
||||
};
|
||||
|
||||
const cancelQueueItem = (queueItemId: number, store: AppStore): Promise<S['SessionQueueItem']> => {
|
||||
return store
|
||||
.dispatch(queueApi.endpoints.cancelQueueItem.initiate({ item_id: queueItemId }, { track: false }))
|
||||
.unwrap();
|
||||
};
|
||||
|
||||
class NodeNotFoundError extends Error {
|
||||
session: S['SessionQueueItem']['session'];
|
||||
nodeId: string;
|
||||
|
||||
constructor(nodeId: string, session: S['SessionQueueItem']['session']) {
|
||||
super();
|
||||
const availableNodes = Object.keys(session.source_prepared_mapping);
|
||||
super(`Node '${nodeId}' not found in session. Available nodes: ${availableNodes.join(', ')}`);
|
||||
this.name = this.constructor.name;
|
||||
this.message = `Node '${nodeId}' not found in session.`;
|
||||
this.session = session;
|
||||
this.nodeId = nodeId;
|
||||
}
|
||||
@@ -291,9 +319,9 @@ class ResultNotFoundError extends Error {
|
||||
nodeId: string;
|
||||
|
||||
constructor(nodeId: string, session: S['SessionQueueItem']['session']) {
|
||||
super();
|
||||
const availableResults = Object.keys(session.results);
|
||||
super(`Result for node '${nodeId}' not found in session. Available results: ${availableResults.join(', ')}`);
|
||||
this.name = this.constructor.name;
|
||||
this.message = `Result for node '${nodeId}' not found in session.`;
|
||||
this.session = session;
|
||||
this.nodeId = nodeId;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user