refactor(ui): graph building respects selected tab

This commit is contained in:
psychedelicious
2025-06-30 16:34:33 +10:00
parent e00ccba7d3
commit a035645ed3
14 changed files with 223 additions and 287 deletions

View File

@@ -1,7 +1,6 @@
import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { canvasReset } from 'features/controlLayers/store/actions';
type CanvasStagingAreaState = {
@@ -20,26 +19,16 @@ export const canvasSessionSlice = createSlice({
name: 'canvasSession',
initialState: getInitialState(),
reducers: {
generateSessionIdCreated: {
reducer: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
state.generateSessionId = id;
},
prepare: () => ({
payload: { id: getPrefixedId('generate') },
}),
generateSessionIdChanged: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
state.generateSessionId = id;
},
generateSessionReset: (state) => {
state.generateSessionId = null;
},
canvasSessionIdCreated: {
reducer: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
state.canvasSessionId = id;
},
prepare: () => ({
payload: { id: getPrefixedId('canvas') },
}),
canvasSessionIdChanged: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
state.canvasSessionId = id;
},
canvasSessionReset: (state) => {
state.canvasSessionId = null;
@@ -52,7 +41,7 @@ export const canvasSessionSlice = createSlice({
},
});
export const { generateSessionIdCreated, generateSessionReset, canvasSessionIdCreated, canvasSessionReset } =
export const { generateSessionIdChanged, generateSessionReset, canvasSessionIdChanged, canvasSessionReset } =
canvasSessionSlice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */

View File

@@ -1,6 +1,4 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
@@ -8,23 +6,18 @@ import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { isChatGPT4oAspectRatioID, isChatGPT4oReferenceImageConfig } from 'features/controlLayers/store/types';
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
import { type ImageField, zModelIdentifierField } from 'features/nodes/types/common';
import { getGenerationMode } from 'features/nodes/util/graph/generation/getGenerationMode';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { selectCanvasOutputFields, selectPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import { type GraphBuilderReturn, UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import type { GraphBuilderArg, GraphBuilderReturn } from 'features/nodes/util/graph/types';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { t } from 'i18next';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
const log = logger('system');
export const buildChatGPT4oGraph = async (
state: RootState,
manager: CanvasManager | null
): Promise<GraphBuilderReturn> => {
const tab = selectActiveTab(state);
const generationMode = await getGenerationMode(manager, tab);
export const buildChatGPT4oGraph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
const { generationMode, state } = arg;
if (generationMode !== 'txt2img' && generationMode !== 'img2img') {
throw new UnsupportedGenerationModeError(t('toast.chatGPT4oIncompatibleGenerationMode'));
@@ -86,9 +79,8 @@ export const buildChatGPT4oGraph = async (
}
if (generationMode === 'img2img') {
assert(manager, 'Need manager to do img2img');
const adapters = manager.compositor.getVisibleAdaptersOfType('raster_layer');
const { image_name } = await manager.compositor.getCompositeImageDTO(adapters, bbox.rect, {
const adapters = arg.canvasManager.compositor.getVisibleAdaptersOfType('raster_layer');
const { image_name } = await arg.canvasManager.compositor.getCompositeImageDTO(adapters, bbox.rect, {
is_intermediate: true,
silent: true,
});

View File

@@ -1,6 +1,4 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
@@ -11,15 +9,13 @@ import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChec
import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import { getGenerationMode } from 'features/nodes/util/graph/generation/getGenerationMode';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import {
getSizes,
selectCanvasOutputFields,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
import type { Invocation } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import type { Equals } from 'tsafe';
@@ -27,12 +23,8 @@ import { assert } from 'tsafe';
const log = logger('system');
export const buildCogView4Graph = async (
state: RootState,
manager: CanvasManager | null
): Promise<GraphBuilderReturn> => {
const tab = selectActiveTab(state);
const generationMode = await getGenerationMode(manager, tab);
export const buildCogView4Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
const { generationMode, state } = arg;
log.debug({ generationMode }, 'Building CogView4 graph');
const params = selectParamsSlice(state);
@@ -112,10 +104,9 @@ export const buildCogView4Graph = async (
canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize });
g.upsertMetadata({ generation_mode: 'cogview4_txt2img' });
} else if (generationMode === 'img2img') {
assert(manager, 'Need manager to do img2img');
canvasOutput = await addImageToImage({
g,
manager,
manager: arg.canvasManager,
l2i,
i2lNodeType: 'cogview4_i2l',
denoise,
@@ -128,11 +119,10 @@ export const buildCogView4Graph = async (
});
g.upsertMetadata({ generation_mode: 'cogview4_img2img' });
} else if (generationMode === 'inpaint') {
assert(manager, 'Need manager to do inpaint');
canvasOutput = await addInpaint({
state,
g,
manager,
manager: arg.canvasManager,
l2i,
i2lNodeType: 'cogview4_i2l',
denoise,
@@ -146,11 +136,10 @@ export const buildCogView4Graph = async (
});
g.upsertMetadata({ generation_mode: 'cogview4_inpaint' });
} else if (generationMode === 'outpaint') {
assert(manager, 'Need manager to do outpaint');
canvasOutput = await addOutpaint({
state,
g,
manager,
manager: arg.canvasManager,
l2i,
i2lNodeType: 'cogview4_i2l',
denoise,

View File

@@ -1,6 +1,4 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
@@ -15,19 +13,14 @@ import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
import { addRegions } from 'features/nodes/util/graph/generation/addRegions';
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import { getGenerationMode } from 'features/nodes/util/graph/generation/getGenerationMode';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import {
getSizes,
selectCanvasOutputFields,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import {
type GraphBuilderReturn,
type ImageOutputNodes,
UnsupportedGenerationModeError,
} from 'features/nodes/util/graph/types';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { t } from 'i18next';
import type { Invocation } from 'services/api/types';
import type { Equals } from 'tsafe';
@@ -38,9 +31,8 @@ import { addIPAdapters } from './addIPAdapters';
const log = logger('system');
export const buildFLUXGraph = async (state: RootState, manager: CanvasManager | null): Promise<GraphBuilderReturn> => {
const tab = selectActiveTab(state);
const generationMode = await getGenerationMode(manager, tab);
export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
const { generationMode, state } = arg;
log.debug({ generationMode }, 'Building FLUX graph');
const params = selectParamsSlice(state);
@@ -171,12 +163,11 @@ export const buildFLUXGraph = async (state: RootState, manager: CanvasManager |
let canvasOutput: Invocation<ImageOutputNodes> = l2i;
if (isFLUXFill) {
assert(manager, 'Need manager to do FLUX Fill');
if (isFLUXFill && (generationMode === 'inpaint' || generationMode === 'outpaint')) {
canvasOutput = await addFLUXFill({
state,
g,
manager,
manager: arg.canvasManager,
l2i,
denoise,
originalSize,
@@ -186,10 +177,9 @@ export const buildFLUXGraph = async (state: RootState, manager: CanvasManager |
canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize });
g.upsertMetadata({ generation_mode: 'flux_txt2img' });
} else if (generationMode === 'img2img') {
assert(manager, 'Need manager to do img2img');
canvasOutput = await addImageToImage({
g,
manager,
manager: arg.canvasManager,
l2i,
i2lNodeType: 'flux_vae_encode',
denoise,
@@ -202,11 +192,10 @@ export const buildFLUXGraph = async (state: RootState, manager: CanvasManager |
});
g.upsertMetadata({ generation_mode: 'flux_img2img' });
} else if (generationMode === 'inpaint') {
assert(manager, 'Need manager to do inpaint');
canvasOutput = await addInpaint({
state,
g,
manager,
manager: arg.canvasManager,
l2i,
i2lNodeType: 'flux_vae_encode',
denoise,
@@ -220,11 +209,10 @@ export const buildFLUXGraph = async (state: RootState, manager: CanvasManager |
});
g.upsertMetadata({ generation_mode: 'flux_inpaint' });
} else if (generationMode === 'outpaint') {
assert(manager, 'Need manager to do outpaint');
canvasOutput = await addOutpaint({
state,
g,
manager,
manager: arg.canvasManager,
l2i,
i2lNodeType: 'flux_vae_encode',
denoise,
@@ -241,13 +229,13 @@ export const buildFLUXGraph = async (state: RootState, manager: CanvasManager |
assert<Equals<typeof generationMode, never>>(false);
}
if (manager) {
if (generationMode === 'img2img' || generationMode === 'inpaint' || generationMode === 'outpaint') {
const controlNetCollector = g.addNode({
type: 'collect',
id: getPrefixedId('control_net_collector'),
});
const controlNetResult = await addControlNets({
manager,
manager: arg.canvasManager,
entities: canvas.controlLayers.entities,
g,
rect: canvas.bbox.rect,
@@ -261,7 +249,7 @@ export const buildFLUXGraph = async (state: RootState, manager: CanvasManager |
}
await addControlLoRA({
manager,
manager: arg.canvasManager,
entities: canvas.controlLayers.entities,
g,
rect: canvas.bbox.rect,
@@ -295,9 +283,9 @@ export const buildFLUXGraph = async (state: RootState, manager: CanvasManager |
});
let totalReduxesAdded = fluxReduxResult.addedFLUXReduxes;
if (manager) {
if (generationMode === 'img2img' || generationMode === 'inpaint' || generationMode === 'outpaint') {
const regionsResult = await addRegions({
manager,
manager: arg.canvasManager,
regions: canvas.regionalGuidance.entities,
g,
bbox: canvas.bbox.rect,

View File

@@ -1,6 +1,4 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
@@ -11,22 +9,15 @@ import type { ImageField } from 'features/nodes/types/common';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { selectCanvasOutputFields, selectPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import { type GraphBuilderReturn, UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import type { GraphBuilderArg, GraphBuilderReturn } from 'features/nodes/util/graph/types';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { t } from 'i18next';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
import { getGenerationMode } from './getGenerationMode';
const log = logger('system');
export const buildFluxKontextGraph = async (
state: RootState,
manager: CanvasManager | null
): Promise<GraphBuilderReturn> => {
const tab = selectActiveTab(state);
const generationMode = await getGenerationMode(manager, tab);
export const buildFluxKontextGraph = (arg: GraphBuilderArg): GraphBuilderReturn => {
const { generationMode, state } = arg;
if (generationMode !== 'txt2img') {
throw new UnsupportedGenerationModeError(t('toast.imagenIncompatibleGenerationMode', { model: 'FLUX Kontext' }));
@@ -61,29 +52,25 @@ export const buildFluxKontextGraph = async (
};
}
if (generationMode === 'txt2img') {
const g = new Graph(getPrefixedId('flux_kontext_txt2img_graph'));
const fluxKontextImage = g.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: input_image ? 'flux_kontext_edit_image' : 'flux_kontext_generate_image',
model: zModelIdentifierField.parse(model),
positive_prompt: positivePrompt,
aspect_ratio: bbox.aspectRatio.id,
input_image,
prompt_upsampling: true,
...selectCanvasOutputFields(state),
});
g.upsertMetadata({
positive_prompt: positivePrompt,
model: Graph.getModelMetadataField(model),
width: bbox.rect.width,
height: bbox.rect.height,
});
return {
g,
positivePromptFieldIdentifier: { nodeId: fluxKontextImage.id, fieldName: 'positive_prompt' },
};
}
assert<Equals<typeof generationMode, never>>(false, 'Invalid generation mode for Flux Kontext');
const g = new Graph(getPrefixedId('flux_kontext_txt2img_graph'));
const fluxKontextImage = g.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: input_image ? 'flux_kontext_edit_image' : 'flux_kontext_generate_image',
model: zModelIdentifierField.parse(model),
positive_prompt: positivePrompt,
aspect_ratio: bbox.aspectRatio.id,
input_image,
prompt_upsampling: true,
...selectCanvasOutputFields(state),
});
g.upsertMetadata({
positive_prompt: positivePrompt,
model: Graph.getModelMetadataField(model),
width: bbox.rect.width,
height: bbox.rect.height,
});
return {
g,
positivePromptFieldIdentifier: { nodeId: fluxKontextImage.id, fieldName: 'positive_prompt' },
};
};

View File

@@ -1,28 +1,20 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { isImagenAspectRatioID } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { getGenerationMode } from 'features/nodes/util/graph/generation/getGenerationMode';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { selectCanvasOutputFields, selectPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import { type GraphBuilderReturn, UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import type { GraphBuilderArg, GraphBuilderReturn } from 'features/nodes/util/graph/types';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { t } from 'i18next';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
const log = logger('system');
export const buildImagen3Graph = async (
state: RootState,
manager: CanvasManager | null
): Promise<GraphBuilderReturn> => {
const tab = selectActiveTab(state);
const generationMode = await getGenerationMode(manager, tab);
export const buildImagen3Graph = (arg: GraphBuilderArg): GraphBuilderReturn => {
const { generationMode, state } = arg;
if (generationMode !== 'txt2img') {
throw new UnsupportedGenerationModeError(t('toast.imagenIncompatibleGenerationMode', { model: 'Imagen3' }));
@@ -41,33 +33,30 @@ export const buildImagen3Graph = async (
assert(isImagenAspectRatioID(bbox.aspectRatio.id), 'Imagen3 does not support this aspect ratio');
assert(positivePrompt.length > 0, 'Imagen3 requires positive prompt to have at least one character');
if (generationMode === 'txt2img') {
const g = new Graph(getPrefixedId('imagen3_txt2img_graph'));
const imagen3 = g.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: 'google_imagen3_generate_image',
model: zModelIdentifierField.parse(model),
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
aspect_ratio: bbox.aspectRatio.id,
// When enhance_prompt is true, Imagen3 will return a new image every time, ignoring the seed.
enhance_prompt: true,
...selectCanvasOutputFields(state),
});
g.upsertMetadata({
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
width: bbox.rect.width,
height: bbox.rect.height,
model: Graph.getModelMetadataField(model),
...selectCanvasMetadata(state),
});
return {
g,
seedFieldIdentifier: { nodeId: imagen3.id, fieldName: 'seed' },
positivePromptFieldIdentifier: { nodeId: imagen3.id, fieldName: 'positive_prompt' },
};
}
const g = new Graph(getPrefixedId('imagen3_txt2img_graph'));
const imagen3 = g.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: 'google_imagen3_generate_image',
model: zModelIdentifierField.parse(model),
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
aspect_ratio: bbox.aspectRatio.id,
// When enhance_prompt is true, Imagen3 will return a new image every time, ignoring the seed.
enhance_prompt: true,
...selectCanvasOutputFields(state),
});
g.upsertMetadata({
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
width: bbox.rect.width,
height: bbox.rect.height,
model: Graph.getModelMetadataField(model),
...selectCanvasMetadata(state),
});
assert<Equals<typeof generationMode, never>>(false, 'Invalid generation mode for Imagen3');
return {
g,
seedFieldIdentifier: { nodeId: imagen3.id, fieldName: 'seed' },
positivePromptFieldIdentifier: { nodeId: imagen3.id, fieldName: 'positive_prompt' },
};
};

View File

@@ -1,28 +1,20 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { isImagenAspectRatioID } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { getGenerationMode } from 'features/nodes/util/graph/generation/getGenerationMode';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { selectCanvasOutputFields, selectPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import { type GraphBuilderReturn, UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import type { GraphBuilderArg, GraphBuilderReturn } from 'features/nodes/util/graph/types';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { t } from 'i18next';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
const log = logger('system');
export const buildImagen4Graph = async (
state: RootState,
manager: CanvasManager | null
): Promise<GraphBuilderReturn> => {
const tab = selectActiveTab(state);
const generationMode = await getGenerationMode(manager, tab);
export const buildImagen4Graph = (arg: GraphBuilderArg): GraphBuilderReturn => {
const { generationMode, state } = arg;
if (generationMode !== 'txt2img') {
throw new UnsupportedGenerationModeError(t('toast.imagenIncompatibleGenerationMode', { model: 'Imagen4' }));
@@ -41,34 +33,30 @@ export const buildImagen4Graph = async (
assert(isImagenAspectRatioID(bbox.aspectRatio.id), 'Imagen4 does not support this aspect ratio');
assert(positivePrompt.length > 0, 'Imagen4 requires positive prompt to have at least one character');
if (generationMode === 'txt2img') {
const g = new Graph(getPrefixedId('imagen4_txt2img_graph'));
const imagen4 = g.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: 'google_imagen4_generate_image',
model: zModelIdentifierField.parse(model),
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
aspect_ratio: bbox.aspectRatio.id,
// When enhance_prompt is true, Imagen4 will return a new image every time, ignoring the seed.
enhance_prompt: true,
...selectCanvasOutputFields(state),
});
g.upsertMetadata({
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
width: bbox.rect.width,
height: bbox.rect.height,
model: Graph.getModelMetadataField(model),
...selectCanvasMetadata(state),
});
const g = new Graph(getPrefixedId('imagen4_txt2img_graph'));
const imagen4 = g.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: 'google_imagen4_generate_image',
model: zModelIdentifierField.parse(model),
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
aspect_ratio: bbox.aspectRatio.id,
// When enhance_prompt is true, Imagen4 will return a new image every time, ignoring the seed.
enhance_prompt: true,
...selectCanvasOutputFields(state),
});
g.upsertMetadata({
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
width: bbox.rect.width,
height: bbox.rect.height,
model: Graph.getModelMetadataField(model),
...selectCanvasMetadata(state),
});
return {
g,
seedFieldIdentifier: { nodeId: imagen4.id, fieldName: 'seed' },
positivePromptFieldIdentifier: { nodeId: imagen4.id, fieldName: 'positive_prompt' },
};
}
assert<Equals<typeof generationMode, never>>(false, 'Invalid generation mode for Imagen4');
return {
g,
seedFieldIdentifier: { nodeId: imagen4.id, fieldName: 'seed' },
positivePromptFieldIdentifier: { nodeId: imagen4.id, fieldName: 'positive_prompt' },
};
};

View File

@@ -1,6 +1,4 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
@@ -16,15 +14,13 @@ import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless';
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import { getGenerationMode } from 'features/nodes/util/graph/generation/getGenerationMode';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import {
getSizes,
selectCanvasOutputFields,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
import type { Invocation } from 'services/api/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
@@ -33,9 +29,8 @@ import { addRegions } from './addRegions';
const log = logger('system');
export const buildSD1Graph = async (state: RootState, manager: CanvasManager | null): Promise<GraphBuilderReturn> => {
const tab = selectActiveTab(state);
const generationMode = await getGenerationMode(manager, tab);
export const buildSD1Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
const { generationMode, state } = arg;
log.debug({ generationMode }, 'Building SD1/SD2 graph');
const params = selectParamsSlice(state);
@@ -171,10 +166,9 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager | n
canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize });
g.upsertMetadata({ generation_mode: 'txt2img' });
} else if (generationMode === 'img2img') {
assert(manager, 'Need manager to do img2img');
canvasOutput = await addImageToImage({
g,
manager,
manager: arg.canvasManager,
l2i,
i2lNodeType: 'i2l',
denoise,
@@ -187,11 +181,10 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager | n
});
g.upsertMetadata({ generation_mode: 'img2img' });
} else if (generationMode === 'inpaint') {
assert(manager, 'Need manager to do inpaint');
canvasOutput = await addInpaint({
state,
g,
manager,
manager: arg.canvasManager,
l2i,
i2lNodeType: 'i2l',
denoise,
@@ -205,11 +198,10 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager | n
});
g.upsertMetadata({ generation_mode: 'inpaint' });
} else if (generationMode === 'outpaint') {
assert(manager, 'Need manager to do outpaint');
canvasOutput = await addOutpaint({
state,
g,
manager,
manager: arg.canvasManager,
l2i,
i2lNodeType: 'i2l',
denoise,
@@ -226,13 +218,13 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager | n
assert<Equals<typeof generationMode, never>>(false);
}
if (manager) {
if (generationMode === 'img2img' || generationMode === 'inpaint' || generationMode === 'outpaint') {
const controlNetCollector = g.addNode({
type: 'collect',
id: getPrefixedId('control_net_collector'),
});
const controlNetResult = await addControlNets({
manager,
manager: arg.canvasManager,
entities: canvas.controlLayers.entities,
g,
rect: canvas.bbox.rect,
@@ -250,7 +242,7 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager | n
id: getPrefixedId('t2i_adapter_collector'),
});
const t2iAdapterResult = await addT2IAdapters({
manager,
manager: arg.canvasManager,
entities: canvas.controlLayers.entities,
g,
rect: canvas.bbox.rect,
@@ -276,9 +268,9 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager | n
});
let totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters;
if (manager) {
if (generationMode === 'img2img' || generationMode === 'inpaint' || generationMode === 'outpaint') {
const regionsResult = await addRegions({
manager,
manager: arg.canvasManager,
regions: canvas.regionalGuidance.entities,
g,
bbox: canvas.bbox.rect,

View File

@@ -1,6 +1,4 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
@@ -10,24 +8,21 @@ import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChec
import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import { getGenerationMode } from 'features/nodes/util/graph/generation/getGenerationMode';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import {
getSizes,
selectCanvasOutputFields,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
import type { Invocation } from 'services/api/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
const log = logger('system');
export const buildSD3Graph = async (state: RootState, manager: CanvasManager | null): Promise<GraphBuilderReturn> => {
const tab = selectActiveTab(state);
const generationMode = await getGenerationMode(manager, tab);
export const buildSD3Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
const { generationMode, state } = arg;
log.debug({ generationMode }, 'Building SD3 graph');
const model = selectMainModelConfig(state);
@@ -133,10 +128,9 @@ export const buildSD3Graph = async (state: RootState, manager: CanvasManager | n
canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize });
g.upsertMetadata({ generation_mode: 'sd3_txt2img' });
} else if (generationMode === 'img2img') {
assert(manager, 'Need manager to do img2img');
canvasOutput = await addImageToImage({
g,
manager,
manager: arg.canvasManager,
l2i,
i2lNodeType: 'sd3_i2l',
denoise,
@@ -149,11 +143,10 @@ export const buildSD3Graph = async (state: RootState, manager: CanvasManager | n
});
g.upsertMetadata({ generation_mode: 'sd3_img2img' });
} else if (generationMode === 'inpaint') {
assert(manager, 'Need manager to do inpaint');
canvasOutput = await addInpaint({
state,
g,
manager,
manager: arg.canvasManager,
l2i,
i2lNodeType: 'sd3_i2l',
denoise,
@@ -167,11 +160,10 @@ export const buildSD3Graph = async (state: RootState, manager: CanvasManager | n
});
g.upsertMetadata({ generation_mode: 'sd3_inpaint' });
} else if (generationMode === 'outpaint') {
assert(manager, 'Need manager to do outpaint');
canvasOutput = await addOutpaint({
state,
g,
manager,
manager: arg.canvasManager,
l2i,
i2lNodeType: 'sd3_i2l',
denoise,

View File

@@ -1,6 +1,4 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
@@ -16,15 +14,13 @@ import { addSDXLRefiner } from 'features/nodes/util/graph/generation/addSDXLRefi
import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless';
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import { getGenerationMode } from 'features/nodes/util/graph/generation/getGenerationMode';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import {
getSizes,
selectCanvasOutputFields,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
import type { Invocation } from 'services/api/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
@@ -33,9 +29,8 @@ import { addRegions } from './addRegions';
const log = logger('system');
export const buildSDXLGraph = async (state: RootState, manager: CanvasManager | null): Promise<GraphBuilderReturn> => {
const tab = selectActiveTab(state);
const generationMode = await getGenerationMode(manager, tab);
export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
const { generationMode, state } = arg;
log.debug({ generationMode }, 'Building SDXL graph');
const model = selectMainModelConfig(state);
@@ -178,10 +173,9 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager |
canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize });
g.upsertMetadata({ generation_mode: 'sdxl_txt2img' });
} else if (generationMode === 'img2img') {
assert(manager, 'Need manager to do img2img');
canvasOutput = await addImageToImage({
g,
manager,
manager: arg.canvasManager,
l2i,
i2lNodeType: 'i2l',
denoise,
@@ -194,11 +188,10 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager |
});
g.upsertMetadata({ generation_mode: 'sdxl_img2img' });
} else if (generationMode === 'inpaint') {
assert(manager, 'Need manager to do inpaint');
canvasOutput = await addInpaint({
state,
g,
manager,
manager: arg.canvasManager,
l2i,
i2lNodeType: 'i2l',
denoise,
@@ -212,11 +205,10 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager |
});
g.upsertMetadata({ generation_mode: 'sdxl_inpaint' });
} else if (generationMode === 'outpaint') {
assert(manager, 'Need manager to do outpaint');
canvasOutput = await addOutpaint({
state,
g,
manager,
manager: arg.canvasManager,
l2i,
i2lNodeType: 'i2l',
denoise,
@@ -233,13 +225,13 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager |
assert<Equals<typeof generationMode, never>>(false);
}
if (manager) {
if (generationMode === 'img2img' || generationMode === 'inpaint' || generationMode === 'outpaint') {
const controlNetCollector = g.addNode({
type: 'collect',
id: getPrefixedId('control_net_collector'),
});
const controlNetResult = await addControlNets({
manager,
manager: arg.canvasManager,
entities: canvas.controlLayers.entities,
g,
rect: canvas.bbox.rect,
@@ -257,7 +249,7 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager |
id: getPrefixedId('t2i_adapter_collector'),
});
const t2iAdapterResult = await addT2IAdapters({
manager,
manager: arg.canvasManager,
entities: canvas.controlLayers.entities,
g,
rect: canvas.bbox.rect,
@@ -283,9 +275,9 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager |
});
let totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters;
if (manager) {
if (generationMode === 'img2img' || generationMode === 'inpaint' || generationMode === 'outpaint') {
const regionsResult = await addRegions({
manager,
manager: arg.canvasManager,
regions: canvas.regionalGuidance.entities,
g,
bbox: canvas.bbox.rect,

View File

@@ -1,6 +1,7 @@
import { createSelector } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import { pick } from 'es-toolkit/compat';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import type { CanvasState, ParamsState } from 'features/controlLayers/store/types';
@@ -13,7 +14,7 @@ import { selectListStylePresetsRequestState } from 'services/api/endpoints/style
import type { Invocation, S } from 'services/api/types';
import { assert } from 'tsafe';
import type { MainModelLoaderNodes } from './types';
import type { GraphBuilderArg, MainModelLoaderNodes } from './types';
/**
* Gets the board field, based on the autoAddBoardId setting.
@@ -165,3 +166,19 @@ export const isCanvasOutputNodeId = (nodeId: string) => nodeId.split(':')[0] ===
export const isCanvasOutputEvent = (data: S['InvocationCompleteEvent']) => {
return isCanvasOutputNodeId(data.invocation_source_id);
};
export const getGraphBuilderArg = async (
state: RootState,
canvasManager: CanvasManager | null
): Promise<GraphBuilderArg> => {
const tab = selectActiveTab(state);
if (tab === 'generate') {
return { generationMode: 'txt2img', state };
} else if (tab === 'canvas') {
assert(canvasManager !== null, 'CanvasManager should not be null in canvas tab');
const generationMode = await canvasManager.compositor.getGenerationMode();
return { generationMode, state, canvasManager };
} else {
assert(false, `Unknown tab: ${tab}`);
}
};

View File

@@ -1,3 +1,6 @@
import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type { GenerationMode } from 'features/controlLayers/store/types';
import type { FieldIdentifier } from 'features/nodes/types/field';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
@@ -27,6 +30,17 @@ export type MainModelLoaderNodes =
export type VaeSourceNodes = 'seamless' | 'vae_loader';
export type GraphBuilderArg =
| {
generationMode: Extract<GenerationMode, 'txt2img'>;
state: RootState;
}
| {
generationMode: Exclude<GenerationMode, 'txt2img'>;
state: RootState;
canvasManager: CanvasManager;
};
export type GraphBuilderReturn = {
g: Graph;
seedFieldIdentifier?: FieldIdentifier;

View File

@@ -7,7 +7,8 @@ import { extractMessageFromAssertionError } from 'common/util/extractMessageFrom
import { withResult, withResultAsync } from 'common/util/result';
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { canvasSessionIdCreated, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { canvasSessionIdChanged, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
@@ -18,6 +19,7 @@ import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildIma
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import type { GraphBuilderArg } from 'features/nodes/util/graph/types';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
@@ -33,40 +35,42 @@ const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prep
dispatch(enqueueRequestedCanvas());
let destination = selectCanvasSessionId(getState());
if (!destination) {
dispatch(canvasSessionIdCreated());
destination = selectCanvasSessionId(getState());
}
assert(destination !== null);
const state = getState();
const model = state.params.model;
assert(model, 'No model found in state');
const base = model.base;
let destination = selectCanvasSessionId(state);
if (destination === null) {
destination = getPrefixedId('canvas');
dispatch(canvasSessionIdChanged({ id: destination }));
}
const buildGraphResult = await withResultAsync(async () => {
const model = state.params.model;
assert(model, 'No model found in state');
const base = model.base;
const generationMode = await canvasManager.compositor.getGenerationMode();
const graphBuilderArg: GraphBuilderArg = { generationMode, state, canvasManager };
switch (base) {
case 'sdxl':
return await buildSDXLGraph(state, canvasManager);
return await buildSDXLGraph(graphBuilderArg);
case 'sd-1':
case `sd-2`:
return await buildSD1Graph(state, canvasManager);
return await buildSD1Graph(graphBuilderArg);
case `sd-3`:
return await buildSD3Graph(state, canvasManager);
return await buildSD3Graph(graphBuilderArg);
case `flux`:
return await buildFLUXGraph(state, canvasManager);
return await buildFLUXGraph(graphBuilderArg);
case 'cogview4':
return await buildCogView4Graph(state, canvasManager);
return await buildCogView4Graph(graphBuilderArg);
case 'imagen3':
return await buildImagen3Graph(state, canvasManager);
return buildImagen3Graph(graphBuilderArg);
case 'imagen4':
return await buildImagen4Graph(state, canvasManager);
return buildImagen4Graph(graphBuilderArg);
case 'chatgpt-4o':
return await buildChatGPT4oGraph(state, canvasManager);
return await buildChatGPT4oGraph(graphBuilderArg);
case 'flux-kontext':
return await buildFluxKontextGraph(state, canvasManager);
return buildFluxKontextGraph(graphBuilderArg);
default:
assert(false, `No graph builders for base ${base}`);
}

View File

@@ -5,7 +5,8 @@ import type { AppStore } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
import { withResult, withResultAsync } from 'common/util/result';
import { generateSessionIdCreated, selectGenerateSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { generateSessionIdChanged, selectGenerateSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
@@ -16,6 +17,7 @@ import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildIma
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import type { GraphBuilderArg } from 'features/nodes/util/graph/types';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
@@ -32,40 +34,41 @@ const enqueueGenerate = async (store: AppStore, prepend: boolean) => {
dispatch(enqueueRequestedGenerate());
let destination = selectGenerateSessionId(getState());
if (!destination) {
dispatch(generateSessionIdCreated());
destination = selectGenerateSessionId(getState());
}
assert(destination !== null);
const state = getState();
const model = state.params.model;
assert(model, 'No model found in state');
const base = model.base;
let destination = selectGenerateSessionId(state);
if (destination === null) {
destination = getPrefixedId('generate');
dispatch(generateSessionIdChanged({ id: destination }));
}
const buildGraphResult = await withResultAsync(async () => {
const model = state.params.model;
assert(model, 'No model found in state');
const base = model.base;
const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state };
switch (base) {
case 'sdxl':
return await buildSDXLGraph(state, null);
return await buildSDXLGraph(graphBuilderArg);
case 'sd-1':
case `sd-2`:
return await buildSD1Graph(state, null);
return await buildSD1Graph(graphBuilderArg);
case `sd-3`:
return await buildSD3Graph(state, null);
return await buildSD3Graph(graphBuilderArg);
case `flux`:
return await buildFLUXGraph(state, null);
return await buildFLUXGraph(graphBuilderArg);
case 'cogview4':
return await buildCogView4Graph(state, null);
return await buildCogView4Graph(graphBuilderArg);
case 'imagen3':
return await buildImagen3Graph(state, null);
return buildImagen3Graph(graphBuilderArg);
case 'imagen4':
return await buildImagen4Graph(state, null);
return buildImagen4Graph(graphBuilderArg);
case 'chatgpt-4o':
return await buildChatGPT4oGraph(state, null);
return await buildChatGPT4oGraph(graphBuilderArg);
case 'flux-kontext':
return await buildFluxKontextGraph(state, null);
return buildFluxKontextGraph(graphBuilderArg);
default:
assert(false, `No graph builders for base ${base}`);
}