Switch to using a custom scheduler implementation for SD3 rather than the diffusers FlowMatchEulerDiscreteScheduler. It is easier to work with and enables us to re-use the clip_timestep_schedule_fractional() utility from FLUX.

This commit is contained in:
Ryan Dick
2024-11-07 22:46:52 +00:00
parent a5f8c23dee
commit a0fefcd43f
2 changed files with 28 additions and 73 deletions

View File

@@ -3,7 +3,6 @@ from typing import Callable, Optional, Tuple
import torch
import torchvision.transforms as tv_transforms
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from torchvision.transforms.functional import resize as tv_resize
from tqdm import tqdm
@@ -23,6 +22,7 @@ from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.invocations.sd3_text_encoder import SD3_T5_MAX_SEQ_LEN
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
from invokeai.backend.model_manager.config import BaseModelType
from invokeai.backend.sd3.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
@@ -74,45 +74,6 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
# TODO(ryand): Write unit tests for _init_scheduler(). I had to fix a bug in the original implementation.
@staticmethod
def _init_scheduler(
scheduler: FlowMatchEulerDiscreteScheduler,
steps: int,
denoising_start: float,
denoising_end: float,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Helper function to initialize the scheduler and prepare the timesteps.
Based on DenoiseLatentsInvocation.init_scheduler(), but simplified since we currently only support
FlowMatchEulerDiscreteScheduler for SD3.
"""
scheduler.set_timesteps(num_inference_steps=steps, device=device)
timesteps = scheduler.timesteps
assert isinstance(timesteps, torch.Tensor)
# Skip greater order timesteps.
_timesteps = timesteps[:: scheduler.order]
# Get the start timestep index.
eps = 1e-6
t_start_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_start)))
t_start_idx = len(list(filter(lambda ts: ts >= t_start_val - eps, _timesteps)))
# Get the end timestep index.
t_end_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_end)))
t_end_idx = len(list(filter(lambda ts: ts >= t_end_val - eps, _timesteps[t_start_idx:])))
# Apply the order to the indexes.
t_start_idx *= scheduler.order
t_end_idx *= scheduler.order
# Note that the returned timesteps list could be empty, but we still return an init_timestep value.
init_timestep = timesteps[t_start_idx : t_start_idx + 1]
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
return timesteps, init_timestep
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
"""Prepare the inpaint mask.
- Loads the mask
@@ -257,18 +218,15 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
prompt_embeds = torch.cat([neg_prompt_embeds, pos_prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([neg_pooled_prompt_embeds, pos_pooled_prompt_embeds], dim=0)
# Prepare the scheduler.
scheduler = FlowMatchEulerDiscreteScheduler()
timesteps, init_timestep = self._init_scheduler(
scheduler=scheduler,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
device=device,
)
# Prepare the timestep schedule.
# We add an extra step to the end to account for the final timestep of 0.0.
timesteps: list[float] = torch.linspace(1, 0, self.steps + 1).tolist()
# Clip the timesteps schedule based on denoising_start and denoising_end.
timesteps = clip_timestep_schedule_fractional(timesteps, self.denoising_start, self.denoising_end)
total_steps = len(timesteps) - 1
# Prepare the CFG scale list.
cfg_scale = self._prepare_cfg_scale(len(timesteps))
cfg_scale = self._prepare_cfg_scale(total_steps)
# Load the input latents, if provided.
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
@@ -291,11 +249,19 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# Prepare input latent image.
if init_latents is not None:
# Noise the init_latents by the appropriate amount for the first timestep.
latents = scheduler.scale_noise(init_latents, init_timestep, noise)
t_0 = timesteps[0]
latents = t_0 * noise + (1.0 - t_0) * init_latents
else:
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
if self.denoising_start > 1e-5:
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
latents = noise
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any
# denoising steps.
if len(timesteps) <= 1:
return latents
# Prepare inpaint extension.
inpaint_mask = self._prep_inpaint_mask(context, latents)
inpaint_extension: InpaintExtension | None = None
@@ -307,7 +273,6 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
noise=noise,
)
total_steps = len(timesteps)
step_callback = self._build_step_callback(context)
step_callback(
@@ -324,11 +289,12 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
assert isinstance(transformer, SD3Transformer2DModel)
# 6. Denoising loop
for step_idx, t in tqdm(list(enumerate(timesteps))):
for step_idx, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
# Expand the latents if we are doing CFG.
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# Expand the timestep to match the latent model input.
timestep = t.expand(latent_model_input.shape[0])
# Multiply by 1000 to match the default FlowMatchEulerDiscreteScheduler num_train_timesteps.
timestep = torch.tensor([t_curr * 1000], device=device).expand(latent_model_input.shape[0])
noise_pred = transformer(
hidden_states=latent_model_input,
@@ -346,27 +312,19 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# Compute the previous noisy sample x_t -> x_t-1.
latents_dtype = latents.dtype
latents = scheduler.step(model_output=noise_pred, timestep=t, sample=latents, return_dict=False)[0]
# TODO(ryand): This MPS dtype handling was copied from diffusers, I haven't tested to see if it's
# needed.
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
latents = latents.to(dtype=torch.float32)
latents = latents + (t_prev - t_curr) * noise_pred
latents = latents.to(dtype=latents_dtype)
if inpaint_extension is not None:
t_prev = timesteps[step_idx + 1] if step_idx < len(timesteps) - 1 else 0.0
latents = inpaint_extension.merge_intermediate_latents_with_init_latents(
latents, scheduler, t_prev=torch.tensor([t_prev], device=device)
)
latents = inpaint_extension.merge_intermediate_latents_with_init_latents(latents, t_prev)
step_callback(
PipelineIntermediateState(
step=step_idx + 1,
order=1,
total_steps=total_steps,
timestep=int(t),
timestep=int(t_curr),
latents=latents,
),
)

View File

@@ -1,5 +1,4 @@
import torch
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
class InpaintExtension:
@@ -26,7 +25,7 @@ class InpaintExtension:
# `InpaintExtension._apply_mask_gradient_adjustment()`.
def merge_intermediate_latents_with_init_latents(
self, intermediate_latents: torch.Tensor, scheduler: FlowMatchEulerDiscreteScheduler, t_prev: torch.Tensor
self, intermediate_latents: torch.Tensor, t_prev: float
) -> torch.Tensor:
"""Merge the intermediate latents with the initial latents for the current timestep using the inpaint mask. I.e.
update the intermediate latents to keep the regions that are not being inpainted on the correct noise
@@ -38,10 +37,8 @@ class InpaintExtension:
# Noise the init latents for the current timestep.
noised_init_latents = self._init_latents
# Note: scheduler.timesteps does not include the final timestep of 0.0. So, if we are in the final timestep, we
# simply use self._init_latents directly.
if t_prev[0] > 1e-6:
noised_init_latents = scheduler.scale_noise(sample=self._init_latents, timestep=t_prev, noise=self._noise)
# Noise the init latents for the current timestep.
noised_init_latents = self._noise * t_prev + (1.0 - t_prev) * self._init_latents
# Merge the intermediate latents with the noised_init_latents using the inpaint_mask.
return intermediate_latents * self._inpaint_mask + noised_init_latents * (1.0 - self._inpaint_mask)