mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-14 04:45:04 -05:00
148 lines
5.5 KiB
TypeScript
148 lines
5.5 KiB
TypeScript
import type { AlertStatus } from '@invoke-ai/ui-library';
|
|
import { createAction } from '@reduxjs/toolkit';
|
|
import { logger } from 'app/logging/logger';
|
|
import type { AppStore } from 'app/store/store';
|
|
import { useAppStore } from 'app/store/storeHooks';
|
|
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
|
|
import { withResult, withResultAsync } from 'common/util/result';
|
|
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
|
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
|
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
|
import { canvasSessionIdChanged, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
|
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
|
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
|
|
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
|
|
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
|
|
import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph';
|
|
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
|
|
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
|
|
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
|
|
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
|
|
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
|
|
import type { GraphBuilderArg } from 'features/nodes/util/graph/types';
|
|
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
|
import { toast } from 'features/toast/toast';
|
|
import { useCallback } from 'react';
|
|
import { serializeError } from 'serialize-error';
|
|
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
|
import { assert, AssertionError } from 'tsafe';
|
|
|
|
const log = logger('generation');
|
|
export const enqueueRequestedCanvas = createAction('app/enqueueRequestedCanvas');
|
|
|
|
const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prepend: boolean) => {
|
|
const { dispatch, getState } = store;
|
|
|
|
dispatch(enqueueRequestedCanvas());
|
|
|
|
const state = getState();
|
|
|
|
let destination = selectCanvasSessionId(state);
|
|
if (destination === null) {
|
|
destination = getPrefixedId('canvas');
|
|
dispatch(canvasSessionIdChanged({ id: destination }));
|
|
}
|
|
|
|
const buildGraphResult = await withResultAsync(async () => {
|
|
const model = state.params.model;
|
|
assert(model, 'No model found in state');
|
|
const base = model.base;
|
|
|
|
const generationMode = await canvasManager.compositor.getGenerationMode();
|
|
const graphBuilderArg: GraphBuilderArg = { generationMode, state, manager: canvasManager };
|
|
|
|
switch (base) {
|
|
case 'sdxl':
|
|
return await buildSDXLGraph(graphBuilderArg);
|
|
case 'sd-1':
|
|
case `sd-2`:
|
|
return await buildSD1Graph(graphBuilderArg);
|
|
case `sd-3`:
|
|
return await buildSD3Graph(graphBuilderArg);
|
|
case `flux`:
|
|
return await buildFLUXGraph(graphBuilderArg);
|
|
case 'cogview4':
|
|
return await buildCogView4Graph(graphBuilderArg);
|
|
case 'imagen3':
|
|
return buildImagen3Graph(graphBuilderArg);
|
|
case 'imagen4':
|
|
return buildImagen4Graph(graphBuilderArg);
|
|
case 'chatgpt-4o':
|
|
return await buildChatGPT4oGraph(graphBuilderArg);
|
|
case 'flux-kontext':
|
|
return buildFluxKontextGraph(graphBuilderArg);
|
|
default:
|
|
assert(false, `No graph builders for base ${base}`);
|
|
}
|
|
});
|
|
|
|
if (buildGraphResult.isErr()) {
|
|
let title = 'Failed to build graph';
|
|
let status: AlertStatus = 'error';
|
|
let description: string | null = null;
|
|
if (buildGraphResult.error instanceof AssertionError) {
|
|
description = extractMessageFromAssertionError(buildGraphResult.error);
|
|
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
|
|
title = 'Unsupported generation mode';
|
|
description = buildGraphResult.error.message;
|
|
status = 'warning';
|
|
}
|
|
const error = serializeError(buildGraphResult.error);
|
|
log.error({ error }, 'Failed to build graph');
|
|
toast({
|
|
status,
|
|
title,
|
|
description,
|
|
});
|
|
return;
|
|
}
|
|
|
|
const { g, seed, positivePrompt } = buildGraphResult.value;
|
|
|
|
const prepareBatchResult = withResult(() =>
|
|
prepareLinearUIBatch({
|
|
state,
|
|
g,
|
|
prepend,
|
|
seedNode: seed,
|
|
positivePromptNode: positivePrompt,
|
|
origin: 'canvas',
|
|
destination,
|
|
})
|
|
);
|
|
|
|
if (prepareBatchResult.isErr()) {
|
|
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
|
|
return;
|
|
}
|
|
|
|
const batchConfig = prepareBatchResult.value;
|
|
|
|
const req = dispatch(
|
|
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
|
...enqueueMutationFixedCacheKeyOptions,
|
|
track: false,
|
|
})
|
|
);
|
|
|
|
const enqueueResult = await req.unwrap();
|
|
|
|
return { batchConfig, enqueueResult };
|
|
};
|
|
|
|
export const useEnqueueCanvas = () => {
|
|
const store = useAppStore();
|
|
const canvasManager = useCanvasManagerSafe();
|
|
const enqueue = useCallback(
|
|
(prepend: boolean) => {
|
|
if (!canvasManager) {
|
|
log.error('Canvas manager is not available');
|
|
return;
|
|
}
|
|
return enqueueCanvas(store, canvasManager, prepend);
|
|
},
|
|
[canvasManager, store]
|
|
);
|
|
return enqueue;
|
|
};
|