Fix bug in CogView4 noise schedule handling that was resulting in low-quality images.

This commit is contained in:
Ryan Dick
2025-03-07 23:55:54 +00:00
committed by psychedelicious
parent 3166b5d2ea
commit 5e75bc570a

View File

@@ -1,7 +1,8 @@
from typing import Callable
import numpy as np
import torch
from diffusers import CogView4Transformer2DModel
from diffusers import CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from tqdm import tqdm
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
@@ -116,13 +117,34 @@ class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
return cfg_scale
def _prepare_timesteps(self, num_steps: int, image_seq_len: int) -> list[float]:
"""Prepare the timestep schedule."""
# The default FlowMatchEulerDiscreteScheduler for CogView4 can be found here:
def _init_scheduler(self) -> FlowMatchEulerDiscreteScheduler:
# The default FlowMatchEulerDiscreteScheduler configs are copied from:
# https://huggingface.co/THUDM/CogView4-6B/blob/fb6f57289c73ac6d139e8d81bd5a4602d1877847/scheduler/scheduler_config.json
# We re-implement this logic here to avoid all the complexity of working with the diffusers schedulers.
# Note that the timestep schedule initialization is pretty similar to that used for Flux. The main difference is
# that we use a linear timestep shift instead of the exponential shift used in Flux.
return FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000,
shift=1.0,
use_dynamic_shifting=True,
base_shift=0.25,
max_shift=0.75,
base_image_seq_len=256,
max_image_seq_len=4096,
invert_sigmas=False,
shift_terminal=None,
use_karras_sigmas=False,
use_exponential_sigmas=False,
use_beta_sigmas=False,
time_shift_type="linear",
)
def _prepare_timesteps_and_sigmas(
self, scheduler: FlowMatchEulerDiscreteScheduler, num_steps: int, image_seq_len: int
) -> tuple[list[float], list[float]]:
"""Prepare the timestep schedule."""
# The logic to prepare the timestep schedule is based on:
# https://github.com/huggingface/diffusers/blob/b38450d5d2e5b87d5ff7088ee5798c85587b9635/src/diffusers/pipelines/cogview4/pipeline_cogview4.py#L575-L595
# TODO(ryand): Should we remove the dependency on the FlowMatchEulerDiscreteScheduler? It just makes this logic
# harder to understand than it needs to be.
def calculate_timestep_shift(
image_seq_len: int, base_seq_len: int = 256, base_shift: float = 0.25, max_shift: float = 0.75
@@ -131,15 +153,17 @@ class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
mu = m * max_shift + base_shift
return mu
def apply_linear_timestep_shift(mu: float, sigma: float, timesteps: torch.Tensor) -> torch.Tensor:
return mu / (mu + (1 / timesteps - 1) ** sigma)
# Add +1 step to account for the final timestep of 0.0.
timesteps = torch.linspace(1, 0, num_steps + 1)
# scheduler = self._init_scheduler()
timesteps = np.linspace(scheduler.config.num_train_timesteps, 1.0, num_steps)
timesteps = timesteps.astype(np.int64).astype(np.float32)
sigmas = timesteps / scheduler.config.num_train_timesteps
mu = calculate_timestep_shift(image_seq_len)
timesteps = apply_linear_timestep_shift(mu, 1.0, timesteps)
scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, mu=mu)
return timesteps.tolist()
# We have to add the final timestep of 0.0. diffusers uses a different convention and omits the final state from
# the list.
return scheduler.timesteps.tolist() + [0], scheduler.sigmas.tolist() + [0]
def _run_diffusion(
self,
@@ -178,12 +202,16 @@ class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
image_seq_len = ((self.height // LATENT_SCALE_FACTOR) * (self.width // LATENT_SCALE_FACTOR)) // (
transformer_info.model.config.patch_size**2
)
timesteps = self._prepare_timesteps(num_steps=self.steps, image_seq_len=image_seq_len)
scheduler = self._init_scheduler()
timesteps, sigmas = self._prepare_timesteps_and_sigmas(
scheduler, num_steps=self.steps, image_seq_len=image_seq_len
)
# TODO(ryand): Add timestep schedule clipping.
total_steps = len(timesteps) - 1
# Prepare the CFG scale list.
cfg_scale = self._prepare_cfg_scale(total_steps)
# TODO(ryand): Implement this.
# cfg_scale = self._prepare_cfg_scale(total_steps)
# Generate initial latent noise.
noise = self._get_noise(
@@ -215,10 +243,13 @@ class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
assert isinstance(transformer, CogView4Transformer2DModel)
# Denoising loop
for step_idx, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
for step_idx in tqdm(range(total_steps)):
t_curr = timesteps[step_idx]
sigma_curr = sigmas[step_idx]
sigma_prev = sigmas[step_idx + 1]
# Expand the timestep to match the latent model input.
# Multiply by 1000 to match the default FlowMatchEulerDiscreteScheduler num_train_timesteps.
timestep = torch.tensor([t_curr * 1000], device=device).expand(latents.shape[0])
timestep = torch.tensor([t_curr], device=device).expand(latents.shape[0])
# TODO(ryand): Support both sequential and batched CFG inference.
noise_pred_cond = transformer(
@@ -251,7 +282,7 @@ class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
latents_dtype = latents.dtype
# TODO(ryand): Is casting to float32 necessary for precision/stability? I copied this from SD3.
latents = latents.to(dtype=torch.float32)
latents = latents + (t_prev - t_curr) * noise_pred
latents = latents + (sigma_prev - sigma_curr) * noise_pred
latents = latents.to(dtype=latents_dtype)
step_callback(