feat(ui): graph building for FLUX in linear UI

This commit is contained in:
Mary Hipp
2024-09-11 19:25:41 -04:00
committed by psychedelicious
parent 00de20d102
commit a300b6ebdd
9 changed files with 230 additions and 25 deletions

View File

@@ -18,6 +18,7 @@ import { serializeError } from 'serialize-error';
import { queueApi } from 'services/api/endpoints/queue';
import type { Invocation } from 'services/api/types';
import { assert } from 'tsafe';
import { buildFLUXGraph } from '../../../../../features/nodes/util/graph/generation/buildFLUXGraph';
const log = logger('generation');
@@ -47,7 +48,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
};
let buildGraphResult: Result<
{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'compel' | 'sdxl_compel_prompt'> },
{ g: Graph; noise: Invocation<'noise' | 'flux_denoise'>; posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> },
Error
>;
@@ -62,6 +63,9 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
case `sd-2`:
buildGraphResult = await withResultAsync(() => buildSD1Graph(state, manager));
break;
case `flux`:
buildGraphResult = await withResultAsync(() => buildFLUXGraph(state, manager));
break;
default:
assert(false, `No graph builders for base ${base}`);
}

View File

@@ -9,8 +9,8 @@ export const prepareLinearUIBatch = (
state: RootState,
g: Graph,
prepend: boolean,
noise: Invocation<'noise'>,
posCond: Invocation<'compel' | 'sdxl_compel_prompt'>,
noise: Invocation<'noise' | 'flux_denoise'>,
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>,
origin: 'generation' | 'workflows' | 'upscaling',
destination: 'canvas' | 'gallery'
): BatchConfig => {

View File

@@ -8,15 +8,15 @@ import type { Invocation } from 'services/api/types';
export const addImageToImage = async (
g: Graph,
manager: CanvasManager,
l2i: Invocation<'l2i'>,
denoise: Invocation<'denoise_latents'>,
vaeSource: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'seamless' | 'vae_loader'>,
l2i: Invocation<'l2i' | 'flux_vae_decode'>,
denoise: Invocation<'denoise_latents' | 'flux_denoise'>,
vaeSource: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader' | 'seamless' | 'vae_loader'>,
originalSize: Dimensions,
scaledSize: Dimensions,
bbox: CanvasState['bbox'],
denoising_start: number,
fp32: boolean
): Promise<Invocation<'img_resize' | 'l2i'>> => {
): Promise<Invocation<'img_resize' | 'l2i' | 'flux_vae_decode'>> => {
denoise.denoising_start = denoising_start;
const { image_name } = await manager.compositor.getCompositeRasterLayerImageDTO(bbox.rect);
@@ -29,7 +29,11 @@ export const addImageToImage = async (
image: { image_name },
...scaledSize,
});
const i2l = g.addNode({ id: 'i2l', type: 'i2l', fp32 });
const i2l = vaeSource.type === "flux_model_loader" ?
g.addNode({ id: 'flux_vae_encode', type: 'flux_vae_encode' }) :
g.addNode({ id: 'i2l', type: 'i2l', fp32 });
const resizeImageToOriginalSize = g.addNode({
type: 'img_resize',
id: getPrefixedId('initial_image_resize_out'),
@@ -45,7 +49,7 @@ export const addImageToImage = async (
return resizeImageToOriginalSize;
} else {
// No need to resize, just decode
const i2l = g.addNode({ id: 'i2l', type: 'i2l', image: { image_name }, fp32 });
const i2l = vaeSource.type === "flux_model_loader" ? g.addNode({ id: 'flux_vae_encode', type: 'flux_vae_encode' }) : g.addNode({ id: 'i2l', type: 'i2l', fp32 });
g.addEdge(vaeSource, 'vae', i2l, 'vae');
g.addEdge(i2l, 'latents', denoise, 'latents');
return l2i;

View File

@@ -13,10 +13,10 @@ export const addInpaint = async (
state: RootState,
g: Graph,
manager: CanvasManager,
l2i: Invocation<'l2i'>,
denoise: Invocation<'denoise_latents'>,
vaeSource: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'seamless' | 'vae_loader'>,
modelLoader: Invocation<'main_model_loader' | 'sdxl_model_loader'>,
l2i: Invocation<'l2i' | 'flux_vae_decode'>,
denoise: Invocation<'denoise_latents' | 'flux_denoise'>,
vaeSource: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader' | 'seamless' | 'vae_loader'>,
modelLoader: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader'>,
originalSize: Dimensions,
scaledSize: Dimensions,
denoising_start: number,
@@ -84,7 +84,9 @@ export const addInpaint = async (
g.addEdge(vaeSource, 'vae', i2l, 'vae');
g.addEdge(vaeSource, 'vae', createGradientMask, 'vae');
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
if (modelLoader.type !== "flux_model_loader") {
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
}
g.addEdge(resizeImageToScaledSize, 'image', createGradientMask, 'image');
g.addEdge(resizeMaskToScaledSize, 'image', createGradientMask, 'mask');
@@ -138,7 +140,9 @@ export const addInpaint = async (
g.addEdge(i2l, 'latents', denoise, 'latents');
g.addEdge(vaeSource, 'vae', i2l, 'vae');
g.addEdge(vaeSource, 'vae', createGradientMask, 'vae');
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
if (modelLoader.type !== "flux_model_loader") {
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
}
g.addEdge(createGradientMask, 'denoise_mask', denoise, 'denoise_mask');
g.addEdge(createGradientMask, 'expanded_mask_area', canvasPasteBack, 'mask');

View File

@@ -10,7 +10,7 @@ import type { Invocation } from 'services/api/types';
*/
export const addNSFWChecker = (
g: Graph,
imageOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop'>
imageOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode'>
): Invocation<'img_nsfw'> => {
const nsfw = g.addNode({
type: 'img_nsfw',

View File

@@ -14,10 +14,10 @@ export const addOutpaint = async (
state: RootState,
g: Graph,
manager: CanvasManager,
l2i: Invocation<'l2i'>,
denoise: Invocation<'denoise_latents'>,
vaeSource: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'seamless' | 'vae_loader'>,
modelLoader: Invocation<'main_model_loader' | 'sdxl_model_loader'>,
l2i: Invocation<'l2i' | 'flux_vae_decode'>,
denoise: Invocation<'denoise_latents' | 'flux_denoise'>,
vaeSource: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader' | 'seamless' | 'vae_loader'>,
modelLoader: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader'>,
originalSize: Dimensions,
scaledSize: Dimensions,
denoising_start: number,
@@ -86,7 +86,10 @@ export const addOutpaint = async (
g.addEdge(infill, 'image', createGradientMask, 'image');
g.addEdge(resizeInputMaskToScaledSize, 'image', createGradientMask, 'mask');
g.addEdge(vaeSource, 'vae', createGradientMask, 'vae');
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
if (modelLoader.type !== "flux_model_loader") {
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
}
g.addEdge(createGradientMask, 'denoise_mask', denoise, 'denoise_mask');
// Decode infilled image and connect to denoise
@@ -169,7 +172,10 @@ export const addOutpaint = async (
g.addEdge(i2l, 'latents', denoise, 'latents');
g.addEdge(vaeSource, 'vae', i2l, 'vae');
g.addEdge(vaeSource, 'vae', createGradientMask, 'vae');
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
if (modelLoader.type !== "flux_model_loader") {
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
}
g.addEdge(createGradientMask, 'denoise_mask', denoise, 'denoise_mask');
g.addEdge(createGradientMask, 'expanded_mask_area', canvasPasteBack, 'mask');
g.addEdge(l2i, 'image', canvasPasteBack, 'generated_image');

View File

@@ -6,10 +6,10 @@ import type { Invocation } from 'services/api/types';
export const addTextToImage = (
g: Graph,
l2i: Invocation<'l2i'>,
l2i: Invocation<'l2i' | 'flux_vae_decode'>,
originalSize: Dimensions,
scaledSize: Dimensions
): Invocation<'img_resize' | 'l2i'> => {
): Invocation<'img_resize' | 'l2i' | 'flux_vae_decode'> => {
if (!isEqual(scaledSize, originalSize)) {
// We need to resize the output image back to the original size
const resizeImageToOriginalSize = g.addNode({

View File

@@ -10,7 +10,7 @@ import type { Invocation } from 'services/api/types';
*/
export const addWatermarker = (
g: Graph,
imageOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop'>
imageOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode'>
): Invocation<'img_watermark'> => {
const watermark = g.addNode({
type: 'img_watermark',

View File

@@ -0,0 +1,187 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage';
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { getBoardField, getPresetModifiedPrompts, getSizes } from 'features/nodes/util/graph/graphBuilderUtils';
import type { Invocation } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
const log = logger('system');
export const buildFLUXGraph = async (
state: RootState,
manager: CanvasManager
): Promise<{ g: Graph; noise: Invocation<'noise' | 'flux_denoise'>; posCond: Invocation<'flux_text_encoder'> }> => {
const generationMode = manager.compositor.getGenerationMode();
log.debug({ generationMode }, 'Building FLUX graph');
const params = selectParamsSlice(state);
const canvasSettings = selectCanvasSettingsSlice(state);
const canvas = selectCanvasSlice(state);
const { bbox } = canvas;
const { originalSize, scaledSize } = getSizes(bbox);
const {
model,
guidance,
seed,
steps,
fluxVAE,
t5EncoderModel,
clipEmbedModel,
img2imgStrength,
} = params;
assert(model, 'No model found in state');
assert(t5EncoderModel, 'No T5 Encoder model found in state');
assert(clipEmbedModel, 'No CLIP Embed model found in state');
assert(fluxVAE, 'No FLUX VAE model found in state');
const { positivePrompt, } = getPresetModifiedPrompts(state);
const g = new Graph(getPrefixedId('flux_graph'));
const modelLoader = g.addNode({
type: 'flux_model_loader',
id: getPrefixedId('flux_model_loader'),
model,
t5_encoder_model: t5EncoderModel,
clip_embed_model: clipEmbedModel,
vae_model: fluxVAE
});
const posCond = g.addNode({
type: 'flux_text_encoder',
id: getPrefixedId('flux_text_encoder'),
prompt: positivePrompt,
});
const noise = g.addNode({
type: 'flux_denoise',
id: getPrefixedId('flux_denoise'),
guidance,
num_steps: steps,
seed,
denoising_start: 1 - img2imgStrength,
denoising_end: 1,
width: scaledSize.width,
height: scaledSize.height
});
const l2i = g.addNode({
type: 'flux_vae_decode',
id: getPrefixedId('flux_vae_decode'),
});
let canvasOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode'> = l2i;
g.addEdge(modelLoader, 'transformer', noise, 'transformer');
g.addEdge(modelLoader, 'vae', l2i, 'vae');
g.addEdge(modelLoader, 'clip', posCond, 'clip');
g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder');
g.addEdge(modelLoader, 'max_seq_len', posCond, 't5_max_seq_len');
g.addEdge(posCond, 'conditioning', noise, 'positive_text_conditioning');
g.addEdge(noise, 'latents', l2i, 'latents');
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
assert(modelConfig.base === 'flux');
g.upsertMetadata({
generation_mode: 'flux_txt2img',
guidance,
width: scaledSize.width,
height: scaledSize.height,
positive_prompt: positivePrompt,
model: Graph.getModelMetadataField(modelConfig),
seed,
steps,
vae: fluxVAE,
t5_encoder: t5EncoderModel,
clip_embed_model: clipEmbedModel
});
if (generationMode === 'txt2img') {
canvasOutput = addTextToImage(g, l2i, originalSize, scaledSize);
} else if (generationMode === 'img2img') {
canvasOutput = await addImageToImage(
g,
manager,
l2i,
noise,
modelLoader,
originalSize,
scaledSize,
bbox,
1 - params.img2imgStrength,
false
);
} else if (generationMode === 'inpaint') {
canvasOutput = await addInpaint(
state,
g,
manager,
l2i,
noise,
modelLoader,
modelLoader,
originalSize,
scaledSize,
1 - params.img2imgStrength,
false
);
} else if (generationMode === 'outpaint') {
canvasOutput = await addOutpaint(
state,
g,
manager,
l2i,
noise,
modelLoader,
modelLoader,
originalSize,
scaledSize,
1 - params.img2imgStrength,
false
);
}
if (state.system.shouldUseNSFWChecker) {
canvasOutput = addNSFWChecker(g, canvasOutput);
}
if (state.system.shouldUseWatermarker) {
canvasOutput = addWatermarker(g, canvasOutput);
}
const shouldSaveToGallery = !canvasSettings.sendToCanvas || canvasSettings.autoSave;
g.updateNode(canvasOutput, {
id: getPrefixedId('canvas_output'),
is_intermediate: !shouldSaveToGallery,
use_cache: false,
board: getBoardField(state),
});
g.setMetadataReceivingNode(canvasOutput);
return { g, noise, posCond };
};