From 6e6bbc5fa6913d2c1e207c6561777e320bedd1c6 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 24 Sep 2024 06:31:40 +1000 Subject: [PATCH] fix(ui): race condition when filtering There's a situation in which the enqueue response comes after the graph actually executes. This was unexpected when I first wrote the logic. I suppose it has to do with the async endpoint handling. --- .../CanvasEntity/CanvasEntityFilterer.ts | 1 + .../konva/CanvasStateApiModule.ts | 63 +++++++++++-------- 2 files changed, 38 insertions(+), 26 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer.ts index 3fd501a468..43206f273f 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer.ts @@ -141,6 +141,7 @@ export class CanvasEntityFilterer extends CanvasModuleBase { this.abortController = controller; const { graph, outputNodeId } = buildGraphResult.value; + const filterResult = await withResultAsync(() => this.manager.stateApi.runGraphAndReturnImageOutput({ graph, diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts index 9dd2308aed..c46e3f3a12 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts @@ -225,7 +225,7 @@ export class CanvasStateApiModule extends CanvasModuleBase { * controller.abort(); * ``` */ - runGraphAndReturnImageOutput = async (arg: { + runGraphAndReturnImageOutput = (arg: { graph: Graph; outputNodeId: string; destination?: string; @@ -268,27 +268,12 @@ export class CanvasStateApiModule extends CanvasModuleBase { } }; - /** - * First, enqueue the graph - we need the `batch_id` to cancel the graph. But to get the `batch_id`, we need to - * `await` the request. You might be tempted to `await` the request inside the result promise, but we should not - * `await` inside a promise executor. - * - * See: https://eslint.org/docs/latest/rules/no-async-promise-executor - */ - const enqueueRequest = this.store.dispatch( - queueApi.endpoints.enqueueBatch.initiate(batch, { - // Use the same cache key for all enqueueBatch requests, so that all consumers of this query get the same status - // updates. - fixedCacheKey: 'enqueueBatch', - // We do not need RTK to track this request in the store - track: false, - }) - ); - - // The `batch_id` should _always_ be present - the OpenAPI schema from which the types are generated is incorrect. - // TODO(psyche): Fix the OpenAPI schema. - const { batch_id } = (await enqueueRequest.unwrap()).batch; - assert(batch_id, 'Enqueue result is missing batch_id'); + // There's a bit of a catch-22 here: we need to set the cancelGraph callback before we enqueue the graph, but we + // can't set it until we have the batch_id from the enqueue request. So we'll set a dummy function here and update + // it later. + let cancelGraph: () => void = () => { + this.log.warn('cancelGraph called before cancelGraph is set'); + }; const resultPromise = new Promise((resolve, reject) => { const invocationCompleteHandler = async (event: S['InvocationCompleteEvent']) => { @@ -357,6 +342,36 @@ export class CanvasStateApiModule extends CanvasModuleBase { } }; + // We are ready to enqueue the graph + const enqueueRequest = this.store.dispatch( + queueApi.endpoints.enqueueBatch.initiate(batch, { + // Use the same cache key for all enqueueBatch requests, so that all consumers of this query get the same status + // updates. + fixedCacheKey: 'enqueueBatch', + // We do not need RTK to track this request in the store + track: false, + }) + ); + + // Enqueue the graph and get the batch_id, updating the cancel graph callack. We need to do this in a .then() block + // instead of awaiting the promise to avoid await-ing in a promise executor. Also need to catch any errors. + enqueueRequest + .unwrap() + .then((data) => { + // The `batch_id` should _always_ be present - the OpenAPI schema from which the types are generated is incorrect. + // TODO(psyche): Fix the OpenAPI schema. + const batch_id = data.batch.batch_id; + assert(batch_id, 'Enqueue result is missing batch_id'); + cancelGraph = () => { + this.store.dispatch( + queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: [batch_id] }, { track: false }) + ); + }; + }) + .catch((error) => { + reject(error); + }); + this.manager.socket.on('invocation_complete', invocationCompleteHandler); this.manager.socket.on('queue_item_status_changed', queueItemStatusChangedHandler); @@ -365,10 +380,6 @@ export class CanvasStateApiModule extends CanvasModuleBase { this.manager.socket.off('queue_item_status_changed', queueItemStatusChangedHandler); }; - const cancelGraph = () => { - this.store.dispatch(queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: [batch_id] }, { track: false })); - }; - if (timeout) { timeoutId = window.setTimeout(() => { this.log.trace('Graph canceled by timeout');