mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
* feat(ui): group nodes by category in add-node dialog Add collapsible category grouping to the node picker command palette. Categories are parsed from the backend schema and displayed as expandable sections with caret icons. All categories auto-expand when searching. * feat(ui): add toggle for category grouping in add-node dialog and prioritize exact matches Add a persistent "Group Nodes by Category" setting to workflow editor settings, allowing users to switch between grouped and flat node list views. Also sort exact title matches to the top when searching. * fix: update test schema categories to match expected templates * feat: add expand/collapse all buttons to node picker and fix node categories Add "Expand All" and "Collapse All" link-buttons above the grouped category list in the add-node dialog so users can quickly open or close all categories at once. Buttons are hidden during search since categories auto-expand while searching. Fix two miscategorized nodes: Z-Image ControlNet was in "Control" instead of "Controlnet", and Upscale (RealESRGAN) was in "Esrgan" instead of "Upscale". * refactor(nodes): clean up node category taxonomy Reorganize all built-in invocation categories into a consistent set of 18 groups (model, prompt, conditioning, controlnet_preprocessors, latents, image, mask, inpaint, tiles, upscale, segmentation, math, strings, primitives, batch, metadata, multimodal, canvas). - Move denoise/i2l/l2i nodes consistently into "latents" - Move all mask creation/manipulation nodes into "mask" - Split ControlNet preprocessors out of "controlnet" into their own group - Fold "unet", "vllm", "string", "ip_adapter", "t2i_adapter" into larger groups - Move metadata_linked denoise wrappers from "latents" to "metadata" - Add missing category to ideal_size - Introduce dedicated "canvas" group for canvas/output/panel nodes Also adds the now-required `category` field to invocation template fixtures in validateConnection.test.ts. * Chore Ruff Format --------- Co-authored-by: dunkeroni <dunkeroni@gmail.com>
338 lines
14 KiB
Python
338 lines
14 KiB
Python
from typing import Callable, Optional, Tuple
|
|
|
|
import torch
|
|
import torchvision.transforms as tv_transforms
|
|
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
|
|
from torchvision.transforms.functional import resize as tv_resize
|
|
from tqdm import tqdm
|
|
|
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
|
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
|
from invokeai.app.invocations.fields import (
|
|
DenoiseMaskField,
|
|
FieldDescriptions,
|
|
Input,
|
|
InputField,
|
|
LatentsField,
|
|
SD3ConditioningField,
|
|
WithBoard,
|
|
WithMetadata,
|
|
)
|
|
from invokeai.app.invocations.model import TransformerField
|
|
from invokeai.app.invocations.primitives import LatentsOutput
|
|
from invokeai.app.invocations.sd3_text_encoder import SD3_T5_MAX_SEQ_LEN
|
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
|
|
from invokeai.backend.model_manager.taxonomy import BaseModelType
|
|
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
|
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import SD3ConditioningInfo
|
|
from invokeai.backend.util.devices import TorchDevice
|
|
|
|
|
|
@invocation(
|
|
"sd3_denoise",
|
|
title="Denoise - SD3",
|
|
tags=["image", "sd3"],
|
|
category="latents",
|
|
version="1.1.1",
|
|
)
|
|
class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|
"""Run denoising process with a SD3 model."""
|
|
|
|
# If latents is provided, this means we are doing image-to-image.
|
|
latents: Optional[LatentsField] = InputField(
|
|
default=None, description=FieldDescriptions.latents, input=Input.Connection
|
|
)
|
|
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
|
|
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
|
default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection
|
|
)
|
|
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
|
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
|
transformer: TransformerField = InputField(
|
|
description=FieldDescriptions.sd3_model, input=Input.Connection, title="Transformer"
|
|
)
|
|
positive_conditioning: SD3ConditioningField = InputField(
|
|
description=FieldDescriptions.positive_cond, input=Input.Connection
|
|
)
|
|
negative_conditioning: SD3ConditioningField = InputField(
|
|
description=FieldDescriptions.negative_cond, input=Input.Connection
|
|
)
|
|
cfg_scale: float | list[float] = InputField(default=3.5, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
|
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
|
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
|
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
|
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
|
|
|
@torch.no_grad()
|
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
latents = self._run_diffusion(context)
|
|
latents = latents.detach().to("cpu")
|
|
|
|
name = context.tensors.save(tensor=latents)
|
|
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
|
|
|
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
|
"""Prepare the inpaint mask.
|
|
- Loads the mask
|
|
- Resizes if necessary
|
|
- Casts to same device/dtype as latents
|
|
|
|
Args:
|
|
context (InvocationContext): The invocation context, for loading the inpaint mask.
|
|
latents (torch.Tensor): A latent image tensor. Used to determine the target shape, device, and dtype for the
|
|
inpaint mask.
|
|
|
|
Returns:
|
|
torch.Tensor | None: Inpaint mask. Values of 0.0 represent the regions to be fully denoised, and 1.0
|
|
represent the regions to be preserved.
|
|
"""
|
|
if self.denoise_mask is None:
|
|
return None
|
|
mask = context.tensors.load(self.denoise_mask.mask_name)
|
|
|
|
# The input denoise_mask contains values in [0, 1], where 0.0 represents the regions to be fully denoised, and
|
|
# 1.0 represents the regions to be preserved.
|
|
# We invert the mask so that the regions to be preserved are 0.0 and the regions to be denoised are 1.0.
|
|
mask = 1.0 - mask
|
|
|
|
_, _, latent_height, latent_width = latents.shape
|
|
mask = tv_resize(
|
|
img=mask,
|
|
size=[latent_height, latent_width],
|
|
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
|
antialias=False,
|
|
)
|
|
|
|
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
|
return mask
|
|
|
|
def _load_text_conditioning(
|
|
self,
|
|
context: InvocationContext,
|
|
conditioning_name: str,
|
|
joint_attention_dim: int,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
# Load the conditioning data.
|
|
cond_data = context.conditioning.load(conditioning_name)
|
|
assert len(cond_data.conditionings) == 1
|
|
sd3_conditioning = cond_data.conditionings[0]
|
|
assert isinstance(sd3_conditioning, SD3ConditioningInfo)
|
|
sd3_conditioning = sd3_conditioning.to(dtype=dtype, device=device)
|
|
|
|
t5_embeds = sd3_conditioning.t5_embeds
|
|
if t5_embeds is None:
|
|
t5_embeds = torch.zeros(
|
|
(1, SD3_T5_MAX_SEQ_LEN, joint_attention_dim),
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
|
|
clip_prompt_embeds = torch.cat([sd3_conditioning.clip_l_embeds, sd3_conditioning.clip_g_embeds], dim=-1)
|
|
clip_prompt_embeds = torch.nn.functional.pad(
|
|
clip_prompt_embeds, (0, t5_embeds.shape[-1] - clip_prompt_embeds.shape[-1])
|
|
)
|
|
|
|
prompt_embeds = torch.cat([clip_prompt_embeds, t5_embeds], dim=-2)
|
|
pooled_prompt_embeds = torch.cat(
|
|
[sd3_conditioning.clip_l_pooled_embeds, sd3_conditioning.clip_g_pooled_embeds], dim=-1
|
|
)
|
|
|
|
return prompt_embeds, pooled_prompt_embeds
|
|
|
|
def _get_noise(
|
|
self,
|
|
num_samples: int,
|
|
num_channels_latents: int,
|
|
height: int,
|
|
width: int,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
seed: int,
|
|
) -> torch.Tensor:
|
|
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
|
|
rand_device = "cpu"
|
|
rand_dtype = torch.float16
|
|
|
|
return torch.randn(
|
|
num_samples,
|
|
num_channels_latents,
|
|
int(height) // LATENT_SCALE_FACTOR,
|
|
int(width) // LATENT_SCALE_FACTOR,
|
|
device=rand_device,
|
|
dtype=rand_dtype,
|
|
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
|
).to(device=device, dtype=dtype)
|
|
|
|
def _prepare_cfg_scale(self, num_timesteps: int) -> list[float]:
|
|
"""Prepare the CFG scale list.
|
|
|
|
Args:
|
|
num_timesteps (int): The number of timesteps in the scheduler. Could be different from num_steps depending
|
|
on the scheduler used (e.g. higher order schedulers).
|
|
|
|
Returns:
|
|
list[float]: _description_
|
|
"""
|
|
if isinstance(self.cfg_scale, float):
|
|
cfg_scale = [self.cfg_scale] * num_timesteps
|
|
elif isinstance(self.cfg_scale, list):
|
|
assert len(self.cfg_scale) == num_timesteps
|
|
cfg_scale = self.cfg_scale
|
|
else:
|
|
raise ValueError(f"Invalid CFG scale type: {type(self.cfg_scale)}")
|
|
|
|
return cfg_scale
|
|
|
|
def _run_diffusion(
|
|
self,
|
|
context: InvocationContext,
|
|
):
|
|
inference_dtype = TorchDevice.choose_torch_dtype()
|
|
device = TorchDevice.choose_torch_device()
|
|
|
|
transformer_info = context.models.load(self.transformer.transformer)
|
|
|
|
# Load/process the conditioning data.
|
|
# TODO(ryand): Make CFG optional.
|
|
do_classifier_free_guidance = True
|
|
pos_prompt_embeds, pos_pooled_prompt_embeds = self._load_text_conditioning(
|
|
context=context,
|
|
conditioning_name=self.positive_conditioning.conditioning_name,
|
|
joint_attention_dim=transformer_info.model.config.joint_attention_dim,
|
|
dtype=inference_dtype,
|
|
device=device,
|
|
)
|
|
neg_prompt_embeds, neg_pooled_prompt_embeds = self._load_text_conditioning(
|
|
context=context,
|
|
conditioning_name=self.negative_conditioning.conditioning_name,
|
|
joint_attention_dim=transformer_info.model.config.joint_attention_dim,
|
|
dtype=inference_dtype,
|
|
device=device,
|
|
)
|
|
# TODO(ryand): Support both sequential and batched CFG inference.
|
|
prompt_embeds = torch.cat([neg_prompt_embeds, pos_prompt_embeds], dim=0)
|
|
pooled_prompt_embeds = torch.cat([neg_pooled_prompt_embeds, pos_pooled_prompt_embeds], dim=0)
|
|
|
|
# Prepare the timestep schedule.
|
|
# We add an extra step to the end to account for the final timestep of 0.0.
|
|
timesteps: list[float] = torch.linspace(1, 0, self.steps + 1).tolist()
|
|
# Clip the timesteps schedule based on denoising_start and denoising_end.
|
|
timesteps = clip_timestep_schedule_fractional(timesteps, self.denoising_start, self.denoising_end)
|
|
total_steps = len(timesteps) - 1
|
|
|
|
# Prepare the CFG scale list.
|
|
cfg_scale = self._prepare_cfg_scale(total_steps)
|
|
|
|
# Load the input latents, if provided.
|
|
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
|
if init_latents is not None:
|
|
init_latents = init_latents.to(device=device, dtype=inference_dtype)
|
|
|
|
# Generate initial latent noise.
|
|
num_channels_latents = transformer_info.model.config.in_channels
|
|
assert isinstance(num_channels_latents, int)
|
|
noise = self._get_noise(
|
|
num_samples=1,
|
|
num_channels_latents=num_channels_latents,
|
|
height=self.height,
|
|
width=self.width,
|
|
dtype=inference_dtype,
|
|
device=device,
|
|
seed=self.seed,
|
|
)
|
|
|
|
# Prepare input latent image.
|
|
if init_latents is not None:
|
|
# Noise the init_latents by the appropriate amount for the first timestep.
|
|
t_0 = timesteps[0]
|
|
latents = t_0 * noise + (1.0 - t_0) * init_latents
|
|
else:
|
|
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
|
|
if self.denoising_start > 1e-5:
|
|
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
|
|
latents = noise
|
|
|
|
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any
|
|
# denoising steps.
|
|
if len(timesteps) <= 1:
|
|
return latents
|
|
|
|
# Prepare inpaint extension.
|
|
inpaint_mask = self._prep_inpaint_mask(context, latents)
|
|
inpaint_extension: RectifiedFlowInpaintExtension | None = None
|
|
if inpaint_mask is not None:
|
|
assert init_latents is not None
|
|
inpaint_extension = RectifiedFlowInpaintExtension(
|
|
init_latents=init_latents,
|
|
inpaint_mask=inpaint_mask,
|
|
noise=noise,
|
|
)
|
|
|
|
step_callback = self._build_step_callback(context)
|
|
|
|
step_callback(
|
|
PipelineIntermediateState(
|
|
step=0,
|
|
order=1,
|
|
total_steps=total_steps,
|
|
timestep=int(timesteps[0]),
|
|
latents=latents,
|
|
),
|
|
)
|
|
|
|
with transformer_info.model_on_device() as (cached_weights, transformer):
|
|
assert isinstance(transformer, SD3Transformer2DModel)
|
|
|
|
# 6. Denoising loop
|
|
for step_idx, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
|
|
# Expand the latents if we are doing CFG.
|
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
|
# Expand the timestep to match the latent model input.
|
|
# Multiply by 1000 to match the default FlowMatchEulerDiscreteScheduler num_train_timesteps.
|
|
timestep = torch.tensor([t_curr * 1000], device=device).expand(latent_model_input.shape[0])
|
|
|
|
noise_pred = transformer(
|
|
hidden_states=latent_model_input,
|
|
timestep=timestep,
|
|
encoder_hidden_states=prompt_embeds,
|
|
pooled_projections=pooled_prompt_embeds,
|
|
joint_attention_kwargs=None,
|
|
return_dict=False,
|
|
)[0]
|
|
|
|
# Apply CFG.
|
|
if do_classifier_free_guidance:
|
|
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
|
noise_pred = noise_pred_uncond + cfg_scale[step_idx] * (noise_pred_cond - noise_pred_uncond)
|
|
|
|
# Compute the previous noisy sample x_t -> x_t-1.
|
|
latents_dtype = latents.dtype
|
|
latents = latents.to(dtype=torch.float32)
|
|
latents = latents + (t_prev - t_curr) * noise_pred
|
|
latents = latents.to(dtype=latents_dtype)
|
|
|
|
if inpaint_extension is not None:
|
|
latents = inpaint_extension.merge_intermediate_latents_with_init_latents(latents, t_prev)
|
|
|
|
step_callback(
|
|
PipelineIntermediateState(
|
|
step=step_idx + 1,
|
|
order=1,
|
|
total_steps=total_steps,
|
|
timestep=int(t_curr),
|
|
latents=latents,
|
|
),
|
|
)
|
|
|
|
return latents
|
|
|
|
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
|
def step_callback(state: PipelineIntermediateState) -> None:
|
|
context.util.sd_step_callback(state, BaseModelType.StableDiffusion3)
|
|
|
|
return step_callback
|