refactor(ui): graph building respects selected tab

This commit is contained in:
psychedelicious
2025-06-30 16:34:33 +10:00
parent e00ccba7d3
commit a035645ed3
14 changed files with 223 additions and 287 deletions

View File

@@ -7,7 +7,8 @@ import { extractMessageFromAssertionError } from 'common/util/extractMessageFrom
import { withResult, withResultAsync } from 'common/util/result';
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { canvasSessionIdCreated, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { canvasSessionIdChanged, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
@@ -18,6 +19,7 @@ import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildIma
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import type { GraphBuilderArg } from 'features/nodes/util/graph/types';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
@@ -33,40 +35,42 @@ const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prep
dispatch(enqueueRequestedCanvas());
let destination = selectCanvasSessionId(getState());
if (!destination) {
dispatch(canvasSessionIdCreated());
destination = selectCanvasSessionId(getState());
}
assert(destination !== null);
const state = getState();
const model = state.params.model;
assert(model, 'No model found in state');
const base = model.base;
let destination = selectCanvasSessionId(state);
if (destination === null) {
destination = getPrefixedId('canvas');
dispatch(canvasSessionIdChanged({ id: destination }));
}
const buildGraphResult = await withResultAsync(async () => {
const model = state.params.model;
assert(model, 'No model found in state');
const base = model.base;
const generationMode = await canvasManager.compositor.getGenerationMode();
const graphBuilderArg: GraphBuilderArg = { generationMode, state, canvasManager };
switch (base) {
case 'sdxl':
return await buildSDXLGraph(state, canvasManager);
return await buildSDXLGraph(graphBuilderArg);
case 'sd-1':
case `sd-2`:
return await buildSD1Graph(state, canvasManager);
return await buildSD1Graph(graphBuilderArg);
case `sd-3`:
return await buildSD3Graph(state, canvasManager);
return await buildSD3Graph(graphBuilderArg);
case `flux`:
return await buildFLUXGraph(state, canvasManager);
return await buildFLUXGraph(graphBuilderArg);
case 'cogview4':
return await buildCogView4Graph(state, canvasManager);
return await buildCogView4Graph(graphBuilderArg);
case 'imagen3':
return await buildImagen3Graph(state, canvasManager);
return buildImagen3Graph(graphBuilderArg);
case 'imagen4':
return await buildImagen4Graph(state, canvasManager);
return buildImagen4Graph(graphBuilderArg);
case 'chatgpt-4o':
return await buildChatGPT4oGraph(state, canvasManager);
return await buildChatGPT4oGraph(graphBuilderArg);
case 'flux-kontext':
return await buildFluxKontextGraph(state, canvasManager);
return buildFluxKontextGraph(graphBuilderArg);
default:
assert(false, `No graph builders for base ${base}`);
}

View File

@@ -5,7 +5,8 @@ import type { AppStore } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
import { withResult, withResultAsync } from 'common/util/result';
import { generateSessionIdCreated, selectGenerateSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { generateSessionIdChanged, selectGenerateSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
@@ -16,6 +17,7 @@ import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildIma
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import type { GraphBuilderArg } from 'features/nodes/util/graph/types';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
@@ -32,40 +34,41 @@ const enqueueGenerate = async (store: AppStore, prepend: boolean) => {
dispatch(enqueueRequestedGenerate());
let destination = selectGenerateSessionId(getState());
if (!destination) {
dispatch(generateSessionIdCreated());
destination = selectGenerateSessionId(getState());
}
assert(destination !== null);
const state = getState();
const model = state.params.model;
assert(model, 'No model found in state');
const base = model.base;
let destination = selectGenerateSessionId(state);
if (destination === null) {
destination = getPrefixedId('generate');
dispatch(generateSessionIdChanged({ id: destination }));
}
const buildGraphResult = await withResultAsync(async () => {
const model = state.params.model;
assert(model, 'No model found in state');
const base = model.base;
const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state };
switch (base) {
case 'sdxl':
return await buildSDXLGraph(state, null);
return await buildSDXLGraph(graphBuilderArg);
case 'sd-1':
case `sd-2`:
return await buildSD1Graph(state, null);
return await buildSD1Graph(graphBuilderArg);
case `sd-3`:
return await buildSD3Graph(state, null);
return await buildSD3Graph(graphBuilderArg);
case `flux`:
return await buildFLUXGraph(state, null);
return await buildFLUXGraph(graphBuilderArg);
case 'cogview4':
return await buildCogView4Graph(state, null);
return await buildCogView4Graph(graphBuilderArg);
case 'imagen3':
return await buildImagen3Graph(state, null);
return buildImagen3Graph(graphBuilderArg);
case 'imagen4':
return await buildImagen4Graph(state, null);
return buildImagen4Graph(graphBuilderArg);
case 'chatgpt-4o':
return await buildChatGPT4oGraph(state, null);
return await buildChatGPT4oGraph(graphBuilderArg);
case 'flux-kontext':
return await buildFluxKontextGraph(state, null);
return buildFluxKontextGraph(graphBuilderArg);
default:
assert(false, `No graph builders for base ${base}`);
}