import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; import { NonNullableGraph } from 'features/nodes/types/types'; import { initialGenerationState } from 'features/parameters/store/generationSlice'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; import { addVAEToGraph } from './addVAEToGraph'; import { addWatermarkerToGraph } from './addWatermarkerToGraph'; import { LATENTS_TO_IMAGE, METADATA_ACCUMULATOR, NEGATIVE_CONDITIONING, NOISE, POSITIVE_CONDITIONING, REFINER_SEAMLESS, SDXL_DENOISE_LATENTS, SDXL_MODEL_LOADER, SDXL_TEXT_TO_IMAGE_GRAPH, SEAMLESS, } from './constants'; import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt'; export const buildLinearSDXLTextToImageGraph = ( state: RootState ): NonNullableGraph => { const log = logger('nodes'); const { positivePrompt, negativePrompt, model, cfgScale: cfg_scale, scheduler, steps, width, height, clipSkip, shouldUseCpuNoise, shouldUseNoiseSettings, vaePrecision, seamlessXAxis, seamlessYAxis, } = state.generation; const { positiveStylePrompt, negativeStylePrompt, shouldUseSDXLRefiner, shouldConcatSDXLStylePrompt, refinerStart, } = state.sdxl; const use_cpu = shouldUseNoiseSettings ? shouldUseCpuNoise : initialGenerationState.shouldUseCpuNoise; if (!model) { log.error('No model found in state'); throw new Error('No model found in state'); } // Construct Style Prompt const { craftedPositiveStylePrompt, craftedNegativeStylePrompt } = craftSDXLStylePrompt(state, shouldConcatSDXLStylePrompt); // Model Loader ID let modelLoaderNodeId = SDXL_MODEL_LOADER; /** * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * full graph here as a template. Then use the parameters from app state and set friendlier node * ids. * * The only thing we need extra logic for is handling randomized seed, control net, and for img2img, * the `fit` param. These are added to the graph at the end. */ // copy-pasted graph from node editor, filled in with state values & friendly node ids const graph: NonNullableGraph = { id: SDXL_TEXT_TO_IMAGE_GRAPH, nodes: { [modelLoaderNodeId]: { type: 'sdxl_model_loader', id: modelLoaderNodeId, model, }, [POSITIVE_CONDITIONING]: { type: 'sdxl_compel_prompt', id: POSITIVE_CONDITIONING, prompt: positivePrompt, style: craftedPositiveStylePrompt, }, [NEGATIVE_CONDITIONING]: { type: 'sdxl_compel_prompt', id: NEGATIVE_CONDITIONING, prompt: negativePrompt, style: craftedNegativeStylePrompt, }, [NOISE]: { type: 'noise', id: NOISE, width, height, use_cpu, }, [SDXL_DENOISE_LATENTS]: { type: 'denoise_latents', id: SDXL_DENOISE_LATENTS, cfg_scale, scheduler, steps, denoising_start: 0, denoising_end: shouldUseSDXLRefiner ? refinerStart : 1, }, [LATENTS_TO_IMAGE]: { type: 'l2i', id: LATENTS_TO_IMAGE, fp32: vaePrecision === 'fp32' ? true : false, }, }, edges: [ // Connect Model Loader to UNet, VAE & CLIP { source: { node_id: modelLoaderNodeId, field: 'unet', }, destination: { node_id: SDXL_DENOISE_LATENTS, field: 'unet', }, }, { source: { node_id: modelLoaderNodeId, field: 'clip', }, destination: { node_id: POSITIVE_CONDITIONING, field: 'clip', }, }, { source: { node_id: modelLoaderNodeId, field: 'clip2', }, destination: { node_id: POSITIVE_CONDITIONING, field: 'clip2', }, }, { source: { node_id: modelLoaderNodeId, field: 'clip', }, destination: { node_id: NEGATIVE_CONDITIONING, field: 'clip', }, }, { source: { node_id: modelLoaderNodeId, field: 'clip2', }, destination: { node_id: NEGATIVE_CONDITIONING, field: 'clip2', }, }, // Connect everything to Denoise Latents { source: { node_id: POSITIVE_CONDITIONING, field: 'conditioning', }, destination: { node_id: SDXL_DENOISE_LATENTS, field: 'positive_conditioning', }, }, { source: { node_id: NEGATIVE_CONDITIONING, field: 'conditioning', }, destination: { node_id: SDXL_DENOISE_LATENTS, field: 'negative_conditioning', }, }, { source: { node_id: NOISE, field: 'noise', }, destination: { node_id: SDXL_DENOISE_LATENTS, field: 'noise', }, }, // Decode Denoised Latents To Image { source: { node_id: SDXL_DENOISE_LATENTS, field: 'latents', }, destination: { node_id: LATENTS_TO_IMAGE, field: 'latents', }, }, ], }; // add metadata accumulator, which is only mostly populated - some fields are added later graph.nodes[METADATA_ACCUMULATOR] = { id: METADATA_ACCUMULATOR, type: 'metadata_accumulator', generation_mode: 'sdxl_txt2img', cfg_scale, height, width, positive_prompt: '', // set in addDynamicPromptsToGraph negative_prompt: negativePrompt, model, seed: 0, // set in addDynamicPromptsToGraph steps, rand_device: use_cpu ? 'cpu' : 'cuda', scheduler, vae: undefined, controlnets: [], loras: [], clip_skip: clipSkip, positive_style_prompt: positiveStylePrompt, negative_style_prompt: negativeStylePrompt, }; graph.edges.push({ source: { node_id: METADATA_ACCUMULATOR, field: 'metadata', }, destination: { node_id: LATENTS_TO_IMAGE, field: 'metadata', }, }); // Add Seamless To Graph if (seamlessXAxis || seamlessYAxis) { addSeamlessToLinearGraph(state, graph, modelLoaderNodeId); modelLoaderNodeId = SEAMLESS; } // Add Refiner if enabled if (shouldUseSDXLRefiner) { addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS); modelLoaderNodeId = REFINER_SEAMLESS; } // optionally add custom VAE addVAEToGraph(state, graph, modelLoaderNodeId); // add LoRA support addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId); // add controlnet, mutating `graph` addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); // add dynamic prompts - also sets up core iteration and seed addDynamicPromptsToGraph(state, graph); // NSFW & watermark - must be last thing added to graph if (state.system.shouldUseNSFWChecker) { // must add before watermarker! addNSFWCheckerToGraph(state, graph); } if (state.system.shouldUseWatermarker) { // must add after nsfw checker! addWatermarkerToGraph(state, graph); } return graph; };