mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Simplify CogView4 timesteps schedule generation in preparation for timestep schedule slipping.
This commit is contained in:
committed by
psychedelicious
parent
4fae8ad163
commit
ace5e748f4
@@ -1,8 +1,7 @@
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||
from diffusers import CogView4Transformer2DModel
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
@@ -117,34 +116,13 @@ class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
return cfg_scale
|
||||
|
||||
def _init_scheduler(self) -> FlowMatchEulerDiscreteScheduler:
|
||||
# The default FlowMatchEulerDiscreteScheduler configs are copied from:
|
||||
# https://huggingface.co/THUDM/CogView4-6B/blob/fb6f57289c73ac6d139e8d81bd5a4602d1877847/scheduler/scheduler_config.json
|
||||
return FlowMatchEulerDiscreteScheduler(
|
||||
num_train_timesteps=1000,
|
||||
shift=1.0,
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.25,
|
||||
max_shift=0.75,
|
||||
base_image_seq_len=256,
|
||||
max_image_seq_len=4096,
|
||||
invert_sigmas=False,
|
||||
shift_terminal=None,
|
||||
use_karras_sigmas=False,
|
||||
use_exponential_sigmas=False,
|
||||
use_beta_sigmas=False,
|
||||
time_shift_type="linear",
|
||||
)
|
||||
|
||||
def _prepare_timesteps_and_sigmas(
|
||||
self, scheduler: FlowMatchEulerDiscreteScheduler, num_steps: int, image_seq_len: int
|
||||
) -> tuple[list[float], list[float]]:
|
||||
"""Prepare the timestep schedule."""
|
||||
# The logic to prepare the timestep schedule is based on:
|
||||
def _convert_timesteps_to_sigmas(self, image_seq_len: int, timesteps: torch.Tensor) -> list[float]:
|
||||
# The logic to prepare the timestep / sigma schedule is based on:
|
||||
# https://github.com/huggingface/diffusers/blob/b38450d5d2e5b87d5ff7088ee5798c85587b9635/src/diffusers/pipelines/cogview4/pipeline_cogview4.py#L575-L595
|
||||
|
||||
# TODO(ryand): Should we remove the dependency on the FlowMatchEulerDiscreteScheduler? It just makes this logic
|
||||
# harder to understand than it needs to be.
|
||||
# The default FlowMatchEulerDiscreteScheduler configs are based on:
|
||||
# https://huggingface.co/THUDM/CogView4-6B/blob/fb6f57289c73ac6d139e8d81bd5a4602d1877847/scheduler/scheduler_config.json
|
||||
# This implementation differs slightly from the original for the sake of simplicity (differs in terminal value
|
||||
# handling, not quantizing timesteps to integers, etc.).
|
||||
|
||||
def calculate_timestep_shift(
|
||||
image_seq_len: int, base_seq_len: int = 256, base_shift: float = 0.25, max_shift: float = 0.75
|
||||
@@ -153,17 +131,14 @@ class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
mu = m * max_shift + base_shift
|
||||
return mu
|
||||
|
||||
# Add +1 step to account for the final timestep of 0.0.
|
||||
# scheduler = self._init_scheduler()
|
||||
timesteps = np.linspace(scheduler.config.num_train_timesteps, 1.0, num_steps)
|
||||
timesteps = timesteps.astype(np.int64).astype(np.float32)
|
||||
sigmas = timesteps / scheduler.config.num_train_timesteps
|
||||
mu = calculate_timestep_shift(image_seq_len)
|
||||
scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, mu=mu)
|
||||
def time_shift_linear(mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
|
||||
return mu / (mu + (1 / t - 1) ** sigma)
|
||||
|
||||
# We have to add the final timestep of 0.0. diffusers uses a different convention and omits the final state from
|
||||
# the list.
|
||||
return scheduler.timesteps.tolist() + [0], scheduler.sigmas.tolist() + [0]
|
||||
mu = calculate_timestep_shift(image_seq_len)
|
||||
sigmas = timesteps / 1000.0
|
||||
sigmas = time_shift_linear(mu, 1.0, sigmas)
|
||||
|
||||
return sigmas.tolist()
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
@@ -198,14 +173,15 @@ class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
target_size = torch.tensor([(self.height, self.width)], dtype=pos_prompt_embeds.dtype, device=device)
|
||||
crops_coords_top_left = torch.tensor([(0, 0)], dtype=pos_prompt_embeds.dtype, device=device)
|
||||
|
||||
# Prepare the timestep schedule.
|
||||
image_seq_len = ((self.height // LATENT_SCALE_FACTOR) * (self.width // LATENT_SCALE_FACTOR)) // (
|
||||
transformer_info.model.config.patch_size**2
|
||||
)
|
||||
scheduler = self._init_scheduler()
|
||||
timesteps, sigmas = self._prepare_timesteps_and_sigmas(
|
||||
scheduler, num_steps=self.steps, image_seq_len=image_seq_len
|
||||
)
|
||||
# Prepare the timestep / sigma schedule.
|
||||
patch_size = transformer_info.model.config.patch_size # type: ignore
|
||||
assert isinstance(patch_size, int)
|
||||
image_seq_len = ((self.height // LATENT_SCALE_FACTOR) * (self.width // LATENT_SCALE_FACTOR)) // (patch_size**2)
|
||||
# We add an extra step to the end to account for the final timestep of 0.0.
|
||||
timesteps_torch = torch.linspace(1000, 0, self.steps + 1)
|
||||
sigmas = self._convert_timesteps_to_sigmas(image_seq_len, timesteps_torch)
|
||||
timesteps: list[float] = timesteps_torch.tolist()
|
||||
|
||||
# TODO(ryand): Add timestep schedule clipping.
|
||||
total_steps = len(timesteps) - 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user