Simplify CogView4 timesteps schedule generation in preparation for timestep schedule slipping.

This commit is contained in:
Ryan Dick
2025-03-10 16:45:30 +00:00
committed by psychedelicious
parent 4fae8ad163
commit ace5e748f4

View File

@@ -1,8 +1,7 @@
from typing import Callable
import numpy as np
import torch
from diffusers import CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers import CogView4Transformer2DModel
from tqdm import tqdm
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
@@ -117,34 +116,13 @@ class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
return cfg_scale
def _init_scheduler(self) -> FlowMatchEulerDiscreteScheduler:
# The default FlowMatchEulerDiscreteScheduler configs are copied from:
# https://huggingface.co/THUDM/CogView4-6B/blob/fb6f57289c73ac6d139e8d81bd5a4602d1877847/scheduler/scheduler_config.json
return FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000,
shift=1.0,
use_dynamic_shifting=True,
base_shift=0.25,
max_shift=0.75,
base_image_seq_len=256,
max_image_seq_len=4096,
invert_sigmas=False,
shift_terminal=None,
use_karras_sigmas=False,
use_exponential_sigmas=False,
use_beta_sigmas=False,
time_shift_type="linear",
)
def _prepare_timesteps_and_sigmas(
self, scheduler: FlowMatchEulerDiscreteScheduler, num_steps: int, image_seq_len: int
) -> tuple[list[float], list[float]]:
"""Prepare the timestep schedule."""
# The logic to prepare the timestep schedule is based on:
def _convert_timesteps_to_sigmas(self, image_seq_len: int, timesteps: torch.Tensor) -> list[float]:
# The logic to prepare the timestep / sigma schedule is based on:
# https://github.com/huggingface/diffusers/blob/b38450d5d2e5b87d5ff7088ee5798c85587b9635/src/diffusers/pipelines/cogview4/pipeline_cogview4.py#L575-L595
# TODO(ryand): Should we remove the dependency on the FlowMatchEulerDiscreteScheduler? It just makes this logic
# harder to understand than it needs to be.
# The default FlowMatchEulerDiscreteScheduler configs are based on:
# https://huggingface.co/THUDM/CogView4-6B/blob/fb6f57289c73ac6d139e8d81bd5a4602d1877847/scheduler/scheduler_config.json
# This implementation differs slightly from the original for the sake of simplicity (differs in terminal value
# handling, not quantizing timesteps to integers, etc.).
def calculate_timestep_shift(
image_seq_len: int, base_seq_len: int = 256, base_shift: float = 0.25, max_shift: float = 0.75
@@ -153,17 +131,14 @@ class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
mu = m * max_shift + base_shift
return mu
# Add +1 step to account for the final timestep of 0.0.
# scheduler = self._init_scheduler()
timesteps = np.linspace(scheduler.config.num_train_timesteps, 1.0, num_steps)
timesteps = timesteps.astype(np.int64).astype(np.float32)
sigmas = timesteps / scheduler.config.num_train_timesteps
mu = calculate_timestep_shift(image_seq_len)
scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, mu=mu)
def time_shift_linear(mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
return mu / (mu + (1 / t - 1) ** sigma)
# We have to add the final timestep of 0.0. diffusers uses a different convention and omits the final state from
# the list.
return scheduler.timesteps.tolist() + [0], scheduler.sigmas.tolist() + [0]
mu = calculate_timestep_shift(image_seq_len)
sigmas = timesteps / 1000.0
sigmas = time_shift_linear(mu, 1.0, sigmas)
return sigmas.tolist()
def _run_diffusion(
self,
@@ -198,14 +173,15 @@ class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
target_size = torch.tensor([(self.height, self.width)], dtype=pos_prompt_embeds.dtype, device=device)
crops_coords_top_left = torch.tensor([(0, 0)], dtype=pos_prompt_embeds.dtype, device=device)
# Prepare the timestep schedule.
image_seq_len = ((self.height // LATENT_SCALE_FACTOR) * (self.width // LATENT_SCALE_FACTOR)) // (
transformer_info.model.config.patch_size**2
)
scheduler = self._init_scheduler()
timesteps, sigmas = self._prepare_timesteps_and_sigmas(
scheduler, num_steps=self.steps, image_seq_len=image_seq_len
)
# Prepare the timestep / sigma schedule.
patch_size = transformer_info.model.config.patch_size # type: ignore
assert isinstance(patch_size, int)
image_seq_len = ((self.height // LATENT_SCALE_FACTOR) * (self.width // LATENT_SCALE_FACTOR)) // (patch_size**2)
# We add an extra step to the end to account for the final timestep of 0.0.
timesteps_torch = torch.linspace(1000, 0, self.steps + 1)
sigmas = self._convert_timesteps_to_sigmas(image_seq_len, timesteps_torch)
timesteps: list[float] = timesteps_torch.tolist()
# TODO(ryand): Add timestep schedule clipping.
total_steps = len(timesteps) - 1