mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Sharded SDXL Inference (#6328)
* initial sharding fixes * sigma device fix * emptyline space fix --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -12,7 +12,7 @@ from extra.models.unet import UNetModel, Upsample, Downsample, timestep_embeddin
|
||||
from examples.stable_diffusion import ResnetBlock, Mid
|
||||
import numpy as np
|
||||
|
||||
from typing import Dict, List, Callable, Optional, Any, Set, Tuple, Union
|
||||
from typing import Dict, List, Callable, Optional, Any, Set, Tuple, Union, Type
|
||||
import argparse, tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
@@ -282,19 +282,29 @@ class SDXL:
|
||||
return self.first_stage_model.decode(1.0 / 0.13025 * x)
|
||||
|
||||
|
||||
class VanillaCFG:
|
||||
class Guider(ABC):
|
||||
def __init__(self, scale:float):
|
||||
self.scale = scale
|
||||
|
||||
def prepare_inputs(self, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor,Dict]:
|
||||
@abstractmethod
|
||||
def __call__(self, denoiser, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tensor:
|
||||
pass
|
||||
|
||||
class VanillaCFG(Guider):
|
||||
def __call__(self, denoiser, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tensor:
|
||||
c_out = {}
|
||||
for k in c:
|
||||
assert k in ["vector", "crossattn", "concat"]
|
||||
c_out[k] = Tensor.cat(uc[k], c[k], dim=0)
|
||||
return Tensor.cat(x, x), Tensor.cat(s, s), c_out
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
x_u, x_c = x.chunk(2)
|
||||
x_u, x_c = denoiser(Tensor.cat(x, x), Tensor.cat(s, s), c_out).chunk(2)
|
||||
x_pred = x_u + self.scale*(x_c - x_u)
|
||||
return x_pred
|
||||
|
||||
class SplitVanillaCFG(Guider):
|
||||
def __call__(self, denoiser, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tensor:
|
||||
x_u = denoiser(x, s, uc)
|
||||
x_c = denoiser(x, s, c)
|
||||
x_pred = x_u + self.scale*(x_c - x_u)
|
||||
return x_pred
|
||||
|
||||
@@ -302,13 +312,12 @@ class VanillaCFG:
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/sampling.py#L21
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/sampling.py#L287
|
||||
class DPMPP2MSampler:
|
||||
def __init__(self, cfg_scale:float):
|
||||
def __init__(self, cfg_scale:float, guider_cls:Type[Guider]=VanillaCFG):
|
||||
self.discretization = LegacyDDPMDiscretization()
|
||||
self.guider = VanillaCFG(cfg_scale)
|
||||
self.guider = guider_cls(cfg_scale)
|
||||
|
||||
def sampler_step(self, old_denoised:Optional[Tensor], prev_sigma:Optional[Tensor], sigma:Tensor, next_sigma:Tensor, denoiser, x:Tensor, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor]:
|
||||
denoised = denoiser(*self.guider.prepare_inputs(x, sigma, c, uc))
|
||||
denoised = self.guider(denoised)
|
||||
denoised = self.guider(denoiser, x, sigma, c, uc)
|
||||
|
||||
t, t_next = sigma.log().neg(), next_sigma.log().neg()
|
||||
h = t_next - t
|
||||
@@ -329,7 +338,7 @@ class DPMPP2MSampler:
|
||||
return x, denoised
|
||||
|
||||
def __call__(self, denoiser, x:Tensor, c:Dict, uc:Dict, num_steps:int, timing=False) -> Tensor:
|
||||
sigmas = self.discretization(num_steps)
|
||||
sigmas = self.discretization(num_steps).to(x.device)
|
||||
x *= Tensor.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||
num_sigmas = len(sigmas)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import math
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/util.py#L207
|
||||
def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp()
|
||||
freqs = (-math.log(max_period) * Tensor.arange(half, device=timesteps.device) / half).exp()
|
||||
args = timesteps.unsqueeze(1) * freqs.unsqueeze(0)
|
||||
return Tensor.cat(args.cos(), args.sin(), dim=-1).cast(dtypes.float16)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user