feat(ui): use runGraph in canvas

This commit is contained in:
psychedelicious
2025-06-28 18:47:40 +10:00
parent 40e9624954
commit 8dc6d0b5ae
3 changed files with 34 additions and 202 deletions

View File

@@ -2,7 +2,6 @@ import { $alt, $ctrl, $meta, $shift } from '@invoke-ai/ui-library';
import type { Selector } from '@reduxjs/toolkit';
import { addAppListener } from 'app/store/middleware/listenerMiddleware';
import type { AppStore, RootState } from 'app/store/store';
import { withResultAsync } from 'common/util/result';
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
@@ -49,15 +48,14 @@ import type {
RgbaColor,
} from 'features/controlLayers/store/types';
import { RGBA_BLACK } from 'features/controlLayers/store/types';
import { zImageOutput } from 'features/nodes/types/common';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { atom, computed } from 'nanostores';
import type { Logger } from 'roarr';
import { getImageDTO } from 'services/api/endpoints/images';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import type { EnqueueBatchArg, ImageDTO, S } from 'services/api/types';
import { QueueError } from 'services/events/errors';
import { buildRunGraphDependencies, runGraph } from 'services/api/run-graph';
import type { ImageDTO, S } from 'services/api/types';
import type { Param0 } from 'tsafe';
import { assert } from 'tsafe';
import type { CanvasEntityAdapter } from './CanvasEntity/types';
@@ -266,7 +264,7 @@ export class CanvasStateApiModule extends CanvasModuleBase {
* controller.abort();
* ```
*/
runGraphAndReturnImageOutput = (arg: {
runGraphAndReturnImageOutput = async (arg: {
graph: Graph;
outputNodeId: string;
destination?: string;
@@ -276,203 +274,37 @@ export class CanvasStateApiModule extends CanvasModuleBase {
}): Promise<ImageDTO> => {
const { graph, outputNodeId, destination, prepend, timeout, signal } = arg;
if (!graph.hasNode(outputNodeId)) {
throw new Error(`Graph does not contain node with id: ${outputNodeId}`);
}
const dependencies = buildRunGraphDependencies(this.store, this.manager.socket);
/**
* We will use the origin to handle events from the graph. Ideally we'd just use the queue item's id, but there's a
* race condition:
* - The queue item id is not available until the graph is enqueued
* - The graph may complete before we can set up the listeners to handle the completion event
*
* The origin is the only unique identifier we have that is guaranteed to be available before the graph is enqueued,
* so we will use that to filter events.
*/
const origin = getPrefixedId(graph.id);
const batch: EnqueueBatchArg = {
const { output } = await runGraph({
graph,
outputNodeId,
dependencies,
destination,
prepend,
batch: {
graph: graph.getGraph(),
origin,
destination,
runs: 1,
},
};
let didSuceed = false;
/**
* If a timeout is provided, we will cancel the graph if it takes too long - but we need a way to clear the timeout
* if the graph completes or errors before the timeout.
*/
let timeoutId: number | null = null;
const _clearTimeout = () => {
if (timeoutId !== null) {
window.clearTimeout(timeoutId);
timeoutId = null;
}
};
// There's a bit of a catch-22 here: we need to set the cancelGraph callback before we enqueue the graph, but we
// can't set it until we have the batch_id from the enqueue request. So we'll set a dummy function here and update
// it later.
let cancelGraph: () => void = () => {
this.log.warn('cancelGraph called before cancelGraph is set');
};
const resultPromise = new Promise<ImageDTO>((resolve, reject) => {
const invocationCompleteHandler = async (event: S['InvocationCompleteEvent']) => {
// Ignore events that are not for this graph
if (event.origin !== origin) {
return;
}
// Ignore events that are not from the output node
if (event.invocation_source_id !== outputNodeId) {
return;
}
// If we get here, the event is for the correct graph and output node.
// Clear the timeout and socket listeners
_clearTimeout();
clearListeners();
// The result must be an image output
const { result } = event;
if (result.type !== 'image_output') {
reject(new Error(`Graph output node did not return an image output, got: ${result}`));
return;
}
// Get the result image DTO
const getImageDTOResult = await withResultAsync(() => getImageDTO(result.image.image_name));
if (getImageDTOResult.isErr()) {
reject(getImageDTOResult.error);
return;
}
didSuceed = true;
// Ok!
resolve(getImageDTOResult.value);
};
const queueItemStatusChangedHandler = (event: S['QueueItemStatusChangedEvent']) => {
// Ignore events that are not for this graph
if (event.origin !== origin) {
return;
}
// Ignore events where the status is pending or in progress - no need to do anything for these
if (event.status === 'pending' || event.status === 'in_progress') {
return;
}
if (event.status === 'completed') {
/**
* The invocation_complete event should have been received before the queue item completed event, and the
* event listeners are cleared in the invocation_complete handler. If we get here, it means we never got
* the completion event for the output node! This should is a fail case.
*
* TODO(psyche): In the unexpected case where events are received out of order, this logic doesn't do what
* we expect. If we got a queue item completed event before the output node completion event, we'd erroneously
* triggers this error.
*
* For now, we'll just log a warning instead of rejecting the promise. This should be super rare anyways.
*/
// reject(new Error('Queue item completed without output node completion event'));
this.log.warn('Queue item completed without output node completion event');
return;
}
// event.status is 'failed', 'canceled' - something has gone awry
_clearTimeout();
clearListeners();
if (event.status === 'failed') {
// We expect the event to have error details, but technically it's possible that it doesn't
const { error_type, error_message, error_traceback } = event;
if (error_type && error_message && error_traceback) {
reject(new QueueError(error_type, error_message, error_traceback));
} else {
reject(new Error('Queue item failed, but no error details were provided'));
}
} else {
// event.status is 'canceled'
reject(new Error('Graph canceled'));
}
};
// We are ready to enqueue the graph
const enqueueRequest = this.store.dispatch(
queueApi.endpoints.enqueueBatch.initiate(batch, {
// Use the same cache key for all enqueueBatch requests, so that all consumers of this query get the same status
// updates.
...enqueueMutationFixedCacheKeyOptions,
// We do not need RTK to track this request in the store
track: false,
})
);
// Enqueue the graph and get the batch_id, updating the cancel graph callack. We need to do this in a .then() block
// instead of awaiting the promise to avoid await-ing in a promise executor. Also need to catch any errors.
enqueueRequest
.unwrap()
.then((data) => {
// The `batch_id` should _always_ be present - the OpenAPI schema from which the types are generated is incorrect.
// TODO(psyche): Fix the OpenAPI schema.
const batch_id = data.batch.batch_id;
assert(batch_id, 'Enqueue result is missing batch_id');
cancelGraph = () => {
this.store.dispatch(
queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: [batch_id] }, { track: false })
);
};
})
.catch((error) => {
reject(error);
});
this.manager.socket.on('invocation_complete', invocationCompleteHandler);
this.manager.socket.on('queue_item_status_changed', queueItemStatusChangedHandler);
const clearListeners = () => {
this.manager.socket.off('invocation_complete', invocationCompleteHandler);
this.manager.socket.off('queue_item_status_changed', queueItemStatusChangedHandler);
};
if (timeout) {
timeoutId = window.setTimeout(() => {
if (didSuceed) {
// If we already succeeded, we don't need to do anything
return;
}
this.log.trace('Graph canceled by timeout');
clearListeners();
cancelGraph();
reject(new Error('Graph timed out'));
}, timeout);
}
if (signal) {
signal.addEventListener('abort', () => {
if (didSuceed) {
// If we already succeeded, we don't need to do anything
return;
}
this.log.trace('Graph canceled by signal');
_clearTimeout();
clearListeners();
cancelGraph();
reject(new Error('Graph canceled'));
});
}
timeout,
signal,
});
return resultPromise;
// Extract the image from the result - we expect a single image
const imageDTO = await this.getImageDTOFromResult(output);
return imageDTO;
};
/**
* Helper function to extract ImageDTO from graph execution result.
* Expects the result to be an ImageOutput.
*/
private getImageDTOFromResult = async (result: S['GraphExecutionState']['results'][string]): Promise<ImageDTO> => {
// Validate that the result is an ImageOutput using zod schema
const parseResult = zImageOutput.safeParse(result);
if (!parseResult.success) {
throw new Error(`Graph output is not a valid ImageOutput. Got: ${JSON.stringify(result)}`);
}
const imageOutput = parseResult.data;
return await getImageDTO(imageOutput.image.image_name);
};
/**

View File

@@ -187,7 +187,7 @@ export type ProgressImage = z.infer<typeof zProgressImage>;
// #endregion
// #region ImageOutput
const zImageOutput = z.object({
export const zImageOutput = z.object({
image: zImageField,
width: z.number().int().gt(0),
height: z.number().int().gt(0),

View File

@@ -47,7 +47,7 @@ type RunGraphReturn = {
/**
* Creates production dependencies for runGraph using Redux store and socket.
*/
export const createProductionDependencies = (
export const buildRunGraphDependencies = (
store: AppStore,
socket: {
on: (event: 'queue_item_status_changed', handler: (event: S['QueueItemStatusChangedEvent']) => void) => void;
@@ -95,7 +95,7 @@ export const createProductionDependencies = (
* @example
*
* ```ts
* const dependencies = createProductionDependencies(store, socket);
* const dependencies = buildRunGraphDependencies(store, socket);
* const graph = new Graph();
* const outputNode = graph.addNode({ id: 'my-resize-node', type: 'img_resize', image: { image_name: 'my-image.png' } });
* const controller = new AbortController();