mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-31 12:57:56 -05:00
364 lines
16 KiB
Python
364 lines
16 KiB
Python
from typing import Callable, Optional
|
|
|
|
import torch
|
|
import torchvision.transforms as tv_transforms
|
|
from diffusers.models.transformers.transformer_cogview4 import CogView4Transformer2DModel
|
|
from torchvision.transforms.functional import resize as tv_resize
|
|
from tqdm import tqdm
|
|
|
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
|
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
|
from invokeai.app.invocations.fields import (
|
|
CogView4ConditioningField,
|
|
DenoiseMaskField,
|
|
FieldDescriptions,
|
|
Input,
|
|
InputField,
|
|
LatentsField,
|
|
WithBoard,
|
|
WithMetadata,
|
|
)
|
|
from invokeai.app.invocations.model import TransformerField
|
|
from invokeai.app.invocations.primitives import LatentsOutput
|
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
|
|
from invokeai.backend.model_manager.config import BaseModelType
|
|
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
|
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import CogView4ConditioningInfo
|
|
from invokeai.backend.util.devices import TorchDevice
|
|
|
|
|
|
@invocation(
|
|
"cogview4_denoise",
|
|
title="Denoise - CogView4",
|
|
tags=["image", "cogview4"],
|
|
category="image",
|
|
version="1.0.0",
|
|
classification=Classification.Prototype,
|
|
)
|
|
class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|
"""Run the denoising process with a CogView4 model."""
|
|
|
|
# If latents is provided, this means we are doing image-to-image.
|
|
latents: Optional[LatentsField] = InputField(
|
|
default=None, description=FieldDescriptions.latents, input=Input.Connection
|
|
)
|
|
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
|
|
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
|
default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection
|
|
)
|
|
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
|
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
|
transformer: TransformerField = InputField(
|
|
description=FieldDescriptions.cogview4_model, input=Input.Connection, title="Transformer"
|
|
)
|
|
positive_conditioning: CogView4ConditioningField = InputField(
|
|
description=FieldDescriptions.positive_cond, input=Input.Connection
|
|
)
|
|
negative_conditioning: CogView4ConditioningField = InputField(
|
|
description=FieldDescriptions.negative_cond, input=Input.Connection
|
|
)
|
|
cfg_scale: float | list[float] = InputField(default=3.5, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
|
width: int = InputField(default=1024, multiple_of=32, description="Width of the generated image.")
|
|
height: int = InputField(default=1024, multiple_of=32, description="Height of the generated image.")
|
|
steps: int = InputField(default=25, gt=0, description=FieldDescriptions.steps)
|
|
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
|
|
|
@torch.no_grad()
|
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
latents = self._run_diffusion(context)
|
|
latents = latents.detach().to("cpu")
|
|
|
|
name = context.tensors.save(tensor=latents)
|
|
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
|
|
|
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
|
"""Prepare the inpaint mask.
|
|
- Loads the mask
|
|
- Resizes if necessary
|
|
- Casts to same device/dtype as latents
|
|
|
|
Args:
|
|
context (InvocationContext): The invocation context, for loading the inpaint mask.
|
|
latents (torch.Tensor): A latent image tensor. Used to determine the target shape, device, and dtype for the
|
|
inpaint mask.
|
|
|
|
Returns:
|
|
torch.Tensor | None: Inpaint mask. Values of 0.0 represent the regions to be fully denoised, and 1.0
|
|
represent the regions to be preserved.
|
|
"""
|
|
if self.denoise_mask is None:
|
|
return None
|
|
mask = context.tensors.load(self.denoise_mask.mask_name)
|
|
|
|
# The input denoise_mask contains values in [0, 1], where 0.0 represents the regions to be fully denoised, and
|
|
# 1.0 represents the regions to be preserved.
|
|
# We invert the mask so that the regions to be preserved are 0.0 and the regions to be denoised are 1.0.
|
|
mask = 1.0 - mask
|
|
|
|
_, _, latent_height, latent_width = latents.shape
|
|
mask = tv_resize(
|
|
img=mask,
|
|
size=[latent_height, latent_width],
|
|
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
|
antialias=False,
|
|
)
|
|
|
|
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
|
return mask
|
|
|
|
def _load_text_conditioning(
|
|
self,
|
|
context: InvocationContext,
|
|
conditioning_name: str,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
) -> torch.Tensor:
|
|
# Load the conditioning data.
|
|
cond_data = context.conditioning.load(conditioning_name)
|
|
assert len(cond_data.conditionings) == 1
|
|
cogview4_conditioning = cond_data.conditionings[0]
|
|
assert isinstance(cogview4_conditioning, CogView4ConditioningInfo)
|
|
cogview4_conditioning = cogview4_conditioning.to(dtype=dtype, device=device)
|
|
|
|
return cogview4_conditioning.glm_embeds
|
|
|
|
def _get_noise(
|
|
self,
|
|
batch_size: int,
|
|
num_channels_latents: int,
|
|
height: int,
|
|
width: int,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
seed: int,
|
|
) -> torch.Tensor:
|
|
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
|
|
rand_device = "cpu"
|
|
rand_dtype = torch.float16
|
|
|
|
return torch.randn(
|
|
batch_size,
|
|
num_channels_latents,
|
|
int(height) // LATENT_SCALE_FACTOR,
|
|
int(width) // LATENT_SCALE_FACTOR,
|
|
device=rand_device,
|
|
dtype=rand_dtype,
|
|
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
|
).to(device=device, dtype=dtype)
|
|
|
|
def _prepare_cfg_scale(self, num_timesteps: int) -> list[float]:
|
|
"""Prepare the CFG scale list.
|
|
|
|
Args:
|
|
num_timesteps (int): The number of timesteps in the scheduler. Could be different from num_steps depending
|
|
on the scheduler used (e.g. higher order schedulers).
|
|
|
|
Returns:
|
|
list[float]: _description_
|
|
"""
|
|
if isinstance(self.cfg_scale, float):
|
|
cfg_scale = [self.cfg_scale] * num_timesteps
|
|
elif isinstance(self.cfg_scale, list):
|
|
assert len(self.cfg_scale) == num_timesteps
|
|
cfg_scale = self.cfg_scale
|
|
else:
|
|
raise ValueError(f"Invalid CFG scale type: {type(self.cfg_scale)}")
|
|
|
|
return cfg_scale
|
|
|
|
def _convert_timesteps_to_sigmas(self, image_seq_len: int, timesteps: torch.Tensor) -> list[float]:
|
|
# The logic to prepare the timestep / sigma schedule is based on:
|
|
# https://github.com/huggingface/diffusers/blob/b38450d5d2e5b87d5ff7088ee5798c85587b9635/src/diffusers/pipelines/cogview4/pipeline_cogview4.py#L575-L595
|
|
# The default FlowMatchEulerDiscreteScheduler configs are based on:
|
|
# https://huggingface.co/THUDM/CogView4-6B/blob/fb6f57289c73ac6d139e8d81bd5a4602d1877847/scheduler/scheduler_config.json
|
|
# This implementation differs slightly from the original for the sake of simplicity (differs in terminal value
|
|
# handling, not quantizing timesteps to integers, etc.).
|
|
|
|
def calculate_timestep_shift(
|
|
image_seq_len: int, base_seq_len: int = 256, base_shift: float = 0.25, max_shift: float = 0.75
|
|
) -> float:
|
|
m = (image_seq_len / base_seq_len) ** 0.5
|
|
mu = m * max_shift + base_shift
|
|
return mu
|
|
|
|
def time_shift_linear(mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
|
|
return mu / (mu + (1 / t - 1) ** sigma)
|
|
|
|
mu = calculate_timestep_shift(image_seq_len)
|
|
sigmas = time_shift_linear(mu, 1.0, timesteps)
|
|
return sigmas.tolist()
|
|
|
|
def _run_diffusion(
|
|
self,
|
|
context: InvocationContext,
|
|
):
|
|
inference_dtype = torch.bfloat16
|
|
device = TorchDevice.choose_torch_device()
|
|
|
|
transformer_info = context.models.load(self.transformer.transformer)
|
|
assert isinstance(transformer_info.model, CogView4Transformer2DModel)
|
|
|
|
# Load/process the conditioning data.
|
|
# TODO(ryand): Make CFG optional.
|
|
do_classifier_free_guidance = True
|
|
pos_prompt_embeds = self._load_text_conditioning(
|
|
context=context,
|
|
conditioning_name=self.positive_conditioning.conditioning_name,
|
|
dtype=inference_dtype,
|
|
device=device,
|
|
)
|
|
neg_prompt_embeds = self._load_text_conditioning(
|
|
context=context,
|
|
conditioning_name=self.negative_conditioning.conditioning_name,
|
|
dtype=inference_dtype,
|
|
device=device,
|
|
)
|
|
|
|
# Prepare misc. conditioning variables.
|
|
# TODO(ryand): We could expose these as params (like with SDXL). But, we should experiment to see if they are
|
|
# useful first.
|
|
original_size = torch.tensor([(self.height, self.width)], dtype=pos_prompt_embeds.dtype, device=device)
|
|
target_size = torch.tensor([(self.height, self.width)], dtype=pos_prompt_embeds.dtype, device=device)
|
|
crops_coords_top_left = torch.tensor([(0, 0)], dtype=pos_prompt_embeds.dtype, device=device)
|
|
|
|
# Prepare the timestep / sigma schedule.
|
|
patch_size = transformer_info.model.config.patch_size # type: ignore
|
|
assert isinstance(patch_size, int)
|
|
image_seq_len = ((self.height // LATENT_SCALE_FACTOR) * (self.width // LATENT_SCALE_FACTOR)) // (patch_size**2)
|
|
# We add an extra step to the end to account for the final timestep of 0.0.
|
|
timesteps: list[float] = torch.linspace(1, 0, self.steps + 1).tolist()
|
|
# Clip the timesteps schedule based on denoising_start and denoising_end.
|
|
timesteps = clip_timestep_schedule_fractional(timesteps, self.denoising_start, self.denoising_end)
|
|
sigmas = self._convert_timesteps_to_sigmas(image_seq_len, torch.tensor(timesteps))
|
|
total_steps = len(timesteps) - 1
|
|
|
|
# Prepare the CFG scale list.
|
|
cfg_scale = self._prepare_cfg_scale(total_steps)
|
|
|
|
# Load the input latents, if provided.
|
|
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
|
if init_latents is not None:
|
|
init_latents = init_latents.to(device=device, dtype=inference_dtype)
|
|
|
|
# Generate initial latent noise.
|
|
num_channels_latents = transformer_info.model.config.in_channels # type: ignore
|
|
assert isinstance(num_channels_latents, int)
|
|
noise = self._get_noise(
|
|
batch_size=1,
|
|
num_channels_latents=num_channels_latents,
|
|
height=self.height,
|
|
width=self.width,
|
|
dtype=inference_dtype,
|
|
device=device,
|
|
seed=self.seed,
|
|
)
|
|
|
|
# Prepare input latent image.
|
|
if init_latents is not None:
|
|
# Noise the init_latents by the appropriate amount for the first timestep.
|
|
s_0 = sigmas[0]
|
|
latents = s_0 * noise + (1.0 - s_0) * init_latents
|
|
else:
|
|
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
|
|
if self.denoising_start > 1e-5:
|
|
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
|
|
latents = noise
|
|
|
|
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any
|
|
# denoising steps.
|
|
if len(timesteps) <= 1:
|
|
return latents
|
|
|
|
# Prepare inpaint extension.
|
|
inpaint_mask = self._prep_inpaint_mask(context, latents)
|
|
inpaint_extension: RectifiedFlowInpaintExtension | None = None
|
|
if inpaint_mask is not None:
|
|
assert init_latents is not None
|
|
inpaint_extension = RectifiedFlowInpaintExtension(
|
|
init_latents=init_latents,
|
|
inpaint_mask=inpaint_mask,
|
|
noise=noise,
|
|
)
|
|
|
|
step_callback = self._build_step_callback(context)
|
|
|
|
step_callback(
|
|
PipelineIntermediateState(
|
|
step=0,
|
|
order=1,
|
|
total_steps=total_steps,
|
|
timestep=int(timesteps[0]),
|
|
latents=latents,
|
|
),
|
|
)
|
|
|
|
with transformer_info.model_on_device() as (_, transformer):
|
|
assert isinstance(transformer, CogView4Transformer2DModel)
|
|
|
|
# Denoising loop
|
|
for step_idx in tqdm(range(total_steps)):
|
|
t_curr = timesteps[step_idx]
|
|
sigma_curr = sigmas[step_idx]
|
|
sigma_prev = sigmas[step_idx + 1]
|
|
|
|
# Expand the timestep to match the latent model input.
|
|
# Multiply by 1000 to match the default FlowMatchEulerDiscreteScheduler num_train_timesteps.
|
|
timestep = torch.tensor([t_curr * 1000], device=device).expand(latents.shape[0])
|
|
|
|
# TODO(ryand): Support both sequential and batched CFG inference.
|
|
noise_pred_cond = transformer(
|
|
hidden_states=latents,
|
|
encoder_hidden_states=pos_prompt_embeds,
|
|
timestep=timestep,
|
|
original_size=original_size,
|
|
target_size=target_size,
|
|
crop_coords=crops_coords_top_left,
|
|
return_dict=False,
|
|
)[0]
|
|
|
|
# Apply CFG.
|
|
if do_classifier_free_guidance:
|
|
noise_pred_uncond = transformer(
|
|
hidden_states=latents,
|
|
encoder_hidden_states=neg_prompt_embeds,
|
|
timestep=timestep,
|
|
original_size=original_size,
|
|
target_size=target_size,
|
|
crop_coords=crops_coords_top_left,
|
|
return_dict=False,
|
|
)[0]
|
|
|
|
noise_pred = noise_pred_uncond + cfg_scale[step_idx] * (noise_pred_cond - noise_pred_uncond)
|
|
else:
|
|
noise_pred = noise_pred_cond
|
|
|
|
# Compute the previous noisy sample x_t -> x_t-1.
|
|
latents_dtype = latents.dtype
|
|
# TODO(ryand): Is casting to float32 necessary for precision/stability? I copied this from SD3.
|
|
latents = latents.to(dtype=torch.float32)
|
|
latents = latents + (sigma_prev - sigma_curr) * noise_pred
|
|
latents = latents.to(dtype=latents_dtype)
|
|
|
|
if inpaint_extension is not None:
|
|
latents = inpaint_extension.merge_intermediate_latents_with_init_latents(latents, sigma_prev)
|
|
|
|
step_callback(
|
|
PipelineIntermediateState(
|
|
step=step_idx + 1,
|
|
order=1,
|
|
total_steps=total_steps,
|
|
timestep=int(t_curr),
|
|
latents=latents,
|
|
),
|
|
)
|
|
|
|
return latents
|
|
|
|
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
|
def step_callback(state: PipelineIntermediateState) -> None:
|
|
context.util.sd_step_callback(state, BaseModelType.CogView4)
|
|
|
|
return step_callback
|