Sharded SDXL Inference (#6328)

* initial sharding fixes

* sigma device fix

* emptyline space fix

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
Tobias Fischer
2024-09-21 01:26:43 -04:00
committed by GitHub
parent b91aa1c3d1
commit c1bbd15bd9
2 changed files with 21 additions and 12 deletions

View File

@@ -12,7 +12,7 @@ from extra.models.unet import UNetModel, Upsample, Downsample, timestep_embeddin
from examples.stable_diffusion import ResnetBlock, Mid
import numpy as np
from typing import Dict, List, Callable, Optional, Any, Set, Tuple, Union
from typing import Dict, List, Callable, Optional, Any, Set, Tuple, Union, Type
import argparse, tempfile
from abc import ABC, abstractmethod
from pathlib import Path
@@ -282,19 +282,29 @@ class SDXL:
return self.first_stage_model.decode(1.0 / 0.13025 * x)
class VanillaCFG:
class Guider(ABC):
def __init__(self, scale:float):
self.scale = scale
def prepare_inputs(self, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor,Dict]:
@abstractmethod
def __call__(self, denoiser, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tensor:
pass
class VanillaCFG(Guider):
def __call__(self, denoiser, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tensor:
c_out = {}
for k in c:
assert k in ["vector", "crossattn", "concat"]
c_out[k] = Tensor.cat(uc[k], c[k], dim=0)
return Tensor.cat(x, x), Tensor.cat(s, s), c_out
def __call__(self, x:Tensor) -> Tensor:
x_u, x_c = x.chunk(2)
x_u, x_c = denoiser(Tensor.cat(x, x), Tensor.cat(s, s), c_out).chunk(2)
x_pred = x_u + self.scale*(x_c - x_u)
return x_pred
class SplitVanillaCFG(Guider):
def __call__(self, denoiser, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tensor:
x_u = denoiser(x, s, uc)
x_c = denoiser(x, s, c)
x_pred = x_u + self.scale*(x_c - x_u)
return x_pred
@@ -302,13 +312,12 @@ class VanillaCFG:
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/sampling.py#L21
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/sampling.py#L287
class DPMPP2MSampler:
def __init__(self, cfg_scale:float):
def __init__(self, cfg_scale:float, guider_cls:Type[Guider]=VanillaCFG):
self.discretization = LegacyDDPMDiscretization()
self.guider = VanillaCFG(cfg_scale)
self.guider = guider_cls(cfg_scale)
def sampler_step(self, old_denoised:Optional[Tensor], prev_sigma:Optional[Tensor], sigma:Tensor, next_sigma:Tensor, denoiser, x:Tensor, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor]:
denoised = denoiser(*self.guider.prepare_inputs(x, sigma, c, uc))
denoised = self.guider(denoised)
denoised = self.guider(denoiser, x, sigma, c, uc)
t, t_next = sigma.log().neg(), next_sigma.log().neg()
h = t_next - t
@@ -329,7 +338,7 @@ class DPMPP2MSampler:
return x, denoised
def __call__(self, denoiser, x:Tensor, c:Dict, uc:Dict, num_steps:int, timing=False) -> Tensor:
sigmas = self.discretization(num_steps)
sigmas = self.discretization(num_steps).to(x.device)
x *= Tensor.sqrt(1.0 + sigmas[0] ** 2.0)
num_sigmas = len(sigmas)

View File

@@ -7,7 +7,7 @@ import math
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/util.py#L207
def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000):
half = dim // 2
freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp()
freqs = (-math.log(max_period) * Tensor.arange(half, device=timesteps.device) / half).exp()
args = timesteps.unsqueeze(1) * freqs.unsqueeze(0)
return Tensor.cat(args.cos(), args.sin(), dim=-1).cast(dtypes.float16)