Files
InvokeAI/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts
2023-08-21 19:17:36 +10:00

537 lines
12 KiB
TypeScript

import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
ImageBlurInvocation,
ImageDTO,
ImageToLatentsInvocation,
NoiseInvocation,
RandomIntInvocation,
RangeOfSizeInvocation,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
CANVAS_INPAINT_GRAPH,
CANVAS_OUTPUT,
CLIP_SKIP,
COLOR_CORRECT,
DENOISE_LATENTS,
INPAINT_IMAGE,
INPAINT_IMAGE_RESIZE_DOWN,
INPAINT_IMAGE_RESIZE_UP,
ITERATE,
LATENTS_TO_IMAGE,
MAIN_MODEL_LOADER,
MASK_BLUR,
MASK_RESIZE_DOWN,
MASK_RESIZE_UP,
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
} from './constants';
/**
* Builds the Canvas tab's Inpaint graph.
*/
export const buildCanvasInpaintGraph = (
state: RootState,
canvasInitImage: ImageDTO,
canvasMaskImage: ImageDTO
): NonNullableGraph => {
const log = logger('nodes');
const {
positivePrompt,
negativePrompt,
model,
cfgScale: cfg_scale,
scheduler,
steps,
img2imgStrength: strength,
iterations,
seed,
shouldRandomizeSeed,
vaePrecision,
shouldUseNoiseSettings,
shouldUseCpuNoise,
maskBlur,
maskBlurMethod,
clipSkip,
} = state.generation;
if (!model) {
log.error('No model found in state');
throw new Error('No model found in state');
}
// The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions;
// We may need to set the inpaint width and height to scale the image
const {
scaledBoundingBoxDimensions,
boundingBoxScaleMethod,
shouldAutoSave,
} = state.canvas;
const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise
: shouldUseCpuNoise;
const graph: NonNullableGraph = {
id: CANVAS_INPAINT_GRAPH,
nodes: {
[MAIN_MODEL_LOADER]: {
type: 'main_model_loader',
id: MAIN_MODEL_LOADER,
is_intermediate: true,
model,
},
[CLIP_SKIP]: {
type: 'clip_skip',
id: CLIP_SKIP,
is_intermediate: true,
skipped_layers: clipSkip,
},
[POSITIVE_CONDITIONING]: {
type: 'compel',
id: POSITIVE_CONDITIONING,
is_intermediate: true,
prompt: positivePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: 'compel',
id: NEGATIVE_CONDITIONING,
is_intermediate: true,
prompt: negativePrompt,
},
[MASK_BLUR]: {
type: 'img_blur',
id: MASK_BLUR,
is_intermediate: true,
radius: maskBlur,
blur_type: maskBlurMethod,
},
[INPAINT_IMAGE]: {
type: 'i2l',
id: INPAINT_IMAGE,
is_intermediate: true,
fp32: vaePrecision === 'fp32' ? true : false,
},
[NOISE]: {
type: 'noise',
id: NOISE,
use_cpu,
is_intermediate: true,
},
[DENOISE_LATENTS]: {
type: 'denoise_latents',
id: DENOISE_LATENTS,
is_intermediate: true,
steps: steps,
cfg_scale: cfg_scale,
scheduler: scheduler,
denoising_start: 1 - strength,
denoising_end: 1,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
is_intermediate: true,
fp32: vaePrecision === 'fp32' ? true : false,
},
[COLOR_CORRECT]: {
type: 'color_correct',
id: COLOR_CORRECT,
is_intermediate: true,
reference: canvasInitImage,
},
[CANVAS_OUTPUT]: {
type: 'img_paste',
id: CANVAS_OUTPUT,
is_intermediate: !shouldAutoSave,
base_image: canvasInitImage,
},
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
is_intermediate: true,
// seed - must be connected manually
// start: 0,
size: iterations,
step: 1,
},
[ITERATE]: {
type: 'iterate',
id: ITERATE,
is_intermediate: true,
},
},
edges: [
// Connect Model Loader to CLIP Skip and UNet
{
source: {
node_id: MAIN_MODEL_LOADER,
field: 'unet',
},
destination: {
node_id: DENOISE_LATENTS,
field: 'unet',
},
},
{
source: {
node_id: MAIN_MODEL_LOADER,
field: 'clip',
},
destination: {
node_id: CLIP_SKIP,
field: 'clip',
},
},
// Connect CLIP Skip to Conditioning
{
source: {
node_id: CLIP_SKIP,
field: 'clip',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip',
},
},
{
source: {
node_id: CLIP_SKIP,
field: 'clip',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip',
},
},
// Connect Everything To Inpaint Node
{
source: {
node_id: POSITIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: DENOISE_LATENTS,
field: 'positive_conditioning',
},
},
{
source: {
node_id: NEGATIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: DENOISE_LATENTS,
field: 'negative_conditioning',
},
},
{
source: {
node_id: NOISE,
field: 'noise',
},
destination: {
node_id: DENOISE_LATENTS,
field: 'noise',
},
},
{
source: {
node_id: INPAINT_IMAGE,
field: 'latents',
},
destination: {
node_id: DENOISE_LATENTS,
field: 'latents',
},
},
{
source: {
node_id: MASK_BLUR,
field: 'image',
},
destination: {
node_id: DENOISE_LATENTS,
field: 'mask',
},
},
// Iterate
{
source: {
node_id: RANGE_OF_SIZE,
field: 'collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
},
{
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
},
// Decode Inpainted Latents To Image
{
source: {
node_id: DENOISE_LATENTS,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'latents',
},
},
],
};
// Handle Scale Before Processing
if (['auto', 'manual'].includes(boundingBoxScaleMethod)) {
const scaledWidth: number = scaledBoundingBoxDimensions.width;
const scaledHeight: number = scaledBoundingBoxDimensions.height;
// Add Scaling Nodes
graph.nodes[INPAINT_IMAGE_RESIZE_UP] = {
type: 'img_resize',
id: INPAINT_IMAGE_RESIZE_UP,
is_intermediate: true,
width: scaledWidth,
height: scaledHeight,
image: canvasInitImage,
};
graph.nodes[MASK_RESIZE_UP] = {
type: 'img_resize',
id: MASK_RESIZE_UP,
is_intermediate: true,
width: scaledWidth,
height: scaledHeight,
image: canvasMaskImage,
};
graph.nodes[INPAINT_IMAGE_RESIZE_DOWN] = {
type: 'img_resize',
id: INPAINT_IMAGE_RESIZE_DOWN,
is_intermediate: true,
width: width,
height: height,
};
graph.nodes[MASK_RESIZE_DOWN] = {
type: 'img_resize',
id: MASK_RESIZE_DOWN,
is_intermediate: true,
width: width,
height: height,
};
graph.nodes[NOISE] = {
...(graph.nodes[NOISE] as NoiseInvocation),
width: scaledWidth,
height: scaledHeight,
};
// Connect Nodes
graph.edges.push(
// Scale Inpaint Image and Mask
{
source: {
node_id: INPAINT_IMAGE_RESIZE_UP,
field: 'image',
},
destination: {
node_id: INPAINT_IMAGE,
field: 'image',
},
},
{
source: {
node_id: MASK_RESIZE_UP,
field: 'image',
},
destination: {
node_id: MASK_BLUR,
field: 'image',
},
},
// Color Correct The Inpainted Result
{
source: {
node_id: LATENTS_TO_IMAGE,
field: 'image',
},
destination: {
node_id: INPAINT_IMAGE_RESIZE_DOWN,
field: 'image',
},
},
{
source: {
node_id: INPAINT_IMAGE_RESIZE_DOWN,
field: 'image',
},
destination: {
node_id: COLOR_CORRECT,
field: 'image',
},
},
{
source: {
node_id: MASK_BLUR,
field: 'image',
},
destination: {
node_id: MASK_RESIZE_DOWN,
field: 'image',
},
},
{
source: {
node_id: MASK_RESIZE_DOWN,
field: 'image',
},
destination: {
node_id: COLOR_CORRECT,
field: 'mask',
},
},
// Paste Back Onto Original Image
{
source: {
node_id: COLOR_CORRECT,
field: 'image',
},
destination: {
node_id: CANVAS_OUTPUT,
field: 'image',
},
},
{
source: {
node_id: MASK_RESIZE_DOWN,
field: 'image',
},
destination: {
node_id: CANVAS_OUTPUT,
field: 'mask',
},
}
);
} else {
// Add Images To Nodes
graph.nodes[NOISE] = {
...(graph.nodes[NOISE] as NoiseInvocation),
width: width,
height: height,
};
graph.nodes[INPAINT_IMAGE] = {
...(graph.nodes[INPAINT_IMAGE] as ImageToLatentsInvocation),
image: canvasInitImage,
};
graph.nodes[MASK_BLUR] = {
...(graph.nodes[MASK_BLUR] as ImageBlurInvocation),
image: canvasMaskImage,
};
graph.edges.push(
// Color Correct The Inpainted Result
{
source: {
node_id: LATENTS_TO_IMAGE,
field: 'image',
},
destination: {
node_id: COLOR_CORRECT,
field: 'image',
},
},
{
source: {
node_id: MASK_BLUR,
field: 'image',
},
destination: {
node_id: COLOR_CORRECT,
field: 'mask',
},
},
// Paste Back Onto Original Image
{
source: {
node_id: COLOR_CORRECT,
field: 'image',
},
destination: {
node_id: CANVAS_OUTPUT,
field: 'image',
},
},
{
source: {
node_id: MASK_BLUR,
field: 'image',
},
destination: {
node_id: CANVAS_OUTPUT,
field: 'mask',
},
}
);
}
// Handle Seed
if (shouldRandomizeSeed) {
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
graph.nodes[RANDOM_INT] = randomIntNode;
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'value' },
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
}
// Add VAE
addVAEToGraph(state, graph, MAIN_MODEL_LOADER);
// add LoRA support
addLoRAsToGraph(state, graph, DENOISE_LATENTS, MAIN_MODEL_LOADER);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph, CANVAS_OUTPUT);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
return graph;
};