mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Switch to using a custom scheduler implementation for SD3 rather than the diffusers FlowMatchEulerDiscreteScheduler. It is easier to work with and enables us to re-use the clip_timestep_schedule_fractional() utility from FLUX.
This commit is contained in:
@@ -3,7 +3,6 @@ from typing import Callable, Optional, Tuple
|
||||
import torch
|
||||
import torchvision.transforms as tv_transforms
|
||||
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
|
||||
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -23,6 +22,7 @@ from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.invocations.sd3_text_encoder import SD3_T5_MAX_SEQ_LEN
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
from invokeai.backend.sd3.extensions.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
@@ -74,45 +74,6 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
# TODO(ryand): Write unit tests for _init_scheduler(). I had to fix a bug in the original implementation.
|
||||
@staticmethod
|
||||
def _init_scheduler(
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
steps: int,
|
||||
denoising_start: float,
|
||||
denoising_end: float,
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Helper function to initialize the scheduler and prepare the timesteps.
|
||||
|
||||
Based on DenoiseLatentsInvocation.init_scheduler(), but simplified since we currently only support
|
||||
FlowMatchEulerDiscreteScheduler for SD3.
|
||||
"""
|
||||
scheduler.set_timesteps(num_inference_steps=steps, device=device)
|
||||
timesteps = scheduler.timesteps
|
||||
assert isinstance(timesteps, torch.Tensor)
|
||||
|
||||
# Skip greater order timesteps.
|
||||
_timesteps = timesteps[:: scheduler.order]
|
||||
|
||||
# Get the start timestep index.
|
||||
eps = 1e-6
|
||||
t_start_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_start)))
|
||||
t_start_idx = len(list(filter(lambda ts: ts >= t_start_val - eps, _timesteps)))
|
||||
|
||||
# Get the end timestep index.
|
||||
t_end_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_end)))
|
||||
t_end_idx = len(list(filter(lambda ts: ts >= t_end_val - eps, _timesteps[t_start_idx:])))
|
||||
|
||||
# Apply the order to the indexes.
|
||||
t_start_idx *= scheduler.order
|
||||
t_end_idx *= scheduler.order
|
||||
|
||||
# Note that the returned timesteps list could be empty, but we still return an init_timestep value.
|
||||
init_timestep = timesteps[t_start_idx : t_start_idx + 1]
|
||||
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
|
||||
return timesteps, init_timestep
|
||||
|
||||
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
||||
"""Prepare the inpaint mask.
|
||||
- Loads the mask
|
||||
@@ -257,18 +218,15 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
prompt_embeds = torch.cat([neg_prompt_embeds, pos_prompt_embeds], dim=0)
|
||||
pooled_prompt_embeds = torch.cat([neg_pooled_prompt_embeds, pos_pooled_prompt_embeds], dim=0)
|
||||
|
||||
# Prepare the scheduler.
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
timesteps, init_timestep = self._init_scheduler(
|
||||
scheduler=scheduler,
|
||||
steps=self.steps,
|
||||
denoising_start=self.denoising_start,
|
||||
denoising_end=self.denoising_end,
|
||||
device=device,
|
||||
)
|
||||
# Prepare the timestep schedule.
|
||||
# We add an extra step to the end to account for the final timestep of 0.0.
|
||||
timesteps: list[float] = torch.linspace(1, 0, self.steps + 1).tolist()
|
||||
# Clip the timesteps schedule based on denoising_start and denoising_end.
|
||||
timesteps = clip_timestep_schedule_fractional(timesteps, self.denoising_start, self.denoising_end)
|
||||
total_steps = len(timesteps) - 1
|
||||
|
||||
# Prepare the CFG scale list.
|
||||
cfg_scale = self._prepare_cfg_scale(len(timesteps))
|
||||
cfg_scale = self._prepare_cfg_scale(total_steps)
|
||||
|
||||
# Load the input latents, if provided.
|
||||
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||
@@ -291,11 +249,19 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# Prepare input latent image.
|
||||
if init_latents is not None:
|
||||
# Noise the init_latents by the appropriate amount for the first timestep.
|
||||
latents = scheduler.scale_noise(init_latents, init_timestep, noise)
|
||||
t_0 = timesteps[0]
|
||||
latents = t_0 * noise + (1.0 - t_0) * init_latents
|
||||
else:
|
||||
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
|
||||
if self.denoising_start > 1e-5:
|
||||
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
|
||||
latents = noise
|
||||
|
||||
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any
|
||||
# denoising steps.
|
||||
if len(timesteps) <= 1:
|
||||
return latents
|
||||
|
||||
# Prepare inpaint extension.
|
||||
inpaint_mask = self._prep_inpaint_mask(context, latents)
|
||||
inpaint_extension: InpaintExtension | None = None
|
||||
@@ -307,7 +273,6 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
noise=noise,
|
||||
)
|
||||
|
||||
total_steps = len(timesteps)
|
||||
step_callback = self._build_step_callback(context)
|
||||
|
||||
step_callback(
|
||||
@@ -324,11 +289,12 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
assert isinstance(transformer, SD3Transformer2DModel)
|
||||
|
||||
# 6. Denoising loop
|
||||
for step_idx, t in tqdm(list(enumerate(timesteps))):
|
||||
for step_idx, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
|
||||
# Expand the latents if we are doing CFG.
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
# Expand the timestep to match the latent model input.
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
# Multiply by 1000 to match the default FlowMatchEulerDiscreteScheduler num_train_timesteps.
|
||||
timestep = torch.tensor([t_curr * 1000], device=device).expand(latent_model_input.shape[0])
|
||||
|
||||
noise_pred = transformer(
|
||||
hidden_states=latent_model_input,
|
||||
@@ -346,27 +312,19 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
# Compute the previous noisy sample x_t -> x_t-1.
|
||||
latents_dtype = latents.dtype
|
||||
latents = scheduler.step(model_output=noise_pred, timestep=t, sample=latents, return_dict=False)[0]
|
||||
|
||||
# TODO(ryand): This MPS dtype handling was copied from diffusers, I haven't tested to see if it's
|
||||
# needed.
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
latents = latents.to(dtype=torch.float32)
|
||||
latents = latents + (t_prev - t_curr) * noise_pred
|
||||
latents = latents.to(dtype=latents_dtype)
|
||||
|
||||
if inpaint_extension is not None:
|
||||
t_prev = timesteps[step_idx + 1] if step_idx < len(timesteps) - 1 else 0.0
|
||||
latents = inpaint_extension.merge_intermediate_latents_with_init_latents(
|
||||
latents, scheduler, t_prev=torch.tensor([t_prev], device=device)
|
||||
)
|
||||
latents = inpaint_extension.merge_intermediate_latents_with_init_latents(latents, t_prev)
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=step_idx + 1,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(t),
|
||||
timestep=int(t_curr),
|
||||
latents=latents,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import torch
|
||||
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
|
||||
class InpaintExtension:
|
||||
@@ -26,7 +25,7 @@ class InpaintExtension:
|
||||
# `InpaintExtension._apply_mask_gradient_adjustment()`.
|
||||
|
||||
def merge_intermediate_latents_with_init_latents(
|
||||
self, intermediate_latents: torch.Tensor, scheduler: FlowMatchEulerDiscreteScheduler, t_prev: torch.Tensor
|
||||
self, intermediate_latents: torch.Tensor, t_prev: float
|
||||
) -> torch.Tensor:
|
||||
"""Merge the intermediate latents with the initial latents for the current timestep using the inpaint mask. I.e.
|
||||
update the intermediate latents to keep the regions that are not being inpainted on the correct noise
|
||||
@@ -38,10 +37,8 @@ class InpaintExtension:
|
||||
# Noise the init latents for the current timestep.
|
||||
noised_init_latents = self._init_latents
|
||||
|
||||
# Note: scheduler.timesteps does not include the final timestep of 0.0. So, if we are in the final timestep, we
|
||||
# simply use self._init_latents directly.
|
||||
if t_prev[0] > 1e-6:
|
||||
noised_init_latents = scheduler.scale_noise(sample=self._init_latents, timestep=t_prev, noise=self._noise)
|
||||
# Noise the init latents for the current timestep.
|
||||
noised_init_latents = self._noise * t_prev + (1.0 - t_prev) * self._init_latents
|
||||
|
||||
# Merge the intermediate latents with the noised_init_latents using the inpaint_mask.
|
||||
return intermediate_latents * self._inpaint_mask + noised_init_latents * (1.0 - self._inpaint_mask)
|
||||
|
||||
Reference in New Issue
Block a user