From c1bbd15bd9eb1d8460aa996c8d9f2a71d12400ac Mon Sep 17 00:00:00 2001 From: Tobias Fischer Date: Sat, 21 Sep 2024 01:26:43 -0400 Subject: [PATCH] Sharded SDXL Inference (#6328) * initial sharding fixes * sigma device fix * emptyline space fix --------- Co-authored-by: chenyu --- examples/sdxl.py | 31 ++++++++++++++++++++----------- extra/models/unet.py | 2 +- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/examples/sdxl.py b/examples/sdxl.py index 5196430b83..131525b6ac 100644 --- a/examples/sdxl.py +++ b/examples/sdxl.py @@ -12,7 +12,7 @@ from extra.models.unet import UNetModel, Upsample, Downsample, timestep_embeddin from examples.stable_diffusion import ResnetBlock, Mid import numpy as np -from typing import Dict, List, Callable, Optional, Any, Set, Tuple, Union +from typing import Dict, List, Callable, Optional, Any, Set, Tuple, Union, Type import argparse, tempfile from abc import ABC, abstractmethod from pathlib import Path @@ -282,19 +282,29 @@ class SDXL: return self.first_stage_model.decode(1.0 / 0.13025 * x) -class VanillaCFG: +class Guider(ABC): def __init__(self, scale:float): self.scale = scale - def prepare_inputs(self, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor,Dict]: + @abstractmethod + def __call__(self, denoiser, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tensor: + pass + +class VanillaCFG(Guider): + def __call__(self, denoiser, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tensor: c_out = {} for k in c: assert k in ["vector", "crossattn", "concat"] c_out[k] = Tensor.cat(uc[k], c[k], dim=0) - return Tensor.cat(x, x), Tensor.cat(s, s), c_out - def __call__(self, x:Tensor) -> Tensor: - x_u, x_c = x.chunk(2) + x_u, x_c = denoiser(Tensor.cat(x, x), Tensor.cat(s, s), c_out).chunk(2) + x_pred = x_u + self.scale*(x_c - x_u) + return x_pred + +class SplitVanillaCFG(Guider): + def __call__(self, denoiser, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tensor: + x_u = denoiser(x, s, uc) + x_c = denoiser(x, s, c) x_pred = x_u + self.scale*(x_c - x_u) return x_pred @@ -302,13 +312,12 @@ class VanillaCFG: # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/sampling.py#L21 # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/sampling.py#L287 class DPMPP2MSampler: - def __init__(self, cfg_scale:float): + def __init__(self, cfg_scale:float, guider_cls:Type[Guider]=VanillaCFG): self.discretization = LegacyDDPMDiscretization() - self.guider = VanillaCFG(cfg_scale) + self.guider = guider_cls(cfg_scale) def sampler_step(self, old_denoised:Optional[Tensor], prev_sigma:Optional[Tensor], sigma:Tensor, next_sigma:Tensor, denoiser, x:Tensor, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor]: - denoised = denoiser(*self.guider.prepare_inputs(x, sigma, c, uc)) - denoised = self.guider(denoised) + denoised = self.guider(denoiser, x, sigma, c, uc) t, t_next = sigma.log().neg(), next_sigma.log().neg() h = t_next - t @@ -329,7 +338,7 @@ class DPMPP2MSampler: return x, denoised def __call__(self, denoiser, x:Tensor, c:Dict, uc:Dict, num_steps:int, timing=False) -> Tensor: - sigmas = self.discretization(num_steps) + sigmas = self.discretization(num_steps).to(x.device) x *= Tensor.sqrt(1.0 + sigmas[0] ** 2.0) num_sigmas = len(sigmas) diff --git a/extra/models/unet.py b/extra/models/unet.py index fad41443cb..92d4496320 100644 --- a/extra/models/unet.py +++ b/extra/models/unet.py @@ -7,7 +7,7 @@ import math # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/util.py#L207 def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000): half = dim // 2 - freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp() + freqs = (-math.log(max_period) * Tensor.arange(half, device=timesteps.device) / half).exp() args = timesteps.unsqueeze(1) * freqs.unsqueeze(0) return Tensor.cat(args.cos(), args.sin(), dim=-1).cast(dtypes.float16)