feat(ui): use DI to make runGraph testable

This commit is contained in:
psychedelicious
2025-06-28 17:16:04 +10:00
parent b9ce5389ef
commit 2c6d22664e

View File

@@ -3,20 +3,35 @@ import type { AppStore } from 'app/store/store';
import { withResult, withResultAsync } from 'common/util/result';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { S } from 'services/api/types';
import { QueueError } from 'services/events/errors';
import type { AppSocket } from 'services/events/types';
import { assert } from 'tsafe';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from './endpoints/queue';
import type { EnqueueBatchArg, S } from './types';
import type { EnqueueBatchArg } from './types';
const log = logger('queue');
interface QueueStatusEventHandler {
subscribe: (handler: (event: S['QueueItemStatusChangedEvent']) => void) => void;
unsubscribe: (handler: (event: S['QueueItemStatusChangedEvent']) => void) => void;
}
interface GraphExecutor {
enqueueBatch: (batch: EnqueueBatchArg) => Promise<{ item_ids: number[] }>;
getQueueItem: (id: number) => Promise<S['SessionQueueItem']>;
cancelQueueItem: (id: number) => Promise<S['SessionQueueItem']>;
}
interface GraphRunnerDependencies {
executor: GraphExecutor;
eventHandler: QueueStatusEventHandler;
}
type RunGraphArg = {
graph: Graph;
outputNodeId: string;
store: AppStore;
socket: AppSocket;
dependencies: GraphRunnerDependencies;
destination?: string;
prepend?: boolean;
timeout?: number;
@@ -28,6 +43,36 @@ type RunGraphReturn = {
output: S['GraphExecutionState']['results'][string];
};
/**
* Creates production dependencies for runGraph using Redux store and socket.
*/
export const createProductionDependencies = (
store: AppStore,
socket: {
on: (event: 'queue_item_status_changed', handler: (event: S['QueueItemStatusChangedEvent']) => void) => void;
off: (event: 'queue_item_status_changed', handler: (event: S['QueueItemStatusChangedEvent']) => void) => void;
}
): GraphRunnerDependencies => ({
executor: {
enqueueBatch: (batch) =>
store
.dispatch(
queueApi.endpoints.enqueueBatch.initiate(batch, {
...enqueueMutationFixedCacheKeyOptions,
track: false,
})
)
.unwrap(),
getQueueItem: (id) => store.dispatch(queueApi.endpoints.getQueueItem.initiate(id, { subscribe: false })).unwrap(),
cancelQueueItem: (id) =>
store.dispatch(queueApi.endpoints.cancelQueueItem.initiate({ item_id: id }, { track: false })).unwrap(),
},
eventHandler: {
subscribe: (handler) => socket.on('queue_item_status_changed', handler),
unsubscribe: (handler) => socket.off('queue_item_status_changed', handler),
},
});
/**
* Run a graph and return the output of a specific node.
*
@@ -38,8 +83,7 @@ type RunGraphReturn = {
*
* @param arg.graph The graph to execute as an instance of the Graph class.
* @param arg.outputNodeId The id of the node whose output will be retrieved.
* @param arg.store The Redux store to use for dispatching actions and accessing state.
* @param arg.socket The socket to use for listening to events.
* @param arg.dependencies The dependencies for queue operations and event handling.
* @param arg.destination The destination to assign to the batch. If omitted, the destination is not set.
* @param arg.prepend Whether to prepend the graph to the front of the queue. If omitted, the graph is appended to the end of the queue.
* @param arg.timeout The timeout for the batch. If omitted, there is no timeout.
@@ -50,12 +94,14 @@ type RunGraphReturn = {
* @example
*
* ```ts
* const dependencies = createProductionDependencies(store, socket);
* const graph = new Graph();
* const outputNode = graph.addNode({ id: 'my-resize-node', type: 'img_resize', image: { image_name: 'my-image.png' } });
* const controller = new AbortController();
* const imageDTO = await this.manager.stateApi.runGraphAndReturnImageOutput({
* const result = await runGraph({
* graph,
* outputNodeId: outputNode.id,
* dependencies,
* prepend: true,
* signal: controller.signal,
* });
@@ -65,7 +111,7 @@ type RunGraphReturn = {
*/
export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
const promise = new Promise<RunGraphReturn>((resolve, reject) => {
const { graph, outputNodeId, store, socket, destination, prepend, timeout, signal } = arg;
const { graph, outputNodeId, dependencies, destination, prepend, timeout, signal } = arg;
if (!graph.hasNode(outputNodeId)) {
reject(new Error(`Graph does not contain output node ${outputNodeId}.`));
@@ -125,7 +171,7 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
log.trace('Graph canceled by timeout');
cleanup();
if (queueItemId !== null) {
cancelQueueItem(queueItemId, store);
dependencies.executor.cancelQueueItem(queueItemId);
}
reject(new Error('Graph timed out'));
}, timeout);
@@ -143,7 +189,7 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
log.trace('Graph canceled by signal');
cleanup();
if (queueItemId !== null) {
cancelQueueItem(queueItemId, store);
dependencies.executor.cancelQueueItem(queueItemId);
}
reject(new Error('Graph canceled'));
};
@@ -173,7 +219,7 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
isResolved = true;
cleanup();
const queueItemResult = await withResultAsync(() => getQueueItem(event.item_id, store));
const queueItemResult = await withResultAsync(() => dependencies.executor.getQueueItem(event.item_id));
if (queueItemResult.isErr()) {
reject(queueItemResult.error);
return;
@@ -213,26 +259,14 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
}
};
socket.on('queue_item_status_changed', onQueueItemStatusChanged);
dependencies.eventHandler.subscribe(onQueueItemStatusChanged);
cleanupFunctions.add(() => {
socket.off('queue_item_status_changed', onQueueItemStatusChanged);
dependencies.eventHandler.unsubscribe(onQueueItemStatusChanged);
});
// We are ready to enqueue the graph
const enqueueRequest = store.dispatch(
queueApi.endpoints.enqueueBatch.initiate(batch, {
// Use the same cache key for all enqueueBatch requests, so that all consumers of this query get the same status
// updates.
...enqueueMutationFixedCacheKeyOptions,
// We do not need RTK to track this request in the store
track: false,
})
);
// Enqueue the graph and get the batch_id, updating the cancel graph callack. We need to do this in a .then() block
// instead of awaiting the promise to avoid await-ing in a promise executor. Also need to catch any errors.
enqueueRequest
.unwrap()
dependencies.executor
.enqueueBatch(batch)
.then((data) => {
// We queue a single run of the batch, so we expect only one item_id in the response.
assert(data.item_ids.length === 1);
@@ -240,17 +274,17 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
queueItemId = data.item_ids[0];
})
.catch((error) => {
reject(error);
if (!isResolved) {
isResolved = true;
cleanup();
reject(error);
}
});
});
return promise;
};
const getQueueItem = (queueItemId: number, store: AppStore): Promise<S['SessionQueueItem']> => {
return store.dispatch(queueApi.endpoints.getQueueItem.initiate(queueItemId, { subscribe: false })).unwrap();
};
const getOutputFromSession = (
session: S['SessionQueueItem']['session'],
nodeId: string
@@ -267,20 +301,14 @@ const getOutputFromSession = (
return result;
};
const cancelQueueItem = (queueItemId: number, store: AppStore): Promise<S['SessionQueueItem']> => {
return store
.dispatch(queueApi.endpoints.cancelQueueItem.initiate({ item_id: queueItemId }, { track: false }))
.unwrap();
};
class NodeNotFoundError extends Error {
session: S['SessionQueueItem']['session'];
nodeId: string;
constructor(nodeId: string, session: S['SessionQueueItem']['session']) {
super();
const availableNodes = Object.keys(session.source_prepared_mapping);
super(`Node '${nodeId}' not found in session. Available nodes: ${availableNodes.join(', ')}`);
this.name = this.constructor.name;
this.message = `Node '${nodeId}' not found in session.`;
this.session = session;
this.nodeId = nodeId;
}
@@ -291,9 +319,9 @@ class ResultNotFoundError extends Error {
nodeId: string;
constructor(nodeId: string, session: S['SessionQueueItem']['session']) {
super();
const availableResults = Object.keys(session.results);
super(`Result for node '${nodeId}' not found in session. Available results: ${availableResults.join(', ')}`);
this.name = this.constructor.name;
this.message = `Result for node '${nodeId}' not found in session.`;
this.session = session;
this.nodeId = nodeId;
}