From 8dc6d0b5ae98e4e97032e823d2f893d53500b5c0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 28 Jun 2025 18:47:40 +1000 Subject: [PATCH] feat(ui): use runGraph in canvas --- .../konva/CanvasStateApiModule.ts | 230 +++--------------- .../web/src/features/nodes/types/common.ts | 2 +- .../web/src/services/api/run-graph.ts | 4 +- 3 files changed, 34 insertions(+), 202 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts index 26fc3d8c24..f03f90fc96 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts @@ -2,7 +2,6 @@ import { $alt, $ctrl, $meta, $shift } from '@invoke-ai/ui-library'; import type { Selector } from '@reduxjs/toolkit'; import { addAppListener } from 'app/store/middleware/listenerMiddleware'; import type { AppStore, RootState } from 'app/store/store'; -import { withResultAsync } from 'common/util/result'; import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer'; import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; @@ -49,15 +48,14 @@ import type { RgbaColor, } from 'features/controlLayers/store/types'; import { RGBA_BLACK } from 'features/controlLayers/store/types'; +import { zImageOutput } from 'features/nodes/types/common'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import { atom, computed } from 'nanostores'; import type { Logger } from 'roarr'; import { getImageDTO } from 'services/api/endpoints/images'; -import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue'; -import type { EnqueueBatchArg, ImageDTO, S } from 'services/api/types'; -import { QueueError } from 'services/events/errors'; +import { buildRunGraphDependencies, runGraph } from 'services/api/run-graph'; +import type { ImageDTO, S } from 'services/api/types'; import type { Param0 } from 'tsafe'; -import { assert } from 'tsafe'; import type { CanvasEntityAdapter } from './CanvasEntity/types'; @@ -266,7 +264,7 @@ export class CanvasStateApiModule extends CanvasModuleBase { * controller.abort(); * ``` */ - runGraphAndReturnImageOutput = (arg: { + runGraphAndReturnImageOutput = async (arg: { graph: Graph; outputNodeId: string; destination?: string; @@ -276,203 +274,37 @@ export class CanvasStateApiModule extends CanvasModuleBase { }): Promise => { const { graph, outputNodeId, destination, prepend, timeout, signal } = arg; - if (!graph.hasNode(outputNodeId)) { - throw new Error(`Graph does not contain node with id: ${outputNodeId}`); - } + const dependencies = buildRunGraphDependencies(this.store, this.manager.socket); - /** - * We will use the origin to handle events from the graph. Ideally we'd just use the queue item's id, but there's a - * race condition: - * - The queue item id is not available until the graph is enqueued - * - The graph may complete before we can set up the listeners to handle the completion event - * - * The origin is the only unique identifier we have that is guaranteed to be available before the graph is enqueued, - * so we will use that to filter events. - */ - const origin = getPrefixedId(graph.id); - - const batch: EnqueueBatchArg = { + const { output } = await runGraph({ + graph, + outputNodeId, + dependencies, + destination, prepend, - batch: { - graph: graph.getGraph(), - origin, - destination, - runs: 1, - }, - }; - - let didSuceed = false; - - /** - * If a timeout is provided, we will cancel the graph if it takes too long - but we need a way to clear the timeout - * if the graph completes or errors before the timeout. - */ - let timeoutId: number | null = null; - const _clearTimeout = () => { - if (timeoutId !== null) { - window.clearTimeout(timeoutId); - timeoutId = null; - } - }; - - // There's a bit of a catch-22 here: we need to set the cancelGraph callback before we enqueue the graph, but we - // can't set it until we have the batch_id from the enqueue request. So we'll set a dummy function here and update - // it later. - let cancelGraph: () => void = () => { - this.log.warn('cancelGraph called before cancelGraph is set'); - }; - - const resultPromise = new Promise((resolve, reject) => { - const invocationCompleteHandler = async (event: S['InvocationCompleteEvent']) => { - // Ignore events that are not for this graph - if (event.origin !== origin) { - return; - } - - // Ignore events that are not from the output node - if (event.invocation_source_id !== outputNodeId) { - return; - } - - // If we get here, the event is for the correct graph and output node. - - // Clear the timeout and socket listeners - _clearTimeout(); - clearListeners(); - - // The result must be an image output - const { result } = event; - if (result.type !== 'image_output') { - reject(new Error(`Graph output node did not return an image output, got: ${result}`)); - return; - } - - // Get the result image DTO - const getImageDTOResult = await withResultAsync(() => getImageDTO(result.image.image_name)); - if (getImageDTOResult.isErr()) { - reject(getImageDTOResult.error); - return; - } - - didSuceed = true; - - // Ok! - resolve(getImageDTOResult.value); - }; - - const queueItemStatusChangedHandler = (event: S['QueueItemStatusChangedEvent']) => { - // Ignore events that are not for this graph - if (event.origin !== origin) { - return; - } - - // Ignore events where the status is pending or in progress - no need to do anything for these - if (event.status === 'pending' || event.status === 'in_progress') { - return; - } - - if (event.status === 'completed') { - /** - * The invocation_complete event should have been received before the queue item completed event, and the - * event listeners are cleared in the invocation_complete handler. If we get here, it means we never got - * the completion event for the output node! This should is a fail case. - * - * TODO(psyche): In the unexpected case where events are received out of order, this logic doesn't do what - * we expect. If we got a queue item completed event before the output node completion event, we'd erroneously - * triggers this error. - * - * For now, we'll just log a warning instead of rejecting the promise. This should be super rare anyways. - */ - // reject(new Error('Queue item completed without output node completion event')); - this.log.warn('Queue item completed without output node completion event'); - return; - } - - // event.status is 'failed', 'canceled' - something has gone awry - _clearTimeout(); - clearListeners(); - - if (event.status === 'failed') { - // We expect the event to have error details, but technically it's possible that it doesn't - const { error_type, error_message, error_traceback } = event; - if (error_type && error_message && error_traceback) { - reject(new QueueError(error_type, error_message, error_traceback)); - } else { - reject(new Error('Queue item failed, but no error details were provided')); - } - } else { - // event.status is 'canceled' - reject(new Error('Graph canceled')); - } - }; - - // We are ready to enqueue the graph - const enqueueRequest = this.store.dispatch( - queueApi.endpoints.enqueueBatch.initiate(batch, { - // Use the same cache key for all enqueueBatch requests, so that all consumers of this query get the same status - // updates. - ...enqueueMutationFixedCacheKeyOptions, - // We do not need RTK to track this request in the store - track: false, - }) - ); - - // Enqueue the graph and get the batch_id, updating the cancel graph callack. We need to do this in a .then() block - // instead of awaiting the promise to avoid await-ing in a promise executor. Also need to catch any errors. - enqueueRequest - .unwrap() - .then((data) => { - // The `batch_id` should _always_ be present - the OpenAPI schema from which the types are generated is incorrect. - // TODO(psyche): Fix the OpenAPI schema. - const batch_id = data.batch.batch_id; - assert(batch_id, 'Enqueue result is missing batch_id'); - cancelGraph = () => { - this.store.dispatch( - queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: [batch_id] }, { track: false }) - ); - }; - }) - .catch((error) => { - reject(error); - }); - - this.manager.socket.on('invocation_complete', invocationCompleteHandler); - this.manager.socket.on('queue_item_status_changed', queueItemStatusChangedHandler); - - const clearListeners = () => { - this.manager.socket.off('invocation_complete', invocationCompleteHandler); - this.manager.socket.off('queue_item_status_changed', queueItemStatusChangedHandler); - }; - - if (timeout) { - timeoutId = window.setTimeout(() => { - if (didSuceed) { - // If we already succeeded, we don't need to do anything - return; - } - this.log.trace('Graph canceled by timeout'); - clearListeners(); - cancelGraph(); - reject(new Error('Graph timed out')); - }, timeout); - } - - if (signal) { - signal.addEventListener('abort', () => { - if (didSuceed) { - // If we already succeeded, we don't need to do anything - return; - } - this.log.trace('Graph canceled by signal'); - _clearTimeout(); - clearListeners(); - cancelGraph(); - reject(new Error('Graph canceled')); - }); - } + timeout, + signal, }); - return resultPromise; + // Extract the image from the result - we expect a single image + const imageDTO = await this.getImageDTOFromResult(output); + + return imageDTO; + }; + + /** + * Helper function to extract ImageDTO from graph execution result. + * Expects the result to be an ImageOutput. + */ + private getImageDTOFromResult = async (result: S['GraphExecutionState']['results'][string]): Promise => { + // Validate that the result is an ImageOutput using zod schema + const parseResult = zImageOutput.safeParse(result); + if (!parseResult.success) { + throw new Error(`Graph output is not a valid ImageOutput. Got: ${JSON.stringify(result)}`); + } + + const imageOutput = parseResult.data; + return await getImageDTO(imageOutput.image.image_name); }; /** diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index d293c5df05..268bfc9b5f 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -187,7 +187,7 @@ export type ProgressImage = z.infer; // #endregion // #region ImageOutput -const zImageOutput = z.object({ +export const zImageOutput = z.object({ image: zImageField, width: z.number().int().gt(0), height: z.number().int().gt(0), diff --git a/invokeai/frontend/web/src/services/api/run-graph.ts b/invokeai/frontend/web/src/services/api/run-graph.ts index 0823313a7e..84354672da 100644 --- a/invokeai/frontend/web/src/services/api/run-graph.ts +++ b/invokeai/frontend/web/src/services/api/run-graph.ts @@ -47,7 +47,7 @@ type RunGraphReturn = { /** * Creates production dependencies for runGraph using Redux store and socket. */ -export const createProductionDependencies = ( +export const buildRunGraphDependencies = ( store: AppStore, socket: { on: (event: 'queue_item_status_changed', handler: (event: S['QueueItemStatusChangedEvent']) => void) => void; @@ -95,7 +95,7 @@ export const createProductionDependencies = ( * @example * * ```ts - * const dependencies = createProductionDependencies(store, socket); + * const dependencies = buildRunGraphDependencies(store, socket); * const graph = new Graph(); * const outputNode = graph.addNode({ id: 'my-resize-node', type: 'img_resize', image: { image_name: 'my-image.png' } }); * const controller = new AbortController();