Files
InvokeAI/invokeai/frontend/web/src/features/queue/hooks/useEnqueueCanvas.ts
2025-07-16 08:21:03 -04:00

148 lines
5.5 KiB
TypeScript

import type { AlertStatus } from '@invoke-ai/ui-library';
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStore } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
import { withResult, withResultAsync } from 'common/util/result';
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { canvasSessionIdChanged, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph';
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import type { GraphBuilderArg } from 'features/nodes/util/graph/types';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { serializeError } from 'serialize-error';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import { assert, AssertionError } from 'tsafe';
const log = logger('generation');
export const enqueueRequestedCanvas = createAction('app/enqueueRequestedCanvas');
const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prepend: boolean) => {
const { dispatch, getState } = store;
dispatch(enqueueRequestedCanvas());
const state = getState();
let destination = selectCanvasSessionId(state);
if (destination === null) {
destination = getPrefixedId('canvas');
dispatch(canvasSessionIdChanged({ id: destination }));
}
const buildGraphResult = await withResultAsync(async () => {
const model = state.params.model;
assert(model, 'No model found in state');
const base = model.base;
const generationMode = await canvasManager.compositor.getGenerationMode();
const graphBuilderArg: GraphBuilderArg = { generationMode, state, manager: canvasManager };
switch (base) {
case 'sdxl':
return await buildSDXLGraph(graphBuilderArg);
case 'sd-1':
case `sd-2`:
return await buildSD1Graph(graphBuilderArg);
case `sd-3`:
return await buildSD3Graph(graphBuilderArg);
case `flux`:
return await buildFLUXGraph(graphBuilderArg);
case 'cogview4':
return await buildCogView4Graph(graphBuilderArg);
case 'imagen3':
return buildImagen3Graph(graphBuilderArg);
case 'imagen4':
return buildImagen4Graph(graphBuilderArg);
case 'chatgpt-4o':
return await buildChatGPT4oGraph(graphBuilderArg);
case 'flux-kontext':
return buildFluxKontextGraph(graphBuilderArg);
default:
assert(false, `No graph builders for base ${base}`);
}
});
if (buildGraphResult.isErr()) {
let title = 'Failed to build graph';
let status: AlertStatus = 'error';
let description: string | null = null;
if (buildGraphResult.error instanceof AssertionError) {
description = extractMessageFromAssertionError(buildGraphResult.error);
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
title = 'Unsupported generation mode';
description = buildGraphResult.error.message;
status = 'warning';
}
const error = serializeError(buildGraphResult.error);
log.error({ error }, 'Failed to build graph');
toast({
status,
title,
description,
});
return;
}
const { g, seed, positivePrompt } = buildGraphResult.value;
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch({
state,
g,
prepend,
seedNode: seed,
positivePromptNode: positivePrompt,
origin: 'canvas',
destination,
})
);
if (prepareBatchResult.isErr()) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
return;
}
const batchConfig = prepareBatchResult.value;
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
...enqueueMutationFixedCacheKeyOptions,
track: false,
})
);
const enqueueResult = await req.unwrap();
return { batchConfig, enqueueResult };
};
export const useEnqueueCanvas = () => {
const store = useAppStore();
const canvasManager = useCanvasManagerSafe();
const enqueue = useCallback(
(prepend: boolean) => {
if (!canvasManager) {
log.error('Canvas manager is not available');
return;
}
return enqueueCanvas(store, canvasManager, prepend);
},
[canvasManager, store]
);
return enqueue;
};