mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
sdxl batched inference fixes (#6293)
This commit is contained in:
@@ -12,7 +12,7 @@ from extra.models.unet import UNetModel, Upsample, Downsample, timestep_embeddin
|
||||
from examples.stable_diffusion import ResnetBlock, Mid
|
||||
import numpy as np
|
||||
|
||||
from typing import Dict, List, Callable, Optional, Any, Set, Tuple
|
||||
from typing import Dict, List, Callable, Optional, Any, Set, Tuple, Union
|
||||
import argparse, tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
@@ -47,21 +47,14 @@ class DiffusionModel:
|
||||
self.diffusion_model = UNetModel(*args, **kwargs)
|
||||
|
||||
|
||||
class Embedder(ABC):
|
||||
input_key: str
|
||||
@abstractmethod
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
pass
|
||||
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L913
|
||||
class ConcatTimestepEmbedderND(Embedder):
|
||||
def __init__(self, outdim:int, input_key:str):
|
||||
self.outdim = outdim
|
||||
self.input_key = input_key
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
assert len(x.shape) == 2
|
||||
def __call__(self, x:Union[str,List[str],Tensor]):
|
||||
assert isinstance(x, Tensor) and len(x.shape) == 2
|
||||
emb = timestep_embedding(x.flatten(), self.outdim)
|
||||
emb = emb.reshape((x.shape[0],-1))
|
||||
return emb
|
||||
@@ -91,9 +84,8 @@ class Conditioner:
|
||||
emb_out = embedder(batch[embedder.input_key])
|
||||
|
||||
if isinstance(emb_out, Tensor):
|
||||
emb_out = [emb_out]
|
||||
else:
|
||||
assert isinstance(emb_out, (list, tuple))
|
||||
emb_out = (emb_out,)
|
||||
assert isinstance(emb_out, (list, tuple))
|
||||
|
||||
for emb in emb_out:
|
||||
if embedder.input_key in force_zero_embeddings:
|
||||
@@ -248,22 +240,23 @@ class SDXL:
|
||||
self.sigmas = self.discretization(config["denoiser"]["num_idx"], flip=True)
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/inference/helpers.py#L173
|
||||
def create_conditioning(self, pos_prompt:str, img_width:int, img_height:int, aesthetic_score:float=5.0) -> Tuple[Dict,Dict]:
|
||||
def create_conditioning(self, pos_prompts:List[str], img_width:int, img_height:int, aesthetic_score:float=5.0) -> Tuple[Dict,Dict]:
|
||||
N = len(pos_prompts)
|
||||
batch_c : Dict = {
|
||||
"txt": pos_prompt,
|
||||
"txt": pos_prompts,
|
||||
"original_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1),
|
||||
"crop_coords_top_left": Tensor([0,0]).repeat(N,1),
|
||||
"target_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1),
|
||||
"aesthetic_score": Tensor([aesthetic_score]).repeat(N,1),
|
||||
}
|
||||
batch_uc: Dict = {
|
||||
"txt": "",
|
||||
"txt": [""]*N,
|
||||
"original_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1),
|
||||
"crop_coords_top_left": Tensor([0,0]).repeat(N,1),
|
||||
"target_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1),
|
||||
"aesthetic_score": Tensor([aesthetic_score]).repeat(N,1),
|
||||
}
|
||||
return model.conditioner(batch_c), model.conditioner(batch_uc, force_zero_embeddings=["txt"])
|
||||
return self.conditioner(batch_c), self.conditioner(batch_uc, force_zero_embeddings=["txt"])
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/denoiser.py#L42
|
||||
def denoise(self, x:Tensor, sigma:Tensor, cond:Dict) -> Tensor:
|
||||
@@ -293,14 +286,14 @@ class VanillaCFG:
|
||||
def __init__(self, scale:float):
|
||||
self.scale = scale
|
||||
|
||||
def prepare_inputs(self, x:Tensor, s:float, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor,Tensor]:
|
||||
def prepare_inputs(self, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor,Dict]:
|
||||
c_out = {}
|
||||
for k in c:
|
||||
assert k in ["vector", "crossattn", "concat"]
|
||||
c_out[k] = Tensor.cat(uc[k], c[k], dim=0)
|
||||
return Tensor.cat(x, x), Tensor.cat(s, s), c_out
|
||||
|
||||
def __call__(self, x:Tensor, sigma:float) -> Tensor:
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
x_u, x_c = x.chunk(2)
|
||||
x_pred = x_u + self.scale*(x_c - x_u)
|
||||
return x_pred
|
||||
@@ -315,7 +308,7 @@ class DPMPP2MSampler:
|
||||
|
||||
def sampler_step(self, old_denoised:Optional[Tensor], prev_sigma:Optional[Tensor], sigma:Tensor, next_sigma:Tensor, denoiser, x:Tensor, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor]:
|
||||
denoised = denoiser(*self.guider.prepare_inputs(x, sigma, c, uc))
|
||||
denoised = self.guider(denoised, sigma)
|
||||
denoised = self.guider(denoised)
|
||||
|
||||
t, t_next = sigma.log().neg(), next_sigma.log().neg()
|
||||
h = t_next - t
|
||||
@@ -345,9 +338,9 @@ class DPMPP2MSampler:
|
||||
with Timing("step in ", enabled=timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
|
||||
x, old_denoised = self.sampler_step(
|
||||
old_denoised=old_denoised,
|
||||
prev_sigma=(None if i==0 else sigmas[i-1].reshape(x.shape[0])),
|
||||
sigma=sigmas[i].reshape(x.shape[0]),
|
||||
next_sigma=sigmas[i+1].reshape(x.shape[0]),
|
||||
prev_sigma=(None if i==0 else sigmas[i-1].expand(x.shape[0])),
|
||||
sigma=sigmas[i].expand(x.shape[0]),
|
||||
next_sigma=sigmas[i+1].expand(x.shape[0]),
|
||||
denoiser=denoiser,
|
||||
x=x,
|
||||
c=c,
|
||||
@@ -391,7 +384,7 @@ if __name__ == "__main__":
|
||||
assert args.width % F == 0, f"img_width must be multiple of {F}, got {args.width}"
|
||||
assert args.height % F == 0, f"img_height must be multiple of {F}, got {args.height}"
|
||||
|
||||
c, uc = model.create_conditioning(args.prompt, args.width, args.height)
|
||||
c, uc = model.create_conditioning([args.prompt], args.width, args.height)
|
||||
del model.conditioner
|
||||
for v in c .values(): v.realize()
|
||||
for v in uc.values(): v.realize()
|
||||
|
||||
@@ -108,7 +108,7 @@ class Tokenizer:
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text:str, pad_with_zeros:bool=False):
|
||||
def encode(self, text:str, pad_with_zeros:bool=False) -> List[int]:
|
||||
bpe_tokens: List[int] = []
|
||||
text = Tokenizer.whitespace_clean(text.strip()).lower()
|
||||
for token in re.findall(self.pat, text):
|
||||
@@ -123,7 +123,7 @@ class Tokenizer:
|
||||
class Embedder(ABC):
|
||||
input_key: str
|
||||
@abstractmethod
|
||||
def __call__(self, text:str) -> Union[Tensor,Tuple[Tensor,...]]:
|
||||
def __call__(self, x:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]:
|
||||
pass
|
||||
|
||||
|
||||
@@ -222,9 +222,11 @@ class FrozenClosedClipEmbedder(Embedder):
|
||||
self.transformer = Closed.ClipTextModel(ret_layer_idx)
|
||||
self.input_key = "txt"
|
||||
|
||||
def __call__(self, text:str) -> Union[Tensor,Tuple[Tensor,...]]:
|
||||
tokens = Tensor(self.tokenizer.encode(text))
|
||||
return self.transformer.text_model(tokens.reshape(1,-1))
|
||||
def __call__(self, texts:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]:
|
||||
if isinstance(texts, str): texts = [texts]
|
||||
assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}"
|
||||
tokens = Tensor.cat(*[Tensor(self.tokenizer.encode(text)) for text in texts], dim=0)
|
||||
return self.transformer.text_model(tokens.reshape(len(texts),-1))
|
||||
|
||||
|
||||
class Open:
|
||||
@@ -378,8 +380,10 @@ class FrozenOpenClipEmbedder(Embedder):
|
||||
else:
|
||||
return penultimate
|
||||
|
||||
def __call__(self, text:str) -> Union[Tensor,Tuple[Tensor,...]]:
|
||||
tokens = self.tokenize(text)
|
||||
def __call__(self, texts:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]:
|
||||
if isinstance(texts, str): texts = [texts]
|
||||
assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}"
|
||||
tokens = Tensor.cat(*[self.tokenize(text) for text in texts], dim=0)
|
||||
return self.embed_tokens(tokens)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user