mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-13 18:25:28 -05:00
feat(ui): use runGraph in canvas
This commit is contained in:
@@ -2,7 +2,6 @@ import { $alt, $ctrl, $meta, $shift } from '@invoke-ai/ui-library';
|
||||
import type { Selector } from '@reduxjs/toolkit';
|
||||
import { addAppListener } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppStore, RootState } from 'app/store/store';
|
||||
import { withResultAsync } from 'common/util/result';
|
||||
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
|
||||
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
@@ -49,15 +48,14 @@ import type {
|
||||
RgbaColor,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { RGBA_BLACK } from 'features/controlLayers/store/types';
|
||||
import { zImageOutput } from 'features/nodes/types/common';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { atom, computed } from 'nanostores';
|
||||
import type { Logger } from 'roarr';
|
||||
import { getImageDTO } from 'services/api/endpoints/images';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
import type { EnqueueBatchArg, ImageDTO, S } from 'services/api/types';
|
||||
import { QueueError } from 'services/events/errors';
|
||||
import { buildRunGraphDependencies, runGraph } from 'services/api/run-graph';
|
||||
import type { ImageDTO, S } from 'services/api/types';
|
||||
import type { Param0 } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import type { CanvasEntityAdapter } from './CanvasEntity/types';
|
||||
|
||||
@@ -266,7 +264,7 @@ export class CanvasStateApiModule extends CanvasModuleBase {
|
||||
* controller.abort();
|
||||
* ```
|
||||
*/
|
||||
runGraphAndReturnImageOutput = (arg: {
|
||||
runGraphAndReturnImageOutput = async (arg: {
|
||||
graph: Graph;
|
||||
outputNodeId: string;
|
||||
destination?: string;
|
||||
@@ -276,203 +274,37 @@ export class CanvasStateApiModule extends CanvasModuleBase {
|
||||
}): Promise<ImageDTO> => {
|
||||
const { graph, outputNodeId, destination, prepend, timeout, signal } = arg;
|
||||
|
||||
if (!graph.hasNode(outputNodeId)) {
|
||||
throw new Error(`Graph does not contain node with id: ${outputNodeId}`);
|
||||
}
|
||||
const dependencies = buildRunGraphDependencies(this.store, this.manager.socket);
|
||||
|
||||
/**
|
||||
* We will use the origin to handle events from the graph. Ideally we'd just use the queue item's id, but there's a
|
||||
* race condition:
|
||||
* - The queue item id is not available until the graph is enqueued
|
||||
* - The graph may complete before we can set up the listeners to handle the completion event
|
||||
*
|
||||
* The origin is the only unique identifier we have that is guaranteed to be available before the graph is enqueued,
|
||||
* so we will use that to filter events.
|
||||
*/
|
||||
const origin = getPrefixedId(graph.id);
|
||||
|
||||
const batch: EnqueueBatchArg = {
|
||||
const { output } = await runGraph({
|
||||
graph,
|
||||
outputNodeId,
|
||||
dependencies,
|
||||
destination,
|
||||
prepend,
|
||||
batch: {
|
||||
graph: graph.getGraph(),
|
||||
origin,
|
||||
destination,
|
||||
runs: 1,
|
||||
},
|
||||
};
|
||||
|
||||
let didSuceed = false;
|
||||
|
||||
/**
|
||||
* If a timeout is provided, we will cancel the graph if it takes too long - but we need a way to clear the timeout
|
||||
* if the graph completes or errors before the timeout.
|
||||
*/
|
||||
let timeoutId: number | null = null;
|
||||
const _clearTimeout = () => {
|
||||
if (timeoutId !== null) {
|
||||
window.clearTimeout(timeoutId);
|
||||
timeoutId = null;
|
||||
}
|
||||
};
|
||||
|
||||
// There's a bit of a catch-22 here: we need to set the cancelGraph callback before we enqueue the graph, but we
|
||||
// can't set it until we have the batch_id from the enqueue request. So we'll set a dummy function here and update
|
||||
// it later.
|
||||
let cancelGraph: () => void = () => {
|
||||
this.log.warn('cancelGraph called before cancelGraph is set');
|
||||
};
|
||||
|
||||
const resultPromise = new Promise<ImageDTO>((resolve, reject) => {
|
||||
const invocationCompleteHandler = async (event: S['InvocationCompleteEvent']) => {
|
||||
// Ignore events that are not for this graph
|
||||
if (event.origin !== origin) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Ignore events that are not from the output node
|
||||
if (event.invocation_source_id !== outputNodeId) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If we get here, the event is for the correct graph and output node.
|
||||
|
||||
// Clear the timeout and socket listeners
|
||||
_clearTimeout();
|
||||
clearListeners();
|
||||
|
||||
// The result must be an image output
|
||||
const { result } = event;
|
||||
if (result.type !== 'image_output') {
|
||||
reject(new Error(`Graph output node did not return an image output, got: ${result}`));
|
||||
return;
|
||||
}
|
||||
|
||||
// Get the result image DTO
|
||||
const getImageDTOResult = await withResultAsync(() => getImageDTO(result.image.image_name));
|
||||
if (getImageDTOResult.isErr()) {
|
||||
reject(getImageDTOResult.error);
|
||||
return;
|
||||
}
|
||||
|
||||
didSuceed = true;
|
||||
|
||||
// Ok!
|
||||
resolve(getImageDTOResult.value);
|
||||
};
|
||||
|
||||
const queueItemStatusChangedHandler = (event: S['QueueItemStatusChangedEvent']) => {
|
||||
// Ignore events that are not for this graph
|
||||
if (event.origin !== origin) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Ignore events where the status is pending or in progress - no need to do anything for these
|
||||
if (event.status === 'pending' || event.status === 'in_progress') {
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.status === 'completed') {
|
||||
/**
|
||||
* The invocation_complete event should have been received before the queue item completed event, and the
|
||||
* event listeners are cleared in the invocation_complete handler. If we get here, it means we never got
|
||||
* the completion event for the output node! This should is a fail case.
|
||||
*
|
||||
* TODO(psyche): In the unexpected case where events are received out of order, this logic doesn't do what
|
||||
* we expect. If we got a queue item completed event before the output node completion event, we'd erroneously
|
||||
* triggers this error.
|
||||
*
|
||||
* For now, we'll just log a warning instead of rejecting the promise. This should be super rare anyways.
|
||||
*/
|
||||
// reject(new Error('Queue item completed without output node completion event'));
|
||||
this.log.warn('Queue item completed without output node completion event');
|
||||
return;
|
||||
}
|
||||
|
||||
// event.status is 'failed', 'canceled' - something has gone awry
|
||||
_clearTimeout();
|
||||
clearListeners();
|
||||
|
||||
if (event.status === 'failed') {
|
||||
// We expect the event to have error details, but technically it's possible that it doesn't
|
||||
const { error_type, error_message, error_traceback } = event;
|
||||
if (error_type && error_message && error_traceback) {
|
||||
reject(new QueueError(error_type, error_message, error_traceback));
|
||||
} else {
|
||||
reject(new Error('Queue item failed, but no error details were provided'));
|
||||
}
|
||||
} else {
|
||||
// event.status is 'canceled'
|
||||
reject(new Error('Graph canceled'));
|
||||
}
|
||||
};
|
||||
|
||||
// We are ready to enqueue the graph
|
||||
const enqueueRequest = this.store.dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batch, {
|
||||
// Use the same cache key for all enqueueBatch requests, so that all consumers of this query get the same status
|
||||
// updates.
|
||||
...enqueueMutationFixedCacheKeyOptions,
|
||||
// We do not need RTK to track this request in the store
|
||||
track: false,
|
||||
})
|
||||
);
|
||||
|
||||
// Enqueue the graph and get the batch_id, updating the cancel graph callack. We need to do this in a .then() block
|
||||
// instead of awaiting the promise to avoid await-ing in a promise executor. Also need to catch any errors.
|
||||
enqueueRequest
|
||||
.unwrap()
|
||||
.then((data) => {
|
||||
// The `batch_id` should _always_ be present - the OpenAPI schema from which the types are generated is incorrect.
|
||||
// TODO(psyche): Fix the OpenAPI schema.
|
||||
const batch_id = data.batch.batch_id;
|
||||
assert(batch_id, 'Enqueue result is missing batch_id');
|
||||
cancelGraph = () => {
|
||||
this.store.dispatch(
|
||||
queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: [batch_id] }, { track: false })
|
||||
);
|
||||
};
|
||||
})
|
||||
.catch((error) => {
|
||||
reject(error);
|
||||
});
|
||||
|
||||
this.manager.socket.on('invocation_complete', invocationCompleteHandler);
|
||||
this.manager.socket.on('queue_item_status_changed', queueItemStatusChangedHandler);
|
||||
|
||||
const clearListeners = () => {
|
||||
this.manager.socket.off('invocation_complete', invocationCompleteHandler);
|
||||
this.manager.socket.off('queue_item_status_changed', queueItemStatusChangedHandler);
|
||||
};
|
||||
|
||||
if (timeout) {
|
||||
timeoutId = window.setTimeout(() => {
|
||||
if (didSuceed) {
|
||||
// If we already succeeded, we don't need to do anything
|
||||
return;
|
||||
}
|
||||
this.log.trace('Graph canceled by timeout');
|
||||
clearListeners();
|
||||
cancelGraph();
|
||||
reject(new Error('Graph timed out'));
|
||||
}, timeout);
|
||||
}
|
||||
|
||||
if (signal) {
|
||||
signal.addEventListener('abort', () => {
|
||||
if (didSuceed) {
|
||||
// If we already succeeded, we don't need to do anything
|
||||
return;
|
||||
}
|
||||
this.log.trace('Graph canceled by signal');
|
||||
_clearTimeout();
|
||||
clearListeners();
|
||||
cancelGraph();
|
||||
reject(new Error('Graph canceled'));
|
||||
});
|
||||
}
|
||||
timeout,
|
||||
signal,
|
||||
});
|
||||
|
||||
return resultPromise;
|
||||
// Extract the image from the result - we expect a single image
|
||||
const imageDTO = await this.getImageDTOFromResult(output);
|
||||
|
||||
return imageDTO;
|
||||
};
|
||||
|
||||
/**
|
||||
* Helper function to extract ImageDTO from graph execution result.
|
||||
* Expects the result to be an ImageOutput.
|
||||
*/
|
||||
private getImageDTOFromResult = async (result: S['GraphExecutionState']['results'][string]): Promise<ImageDTO> => {
|
||||
// Validate that the result is an ImageOutput using zod schema
|
||||
const parseResult = zImageOutput.safeParse(result);
|
||||
if (!parseResult.success) {
|
||||
throw new Error(`Graph output is not a valid ImageOutput. Got: ${JSON.stringify(result)}`);
|
||||
}
|
||||
|
||||
const imageOutput = parseResult.data;
|
||||
return await getImageDTO(imageOutput.image.image_name);
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -187,7 +187,7 @@ export type ProgressImage = z.infer<typeof zProgressImage>;
|
||||
// #endregion
|
||||
|
||||
// #region ImageOutput
|
||||
const zImageOutput = z.object({
|
||||
export const zImageOutput = z.object({
|
||||
image: zImageField,
|
||||
width: z.number().int().gt(0),
|
||||
height: z.number().int().gt(0),
|
||||
|
||||
@@ -47,7 +47,7 @@ type RunGraphReturn = {
|
||||
/**
|
||||
* Creates production dependencies for runGraph using Redux store and socket.
|
||||
*/
|
||||
export const createProductionDependencies = (
|
||||
export const buildRunGraphDependencies = (
|
||||
store: AppStore,
|
||||
socket: {
|
||||
on: (event: 'queue_item_status_changed', handler: (event: S['QueueItemStatusChangedEvent']) => void) => void;
|
||||
@@ -95,7 +95,7 @@ export const createProductionDependencies = (
|
||||
* @example
|
||||
*
|
||||
* ```ts
|
||||
* const dependencies = createProductionDependencies(store, socket);
|
||||
* const dependencies = buildRunGraphDependencies(store, socket);
|
||||
* const graph = new Graph();
|
||||
* const outputNode = graph.addNode({ id: 'my-resize-node', type: 'img_resize', image: { image_name: 'my-image.png' } });
|
||||
* const controller = new AbortController();
|
||||
|
||||
Reference in New Issue
Block a user