mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-13 12:44:56 -05:00
feat(ui): better exception naming and docstrings in runGraph
This commit is contained in:
@@ -5,14 +5,14 @@ import type { PartialDeep } from 'type-fest';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import {
|
||||
IterateNodeFoundError,
|
||||
NodeNotFoundError,
|
||||
OutputNodeNotFoundError,
|
||||
ResultNotFoundError,
|
||||
IterateNodeFoundInGraphError,
|
||||
OutputNodeNotFoundInCompletedSessionError,
|
||||
OutputNodeNotFoundInGraphError,
|
||||
ResultNotFoundInCompletedSessionError,
|
||||
runGraph,
|
||||
SessionAbortedError,
|
||||
SessionCancelationError,
|
||||
SessionExecutionError,
|
||||
SessionCanceledError,
|
||||
SessionFailedError,
|
||||
SessionTimeoutError,
|
||||
} from './run-graph';
|
||||
|
||||
@@ -135,7 +135,7 @@ describe('runGraph', () => {
|
||||
dependencies: { executor: mockExecutor, eventHandler: mockEventHandler },
|
||||
});
|
||||
|
||||
await expect(promise).rejects.toThrow(OutputNodeNotFoundError);
|
||||
await expect(promise).rejects.toThrow(OutputNodeNotFoundInGraphError);
|
||||
});
|
||||
|
||||
it('should reject with IterateNodeFoundError if graph contains iterate nodes', async () => {
|
||||
@@ -147,7 +147,7 @@ describe('runGraph', () => {
|
||||
dependencies: { executor: mockExecutor, eventHandler: mockEventHandler },
|
||||
});
|
||||
|
||||
await expect(promise).rejects.toThrow(IterateNodeFoundError);
|
||||
await expect(promise).rejects.toThrow(IterateNodeFoundInGraphError);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -248,7 +248,7 @@ describe('runGraph', () => {
|
||||
} as S['QueueItemStatusChangedEvent']);
|
||||
});
|
||||
|
||||
await expect(promise).rejects.toThrow(SessionExecutionError);
|
||||
await expect(promise).rejects.toThrow(SessionFailedError);
|
||||
});
|
||||
|
||||
it('should reject with SessionCancelationError on canceled status', async () => {
|
||||
@@ -270,7 +270,7 @@ describe('runGraph', () => {
|
||||
} as S['QueueItemStatusChangedEvent']);
|
||||
});
|
||||
|
||||
await expect(promise).rejects.toThrow(SessionCancelationError);
|
||||
await expect(promise).rejects.toThrow(SessionCanceledError);
|
||||
});
|
||||
|
||||
it('should reject if enqueueBatch fails', async () => {
|
||||
@@ -609,7 +609,7 @@ describe('runGraph', () => {
|
||||
origin: TEST_ORIGIN,
|
||||
} as S['QueueItemStatusChangedEvent']);
|
||||
|
||||
await expect(promise).rejects.toThrow(NodeNotFoundError);
|
||||
await expect(promise).rejects.toThrow(OutputNodeNotFoundInCompletedSessionError);
|
||||
});
|
||||
|
||||
it('should reject with ResultNotFoundError if result not found for prepared node', async () => {
|
||||
@@ -643,7 +643,7 @@ describe('runGraph', () => {
|
||||
origin: TEST_ORIGIN,
|
||||
} as S['QueueItemStatusChangedEvent']);
|
||||
|
||||
await expect(promise).rejects.toThrow(ResultNotFoundError);
|
||||
await expect(promise).rejects.toThrow(ResultNotFoundInCompletedSessionError);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -79,11 +79,25 @@ type RunGraphReturn = {
|
||||
* @param arg.outputNodeId The id of the node whose output will be retrieved.
|
||||
* @param arg.dependencies The dependencies for queue operations and event handling.
|
||||
* @param arg.destination The destination to assign to the batch. If omitted, the destination is not set.
|
||||
* @param arg.prepend Whether to prepend the graph to the front of the queue. If omitted, the graph is appended to the end of the queue.
|
||||
* @param arg.timeout The timeout for the batch. If omitted, there is no timeout.
|
||||
* @param arg.signal An optional signal to cancel the operation. If omitted, the operation cannot be canceled.
|
||||
* @param arg.prepend Whether to prepend the graph to the front of the queue. If omitted, the graph is appended to the
|
||||
* end of the queue.
|
||||
* @param arg.timeout The timeout for the run in milliseconds. The promise rejects with a SessionTimeoutError when
|
||||
* the run times out. If the queue item was enqueued, a best effort is made to cancel it. **If omitted, there is
|
||||
* no timeout and the run will wait indefinitely for completion.**
|
||||
* @param arg.signal An optional signal to cancel the operation. The promise rejects with a SessionAbortedError when
|
||||
* the run is canceled via signal. If the queue item was enqueued, a best effort is made to cancel it. **If omitted,
|
||||
* the run cannot easily be canceled.**
|
||||
*
|
||||
* @returns A promise that resolves to the output and completed session, or rejects with an error if the graph fails or is canceled.
|
||||
* @returns A promise that resolves to the output and completed session, or rejects with an error:
|
||||
* - `OutputNodeNotFoundInGraphError` if the output node is not found in the provided graph.
|
||||
* - `IterateNodeFoundInGraphError` if the graph contains any iterate nodes, which are not supported.
|
||||
* - `UnexpectedStatusError` if the session has an unexpected status (not completed, failed, canceled).
|
||||
* - `OutputNodeNotFoundInCompletedSessionError` if the output node is not found in the completed session.
|
||||
* - `ResultNotFoundInCompletedSessionError` if the result for the output node is not found in the completed session.
|
||||
* - `SessionFailedError` if the session execution fails, including error type, message, and traceback.
|
||||
* - `SessionCanceledError` if the session execution is canceled via the queue.
|
||||
* - `SessionAbortedError` if the session execution is aborted via signal.
|
||||
* - `SessionTimeoutError` if the session execution times out.
|
||||
*
|
||||
* @example
|
||||
*
|
||||
@@ -151,14 +165,14 @@ const _runGraph = async (
|
||||
const { graph, outputNodeId, dependencies, destination, prepend, timeout, signal } = arg;
|
||||
|
||||
if (!graph.hasNode(outputNodeId)) {
|
||||
reject(new OutputNodeNotFoundError(outputNodeId, graph));
|
||||
reject(new OutputNodeNotFoundInGraphError(outputNodeId, graph));
|
||||
return;
|
||||
}
|
||||
|
||||
const g = graph.getGraph();
|
||||
|
||||
if (Object.values(g.nodes).some((node) => node.type === 'iterate')) {
|
||||
reject(new IterateNodeFoundError(graph));
|
||||
reject(new IterateNodeFoundInGraphError(graph));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -315,11 +329,11 @@ const _runGraph = async (
|
||||
}
|
||||
|
||||
if (status === 'failed') {
|
||||
return ErrResult(new SessionExecutionError(queueItemId, session, error_type, error_message, error_traceback));
|
||||
return ErrResult(new SessionFailedError(queueItemId, session, error_type, error_message, error_traceback));
|
||||
}
|
||||
|
||||
if (status === 'canceled') {
|
||||
return ErrResult(new SessionCancelationError(queueItemId, session));
|
||||
return ErrResult(new SessionCanceledError(queueItemId, session));
|
||||
}
|
||||
|
||||
assert<Equals<never, typeof status>>(false);
|
||||
@@ -354,16 +368,16 @@ const getOutputFromSession = (
|
||||
const { results, source_prepared_mapping } = session;
|
||||
const preparedNodeId = source_prepared_mapping[nodeId]?.[0];
|
||||
if (!preparedNodeId) {
|
||||
throw new NodeNotFoundError(queueItemId, session, nodeId);
|
||||
throw new OutputNodeNotFoundInCompletedSessionError(queueItemId, session, nodeId);
|
||||
}
|
||||
const result = results[preparedNodeId];
|
||||
if (!result) {
|
||||
throw new ResultNotFoundError(queueItemId, session, nodeId);
|
||||
throw new ResultNotFoundInCompletedSessionError(queueItemId, session, nodeId);
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
export class OutputNodeNotFoundError extends Error {
|
||||
export class OutputNodeNotFoundInGraphError extends Error {
|
||||
outputNodeId: string;
|
||||
graph: Graph;
|
||||
|
||||
@@ -375,7 +389,7 @@ export class OutputNodeNotFoundError extends Error {
|
||||
}
|
||||
}
|
||||
|
||||
export class IterateNodeFoundError extends Error {
|
||||
export class IterateNodeFoundInGraphError extends Error {
|
||||
graph: Graph;
|
||||
|
||||
constructor(graph: Graph) {
|
||||
@@ -385,7 +399,7 @@ export class IterateNodeFoundError extends Error {
|
||||
}
|
||||
}
|
||||
|
||||
export class QueueItemError extends Error {
|
||||
class BaseQueueItemError extends Error {
|
||||
queueItemId: number | null;
|
||||
|
||||
constructor(queueItemId: number | null, message?: string) {
|
||||
@@ -395,7 +409,7 @@ export class QueueItemError extends Error {
|
||||
}
|
||||
}
|
||||
|
||||
export class SessionError extends QueueItemError {
|
||||
class BaseSessionError extends BaseQueueItemError {
|
||||
session: S['SessionQueueItem']['session'];
|
||||
|
||||
constructor(queueItemId: number | null, session: S['SessionQueueItem']['session'], message?: string) {
|
||||
@@ -405,7 +419,7 @@ export class SessionError extends QueueItemError {
|
||||
}
|
||||
}
|
||||
|
||||
export class UnexpectedStatusError extends SessionError {
|
||||
export class UnexpectedStatusError extends BaseSessionError {
|
||||
status: S['SessionQueueItem']['status'];
|
||||
|
||||
constructor(
|
||||
@@ -419,7 +433,7 @@ export class UnexpectedStatusError extends SessionError {
|
||||
}
|
||||
}
|
||||
|
||||
export class NodeNotFoundError extends SessionError {
|
||||
export class OutputNodeNotFoundInCompletedSessionError extends BaseSessionError {
|
||||
nodeId: string;
|
||||
|
||||
constructor(queueItemId: number | null, session: S['SessionQueueItem']['session'], nodeId: string) {
|
||||
@@ -429,7 +443,7 @@ export class NodeNotFoundError extends SessionError {
|
||||
}
|
||||
}
|
||||
|
||||
export class ResultNotFoundError extends SessionError {
|
||||
export class ResultNotFoundInCompletedSessionError extends BaseSessionError {
|
||||
nodeId: string;
|
||||
|
||||
constructor(queueItemId: number | null, session: S['SessionQueueItem']['session'], nodeId: string) {
|
||||
@@ -439,7 +453,7 @@ export class ResultNotFoundError extends SessionError {
|
||||
}
|
||||
}
|
||||
|
||||
export class SessionExecutionError extends SessionError {
|
||||
export class SessionFailedError extends BaseSessionError {
|
||||
error_type?: string | null;
|
||||
error_message?: string | null;
|
||||
error_traceback?: string | null;
|
||||
@@ -459,21 +473,21 @@ export class SessionExecutionError extends SessionError {
|
||||
}
|
||||
}
|
||||
|
||||
export class SessionCancelationError extends SessionError {
|
||||
export class SessionCanceledError extends BaseSessionError {
|
||||
constructor(queueItemId: number | null, session: S['SessionQueueItem']['session']) {
|
||||
super(queueItemId, session, 'Session execution was canceled');
|
||||
this.name = this.constructor.name;
|
||||
}
|
||||
}
|
||||
|
||||
export class SessionAbortedError extends QueueItemError {
|
||||
export class SessionAbortedError extends BaseQueueItemError {
|
||||
constructor(queueItemId: number | null) {
|
||||
super(queueItemId, 'Session execution was aborted via signal');
|
||||
this.name = this.constructor.name;
|
||||
}
|
||||
}
|
||||
|
||||
export class SessionTimeoutError extends QueueItemError {
|
||||
export class SessionTimeoutError extends BaseQueueItemError {
|
||||
constructor(queueItemId: number | null) {
|
||||
super(queueItemId, 'Session execution timed out');
|
||||
this.name = this.constructor.name;
|
||||
|
||||
Reference in New Issue
Block a user