From 3517aa89d95e745947adefd7613662d4d4ba84fe Mon Sep 17 00:00:00 2001 From: Tobias Fischer Date: Wed, 28 Aug 2024 07:44:58 -0400 Subject: [PATCH] sdxl batched inference fixes (#6293) --- examples/sdxl.py | 41 +++++++++++++++++------------------------ extra/models/clip.py | 18 +++++++++++------- 2 files changed, 28 insertions(+), 31 deletions(-) diff --git a/examples/sdxl.py b/examples/sdxl.py index aa1ad53e54..5196430b83 100644 --- a/examples/sdxl.py +++ b/examples/sdxl.py @@ -12,7 +12,7 @@ from extra.models.unet import UNetModel, Upsample, Downsample, timestep_embeddin from examples.stable_diffusion import ResnetBlock, Mid import numpy as np -from typing import Dict, List, Callable, Optional, Any, Set, Tuple +from typing import Dict, List, Callable, Optional, Any, Set, Tuple, Union import argparse, tempfile from abc import ABC, abstractmethod from pathlib import Path @@ -47,21 +47,14 @@ class DiffusionModel: self.diffusion_model = UNetModel(*args, **kwargs) -class Embedder(ABC): - input_key: str - @abstractmethod - def __call__(self, x:Tensor) -> Tensor: - pass - - # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L913 class ConcatTimestepEmbedderND(Embedder): def __init__(self, outdim:int, input_key:str): self.outdim = outdim self.input_key = input_key - def __call__(self, x:Tensor): - assert len(x.shape) == 2 + def __call__(self, x:Union[str,List[str],Tensor]): + assert isinstance(x, Tensor) and len(x.shape) == 2 emb = timestep_embedding(x.flatten(), self.outdim) emb = emb.reshape((x.shape[0],-1)) return emb @@ -91,9 +84,8 @@ class Conditioner: emb_out = embedder(batch[embedder.input_key]) if isinstance(emb_out, Tensor): - emb_out = [emb_out] - else: - assert isinstance(emb_out, (list, tuple)) + emb_out = (emb_out,) + assert isinstance(emb_out, (list, tuple)) for emb in emb_out: if embedder.input_key in force_zero_embeddings: @@ -248,22 +240,23 @@ class SDXL: self.sigmas = self.discretization(config["denoiser"]["num_idx"], flip=True) # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/inference/helpers.py#L173 - def create_conditioning(self, pos_prompt:str, img_width:int, img_height:int, aesthetic_score:float=5.0) -> Tuple[Dict,Dict]: + def create_conditioning(self, pos_prompts:List[str], img_width:int, img_height:int, aesthetic_score:float=5.0) -> Tuple[Dict,Dict]: + N = len(pos_prompts) batch_c : Dict = { - "txt": pos_prompt, + "txt": pos_prompts, "original_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1), "crop_coords_top_left": Tensor([0,0]).repeat(N,1), "target_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1), "aesthetic_score": Tensor([aesthetic_score]).repeat(N,1), } batch_uc: Dict = { - "txt": "", + "txt": [""]*N, "original_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1), "crop_coords_top_left": Tensor([0,0]).repeat(N,1), "target_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1), "aesthetic_score": Tensor([aesthetic_score]).repeat(N,1), } - return model.conditioner(batch_c), model.conditioner(batch_uc, force_zero_embeddings=["txt"]) + return self.conditioner(batch_c), self.conditioner(batch_uc, force_zero_embeddings=["txt"]) # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/denoiser.py#L42 def denoise(self, x:Tensor, sigma:Tensor, cond:Dict) -> Tensor: @@ -293,14 +286,14 @@ class VanillaCFG: def __init__(self, scale:float): self.scale = scale - def prepare_inputs(self, x:Tensor, s:float, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor,Tensor]: + def prepare_inputs(self, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor,Dict]: c_out = {} for k in c: assert k in ["vector", "crossattn", "concat"] c_out[k] = Tensor.cat(uc[k], c[k], dim=0) return Tensor.cat(x, x), Tensor.cat(s, s), c_out - def __call__(self, x:Tensor, sigma:float) -> Tensor: + def __call__(self, x:Tensor) -> Tensor: x_u, x_c = x.chunk(2) x_pred = x_u + self.scale*(x_c - x_u) return x_pred @@ -315,7 +308,7 @@ class DPMPP2MSampler: def sampler_step(self, old_denoised:Optional[Tensor], prev_sigma:Optional[Tensor], sigma:Tensor, next_sigma:Tensor, denoiser, x:Tensor, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor]: denoised = denoiser(*self.guider.prepare_inputs(x, sigma, c, uc)) - denoised = self.guider(denoised, sigma) + denoised = self.guider(denoised) t, t_next = sigma.log().neg(), next_sigma.log().neg() h = t_next - t @@ -345,9 +338,9 @@ class DPMPP2MSampler: with Timing("step in ", enabled=timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"): x, old_denoised = self.sampler_step( old_denoised=old_denoised, - prev_sigma=(None if i==0 else sigmas[i-1].reshape(x.shape[0])), - sigma=sigmas[i].reshape(x.shape[0]), - next_sigma=sigmas[i+1].reshape(x.shape[0]), + prev_sigma=(None if i==0 else sigmas[i-1].expand(x.shape[0])), + sigma=sigmas[i].expand(x.shape[0]), + next_sigma=sigmas[i+1].expand(x.shape[0]), denoiser=denoiser, x=x, c=c, @@ -391,7 +384,7 @@ if __name__ == "__main__": assert args.width % F == 0, f"img_width must be multiple of {F}, got {args.width}" assert args.height % F == 0, f"img_height must be multiple of {F}, got {args.height}" - c, uc = model.create_conditioning(args.prompt, args.width, args.height) + c, uc = model.create_conditioning([args.prompt], args.width, args.height) del model.conditioner for v in c .values(): v.realize() for v in uc.values(): v.realize() diff --git a/extra/models/clip.py b/extra/models/clip.py index d8c6879456..96ec072aa1 100644 --- a/extra/models/clip.py +++ b/extra/models/clip.py @@ -108,7 +108,7 @@ class Tokenizer: self.cache[token] = word return word - def encode(self, text:str, pad_with_zeros:bool=False): + def encode(self, text:str, pad_with_zeros:bool=False) -> List[int]: bpe_tokens: List[int] = [] text = Tokenizer.whitespace_clean(text.strip()).lower() for token in re.findall(self.pat, text): @@ -123,7 +123,7 @@ class Tokenizer: class Embedder(ABC): input_key: str @abstractmethod - def __call__(self, text:str) -> Union[Tensor,Tuple[Tensor,...]]: + def __call__(self, x:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]: pass @@ -222,9 +222,11 @@ class FrozenClosedClipEmbedder(Embedder): self.transformer = Closed.ClipTextModel(ret_layer_idx) self.input_key = "txt" - def __call__(self, text:str) -> Union[Tensor,Tuple[Tensor,...]]: - tokens = Tensor(self.tokenizer.encode(text)) - return self.transformer.text_model(tokens.reshape(1,-1)) + def __call__(self, texts:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]: + if isinstance(texts, str): texts = [texts] + assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}" + tokens = Tensor.cat(*[Tensor(self.tokenizer.encode(text)) for text in texts], dim=0) + return self.transformer.text_model(tokens.reshape(len(texts),-1)) class Open: @@ -378,8 +380,10 @@ class FrozenOpenClipEmbedder(Embedder): else: return penultimate - def __call__(self, text:str) -> Union[Tensor,Tuple[Tensor,...]]: - tokens = self.tokenize(text) + def __call__(self, texts:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]: + if isinstance(texts, str): texts = [texts] + assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}" + tokens = Tensor.cat(*[self.tokenize(text) for text in texts], dim=0) return self.embed_tokens(tokens)