mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): graph building for FLUX in linear UI
This commit is contained in:
committed by
psychedelicious
parent
00de20d102
commit
a300b6ebdd
@@ -18,6 +18,7 @@ import { serializeError } from 'serialize-error';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
import { buildFLUXGraph } from '../../../../../features/nodes/util/graph/generation/buildFLUXGraph';
|
||||
|
||||
const log = logger('generation');
|
||||
|
||||
@@ -47,7 +48,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
};
|
||||
|
||||
let buildGraphResult: Result<
|
||||
{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'compel' | 'sdxl_compel_prompt'> },
|
||||
{ g: Graph; noise: Invocation<'noise' | 'flux_denoise'>; posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> },
|
||||
Error
|
||||
>;
|
||||
|
||||
@@ -62,6 +63,9 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
case `sd-2`:
|
||||
buildGraphResult = await withResultAsync(() => buildSD1Graph(state, manager));
|
||||
break;
|
||||
case `flux`:
|
||||
buildGraphResult = await withResultAsync(() => buildFLUXGraph(state, manager));
|
||||
break;
|
||||
default:
|
||||
assert(false, `No graph builders for base ${base}`);
|
||||
}
|
||||
|
||||
@@ -9,8 +9,8 @@ export const prepareLinearUIBatch = (
|
||||
state: RootState,
|
||||
g: Graph,
|
||||
prepend: boolean,
|
||||
noise: Invocation<'noise'>,
|
||||
posCond: Invocation<'compel' | 'sdxl_compel_prompt'>,
|
||||
noise: Invocation<'noise' | 'flux_denoise'>,
|
||||
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>,
|
||||
origin: 'generation' | 'workflows' | 'upscaling',
|
||||
destination: 'canvas' | 'gallery'
|
||||
): BatchConfig => {
|
||||
|
||||
@@ -8,15 +8,15 @@ import type { Invocation } from 'services/api/types';
|
||||
export const addImageToImage = async (
|
||||
g: Graph,
|
||||
manager: CanvasManager,
|
||||
l2i: Invocation<'l2i'>,
|
||||
denoise: Invocation<'denoise_latents'>,
|
||||
vaeSource: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'seamless' | 'vae_loader'>,
|
||||
l2i: Invocation<'l2i' | 'flux_vae_decode'>,
|
||||
denoise: Invocation<'denoise_latents' | 'flux_denoise'>,
|
||||
vaeSource: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader' | 'seamless' | 'vae_loader'>,
|
||||
originalSize: Dimensions,
|
||||
scaledSize: Dimensions,
|
||||
bbox: CanvasState['bbox'],
|
||||
denoising_start: number,
|
||||
fp32: boolean
|
||||
): Promise<Invocation<'img_resize' | 'l2i'>> => {
|
||||
): Promise<Invocation<'img_resize' | 'l2i' | 'flux_vae_decode'>> => {
|
||||
denoise.denoising_start = denoising_start;
|
||||
|
||||
const { image_name } = await manager.compositor.getCompositeRasterLayerImageDTO(bbox.rect);
|
||||
@@ -29,7 +29,11 @@ export const addImageToImage = async (
|
||||
image: { image_name },
|
||||
...scaledSize,
|
||||
});
|
||||
const i2l = g.addNode({ id: 'i2l', type: 'i2l', fp32 });
|
||||
|
||||
const i2l = vaeSource.type === "flux_model_loader" ?
|
||||
g.addNode({ id: 'flux_vae_encode', type: 'flux_vae_encode' }) :
|
||||
g.addNode({ id: 'i2l', type: 'i2l', fp32 });
|
||||
|
||||
const resizeImageToOriginalSize = g.addNode({
|
||||
type: 'img_resize',
|
||||
id: getPrefixedId('initial_image_resize_out'),
|
||||
@@ -45,7 +49,7 @@ export const addImageToImage = async (
|
||||
return resizeImageToOriginalSize;
|
||||
} else {
|
||||
// No need to resize, just decode
|
||||
const i2l = g.addNode({ id: 'i2l', type: 'i2l', image: { image_name }, fp32 });
|
||||
const i2l = vaeSource.type === "flux_model_loader" ? g.addNode({ id: 'flux_vae_encode', type: 'flux_vae_encode' }) : g.addNode({ id: 'i2l', type: 'i2l', fp32 });
|
||||
g.addEdge(vaeSource, 'vae', i2l, 'vae');
|
||||
g.addEdge(i2l, 'latents', denoise, 'latents');
|
||||
return l2i;
|
||||
|
||||
@@ -13,10 +13,10 @@ export const addInpaint = async (
|
||||
state: RootState,
|
||||
g: Graph,
|
||||
manager: CanvasManager,
|
||||
l2i: Invocation<'l2i'>,
|
||||
denoise: Invocation<'denoise_latents'>,
|
||||
vaeSource: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'seamless' | 'vae_loader'>,
|
||||
modelLoader: Invocation<'main_model_loader' | 'sdxl_model_loader'>,
|
||||
l2i: Invocation<'l2i' | 'flux_vae_decode'>,
|
||||
denoise: Invocation<'denoise_latents' | 'flux_denoise'>,
|
||||
vaeSource: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader' | 'seamless' | 'vae_loader'>,
|
||||
modelLoader: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader'>,
|
||||
originalSize: Dimensions,
|
||||
scaledSize: Dimensions,
|
||||
denoising_start: number,
|
||||
@@ -84,7 +84,9 @@ export const addInpaint = async (
|
||||
g.addEdge(vaeSource, 'vae', i2l, 'vae');
|
||||
|
||||
g.addEdge(vaeSource, 'vae', createGradientMask, 'vae');
|
||||
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
|
||||
if (modelLoader.type !== "flux_model_loader") {
|
||||
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
|
||||
}
|
||||
g.addEdge(resizeImageToScaledSize, 'image', createGradientMask, 'image');
|
||||
g.addEdge(resizeMaskToScaledSize, 'image', createGradientMask, 'mask');
|
||||
|
||||
@@ -138,7 +140,9 @@ export const addInpaint = async (
|
||||
g.addEdge(i2l, 'latents', denoise, 'latents');
|
||||
g.addEdge(vaeSource, 'vae', i2l, 'vae');
|
||||
g.addEdge(vaeSource, 'vae', createGradientMask, 'vae');
|
||||
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
|
||||
if (modelLoader.type !== "flux_model_loader") {
|
||||
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
|
||||
}
|
||||
g.addEdge(createGradientMask, 'denoise_mask', denoise, 'denoise_mask');
|
||||
g.addEdge(createGradientMask, 'expanded_mask_area', canvasPasteBack, 'mask');
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import type { Invocation } from 'services/api/types';
|
||||
*/
|
||||
export const addNSFWChecker = (
|
||||
g: Graph,
|
||||
imageOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop'>
|
||||
imageOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode'>
|
||||
): Invocation<'img_nsfw'> => {
|
||||
const nsfw = g.addNode({
|
||||
type: 'img_nsfw',
|
||||
|
||||
@@ -14,10 +14,10 @@ export const addOutpaint = async (
|
||||
state: RootState,
|
||||
g: Graph,
|
||||
manager: CanvasManager,
|
||||
l2i: Invocation<'l2i'>,
|
||||
denoise: Invocation<'denoise_latents'>,
|
||||
vaeSource: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'seamless' | 'vae_loader'>,
|
||||
modelLoader: Invocation<'main_model_loader' | 'sdxl_model_loader'>,
|
||||
l2i: Invocation<'l2i' | 'flux_vae_decode'>,
|
||||
denoise: Invocation<'denoise_latents' | 'flux_denoise'>,
|
||||
vaeSource: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader' | 'seamless' | 'vae_loader'>,
|
||||
modelLoader: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader'>,
|
||||
originalSize: Dimensions,
|
||||
scaledSize: Dimensions,
|
||||
denoising_start: number,
|
||||
@@ -86,7 +86,10 @@ export const addOutpaint = async (
|
||||
g.addEdge(infill, 'image', createGradientMask, 'image');
|
||||
g.addEdge(resizeInputMaskToScaledSize, 'image', createGradientMask, 'mask');
|
||||
g.addEdge(vaeSource, 'vae', createGradientMask, 'vae');
|
||||
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
|
||||
if (modelLoader.type !== "flux_model_loader") {
|
||||
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
|
||||
}
|
||||
|
||||
g.addEdge(createGradientMask, 'denoise_mask', denoise, 'denoise_mask');
|
||||
|
||||
// Decode infilled image and connect to denoise
|
||||
@@ -169,7 +172,10 @@ export const addOutpaint = async (
|
||||
g.addEdge(i2l, 'latents', denoise, 'latents');
|
||||
g.addEdge(vaeSource, 'vae', i2l, 'vae');
|
||||
g.addEdge(vaeSource, 'vae', createGradientMask, 'vae');
|
||||
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
|
||||
if (modelLoader.type !== "flux_model_loader") {
|
||||
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
|
||||
}
|
||||
|
||||
g.addEdge(createGradientMask, 'denoise_mask', denoise, 'denoise_mask');
|
||||
g.addEdge(createGradientMask, 'expanded_mask_area', canvasPasteBack, 'mask');
|
||||
g.addEdge(l2i, 'image', canvasPasteBack, 'generated_image');
|
||||
|
||||
@@ -6,10 +6,10 @@ import type { Invocation } from 'services/api/types';
|
||||
|
||||
export const addTextToImage = (
|
||||
g: Graph,
|
||||
l2i: Invocation<'l2i'>,
|
||||
l2i: Invocation<'l2i' | 'flux_vae_decode'>,
|
||||
originalSize: Dimensions,
|
||||
scaledSize: Dimensions
|
||||
): Invocation<'img_resize' | 'l2i'> => {
|
||||
): Invocation<'img_resize' | 'l2i' | 'flux_vae_decode'> => {
|
||||
if (!isEqual(scaledSize, originalSize)) {
|
||||
// We need to resize the output image back to the original size
|
||||
const resizeImageToOriginalSize = g.addNode({
|
||||
|
||||
@@ -10,7 +10,7 @@ import type { Invocation } from 'services/api/types';
|
||||
*/
|
||||
export const addWatermarker = (
|
||||
g: Graph,
|
||||
imageOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop'>
|
||||
imageOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode'>
|
||||
): Invocation<'img_watermark'> => {
|
||||
const watermark = g.addNode({
|
||||
type: 'img_watermark',
|
||||
|
||||
@@ -0,0 +1,187 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage';
|
||||
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
|
||||
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
|
||||
import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
|
||||
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
|
||||
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { getBoardField, getPresetModifiedPrompts, getSizes } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const buildFLUXGraph = async (
|
||||
state: RootState,
|
||||
manager: CanvasManager
|
||||
): Promise<{ g: Graph; noise: Invocation<'noise' | 'flux_denoise'>; posCond: Invocation<'flux_text_encoder'> }> => {
|
||||
const generationMode = manager.compositor.getGenerationMode();
|
||||
log.debug({ generationMode }, 'Building FLUX graph');
|
||||
|
||||
const params = selectParamsSlice(state);
|
||||
const canvasSettings = selectCanvasSettingsSlice(state);
|
||||
const canvas = selectCanvasSlice(state);
|
||||
|
||||
const { bbox } = canvas;
|
||||
|
||||
const { originalSize, scaledSize } = getSizes(bbox);
|
||||
|
||||
const {
|
||||
model,
|
||||
guidance,
|
||||
seed,
|
||||
steps,
|
||||
fluxVAE,
|
||||
t5EncoderModel,
|
||||
clipEmbedModel,
|
||||
img2imgStrength,
|
||||
} = params;
|
||||
|
||||
assert(model, 'No model found in state');
|
||||
assert(t5EncoderModel, 'No T5 Encoder model found in state');
|
||||
assert(clipEmbedModel, 'No CLIP Embed model found in state');
|
||||
assert(fluxVAE, 'No FLUX VAE model found in state');
|
||||
|
||||
const { positivePrompt, } = getPresetModifiedPrompts(state);
|
||||
|
||||
const g = new Graph(getPrefixedId('flux_graph'));
|
||||
const modelLoader = g.addNode({
|
||||
type: 'flux_model_loader',
|
||||
id: getPrefixedId('flux_model_loader'),
|
||||
model,
|
||||
t5_encoder_model: t5EncoderModel,
|
||||
clip_embed_model: clipEmbedModel,
|
||||
vae_model: fluxVAE
|
||||
});
|
||||
|
||||
const posCond = g.addNode({
|
||||
type: 'flux_text_encoder',
|
||||
id: getPrefixedId('flux_text_encoder'),
|
||||
prompt: positivePrompt,
|
||||
});
|
||||
|
||||
const noise = g.addNode({
|
||||
type: 'flux_denoise',
|
||||
id: getPrefixedId('flux_denoise'),
|
||||
guidance,
|
||||
num_steps: steps,
|
||||
seed,
|
||||
denoising_start: 1 - img2imgStrength,
|
||||
denoising_end: 1,
|
||||
width: scaledSize.width,
|
||||
height: scaledSize.height
|
||||
});
|
||||
|
||||
const l2i = g.addNode({
|
||||
type: 'flux_vae_decode',
|
||||
id: getPrefixedId('flux_vae_decode'),
|
||||
});
|
||||
|
||||
|
||||
let canvasOutput: Invocation<'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode'> = l2i;
|
||||
|
||||
g.addEdge(modelLoader, 'transformer', noise, 'transformer');
|
||||
g.addEdge(modelLoader, 'vae', l2i, 'vae');
|
||||
|
||||
g.addEdge(modelLoader, 'clip', posCond, 'clip');
|
||||
g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder');
|
||||
g.addEdge(modelLoader, 'max_seq_len', posCond, 't5_max_seq_len');
|
||||
|
||||
g.addEdge(posCond, 'conditioning', noise, 'positive_text_conditioning');
|
||||
|
||||
g.addEdge(noise, 'latents', l2i, 'latents');
|
||||
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
assert(modelConfig.base === 'flux');
|
||||
|
||||
g.upsertMetadata({
|
||||
generation_mode: 'flux_txt2img',
|
||||
guidance,
|
||||
width: scaledSize.width,
|
||||
height: scaledSize.height,
|
||||
positive_prompt: positivePrompt,
|
||||
model: Graph.getModelMetadataField(modelConfig),
|
||||
seed,
|
||||
steps,
|
||||
vae: fluxVAE,
|
||||
t5_encoder: t5EncoderModel,
|
||||
clip_embed_model: clipEmbedModel
|
||||
});
|
||||
|
||||
|
||||
if (generationMode === 'txt2img') {
|
||||
canvasOutput = addTextToImage(g, l2i, originalSize, scaledSize);
|
||||
} else if (generationMode === 'img2img') {
|
||||
canvasOutput = await addImageToImage(
|
||||
g,
|
||||
manager,
|
||||
l2i,
|
||||
noise,
|
||||
modelLoader,
|
||||
originalSize,
|
||||
scaledSize,
|
||||
bbox,
|
||||
1 - params.img2imgStrength,
|
||||
false
|
||||
);
|
||||
} else if (generationMode === 'inpaint') {
|
||||
canvasOutput = await addInpaint(
|
||||
state,
|
||||
g,
|
||||
manager,
|
||||
l2i,
|
||||
noise,
|
||||
modelLoader,
|
||||
modelLoader,
|
||||
originalSize,
|
||||
scaledSize,
|
||||
1 - params.img2imgStrength,
|
||||
false
|
||||
);
|
||||
} else if (generationMode === 'outpaint') {
|
||||
canvasOutput = await addOutpaint(
|
||||
state,
|
||||
g,
|
||||
manager,
|
||||
l2i,
|
||||
noise,
|
||||
modelLoader,
|
||||
modelLoader,
|
||||
originalSize,
|
||||
scaledSize,
|
||||
1 - params.img2imgStrength,
|
||||
false
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
canvasOutput = addNSFWChecker(g, canvasOutput);
|
||||
}
|
||||
|
||||
if (state.system.shouldUseWatermarker) {
|
||||
canvasOutput = addWatermarker(g, canvasOutput);
|
||||
}
|
||||
|
||||
const shouldSaveToGallery = !canvasSettings.sendToCanvas || canvasSettings.autoSave;
|
||||
|
||||
g.updateNode(canvasOutput, {
|
||||
id: getPrefixedId('canvas_output'),
|
||||
is_intermediate: !shouldSaveToGallery,
|
||||
use_cache: false,
|
||||
board: getBoardField(state),
|
||||
});
|
||||
|
||||
g.setMetadataReceivingNode(canvasOutput);
|
||||
return { g, noise, posCond };
|
||||
};
|
||||
Reference in New Issue
Block a user