sdxl batched inference fixes (#6293)

This commit is contained in:
Tobias Fischer
2024-08-28 07:44:58 -04:00
committed by GitHub
parent 85591bd1ae
commit 3517aa89d9
2 changed files with 28 additions and 31 deletions

View File

@@ -12,7 +12,7 @@ from extra.models.unet import UNetModel, Upsample, Downsample, timestep_embeddin
from examples.stable_diffusion import ResnetBlock, Mid
import numpy as np
from typing import Dict, List, Callable, Optional, Any, Set, Tuple
from typing import Dict, List, Callable, Optional, Any, Set, Tuple, Union
import argparse, tempfile
from abc import ABC, abstractmethod
from pathlib import Path
@@ -47,21 +47,14 @@ class DiffusionModel:
self.diffusion_model = UNetModel(*args, **kwargs)
class Embedder(ABC):
input_key: str
@abstractmethod
def __call__(self, x:Tensor) -> Tensor:
pass
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L913
class ConcatTimestepEmbedderND(Embedder):
def __init__(self, outdim:int, input_key:str):
self.outdim = outdim
self.input_key = input_key
def __call__(self, x:Tensor):
assert len(x.shape) == 2
def __call__(self, x:Union[str,List[str],Tensor]):
assert isinstance(x, Tensor) and len(x.shape) == 2
emb = timestep_embedding(x.flatten(), self.outdim)
emb = emb.reshape((x.shape[0],-1))
return emb
@@ -91,9 +84,8 @@ class Conditioner:
emb_out = embedder(batch[embedder.input_key])
if isinstance(emb_out, Tensor):
emb_out = [emb_out]
else:
assert isinstance(emb_out, (list, tuple))
emb_out = (emb_out,)
assert isinstance(emb_out, (list, tuple))
for emb in emb_out:
if embedder.input_key in force_zero_embeddings:
@@ -248,22 +240,23 @@ class SDXL:
self.sigmas = self.discretization(config["denoiser"]["num_idx"], flip=True)
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/inference/helpers.py#L173
def create_conditioning(self, pos_prompt:str, img_width:int, img_height:int, aesthetic_score:float=5.0) -> Tuple[Dict,Dict]:
def create_conditioning(self, pos_prompts:List[str], img_width:int, img_height:int, aesthetic_score:float=5.0) -> Tuple[Dict,Dict]:
N = len(pos_prompts)
batch_c : Dict = {
"txt": pos_prompt,
"txt": pos_prompts,
"original_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1),
"crop_coords_top_left": Tensor([0,0]).repeat(N,1),
"target_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1),
"aesthetic_score": Tensor([aesthetic_score]).repeat(N,1),
}
batch_uc: Dict = {
"txt": "",
"txt": [""]*N,
"original_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1),
"crop_coords_top_left": Tensor([0,0]).repeat(N,1),
"target_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1),
"aesthetic_score": Tensor([aesthetic_score]).repeat(N,1),
}
return model.conditioner(batch_c), model.conditioner(batch_uc, force_zero_embeddings=["txt"])
return self.conditioner(batch_c), self.conditioner(batch_uc, force_zero_embeddings=["txt"])
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/denoiser.py#L42
def denoise(self, x:Tensor, sigma:Tensor, cond:Dict) -> Tensor:
@@ -293,14 +286,14 @@ class VanillaCFG:
def __init__(self, scale:float):
self.scale = scale
def prepare_inputs(self, x:Tensor, s:float, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor,Tensor]:
def prepare_inputs(self, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor,Dict]:
c_out = {}
for k in c:
assert k in ["vector", "crossattn", "concat"]
c_out[k] = Tensor.cat(uc[k], c[k], dim=0)
return Tensor.cat(x, x), Tensor.cat(s, s), c_out
def __call__(self, x:Tensor, sigma:float) -> Tensor:
def __call__(self, x:Tensor) -> Tensor:
x_u, x_c = x.chunk(2)
x_pred = x_u + self.scale*(x_c - x_u)
return x_pred
@@ -315,7 +308,7 @@ class DPMPP2MSampler:
def sampler_step(self, old_denoised:Optional[Tensor], prev_sigma:Optional[Tensor], sigma:Tensor, next_sigma:Tensor, denoiser, x:Tensor, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor]:
denoised = denoiser(*self.guider.prepare_inputs(x, sigma, c, uc))
denoised = self.guider(denoised, sigma)
denoised = self.guider(denoised)
t, t_next = sigma.log().neg(), next_sigma.log().neg()
h = t_next - t
@@ -345,9 +338,9 @@ class DPMPP2MSampler:
with Timing("step in ", enabled=timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
x, old_denoised = self.sampler_step(
old_denoised=old_denoised,
prev_sigma=(None if i==0 else sigmas[i-1].reshape(x.shape[0])),
sigma=sigmas[i].reshape(x.shape[0]),
next_sigma=sigmas[i+1].reshape(x.shape[0]),
prev_sigma=(None if i==0 else sigmas[i-1].expand(x.shape[0])),
sigma=sigmas[i].expand(x.shape[0]),
next_sigma=sigmas[i+1].expand(x.shape[0]),
denoiser=denoiser,
x=x,
c=c,
@@ -391,7 +384,7 @@ if __name__ == "__main__":
assert args.width % F == 0, f"img_width must be multiple of {F}, got {args.width}"
assert args.height % F == 0, f"img_height must be multiple of {F}, got {args.height}"
c, uc = model.create_conditioning(args.prompt, args.width, args.height)
c, uc = model.create_conditioning([args.prompt], args.width, args.height)
del model.conditioner
for v in c .values(): v.realize()
for v in uc.values(): v.realize()

View File

@@ -108,7 +108,7 @@ class Tokenizer:
self.cache[token] = word
return word
def encode(self, text:str, pad_with_zeros:bool=False):
def encode(self, text:str, pad_with_zeros:bool=False) -> List[int]:
bpe_tokens: List[int] = []
text = Tokenizer.whitespace_clean(text.strip()).lower()
for token in re.findall(self.pat, text):
@@ -123,7 +123,7 @@ class Tokenizer:
class Embedder(ABC):
input_key: str
@abstractmethod
def __call__(self, text:str) -> Union[Tensor,Tuple[Tensor,...]]:
def __call__(self, x:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]:
pass
@@ -222,9 +222,11 @@ class FrozenClosedClipEmbedder(Embedder):
self.transformer = Closed.ClipTextModel(ret_layer_idx)
self.input_key = "txt"
def __call__(self, text:str) -> Union[Tensor,Tuple[Tensor,...]]:
tokens = Tensor(self.tokenizer.encode(text))
return self.transformer.text_model(tokens.reshape(1,-1))
def __call__(self, texts:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]:
if isinstance(texts, str): texts = [texts]
assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}"
tokens = Tensor.cat(*[Tensor(self.tokenizer.encode(text)) for text in texts], dim=0)
return self.transformer.text_model(tokens.reshape(len(texts),-1))
class Open:
@@ -378,8 +380,10 @@ class FrozenOpenClipEmbedder(Embedder):
else:
return penultimate
def __call__(self, text:str) -> Union[Tensor,Tuple[Tensor,...]]:
tokens = self.tokenize(text)
def __call__(self, texts:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]:
if isinstance(texts, str): texts = [texts]
assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}"
tokens = Tensor.cat(*[self.tokenize(text) for text in texts], dim=0)
return self.embed_tokens(tokens)