mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-13 21:25:04 -05:00
refactor(ui): better race condition handling in runGraph
This commit is contained in:
@@ -57,7 +57,7 @@ export class Err<E> {
|
||||
* @template T The type of the value in the `Ok` case.
|
||||
* @template E The type of the error in the `Err` case.
|
||||
*/
|
||||
type Result<T, E> = Ok<T> | Err<E>;
|
||||
export type Result<T, E = Error> = Ok<T> | Err<E>;
|
||||
|
||||
/**
|
||||
* Creates a successful result.
|
||||
@@ -85,12 +85,11 @@ export function ErrResult<E>(error: E): Err<E> {
|
||||
* @param {() => T} fn The function to execute.
|
||||
* @returns {Result<T>} An `Ok` result if the function succeeds, or an `Err` result if it throws an error.
|
||||
*/
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
export function withResult<T>(fn: () => T): Result<T, any> {
|
||||
export function withResult<T>(fn: () => T): Result<T> {
|
||||
try {
|
||||
return new Ok(fn());
|
||||
} catch (error) {
|
||||
return new Err(error);
|
||||
return new Err(error instanceof Error ? error : new WrappedError(error));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,12 +99,21 @@ export function withResult<T>(fn: () => T): Result<T, any> {
|
||||
* @param {() => Promise<T>} fn The asynchronous function to execute.
|
||||
* @returns {Promise<Result<T>>} A `Promise` resolving to an `Ok` result if the function succeeds, or an `Err` result if it throws an error.
|
||||
*/
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
export async function withResultAsync<T>(fn: () => Promise<T>): Promise<Result<T, any>> {
|
||||
export async function withResultAsync<T>(fn: () => Promise<T>): Promise<Result<T>> {
|
||||
try {
|
||||
const result = await fn();
|
||||
return new Ok(result);
|
||||
} catch (error) {
|
||||
return new Err(error);
|
||||
return new Err(error instanceof Error ? error : new WrappedError(error));
|
||||
}
|
||||
}
|
||||
|
||||
export class WrappedError extends Error {
|
||||
data: unknown;
|
||||
|
||||
constructor(data: unknown) {
|
||||
super('Wrapped Error');
|
||||
this.name = this.constructor.name;
|
||||
this.data = data;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStore } from 'app/store/store';
|
||||
import { withResult, withResultAsync } from 'common/util/result';
|
||||
import { Mutex } from 'async-mutex';
|
||||
import type { Result } from 'common/util/result';
|
||||
import { ErrResult, OkResult, withResult, withResultAsync } from 'common/util/result';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import type { S } from 'services/api/types';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from './endpoints/queue';
|
||||
@@ -147,12 +150,6 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
|
||||
},
|
||||
};
|
||||
|
||||
/**
|
||||
* Flag to indicate whether we have finished with the business logic of executing the graph. This is used to
|
||||
* prevent multiple promise resolutions. This flag must be set to true before the promise is resolved or rejected.
|
||||
*/
|
||||
let isFinished = false;
|
||||
|
||||
/**
|
||||
* The queue item id is set to null initially, but will be updated once the graph is enqueued. It will be used to
|
||||
* retrieve the queue item.
|
||||
@@ -173,23 +170,51 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
|
||||
}
|
||||
};
|
||||
|
||||
// If a timeout value is provided, we create a timer to reject the promise.
|
||||
if (timeout !== undefined) {
|
||||
const timeoutId = setTimeout(() => {
|
||||
if (isFinished) {
|
||||
/**
|
||||
* We use a mutex to ensure that the promise is resolved or rejected only once, even if multiple events
|
||||
* are received or the settle function is called multiple times.
|
||||
*
|
||||
* A flag allows pending locks to bail if the promise has already been settled.
|
||||
*/
|
||||
let isSettling = false;
|
||||
const settlementMutex = new Mutex();
|
||||
const settle = async (settlement: () => Promise<Result<RunGraphReturn, Error>> | Result<RunGraphReturn, Error>) => {
|
||||
await settlementMutex.runExclusive(async () => {
|
||||
// If we are already settling, ignore this call to avoid multiple resolutions or rejections.
|
||||
// We don't want to _cancel_ pending locks as this would raise.
|
||||
if (isSettling) {
|
||||
return;
|
||||
}
|
||||
isFinished = true;
|
||||
log.trace('Graph canceled by timeout');
|
||||
isSettling = true;
|
||||
|
||||
// Clean up listeners, timeouts, etc. ASAP.
|
||||
cleanup();
|
||||
if (queueItemId !== null) {
|
||||
// It's possible the cancelation will fail, but we have no way to handle that gracefully. Log a warning
|
||||
// and move on to reject.
|
||||
dependencies.executor.cancelQueueItem(queueItemId).catch((error) => {
|
||||
log.warn({ error: parseify(error) }, 'Failed to cancel queue item during timeout');
|
||||
});
|
||||
|
||||
// Normalize the settlement function to always return a promise.
|
||||
const result = await Promise.resolve(settlement());
|
||||
|
||||
if (result.isOk()) {
|
||||
resolve(result.value);
|
||||
} else {
|
||||
reject(result.error);
|
||||
}
|
||||
reject(new SessionTimeoutError(queueItemId));
|
||||
});
|
||||
};
|
||||
|
||||
// If a timeout value is provided, we create a timer to reject the promise.
|
||||
if (timeout !== undefined) {
|
||||
const timeoutId = setTimeout(async () => {
|
||||
await settle(() => {
|
||||
log.trace('Graph canceled by timeout');
|
||||
if (queueItemId !== null) {
|
||||
// It's possible the cancelation will fail, but we have no way to handle that gracefully. Log a warning
|
||||
// and move on to reject.
|
||||
dependencies.executor.cancelQueueItem(queueItemId).catch((error) => {
|
||||
log.warn({ error: parseify(error) }, 'Failed to cancel queue item during timeout');
|
||||
});
|
||||
}
|
||||
return ErrResult(new SessionTimeoutError(queueItemId));
|
||||
});
|
||||
}, timeout);
|
||||
|
||||
cleanupFunctions.add(() => {
|
||||
@@ -200,20 +225,17 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
|
||||
// If a signal is provided, we add an abort handler to reject the promise if the signal is aborted.
|
||||
if (signal !== undefined) {
|
||||
const abortHandler = () => {
|
||||
if (isFinished) {
|
||||
return;
|
||||
}
|
||||
isFinished = true;
|
||||
log.trace('Graph canceled by signal');
|
||||
cleanup();
|
||||
if (queueItemId !== null) {
|
||||
// It's possible the cancelation will fail, but we have no way to handle that gracefully. Log a warning
|
||||
// and move on to reject.
|
||||
dependencies.executor.cancelQueueItem(queueItemId).catch((error) => {
|
||||
log.warn({ error: parseify(error) }, 'Failed to cancel queue item during abort');
|
||||
});
|
||||
}
|
||||
reject(new SessionAbortedError(queueItemId));
|
||||
settle(() => {
|
||||
log.trace('Graph canceled by signal');
|
||||
if (queueItemId !== null) {
|
||||
// It's possible the cancelation will fail, but we have no way to handle that gracefully. Log a warning
|
||||
// and move on to reject.
|
||||
dependencies.executor.cancelQueueItem(queueItemId).catch((error) => {
|
||||
log.warn({ error: parseify(error) }, 'Failed to cancel queue item during abort');
|
||||
});
|
||||
}
|
||||
return ErrResult(new SessionAbortedError(queueItemId));
|
||||
});
|
||||
};
|
||||
|
||||
signal.addEventListener('abort', abortHandler);
|
||||
@@ -224,10 +246,6 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
|
||||
|
||||
// Handle the queue item status change events.
|
||||
const onQueueItemStatusChanged = async (event: S['QueueItemStatusChangedEvent']) => {
|
||||
if (isFinished) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Ignore events that are not for this graph
|
||||
if (event.origin !== origin) {
|
||||
return;
|
||||
@@ -238,42 +256,39 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
|
||||
return;
|
||||
}
|
||||
|
||||
// The queue item is finished - retrieve it, extract results and resolve or reject the promise
|
||||
isFinished = true;
|
||||
cleanup();
|
||||
|
||||
// We need to handle any errors, including retrieving the queue item
|
||||
const queueItemResult = await withResultAsync(() => dependencies.executor.getQueueItem(event.item_id));
|
||||
if (queueItemResult.isErr()) {
|
||||
reject(queueItemResult.error);
|
||||
return;
|
||||
}
|
||||
|
||||
const queueItem = queueItemResult.value;
|
||||
|
||||
const { status, session, error_type, error_message, error_traceback } = queueItem;
|
||||
|
||||
if (status === 'completed') {
|
||||
const getOutputResult = withResult(() => getOutputFromSession(queueItemId, session, outputNodeId));
|
||||
if (getOutputResult.isErr()) {
|
||||
reject(getOutputResult.error);
|
||||
return;
|
||||
await settle(async () => {
|
||||
// We need to handle any errors, including retrieving the queue item
|
||||
const queueItemResult = await withResultAsync(() => dependencies.executor.getQueueItem(event.item_id));
|
||||
if (queueItemResult.isErr()) {
|
||||
return ErrResult(queueItemResult.error);
|
||||
}
|
||||
const output = getOutputResult.value;
|
||||
|
||||
resolve({ session, output });
|
||||
return;
|
||||
}
|
||||
const queueItem = queueItemResult.value;
|
||||
|
||||
if (status === 'failed') {
|
||||
reject(new SessionExecutionError(queueItemId, session, error_type, error_message, error_traceback));
|
||||
return;
|
||||
}
|
||||
const { status, session, error_type, error_message, error_traceback } = queueItem;
|
||||
|
||||
if (status === 'canceled') {
|
||||
reject(new SessionCancelationError(queueItemId, session));
|
||||
return;
|
||||
}
|
||||
// We are confident that the queue item is not pending or in progress, at this time.
|
||||
assert(status !== 'pending' && status !== 'in_progress');
|
||||
|
||||
if (status === 'completed') {
|
||||
const getOutputResult = withResult(() => getOutputFromSession(queueItemId, session, outputNodeId));
|
||||
if (getOutputResult.isErr()) {
|
||||
return ErrResult(getOutputResult.error);
|
||||
}
|
||||
const output = getOutputResult.value;
|
||||
return OkResult({ session, output });
|
||||
}
|
||||
|
||||
if (status === 'failed') {
|
||||
return ErrResult(new SessionExecutionError(queueItemId, session, error_type, error_message, error_traceback));
|
||||
}
|
||||
|
||||
if (status === 'canceled') {
|
||||
return ErrResult(new SessionCancelationError(queueItemId, session));
|
||||
}
|
||||
|
||||
assert<Equals<never, typeof status>>(false);
|
||||
});
|
||||
};
|
||||
|
||||
dependencies.eventHandler.subscribe(onQueueItemStatusChanged);
|
||||
@@ -285,20 +300,15 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
|
||||
dependencies.executor
|
||||
.enqueueBatch(batch)
|
||||
.then((data) => {
|
||||
// We queue a single run of the batch, so we expect only one item_id in the response.
|
||||
// We queue a single run of the batch, so we know there is only one item_id in the response.
|
||||
assert(data.item_ids.length === 1);
|
||||
assert(data.item_ids[0] !== undefined, 'Enqueue result is missing first queue item id');
|
||||
assert(data.item_ids[0] !== undefined);
|
||||
queueItemId = data.item_ids[0];
|
||||
})
|
||||
.catch((error) => {
|
||||
if (isFinished) {
|
||||
// Not sure how it could happen that we are settled at this point, but if it does, we don't want to
|
||||
// reject the promise again.
|
||||
return;
|
||||
}
|
||||
isFinished = true;
|
||||
cleanup();
|
||||
reject(error);
|
||||
.catch(async (error) => {
|
||||
await settle(() => {
|
||||
return ErrResult(error);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -364,6 +374,20 @@ export class SessionError extends QueueItemError {
|
||||
}
|
||||
}
|
||||
|
||||
export class UnexpectedStatusError extends SessionError {
|
||||
status: S['SessionQueueItem']['status'];
|
||||
|
||||
constructor(
|
||||
queueItemId: number | null,
|
||||
session: S['SessionQueueItem']['session'],
|
||||
status: S['SessionQueueItem']['status']
|
||||
) {
|
||||
super(queueItemId, session, `Session has unexpected status ${status}.`);
|
||||
this.name = this.constructor.name;
|
||||
this.status = status;
|
||||
}
|
||||
}
|
||||
|
||||
export class NodeNotFoundError extends SessionError {
|
||||
nodeId: string;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user