refactor(ui): simplifiy graph builders (WIP)

This commit is contained in:
psychedelicious
2025-07-07 16:17:03 +10:00
parent 067026a0d0
commit c143f63ef0
19 changed files with 445 additions and 352 deletions

View File

@@ -1,14 +1,11 @@
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from 'app/constants';
import type { RootState } from 'app/store/store';
import { generateSeeds } from 'common/util/generateSeeds';
import randomInt from 'common/util/randomInt';
import { range } from 'es-toolkit/compat';
import type { SeedBehaviour } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import type { FieldIdentifier } from 'features/nodes/types/field';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { components } from 'services/api/schema';
import type { Batch, EnqueueBatchArg } from 'services/api/types';
import type { Batch, EnqueueBatchArg, Invocation } from 'services/api/types';
import { assert } from 'tsafe';
const getExtendedPrompts = (arg: {
@@ -31,13 +28,13 @@ export const prepareLinearUIBatch = (arg: {
state: RootState;
g: Graph;
prepend: boolean;
seedFieldIdentifier?: FieldIdentifier;
positivePromptFieldIdentifier: FieldIdentifier;
positivePromptNode: Invocation<'string'>;
seedNode?: Invocation<'integer'>;
origin: string;
destination: string;
}): EnqueueBatchArg => {
const { state, g, prepend, seedFieldIdentifier, positivePromptFieldIdentifier, origin, destination } = arg;
const { iterations, model, shouldRandomizeSeed, seed, shouldConcatPrompts } = state.params;
const { state, g, prepend, positivePromptNode, seedNode, origin, destination } = arg;
const { iterations, model, shouldRandomizeSeed, seed } = state.params;
const { prompts, seedBehaviour } = state.dynamicPrompts;
assert(model, 'No model found in state when preparing batch');
@@ -47,55 +44,27 @@ export const prepareLinearUIBatch = (arg: {
const secondBatchDatumList: components['schemas']['BatchDatum'][] = [];
// add seeds first to ensure the output order groups the prompts
if (seedFieldIdentifier && seedBehaviour === 'PER_PROMPT') {
if (seedNode && seedBehaviour === 'PER_PROMPT') {
const seeds = generateSeeds({
count: prompts.length * iterations,
// Imagen3's support for seeded generation is iffy, we are just not going too use it in linear UI generations.
start:
model.base === 'imagen3' || model.base === 'imagen4'
? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX)
: shouldRandomizeSeed
? undefined
: seed,
start: shouldRandomizeSeed ? undefined : seed,
});
firstBatchDatumList.push({
node_path: seedFieldIdentifier.nodeId,
field_name: seedFieldIdentifier.fieldName,
node_path: seedNode.id,
field_name: 'value',
items: seeds,
});
// add to metadata
g.removeMetadata(['seed']);
firstBatchDatumList.push({
node_path: g.getMetadataNode().id,
field_name: 'seed',
items: seeds,
});
} else if (seedFieldIdentifier && seedBehaviour === 'PER_ITERATION') {
} else if (seedNode && seedBehaviour === 'PER_ITERATION') {
// seedBehaviour = SeedBehaviour.PerRun
const seeds = generateSeeds({
count: iterations,
// Imagen3's support for seeded generation is iffy, we are just not going too use in in linear UI generations.
start:
model.base === 'imagen3' || model.base === 'imagen4'
? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX)
: shouldRandomizeSeed
? undefined
: seed,
start: shouldRandomizeSeed ? undefined : seed,
});
secondBatchDatumList.push({
node_path: seedFieldIdentifier.nodeId,
field_name: seedFieldIdentifier.fieldName,
items: seeds,
});
// add to metadata
g.removeMetadata(['seed']);
secondBatchDatumList.push({
node_path: g.getMetadataNode().id,
field_name: 'seed',
node_path: seedNode.id,
field_name: 'value',
items: seeds,
});
data.push(secondBatchDatumList);
@@ -105,35 +74,11 @@ export const prepareLinearUIBatch = (arg: {
// zipped batch of prompts
firstBatchDatumList.push({
node_path: positivePromptFieldIdentifier.nodeId,
field_name: positivePromptFieldIdentifier.fieldName,
node_path: positivePromptNode.id,
field_name: 'value',
items: extendedPrompts,
});
// add to metadata
g.removeMetadata(['positive_prompt']);
firstBatchDatumList.push({
node_path: g.getMetadataNode().id,
field_name: 'positive_prompt',
items: extendedPrompts,
});
if (shouldConcatPrompts && model.base === 'sdxl') {
firstBatchDatumList.push({
node_path: positivePromptFieldIdentifier.nodeId,
field_name: 'style',
items: extendedPrompts,
});
// add to metadata
g.removeMetadata(['positive_style_prompt']);
firstBatchDatumList.push({
node_path: g.getMetadataNode().id,
field_name: 'positive_style_prompt',
items: extendedPrompts,
});
}
data.push(firstBatchDatumList);
const enqueueBatchArg: EnqueueBatchArg = {

View File

@@ -7,6 +7,7 @@ import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { Dimensions } from 'features/controlLayers/store/types';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { getDenoisingStartAndEnd } from 'features/nodes/util/graph/graphBuilderUtils';
import type { Invocation } from 'services/api/types';
type AddFLUXFillArg = {
@@ -28,9 +29,9 @@ export const addFLUXFill = async ({
originalSize,
scaledSize,
}: AddFLUXFillArg): Promise<Invocation<'invokeai_img_blend' | 'apply_mask_to_image'>> => {
// FLUX Fill always fully denoises
denoise.denoising_start = 0;
denoise.denoising_end = 1;
const { denoising_start, denoising_end } = getDenoisingStartAndEnd(state);
denoise.denoising_start = denoising_start;
denoise.denoising_end = denoising_end;
const params = selectParamsSlice(state);
const canvasSettings = selectCanvasSettingsSlice(state);

View File

@@ -1,8 +1,10 @@
import { objectEquals } from '@observ33r/object-equals';
import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { CanvasState, Dimensions } from 'features/controlLayers/store/types';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { getDenoisingStartAndEnd } from 'features/nodes/util/graph/graphBuilderUtils';
import type {
DenoiseLatentsNodes,
LatentToImageNodes,
@@ -13,6 +15,7 @@ import type { Invocation } from 'services/api/types';
type AddImageToImageArg = {
g: Graph;
state: RootState;
manager: CanvasManager;
l2i: Invocation<LatentToImageNodes>;
i2l: Invocation<'i2l' | 'flux_vae_encode' | 'sd3_i2l' | 'cogview4_i2l'>;
@@ -21,11 +24,11 @@ type AddImageToImageArg = {
originalSize: Dimensions;
scaledSize: Dimensions;
bbox: CanvasState['bbox'];
denoising_start: number;
};
export const addImageToImage = async ({
g,
state,
manager,
l2i,
i2l,
@@ -34,9 +37,11 @@ export const addImageToImage = async ({
originalSize,
scaledSize,
bbox,
denoising_start,
}: AddImageToImageArg): Promise<Invocation<'img_resize' | 'l2i' | 'flux_vae_decode' | 'sd3_l2i' | 'cogview4_l2i'>> => {
const { denoising_start, denoising_end } = getDenoisingStartAndEnd(state);
denoise.denoising_start = denoising_start;
denoise.denoising_end = denoising_end;
const adapters = manager.compositor.getVisibleAdaptersOfType('raster_layer');
const { image_name } = await manager.compositor.getCompositeImageDTO(adapters, bbox.rect, {
is_intermediate: true,

View File

@@ -7,7 +7,7 @@ import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { Dimensions } from 'features/controlLayers/store/types';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { isMainModelWithoutUnet } from 'features/nodes/util/graph/graphBuilderUtils';
import { getDenoisingStartAndEnd, isMainModelWithoutUnet } from 'features/nodes/util/graph/graphBuilderUtils';
import type {
DenoiseLatentsNodes,
LatentToImageNodes,
@@ -17,8 +17,8 @@ import type {
import type { ImageDTO, Invocation } from 'services/api/types';
type AddInpaintArg = {
state: RootState;
g: Graph;
state: RootState;
manager: CanvasManager;
l2i: Invocation<LatentToImageNodes>;
i2l: Invocation<'i2l' | 'flux_vae_encode' | 'sd3_i2l' | 'cogview4_i2l'>;
@@ -27,13 +27,12 @@ type AddInpaintArg = {
modelLoader: Invocation<MainModelLoaderNodes>;
originalSize: Dimensions;
scaledSize: Dimensions;
denoising_start: number;
seed: Invocation<'integer'>;
};
export const addInpaint = async ({
state,
g,
state,
manager,
l2i,
i2l,
@@ -42,10 +41,11 @@ export const addInpaint = async ({
modelLoader,
originalSize,
scaledSize,
denoising_start,
seed,
}: AddInpaintArg): Promise<Invocation<'invokeai_img_blend' | 'apply_mask_to_image'>> => {
const { denoising_start, denoising_end } = getDenoisingStartAndEnd(state);
denoise.denoising_start = denoising_start;
denoise.denoising_end = denoising_end;
const params = selectParamsSlice(state);
const canvasSettings = selectCanvasSettingsSlice(state);

View File

@@ -7,7 +7,11 @@ import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { Dimensions } from 'features/controlLayers/store/types';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { getInfill, isMainModelWithoutUnet } from 'features/nodes/util/graph/graphBuilderUtils';
import {
getDenoisingStartAndEnd,
getInfill,
isMainModelWithoutUnet,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type {
DenoiseLatentsNodes,
ImageToLatentsNodes,
@@ -28,7 +32,6 @@ type AddOutpaintArg = {
modelLoader: Invocation<MainModelLoaderNodes>;
originalSize: Dimensions;
scaledSize: Dimensions;
denoising_start: number;
seed: Invocation<'integer'>;
};
@@ -43,10 +46,11 @@ export const addOutpaint = async ({
modelLoader,
originalSize,
scaledSize,
denoising_start,
seed,
}: AddOutpaintArg): Promise<Invocation<'invokeai_img_blend' | 'apply_mask_to_image'>> => {
const { denoising_start, denoising_end } = getDenoisingStartAndEnd(state);
denoise.denoising_start = denoising_start;
denoise.denoising_end = denoising_end;
const params = selectParamsSlice(state);
const canvasSettings = selectCanvasSettingsSlice(state);

View File

@@ -2,11 +2,12 @@ import { objectEquals } from '@observ33r/object-equals';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { Dimensions } from 'features/controlLayers/store/types';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { LatentToImageNodes } from 'features/nodes/util/graph/types';
import type { DenoiseLatentsNodes, LatentToImageNodes } from 'features/nodes/util/graph/types';
import type { Invocation } from 'services/api/types';
type AddTextToImageArg = {
g: Graph;
denoise: Invocation<DenoiseLatentsNodes>;
l2i: Invocation<LatentToImageNodes>;
originalSize: Dimensions;
scaledSize: Dimensions;
@@ -14,10 +15,14 @@ type AddTextToImageArg = {
export const addTextToImage = ({
g,
denoise,
l2i,
originalSize,
scaledSize,
}: AddTextToImageArg): Invocation<'img_resize' | 'l2i' | 'flux_vae_decode' | 'sd3_l2i' | 'cogview4_l2i'> => {
denoise.denoising_start = 0;
denoise.denoising_end = 1;
if (!objectEquals(scaledSize, originalSize)) {
// We need to resize the output image back to the original size
const resizeImageToOriginalSize = g.addNode({

View File

@@ -6,11 +6,7 @@ import { isChatGPT4oAspectRatioID, isChatGPT4oReferenceImageConfig } from 'featu
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
import { type ImageField, zModelIdentifierField } from 'features/nodes/types/common';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import {
selectCanvasOutputFields,
selectOriginalAndScaledSizes,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import { selectCanvasOutputFields, selectOriginalAndScaledSizes } from 'features/nodes/util/graph/graphBuilderUtils';
import type { GraphBuilderArg, GraphBuilderReturn } from 'features/nodes/util/graph/types';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { t } from 'i18next';
@@ -33,10 +29,9 @@ export const buildChatGPT4oGraph = async (arg: GraphBuilderArg): Promise<GraphBu
const refImages = selectRefImagesSlice(state);
const { originalSize, scaledSize, aspectRatio } = selectOriginalAndScaledSizes(state);
const { positivePrompt } = selectPresetModifiedPrompts(state);
assert(model, 'No model found in state');
assert(model.base === 'chatgpt-4o', 'Model is not a ChatGPT 4o model');
assert(model, 'No model selected');
assert(model.base === 'chatgpt-4o', 'Selected model is not a ChatGPT 4o API model');
const validRefImages = refImages.entities
.filter((entity) => entity.isEnabled)
@@ -60,24 +55,35 @@ export const buildChatGPT4oGraph = async (arg: GraphBuilderArg): Promise<GraphBu
assert(isChatGPT4oAspectRatioID(aspectRatio.id), 'ChatGPT 4o does not support this aspect ratio');
const g = new Graph(getPrefixedId('chatgpt_4o_txt2img_graph'));
const positivePrompt = g.addNode({
id: getPrefixedId('positive_prompt'),
type: 'string',
});
const gptImage = g.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: 'chatgpt_4o_generate_image',
model: zModelIdentifierField.parse(model),
positive_prompt: positivePrompt,
aspect_ratio: aspectRatio.id,
reference_images,
...selectCanvasOutputFields(state),
});
g.addEdge(
positivePrompt,
'value',
gptImage,
// @ts-expect-error: These nodes are not available in the OSS application
'positive_prompt'
);
g.addEdgeToMetadata(positivePrompt, 'value', 'positive_prompt');
g.upsertMetadata({
positive_prompt: positivePrompt,
model: Graph.getModelMetadataField(model),
width: originalSize.width,
height: originalSize.height,
});
return {
g,
positivePromptFieldIdentifier: { nodeId: gptImage.id, fieldName: 'positive_prompt' },
positivePrompt,
};
}
@@ -89,25 +95,36 @@ export const buildChatGPT4oGraph = async (arg: GraphBuilderArg): Promise<GraphBu
silent: true,
});
const g = new Graph(getPrefixedId('chatgpt_4o_img2img_graph'));
const positivePrompt = g.addNode({
id: getPrefixedId('positive_prompt'),
type: 'string',
});
const gptImage = g.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: 'chatgpt_4o_edit_image',
model: zModelIdentifierField.parse(model),
positive_prompt: positivePrompt,
aspect_ratio: bbox.aspectRatio.id,
base_image: { image_name },
reference_images,
...selectCanvasOutputFields(state),
});
g.addEdge(
positivePrompt,
'value',
gptImage,
// @ts-expect-error: These nodes are not available in the OSS application
'positive_prompt'
);
g.addEdgeToMetadata(positivePrompt, 'value', 'positive_prompt');
g.upsertMetadata({
positive_prompt: positivePrompt,
model: Graph.getModelMetadataField(model),
width: bbox.rect.width,
height: bbox.rect.height,
});
return {
g,
positivePromptFieldIdentifier: { nodeId: gptImage.id, fieldName: 'positive_prompt' },
positivePrompt,
};
}

View File

@@ -1,6 +1,6 @@
import { logger } from 'app/logging/logger';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage';
@@ -25,43 +25,51 @@ const log = logger('system');
export const buildCogView4Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
const { generationMode, state, manager } = arg;
log.debug({ generationMode, manager: manager?.id }, 'Building CogView4 graph');
const model = selectMainModelConfig(state);
assert(model, 'No model selected');
assert(model.base === 'cogview4', 'Selected model is not a CogView4 model');
const params = selectParamsSlice(state);
const canvas = selectCanvasSlice(state);
const { bbox } = canvas;
const { model, cfgScale: cfg_scale, seed: _seed, steps } = params;
assert(model, 'No model found in state');
const { cfgScale: cfg_scale, seed: _seed, steps } = params;
const { originalSize, scaledSize } = selectOriginalAndScaledSizes(state);
const { positivePrompt, negativePrompt } = selectPresetModifiedPrompts(state);
const prompts = selectPresetModifiedPrompts(state);
const g = new Graph(getPrefixedId('cogview4_graph'));
const seed = g.addNode({
id: getPrefixedId('seed'),
type: 'integer',
value: _seed,
});
const modelLoader = g.addNode({
type: 'cogview4_model_loader',
id: getPrefixedId('cogview4_model_loader'),
model,
});
const positivePrompt = g.addNode({
id: getPrefixedId('positive_prompt'),
type: 'string',
});
const posCond = g.addNode({
type: 'cogview4_text_encoder',
id: getPrefixedId('pos_prompt'),
prompt: positivePrompt,
});
const negCond = g.addNode({
type: 'cogview4_text_encoder',
id: getPrefixedId('neg_prompt'),
prompt: negativePrompt,
prompt: prompts.negative,
});
const seed = g.addNode({
id: getPrefixedId('seed'),
type: 'integer',
value: _seed,
});
const denoise = g.addNode({
type: 'cogview4_denoise',
id: getPrefixedId('denoise_latents'),
@@ -69,8 +77,6 @@ export const buildCogView4Graph = async (arg: GraphBuilderArg): Promise<GraphBui
width: scaledSize.width,
height: scaledSize.height,
steps,
denoising_start: 0,
denoising_end: 1,
});
const l2i = g.addNode({
type: 'cogview4_l2i',
@@ -81,15 +87,17 @@ export const buildCogView4Graph = async (arg: GraphBuilderArg): Promise<GraphBui
id: getPrefixedId('cogview4_i2l'),
});
g.addEdge(seed, 'value', denoise, 'seed');
g.addEdge(modelLoader, 'transformer', denoise, 'transformer');
g.addEdge(modelLoader, 'glm_encoder', posCond, 'glm_encoder');
g.addEdge(modelLoader, 'glm_encoder', negCond, 'glm_encoder');
g.addEdge(modelLoader, 'vae', l2i, 'vae');
g.addEdge(positivePrompt, 'value', posCond, 'prompt');
g.addEdge(posCond, 'conditioning', denoise, 'positive_conditioning');
g.addEdge(negCond, 'conditioning', denoise, 'negative_conditioning');
g.addEdge(seed, 'value', denoise, 'seed');
g.addEdge(denoise, 'latents', l2i, 'latents');
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
@@ -99,39 +107,44 @@ export const buildCogView4Graph = async (arg: GraphBuilderArg): Promise<GraphBui
cfg_scale,
width: originalSize.width,
height: originalSize.height,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
negative_prompt: prompts.negative,
model: Graph.getModelMetadataField(modelConfig),
steps,
});
const denoising_start = 1 - params.img2imgStrength;
g.addEdgeToMetadata(seed, 'value', 'seed');
g.addEdgeToMetadata(positivePrompt, 'value', 'positive_prompt');
let canvasOutput: Invocation<ImageOutputNodes> = l2i;
if (generationMode === 'txt2img') {
canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize });
canvasOutput = addTextToImage({
g,
denoise,
l2i,
originalSize,
scaledSize,
});
g.upsertMetadata({ generation_mode: 'cogview4_txt2img' });
} else if (generationMode === 'img2img') {
assert(manager !== null);
canvasOutput = await addImageToImage({
g,
state,
manager,
denoise,
l2i,
i2l,
denoise,
vaeSource: modelLoader,
originalSize,
scaledSize,
bbox,
denoising_start,
});
g.upsertMetadata({ generation_mode: 'cogview4_img2img' });
} else if (generationMode === 'inpaint') {
assert(manager !== null);
canvasOutput = await addInpaint({
state,
g,
state,
manager,
l2i,
i2l,
@@ -140,15 +153,14 @@ export const buildCogView4Graph = async (arg: GraphBuilderArg): Promise<GraphBui
modelLoader,
originalSize,
scaledSize,
denoising_start,
seed,
});
g.upsertMetadata({ generation_mode: 'cogview4_inpaint' });
} else if (generationMode === 'outpaint') {
assert(manager !== null);
canvasOutput = await addOutpaint({
state,
g,
state,
manager,
l2i,
i2l,
@@ -157,7 +169,6 @@ export const buildCogView4Graph = async (arg: GraphBuilderArg): Promise<GraphBui
modelLoader,
originalSize,
scaledSize,
denoising_start,
seed,
});
g.upsertMetadata({ generation_mode: 'cogview4_outpaint' });
@@ -178,9 +189,10 @@ export const buildCogView4Graph = async (arg: GraphBuilderArg): Promise<GraphBui
g.updateNode(canvasOutput, selectCanvasOutputFields(state));
g.setMetadataReceivingNode(canvasOutput);
return {
g,
seedFieldIdentifier: { nodeId: seed.id, fieldName: 'value' },
positivePromptFieldIdentifier: { nodeId: posCond.id, fieldName: 'prompt' },
seed,
positivePrompt,
};
};

View File

@@ -16,11 +16,7 @@ import { addRegions } from 'features/nodes/util/graph/generation/addRegions';
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import {
selectCanvasOutputFields,
selectOriginalAndScaledSizes,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import { selectCanvasOutputFields, selectOriginalAndScaledSizes } from 'features/nodes/util/graph/graphBuilderUtils';
import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { t } from 'i18next';
@@ -37,6 +33,10 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
const { generationMode, state, manager } = arg;
log.debug({ generationMode, manager: manager?.id }, 'Building FLUX graph');
const model = selectMainModelConfig(state);
assert(model, 'No model selected');
assert(model.base === 'flux', 'Selected model is not a FLUX model');
const params = selectParamsSlice(state);
const canvas = selectCanvasSlice(state);
const refImages = selectRefImagesSlice(state);
@@ -45,21 +45,8 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
const { originalSize, scaledSize } = selectOriginalAndScaledSizes(state);
const model = selectMainModelConfig(state);
const { guidance: baseGuidance, seed: _seed, steps, fluxVAE, t5EncoderModel, clipEmbedModel } = params;
const {
guidance: baseGuidance,
seed: _seed,
steps,
fluxVAE,
t5EncoderModel,
clipEmbedModel,
img2imgStrength,
optimizedDenoisingEnabled,
} = params;
assert(model, 'No model found in state');
assert(model.base === 'flux', 'Model is not a FLUX model');
assert(t5EncoderModel, 'No T5 Encoder model found in state');
assert(clipEmbedModel, 'No CLIP Embed model found in state');
assert(fluxVAE, 'No FLUX VAE model found in state');
@@ -98,14 +85,8 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
guidance = 30;
}
const { positivePrompt } = selectPresetModifiedPrompts(state);
const g = new Graph(getPrefixedId('flux_graph'));
const seed = g.addNode({
id: getPrefixedId('seed'),
type: 'integer',
value: _seed,
});
const modelLoader = g.addNode({
type: 'flux_model_loader',
id: getPrefixedId('flux_model_loader'),
@@ -115,23 +96,29 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
vae_model: fluxVAE,
});
const positivePrompt = g.addNode({
id: getPrefixedId('positive_prompt'),
type: 'string',
});
const posCond = g.addNode({
type: 'flux_text_encoder',
id: getPrefixedId('flux_text_encoder'),
prompt: positivePrompt,
});
const posCondCollect = g.addNode({
type: 'collect',
id: getPrefixedId('pos_cond_collect'),
});
const seed = g.addNode({
id: getPrefixedId('seed'),
type: 'integer',
value: _seed,
});
const denoise = g.addNode({
type: 'flux_denoise',
id: getPrefixedId('flux_denoise'),
guidance,
num_steps: steps,
denoising_start: 0,
denoising_end: 1,
width: scaledSize.width,
height: scaledSize.height,
});
@@ -145,6 +132,35 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
id: getPrefixedId('flux_vae_encode'),
});
g.addEdge(modelLoader, 'transformer', denoise, 'transformer');
g.addEdge(modelLoader, 'vae', denoise, 'controlnet_vae');
g.addEdge(modelLoader, 'vae', l2i, 'vae');
g.addEdge(modelLoader, 'clip', posCond, 'clip');
g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder');
g.addEdge(modelLoader, 'max_seq_len', posCond, 't5_max_seq_len');
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
g.addEdge(posCondCollect, 'collection', denoise, 'positive_text_conditioning');
g.addEdge(seed, 'value', denoise, 'seed');
g.addEdge(denoise, 'latents', l2i, 'latents');
addFLUXLoRAs(state, g, denoise, modelLoader, posCond);
g.upsertMetadata({
guidance,
width: originalSize.width,
height: originalSize.height,
model: Graph.getModelMetadataField(model),
steps,
vae: fluxVAE,
t5_encoder: t5EncoderModel,
clip_embed_model: clipEmbedModel,
});
g.addEdgeToMetadata(seed, 'value', 'seed');
g.addEdgeToMetadata(positivePrompt, 'value', 'positive_prompt');
if (isFluxKontextDev) {
const validFLUXKontextConfigs = selectRefImagesSlice(state)
.entities.filter((entity) => entity.isEnabled)
@@ -170,49 +186,13 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
}
}
g.addEdge(seed, 'value', denoise, 'seed');
g.addEdge(modelLoader, 'transformer', denoise, 'transformer');
g.addEdge(modelLoader, 'vae', denoise, 'controlnet_vae');
g.addEdge(modelLoader, 'vae', l2i, 'vae');
g.addEdge(modelLoader, 'clip', posCond, 'clip');
g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder');
g.addEdge(modelLoader, 'max_seq_len', posCond, 't5_max_seq_len');
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
g.addEdge(posCondCollect, 'collection', denoise, 'positive_text_conditioning');
g.addEdge(denoise, 'latents', l2i, 'latents');
addFLUXLoRAs(state, g, denoise, modelLoader, posCond);
g.upsertMetadata({
guidance,
width: originalSize.width,
height: originalSize.height,
positive_prompt: positivePrompt,
model: Graph.getModelMetadataField(model),
steps,
vae: fluxVAE,
t5_encoder: t5EncoderModel,
clip_embed_model: clipEmbedModel,
});
let denoising_start: number;
if (optimizedDenoisingEnabled) {
// We rescale the img2imgStrength (with exponent 0.2) to effectively use the entire range [0, 1] and make the scale
// more user-friendly for FLUX. Without this, most of the 'change' is concentrated in the high denoise strength
// range (>0.9).
denoising_start = 1 - img2imgStrength ** 0.2;
} else {
denoising_start = 1 - img2imgStrength;
}
let canvasOutput: Invocation<ImageOutputNodes> = l2i;
if (isFLUXFill && (generationMode === 'inpaint' || generationMode === 'outpaint')) {
assert(manager !== null);
canvasOutput = await addFLUXFill({
state,
g,
state,
manager,
l2i,
denoise,
@@ -220,12 +200,19 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
scaledSize,
});
} else if (generationMode === 'txt2img') {
canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize });
canvasOutput = addTextToImage({
g,
denoise,
l2i,
originalSize,
scaledSize,
});
g.upsertMetadata({ generation_mode: 'flux_txt2img' });
} else if (generationMode === 'img2img') {
assert(manager !== null);
canvasOutput = await addImageToImage({
g,
state,
manager,
l2i,
i2l,
@@ -234,14 +221,13 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
originalSize,
scaledSize,
bbox,
denoising_start,
});
g.upsertMetadata({ generation_mode: 'flux_img2img' });
} else if (generationMode === 'inpaint') {
assert(manager !== null);
canvasOutput = await addInpaint({
state,
g,
state,
manager,
l2i,
i2l,
@@ -250,15 +236,14 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
modelLoader,
originalSize,
scaledSize,
denoising_start,
seed,
});
g.upsertMetadata({ generation_mode: 'flux_inpaint' });
} else if (generationMode === 'outpaint') {
assert(manager !== null);
canvasOutput = await addOutpaint({
state,
g,
state,
manager,
l2i,
i2l,
@@ -267,7 +252,6 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
modelLoader,
originalSize,
scaledSize,
denoising_start,
seed,
});
g.upsertMetadata({ generation_mode: 'flux_outpaint' });
@@ -375,9 +359,10 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
g.updateNode(canvasOutput, selectCanvasOutputFields(state));
g.setMetadataReceivingNode(canvasOutput);
return {
g,
seedFieldIdentifier: { nodeId: seed.id, fieldName: 'value' },
positivePromptFieldIdentifier: { nodeId: posCond.id, fieldName: 'prompt' },
seed,
positivePrompt,
};
};

View File

@@ -8,7 +8,7 @@ import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/va
import type { ImageField } from 'features/nodes/types/common';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { selectCanvasOutputFields, selectPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import { selectCanvasOutputFields } from 'features/nodes/util/graph/graphBuilderUtils';
import type { GraphBuilderArg, GraphBuilderReturn } from 'features/nodes/util/graph/types';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { t } from 'i18next';
@@ -19,22 +19,20 @@ const log = logger('system');
export const buildFluxKontextGraph = (arg: GraphBuilderArg): GraphBuilderReturn => {
const { generationMode, state, manager } = arg;
const model = selectMainModelConfig(state);
assert(model, 'No model selected');
assert(model.base === 'flux-kontext', 'Selected model is not a FLUX Kontext API model');
if (generationMode !== 'txt2img') {
throw new UnsupportedGenerationModeError(t('toast.imagenIncompatibleGenerationMode', { model: 'FLUX Kontext' }));
}
log.debug({ generationMode, manager: manager?.id }, 'Building FLUX Kontext graph');
const model = selectMainModelConfig(state);
const canvas = selectCanvasSlice(state);
const refImages = selectRefImagesSlice(state);
const { bbox } = canvas;
const { positivePrompt } = selectPresetModifiedPrompts(state);
assert(model, 'No model found in state');
assert(model.base === 'flux-kontext', 'Model is not a Flux Kontext model');
const validRefImages = refImages.entities
.filter((entity) => entity.isEnabled)
@@ -54,24 +52,35 @@ export const buildFluxKontextGraph = (arg: GraphBuilderArg): GraphBuilderReturn
}
const g = new Graph(getPrefixedId('flux_kontext_txt2img_graph'));
const positivePrompt = g.addNode({
id: getPrefixedId('positive_prompt'),
type: 'string',
});
const fluxKontextImage = g.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: input_image ? 'flux_kontext_edit_image' : 'flux_kontext_generate_image',
model: zModelIdentifierField.parse(model),
positive_prompt: positivePrompt,
aspect_ratio: bbox.aspectRatio.id,
input_image,
prompt_upsampling: true,
...selectCanvasOutputFields(state),
});
g.addEdge(
positivePrompt,
'value',
fluxKontextImage,
// @ts-expect-error: These nodes are not available in the OSS application
'positive_prompt'
);
g.addEdgeToMetadata(positivePrompt, 'value', 'positive_prompt');
g.upsertMetadata({
positive_prompt: positivePrompt,
model: Graph.getModelMetadataField(model),
width: bbox.rect.width,
height: bbox.rect.height,
});
return {
g,
positivePromptFieldIdentifier: { nodeId: fluxKontextImage.id, fieldName: 'positive_prompt' },
positivePrompt,
};
};

View File

@@ -15,39 +15,51 @@ const log = logger('system');
export const buildImagen3Graph = (arg: GraphBuilderArg): GraphBuilderReturn => {
const { generationMode, state, manager } = arg;
log.debug({ generationMode, manager: manager?.id }, 'Building Imagen3 graph');
const model = selectMainModelConfig(state);
assert(model, 'No model selected');
assert(model.base === 'imagen3', 'Selected model is not an Imagen3 API model');
if (generationMode !== 'txt2img') {
throw new UnsupportedGenerationModeError(t('toast.imagenIncompatibleGenerationMode', { model: 'Imagen3' }));
}
log.debug({ generationMode, manager: manager?.id }, 'Building Imagen3 graph');
const canvas = selectCanvasSlice(state);
const { bbox } = canvas;
const { positivePrompt, negativePrompt } = selectPresetModifiedPrompts(state);
const model = selectMainModelConfig(state);
assert(model, 'No model found for Imagen3 graph');
assert(model.base === 'imagen3', 'Imagen3 graph requires Imagen3 model');
const prompts = selectPresetModifiedPrompts(state);
assert(isImagenAspectRatioID(bbox.aspectRatio.id), 'Imagen3 does not support this aspect ratio');
assert(positivePrompt.length > 0, 'Imagen3 requires positive prompt to have at least one character');
assert(prompts.positive.length > 0, 'Imagen3 requires positive prompt to have at least one character');
const g = new Graph(getPrefixedId('imagen3_txt2img_graph'));
const positivePrompt = g.addNode({
id: getPrefixedId('positive_prompt'),
type: 'string',
});
const imagen3 = g.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: 'google_imagen3_generate_image',
model: zModelIdentifierField.parse(model),
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
negative_prompt: prompts.negative,
aspect_ratio: bbox.aspectRatio.id,
// When enhance_prompt is true, Imagen3 will return a new image every time, ignoring the seed.
enhance_prompt: true,
...selectCanvasOutputFields(state),
});
g.addEdge(
positivePrompt,
'value',
imagen3,
// @ts-expect-error: These nodes are not available in the OSS application
'positive_prompt'
);
g.addEdgeToMetadata(positivePrompt, 'value', 'positive_prompt');
g.upsertMetadata({
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
negative_prompt: prompts.negative,
width: bbox.rect.width,
height: bbox.rect.height,
model: Graph.getModelMetadataField(model),
@@ -56,7 +68,6 @@ export const buildImagen3Graph = (arg: GraphBuilderArg): GraphBuilderReturn => {
return {
g,
seedFieldIdentifier: { nodeId: imagen3.id, fieldName: 'seed' },
positivePromptFieldIdentifier: { nodeId: imagen3.id, fieldName: 'positive_prompt' },
positivePrompt,
};
};

View File

@@ -15,39 +15,51 @@ const log = logger('system');
export const buildImagen4Graph = (arg: GraphBuilderArg): GraphBuilderReturn => {
const { generationMode, state, manager } = arg;
log.debug({ generationMode, manager: manager?.id }, 'Building Imagen4 graph');
const model = selectMainModelConfig(state);
assert(model, 'No model selected');
assert(model.base === 'imagen4', 'Selected model is not a Imagen4 API model');
if (generationMode !== 'txt2img') {
throw new UnsupportedGenerationModeError(t('toast.imagenIncompatibleGenerationMode', { model: 'Imagen4' }));
}
log.debug({ generationMode, manager: manager?.id }, 'Building Imagen4 graph');
const canvas = selectCanvasSlice(state);
const { bbox } = canvas;
const { positivePrompt, negativePrompt } = selectPresetModifiedPrompts(state);
const model = selectMainModelConfig(state);
assert(model, 'No model found for Imagen4 graph');
assert(model.base === 'imagen4', 'Imagen4 graph requires Imagen4 model');
const prompts = selectPresetModifiedPrompts(state);
assert(isImagenAspectRatioID(bbox.aspectRatio.id), 'Imagen4 does not support this aspect ratio');
assert(positivePrompt.length > 0, 'Imagen4 requires positive prompt to have at least one character');
assert(prompts.positive.length > 0, 'Imagen4 requires positive prompt to have at least one character');
const g = new Graph(getPrefixedId('imagen4_txt2img_graph'));
const positivePrompt = g.addNode({
id: getPrefixedId('positive_prompt'),
type: 'string',
});
const imagen4 = g.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: 'google_imagen4_generate_image',
model: zModelIdentifierField.parse(model),
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
negative_prompt: prompts.negative,
aspect_ratio: bbox.aspectRatio.id,
// When enhance_prompt is true, Imagen4 will return a new image every time, ignoring the seed.
enhance_prompt: true,
...selectCanvasOutputFields(state),
});
g.addEdge(
positivePrompt,
'value',
imagen4,
// @ts-expect-error: These nodes are not available in the OSS application
'positive_prompt'
);
g.addEdgeToMetadata(positivePrompt, 'value', 'positive_prompt');
g.upsertMetadata({
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
negative_prompt: prompts.negative,
width: bbox.rect.width,
height: bbox.rect.height,
model: Graph.getModelMetadataField(model),
@@ -56,7 +68,6 @@ export const buildImagen4Graph = (arg: GraphBuilderArg): GraphBuilderReturn => {
return {
g,
seedFieldIdentifier: { nodeId: imagen4.id, fieldName: 'seed' },
positivePromptFieldIdentifier: { nodeId: imagen4.id, fieldName: 'positive_prompt' },
positivePrompt,
};
};

View File

@@ -31,14 +31,18 @@ const log = logger('system');
export const buildSD1Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
const { generationMode, state, manager } = arg;
log.debug({ generationMode, manager: manager?.id }, 'Building SD1/SD2 graph');
const model = selectMainModelConfig(state);
assert(model, 'No model selected');
assert(model.base === 'sd-1' || model.base === 'sd-2', 'Selected model is not a SDXL model');
const params = selectParamsSlice(state);
const canvas = selectCanvasSlice(state);
const refImages = selectRefImagesSlice(state);
const { bbox } = canvas;
const model = selectMainModelConfig(state);
const {
cfgScale: cfg_scale,
@@ -52,10 +56,8 @@ export const buildSD1Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
vae,
} = params;
assert(model, 'No model found in state');
const fp32 = vaePrecision === 'fp32';
const { positivePrompt, negativePrompt } = selectPresetModifiedPrompts(state);
const prompts = selectPresetModifiedPrompts(state);
const { originalSize, scaledSize } = selectOriginalAndScaledSizes(state);
const g = new Graph(getPrefixedId('sd1_graph'));
@@ -64,6 +66,10 @@ export const buildSD1Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
type: 'integer',
value: _seed,
});
const positivePrompt = g.addNode({
id: getPrefixedId('positive_prompt'),
type: 'string',
});
const modelLoader = g.addNode({
type: 'main_model_loader',
id: getPrefixedId('sd1_model_loader'),
@@ -77,7 +83,6 @@ export const buildSD1Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
const posCond = g.addNode({
type: 'compel',
id: getPrefixedId('pos_cond'),
prompt: positivePrompt,
});
const posCondCollect = g.addNode({
type: 'collect',
@@ -86,7 +91,7 @@ export const buildSD1Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
const negCond = g.addNode({
type: 'compel',
id: getPrefixedId('neg_cond'),
prompt: negativePrompt,
prompt: prompts.negative,
});
const negCondCollect = g.addNode({
type: 'collect',
@@ -128,27 +133,28 @@ export const buildSD1Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
})
: null;
g.addEdge(seed, 'value', noise, 'seed');
g.addEdge(modelLoader, 'unet', denoise, 'unet');
g.addEdge(modelLoader, 'clip', clipSkip, 'clip');
g.addEdge(clipSkip, 'clip', posCond, 'clip');
g.addEdge(clipSkip, 'clip', negCond, 'clip');
g.addEdge(positivePrompt, 'value', posCond, 'prompt');
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
g.addEdge(negCond, 'conditioning', negCondCollect, 'item');
g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning');
g.addEdge(negCond, 'conditioning', negCondCollect, 'item');
g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning');
g.addEdge(seed, 'value', noise, 'seed');
g.addEdge(noise, 'noise', denoise, 'noise');
g.addEdge(denoise, 'latents', l2i, 'latents');
assert(model.base === 'sd-1' || model.base === 'sd-2');
g.upsertMetadata({
cfg_scale,
cfg_rescale_multiplier,
width: originalSize.width,
height: originalSize.height,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
negative_prompt: prompts.negative,
model: Graph.getModelMetadataField(model),
steps,
rand_device: shouldUseCpuNoise ? 'cpu' : 'cuda',
@@ -156,6 +162,8 @@ export const buildSD1Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
clip_skip: skipped_layers,
vae: vae ?? undefined,
});
g.addEdgeToMetadata(seed, 'value', 'seed');
g.addEdgeToMetadata(positivePrompt, 'value', 'positive_prompt');
const seamless = addSeamless(state, g, denoise, modelLoader, vaeLoader);
@@ -167,17 +175,22 @@ export const buildSD1Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
> = seamless ?? vaeLoader ?? modelLoader;
g.addEdge(vaeSource, 'vae', l2i, 'vae');
const denoising_start = 1 - params.img2imgStrength;
let canvasOutput: Invocation<ImageOutputNodes> = l2i;
if (generationMode === 'txt2img') {
canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize });
canvasOutput = addTextToImage({
g,
denoise,
l2i,
originalSize,
scaledSize,
});
g.upsertMetadata({ generation_mode: 'txt2img' });
} else if (generationMode === 'img2img') {
assert(manager !== null);
canvasOutput = await addImageToImage({
g,
state,
manager,
l2i,
i2l,
@@ -186,14 +199,13 @@ export const buildSD1Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
originalSize,
scaledSize,
bbox,
denoising_start,
});
g.upsertMetadata({ generation_mode: 'img2img' });
} else if (generationMode === 'inpaint') {
assert(manager !== null);
canvasOutput = await addInpaint({
state,
g,
state,
manager,
l2i,
i2l,
@@ -202,15 +214,14 @@ export const buildSD1Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
modelLoader,
originalSize,
scaledSize,
denoising_start,
seed,
});
g.upsertMetadata({ generation_mode: 'inpaint' });
} else if (generationMode === 'outpaint') {
assert(manager !== null);
canvasOutput = await addOutpaint({
state,
g,
state,
manager,
l2i,
i2l,
@@ -219,7 +230,6 @@ export const buildSD1Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
modelLoader,
originalSize,
scaledSize,
denoising_start,
seed,
});
g.upsertMetadata({ generation_mode: 'outpaint' });
@@ -314,9 +324,10 @@ export const buildSD1Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
g.updateNode(canvasOutput, selectCanvasOutputFields(state));
g.setMetadataReceivingNode(canvasOutput);
return {
g,
seedFieldIdentifier: { nodeId: seed.id, fieldName: 'value' },
positivePromptFieldIdentifier: { nodeId: posCond.id, fieldName: 'prompt' },
seed,
positivePrompt,
};
};

View File

@@ -23,6 +23,7 @@ const log = logger('system');
export const buildSD3Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
const { generationMode, state, manager } = arg;
log.debug({ generationMode, manager: manager?.id }, 'Building SD3 graph');
const model = selectMainModelConfig(state);
@@ -34,27 +35,13 @@ export const buildSD3Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
const { bbox } = canvas;
const {
cfgScale: cfg_scale,
seed: _seed,
steps,
vae,
t5EncoderModel,
clipLEmbedModel,
clipGEmbedModel,
optimizedDenoisingEnabled,
img2imgStrength,
} = params;
const { cfgScale: cfg_scale, seed: _seed, steps, vae, t5EncoderModel, clipLEmbedModel, clipGEmbedModel } = params;
const { originalSize, scaledSize } = selectOriginalAndScaledSizes(state);
const { positivePrompt, negativePrompt } = selectPresetModifiedPrompts(state);
const prompts = selectPresetModifiedPrompts(state);
const g = new Graph(getPrefixedId('sd3_graph'));
const seed = g.addNode({
id: getPrefixedId('seed'),
type: 'integer',
value: _seed,
});
const modelLoader = g.addNode({
type: 'sd3_model_loader',
id: getPrefixedId('sd3_model_loader'),
@@ -64,18 +51,27 @@ export const buildSD3Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
clip_g_model: clipGEmbedModel,
vae_model: vae,
});
const positivePrompt = g.addNode({
id: getPrefixedId('positive_prompt'),
type: 'string',
});
const posCond = g.addNode({
type: 'sd3_text_encoder',
id: getPrefixedId('pos_cond'),
prompt: positivePrompt,
});
const negCond = g.addNode({
type: 'sd3_text_encoder',
id: getPrefixedId('neg_cond'),
prompt: negativePrompt,
prompt: prompts.negative,
});
const seed = g.addNode({
id: getPrefixedId('seed'),
type: 'integer',
value: _seed,
});
const denoise = g.addNode({
type: 'sd3_denoise',
id: getPrefixedId('sd3_denoise'),
@@ -95,7 +91,6 @@ export const buildSD3Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
id: getPrefixedId('sd3_i2l'),
});
g.addEdge(seed, 'value', denoise, 'seed');
g.addEdge(modelLoader, 'transformer', denoise, 'transformer');
g.addEdge(modelLoader, 'clip_l', posCond, 'clip_l');
g.addEdge(modelLoader, 'clip_l', negCond, 'clip_l');
@@ -103,43 +98,43 @@ export const buildSD3Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
g.addEdge(modelLoader, 'clip_g', negCond, 'clip_g');
g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder');
g.addEdge(modelLoader, 't5_encoder', negCond, 't5_encoder');
g.addEdge(modelLoader, 'vae', l2i, 'vae');
g.addEdge(positivePrompt, 'value', posCond, 'prompt');
g.addEdge(posCond, 'conditioning', denoise, 'positive_conditioning');
g.addEdge(negCond, 'conditioning', denoise, 'negative_conditioning');
g.addEdge(seed, 'value', denoise, 'seed');
g.addEdge(denoise, 'latents', l2i, 'latents');
g.upsertMetadata({
cfg_scale,
width: originalSize.width,
height: originalSize.height,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
negative_prompt: prompts.negative,
model: Graph.getModelMetadataField(model),
steps,
vae: vae ?? undefined,
});
g.addEdge(modelLoader, 'vae', l2i, 'vae');
let denoising_start: number;
if (optimizedDenoisingEnabled) {
// We rescale the img2imgStrength (with exponent 0.2) to effectively use the entire range [0, 1] and make the scale
// more user-friendly for SD3.5. Without this, most of the 'change' is concentrated in the high denoise strength
// range (>0.9).
denoising_start = 1 - img2imgStrength ** 0.2;
} else {
denoising_start = 1 - img2imgStrength;
}
g.addEdgeToMetadata(seed, 'value', 'seed');
g.addEdgeToMetadata(positivePrompt, 'value', 'positive_prompt');
let canvasOutput: Invocation<ImageOutputNodes> = l2i;
if (generationMode === 'txt2img') {
canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize });
canvasOutput = addTextToImage({
g,
denoise,
l2i,
originalSize,
scaledSize,
});
g.upsertMetadata({ generation_mode: 'sd3_txt2img' });
} else if (generationMode === 'img2img') {
assert(manager !== null);
canvasOutput = await addImageToImage({
g,
state,
manager,
l2i,
i2l,
@@ -148,14 +143,13 @@ export const buildSD3Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
originalSize,
scaledSize,
bbox,
denoising_start,
});
g.upsertMetadata({ generation_mode: 'sd3_img2img' });
} else if (generationMode === 'inpaint') {
assert(manager !== null);
canvasOutput = await addInpaint({
state,
g,
state,
manager,
l2i,
i2l,
@@ -164,15 +158,14 @@ export const buildSD3Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
modelLoader,
originalSize,
scaledSize,
denoising_start,
seed,
});
g.upsertMetadata({ generation_mode: 'sd3_inpaint' });
} else if (generationMode === 'outpaint') {
assert(manager !== null);
canvasOutput = await addOutpaint({
state,
g,
state,
manager,
l2i,
i2l,
@@ -181,7 +174,6 @@ export const buildSD3Graph = async (arg: GraphBuilderArg): Promise<GraphBuilderR
modelLoader,
originalSize,
scaledSize,
denoising_start,
seed,
});
g.upsertMetadata({ generation_mode: 'sd3_outpaint' });

View File

@@ -31,11 +31,12 @@ const log = logger('system');
export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilderReturn> => {
const { generationMode, state, manager } = arg;
log.debug({ generationMode, manager: manager?.id }, 'Building SDXL graph');
const model = selectMainModelConfig(state);
assert(model, 'No model found in state');
assert(model.base === 'sdxl');
assert(model, 'No model selected');
assert(model.base === 'sdxl', 'Selected model is not a SDXL Kontext model');
const params = selectParamsSlice(state);
const canvas = selectCanvasSlice(state);
@@ -53,47 +54,49 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
vaePrecision,
vae,
refinerModel,
refinerStart,
} = params;
assert(model, 'No model found in state');
const fp32 = vaePrecision === 'fp32';
const { originalSize, scaledSize } = selectOriginalAndScaledSizes(state);
const { positivePrompt, negativePrompt, positiveStylePrompt, negativeStylePrompt } =
selectPresetModifiedPrompts(state);
const prompts = selectPresetModifiedPrompts(state);
const g = new Graph(getPrefixedId('sdxl_graph'));
const seed = g.addNode({
id: getPrefixedId('seed'),
type: 'integer',
value: _seed,
});
const modelLoader = g.addNode({
type: 'sdxl_model_loader',
id: getPrefixedId('sdxl_model_loader'),
model,
});
const positivePrompt = g.addNode({
id: getPrefixedId('positive_prompt'),
type: 'string',
});
const posCond = g.addNode({
type: 'sdxl_compel_prompt',
id: getPrefixedId('pos_cond'),
prompt: positivePrompt,
style: positiveStylePrompt,
});
const posCondCollect = g.addNode({
type: 'collect',
id: getPrefixedId('pos_cond_collect'),
});
const negCond = g.addNode({
type: 'sdxl_compel_prompt',
id: getPrefixedId('neg_cond'),
prompt: negativePrompt,
style: negativeStylePrompt,
prompt: prompts.negative,
style: prompts.negativeStyle,
});
const negCondCollect = g.addNode({
type: 'collect',
id: getPrefixedId('neg_cond_collect'),
});
const seed = g.addNode({
id: getPrefixedId('seed'),
type: 'integer',
value: _seed,
});
const noise = g.addNode({
type: 'noise',
id: getPrefixedId('noise'),
@@ -108,8 +111,6 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
cfg_rescale_multiplier,
scheduler,
steps,
denoising_start: 0,
denoising_end: refinerModel ? refinerStart : 1,
});
const l2i = g.addNode({
type: 'l2i',
@@ -130,16 +131,20 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
})
: null;
g.addEdge(seed, 'value', noise, 'seed');
g.addEdge(modelLoader, 'unet', denoise, 'unet');
g.addEdge(modelLoader, 'clip', posCond, 'clip');
g.addEdge(modelLoader, 'clip', negCond, 'clip');
g.addEdge(modelLoader, 'clip2', posCond, 'clip2');
g.addEdge(modelLoader, 'clip2', negCond, 'clip2');
g.addEdge(positivePrompt, 'value', posCond, 'prompt');
g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
g.addEdge(negCond, 'conditioning', negCondCollect, 'item');
g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning');
g.addEdge(negCond, 'conditioning', negCondCollect, 'item');
g.addEdge(negCondCollect, 'collection', denoise, 'negative_conditioning');
g.addEdge(seed, 'value', noise, 'seed');
g.addEdge(noise, 'noise', denoise, 'noise');
g.addEdge(denoise, 'latents', l2i, 'latents');
@@ -148,16 +153,24 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
cfg_rescale_multiplier,
width: originalSize.width,
height: originalSize.height,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model: Graph.getModelMetadataField(model),
steps,
rand_device: shouldUseCpuNoise ? 'cpu' : 'cuda',
scheduler,
positive_style_prompt: positiveStylePrompt,
negative_style_prompt: negativeStylePrompt,
negative_prompt: prompts.negative,
negative_style_prompt: prompts.negativeStyle,
vae: vae ?? undefined,
});
g.addEdgeToMetadata(seed, 'value', 'seed');
g.addEdgeToMetadata(positivePrompt, 'value', 'positive_prompt');
if (prompts.useMainPromptsForStyle) {
g.addEdge(positivePrompt, 'value', posCond, 'style');
g.addEdgeToMetadata(positivePrompt, 'value', 'positive_style_prompt');
} else {
posCond.style = prompts.positiveStyle;
g.upsertMetadata({ positive_style_prompt: prompts.positiveStyle });
}
const seamless = addSeamless(state, g, denoise, modelLoader, vaeLoader);
@@ -172,19 +185,22 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
await addSDXLRefiner(state, g, denoise, seamless, posCond, negCond, l2i);
}
const denoising_start = refinerModel
? Math.min(refinerStart, 1 - params.img2imgStrength)
: 1 - params.img2imgStrength;
let canvasOutput: Invocation<ImageOutputNodes> = l2i;
if (generationMode === 'txt2img') {
canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize });
canvasOutput = addTextToImage({
g,
denoise,
l2i,
originalSize,
scaledSize,
});
g.upsertMetadata({ generation_mode: 'sdxl_txt2img' });
} else if (generationMode === 'img2img') {
assert(manager !== null);
canvasOutput = await addImageToImage({
g,
state,
manager,
l2i,
i2l,
@@ -193,14 +209,13 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
originalSize,
scaledSize,
bbox,
denoising_start,
});
g.upsertMetadata({ generation_mode: 'sdxl_img2img' });
} else if (generationMode === 'inpaint') {
assert(manager !== null);
canvasOutput = await addInpaint({
state,
g,
state,
manager,
l2i,
i2l,
@@ -209,15 +224,14 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
modelLoader,
originalSize,
scaledSize,
denoising_start,
seed,
});
g.upsertMetadata({ generation_mode: 'sdxl_inpaint' });
} else if (generationMode === 'outpaint') {
assert(manager !== null);
canvasOutput = await addOutpaint({
state,
g,
state,
manager,
l2i,
i2l,
@@ -226,7 +240,6 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
modelLoader,
originalSize,
scaledSize,
denoising_start,
seed,
});
g.upsertMetadata({ generation_mode: 'sdxl_outpaint' });
@@ -322,7 +335,7 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
g.setMetadataReceivingNode(canvasOutput);
return {
g,
seedFieldIdentifier: { nodeId: seed.id, fieldName: 'value' },
positivePromptFieldIdentifier: { nodeId: posCond.id, fieldName: 'prompt' },
seed,
positivePrompt,
};
};

View File

@@ -1,7 +1,14 @@
import { createSelector } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import {
selectImg2imgStrength,
selectMainModelConfig,
selectOptimizedDenoisingEnabled,
selectParamsSlice,
selectRefinerModel,
selectRefinerStart,
} from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { ParamsState } from 'features/controlLayers/store/types';
import type { BoardField } from 'features/nodes/types/common';
@@ -77,19 +84,21 @@ export const selectPresetModifiedPrompts = createSelector(
);
return {
positivePrompt: presetModifiedPositivePrompt,
negativePrompt: presetModifiedNegativePrompt,
positiveStylePrompt: shouldConcatPrompts ? presetModifiedPositivePrompt : positivePrompt2,
negativeStylePrompt: shouldConcatPrompts ? presetModifiedNegativePrompt : negativePrompt2,
positive: presetModifiedPositivePrompt,
negative: presetModifiedNegativePrompt,
positiveStyle: positivePrompt2,
negativeStyle: negativePrompt2,
useMainPromptsForStyle: shouldConcatPrompts,
};
}
}
return {
positivePrompt,
negativePrompt,
positiveStylePrompt: shouldConcatPrompts ? positivePrompt : positivePrompt2,
negativeStylePrompt: shouldConcatPrompts ? negativePrompt : negativePrompt2,
positive: positivePrompt,
negative: negativePrompt,
positiveStyle: positivePrompt2,
negativeStyle: negativePrompt2,
useMainPromptsForStyle: shouldConcatPrompts,
};
}
);
@@ -177,3 +186,66 @@ export const isMainModelWithoutUnet = (modelLoader: Invocation<MainModelLoaderNo
};
export const isCanvasOutputNodeId = (nodeId: string) => nodeId.split(':')[0] === CANVAS_OUTPUT_PREFIX;
export const getDenoisingStartAndEnd = (state: RootState): { denoising_start: number; denoising_end: number } => {
const optimizedDenoisingEnabled = selectOptimizedDenoisingEnabled(state);
const denoisingStrength = selectImg2imgStrength(state);
const model = selectMainModelConfig(state);
const refinerModel = selectRefinerModel(state);
const refinerDenoisingStart = selectRefinerStart(state);
switch (model?.base) {
case 'sd-3': {
// We rescale the img2imgStrength (with exponent 0.2) to effectively use the entire range [0, 1] and make the scale
// more user-friendly for SD3.5. Without this, most of the 'change' is concentrated in the high denoise strength
// range (>0.9).
const exponent = optimizedDenoisingEnabled ? 0.2 : 1;
return {
denoising_start: 1 - denoisingStrength ** exponent,
denoising_end: 1,
};
}
case 'flux': {
if (model.variant === 'inpaint') {
// This is a FLUX Fill model - we always denoise fully
return {
denoising_start: 0,
denoising_end: 1,
};
} else {
// We rescale the img2imgStrength (with exponent 0.2) to effectively use the entire range [0, 1] and make the scale
// more user-friendly for SD3.5. Without this, most of the 'change' is concentrated in the high denoise strength
// range (>0.9).
const exponent = optimizedDenoisingEnabled ? 0.2 : 1;
return {
denoising_start: 1 - denoisingStrength ** exponent,
denoising_end: 1,
};
}
}
case 'sd-1':
case 'sd-2':
case 'cogview4': {
return {
denoising_start: 1 - denoisingStrength,
denoising_end: 1,
};
}
case 'sdxl': {
if (refinerModel) {
return {
denoising_start: Math.min(refinerDenoisingStart, 1 - denoisingStrength),
denoising_end: refinerDenoisingStart,
};
} else {
return {
denoising_start: 1 - denoisingStrength,
denoising_end: 1,
};
}
}
default: {
assert(false, `Unsupported base: ${model?.base}`);
}
}
};

View File

@@ -1,8 +1,8 @@
import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type { GenerationMode } from 'features/controlLayers/store/types';
import type { FieldIdentifier } from 'features/nodes/types/field';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { Invocation } from 'services/api/types';
export type ImageOutputNodes =
| 'l2i'
@@ -38,8 +38,8 @@ export type GraphBuilderArg = {
export type GraphBuilderReturn = {
g: Graph;
seedFieldIdentifier?: FieldIdentifier;
positivePromptFieldIdentifier: FieldIdentifier;
seed?: Invocation<'integer'>;
positivePrompt: Invocation<'string'>;
};
export class UnsupportedGenerationModeError extends Error {

View File

@@ -97,15 +97,15 @@ const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prep
return;
}
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = buildGraphResult.value;
const { g, seed, positivePrompt } = buildGraphResult.value;
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch({
state,
g,
prepend,
seedFieldIdentifier,
positivePromptFieldIdentifier,
seedNode: seed,
positivePromptNode: positivePrompt,
origin: 'canvas',
destination,
})

View File

@@ -95,15 +95,15 @@ const enqueueGenerate = async (store: AppStore, prepend: boolean) => {
return;
}
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = buildGraphResult.value;
const { g, seed, positivePrompt } = buildGraphResult.value;
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch({
state,
g,
prepend,
seedFieldIdentifier,
positivePromptFieldIdentifier,
seedNode: seed,
positivePromptNode: positivePrompt,
origin: 'canvas',
destination,
})