From 8c9c1cf62f61cecea272fa0fe8645ab721a3aa41 Mon Sep 17 00:00:00 2001 From: Tobias Fischer Date: Mon, 1 Jul 2024 22:33:01 -0400 Subject: [PATCH] Pulled CLIP and UNet into Seperate Files (#5253) * pulled clip and unet into seperate files * reference cleanup, lru cache fix * better pool indexing --- examples/sdxl.py | 519 +------------------ examples/stable_diffusion.py | 403 +------------- extra/models/clip.py | 348 +++++++++++++ extra/models/unet.py | 253 +++++++++ test/external/external_test_jit_on_models.py | 4 +- test/models/test_real_world.py | 5 +- 6 files changed, 641 insertions(+), 891 deletions(-) create mode 100644 extra/models/clip.py create mode 100644 extra/models/unet.py diff --git a/examples/sdxl.py b/examples/sdxl.py index 26ed3a4ef3..f9fc2248ea 100644 --- a/examples/sdxl.py +++ b/examples/sdxl.py @@ -4,14 +4,16 @@ # mlfoundations/open_clip | MIT | https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/LICENSE from tinygrad import Tensor, TinyJit, dtypes -from tinygrad.nn import Linear, Conv2d, GroupNorm, LayerNorm, Embedding +from tinygrad.nn import Conv2d, GroupNorm from tinygrad.nn.state import safe_load, load_state_dict -from tinygrad.helpers import fetch, trange, colored, THREEFRY -from examples.stable_diffusion import ClipTokenizer, ResnetBlock, Mid, Downsample, Upsample +from tinygrad.helpers import fetch, trange, colored +from extra.models.clip import Embedder, FrozenClosedClipEmbedder, FrozenOpenClipEmbedder +from extra.models.unet import UNetModel, Upsample, Downsample, timestep_embedding +from examples.stable_diffusion import ResnetBlock, Mid import numpy as np -from typing import Dict, List, Union, Callable, Optional, Any, Set, Tuple -import math, argparse, tempfile +from typing import Dict, List, Callable, Optional, Any, Set, Tuple +import argparse, tempfile from abc import ABC, abstractmethod from pathlib import Path from PIL import Image @@ -22,13 +24,13 @@ from PIL import Image # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/configs/inference/sd_xl_refiner.yaml configs: Dict = { "SDXL_Base": { - "model": {"adm_in_channels": 2816, "in_channels": 4, "out_channels": 4, "model_channels": 320, "attention_resolutions": [4, 2], "num_res_blocks": 2, "channel_mult": [1, 2, 4], "d_head": 64, "transformer_depth": [1, 2, 10], "ctx_dim": 2048}, + "model": {"adm_in_ch": 2816, "in_ch": 4, "out_ch": 4, "model_ch": 320, "attention_resolutions": [4, 2], "num_res_blocks": 2, "channel_mult": [1, 2, 4], "d_head": 64, "transformer_depth": [1, 2, 10], "ctx_dim": 2048, "use_linear": True}, "conditioner": {"concat_embedders": ["original_size_as_tuple", "crop_coords_top_left", "target_size_as_tuple"]}, "first_stage_model": {"ch": 128, "in_ch": 3, "out_ch": 3, "z_ch": 4, "ch_mult": [1, 2, 4, 4], "num_res_blocks": 2, "resolution": 256}, "denoiser": {"num_idx": 1000}, }, "SDXL_Refiner": { - "model": {"adm_in_channels": 2560, "in_channels": 4, "out_channels": 4, "model_channels": 384, "attention_resolutions": [4, 2], "num_res_blocks": 2, "channel_mult": [1, 2, 4, 4], "d_head": 64, "transformer_depth": [4, 4, 4, 4], "ctx_dim": [1280, 1280, 1280, 1280]}, + "model": {"adm_in_ch": 2560, "in_ch": 4, "out_ch": 4, "model_ch": 384, "attention_resolutions": [4, 2], "num_res_blocks": 2, "channel_mult": [1, 2, 4, 4], "d_head": 64, "transformer_depth": [4, 4, 4, 4], "ctx_dim": [1280, 1280, 1280, 1280], "use_linear": True}, "conditioner": {"concat_embedders": ["original_size_as_tuple", "crop_coords_top_left", "aesthetic_score"]}, "first_stage_model": {"ch": 128, "in_ch": 3, "out_ch": 3, "z_ch": 4, "ch_mult": [1, 2, 4, 4], "num_res_blocks": 2, "resolution": 256}, "denoiser": {"num_idx": 1000}, @@ -40,256 +42,6 @@ def tensor_identity(x:Tensor) -> Tensor: return x -class UNet: - """ - Namespace for UNet model components. - """ - - # https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L136 - class ResBlock: - def __init__(self, channels:int, emb_channels:int, out_channels:int): - self.in_layers = [ - GroupNorm(32, channels), - Tensor.silu, - Conv2d(channels, out_channels, 3, padding=1), - ] - self.emb_layers = [ - Tensor.silu, - Linear(emb_channels, out_channels), - ] - self.out_layers = [ - GroupNorm(32, out_channels), - Tensor.silu, - lambda x: x, # needed for weights loading code to work - Conv2d(out_channels, out_channels, 3, padding=1), - ] - self.skip_connection = Conv2d(channels, out_channels, 1) if channels != out_channels else tensor_identity - - def __call__(self, x:Tensor, emb:Tensor) -> Tensor: - h = x.sequential(self.in_layers) - emb_out = emb.sequential(self.emb_layers) - h = h + emb_out.reshape(*emb_out.shape, 1, 1) - h = h.sequential(self.out_layers) - return self.skip_connection(x) + h - - - # https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L163 - class CrossAttention: - def __init__(self, query_dim:int, ctx_dim:int, n_heads:int, d_head:int): - self.to_q = Linear(query_dim, n_heads*d_head, bias=False) - self.to_k = Linear(ctx_dim, n_heads*d_head, bias=False) - self.to_v = Linear(ctx_dim, n_heads*d_head, bias=False) - self.num_heads = n_heads - self.head_size = d_head - self.to_out = [Linear(n_heads*d_head, query_dim)] - - def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor: - ctx = x if ctx is None else ctx - q,k,v = self.to_q(x), self.to_k(ctx), self.to_v(ctx) - q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)] - attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1,2) - h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size) - return h_.sequential(self.to_out) - - - # https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L180 - class GEGLU: - def __init__(self, dim_in:int, dim_out:int): - self.proj = Linear(dim_in, dim_out * 2) - self.dim_out = dim_out - - def __call__(self, x:Tensor) -> Tensor: - x, gate = self.proj(x).chunk(2, dim=-1) - return x * gate.gelu() - - - # https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L189 - class FeedForward: - def __init__(self, dim:int, mult:int=4): - self.net = [ - UNet.GEGLU(dim, dim*mult), - lambda x: x, # needed for weights loading code to work - Linear(dim*mult, dim) - ] - - def __call__(self, x:Tensor) -> Tensor: - return x.sequential(self.net) - - - # https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L200 - class BasicTransformerBlock: - def __init__(self, dim:int, ctx_dim:int, n_heads:int, d_head:int): - self.attn1 = UNet.CrossAttention(dim, dim, n_heads, d_head) - self.ff = UNet.FeedForward(dim) - self.attn2 = UNet.CrossAttention(dim, ctx_dim, n_heads, d_head) - self.norm1 = LayerNorm(dim) - self.norm2 = LayerNorm(dim) - self.norm3 = LayerNorm(dim) - - def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor: - x = x + self.attn1(self.norm1(x)) - x = x + self.attn2(self.norm2(x), ctx=ctx) - x = x + self.ff(self.norm3(x)) - return x - - - # https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L215 - # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/attention.py#L619 - class SpatialTransformer: - def __init__(self, channels:int, n_heads:int, d_head:int, ctx_dim:Union[int,List[int]], depth:int=1): - if isinstance(ctx_dim, int): - ctx_dim = [ctx_dim]*depth - else: - assert isinstance(ctx_dim, list) and depth == len(ctx_dim) - self.norm = GroupNorm(32, channels) - assert channels == n_heads * d_head - self.proj_in = Linear(channels, n_heads * d_head) - self.transformer_blocks = [UNet.BasicTransformerBlock(channels, ctx_dim[d], n_heads, d_head) for d in range(depth)] - self.proj_out = Linear(n_heads * d_head, channels) - - def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor: - B, C, H, W = x.shape - x_in = x - x = self.norm(x) - x = x.reshape(B, C, H*W).permute(0,2,1) # b c h w -> b c (h w) -> b (h w) c - x = self.proj_in(x) - for block in self.transformer_blocks: - x = block(x, ctx=ctx) - x = self.proj_out(x) - x = x.permute(0,2,1).reshape(B, C, H, W) # b (h w) c -> b c (h w) -> b c h w - return x + x_in - - -# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/util.py#L207 -# https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L251 -def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000): - half = dim // 2 - freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp() - args = timesteps.unsqueeze(1) * freqs.unsqueeze(0) - return Tensor.cat(args.cos(), args.sin(), dim=-1).cast(dtypes.float16) - - -# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/openaimodel.py#L472 -# https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L257 -class UNetModel: - def __init__(self, adm_in_channels:int, in_channels:int, out_channels:int, model_channels:int, attention_resolutions:List[int], num_res_blocks:int, channel_mult:List[int], d_head:int, transformer_depth:List[int], ctx_dim:Union[int,List[int]]): - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.num_res_blocks = [num_res_blocks] * len(channel_mult) - - self.attention_resolutions = attention_resolutions - self.dropout = 0.0 - self.channel_mult = channel_mult - self.conv_resample = True - self.num_classes = "sequential" - self.use_checkpoint = False - self.d_head = d_head - - time_embed_dim = model_channels * 4 - self.time_embed = [ - Linear(model_channels, time_embed_dim), - Tensor.silu, - Linear(time_embed_dim, time_embed_dim), - ] - - self.label_emb = [ - [ - Linear(adm_in_channels, time_embed_dim), - Tensor.silu, - Linear(time_embed_dim, time_embed_dim), - ] - ] - - self.input_blocks = [ - [Conv2d(in_channels, model_channels, 3, padding=1)] - ] - input_block_channels = [model_channels] - ch = model_channels - ds = 1 - for idx, mult in enumerate(channel_mult): - for _ in range(self.num_res_blocks[idx]): - layers: List[Any] = [ - UNet.ResBlock(ch, time_embed_dim, model_channels*mult), - ] - ch = mult * model_channels - if ds in attention_resolutions: - n_heads = ch // d_head - layers.append(UNet.SpatialTransformer(ch, n_heads, d_head, ctx_dim, depth=transformer_depth[idx])) - - self.input_blocks.append(layers) - input_block_channels.append(ch) - - if idx != len(channel_mult) - 1: - self.input_blocks.append([ - Downsample(ch), - ]) - input_block_channels.append(ch) - ds *= 2 - - n_heads = ch // d_head - self.middle_block: List = [ - UNet.ResBlock(ch, time_embed_dim, ch), - UNet.SpatialTransformer(ch, n_heads, d_head, ctx_dim, depth=transformer_depth[-1]), - UNet.ResBlock(ch, time_embed_dim, ch), - ] - - self.output_blocks = [] - for idx, mult in list(enumerate(channel_mult))[::-1]: - for i in range(self.num_res_blocks[idx] + 1): - ich = input_block_channels.pop() - layers = [ - UNet.ResBlock(ch + ich, time_embed_dim, model_channels*mult), - ] - ch = model_channels * mult - - if ds in attention_resolutions: - n_heads = ch // d_head - layers.append(UNet.SpatialTransformer(ch, n_heads, d_head, ctx_dim, depth=transformer_depth[idx])) - - if idx > 0 and i == self.num_res_blocks[idx]: - layers.append(Upsample(ch)) - ds //= 2 - self.output_blocks.append(layers) - - self.out = [ - GroupNorm(32, ch), - Tensor.silu, - Conv2d(model_channels, out_channels, 3, padding=1), - ] - - def __call__(self, x:Tensor, tms:Tensor, ctx:Tensor, y:Tensor) -> Tensor: - t_emb = timestep_embedding(tms, self.model_channels).cast(dtypes.float16) - emb = t_emb.sequential(self.time_embed) - - assert y.shape[0] == x.shape[0] - emb = emb + y.sequential(self.label_emb[0]) - - emb = emb.cast(dtypes.float16) - ctx = ctx.cast(dtypes.float16) - x = x .cast(dtypes.float16) - - def run(x:Tensor, bb) -> Tensor: - if isinstance(bb, UNet.ResBlock): x = bb(x, emb) - elif isinstance(bb, UNet.SpatialTransformer): x = bb(x, ctx) - else: x = bb(x) - return x - - saved_inputs = [] - for b in self.input_blocks: - for bb in b: - x = run(x, bb) - saved_inputs.append(x) - for bb in self.middle_block: - x = run(x, bb) - for b in self.output_blocks: - x = x.cat(saved_inputs.pop(), dim=1) - for bb in b: - x = run(x, bb) - - return x.sequential(self.out) - - class DiffusionModel: def __init__(self, *args, **kwargs): self.diffusion_model = UNetModel(*args, **kwargs) @@ -302,230 +54,6 @@ class Embedder(ABC): pass -class Closed: - """ - Namespace for OpenAI CLIP model components. - """ - - # https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L329 - class ClipMlp: - def __init__(self): - self.fc1 = Linear(768, 3072) - self.fc2 = Linear(3072, 768) - - def __call__(self, h:Tensor) -> Tensor: - h = self.fc1(h) - h = h.quick_gelu() - h = self.fc2(h) - return h - - - # https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L340 - class ClipAttention: - def __init__(self): - self.embed_dim = 768 - self.num_heads = 12 - self.head_dim = self.embed_dim // self.num_heads - self.k_proj = Linear(self.embed_dim, self.embed_dim) - self.v_proj = Linear(self.embed_dim, self.embed_dim) - self.q_proj = Linear(self.embed_dim, self.embed_dim) - self.out_proj = Linear(self.embed_dim, self.embed_dim) - - def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor: - bsz, tgt_len, embed_dim = hidden_states.shape - q,k,v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) - q,k,v = [x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) for x in (q,k,v)] - attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=causal_attention_mask) - return self.out_proj(attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim)) - - - # https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L357 - class ClipEncoderLayer: - def __init__(self): - self.self_attn = Closed.ClipAttention() - self.layer_norm1 = LayerNorm(768) - self.mlp = Closed.ClipMlp() - self.layer_norm2 = LayerNorm(768) - - def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor: - residual = hidden_states - hidden_states = self.layer_norm1(hidden_states) - hidden_states = self.self_attn(hidden_states, causal_attention_mask) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.layer_norm2(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - return hidden_states - - - # https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L386 - class ClipTextEmbeddings: - def __init__(self): - self.token_embedding = Embedding(49408, 768) - self.position_embedding = Embedding(77, 768) - - def __call__(self, input_ids:Tensor, position_ids:Tensor) -> Tensor: - return self.token_embedding(input_ids) + self.position_embedding(position_ids) - - - # https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L377 - class ClipEncoder: - def __init__(self, layer_count:int=12): - self.layers = [Closed.ClipEncoderLayer() for _ in range(layer_count)] - - def __call__(self, x:Tensor, causal_attention_mask:Tensor, ret_layer_idx:Optional[int]=None) -> Tensor: - # the indexing of layers is NOT off by 1, the original code considers the "input" as the first hidden state - layers = self.layers if ret_layer_idx is None else self.layers[:ret_layer_idx] - for l in layers: - x = l(x, causal_attention_mask) - return x - - - # https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L394 - class ClipTextTransformer: - def __init__(self, ret_layer_idx:Optional[int]=None): - self.embeddings = Closed.ClipTextEmbeddings() - self.encoder = Closed.ClipEncoder() - self.final_layer_norm = LayerNorm(768) - self.ret_layer_idx = ret_layer_idx - - def __call__(self, input_ids:Tensor) -> Tensor: - x = self.embeddings(input_ids, Tensor.arange(input_ids.shape[1]).reshape(1, -1)) - x = self.encoder(x, Tensor.full((1, 1, 77, 77), float("-inf")).triu(1), self.ret_layer_idx) - return self.final_layer_norm(x) if (self.ret_layer_idx is None) else x - - class ClipTextModel: - def __init__(self): - self.text_model = Closed.ClipTextTransformer(ret_layer_idx=11) - - -# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L331 -class FrozenClosedClipEmbedder(Embedder): - def __init__(self): - self.tokenizer = ClipTokenizer() - self.transformer = Closed.ClipTextModel() - self.input_key = "txt" - - def __call__(self, text:Tensor) -> Tensor: - tokens = Tensor(self.tokenizer.encode(text)) - return self.transformer.text_model(tokens.reshape(1,-1)) - - -class Open: - """ - Namespace for OpenCLIP model components. - """ - - class MultiheadAttention: - def __init__(self, dims:int, n_heads:int): - self.dims = dims - self.n_heads = n_heads - self.d_head = self.dims // self.n_heads - - self.in_proj_bias = Tensor.empty(3*dims) - self.in_proj_weight = Tensor.empty(3*dims, dims) - self.out_proj = Linear(dims, dims) - - def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor: - T,B,C = x.shape - - proj = x.linear(self.in_proj_weight.T, self.in_proj_bias) - proj = proj.unflatten(-1, (3,C)).unsqueeze(0).transpose(0,-2) - q,k,v = proj.chunk(3) - - q,k,v = [y.reshape(T, B*self.n_heads, self.d_head).transpose(0, 1).reshape(B, self.n_heads, T, self.d_head) for y in (q,k,v)] - - attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) - attn_output = attn_output.permute(2,0,1,3).reshape(B*T, C) - - attn_output = self.out_proj(attn_output) - attn_output = attn_output.reshape(T, B, C) - - return attn_output - - - class Mlp: - def __init__(self, dims, hidden_dims): - self.c_fc = Linear(dims, hidden_dims) - self.c_proj = Linear(hidden_dims, dims) - - def __call__(self, x:Tensor) -> Tensor: - return x.sequential([self.c_fc, Tensor.gelu, self.c_proj]) - - - # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L210 - class ResidualAttentionBlocks: - def __init__(self, dims:int, n_heads:int, mlp_ratio:float): - self.ln_1 = LayerNorm(dims) - self.attn = Open.MultiheadAttention(dims, n_heads) - - self.ln_2 = LayerNorm(dims) - self.mlp = Open.Mlp(dims, int(dims * mlp_ratio)) - - def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor: - x = x + self.attn(self.ln_1(x), attn_mask=attn_mask) - x = x + self.mlp(self.ln_2(x)) - return x - - - # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L317 - class ClipTransformer: - def __init__(self, dims:int, layers:int, n_heads:int, mlp_ratio:float=4.0): - self.resblocks = [ - Open.ResidualAttentionBlocks(dims, n_heads, mlp_ratio) for _ in range(layers) - ] - - def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor: - x = x.transpose(0, 1) - for r in self.resblocks: - x = r(x, attn_mask=attn_mask) - x = x.transpose(0, 1) - return x - - - # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/model.py#L220 - # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L661 - class ClipTextTransformer: - def __init__(self, dims:int, vocab_size:int=49408, n_heads:int=20, ctx_length:int=77, layers:int=32): - self.token_embedding = Embedding(vocab_size, dims) - self.positional_embedding = Tensor.empty(ctx_length, dims) - self.transformer = Open.ClipTransformer(dims, layers, n_heads) - self.ln_final = LayerNorm(dims) - self.text_projection = Tensor.empty(dims, dims) - - @property - def attn_mask(self) -> Tensor: - if not hasattr(self, "_attn_mask"): - self._attn_mask = Tensor.full((77, 77), float("-inf")).triu(1) - return self._attn_mask - - -# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L396 -class FrozenOpenClipEmbedder(Embedder): - def __init__(self, dims:int=1280): - self.model = Open.ClipTextTransformer(dims) - self.input_key = "txt" - self.tokenizer = ClipTokenizer() - - def text_transformer_forward(self, x:Tensor, attn_mask:Optional[Tensor]=None): - for r in self.model.transformer.resblocks: - x, penultimate = r(x, attn_mask=attn_mask), x - return x.permute(1,0,2), penultimate.permute(1,0,2) - - def __call__(self, text:Tensor) -> Tensor: - tokens = Tensor(self.tokenizer.encode(text, pad_with_zeros=True), dtype=dtypes.int64).reshape(1,-1) - - x = self.model.token_embedding(tokens).add(self.model.positional_embedding).permute(1,0,2) - x, penultimate = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) - x = self.model.ln_final(x) - pooled = x[Tensor.arange(x.shape[0]), tokens.argmax(axis=-1)] @ self.model.text_projection - - return penultimate, pooled - - # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L913 class ConcatTimestepEmbedderND(Embedder): def __init__(self, outdim:int, input_key:str): @@ -547,8 +75,8 @@ class Conditioner: def __init__(self, concat_embedders:List[str]): self.embedders = [ - FrozenClosedClipEmbedder(), - FrozenOpenClipEmbedder(), + FrozenClosedClipEmbedder(ret_layer_idx=11), + FrozenOpenClipEmbedder(dims=1280, n_heads=20, layers=32, return_pooled=True), ] for input_key in concat_embedders: self.embedders.append(ConcatTimestepEmbedderND(256, input_key)) @@ -585,28 +113,6 @@ class FirstStage: Namespace for First Stage Model components """ - # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/model.py#L74 - # https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L102 - class Downsample: - def __init__(self, dims:int): - self.conv = Conv2d(dims, dims, kernel_size=3, stride=2, padding=(0,1,0,1)) - - def __call__(self, x:Tensor) -> Tensor: - return self.conv(x) - - - # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/model.py#L58 - # https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L83 - class Upsample: - def __init__(self, dims:int): - self.conv = Conv2d(dims, dims, kernel_size=3, stride=1, padding=1) - - def __call__(self, x:Tensor) -> Tensor: - B,C,Y,X = x.shape - x = x.reshape(B, C, Y, 1, X, 1).expand(B, C, Y, 2, X, 2).reshape(B, C, Y*2, X*2) - return self.conv(x) - - # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/model.py#L487 class Encoder: def __init__(self, ch:int, in_ch:int, out_ch:int, z_ch:int, ch_mult:List[int], num_res_blocks:int, resolution:int): @@ -626,7 +132,7 @@ class FirstStage: block.append(ResnetBlock(block_in, block_out)) block_in = block_out - downsample = tensor_identity if (i_level == len(ch_mult)-1) else FirstStage.Downsample(block_in) + downsample = tensor_identity if (i_level == len(ch_mult)-1) else Downsample(block_in) self.down.append(BlockEntry(block, downsample)) self.mid = Mid(block_in) @@ -779,7 +285,6 @@ class SDXL: return run(self.model.diffusion_model, *prep(x*c_in, c_noise, cond["crossattn"], cond["vector"], c_out, x)) - # https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L543 def decode(self, x:Tensor) -> Tensor: return self.first_stage_model.decode(1.0 / 0.13025 * x) diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index 8a6ac5b679..8618ef5f4d 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -2,16 +2,18 @@ # https://github.com/ekagra-ranjan/huggingface-blog/blob/main/stable_diffusion.md import tempfile from pathlib import Path -import gzip, argparse, math, re -from functools import lru_cache +import argparse from collections import namedtuple +from typing import Dict, Any from PIL import Image import numpy as np from tinygrad import Device, GlobalCounters, dtypes, Tensor, TinyJit -from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm, THREEFRY -from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding +from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm +from tinygrad.nn import Conv2d, GroupNorm from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict +from extra.models.clip import Closed, Tokenizer +from extra.models.unet import UNetModel class AttnBlock: def __init__(self, in_channels): @@ -131,391 +133,32 @@ class AutoencoderKL: latent = self.post_quant_conv(latent) return self.decoder(latent) -# not to be confused with ResnetBlock -class ResBlock: - def __init__(self, channels, emb_channels, out_channels): - self.in_layers = [ - GroupNorm(32, channels), - Tensor.silu, - Conv2d(channels, out_channels, 3, padding=1) - ] - self.emb_layers = [ - Tensor.silu, - Linear(emb_channels, out_channels) - ] - self.out_layers = [ - GroupNorm(32, out_channels), - Tensor.silu, - lambda x: x, # needed for weights loading code to work - Conv2d(out_channels, out_channels, 3, padding=1) - ] - self.skip_connection = Conv2d(channels, out_channels, 1) if channels != out_channels else lambda x: x - - def __call__(self, x, emb): - h = x.sequential(self.in_layers) - emb_out = emb.sequential(self.emb_layers) - h = h + emb_out.reshape(*emb_out.shape, 1, 1) - h = h.sequential(self.out_layers) - ret = self.skip_connection(x) + h - return ret - -class CrossAttention: - def __init__(self, query_dim, context_dim, n_heads, d_head): - self.to_q = Linear(query_dim, n_heads*d_head, bias=False) - self.to_k = Linear(context_dim, n_heads*d_head, bias=False) - self.to_v = Linear(context_dim, n_heads*d_head, bias=False) - self.num_heads = n_heads - self.head_size = d_head - self.to_out = [Linear(n_heads*d_head, query_dim)] - - def __call__(self, x, context=None): - context = x if context is None else context - q,k,v = self.to_q(x), self.to_k(context), self.to_v(context) - q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)] - attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1,2) - h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size) - return h_.sequential(self.to_out) - -class GEGLU: - def __init__(self, dim_in, dim_out): - self.proj = Linear(dim_in, dim_out * 2) - self.dim_out = dim_out - - def __call__(self, x): - x, gate = self.proj(x).chunk(2, dim=-1) - return x * gate.gelu() - -class FeedForward: - def __init__(self, dim, mult=4): - self.net = [ - GEGLU(dim, dim*mult), - lambda x: x, # needed for weights loading code to work - Linear(dim*mult, dim) - ] - - def __call__(self, x): - return x.sequential(self.net) - -class BasicTransformerBlock: - def __init__(self, dim, context_dim, n_heads, d_head): - self.attn1 = CrossAttention(dim, dim, n_heads, d_head) - self.ff = FeedForward(dim) - self.attn2 = CrossAttention(dim, context_dim, n_heads, d_head) - self.norm1 = LayerNorm(dim) - self.norm2 = LayerNorm(dim) - self.norm3 = LayerNorm(dim) - - def __call__(self, x, context=None): - x = self.attn1(self.norm1(x)) + x - x = self.attn2(self.norm2(x), context=context) + x - x = self.ff(self.norm3(x)) + x - return x - -class SpatialTransformer: - def __init__(self, channels, context_dim, n_heads, d_head): - self.norm = GroupNorm(32, channels) - assert channels == n_heads * d_head - self.proj_in = Conv2d(channels, n_heads * d_head, 1) - self.transformer_blocks = [BasicTransformerBlock(channels, context_dim, n_heads, d_head)] - self.proj_out = Conv2d(n_heads * d_head, channels, 1) - - def __call__(self, x, context=None): - b, c, h, w = x.shape - x_in = x - x = self.norm(x) - x = self.proj_in(x) - x = x.reshape(b, c, h*w).permute(0,2,1) - for block in self.transformer_blocks: - x = block(x, context=context) - x = x.permute(0,2,1).reshape(b, c, h, w) - ret = self.proj_out(x) + x_in - return ret - -class Downsample: - def __init__(self, channels): - self.op = Conv2d(channels, channels, 3, stride=2, padding=1) - - def __call__(self, x): - return self.op(x) - -class Upsample: - def __init__(self, channels): - self.conv = Conv2d(channels, channels, 3, padding=1) - - def __call__(self, x): - bs,c,py,px = x.shape - x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2) - return self.conv(x) - -def timestep_embedding(timesteps, dim, max_period=10000): - half = dim // 2 - freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp() - args = timesteps * freqs - return Tensor.cat(args.cos(), args.sin()).reshape(1, -1) - -class UNetModel: - def __init__(self): - self.time_embed = [ - Linear(320, 1280), - Tensor.silu, - Linear(1280, 1280), - ] - self.input_blocks = [ - [Conv2d(4, 320, kernel_size=3, padding=1)], - [ResBlock(320, 1280, 320), SpatialTransformer(320, 768, 8, 40)], - [ResBlock(320, 1280, 320), SpatialTransformer(320, 768, 8, 40)], - [Downsample(320)], - [ResBlock(320, 1280, 640), SpatialTransformer(640, 768, 8, 80)], - [ResBlock(640, 1280, 640), SpatialTransformer(640, 768, 8, 80)], - [Downsample(640)], - [ResBlock(640, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)], - [ResBlock(1280, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)], - [Downsample(1280)], - [ResBlock(1280, 1280, 1280)], - [ResBlock(1280, 1280, 1280)] - ] - self.middle_block = [ - ResBlock(1280, 1280, 1280), - SpatialTransformer(1280, 768, 8, 160), - ResBlock(1280, 1280, 1280) - ] - self.output_blocks = [ - [ResBlock(2560, 1280, 1280)], - [ResBlock(2560, 1280, 1280)], - [ResBlock(2560, 1280, 1280), Upsample(1280)], - [ResBlock(2560, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)], - [ResBlock(2560, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)], - [ResBlock(1920, 1280, 1280), SpatialTransformer(1280, 768, 8, 160), Upsample(1280)], - [ResBlock(1920, 1280, 640), SpatialTransformer(640, 768, 8, 80)], # 6 - [ResBlock(1280, 1280, 640), SpatialTransformer(640, 768, 8, 80)], - [ResBlock(960, 1280, 640), SpatialTransformer(640, 768, 8, 80), Upsample(640)], - [ResBlock(960, 1280, 320), SpatialTransformer(320, 768, 8, 40)], - [ResBlock(640, 1280, 320), SpatialTransformer(320, 768, 8, 40)], - [ResBlock(640, 1280, 320), SpatialTransformer(320, 768, 8, 40)], - ] - self.out = [ - GroupNorm(32, 320), - Tensor.silu, - Conv2d(320, 4, kernel_size=3, padding=1) - ] - - def __call__(self, x, timesteps=None, context=None): - # TODO: real time embedding - t_emb = timestep_embedding(timesteps, 320) - emb = t_emb.sequential(self.time_embed) - - def run(x, bb): - if isinstance(bb, ResBlock): x = bb(x, emb) - elif isinstance(bb, SpatialTransformer): x = bb(x, context) - else: x = bb(x) - return x - - saved_inputs = [] - for i,b in enumerate(self.input_blocks): - #print("input block", i) - for bb in b: - x = run(x, bb) - saved_inputs.append(x) - for bb in self.middle_block: - x = run(x, bb) - for i,b in enumerate(self.output_blocks): - #print("output block", i) - x = x.cat(saved_inputs.pop(), dim=1) - for bb in b: - x = run(x, bb) - return x.sequential(self.out) - -class CLIPMLP: - def __init__(self): - self.fc1 = Linear(768, 3072) - self.fc2 = Linear(3072, 768) - - def __call__(self, hidden_states): - hidden_states = self.fc1(hidden_states) - hidden_states = hidden_states.quick_gelu() - hidden_states = self.fc2(hidden_states) - return hidden_states - -class CLIPAttention: - def __init__(self): - self.embed_dim = 768 - self.num_heads = 12 - self.head_dim = self.embed_dim // self.num_heads - self.k_proj = Linear(self.embed_dim, self.embed_dim) - self.v_proj = Linear(self.embed_dim, self.embed_dim) - self.q_proj = Linear(self.embed_dim, self.embed_dim) - self.out_proj = Linear(self.embed_dim, self.embed_dim) - - def __call__(self, hidden_states, causal_attention_mask): - bsz, tgt_len, embed_dim = hidden_states.shape - q,k,v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) - q,k,v = [x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) for x in (q,k,v)] - attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=causal_attention_mask) - return self.out_proj(attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim)) - -class CLIPEncoderLayer: - def __init__(self): - self.self_attn = CLIPAttention() - self.layer_norm1 = LayerNorm(768) - self.mlp = CLIPMLP() - self.layer_norm2 = LayerNorm(768) - - def __call__(self, hidden_states, causal_attention_mask): - residual = hidden_states - hidden_states = self.layer_norm1(hidden_states) - hidden_states = self.self_attn(hidden_states, causal_attention_mask) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.layer_norm2(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - return hidden_states - -class CLIPEncoder: - def __init__(self): - self.layers = [CLIPEncoderLayer() for i in range(12)] - - def __call__(self, hidden_states, causal_attention_mask): - for l in self.layers: - hidden_states = l(hidden_states, causal_attention_mask) - return hidden_states - -class CLIPTextEmbeddings: - def __init__(self): - self.token_embedding = Embedding(49408, 768) - self.position_embedding = Embedding(77, 768) - - def __call__(self, input_ids, position_ids): - return self.token_embedding(input_ids) + self.position_embedding(position_ids) - -class CLIPTextTransformer: - def __init__(self): - self.embeddings = CLIPTextEmbeddings() - self.encoder = CLIPEncoder() - self.final_layer_norm = LayerNorm(768) - - def __call__(self, input_ids): - x = self.embeddings(input_ids, Tensor.arange(input_ids.shape[1]).reshape(1, -1)) - x = self.encoder(x, Tensor.full((1, 1, 77, 77), float("-inf")).triu(1)) - return self.final_layer_norm(x) - -# Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license) -@lru_cache() -def default_bpe(): return fetch("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", "bpe_simple_vocab_16e6.txt.gz") - -def get_pairs(word): - """Return set of symbol pairs in a word. - Word is represented as tuple of symbols (symbols being variable-length strings). - """ - return set(zip(word, word[1:])) - -def whitespace_clean(text): - text = re.sub(r'\s+', ' ', text) - text = text.strip() - return text - -def bytes_to_unicode(): - """ - Returns list of utf-8 byte and a corresponding list of unicode strings. - The reversible bpe codes work on unicode strings. - This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. - When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. - This is a significant percentage of your normal, say, 32K bpe vocab. - To avoid that, we want lookup tables between utf-8 bytes and unicode strings. - And avoids mapping to whitespace/control characters the bpe code barfs on. - """ - bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) - cs = bs[:] - n = 0 - for b in range(2**8): - if b not in bs: - bs.append(b) - cs.append(2**8+n) - n += 1 - cs = [chr(n) for n in cs] - return dict(zip(bs, cs)) - -class ClipTokenizer: - def __init__(self, bpe_path: str = default_bpe()): - self.byte_encoder = bytes_to_unicode() - merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') - merges = merges[1:49152-256-2+1] - merges = [tuple(merge.split()) for merge in merges] - vocab = list(bytes_to_unicode().values()) - vocab = vocab + [v+'' for v in vocab] - for merge in merges: - vocab.append(''.join(merge)) - vocab.extend(['<|startoftext|>', '<|endoftext|>']) - self.encoder = dict(zip(vocab, range(len(vocab)))) - self.bpe_ranks = dict(zip(merges, range(len(merges)))) - self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} - self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE) - - def bpe(self, token): - if token in self.cache: - return self.cache[token] - word = tuple(token[:-1]) + ( token[-1] + '',) - pairs = get_pairs(word) - - if not pairs: - return token+'' - - while True: - bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) - if bigram not in self.bpe_ranks: - break - first, second = bigram - new_word = [] - i = 0 - while i < len(word): - try: - j = word.index(first, i) - new_word.extend(word[i:j]) - i = j - except Exception: - new_word.extend(word[i:]) - break - - if word[i] == first and i < len(word)-1 and word[i+1] == second: - new_word.append(first+second) - i += 2 - else: - new_word.append(word[i]) - i += 1 - new_word = tuple(new_word) - word = new_word - if len(word) == 1: - break - pairs = get_pairs(word) - word = ' '.join(word) - self.cache[token] = word - return word - - def encode(self, text, pad_with_zeros=False): - bpe_tokens = [] - text = whitespace_clean(text.strip()).lower() - for token in re.findall(self.pat, text): - token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) - bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) - # Truncation, keeping two slots for start and end tokens. - if len(bpe_tokens) > 75: - bpe_tokens = bpe_tokens[:75] - return [49406] + bpe_tokens + [49407] + ([0] if pad_with_zeros else [49407]) * (77 - len(bpe_tokens) - 2) - def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=1000): betas = np.linspace(beta_start ** 0.5, beta_end ** 0.5, n_training_steps, dtype=np.float32) ** 2 alphas = 1.0 - betas alphas_cumprod = np.cumprod(alphas, axis=0) return Tensor(alphas_cumprod) +unet_params: Dict[str,Any] = { + "adm_in_ch": None, + "in_ch": 4, + "out_ch": 4, + "model_ch": 320, + "attention_resolutions": [4, 2, 1], + "num_res_blocks": 2, + "channel_mult": [1, 2, 4, 4], + "n_heads": 8, + "transformer_depth": [1, 1, 1, 1], + "ctx_dim": 768, + "use_linear": False, +} + class StableDiffusion: def __init__(self): self.alphas_cumprod = get_alphas_cumprod() - self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel()) + self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel(**unet_params)) self.first_stage_model = AutoencoderKL() - self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = CLIPTextTransformer())) + self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = Closed.ClipTextTransformer())) def get_x_prev_and_pred_x0(self, x, e_t, a_t, a_prev): temperature = 1 @@ -597,7 +240,7 @@ if __name__ == "__main__": v.replace(v.cast(dtypes.float16).realize()) # run through CLIP to get context - tokenizer = ClipTokenizer() + tokenizer = Tokenizer.ClipTokenizer() prompt = Tensor([tokenizer.encode(args.prompt)]) context = model.cond_stage_model.transformer.text_model(prompt).realize() print("got CLIP context", context.shape) diff --git a/extra/models/clip.py b/extra/models/clip.py new file mode 100644 index 0000000000..0ab70e0b71 --- /dev/null +++ b/extra/models/clip.py @@ -0,0 +1,348 @@ +from tinygrad import Tensor, dtypes +from tinygrad.helpers import fetch +from tinygrad.nn import Linear, LayerNorm, Embedding + +from typing import List, Optional, Union, Tuple +from abc import ABC, abstractmethod +from functools import lru_cache +import re, gzip + +@lru_cache() +def default_bpe(): + # Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license) + return fetch("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", "bpe_simple_vocab_16e6.txt.gz") + +class Tokenizer: + """ + Namespace for CLIP Text Tokenizer components. + """ + + @staticmethod + def get_pairs(word): + """ + Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + return set(zip(word, word[1:])) + @staticmethod + def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + @staticmethod + def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + class ClipTokenizer: + def __init__(self): + self.byte_encoder = Tokenizer.bytes_to_unicode() + merges = gzip.open(default_bpe()).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(Tokenizer.bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = Tokenizer.get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + pairs = Tokenizer.get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text:str, pad_with_zeros:bool=False): + bpe_tokens: List[int] = [] + text = Tokenizer.whitespace_clean(text.strip()).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + # Truncation, keeping two slots for start and end tokens. + if len(bpe_tokens) > 75: + bpe_tokens = bpe_tokens[:75] + return [49406] + bpe_tokens + [49407] + ([0] if pad_with_zeros else [49407]) * (77 - len(bpe_tokens) - 2) + + +class Embedder(ABC): + input_key: str + @abstractmethod + def __call__(self, text:str) -> Union[Tensor,Tuple[Tensor,...]]: + pass + + +class Closed: + """ + Namespace for OpenAI CLIP model components. + """ + class ClipMlp: + def __init__(self): + self.fc1 = Linear(768, 3072) + self.fc2 = Linear(3072, 768) + + def __call__(self, h:Tensor) -> Tensor: + h = self.fc1(h) + h = h.quick_gelu() + h = self.fc2(h) + return h + + class ClipAttention: + def __init__(self): + self.embed_dim = 768 + self.num_heads = 12 + self.head_dim = self.embed_dim // self.num_heads + self.k_proj = Linear(self.embed_dim, self.embed_dim) + self.v_proj = Linear(self.embed_dim, self.embed_dim) + self.q_proj = Linear(self.embed_dim, self.embed_dim) + self.out_proj = Linear(self.embed_dim, self.embed_dim) + + def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor: + bsz, tgt_len, embed_dim = hidden_states.shape + q,k,v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) + q,k,v = [x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) for x in (q,k,v)] + attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=causal_attention_mask) + return self.out_proj(attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim)) + + class ClipEncoderLayer: + def __init__(self): + self.self_attn = Closed.ClipAttention() + self.layer_norm1 = LayerNorm(768) + self.mlp = Closed.ClipMlp() + self.layer_norm2 = LayerNorm(768) + + def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor: + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn(hidden_states, causal_attention_mask) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + class ClipTextEmbeddings: + def __init__(self): + self.token_embedding = Embedding(49408, 768) + self.position_embedding = Embedding(77, 768) + + def __call__(self, input_ids:Tensor, position_ids:Tensor) -> Tensor: + return self.token_embedding(input_ids) + self.position_embedding(position_ids) + + class ClipEncoder: + def __init__(self, layer_count:int=12): + self.layers = [Closed.ClipEncoderLayer() for _ in range(layer_count)] + + def __call__(self, x:Tensor, causal_attention_mask:Tensor, ret_layer_idx:Optional[int]=None) -> Tensor: + # the indexing of layers is NOT off by 1, the original code considers the "input" as the first hidden state + layers = self.layers if ret_layer_idx is None else self.layers[:ret_layer_idx] + for l in layers: + x = l(x, causal_attention_mask) + return x + + class ClipTextTransformer: + def __init__(self, ret_layer_idx:Optional[int]=None): + self.embeddings = Closed.ClipTextEmbeddings() + self.encoder = Closed.ClipEncoder() + self.final_layer_norm = LayerNorm(768) + self.ret_layer_idx = ret_layer_idx + + def __call__(self, input_ids:Tensor) -> Tensor: + x = self.embeddings(input_ids, Tensor.arange(input_ids.shape[1]).reshape(1, -1)) + x = self.encoder(x, Tensor.full((1, 1, 77, 77), float("-inf")).triu(1), self.ret_layer_idx) + return self.final_layer_norm(x) if (self.ret_layer_idx is None) else x + + class ClipTextModel: + def __init__(self, ret_layer_idx:Optional[int]): + self.text_model = Closed.ClipTextTransformer(ret_layer_idx=ret_layer_idx) + + +# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L331 +class FrozenClosedClipEmbedder(Embedder): + def __init__(self, ret_layer_idx:Optional[int]=None): + self.tokenizer = Tokenizer.ClipTokenizer() + self.transformer = Closed.ClipTextModel(ret_layer_idx) + self.input_key = "txt" + + def __call__(self, text:str) -> Union[Tensor,Tuple[Tensor,...]]: + tokens = Tensor(self.tokenizer.encode(text)) + return self.transformer.text_model(tokens.reshape(1,-1)) + + +class Open: + """ + Namespace for OpenCLIP model components. + """ + class MultiheadAttention: + def __init__(self, dims:int, n_heads:int): + self.dims = dims + self.n_heads = n_heads + self.d_head = self.dims // self.n_heads + + self.in_proj_bias = Tensor.empty(3*dims) + self.in_proj_weight = Tensor.empty(3*dims, dims) + self.out_proj = Linear(dims, dims) + + def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor: + T,B,C = x.shape + + proj = x.linear(self.in_proj_weight.T, self.in_proj_bias) + proj = proj.unflatten(-1, (3,C)).unsqueeze(0).transpose(0,-2) + q,k,v = proj.chunk(3) + + q,k,v = [y.reshape(T, B*self.n_heads, self.d_head).transpose(0, 1).reshape(B, self.n_heads, T, self.d_head) for y in (q,k,v)] + + attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + attn_output = attn_output.permute(2,0,1,3).reshape(B*T, C) + + attn_output = self.out_proj(attn_output) + attn_output = attn_output.reshape(T, B, C) + + return attn_output + + class Mlp: + def __init__(self, dims, hidden_dims): + self.c_fc = Linear(dims, hidden_dims) + self.c_proj = Linear(hidden_dims, dims) + + def __call__(self, x:Tensor) -> Tensor: + return x.sequential([self.c_fc, Tensor.gelu, self.c_proj]) + + # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L210 + class ResidualAttentionBlocks: + def __init__(self, dims:int, n_heads:int, mlp_ratio:float): + self.ln_1 = LayerNorm(dims) + self.attn = Open.MultiheadAttention(dims, n_heads) + + self.ln_2 = LayerNorm(dims) + self.mlp = Open.Mlp(dims, int(dims * mlp_ratio)) + + def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor: + x = x + self.attn(self.ln_1(x), attn_mask=attn_mask) + x = x + self.mlp(self.ln_2(x)) + return x + + # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L317 + class ClipTransformer: + def __init__(self, dims:int, layers:int, n_heads:int, mlp_ratio:float=4.0): + self.resblocks = [ + Open.ResidualAttentionBlocks(dims, n_heads, mlp_ratio) for _ in range(layers) + ] + + def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor: + x = x.transpose(0, 1).contiguous() + for r in self.resblocks: + x = r(x, attn_mask=attn_mask) + x = x.transpose(0, 1) + return x + + # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/model.py#L220 + # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L661 + class ClipTextTransformer: + def __init__(self, dims:int, n_heads:int, layers:int, vocab_size:int=49408, ctx_length:int=77): + self.token_embedding = Embedding(vocab_size, dims) + self.positional_embedding = Tensor.empty(ctx_length, dims) + self.transformer = Open.ClipTransformer(dims, layers, n_heads) + self.ln_final = LayerNorm(dims) + self.text_projection = Tensor.empty(dims, dims) + + @property + def attn_mask(self) -> Tensor: + if not hasattr(self, "_attn_mask"): + self._attn_mask = Tensor.full((77, 77), float("-inf")).triu(1) + return self._attn_mask + + def __call__(self, text:Tensor) -> Tensor: + seq_len = text.shape[1] + + x = self.token_embedding(text) + x = x + self.positional_embedding[:seq_len] + x = self.transformer(x, attn_mask=self.attn_mask) + x = self.ln_final(x) + + pooled = x[:, text.argmax(dim=-1)] @ self.text_projection + return pooled + + +# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L396 +# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L498 +class FrozenOpenClipEmbedder(Embedder): + def __init__(self, dims:int, n_heads:int, layers:int, return_pooled:bool): + self.tokenizer = Tokenizer.ClipTokenizer() + self.model = Open.ClipTextTransformer(dims, n_heads, layers) + self.return_pooled = return_pooled + self.input_key = "txt" + + def text_transformer_forward(self, x:Tensor, attn_mask:Optional[Tensor]=None): + for r in self.model.transformer.resblocks: + x, penultimate = r(x, attn_mask=attn_mask), x + return x.permute(1,0,2), penultimate.permute(1,0,2) + + def __call__(self, text:str) -> Union[Tensor,Tuple[Tensor,...]]: + tokens = Tensor(self.tokenizer.encode(text, pad_with_zeros=True), dtype=dtypes.int64).reshape(1,-1) + + x = self.model.token_embedding(tokens).add(self.model.positional_embedding).permute(1,0,2) + x, penultimate = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + + if self.return_pooled: + x = self.model.ln_final(x) + pooled = x[Tensor.arange(x.shape[0]), tokens.argmax(axis=-1).numpy().item()] @ self.model.text_projection + return penultimate, pooled + else: + return penultimate diff --git a/extra/models/unet.py b/extra/models/unet.py new file mode 100644 index 0000000000..fad41443cb --- /dev/null +++ b/extra/models/unet.py @@ -0,0 +1,253 @@ +from tinygrad import Tensor, dtypes +from tinygrad.nn import Linear, Conv2d, GroupNorm, LayerNorm + +from typing import Optional, Union, List, Any, Tuple +import math + +# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/util.py#L207 +def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000): + half = dim // 2 + freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp() + args = timesteps.unsqueeze(1) * freqs.unsqueeze(0) + return Tensor.cat(args.cos(), args.sin(), dim=-1).cast(dtypes.float16) + +class ResBlock: + def __init__(self, channels:int, emb_channels:int, out_channels:int): + self.in_layers = [ + GroupNorm(32, channels), + Tensor.silu, + Conv2d(channels, out_channels, 3, padding=1), + ] + self.emb_layers = [ + Tensor.silu, + Linear(emb_channels, out_channels), + ] + self.out_layers = [ + GroupNorm(32, out_channels), + Tensor.silu, + lambda x: x, # needed for weights loading code to work + Conv2d(out_channels, out_channels, 3, padding=1), + ] + self.skip_connection = Conv2d(channels, out_channels, 1) if channels != out_channels else (lambda x: x) + + def __call__(self, x:Tensor, emb:Tensor) -> Tensor: + h = x.sequential(self.in_layers) + emb_out = emb.sequential(self.emb_layers) + h = h + emb_out.reshape(*emb_out.shape, 1, 1) + h = h.sequential(self.out_layers) + return self.skip_connection(x) + h + +class CrossAttention: + def __init__(self, query_dim:int, ctx_dim:int, n_heads:int, d_head:int): + self.to_q = Linear(query_dim, n_heads*d_head, bias=False) + self.to_k = Linear(ctx_dim, n_heads*d_head, bias=False) + self.to_v = Linear(ctx_dim, n_heads*d_head, bias=False) + self.num_heads = n_heads + self.head_size = d_head + self.to_out = [Linear(n_heads*d_head, query_dim)] + + def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor: + ctx = x if ctx is None else ctx + q,k,v = self.to_q(x), self.to_k(ctx), self.to_v(ctx) + q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)] + attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1,2) + h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size) + return h_.sequential(self.to_out) + +class GEGLU: + def __init__(self, dim_in:int, dim_out:int): + self.proj = Linear(dim_in, dim_out * 2) + self.dim_out = dim_out + + def __call__(self, x:Tensor) -> Tensor: + x, gate = self.proj(x).chunk(2, dim=-1) + return x * gate.gelu() + +class FeedForward: + def __init__(self, dim:int, mult:int=4): + self.net = [ + GEGLU(dim, dim*mult), + lambda x: x, # needed for weights loading code to work + Linear(dim*mult, dim) + ] + + def __call__(self, x:Tensor) -> Tensor: + return x.sequential(self.net) + +class BasicTransformerBlock: + def __init__(self, dim:int, ctx_dim:int, n_heads:int, d_head:int): + self.attn1 = CrossAttention(dim, dim, n_heads, d_head) + self.ff = FeedForward(dim) + self.attn2 = CrossAttention(dim, ctx_dim, n_heads, d_head) + self.norm1 = LayerNorm(dim) + self.norm2 = LayerNorm(dim) + self.norm3 = LayerNorm(dim) + + def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor: + x = x + self.attn1(self.norm1(x)) + x = x + self.attn2(self.norm2(x), ctx=ctx) + x = x + self.ff(self.norm3(x)) + return x + +# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/attention.py#L619 +class SpatialTransformer: + def __init__(self, channels:int, n_heads:int, d_head:int, ctx_dim:Union[int,List[int]], use_linear:bool, depth:int=1): + if isinstance(ctx_dim, int): + ctx_dim = [ctx_dim]*depth + else: + assert isinstance(ctx_dim, list) and depth == len(ctx_dim) + self.norm = GroupNorm(32, channels) + assert channels == n_heads * d_head + self.proj_in = Linear(channels, channels) if use_linear else Conv2d(channels, channels, 1) + self.transformer_blocks = [BasicTransformerBlock(channels, ctx_dim[d], n_heads, d_head) for d in range(depth)] + self.proj_out = Linear(channels, channels) if use_linear else Conv2d(channels, channels, 1) + self.use_linear = use_linear + + def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor: + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + ops = [ (lambda z: z.reshape(b, c, h*w).permute(0,2,1)), (lambda z: self.proj_in(z)) ] + x = x.sequential(ops if self.use_linear else ops[::-1]) + for block in self.transformer_blocks: + x = block(x, ctx=ctx) + ops = [ (lambda z: self.proj_out(z)), (lambda z: z.permute(0,2,1).reshape(b, c, h, w)) ] + x = x.sequential(ops if self.use_linear else ops[::-1]) + return x + x_in + +class Downsample: + def __init__(self, channels:int): + self.op = Conv2d(channels, channels, 3, stride=2, padding=1) + + def __call__(self, x:Tensor) -> Tensor: + return self.op(x) + +class Upsample: + def __init__(self, channels:int): + self.conv = Conv2d(channels, channels, 3, padding=1) + + def __call__(self, x:Tensor) -> Tensor: + bs,c,py,px = x.shape + z = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2) + return self.conv(z) + +# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/openaimodel.py#L472 +class UNetModel: + def __init__(self, adm_in_ch:Optional[int], in_ch:int, out_ch:int, model_ch:int, attention_resolutions:List[int], num_res_blocks:int, channel_mult:List[int], transformer_depth:List[int], ctx_dim:Union[int,List[int]], use_linear:bool=False, d_head:Optional[int]=None, n_heads:Optional[int]=None): + self.model_ch = model_ch + self.num_res_blocks = [num_res_blocks] * len(channel_mult) + + self.attention_resolutions = attention_resolutions + self.d_head = d_head + self.n_heads = n_heads + def get_d_and_n_heads(dims:int) -> Tuple[int,int]: + if self.d_head is None: + assert self.n_heads is not None, f"d_head and n_heads cannot both be None" + return dims // self.n_heads, self.n_heads + else: + assert self.n_heads is None, f"d_head and n_heads cannot both be non-None" + return self.d_head, dims // self.d_head + + time_embed_dim = model_ch * 4 + self.time_embed = [ + Linear(model_ch, time_embed_dim), + Tensor.silu, + Linear(time_embed_dim, time_embed_dim), + ] + + if adm_in_ch is not None: + self.label_emb = [ + [ + Linear(adm_in_ch, time_embed_dim), + Tensor.silu, + Linear(time_embed_dim, time_embed_dim), + ] + ] + + self.input_blocks: List[Any] = [ + [Conv2d(in_ch, model_ch, 3, padding=1)] + ] + input_block_channels = [model_ch] + ch = model_ch + ds = 1 + for idx, mult in enumerate(channel_mult): + for _ in range(self.num_res_blocks[idx]): + layers: List[Any] = [ + ResBlock(ch, time_embed_dim, model_ch*mult), + ] + ch = mult * model_ch + if ds in attention_resolutions: + d_head, n_heads = get_d_and_n_heads(ch) + layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx])) + + self.input_blocks.append(layers) + input_block_channels.append(ch) + + if idx != len(channel_mult) - 1: + self.input_blocks.append([ + Downsample(ch), + ]) + input_block_channels.append(ch) + ds *= 2 + + d_head, n_heads = get_d_and_n_heads(ch) + self.middle_block: List = [ + ResBlock(ch, time_embed_dim, ch), + SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[-1]), + ResBlock(ch, time_embed_dim, ch), + ] + + self.output_blocks = [] + for idx, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[idx] + 1): + ich = input_block_channels.pop() + layers = [ + ResBlock(ch + ich, time_embed_dim, model_ch*mult), + ] + ch = model_ch * mult + + if ds in attention_resolutions: + d_head, n_heads = get_d_and_n_heads(ch) + layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx])) + + if idx > 0 and i == self.num_res_blocks[idx]: + layers.append(Upsample(ch)) + ds //= 2 + self.output_blocks.append(layers) + + self.out = [ + GroupNorm(32, ch), + Tensor.silu, + Conv2d(model_ch, out_ch, 3, padding=1), + ] + + def __call__(self, x:Tensor, tms:Tensor, ctx:Tensor, y:Optional[Tensor]=None) -> Tensor: + t_emb = timestep_embedding(tms, self.model_ch).cast(dtypes.float16) + emb = t_emb.sequential(self.time_embed) + + if y is not None: + assert y.shape[0] == x.shape[0] + emb = emb + y.sequential(self.label_emb[0]) + + emb = emb.cast(dtypes.float16) + ctx = ctx.cast(dtypes.float16) + x = x .cast(dtypes.float16) + + def run(x:Tensor, bb) -> Tensor: + if isinstance(bb, ResBlock): x = bb(x, emb) + elif isinstance(bb, SpatialTransformer): x = bb(x, ctx) + else: x = bb(x) + return x + + saved_inputs = [] + for b in self.input_blocks: + for bb in b: + x = run(x, bb) + saved_inputs.append(x) + for bb in self.middle_block: + x = run(x, bb) + for b in self.output_blocks: + x = x.cat(saved_inputs.pop(), dim=1) + for bb in b: + x = run(x, bb) + return x.sequential(self.out) diff --git a/test/external/external_test_jit_on_models.py b/test/external/external_test_jit_on_models.py index 92b18bdb03..c04a6ed8ae 100644 --- a/test/external/external_test_jit_on_models.py +++ b/test/external/external_test_jit_on_models.py @@ -30,8 +30,8 @@ class TestJittedModels(unittest.TestCase): @unittest.skipUnless(not CI, "huge for CI") def test_jitted_stable_diffusion(self): - from examples.stable_diffusion import UNetModel - model = UNetModel() + from examples.stable_diffusion import UNetModel, unet_params + model = UNetModel(**unet_params) derandomize_model(model) def test(t, t2): return model(t, 801, t2).realize() diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index 2b6e42b592..49a7531e81 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -12,7 +12,8 @@ from test.helpers import derandomize_model, is_dtype_supported from examples.gpt2 import Transformer as GPT2Transformer, MODEL_PARAMS as GPT2_MODEL_PARAMS from examples.hlb_cifar10 import SpeedyResNet, hyp from examples.llama import Transformer as LLaMaTransformer, MODEL_PARAMS as LLAMA_MODEL_PARAMS -from examples.stable_diffusion import UNetModel, ResBlock +from examples.stable_diffusion import UNetModel, unet_params +from extra.models.unet import ResBlock global_mem_used = 0 def helper_test(nm, gen, model, max_memory_allowed, max_kernels_allowed, all_jitted=False): @@ -49,7 +50,7 @@ class TestRealWorld(unittest.TestCase): @unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault") @unittest.skipIf(CI, "too big for CI") def test_stable_diffusion(self): - model = UNetModel() + model = UNetModel(**unet_params) derandomize_model(model) @TinyJit def test(t, t2): return model(t, 801, t2).realize()