refactor(ui): better race condition handling in runGraph

This commit is contained in:
psychedelicious
2025-06-28 23:16:47 +10:00
parent bbd21b1eb2
commit 4b84e34599
2 changed files with 120 additions and 88 deletions

View File

@@ -57,7 +57,7 @@ export class Err<E> {
* @template T The type of the value in the `Ok` case.
* @template E The type of the error in the `Err` case.
*/
type Result<T, E> = Ok<T> | Err<E>;
export type Result<T, E = Error> = Ok<T> | Err<E>;
/**
* Creates a successful result.
@@ -85,12 +85,11 @@ export function ErrResult<E>(error: E): Err<E> {
* @param {() => T} fn The function to execute.
* @returns {Result<T>} An `Ok` result if the function succeeds, or an `Err` result if it throws an error.
*/
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export function withResult<T>(fn: () => T): Result<T, any> {
export function withResult<T>(fn: () => T): Result<T> {
try {
return new Ok(fn());
} catch (error) {
return new Err(error);
return new Err(error instanceof Error ? error : new WrappedError(error));
}
}
@@ -100,12 +99,21 @@ export function withResult<T>(fn: () => T): Result<T, any> {
* @param {() => Promise<T>} fn The asynchronous function to execute.
* @returns {Promise<Result<T>>} A `Promise` resolving to an `Ok` result if the function succeeds, or an `Err` result if it throws an error.
*/
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export async function withResultAsync<T>(fn: () => Promise<T>): Promise<Result<T, any>> {
export async function withResultAsync<T>(fn: () => Promise<T>): Promise<Result<T>> {
try {
const result = await fn();
return new Ok(result);
} catch (error) {
return new Err(error);
return new Err(error instanceof Error ? error : new WrappedError(error));
}
}
export class WrappedError extends Error {
data: unknown;
constructor(data: unknown) {
super('Wrapped Error');
this.name = this.constructor.name;
this.data = data;
}
}

View File

@@ -1,10 +1,13 @@
import { logger } from 'app/logging/logger';
import type { AppStore } from 'app/store/store';
import { withResult, withResultAsync } from 'common/util/result';
import { Mutex } from 'async-mutex';
import type { Result } from 'common/util/result';
import { ErrResult, OkResult, withResult, withResultAsync } from 'common/util/result';
import { parseify } from 'common/util/serialize';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { S } from 'services/api/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from './endpoints/queue';
@@ -147,12 +150,6 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
},
};
/**
* Flag to indicate whether we have finished with the business logic of executing the graph. This is used to
* prevent multiple promise resolutions. This flag must be set to true before the promise is resolved or rejected.
*/
let isFinished = false;
/**
* The queue item id is set to null initially, but will be updated once the graph is enqueued. It will be used to
* retrieve the queue item.
@@ -173,23 +170,51 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
}
};
// If a timeout value is provided, we create a timer to reject the promise.
if (timeout !== undefined) {
const timeoutId = setTimeout(() => {
if (isFinished) {
/**
* We use a mutex to ensure that the promise is resolved or rejected only once, even if multiple events
* are received or the settle function is called multiple times.
*
* A flag allows pending locks to bail if the promise has already been settled.
*/
let isSettling = false;
const settlementMutex = new Mutex();
const settle = async (settlement: () => Promise<Result<RunGraphReturn, Error>> | Result<RunGraphReturn, Error>) => {
await settlementMutex.runExclusive(async () => {
// If we are already settling, ignore this call to avoid multiple resolutions or rejections.
// We don't want to _cancel_ pending locks as this would raise.
if (isSettling) {
return;
}
isFinished = true;
log.trace('Graph canceled by timeout');
isSettling = true;
// Clean up listeners, timeouts, etc. ASAP.
cleanup();
if (queueItemId !== null) {
// It's possible the cancelation will fail, but we have no way to handle that gracefully. Log a warning
// and move on to reject.
dependencies.executor.cancelQueueItem(queueItemId).catch((error) => {
log.warn({ error: parseify(error) }, 'Failed to cancel queue item during timeout');
});
// Normalize the settlement function to always return a promise.
const result = await Promise.resolve(settlement());
if (result.isOk()) {
resolve(result.value);
} else {
reject(result.error);
}
reject(new SessionTimeoutError(queueItemId));
});
};
// If a timeout value is provided, we create a timer to reject the promise.
if (timeout !== undefined) {
const timeoutId = setTimeout(async () => {
await settle(() => {
log.trace('Graph canceled by timeout');
if (queueItemId !== null) {
// It's possible the cancelation will fail, but we have no way to handle that gracefully. Log a warning
// and move on to reject.
dependencies.executor.cancelQueueItem(queueItemId).catch((error) => {
log.warn({ error: parseify(error) }, 'Failed to cancel queue item during timeout');
});
}
return ErrResult(new SessionTimeoutError(queueItemId));
});
}, timeout);
cleanupFunctions.add(() => {
@@ -200,20 +225,17 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
// If a signal is provided, we add an abort handler to reject the promise if the signal is aborted.
if (signal !== undefined) {
const abortHandler = () => {
if (isFinished) {
return;
}
isFinished = true;
log.trace('Graph canceled by signal');
cleanup();
if (queueItemId !== null) {
// It's possible the cancelation will fail, but we have no way to handle that gracefully. Log a warning
// and move on to reject.
dependencies.executor.cancelQueueItem(queueItemId).catch((error) => {
log.warn({ error: parseify(error) }, 'Failed to cancel queue item during abort');
});
}
reject(new SessionAbortedError(queueItemId));
settle(() => {
log.trace('Graph canceled by signal');
if (queueItemId !== null) {
// It's possible the cancelation will fail, but we have no way to handle that gracefully. Log a warning
// and move on to reject.
dependencies.executor.cancelQueueItem(queueItemId).catch((error) => {
log.warn({ error: parseify(error) }, 'Failed to cancel queue item during abort');
});
}
return ErrResult(new SessionAbortedError(queueItemId));
});
};
signal.addEventListener('abort', abortHandler);
@@ -224,10 +246,6 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
// Handle the queue item status change events.
const onQueueItemStatusChanged = async (event: S['QueueItemStatusChangedEvent']) => {
if (isFinished) {
return;
}
// Ignore events that are not for this graph
if (event.origin !== origin) {
return;
@@ -238,42 +256,39 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
return;
}
// The queue item is finished - retrieve it, extract results and resolve or reject the promise
isFinished = true;
cleanup();
// We need to handle any errors, including retrieving the queue item
const queueItemResult = await withResultAsync(() => dependencies.executor.getQueueItem(event.item_id));
if (queueItemResult.isErr()) {
reject(queueItemResult.error);
return;
}
const queueItem = queueItemResult.value;
const { status, session, error_type, error_message, error_traceback } = queueItem;
if (status === 'completed') {
const getOutputResult = withResult(() => getOutputFromSession(queueItemId, session, outputNodeId));
if (getOutputResult.isErr()) {
reject(getOutputResult.error);
return;
await settle(async () => {
// We need to handle any errors, including retrieving the queue item
const queueItemResult = await withResultAsync(() => dependencies.executor.getQueueItem(event.item_id));
if (queueItemResult.isErr()) {
return ErrResult(queueItemResult.error);
}
const output = getOutputResult.value;
resolve({ session, output });
return;
}
const queueItem = queueItemResult.value;
if (status === 'failed') {
reject(new SessionExecutionError(queueItemId, session, error_type, error_message, error_traceback));
return;
}
const { status, session, error_type, error_message, error_traceback } = queueItem;
if (status === 'canceled') {
reject(new SessionCancelationError(queueItemId, session));
return;
}
// We are confident that the queue item is not pending or in progress, at this time.
assert(status !== 'pending' && status !== 'in_progress');
if (status === 'completed') {
const getOutputResult = withResult(() => getOutputFromSession(queueItemId, session, outputNodeId));
if (getOutputResult.isErr()) {
return ErrResult(getOutputResult.error);
}
const output = getOutputResult.value;
return OkResult({ session, output });
}
if (status === 'failed') {
return ErrResult(new SessionExecutionError(queueItemId, session, error_type, error_message, error_traceback));
}
if (status === 'canceled') {
return ErrResult(new SessionCancelationError(queueItemId, session));
}
assert<Equals<never, typeof status>>(false);
});
};
dependencies.eventHandler.subscribe(onQueueItemStatusChanged);
@@ -285,20 +300,15 @@ export const runGraph = (arg: RunGraphArg): Promise<RunGraphReturn> => {
dependencies.executor
.enqueueBatch(batch)
.then((data) => {
// We queue a single run of the batch, so we expect only one item_id in the response.
// We queue a single run of the batch, so we know there is only one item_id in the response.
assert(data.item_ids.length === 1);
assert(data.item_ids[0] !== undefined, 'Enqueue result is missing first queue item id');
assert(data.item_ids[0] !== undefined);
queueItemId = data.item_ids[0];
})
.catch((error) => {
if (isFinished) {
// Not sure how it could happen that we are settled at this point, but if it does, we don't want to
// reject the promise again.
return;
}
isFinished = true;
cleanup();
reject(error);
.catch(async (error) => {
await settle(() => {
return ErrResult(error);
});
});
});
@@ -364,6 +374,20 @@ export class SessionError extends QueueItemError {
}
}
export class UnexpectedStatusError extends SessionError {
status: S['SessionQueueItem']['status'];
constructor(
queueItemId: number | null,
session: S['SessionQueueItem']['session'],
status: S['SessionQueueItem']['status']
) {
super(queueItemId, session, `Session has unexpected status ${status}.`);
this.name = this.constructor.name;
this.status = status;
}
}
export class NodeNotFoundError extends SessionError {
nodeId: string;