mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Fix bug in CogView4 noise schedule handling that was resulting in low-quality images.
This commit is contained in:
committed by
psychedelicious
parent
3166b5d2ea
commit
5e75bc570a
@@ -1,7 +1,8 @@
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import CogView4Transformer2DModel
|
||||
from diffusers import CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
@@ -116,13 +117,34 @@ class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
return cfg_scale
|
||||
|
||||
def _prepare_timesteps(self, num_steps: int, image_seq_len: int) -> list[float]:
|
||||
"""Prepare the timestep schedule."""
|
||||
# The default FlowMatchEulerDiscreteScheduler for CogView4 can be found here:
|
||||
def _init_scheduler(self) -> FlowMatchEulerDiscreteScheduler:
|
||||
# The default FlowMatchEulerDiscreteScheduler configs are copied from:
|
||||
# https://huggingface.co/THUDM/CogView4-6B/blob/fb6f57289c73ac6d139e8d81bd5a4602d1877847/scheduler/scheduler_config.json
|
||||
# We re-implement this logic here to avoid all the complexity of working with the diffusers schedulers.
|
||||
# Note that the timestep schedule initialization is pretty similar to that used for Flux. The main difference is
|
||||
# that we use a linear timestep shift instead of the exponential shift used in Flux.
|
||||
return FlowMatchEulerDiscreteScheduler(
|
||||
num_train_timesteps=1000,
|
||||
shift=1.0,
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.25,
|
||||
max_shift=0.75,
|
||||
base_image_seq_len=256,
|
||||
max_image_seq_len=4096,
|
||||
invert_sigmas=False,
|
||||
shift_terminal=None,
|
||||
use_karras_sigmas=False,
|
||||
use_exponential_sigmas=False,
|
||||
use_beta_sigmas=False,
|
||||
time_shift_type="linear",
|
||||
)
|
||||
|
||||
def _prepare_timesteps_and_sigmas(
|
||||
self, scheduler: FlowMatchEulerDiscreteScheduler, num_steps: int, image_seq_len: int
|
||||
) -> tuple[list[float], list[float]]:
|
||||
"""Prepare the timestep schedule."""
|
||||
# The logic to prepare the timestep schedule is based on:
|
||||
# https://github.com/huggingface/diffusers/blob/b38450d5d2e5b87d5ff7088ee5798c85587b9635/src/diffusers/pipelines/cogview4/pipeline_cogview4.py#L575-L595
|
||||
|
||||
# TODO(ryand): Should we remove the dependency on the FlowMatchEulerDiscreteScheduler? It just makes this logic
|
||||
# harder to understand than it needs to be.
|
||||
|
||||
def calculate_timestep_shift(
|
||||
image_seq_len: int, base_seq_len: int = 256, base_shift: float = 0.25, max_shift: float = 0.75
|
||||
@@ -131,15 +153,17 @@ class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
mu = m * max_shift + base_shift
|
||||
return mu
|
||||
|
||||
def apply_linear_timestep_shift(mu: float, sigma: float, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
return mu / (mu + (1 / timesteps - 1) ** sigma)
|
||||
|
||||
# Add +1 step to account for the final timestep of 0.0.
|
||||
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||
# scheduler = self._init_scheduler()
|
||||
timesteps = np.linspace(scheduler.config.num_train_timesteps, 1.0, num_steps)
|
||||
timesteps = timesteps.astype(np.int64).astype(np.float32)
|
||||
sigmas = timesteps / scheduler.config.num_train_timesteps
|
||||
mu = calculate_timestep_shift(image_seq_len)
|
||||
timesteps = apply_linear_timestep_shift(mu, 1.0, timesteps)
|
||||
scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, mu=mu)
|
||||
|
||||
return timesteps.tolist()
|
||||
# We have to add the final timestep of 0.0. diffusers uses a different convention and omits the final state from
|
||||
# the list.
|
||||
return scheduler.timesteps.tolist() + [0], scheduler.sigmas.tolist() + [0]
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
@@ -178,12 +202,16 @@ class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
image_seq_len = ((self.height // LATENT_SCALE_FACTOR) * (self.width // LATENT_SCALE_FACTOR)) // (
|
||||
transformer_info.model.config.patch_size**2
|
||||
)
|
||||
timesteps = self._prepare_timesteps(num_steps=self.steps, image_seq_len=image_seq_len)
|
||||
scheduler = self._init_scheduler()
|
||||
timesteps, sigmas = self._prepare_timesteps_and_sigmas(
|
||||
scheduler, num_steps=self.steps, image_seq_len=image_seq_len
|
||||
)
|
||||
# TODO(ryand): Add timestep schedule clipping.
|
||||
total_steps = len(timesteps) - 1
|
||||
|
||||
# Prepare the CFG scale list.
|
||||
cfg_scale = self._prepare_cfg_scale(total_steps)
|
||||
# TODO(ryand): Implement this.
|
||||
# cfg_scale = self._prepare_cfg_scale(total_steps)
|
||||
|
||||
# Generate initial latent noise.
|
||||
noise = self._get_noise(
|
||||
@@ -215,10 +243,13 @@ class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
assert isinstance(transformer, CogView4Transformer2DModel)
|
||||
|
||||
# Denoising loop
|
||||
for step_idx, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
|
||||
for step_idx in tqdm(range(total_steps)):
|
||||
t_curr = timesteps[step_idx]
|
||||
sigma_curr = sigmas[step_idx]
|
||||
sigma_prev = sigmas[step_idx + 1]
|
||||
|
||||
# Expand the timestep to match the latent model input.
|
||||
# Multiply by 1000 to match the default FlowMatchEulerDiscreteScheduler num_train_timesteps.
|
||||
timestep = torch.tensor([t_curr * 1000], device=device).expand(latents.shape[0])
|
||||
timestep = torch.tensor([t_curr], device=device).expand(latents.shape[0])
|
||||
|
||||
# TODO(ryand): Support both sequential and batched CFG inference.
|
||||
noise_pred_cond = transformer(
|
||||
@@ -251,7 +282,7 @@ class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
latents_dtype = latents.dtype
|
||||
# TODO(ryand): Is casting to float32 necessary for precision/stability? I copied this from SD3.
|
||||
latents = latents.to(dtype=torch.float32)
|
||||
latents = latents + (t_prev - t_curr) * noise_pred
|
||||
latents = latents + (sigma_prev - sigma_curr) * noise_pred
|
||||
latents = latents.to(dtype=latents_dtype)
|
||||
|
||||
step_callback(
|
||||
|
||||
Reference in New Issue
Block a user