fix(ui): control layers ignored in txt2img

This commit is contained in:
psychedelicious
2025-07-03 07:41:20 +10:00
parent e0d7fab524
commit f36d22f13c
13 changed files with 76 additions and 81 deletions

View File

@@ -17,13 +17,13 @@ import { assert } from 'tsafe';
const log = logger('system');
export const buildChatGPT4oGraph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
const { generationMode, state } = arg;
const { generationMode, state, manager } = arg;
if (generationMode !== 'txt2img' && generationMode !== 'img2img') {
throw new UnsupportedGenerationModeError(t('toast.chatGPT4oIncompatibleGenerationMode'));
}
log.debug({ generationMode }, 'Building GPT Image graph');
log.debug({ generationMode, manager: manager?.id }, 'Building ChatGPT 4o graph');
const model = selectMainModelConfig(state);
@@ -80,8 +80,9 @@ export const buildChatGPT4oGraph = async (arg: GraphBuilderArg): Promise<GraphBu
}
if (generationMode === 'img2img') {
const adapters = arg.canvasManager.compositor.getVisibleAdaptersOfType('raster_layer');
const { image_name } = await arg.canvasManager.compositor.getCompositeImageDTO(adapters, bbox.rect, {
assert(manager !== null);
const adapters = manager.compositor.getVisibleAdaptersOfType('raster_layer');
const { image_name } = await manager.compositor.getCompositeImageDTO(adapters, bbox.rect, {
is_intermediate: true,
silent: true,
});

View File

@@ -24,8 +24,8 @@ import { assert } from 'tsafe';
const log = logger('system');
export const buildCogView4Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
const { generationMode, state } = arg;
log.debug({ generationMode }, 'Building CogView4 graph');
const { generationMode, state, manager } = arg;
log.debug({ generationMode, manager: manager?.id }, 'Building CogView4 graph');
const params = selectParamsSlice(state);
const canvas = selectCanvasSlice(state);
@@ -104,9 +104,10 @@ export const buildCogView4Graph = async (arg: GraphBuilderArg): Promise<GraphBui
canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize });
g.upsertMetadata({ generation_mode: 'cogview4_txt2img' });
} else if (generationMode === 'img2img') {
assert(manager !== null);
canvasOutput = await addImageToImage({
g,
manager: arg.canvasManager,
manager,
l2i,
i2lNodeType: 'cogview4_i2l',
denoise,
@@ -119,10 +120,11 @@ export const buildCogView4Graph = async (arg: GraphBuilderArg): Promise<GraphBui
});
g.upsertMetadata({ generation_mode: 'cogview4_img2img' });
} else if (generationMode === 'inpaint') {
assert(manager !== null);
canvasOutput = await addInpaint({
state,
g,
manager: arg.canvasManager,
manager,
l2i,
i2lNodeType: 'cogview4_i2l',
denoise,
@@ -136,10 +138,11 @@ export const buildCogView4Graph = async (arg: GraphBuilderArg): Promise<GraphBui
});
g.upsertMetadata({ generation_mode: 'cogview4_inpaint' });
} else if (generationMode === 'outpaint') {
assert(manager !== null);
canvasOutput = await addOutpaint({
state,
g,
manager: arg.canvasManager,
manager,
l2i,
i2lNodeType: 'cogview4_i2l',
denoise,

View File

@@ -32,8 +32,8 @@ import { addIPAdapters } from './addIPAdapters';
const log = logger('system');
export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
const { generationMode, state } = arg;
log.debug({ generationMode }, 'Building FLUX graph');
const { generationMode, state, manager } = arg;
log.debug({ generationMode, manager: manager?.id }, 'Building FLUX graph');
const params = selectParamsSlice(state);
const canvas = selectCanvasSlice(state);
@@ -164,10 +164,11 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
let canvasOutput: Invocation<ImageOutputNodes> = l2i;
if (isFLUXFill && (generationMode === 'inpaint' || generationMode === 'outpaint')) {
assert(manager !== null);
canvasOutput = await addFLUXFill({
state,
g,
manager: arg.canvasManager,
manager,
l2i,
denoise,
originalSize,
@@ -177,9 +178,10 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize });
g.upsertMetadata({ generation_mode: 'flux_txt2img' });
} else if (generationMode === 'img2img') {
assert(manager !== null);
canvasOutput = await addImageToImage({
g,
manager: arg.canvasManager,
manager,
l2i,
i2lNodeType: 'flux_vae_encode',
denoise,
@@ -192,10 +194,11 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
});
g.upsertMetadata({ generation_mode: 'flux_img2img' });
} else if (generationMode === 'inpaint') {
assert(manager !== null);
canvasOutput = await addInpaint({
state,
g,
manager: arg.canvasManager,
manager,
l2i,
i2lNodeType: 'flux_vae_encode',
denoise,
@@ -209,10 +212,11 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
});
g.upsertMetadata({ generation_mode: 'flux_inpaint' });
} else if (generationMode === 'outpaint') {
assert(manager !== null);
canvasOutput = await addOutpaint({
state,
g,
manager: arg.canvasManager,
manager,
l2i,
i2lNodeType: 'flux_vae_encode',
denoise,
@@ -229,13 +233,13 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
assert<Equals<typeof generationMode, never>>(false);
}
if (generationMode === 'img2img' || generationMode === 'inpaint' || generationMode === 'outpaint') {
if (manager !== null) {
const controlNetCollector = g.addNode({
type: 'collect',
id: getPrefixedId('control_net_collector'),
});
const controlNetResult = await addControlNets({
manager: arg.canvasManager,
manager,
entities: canvas.controlLayers.entities,
g,
rect: canvas.bbox.rect,
@@ -249,7 +253,7 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
}
await addControlLoRA({
manager: arg.canvasManager,
manager,
entities: canvas.controlLayers.entities,
g,
rect: canvas.bbox.rect,
@@ -283,9 +287,9 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
});
let totalReduxesAdded = fluxReduxResult.addedFLUXReduxes;
if (generationMode === 'img2img' || generationMode === 'inpaint' || generationMode === 'outpaint') {
if (manager !== null) {
const regionsResult = await addRegions({
manager: arg.canvasManager,
manager,
regions: canvas.regionalGuidance.entities,
g,
bbox: canvas.bbox.rect,

View File

@@ -17,13 +17,13 @@ import { assert } from 'tsafe';
const log = logger('system');
export const buildFluxKontextGraph = (arg: GraphBuilderArg): GraphBuilderReturn => {
const { generationMode, state } = arg;
const { generationMode, state, manager } = arg;
if (generationMode !== 'txt2img') {
throw new UnsupportedGenerationModeError(t('toast.imagenIncompatibleGenerationMode', { model: 'FLUX Kontext' }));
}
log.debug({ generationMode }, 'Building Flux Kontext graph');
log.debug({ generationMode, manager: manager?.id }, 'Building FLUX Kontext graph');
const model = selectMainModelConfig(state);

View File

@@ -14,13 +14,13 @@ import { assert } from 'tsafe';
const log = logger('system');
export const buildImagen3Graph = (arg: GraphBuilderArg): GraphBuilderReturn => {
const { generationMode, state } = arg;
const { generationMode, state, manager } = arg;
if (generationMode !== 'txt2img') {
throw new UnsupportedGenerationModeError(t('toast.imagenIncompatibleGenerationMode', { model: 'Imagen3' }));
}
log.debug({ generationMode }, 'Building Imagen3 graph');
log.debug({ generationMode, manager: manager?.id }, 'Building Imagen3 graph');
const canvas = selectCanvasSlice(state);

View File

@@ -14,13 +14,13 @@ import { assert } from 'tsafe';
const log = logger('system');
export const buildImagen4Graph = (arg: GraphBuilderArg): GraphBuilderReturn => {
const { generationMode, state } = arg;
const { generationMode, state, manager } = arg;
if (generationMode !== 'txt2img') {
throw new UnsupportedGenerationModeError(t('toast.imagenIncompatibleGenerationMode', { model: 'Imagen4' }));
}
log.debug({ generationMode }, 'Building Imagen4 graph');
log.debug({ generationMode, manager: manager?.id }, 'Building Imagen4 graph');
const canvas = selectCanvasSlice(state);

View File

@@ -30,8 +30,8 @@ import { addRegions } from './addRegions';
const log = logger('system');
export const buildSD1Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
const { generationMode, state } = arg;
log.debug({ generationMode }, 'Building SD1/SD2 graph');
const { generationMode, state, manager } = arg;
log.debug({ generationMode, manager: manager?.id }, 'Building SD1/SD2 graph');
const params = selectParamsSlice(state);
const canvas = selectCanvasSlice(state);
@@ -166,9 +166,10 @@ export const buildSD1Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize });
g.upsertMetadata({ generation_mode: 'txt2img' });
} else if (generationMode === 'img2img') {
assert(manager !== null);
canvasOutput = await addImageToImage({
g,
manager: arg.canvasManager,
manager,
l2i,
i2lNodeType: 'i2l',
denoise,
@@ -181,10 +182,11 @@ export const buildSD1Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
});
g.upsertMetadata({ generation_mode: 'img2img' });
} else if (generationMode === 'inpaint') {
assert(manager !== null);
canvasOutput = await addInpaint({
state,
g,
manager: arg.canvasManager,
manager,
l2i,
i2lNodeType: 'i2l',
denoise,
@@ -198,10 +200,11 @@ export const buildSD1Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
});
g.upsertMetadata({ generation_mode: 'inpaint' });
} else if (generationMode === 'outpaint') {
assert(manager !== null);
canvasOutput = await addOutpaint({
state,
g,
manager: arg.canvasManager,
manager,
l2i,
i2lNodeType: 'i2l',
denoise,
@@ -218,13 +221,13 @@ export const buildSD1Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
assert<Equals<typeof generationMode, never>>(false);
}
if (generationMode === 'img2img' || generationMode === 'inpaint' || generationMode === 'outpaint') {
if (manager !== null) {
const controlNetCollector = g.addNode({
type: 'collect',
id: getPrefixedId('control_net_collector'),
});
const controlNetResult = await addControlNets({
manager: arg.canvasManager,
manager,
entities: canvas.controlLayers.entities,
g,
rect: canvas.bbox.rect,
@@ -242,7 +245,7 @@ export const buildSD1Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
id: getPrefixedId('t2i_adapter_collector'),
});
const t2iAdapterResult = await addT2IAdapters({
manager: arg.canvasManager,
manager,
entities: canvas.controlLayers.entities,
g,
rect: canvas.bbox.rect,
@@ -268,9 +271,9 @@ export const buildSD1Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
});
let totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters;
if (generationMode === 'img2img' || generationMode === 'inpaint' || generationMode === 'outpaint') {
if (manager !== null) {
const regionsResult = await addRegions({
manager: arg.canvasManager,
manager,
regions: canvas.regionalGuidance.entities,
g,
bbox: canvas.bbox.rect,

View File

@@ -22,8 +22,8 @@ import { assert } from 'tsafe';
const log = logger('system');
export const buildSD3Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
const { generationMode, state } = arg;
log.debug({ generationMode }, 'Building SD3 graph');
const { generationMode, state, manager } = arg;
log.debug({ generationMode, manager: manager?.id }, 'Building SD3 graph');
const model = selectMainModelConfig(state);
assert(model, 'No model found in state');
@@ -128,9 +128,10 @@ export const buildSD3Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize });
g.upsertMetadata({ generation_mode: 'sd3_txt2img' });
} else if (generationMode === 'img2img') {
assert(manager !== null);
canvasOutput = await addImageToImage({
g,
manager: arg.canvasManager,
manager,
l2i,
i2lNodeType: 'sd3_i2l',
denoise,
@@ -143,10 +144,11 @@ export const buildSD3Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
});
g.upsertMetadata({ generation_mode: 'sd3_img2img' });
} else if (generationMode === 'inpaint') {
assert(manager !== null);
canvasOutput = await addInpaint({
state,
g,
manager: arg.canvasManager,
manager,
l2i,
i2lNodeType: 'sd3_i2l',
denoise,
@@ -160,10 +162,11 @@ export const buildSD3Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
});
g.upsertMetadata({ generation_mode: 'sd3_inpaint' });
} else if (generationMode === 'outpaint') {
assert(manager !== null);
canvasOutput = await addOutpaint({
state,
g,
manager: arg.canvasManager,
manager,
l2i,
i2lNodeType: 'sd3_i2l',
denoise,

View File

@@ -30,8 +30,8 @@ import { addRegions } from './addRegions';
const log = logger('system');
export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
const { generationMode, state } = arg;
log.debug({ generationMode }, 'Building SDXL graph');
const { generationMode, state, manager } = arg;
log.debug({ generationMode, manager: manager?.id }, 'Building SDXL graph');
const model = selectMainModelConfig(state);
assert(model, 'No model found in state');
@@ -173,9 +173,10 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize });
g.upsertMetadata({ generation_mode: 'sdxl_txt2img' });
} else if (generationMode === 'img2img') {
assert(manager !== null);
canvasOutput = await addImageToImage({
g,
manager: arg.canvasManager,
manager,
l2i,
i2lNodeType: 'i2l',
denoise,
@@ -188,10 +189,11 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
});
g.upsertMetadata({ generation_mode: 'sdxl_img2img' });
} else if (generationMode === 'inpaint') {
assert(manager !== null);
canvasOutput = await addInpaint({
state,
g,
manager: arg.canvasManager,
manager,
l2i,
i2lNodeType: 'i2l',
denoise,
@@ -205,10 +207,11 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
});
g.upsertMetadata({ generation_mode: 'sdxl_inpaint' });
} else if (generationMode === 'outpaint') {
assert(manager !== null);
canvasOutput = await addOutpaint({
state,
g,
manager: arg.canvasManager,
manager,
l2i,
i2lNodeType: 'i2l',
denoise,
@@ -225,13 +228,13 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
assert<Equals<typeof generationMode, never>>(false);
}
if (generationMode === 'img2img' || generationMode === 'inpaint' || generationMode === 'outpaint') {
if (manager !== null) {
const controlNetCollector = g.addNode({
type: 'collect',
id: getPrefixedId('control_net_collector'),
});
const controlNetResult = await addControlNets({
manager: arg.canvasManager,
manager,
entities: canvas.controlLayers.entities,
g,
rect: canvas.bbox.rect,
@@ -249,7 +252,7 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
id: getPrefixedId('t2i_adapter_collector'),
});
const t2iAdapterResult = await addT2IAdapters({
manager: arg.canvasManager,
manager,
entities: canvas.controlLayers.entities,
g,
rect: canvas.bbox.rect,
@@ -275,9 +278,9 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
});
let totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters;
if (generationMode === 'img2img' || generationMode === 'inpaint' || generationMode === 'outpaint') {
if (manager !== null) {
const regionsResult = await addRegions({
manager: arg.canvasManager,
manager,
regions: canvas.regionalGuidance.entities,
g,
bbox: canvas.bbox.rect,

View File

@@ -1,7 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import { pick } from 'es-toolkit/compat';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import type { CanvasState, ParamsState } from 'features/controlLayers/store/types';
@@ -14,7 +13,7 @@ import { selectListStylePresetsRequestState } from 'services/api/endpoints/style
import type { Invocation, S } from 'services/api/types';
import { assert } from 'tsafe';
import type { GraphBuilderArg, MainModelLoaderNodes } from './types';
import type { MainModelLoaderNodes } from './types';
/**
* Gets the board field, based on the autoAddBoardId setting.
@@ -166,19 +165,3 @@ export const isCanvasOutputNodeId = (nodeId: string) => nodeId.split(':')[0] ===
export const isCanvasOutputEvent = (data: S['InvocationCompleteEvent']) => {
return isCanvasOutputNodeId(data.invocation_source_id);
};
export const getGraphBuilderArg = async (
state: RootState,
canvasManager: CanvasManager | null
): Promise<GraphBuilderArg> => {
const tab = selectActiveTab(state);
if (tab === 'generate') {
return { generationMode: 'txt2img', state };
} else if (tab === 'canvas') {
assert(canvasManager !== null, 'CanvasManager should not be null in canvas tab');
const generationMode = await canvasManager.compositor.getGenerationMode();
return { generationMode, state, canvasManager };
} else {
assert(false, `Unknown tab: ${tab}`);
}
};

View File

@@ -30,16 +30,11 @@ export type MainModelLoaderNodes =
export type VaeSourceNodes = 'seamless' | 'vae_loader';
export type GraphBuilderArg =
| {
generationMode: Extract<GenerationMode, 'txt2img'>;
state: RootState;
}
| {
generationMode: Exclude<GenerationMode, 'txt2img'>;
state: RootState;
canvasManager: CanvasManager;
};
export type GraphBuilderArg = {
generationMode: GenerationMode;
state: RootState;
manager: CanvasManager | null;
};
export type GraphBuilderReturn = {
g: Graph;

View File

@@ -49,7 +49,7 @@ const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prep
const base = model.base;
const generationMode = await canvasManager.compositor.getGenerationMode();
const graphBuilderArg: GraphBuilderArg = { generationMode, state, canvasManager };
const graphBuilderArg: GraphBuilderArg = { generationMode, state, manager: canvasManager };
switch (base) {
case 'sdxl':

View File

@@ -47,7 +47,7 @@ const enqueueGenerate = async (store: AppStore, prepend: boolean) => {
assert(model, 'No model found in state');
const base = model.base;
const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state };
const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state, manager: null };
switch (base) {
case 'sdxl':