From 4fa83a622875bbe4a4a09a8328fe1e0987bf9995 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 28 Jun 2025 22:15:44 +1000 Subject: [PATCH] feat(ui): better error handling for runGraph --- .../web/src/services/api/run-graph.test.ts | 122 ++++++++--------- .../web/src/services/api/run-graph.ts | 126 ++++++++++++++---- .../web/src/services/events/errors.ts | 23 ---- 3 files changed, 153 insertions(+), 118 deletions(-) delete mode 100644 invokeai/frontend/web/src/services/events/errors.ts diff --git a/invokeai/frontend/web/src/services/api/run-graph.test.ts b/invokeai/frontend/web/src/services/api/run-graph.test.ts index ef4aec9b90..3d18d9a214 100644 --- a/invokeai/frontend/web/src/services/api/run-graph.test.ts +++ b/invokeai/frontend/web/src/services/api/run-graph.test.ts @@ -1,10 +1,18 @@ +import { getPrefixedId } from 'features/controlLayers/konva/util'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { S } from 'services/api/types'; -import { QueueError } from 'services/events/errors'; import type { PartialDeep } from 'type-fest'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; -import { runGraph } from './run-graph'; +import { + IterateNodeFoundError, + OutputNodeNotFoundError, + runGraph, + SessionAbortedError, + SessionCancelationError, + SessionExecutionError, + SessionTimeoutError, +} from './run-graph'; // Mock dependencies vi.mock('app/logging/logger', () => ({ @@ -21,8 +29,11 @@ vi.mock('features/controlLayers/konva/util', () => ({ getPrefixedId: (prefix: string) => `${prefix}:mock-id-123`, })); +const TEST_ID = 'test-graph'; +const TEST_ORIGIN = getPrefixedId(TEST_ID); + // Helper functions for creating mock objects -const createMockGraph = (id = 'test-graph', hasIterateNodes = false): Graph => { +const createMockGraph = (id = TEST_ID, hasIterateNodes = false): Graph => { const mockNodes = hasIterateNodes ? { node1: { type: 'iterate' }, node2: { type: 'resize' } } : { node1: { type: 'resize' }, node2: { type: 'add' } }; @@ -113,7 +124,7 @@ describe('runGraph', () => { }); describe('validation', () => { - it('should reject if graph does not contain output node', async () => { + it('should reject with OutputNodeNotFoundError if graph does not contain output node', async () => { mockGraph.hasNode = vi.fn().mockReturnValue(false); const promise = runGraph({ @@ -122,10 +133,10 @@ describe('runGraph', () => { dependencies: { executor: mockExecutor, eventHandler: mockEventHandler }, }); - await expect(promise).rejects.toThrow('Graph does not contain output node non-existent-node.'); + await expect(promise).rejects.toThrow(OutputNodeNotFoundError); }); - it('should reject if graph contains iterate nodes', async () => { + it('should reject with IterateNodeFoundError if graph contains iterate nodes', async () => { const graphWithIterateNodes = createMockGraph('test', true); const promise = runGraph({ @@ -134,7 +145,7 @@ describe('runGraph', () => { dependencies: { executor: mockExecutor, eventHandler: mockEventHandler }, }); - await expect(promise).rejects.toThrow('Iterate nodes are not supported by this utility.'); + await expect(promise).rejects.toThrow(IterateNodeFoundError); }); }); @@ -155,7 +166,7 @@ describe('runGraph', () => { mockEventHandler._triggerEvent({ item_id: 1, status: 'completed', - origin: 'test-graph:mock-id-123', + origin: TEST_ORIGIN, } as S['QueueItemStatusChangedEvent']); }); @@ -169,7 +180,7 @@ describe('runGraph', () => { prepend: undefined, batch: { graph: mockGraph.getGraph(), - origin: 'test-graph:mock-id-123', + origin: TEST_ORIGIN, destination: undefined, runs: 1, }, @@ -193,7 +204,7 @@ describe('runGraph', () => { mockEventHandler._triggerEvent({ item_id: 1, status: 'completed', - origin: 'test-graph:mock-id-123', + origin: TEST_ORIGIN, } as S['QueueItemStatusChangedEvent']); }); @@ -203,7 +214,7 @@ describe('runGraph', () => { prepend: true, batch: { graph: mockGraph.getGraph(), - origin: 'test-graph:mock-id-123', + origin: TEST_ORIGIN, destination: 'test-destination', runs: 1, }, @@ -212,7 +223,7 @@ describe('runGraph', () => { }); describe('error handling', () => { - it('should reject with QueueError on failed status with error details', async () => { + it('should reject with SessionExecutionError on failed status with error details', async () => { const mockQueueItem = createMockQueueItem('failed', { error_type: 'ValidationError', error_message: 'Invalid input', @@ -231,41 +242,14 @@ describe('runGraph', () => { mockEventHandler._triggerEvent({ item_id: 1, status: 'failed', - origin: 'test-graph:mock-id-123', + origin: TEST_ORIGIN, } as S['QueueItemStatusChangedEvent']); }); - await expect(promise).rejects.toThrow(QueueError); - await expect(promise).rejects.toThrow('Invalid input'); + await expect(promise).rejects.toThrow(SessionExecutionError); }); - it('should reject with generic error on failed status without error details', async () => { - const mockQueueItem = createMockQueueItem('failed', { - error_type: null, - error_message: null, - error_traceback: null, - }); - mockExecutor.enqueueBatch.mockResolvedValue({ item_ids: [1] }); - mockExecutor.getQueueItem.mockResolvedValue(mockQueueItem); - - const promise = runGraph({ - graph: mockGraph, - outputNodeId: 'output-node', - dependencies: { executor: mockExecutor, eventHandler: mockEventHandler }, - }); - - setImmediate(() => { - mockEventHandler._triggerEvent({ - item_id: 1, - status: 'failed', - origin: 'test-graph:mock-id-123', - } as S['QueueItemStatusChangedEvent']); - }); - - await expect(promise).rejects.toThrow('Queue item failed, but no error details were provided'); - }); - - it('should reject on canceled status', async () => { + it('should reject with SessionCancelationError on canceled status', async () => { const mockQueueItem = createMockQueueItem('canceled'); mockExecutor.enqueueBatch.mockResolvedValue({ item_ids: [1] }); mockExecutor.getQueueItem.mockResolvedValue(mockQueueItem); @@ -280,14 +264,15 @@ describe('runGraph', () => { mockEventHandler._triggerEvent({ item_id: 1, status: 'canceled', - origin: 'test-graph:mock-id-123', + origin: TEST_ORIGIN, } as S['QueueItemStatusChangedEvent']); }); - await expect(promise).rejects.toThrow('Graph canceled'); + await expect(promise).rejects.toThrow(SessionCancelationError); }); it('should reject if enqueueBatch fails', async () => { + // The error we are testing here is provided by the API client. We do not know the exact error type. const error = new Error('Enqueue failed'); mockExecutor.enqueueBatch.mockRejectedValue(error); @@ -301,6 +286,7 @@ describe('runGraph', () => { }); it('should reject if getQueueItem fails', async () => { + // The error we are testing here is provided by the API client. We do not know the exact error type. mockExecutor.enqueueBatch.mockResolvedValue({ item_ids: [1] }); mockExecutor.getQueueItem.mockRejectedValue(new Error('Get queue item failed')); @@ -314,7 +300,7 @@ describe('runGraph', () => { mockEventHandler._triggerEvent({ item_id: 1, status: 'completed', - origin: 'test-graph:mock-id-123', + origin: TEST_ORIGIN, } as S['QueueItemStatusChangedEvent']); }); @@ -331,7 +317,7 @@ describe('runGraph', () => { vi.useRealTimers(); }); - it('should timeout and cancel queue item if timeout is exceeded', async () => { + it('should timeout, reject with a SessionTimeoutError, and cancel queue item if timeout is exceeded', async () => { let resolveEnqueue: (value: { item_ids: number[] }) => void = () => {}; const enqueuePromise = new Promise<{ item_ids: number[] }>((resolve) => { resolveEnqueue = resolve; @@ -354,7 +340,7 @@ describe('runGraph', () => { // Fast-forward time to trigger timeout vi.advanceTimersByTime(1001); - await expect(promise).rejects.toThrow('Graph timed out'); + await expect(promise).rejects.toThrow(SessionTimeoutError); expect(mockExecutor.cancelQueueItem).toHaveBeenCalledWith(1); }); @@ -374,7 +360,7 @@ describe('runGraph', () => { mockEventHandler._triggerEvent({ item_id: 1, status: 'completed', - origin: 'test-graph:mock-id-123', + origin: TEST_ORIGIN, } as S['QueueItemStatusChangedEvent']); vi.advanceTimersByTime(500); @@ -384,7 +370,7 @@ describe('runGraph', () => { expect(mockExecutor.cancelQueueItem).not.toHaveBeenCalled(); }); - it('should timeout without canceling if queue item ID is not yet available', async () => { + it('should timeout and reject with a SessionTimeoutError without canceling if queue item ID is not yet available', async () => { // Don't resolve enqueueBatch to simulate slow enqueue const slowEnqueuePromise = new Promise(() => {}); // Never resolves mockExecutor.enqueueBatch.mockReturnValue(slowEnqueuePromise); @@ -399,14 +385,14 @@ describe('runGraph', () => { // Fast-forward time to trigger timeout before enqueue completes vi.advanceTimersByTime(1001); - await expect(promise).rejects.toThrow('Graph timed out'); + await expect(promise).rejects.toThrow(SessionTimeoutError); // Should not attempt to cancel since queue item ID is not available expect(mockExecutor.cancelQueueItem).not.toHaveBeenCalled(); }); }); describe('abort signal handling', () => { - it('should cancel queue item when abort signal is triggered', async () => { + it('should reject with a SessionAbortedError and cancel the queue item when signal is aborted', async () => { const controller = new AbortController(); mockExecutor.enqueueBatch.mockResolvedValue({ item_ids: [1] }); mockExecutor.cancelQueueItem.mockResolvedValue(createMockQueueItem('canceled')); @@ -422,7 +408,7 @@ describe('runGraph', () => { controller.abort(); }); - await expect(promise).rejects.toThrow('Graph canceled'); + await expect(promise).rejects.toThrow(SessionAbortedError); expect(mockExecutor.cancelQueueItem).toHaveBeenCalledWith(1); }); @@ -443,7 +429,7 @@ describe('runGraph', () => { mockEventHandler._triggerEvent({ item_id: 1, status: 'completed', - origin: 'test-graph:mock-id-123', + origin: TEST_ORIGIN, } as S['QueueItemStatusChangedEvent']); const result = await promise; @@ -454,7 +440,7 @@ describe('runGraph', () => { expect(mockExecutor.cancelQueueItem).not.toHaveBeenCalled(); }); - it('should cancel without queue item if aborted before enqueue completion', async () => { + it('should reject with SessionAbortedError and not cancel the queue item if aborted before enqueue completion', async () => { const controller = new AbortController(); // Don't resolve enqueueBatch to simulate slow enqueue const slowEnqueuePromise = new Promise(() => {}); // Never resolves @@ -471,7 +457,7 @@ describe('runGraph', () => { controller.abort(); }); - await expect(promise).rejects.toThrow('Graph canceled'); + await expect(promise).rejects.toThrow(SessionAbortedError); // Should not attempt to cancel since queue item ID is not available expect(mockExecutor.cancelQueueItem).not.toHaveBeenCalled(); }); @@ -501,7 +487,7 @@ describe('runGraph', () => { mockEventHandler._triggerEvent({ item_id: 1, status: 'completed', - origin: 'test-graph:mock-id-123', + origin: TEST_ORIGIN, } as S['QueueItemStatusChangedEvent']); }); @@ -526,13 +512,13 @@ describe('runGraph', () => { mockEventHandler._triggerEvent({ item_id: 1, status: 'pending', - origin: 'test-graph:mock-id-123', + origin: TEST_ORIGIN, } as S['QueueItemStatusChangedEvent']); mockEventHandler._triggerEvent({ item_id: 1, status: 'in_progress', - origin: 'test-graph:mock-id-123', + origin: TEST_ORIGIN, } as S['QueueItemStatusChangedEvent']); // This should trigger completion @@ -540,7 +526,7 @@ describe('runGraph', () => { mockEventHandler._triggerEvent({ item_id: 1, status: 'completed', - origin: 'test-graph:mock-id-123', + origin: TEST_ORIGIN, } as S['QueueItemStatusChangedEvent']); }); @@ -565,7 +551,7 @@ describe('runGraph', () => { mockEventHandler._triggerEvent({ item_id: 1, status: 'completed', - origin: 'test-graph:mock-id-123', + origin: TEST_ORIGIN, } as S['QueueItemStatusChangedEvent']); await promise; @@ -618,7 +604,7 @@ describe('runGraph', () => { mockEventHandler._triggerEvent({ item_id: 1, status: 'completed', - origin: 'test-graph:mock-id-123', + origin: TEST_ORIGIN, } as S['QueueItemStatusChangedEvent']); await expect(promise).rejects.toThrow("Node 'output-node' not found in session"); @@ -652,7 +638,7 @@ describe('runGraph', () => { mockEventHandler._triggerEvent({ item_id: 1, status: 'completed', - origin: 'test-graph:mock-id-123', + origin: TEST_ORIGIN, } as S['QueueItemStatusChangedEvent']); await expect(promise).rejects.toThrow("Result for node 'output-node' not found in session"); @@ -680,7 +666,7 @@ describe('runGraph', () => { mockEventHandler._triggerEvent({ item_id: 1, status: 'completed', - origin: 'test-graph:mock-id-123', + origin: TEST_ORIGIN, } as S['QueueItemStatusChangedEvent']); // Now resolve the enqueue @@ -706,26 +692,26 @@ describe('runGraph', () => { mockEventHandler._triggerEvent({ item_id: 1, status: 'pending', - origin: 'test-graph:mock-id-123', + origin: TEST_ORIGIN, } as S['QueueItemStatusChangedEvent']); mockEventHandler._triggerEvent({ item_id: 1, status: 'in_progress', - origin: 'test-graph:mock-id-123', + origin: TEST_ORIGIN, } as S['QueueItemStatusChangedEvent']); mockEventHandler._triggerEvent({ item_id: 1, status: 'completed', - origin: 'test-graph:mock-id-123', + origin: TEST_ORIGIN, } as S['QueueItemStatusChangedEvent']); // Trigger another completed event (should be ignored since already resolved) mockEventHandler._triggerEvent({ item_id: 1, status: 'completed', - origin: 'test-graph:mock-id-123', + origin: TEST_ORIGIN, } as S['QueueItemStatusChangedEvent']); const result = await promise; @@ -755,7 +741,7 @@ describe('runGraph', () => { // Abort the operation while enqueueBatch is still pending controller.abort(); - await expect(promise).rejects.toThrow('Graph canceled'); + await expect(promise).rejects.toThrow(SessionAbortedError); // Verify cleanup happened - event handler should be unsubscribed expect(mockEventHandler.unsubscribe).toHaveBeenCalled(); diff --git a/invokeai/frontend/web/src/services/api/run-graph.ts b/invokeai/frontend/web/src/services/api/run-graph.ts index 8bbc82aeb7..a7f6257eb9 100644 --- a/invokeai/frontend/web/src/services/api/run-graph.ts +++ b/invokeai/frontend/web/src/services/api/run-graph.ts @@ -5,7 +5,6 @@ import { parseify } from 'common/util/serialize'; import { getPrefixedId } from 'features/controlLayers/konva/util'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { S } from 'services/api/types'; -import { QueueError } from 'services/events/errors'; import { assert } from 'tsafe'; import { enqueueMutationFixedCacheKeyOptions, queueApi } from './endpoints/queue'; @@ -115,14 +114,14 @@ export const runGraph = (arg: RunGraphArg): Promise => { const { graph, outputNodeId, dependencies, destination, prepend, timeout, signal } = arg; if (!graph.hasNode(outputNodeId)) { - reject(new Error(`Graph does not contain output node ${outputNodeId}.`)); + reject(new OutputNodeNotFoundError(outputNodeId, graph)); return; } const g = graph.getGraph(); if (Object.values(g.nodes).some((node) => node.type === 'iterate')) { - reject(new Error('Iterate nodes are not supported by this utility.')); + reject(new IterateNodeFoundError(graph)); return; } @@ -190,7 +189,7 @@ export const runGraph = (arg: RunGraphArg): Promise => { log.warn({ error: parseify(error) }, 'Failed to cancel queue item during timeout'); }); } - reject(new Error('Graph timed out')); + reject(new SessionTimeoutError(queueItemId)); }, timeout); cleanupFunctions.add(() => { @@ -214,7 +213,7 @@ export const runGraph = (arg: RunGraphArg): Promise => { log.warn({ error: parseify(error) }, 'Failed to cancel queue item during abort'); }); } - reject(new Error('Graph canceled')); + reject(new SessionAbortedError(queueItemId)); }; signal.addEventListener('abort', abortHandler); @@ -255,7 +254,7 @@ export const runGraph = (arg: RunGraphArg): Promise => { const { status, session, error_type, error_message, error_traceback } = queueItem; if (status === 'completed') { - const getOutputResult = withResult(() => getOutputFromSession(session, outputNodeId)); + const getOutputResult = withResult(() => getOutputFromSession(queueItemId, session, outputNodeId)); if (getOutputResult.isErr()) { reject(getOutputResult.error); return; @@ -267,19 +266,12 @@ export const runGraph = (arg: RunGraphArg): Promise => { } if (status === 'failed') { - // We expect the event to have error details, but technically it's possible that it doesn't - if (error_type && error_message && error_traceback) { - reject(new QueueError(error_type, error_message, error_traceback)); - return; - } - - // If we don't have error details, we can't provide a useful error message - reject(new Error('Queue item failed, but no error details were provided')); + reject(new SessionExecutionError(queueItemId, session, error_type, error_message, error_traceback)); return; } if (status === 'canceled') { - reject(new Error('Graph canceled')); + reject(new SessionCancelationError(queueItemId, session)); return; } }; @@ -314,41 +306,121 @@ export const runGraph = (arg: RunGraphArg): Promise => { }; const getOutputFromSession = ( + queueItemId: number | null, session: S['SessionQueueItem']['session'], nodeId: string ): S['SessionQueueItem']['session']['results'][string] => { const { results, source_prepared_mapping } = session; const preparedNodeId = source_prepared_mapping[nodeId]?.[0]; if (!preparedNodeId) { - throw new NodeNotFoundError(nodeId, session); + throw new NodeNotFoundError(queueItemId, session, nodeId); } const result = results[preparedNodeId]; if (!result) { - throw new ResultNotFoundError(nodeId, session); + throw new ResultNotFoundError(queueItemId, session, nodeId); } return result; }; -class NodeNotFoundError extends Error { - session: S['SessionQueueItem']['session']; - nodeId: string; +export class OutputNodeNotFoundError extends Error { + outputNodeId: string; + graph: Graph; - constructor(nodeId: string, session: S['SessionQueueItem']['session']) { - super(`Node '${nodeId}' not found in session.`); + constructor(outputNodeId: string, graph: Graph) { + super(`Output node '${outputNodeId}' not found in the graph.`); + this.name = this.constructor.name; + this.outputNodeId = outputNodeId; + this.graph = graph; + } +} + +export class IterateNodeFoundError extends Error { + graph: Graph; + + constructor(graph: Graph) { + super('Iterate node(s) found in the graph.'); + this.name = this.constructor.name; + this.graph = graph; + } +} + +export class QueueItemError extends Error { + queueItemId: number | null; + + constructor(queueItemId: number | null, message?: string) { + super(message ?? 'Queue item error occurred'); + this.name = this.constructor.name; + this.queueItemId = queueItemId; + } +} + +export class SessionError extends QueueItemError { + session: S['SessionQueueItem']['session']; + + constructor(queueItemId: number | null, session: S['SessionQueueItem']['session'], message?: string) { + super(queueItemId, message ?? 'Session error occurred'); this.name = this.constructor.name; this.session = session; + } +} + +export class NodeNotFoundError extends SessionError { + nodeId: string; + + constructor(queueItemId: number | null, session: S['SessionQueueItem']['session'], nodeId: string) { + super(queueItemId, session, `Node '${nodeId}' not found in session.`); + this.name = this.constructor.name; this.nodeId = nodeId; } } -class ResultNotFoundError extends Error { - session: S['SessionQueueItem']['session']; +export class ResultNotFoundError extends SessionError { nodeId: string; - constructor(nodeId: string, session: S['SessionQueueItem']['session']) { - super(`Result for node '${nodeId}' not found in session.`); + constructor(queueItemId: number | null, session: S['SessionQueueItem']['session'], nodeId: string) { + super(queueItemId, session, `Result for node '${nodeId}' not found in session.`); this.name = this.constructor.name; - this.session = session; this.nodeId = nodeId; } } + +export class SessionExecutionError extends SessionError { + error_type?: string | null; + error_message?: string | null; + error_traceback?: string | null; + + constructor( + queueItemId: number | null, + session: S['SessionQueueItem']['session'], + error_type?: string | null, + error_message?: string | null, + error_traceback?: string | null + ) { + super(queueItemId, session, 'Session execution failed'); + this.name = this.constructor.name; + this.error_type = error_type; + this.error_traceback = error_traceback; + this.error_message = error_message; + } +} + +export class SessionCancelationError extends SessionError { + constructor(queueItemId: number | null, session: S['SessionQueueItem']['session']) { + super(queueItemId, session, 'Session execution was canceled'); + this.name = this.constructor.name; + } +} + +export class SessionAbortedError extends QueueItemError { + constructor(queueItemId: number | null) { + super(queueItemId, 'Session execution was aborted via signal'); + this.name = this.constructor.name; + } +} + +export class SessionTimeoutError extends QueueItemError { + constructor(queueItemId: number | null) { + super(queueItemId, 'Session execution timed out'); + this.name = this.constructor.name; + } +} diff --git a/invokeai/frontend/web/src/services/events/errors.ts b/invokeai/frontend/web/src/services/events/errors.ts deleted file mode 100644 index 24100939e9..0000000000 --- a/invokeai/frontend/web/src/services/events/errors.ts +++ /dev/null @@ -1,23 +0,0 @@ -/** - * A custom error class for queue event errors. These errors have a type, message and traceback. - */ - -export class QueueError extends Error { - type: string; - traceback: string; - - constructor(type: string, message: string, traceback: string) { - super(message); - this.name = 'QueueError'; - this.type = type; - this.traceback = traceback; - - if (Error.captureStackTrace) { - Error.captureStackTrace(this, QueueError); - } - } - - toString() { - return `${this.name} [${this.type}]: ${this.message}\nTraceback:\n${this.traceback}`; - } -}