mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
Pulled CLIP and UNet into Seperate Files (#5253)
* pulled clip and unet into seperate files * reference cleanup, lru cache fix * better pool indexing
This commit is contained in:
519
examples/sdxl.py
519
examples/sdxl.py
@@ -4,14 +4,16 @@
|
||||
# mlfoundations/open_clip | MIT | https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/LICENSE
|
||||
|
||||
from tinygrad import Tensor, TinyJit, dtypes
|
||||
from tinygrad.nn import Linear, Conv2d, GroupNorm, LayerNorm, Embedding
|
||||
from tinygrad.nn import Conv2d, GroupNorm
|
||||
from tinygrad.nn.state import safe_load, load_state_dict
|
||||
from tinygrad.helpers import fetch, trange, colored, THREEFRY
|
||||
from examples.stable_diffusion import ClipTokenizer, ResnetBlock, Mid, Downsample, Upsample
|
||||
from tinygrad.helpers import fetch, trange, colored
|
||||
from extra.models.clip import Embedder, FrozenClosedClipEmbedder, FrozenOpenClipEmbedder
|
||||
from extra.models.unet import UNetModel, Upsample, Downsample, timestep_embedding
|
||||
from examples.stable_diffusion import ResnetBlock, Mid
|
||||
import numpy as np
|
||||
|
||||
from typing import Dict, List, Union, Callable, Optional, Any, Set, Tuple
|
||||
import math, argparse, tempfile
|
||||
from typing import Dict, List, Callable, Optional, Any, Set, Tuple
|
||||
import argparse, tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
@@ -22,13 +24,13 @@ from PIL import Image
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/configs/inference/sd_xl_refiner.yaml
|
||||
configs: Dict = {
|
||||
"SDXL_Base": {
|
||||
"model": {"adm_in_channels": 2816, "in_channels": 4, "out_channels": 4, "model_channels": 320, "attention_resolutions": [4, 2], "num_res_blocks": 2, "channel_mult": [1, 2, 4], "d_head": 64, "transformer_depth": [1, 2, 10], "ctx_dim": 2048},
|
||||
"model": {"adm_in_ch": 2816, "in_ch": 4, "out_ch": 4, "model_ch": 320, "attention_resolutions": [4, 2], "num_res_blocks": 2, "channel_mult": [1, 2, 4], "d_head": 64, "transformer_depth": [1, 2, 10], "ctx_dim": 2048, "use_linear": True},
|
||||
"conditioner": {"concat_embedders": ["original_size_as_tuple", "crop_coords_top_left", "target_size_as_tuple"]},
|
||||
"first_stage_model": {"ch": 128, "in_ch": 3, "out_ch": 3, "z_ch": 4, "ch_mult": [1, 2, 4, 4], "num_res_blocks": 2, "resolution": 256},
|
||||
"denoiser": {"num_idx": 1000},
|
||||
},
|
||||
"SDXL_Refiner": {
|
||||
"model": {"adm_in_channels": 2560, "in_channels": 4, "out_channels": 4, "model_channels": 384, "attention_resolutions": [4, 2], "num_res_blocks": 2, "channel_mult": [1, 2, 4, 4], "d_head": 64, "transformer_depth": [4, 4, 4, 4], "ctx_dim": [1280, 1280, 1280, 1280]},
|
||||
"model": {"adm_in_ch": 2560, "in_ch": 4, "out_ch": 4, "model_ch": 384, "attention_resolutions": [4, 2], "num_res_blocks": 2, "channel_mult": [1, 2, 4, 4], "d_head": 64, "transformer_depth": [4, 4, 4, 4], "ctx_dim": [1280, 1280, 1280, 1280], "use_linear": True},
|
||||
"conditioner": {"concat_embedders": ["original_size_as_tuple", "crop_coords_top_left", "aesthetic_score"]},
|
||||
"first_stage_model": {"ch": 128, "in_ch": 3, "out_ch": 3, "z_ch": 4, "ch_mult": [1, 2, 4, 4], "num_res_blocks": 2, "resolution": 256},
|
||||
"denoiser": {"num_idx": 1000},
|
||||
@@ -40,256 +42,6 @@ def tensor_identity(x:Tensor) -> Tensor:
|
||||
return x
|
||||
|
||||
|
||||
class UNet:
|
||||
"""
|
||||
Namespace for UNet model components.
|
||||
"""
|
||||
|
||||
# https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L136
|
||||
class ResBlock:
|
||||
def __init__(self, channels:int, emb_channels:int, out_channels:int):
|
||||
self.in_layers = [
|
||||
GroupNorm(32, channels),
|
||||
Tensor.silu,
|
||||
Conv2d(channels, out_channels, 3, padding=1),
|
||||
]
|
||||
self.emb_layers = [
|
||||
Tensor.silu,
|
||||
Linear(emb_channels, out_channels),
|
||||
]
|
||||
self.out_layers = [
|
||||
GroupNorm(32, out_channels),
|
||||
Tensor.silu,
|
||||
lambda x: x, # needed for weights loading code to work
|
||||
Conv2d(out_channels, out_channels, 3, padding=1),
|
||||
]
|
||||
self.skip_connection = Conv2d(channels, out_channels, 1) if channels != out_channels else tensor_identity
|
||||
|
||||
def __call__(self, x:Tensor, emb:Tensor) -> Tensor:
|
||||
h = x.sequential(self.in_layers)
|
||||
emb_out = emb.sequential(self.emb_layers)
|
||||
h = h + emb_out.reshape(*emb_out.shape, 1, 1)
|
||||
h = h.sequential(self.out_layers)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
# https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L163
|
||||
class CrossAttention:
|
||||
def __init__(self, query_dim:int, ctx_dim:int, n_heads:int, d_head:int):
|
||||
self.to_q = Linear(query_dim, n_heads*d_head, bias=False)
|
||||
self.to_k = Linear(ctx_dim, n_heads*d_head, bias=False)
|
||||
self.to_v = Linear(ctx_dim, n_heads*d_head, bias=False)
|
||||
self.num_heads = n_heads
|
||||
self.head_size = d_head
|
||||
self.to_out = [Linear(n_heads*d_head, query_dim)]
|
||||
|
||||
def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
|
||||
ctx = x if ctx is None else ctx
|
||||
q,k,v = self.to_q(x), self.to_k(ctx), self.to_v(ctx)
|
||||
q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)]
|
||||
attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1,2)
|
||||
h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size)
|
||||
return h_.sequential(self.to_out)
|
||||
|
||||
|
||||
# https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L180
|
||||
class GEGLU:
|
||||
def __init__(self, dim_in:int, dim_out:int):
|
||||
self.proj = Linear(dim_in, dim_out * 2)
|
||||
self.dim_out = dim_out
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * gate.gelu()
|
||||
|
||||
|
||||
# https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L189
|
||||
class FeedForward:
|
||||
def __init__(self, dim:int, mult:int=4):
|
||||
self.net = [
|
||||
UNet.GEGLU(dim, dim*mult),
|
||||
lambda x: x, # needed for weights loading code to work
|
||||
Linear(dim*mult, dim)
|
||||
]
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return x.sequential(self.net)
|
||||
|
||||
|
||||
# https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L200
|
||||
class BasicTransformerBlock:
|
||||
def __init__(self, dim:int, ctx_dim:int, n_heads:int, d_head:int):
|
||||
self.attn1 = UNet.CrossAttention(dim, dim, n_heads, d_head)
|
||||
self.ff = UNet.FeedForward(dim)
|
||||
self.attn2 = UNet.CrossAttention(dim, ctx_dim, n_heads, d_head)
|
||||
self.norm1 = LayerNorm(dim)
|
||||
self.norm2 = LayerNorm(dim)
|
||||
self.norm3 = LayerNorm(dim)
|
||||
|
||||
def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
|
||||
x = x + self.attn1(self.norm1(x))
|
||||
x = x + self.attn2(self.norm2(x), ctx=ctx)
|
||||
x = x + self.ff(self.norm3(x))
|
||||
return x
|
||||
|
||||
|
||||
# https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L215
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/attention.py#L619
|
||||
class SpatialTransformer:
|
||||
def __init__(self, channels:int, n_heads:int, d_head:int, ctx_dim:Union[int,List[int]], depth:int=1):
|
||||
if isinstance(ctx_dim, int):
|
||||
ctx_dim = [ctx_dim]*depth
|
||||
else:
|
||||
assert isinstance(ctx_dim, list) and depth == len(ctx_dim)
|
||||
self.norm = GroupNorm(32, channels)
|
||||
assert channels == n_heads * d_head
|
||||
self.proj_in = Linear(channels, n_heads * d_head)
|
||||
self.transformer_blocks = [UNet.BasicTransformerBlock(channels, ctx_dim[d], n_heads, d_head) for d in range(depth)]
|
||||
self.proj_out = Linear(n_heads * d_head, channels)
|
||||
|
||||
def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
|
||||
B, C, H, W = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
x = x.reshape(B, C, H*W).permute(0,2,1) # b c h w -> b c (h w) -> b (h w) c
|
||||
x = self.proj_in(x)
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, ctx=ctx)
|
||||
x = self.proj_out(x)
|
||||
x = x.permute(0,2,1).reshape(B, C, H, W) # b (h w) c -> b c (h w) -> b c h w
|
||||
return x + x_in
|
||||
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/util.py#L207
|
||||
# https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L251
|
||||
def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp()
|
||||
args = timesteps.unsqueeze(1) * freqs.unsqueeze(0)
|
||||
return Tensor.cat(args.cos(), args.sin(), dim=-1).cast(dtypes.float16)
|
||||
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/openaimodel.py#L472
|
||||
# https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L257
|
||||
class UNetModel:
|
||||
def __init__(self, adm_in_channels:int, in_channels:int, out_channels:int, model_channels:int, attention_resolutions:List[int], num_res_blocks:int, channel_mult:List[int], d_head:int, transformer_depth:List[int], ctx_dim:Union[int,List[int]]):
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_res_blocks = [num_res_blocks] * len(channel_mult)
|
||||
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = 0.0
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = True
|
||||
self.num_classes = "sequential"
|
||||
self.use_checkpoint = False
|
||||
self.d_head = d_head
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = [
|
||||
Linear(model_channels, time_embed_dim),
|
||||
Tensor.silu,
|
||||
Linear(time_embed_dim, time_embed_dim),
|
||||
]
|
||||
|
||||
self.label_emb = [
|
||||
[
|
||||
Linear(adm_in_channels, time_embed_dim),
|
||||
Tensor.silu,
|
||||
Linear(time_embed_dim, time_embed_dim),
|
||||
]
|
||||
]
|
||||
|
||||
self.input_blocks = [
|
||||
[Conv2d(in_channels, model_channels, 3, padding=1)]
|
||||
]
|
||||
input_block_channels = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for idx, mult in enumerate(channel_mult):
|
||||
for _ in range(self.num_res_blocks[idx]):
|
||||
layers: List[Any] = [
|
||||
UNet.ResBlock(ch, time_embed_dim, model_channels*mult),
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
n_heads = ch // d_head
|
||||
layers.append(UNet.SpatialTransformer(ch, n_heads, d_head, ctx_dim, depth=transformer_depth[idx]))
|
||||
|
||||
self.input_blocks.append(layers)
|
||||
input_block_channels.append(ch)
|
||||
|
||||
if idx != len(channel_mult) - 1:
|
||||
self.input_blocks.append([
|
||||
Downsample(ch),
|
||||
])
|
||||
input_block_channels.append(ch)
|
||||
ds *= 2
|
||||
|
||||
n_heads = ch // d_head
|
||||
self.middle_block: List = [
|
||||
UNet.ResBlock(ch, time_embed_dim, ch),
|
||||
UNet.SpatialTransformer(ch, n_heads, d_head, ctx_dim, depth=transformer_depth[-1]),
|
||||
UNet.ResBlock(ch, time_embed_dim, ch),
|
||||
]
|
||||
|
||||
self.output_blocks = []
|
||||
for idx, mult in list(enumerate(channel_mult))[::-1]:
|
||||
for i in range(self.num_res_blocks[idx] + 1):
|
||||
ich = input_block_channels.pop()
|
||||
layers = [
|
||||
UNet.ResBlock(ch + ich, time_embed_dim, model_channels*mult),
|
||||
]
|
||||
ch = model_channels * mult
|
||||
|
||||
if ds in attention_resolutions:
|
||||
n_heads = ch // d_head
|
||||
layers.append(UNet.SpatialTransformer(ch, n_heads, d_head, ctx_dim, depth=transformer_depth[idx]))
|
||||
|
||||
if idx > 0 and i == self.num_res_blocks[idx]:
|
||||
layers.append(Upsample(ch))
|
||||
ds //= 2
|
||||
self.output_blocks.append(layers)
|
||||
|
||||
self.out = [
|
||||
GroupNorm(32, ch),
|
||||
Tensor.silu,
|
||||
Conv2d(model_channels, out_channels, 3, padding=1),
|
||||
]
|
||||
|
||||
def __call__(self, x:Tensor, tms:Tensor, ctx:Tensor, y:Tensor) -> Tensor:
|
||||
t_emb = timestep_embedding(tms, self.model_channels).cast(dtypes.float16)
|
||||
emb = t_emb.sequential(self.time_embed)
|
||||
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + y.sequential(self.label_emb[0])
|
||||
|
||||
emb = emb.cast(dtypes.float16)
|
||||
ctx = ctx.cast(dtypes.float16)
|
||||
x = x .cast(dtypes.float16)
|
||||
|
||||
def run(x:Tensor, bb) -> Tensor:
|
||||
if isinstance(bb, UNet.ResBlock): x = bb(x, emb)
|
||||
elif isinstance(bb, UNet.SpatialTransformer): x = bb(x, ctx)
|
||||
else: x = bb(x)
|
||||
return x
|
||||
|
||||
saved_inputs = []
|
||||
for b in self.input_blocks:
|
||||
for bb in b:
|
||||
x = run(x, bb)
|
||||
saved_inputs.append(x)
|
||||
for bb in self.middle_block:
|
||||
x = run(x, bb)
|
||||
for b in self.output_blocks:
|
||||
x = x.cat(saved_inputs.pop(), dim=1)
|
||||
for bb in b:
|
||||
x = run(x, bb)
|
||||
|
||||
return x.sequential(self.out)
|
||||
|
||||
|
||||
class DiffusionModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.diffusion_model = UNetModel(*args, **kwargs)
|
||||
@@ -302,230 +54,6 @@ class Embedder(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class Closed:
|
||||
"""
|
||||
Namespace for OpenAI CLIP model components.
|
||||
"""
|
||||
|
||||
# https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L329
|
||||
class ClipMlp:
|
||||
def __init__(self):
|
||||
self.fc1 = Linear(768, 3072)
|
||||
self.fc2 = Linear(3072, 768)
|
||||
|
||||
def __call__(self, h:Tensor) -> Tensor:
|
||||
h = self.fc1(h)
|
||||
h = h.quick_gelu()
|
||||
h = self.fc2(h)
|
||||
return h
|
||||
|
||||
|
||||
# https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L340
|
||||
class ClipAttention:
|
||||
def __init__(self):
|
||||
self.embed_dim = 768
|
||||
self.num_heads = 12
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
self.k_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor:
|
||||
bsz, tgt_len, embed_dim = hidden_states.shape
|
||||
q,k,v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
|
||||
q,k,v = [x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) for x in (q,k,v)]
|
||||
attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=causal_attention_mask)
|
||||
return self.out_proj(attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim))
|
||||
|
||||
|
||||
# https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L357
|
||||
class ClipEncoderLayer:
|
||||
def __init__(self):
|
||||
self.self_attn = Closed.ClipAttention()
|
||||
self.layer_norm1 = LayerNorm(768)
|
||||
self.mlp = Closed.ClipMlp()
|
||||
self.layer_norm2 = LayerNorm(768)
|
||||
|
||||
def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states, causal_attention_mask)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L386
|
||||
class ClipTextEmbeddings:
|
||||
def __init__(self):
|
||||
self.token_embedding = Embedding(49408, 768)
|
||||
self.position_embedding = Embedding(77, 768)
|
||||
|
||||
def __call__(self, input_ids:Tensor, position_ids:Tensor) -> Tensor:
|
||||
return self.token_embedding(input_ids) + self.position_embedding(position_ids)
|
||||
|
||||
|
||||
# https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L377
|
||||
class ClipEncoder:
|
||||
def __init__(self, layer_count:int=12):
|
||||
self.layers = [Closed.ClipEncoderLayer() for _ in range(layer_count)]
|
||||
|
||||
def __call__(self, x:Tensor, causal_attention_mask:Tensor, ret_layer_idx:Optional[int]=None) -> Tensor:
|
||||
# the indexing of layers is NOT off by 1, the original code considers the "input" as the first hidden state
|
||||
layers = self.layers if ret_layer_idx is None else self.layers[:ret_layer_idx]
|
||||
for l in layers:
|
||||
x = l(x, causal_attention_mask)
|
||||
return x
|
||||
|
||||
|
||||
# https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L394
|
||||
class ClipTextTransformer:
|
||||
def __init__(self, ret_layer_idx:Optional[int]=None):
|
||||
self.embeddings = Closed.ClipTextEmbeddings()
|
||||
self.encoder = Closed.ClipEncoder()
|
||||
self.final_layer_norm = LayerNorm(768)
|
||||
self.ret_layer_idx = ret_layer_idx
|
||||
|
||||
def __call__(self, input_ids:Tensor) -> Tensor:
|
||||
x = self.embeddings(input_ids, Tensor.arange(input_ids.shape[1]).reshape(1, -1))
|
||||
x = self.encoder(x, Tensor.full((1, 1, 77, 77), float("-inf")).triu(1), self.ret_layer_idx)
|
||||
return self.final_layer_norm(x) if (self.ret_layer_idx is None) else x
|
||||
|
||||
class ClipTextModel:
|
||||
def __init__(self):
|
||||
self.text_model = Closed.ClipTextTransformer(ret_layer_idx=11)
|
||||
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L331
|
||||
class FrozenClosedClipEmbedder(Embedder):
|
||||
def __init__(self):
|
||||
self.tokenizer = ClipTokenizer()
|
||||
self.transformer = Closed.ClipTextModel()
|
||||
self.input_key = "txt"
|
||||
|
||||
def __call__(self, text:Tensor) -> Tensor:
|
||||
tokens = Tensor(self.tokenizer.encode(text))
|
||||
return self.transformer.text_model(tokens.reshape(1,-1))
|
||||
|
||||
|
||||
class Open:
|
||||
"""
|
||||
Namespace for OpenCLIP model components.
|
||||
"""
|
||||
|
||||
class MultiheadAttention:
|
||||
def __init__(self, dims:int, n_heads:int):
|
||||
self.dims = dims
|
||||
self.n_heads = n_heads
|
||||
self.d_head = self.dims // self.n_heads
|
||||
|
||||
self.in_proj_bias = Tensor.empty(3*dims)
|
||||
self.in_proj_weight = Tensor.empty(3*dims, dims)
|
||||
self.out_proj = Linear(dims, dims)
|
||||
|
||||
def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor:
|
||||
T,B,C = x.shape
|
||||
|
||||
proj = x.linear(self.in_proj_weight.T, self.in_proj_bias)
|
||||
proj = proj.unflatten(-1, (3,C)).unsqueeze(0).transpose(0,-2)
|
||||
q,k,v = proj.chunk(3)
|
||||
|
||||
q,k,v = [y.reshape(T, B*self.n_heads, self.d_head).transpose(0, 1).reshape(B, self.n_heads, T, self.d_head) for y in (q,k,v)]
|
||||
|
||||
attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
attn_output = attn_output.permute(2,0,1,3).reshape(B*T, C)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
attn_output = attn_output.reshape(T, B, C)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class Mlp:
|
||||
def __init__(self, dims, hidden_dims):
|
||||
self.c_fc = Linear(dims, hidden_dims)
|
||||
self.c_proj = Linear(hidden_dims, dims)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return x.sequential([self.c_fc, Tensor.gelu, self.c_proj])
|
||||
|
||||
|
||||
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L210
|
||||
class ResidualAttentionBlocks:
|
||||
def __init__(self, dims:int, n_heads:int, mlp_ratio:float):
|
||||
self.ln_1 = LayerNorm(dims)
|
||||
self.attn = Open.MultiheadAttention(dims, n_heads)
|
||||
|
||||
self.ln_2 = LayerNorm(dims)
|
||||
self.mlp = Open.Mlp(dims, int(dims * mlp_ratio))
|
||||
|
||||
def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor:
|
||||
x = x + self.attn(self.ln_1(x), attn_mask=attn_mask)
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L317
|
||||
class ClipTransformer:
|
||||
def __init__(self, dims:int, layers:int, n_heads:int, mlp_ratio:float=4.0):
|
||||
self.resblocks = [
|
||||
Open.ResidualAttentionBlocks(dims, n_heads, mlp_ratio) for _ in range(layers)
|
||||
]
|
||||
|
||||
def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor:
|
||||
x = x.transpose(0, 1)
|
||||
for r in self.resblocks:
|
||||
x = r(x, attn_mask=attn_mask)
|
||||
x = x.transpose(0, 1)
|
||||
return x
|
||||
|
||||
|
||||
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/model.py#L220
|
||||
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L661
|
||||
class ClipTextTransformer:
|
||||
def __init__(self, dims:int, vocab_size:int=49408, n_heads:int=20, ctx_length:int=77, layers:int=32):
|
||||
self.token_embedding = Embedding(vocab_size, dims)
|
||||
self.positional_embedding = Tensor.empty(ctx_length, dims)
|
||||
self.transformer = Open.ClipTransformer(dims, layers, n_heads)
|
||||
self.ln_final = LayerNorm(dims)
|
||||
self.text_projection = Tensor.empty(dims, dims)
|
||||
|
||||
@property
|
||||
def attn_mask(self) -> Tensor:
|
||||
if not hasattr(self, "_attn_mask"):
|
||||
self._attn_mask = Tensor.full((77, 77), float("-inf")).triu(1)
|
||||
return self._attn_mask
|
||||
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L396
|
||||
class FrozenOpenClipEmbedder(Embedder):
|
||||
def __init__(self, dims:int=1280):
|
||||
self.model = Open.ClipTextTransformer(dims)
|
||||
self.input_key = "txt"
|
||||
self.tokenizer = ClipTokenizer()
|
||||
|
||||
def text_transformer_forward(self, x:Tensor, attn_mask:Optional[Tensor]=None):
|
||||
for r in self.model.transformer.resblocks:
|
||||
x, penultimate = r(x, attn_mask=attn_mask), x
|
||||
return x.permute(1,0,2), penultimate.permute(1,0,2)
|
||||
|
||||
def __call__(self, text:Tensor) -> Tensor:
|
||||
tokens = Tensor(self.tokenizer.encode(text, pad_with_zeros=True), dtype=dtypes.int64).reshape(1,-1)
|
||||
|
||||
x = self.model.token_embedding(tokens).add(self.model.positional_embedding).permute(1,0,2)
|
||||
x, penultimate = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
||||
x = self.model.ln_final(x)
|
||||
pooled = x[Tensor.arange(x.shape[0]), tokens.argmax(axis=-1)] @ self.model.text_projection
|
||||
|
||||
return penultimate, pooled
|
||||
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L913
|
||||
class ConcatTimestepEmbedderND(Embedder):
|
||||
def __init__(self, outdim:int, input_key:str):
|
||||
@@ -547,8 +75,8 @@ class Conditioner:
|
||||
|
||||
def __init__(self, concat_embedders:List[str]):
|
||||
self.embedders = [
|
||||
FrozenClosedClipEmbedder(),
|
||||
FrozenOpenClipEmbedder(),
|
||||
FrozenClosedClipEmbedder(ret_layer_idx=11),
|
||||
FrozenOpenClipEmbedder(dims=1280, n_heads=20, layers=32, return_pooled=True),
|
||||
]
|
||||
for input_key in concat_embedders:
|
||||
self.embedders.append(ConcatTimestepEmbedderND(256, input_key))
|
||||
@@ -585,28 +113,6 @@ class FirstStage:
|
||||
Namespace for First Stage Model components
|
||||
"""
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/model.py#L74
|
||||
# https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L102
|
||||
class Downsample:
|
||||
def __init__(self, dims:int):
|
||||
self.conv = Conv2d(dims, dims, kernel_size=3, stride=2, padding=(0,1,0,1))
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/model.py#L58
|
||||
# https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L83
|
||||
class Upsample:
|
||||
def __init__(self, dims:int):
|
||||
self.conv = Conv2d(dims, dims, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
B,C,Y,X = x.shape
|
||||
x = x.reshape(B, C, Y, 1, X, 1).expand(B, C, Y, 2, X, 2).reshape(B, C, Y*2, X*2)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/model.py#L487
|
||||
class Encoder:
|
||||
def __init__(self, ch:int, in_ch:int, out_ch:int, z_ch:int, ch_mult:List[int], num_res_blocks:int, resolution:int):
|
||||
@@ -626,7 +132,7 @@ class FirstStage:
|
||||
block.append(ResnetBlock(block_in, block_out))
|
||||
block_in = block_out
|
||||
|
||||
downsample = tensor_identity if (i_level == len(ch_mult)-1) else FirstStage.Downsample(block_in)
|
||||
downsample = tensor_identity if (i_level == len(ch_mult)-1) else Downsample(block_in)
|
||||
self.down.append(BlockEntry(block, downsample))
|
||||
|
||||
self.mid = Mid(block_in)
|
||||
@@ -779,7 +285,6 @@ class SDXL:
|
||||
|
||||
return run(self.model.diffusion_model, *prep(x*c_in, c_noise, cond["crossattn"], cond["vector"], c_out, x))
|
||||
|
||||
# https://github.com/tinygrad/tinygrad/blob/64cda3c481613f4ca98eeb40ad2bce7a9d0749a3/examples/stable_diffusion.py#L543
|
||||
def decode(self, x:Tensor) -> Tensor:
|
||||
return self.first_stage_model.decode(1.0 / 0.13025 * x)
|
||||
|
||||
|
||||
@@ -2,16 +2,18 @@
|
||||
# https://github.com/ekagra-ranjan/huggingface-blog/blob/main/stable_diffusion.md
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
import gzip, argparse, math, re
|
||||
from functools import lru_cache
|
||||
import argparse
|
||||
from collections import namedtuple
|
||||
from typing import Dict, Any
|
||||
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from tinygrad import Device, GlobalCounters, dtypes, Tensor, TinyJit
|
||||
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm, THREEFRY
|
||||
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
|
||||
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm
|
||||
from tinygrad.nn import Conv2d, GroupNorm
|
||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||
from extra.models.clip import Closed, Tokenizer
|
||||
from extra.models.unet import UNetModel
|
||||
|
||||
class AttnBlock:
|
||||
def __init__(self, in_channels):
|
||||
@@ -131,391 +133,32 @@ class AutoencoderKL:
|
||||
latent = self.post_quant_conv(latent)
|
||||
return self.decoder(latent)
|
||||
|
||||
# not to be confused with ResnetBlock
|
||||
class ResBlock:
|
||||
def __init__(self, channels, emb_channels, out_channels):
|
||||
self.in_layers = [
|
||||
GroupNorm(32, channels),
|
||||
Tensor.silu,
|
||||
Conv2d(channels, out_channels, 3, padding=1)
|
||||
]
|
||||
self.emb_layers = [
|
||||
Tensor.silu,
|
||||
Linear(emb_channels, out_channels)
|
||||
]
|
||||
self.out_layers = [
|
||||
GroupNorm(32, out_channels),
|
||||
Tensor.silu,
|
||||
lambda x: x, # needed for weights loading code to work
|
||||
Conv2d(out_channels, out_channels, 3, padding=1)
|
||||
]
|
||||
self.skip_connection = Conv2d(channels, out_channels, 1) if channels != out_channels else lambda x: x
|
||||
|
||||
def __call__(self, x, emb):
|
||||
h = x.sequential(self.in_layers)
|
||||
emb_out = emb.sequential(self.emb_layers)
|
||||
h = h + emb_out.reshape(*emb_out.shape, 1, 1)
|
||||
h = h.sequential(self.out_layers)
|
||||
ret = self.skip_connection(x) + h
|
||||
return ret
|
||||
|
||||
class CrossAttention:
|
||||
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
||||
self.to_q = Linear(query_dim, n_heads*d_head, bias=False)
|
||||
self.to_k = Linear(context_dim, n_heads*d_head, bias=False)
|
||||
self.to_v = Linear(context_dim, n_heads*d_head, bias=False)
|
||||
self.num_heads = n_heads
|
||||
self.head_size = d_head
|
||||
self.to_out = [Linear(n_heads*d_head, query_dim)]
|
||||
|
||||
def __call__(self, x, context=None):
|
||||
context = x if context is None else context
|
||||
q,k,v = self.to_q(x), self.to_k(context), self.to_v(context)
|
||||
q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)]
|
||||
attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1,2)
|
||||
h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size)
|
||||
return h_.sequential(self.to_out)
|
||||
|
||||
class GEGLU:
|
||||
def __init__(self, dim_in, dim_out):
|
||||
self.proj = Linear(dim_in, dim_out * 2)
|
||||
self.dim_out = dim_out
|
||||
|
||||
def __call__(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * gate.gelu()
|
||||
|
||||
class FeedForward:
|
||||
def __init__(self, dim, mult=4):
|
||||
self.net = [
|
||||
GEGLU(dim, dim*mult),
|
||||
lambda x: x, # needed for weights loading code to work
|
||||
Linear(dim*mult, dim)
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
return x.sequential(self.net)
|
||||
|
||||
class BasicTransformerBlock:
|
||||
def __init__(self, dim, context_dim, n_heads, d_head):
|
||||
self.attn1 = CrossAttention(dim, dim, n_heads, d_head)
|
||||
self.ff = FeedForward(dim)
|
||||
self.attn2 = CrossAttention(dim, context_dim, n_heads, d_head)
|
||||
self.norm1 = LayerNorm(dim)
|
||||
self.norm2 = LayerNorm(dim)
|
||||
self.norm3 = LayerNorm(dim)
|
||||
|
||||
def __call__(self, x, context=None):
|
||||
x = self.attn1(self.norm1(x)) + x
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
class SpatialTransformer:
|
||||
def __init__(self, channels, context_dim, n_heads, d_head):
|
||||
self.norm = GroupNorm(32, channels)
|
||||
assert channels == n_heads * d_head
|
||||
self.proj_in = Conv2d(channels, n_heads * d_head, 1)
|
||||
self.transformer_blocks = [BasicTransformerBlock(channels, context_dim, n_heads, d_head)]
|
||||
self.proj_out = Conv2d(n_heads * d_head, channels, 1)
|
||||
|
||||
def __call__(self, x, context=None):
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
x = self.proj_in(x)
|
||||
x = x.reshape(b, c, h*w).permute(0,2,1)
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context=context)
|
||||
x = x.permute(0,2,1).reshape(b, c, h, w)
|
||||
ret = self.proj_out(x) + x_in
|
||||
return ret
|
||||
|
||||
class Downsample:
|
||||
def __init__(self, channels):
|
||||
self.op = Conv2d(channels, channels, 3, stride=2, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.op(x)
|
||||
|
||||
class Upsample:
|
||||
def __init__(self, channels):
|
||||
self.conv = Conv2d(channels, channels, 3, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
bs,c,py,px = x.shape
|
||||
x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2)
|
||||
return self.conv(x)
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp()
|
||||
args = timesteps * freqs
|
||||
return Tensor.cat(args.cos(), args.sin()).reshape(1, -1)
|
||||
|
||||
class UNetModel:
|
||||
def __init__(self):
|
||||
self.time_embed = [
|
||||
Linear(320, 1280),
|
||||
Tensor.silu,
|
||||
Linear(1280, 1280),
|
||||
]
|
||||
self.input_blocks = [
|
||||
[Conv2d(4, 320, kernel_size=3, padding=1)],
|
||||
[ResBlock(320, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
|
||||
[ResBlock(320, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
|
||||
[Downsample(320)],
|
||||
[ResBlock(320, 1280, 640), SpatialTransformer(640, 768, 8, 80)],
|
||||
[ResBlock(640, 1280, 640), SpatialTransformer(640, 768, 8, 80)],
|
||||
[Downsample(640)],
|
||||
[ResBlock(640, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
|
||||
[ResBlock(1280, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
|
||||
[Downsample(1280)],
|
||||
[ResBlock(1280, 1280, 1280)],
|
||||
[ResBlock(1280, 1280, 1280)]
|
||||
]
|
||||
self.middle_block = [
|
||||
ResBlock(1280, 1280, 1280),
|
||||
SpatialTransformer(1280, 768, 8, 160),
|
||||
ResBlock(1280, 1280, 1280)
|
||||
]
|
||||
self.output_blocks = [
|
||||
[ResBlock(2560, 1280, 1280)],
|
||||
[ResBlock(2560, 1280, 1280)],
|
||||
[ResBlock(2560, 1280, 1280), Upsample(1280)],
|
||||
[ResBlock(2560, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
|
||||
[ResBlock(2560, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
|
||||
[ResBlock(1920, 1280, 1280), SpatialTransformer(1280, 768, 8, 160), Upsample(1280)],
|
||||
[ResBlock(1920, 1280, 640), SpatialTransformer(640, 768, 8, 80)], # 6
|
||||
[ResBlock(1280, 1280, 640), SpatialTransformer(640, 768, 8, 80)],
|
||||
[ResBlock(960, 1280, 640), SpatialTransformer(640, 768, 8, 80), Upsample(640)],
|
||||
[ResBlock(960, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
|
||||
[ResBlock(640, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
|
||||
[ResBlock(640, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
|
||||
]
|
||||
self.out = [
|
||||
GroupNorm(32, 320),
|
||||
Tensor.silu,
|
||||
Conv2d(320, 4, kernel_size=3, padding=1)
|
||||
]
|
||||
|
||||
def __call__(self, x, timesteps=None, context=None):
|
||||
# TODO: real time embedding
|
||||
t_emb = timestep_embedding(timesteps, 320)
|
||||
emb = t_emb.sequential(self.time_embed)
|
||||
|
||||
def run(x, bb):
|
||||
if isinstance(bb, ResBlock): x = bb(x, emb)
|
||||
elif isinstance(bb, SpatialTransformer): x = bb(x, context)
|
||||
else: x = bb(x)
|
||||
return x
|
||||
|
||||
saved_inputs = []
|
||||
for i,b in enumerate(self.input_blocks):
|
||||
#print("input block", i)
|
||||
for bb in b:
|
||||
x = run(x, bb)
|
||||
saved_inputs.append(x)
|
||||
for bb in self.middle_block:
|
||||
x = run(x, bb)
|
||||
for i,b in enumerate(self.output_blocks):
|
||||
#print("output block", i)
|
||||
x = x.cat(saved_inputs.pop(), dim=1)
|
||||
for bb in b:
|
||||
x = run(x, bb)
|
||||
return x.sequential(self.out)
|
||||
|
||||
class CLIPMLP:
|
||||
def __init__(self):
|
||||
self.fc1 = Linear(768, 3072)
|
||||
self.fc2 = Linear(3072, 768)
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = hidden_states.quick_gelu()
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
class CLIPAttention:
|
||||
def __init__(self):
|
||||
self.embed_dim = 768
|
||||
self.num_heads = 12
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
self.k_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def __call__(self, hidden_states, causal_attention_mask):
|
||||
bsz, tgt_len, embed_dim = hidden_states.shape
|
||||
q,k,v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
|
||||
q,k,v = [x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) for x in (q,k,v)]
|
||||
attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=causal_attention_mask)
|
||||
return self.out_proj(attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim))
|
||||
|
||||
class CLIPEncoderLayer:
|
||||
def __init__(self):
|
||||
self.self_attn = CLIPAttention()
|
||||
self.layer_norm1 = LayerNorm(768)
|
||||
self.mlp = CLIPMLP()
|
||||
self.layer_norm2 = LayerNorm(768)
|
||||
|
||||
def __call__(self, hidden_states, causal_attention_mask):
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states, causal_attention_mask)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
class CLIPEncoder:
|
||||
def __init__(self):
|
||||
self.layers = [CLIPEncoderLayer() for i in range(12)]
|
||||
|
||||
def __call__(self, hidden_states, causal_attention_mask):
|
||||
for l in self.layers:
|
||||
hidden_states = l(hidden_states, causal_attention_mask)
|
||||
return hidden_states
|
||||
|
||||
class CLIPTextEmbeddings:
|
||||
def __init__(self):
|
||||
self.token_embedding = Embedding(49408, 768)
|
||||
self.position_embedding = Embedding(77, 768)
|
||||
|
||||
def __call__(self, input_ids, position_ids):
|
||||
return self.token_embedding(input_ids) + self.position_embedding(position_ids)
|
||||
|
||||
class CLIPTextTransformer:
|
||||
def __init__(self):
|
||||
self.embeddings = CLIPTextEmbeddings()
|
||||
self.encoder = CLIPEncoder()
|
||||
self.final_layer_norm = LayerNorm(768)
|
||||
|
||||
def __call__(self, input_ids):
|
||||
x = self.embeddings(input_ids, Tensor.arange(input_ids.shape[1]).reshape(1, -1))
|
||||
x = self.encoder(x, Tensor.full((1, 1, 77, 77), float("-inf")).triu(1))
|
||||
return self.final_layer_norm(x)
|
||||
|
||||
# Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
|
||||
@lru_cache()
|
||||
def default_bpe(): return fetch("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", "bpe_simple_vocab_16e6.txt.gz")
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
return set(zip(word, word[1:]))
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a significant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8+n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
class ClipTokenizer:
|
||||
def __init__(self, bpe_path: str = default_bpe()):
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
||||
merges = merges[1:49152-256-2+1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
vocab = list(bytes_to_unicode().values())
|
||||
vocab = vocab + [v+'</w>' for v in vocab]
|
||||
for merge in merges:
|
||||
vocab.append(''.join(merge))
|
||||
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
||||
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
||||
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE)
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token+'</w>'
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except Exception:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
||||
new_word.append(first+second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
pairs = get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text, pad_with_zeros=False):
|
||||
bpe_tokens = []
|
||||
text = whitespace_clean(text.strip()).lower()
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||
# Truncation, keeping two slots for start and end tokens.
|
||||
if len(bpe_tokens) > 75:
|
||||
bpe_tokens = bpe_tokens[:75]
|
||||
return [49406] + bpe_tokens + [49407] + ([0] if pad_with_zeros else [49407]) * (77 - len(bpe_tokens) - 2)
|
||||
|
||||
def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=1000):
|
||||
betas = np.linspace(beta_start ** 0.5, beta_end ** 0.5, n_training_steps, dtype=np.float32) ** 2
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
return Tensor(alphas_cumprod)
|
||||
|
||||
unet_params: Dict[str,Any] = {
|
||||
"adm_in_ch": None,
|
||||
"in_ch": 4,
|
||||
"out_ch": 4,
|
||||
"model_ch": 320,
|
||||
"attention_resolutions": [4, 2, 1],
|
||||
"num_res_blocks": 2,
|
||||
"channel_mult": [1, 2, 4, 4],
|
||||
"n_heads": 8,
|
||||
"transformer_depth": [1, 1, 1, 1],
|
||||
"ctx_dim": 768,
|
||||
"use_linear": False,
|
||||
}
|
||||
|
||||
class StableDiffusion:
|
||||
def __init__(self):
|
||||
self.alphas_cumprod = get_alphas_cumprod()
|
||||
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel())
|
||||
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel(**unet_params))
|
||||
self.first_stage_model = AutoencoderKL()
|
||||
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = CLIPTextTransformer()))
|
||||
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = Closed.ClipTextTransformer()))
|
||||
|
||||
def get_x_prev_and_pred_x0(self, x, e_t, a_t, a_prev):
|
||||
temperature = 1
|
||||
@@ -597,7 +240,7 @@ if __name__ == "__main__":
|
||||
v.replace(v.cast(dtypes.float16).realize())
|
||||
|
||||
# run through CLIP to get context
|
||||
tokenizer = ClipTokenizer()
|
||||
tokenizer = Tokenizer.ClipTokenizer()
|
||||
prompt = Tensor([tokenizer.encode(args.prompt)])
|
||||
context = model.cond_stage_model.transformer.text_model(prompt).realize()
|
||||
print("got CLIP context", context.shape)
|
||||
|
||||
348
extra/models/clip.py
Normal file
348
extra/models/clip.py
Normal file
@@ -0,0 +1,348 @@
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.helpers import fetch
|
||||
from tinygrad.nn import Linear, LayerNorm, Embedding
|
||||
|
||||
from typing import List, Optional, Union, Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import lru_cache
|
||||
import re, gzip
|
||||
|
||||
@lru_cache()
|
||||
def default_bpe():
|
||||
# Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
|
||||
return fetch("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", "bpe_simple_vocab_16e6.txt.gz")
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
Namespace for CLIP Text Tokenizer components.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_pairs(word):
|
||||
"""
|
||||
Return set of symbol pairs in a word.
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
return set(zip(word, word[1:]))
|
||||
@staticmethod
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
@staticmethod
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a significant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8+n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
class ClipTokenizer:
|
||||
def __init__(self):
|
||||
self.byte_encoder = Tokenizer.bytes_to_unicode()
|
||||
merges = gzip.open(default_bpe()).read().decode("utf-8").split('\n')
|
||||
merges = merges[1:49152-256-2+1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
vocab = list(Tokenizer.bytes_to_unicode().values())
|
||||
vocab = vocab + [v+'</w>' for v in vocab]
|
||||
for merge in merges:
|
||||
vocab.append(''.join(merge))
|
||||
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
||||
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
||||
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE)
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
||||
pairs = Tokenizer.get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token+'</w>'
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except Exception:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
||||
new_word.append(first+second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
pairs = Tokenizer.get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text:str, pad_with_zeros:bool=False):
|
||||
bpe_tokens: List[int] = []
|
||||
text = Tokenizer.whitespace_clean(text.strip()).lower()
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||
# Truncation, keeping two slots for start and end tokens.
|
||||
if len(bpe_tokens) > 75:
|
||||
bpe_tokens = bpe_tokens[:75]
|
||||
return [49406] + bpe_tokens + [49407] + ([0] if pad_with_zeros else [49407]) * (77 - len(bpe_tokens) - 2)
|
||||
|
||||
|
||||
class Embedder(ABC):
|
||||
input_key: str
|
||||
@abstractmethod
|
||||
def __call__(self, text:str) -> Union[Tensor,Tuple[Tensor,...]]:
|
||||
pass
|
||||
|
||||
|
||||
class Closed:
|
||||
"""
|
||||
Namespace for OpenAI CLIP model components.
|
||||
"""
|
||||
class ClipMlp:
|
||||
def __init__(self):
|
||||
self.fc1 = Linear(768, 3072)
|
||||
self.fc2 = Linear(3072, 768)
|
||||
|
||||
def __call__(self, h:Tensor) -> Tensor:
|
||||
h = self.fc1(h)
|
||||
h = h.quick_gelu()
|
||||
h = self.fc2(h)
|
||||
return h
|
||||
|
||||
class ClipAttention:
|
||||
def __init__(self):
|
||||
self.embed_dim = 768
|
||||
self.num_heads = 12
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
self.k_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor:
|
||||
bsz, tgt_len, embed_dim = hidden_states.shape
|
||||
q,k,v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
|
||||
q,k,v = [x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) for x in (q,k,v)]
|
||||
attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=causal_attention_mask)
|
||||
return self.out_proj(attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim))
|
||||
|
||||
class ClipEncoderLayer:
|
||||
def __init__(self):
|
||||
self.self_attn = Closed.ClipAttention()
|
||||
self.layer_norm1 = LayerNorm(768)
|
||||
self.mlp = Closed.ClipMlp()
|
||||
self.layer_norm2 = LayerNorm(768)
|
||||
|
||||
def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states, causal_attention_mask)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
class ClipTextEmbeddings:
|
||||
def __init__(self):
|
||||
self.token_embedding = Embedding(49408, 768)
|
||||
self.position_embedding = Embedding(77, 768)
|
||||
|
||||
def __call__(self, input_ids:Tensor, position_ids:Tensor) -> Tensor:
|
||||
return self.token_embedding(input_ids) + self.position_embedding(position_ids)
|
||||
|
||||
class ClipEncoder:
|
||||
def __init__(self, layer_count:int=12):
|
||||
self.layers = [Closed.ClipEncoderLayer() for _ in range(layer_count)]
|
||||
|
||||
def __call__(self, x:Tensor, causal_attention_mask:Tensor, ret_layer_idx:Optional[int]=None) -> Tensor:
|
||||
# the indexing of layers is NOT off by 1, the original code considers the "input" as the first hidden state
|
||||
layers = self.layers if ret_layer_idx is None else self.layers[:ret_layer_idx]
|
||||
for l in layers:
|
||||
x = l(x, causal_attention_mask)
|
||||
return x
|
||||
|
||||
class ClipTextTransformer:
|
||||
def __init__(self, ret_layer_idx:Optional[int]=None):
|
||||
self.embeddings = Closed.ClipTextEmbeddings()
|
||||
self.encoder = Closed.ClipEncoder()
|
||||
self.final_layer_norm = LayerNorm(768)
|
||||
self.ret_layer_idx = ret_layer_idx
|
||||
|
||||
def __call__(self, input_ids:Tensor) -> Tensor:
|
||||
x = self.embeddings(input_ids, Tensor.arange(input_ids.shape[1]).reshape(1, -1))
|
||||
x = self.encoder(x, Tensor.full((1, 1, 77, 77), float("-inf")).triu(1), self.ret_layer_idx)
|
||||
return self.final_layer_norm(x) if (self.ret_layer_idx is None) else x
|
||||
|
||||
class ClipTextModel:
|
||||
def __init__(self, ret_layer_idx:Optional[int]):
|
||||
self.text_model = Closed.ClipTextTransformer(ret_layer_idx=ret_layer_idx)
|
||||
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L331
|
||||
class FrozenClosedClipEmbedder(Embedder):
|
||||
def __init__(self, ret_layer_idx:Optional[int]=None):
|
||||
self.tokenizer = Tokenizer.ClipTokenizer()
|
||||
self.transformer = Closed.ClipTextModel(ret_layer_idx)
|
||||
self.input_key = "txt"
|
||||
|
||||
def __call__(self, text:str) -> Union[Tensor,Tuple[Tensor,...]]:
|
||||
tokens = Tensor(self.tokenizer.encode(text))
|
||||
return self.transformer.text_model(tokens.reshape(1,-1))
|
||||
|
||||
|
||||
class Open:
|
||||
"""
|
||||
Namespace for OpenCLIP model components.
|
||||
"""
|
||||
class MultiheadAttention:
|
||||
def __init__(self, dims:int, n_heads:int):
|
||||
self.dims = dims
|
||||
self.n_heads = n_heads
|
||||
self.d_head = self.dims // self.n_heads
|
||||
|
||||
self.in_proj_bias = Tensor.empty(3*dims)
|
||||
self.in_proj_weight = Tensor.empty(3*dims, dims)
|
||||
self.out_proj = Linear(dims, dims)
|
||||
|
||||
def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor:
|
||||
T,B,C = x.shape
|
||||
|
||||
proj = x.linear(self.in_proj_weight.T, self.in_proj_bias)
|
||||
proj = proj.unflatten(-1, (3,C)).unsqueeze(0).transpose(0,-2)
|
||||
q,k,v = proj.chunk(3)
|
||||
|
||||
q,k,v = [y.reshape(T, B*self.n_heads, self.d_head).transpose(0, 1).reshape(B, self.n_heads, T, self.d_head) for y in (q,k,v)]
|
||||
|
||||
attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
attn_output = attn_output.permute(2,0,1,3).reshape(B*T, C)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
attn_output = attn_output.reshape(T, B, C)
|
||||
|
||||
return attn_output
|
||||
|
||||
class Mlp:
|
||||
def __init__(self, dims, hidden_dims):
|
||||
self.c_fc = Linear(dims, hidden_dims)
|
||||
self.c_proj = Linear(hidden_dims, dims)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return x.sequential([self.c_fc, Tensor.gelu, self.c_proj])
|
||||
|
||||
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L210
|
||||
class ResidualAttentionBlocks:
|
||||
def __init__(self, dims:int, n_heads:int, mlp_ratio:float):
|
||||
self.ln_1 = LayerNorm(dims)
|
||||
self.attn = Open.MultiheadAttention(dims, n_heads)
|
||||
|
||||
self.ln_2 = LayerNorm(dims)
|
||||
self.mlp = Open.Mlp(dims, int(dims * mlp_ratio))
|
||||
|
||||
def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor:
|
||||
x = x + self.attn(self.ln_1(x), attn_mask=attn_mask)
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L317
|
||||
class ClipTransformer:
|
||||
def __init__(self, dims:int, layers:int, n_heads:int, mlp_ratio:float=4.0):
|
||||
self.resblocks = [
|
||||
Open.ResidualAttentionBlocks(dims, n_heads, mlp_ratio) for _ in range(layers)
|
||||
]
|
||||
|
||||
def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor:
|
||||
x = x.transpose(0, 1).contiguous()
|
||||
for r in self.resblocks:
|
||||
x = r(x, attn_mask=attn_mask)
|
||||
x = x.transpose(0, 1)
|
||||
return x
|
||||
|
||||
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/model.py#L220
|
||||
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L661
|
||||
class ClipTextTransformer:
|
||||
def __init__(self, dims:int, n_heads:int, layers:int, vocab_size:int=49408, ctx_length:int=77):
|
||||
self.token_embedding = Embedding(vocab_size, dims)
|
||||
self.positional_embedding = Tensor.empty(ctx_length, dims)
|
||||
self.transformer = Open.ClipTransformer(dims, layers, n_heads)
|
||||
self.ln_final = LayerNorm(dims)
|
||||
self.text_projection = Tensor.empty(dims, dims)
|
||||
|
||||
@property
|
||||
def attn_mask(self) -> Tensor:
|
||||
if not hasattr(self, "_attn_mask"):
|
||||
self._attn_mask = Tensor.full((77, 77), float("-inf")).triu(1)
|
||||
return self._attn_mask
|
||||
|
||||
def __call__(self, text:Tensor) -> Tensor:
|
||||
seq_len = text.shape[1]
|
||||
|
||||
x = self.token_embedding(text)
|
||||
x = x + self.positional_embedding[:seq_len]
|
||||
x = self.transformer(x, attn_mask=self.attn_mask)
|
||||
x = self.ln_final(x)
|
||||
|
||||
pooled = x[:, text.argmax(dim=-1)] @ self.text_projection
|
||||
return pooled
|
||||
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L396
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L498
|
||||
class FrozenOpenClipEmbedder(Embedder):
|
||||
def __init__(self, dims:int, n_heads:int, layers:int, return_pooled:bool):
|
||||
self.tokenizer = Tokenizer.ClipTokenizer()
|
||||
self.model = Open.ClipTextTransformer(dims, n_heads, layers)
|
||||
self.return_pooled = return_pooled
|
||||
self.input_key = "txt"
|
||||
|
||||
def text_transformer_forward(self, x:Tensor, attn_mask:Optional[Tensor]=None):
|
||||
for r in self.model.transformer.resblocks:
|
||||
x, penultimate = r(x, attn_mask=attn_mask), x
|
||||
return x.permute(1,0,2), penultimate.permute(1,0,2)
|
||||
|
||||
def __call__(self, text:str) -> Union[Tensor,Tuple[Tensor,...]]:
|
||||
tokens = Tensor(self.tokenizer.encode(text, pad_with_zeros=True), dtype=dtypes.int64).reshape(1,-1)
|
||||
|
||||
x = self.model.token_embedding(tokens).add(self.model.positional_embedding).permute(1,0,2)
|
||||
x, penultimate = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
||||
|
||||
if self.return_pooled:
|
||||
x = self.model.ln_final(x)
|
||||
pooled = x[Tensor.arange(x.shape[0]), tokens.argmax(axis=-1).numpy().item()] @ self.model.text_projection
|
||||
return penultimate, pooled
|
||||
else:
|
||||
return penultimate
|
||||
253
extra/models/unet.py
Normal file
253
extra/models/unet.py
Normal file
@@ -0,0 +1,253 @@
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.nn import Linear, Conv2d, GroupNorm, LayerNorm
|
||||
|
||||
from typing import Optional, Union, List, Any, Tuple
|
||||
import math
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/util.py#L207
|
||||
def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp()
|
||||
args = timesteps.unsqueeze(1) * freqs.unsqueeze(0)
|
||||
return Tensor.cat(args.cos(), args.sin(), dim=-1).cast(dtypes.float16)
|
||||
|
||||
class ResBlock:
|
||||
def __init__(self, channels:int, emb_channels:int, out_channels:int):
|
||||
self.in_layers = [
|
||||
GroupNorm(32, channels),
|
||||
Tensor.silu,
|
||||
Conv2d(channels, out_channels, 3, padding=1),
|
||||
]
|
||||
self.emb_layers = [
|
||||
Tensor.silu,
|
||||
Linear(emb_channels, out_channels),
|
||||
]
|
||||
self.out_layers = [
|
||||
GroupNorm(32, out_channels),
|
||||
Tensor.silu,
|
||||
lambda x: x, # needed for weights loading code to work
|
||||
Conv2d(out_channels, out_channels, 3, padding=1),
|
||||
]
|
||||
self.skip_connection = Conv2d(channels, out_channels, 1) if channels != out_channels else (lambda x: x)
|
||||
|
||||
def __call__(self, x:Tensor, emb:Tensor) -> Tensor:
|
||||
h = x.sequential(self.in_layers)
|
||||
emb_out = emb.sequential(self.emb_layers)
|
||||
h = h + emb_out.reshape(*emb_out.shape, 1, 1)
|
||||
h = h.sequential(self.out_layers)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
class CrossAttention:
|
||||
def __init__(self, query_dim:int, ctx_dim:int, n_heads:int, d_head:int):
|
||||
self.to_q = Linear(query_dim, n_heads*d_head, bias=False)
|
||||
self.to_k = Linear(ctx_dim, n_heads*d_head, bias=False)
|
||||
self.to_v = Linear(ctx_dim, n_heads*d_head, bias=False)
|
||||
self.num_heads = n_heads
|
||||
self.head_size = d_head
|
||||
self.to_out = [Linear(n_heads*d_head, query_dim)]
|
||||
|
||||
def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
|
||||
ctx = x if ctx is None else ctx
|
||||
q,k,v = self.to_q(x), self.to_k(ctx), self.to_v(ctx)
|
||||
q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)]
|
||||
attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1,2)
|
||||
h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size)
|
||||
return h_.sequential(self.to_out)
|
||||
|
||||
class GEGLU:
|
||||
def __init__(self, dim_in:int, dim_out:int):
|
||||
self.proj = Linear(dim_in, dim_out * 2)
|
||||
self.dim_out = dim_out
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * gate.gelu()
|
||||
|
||||
class FeedForward:
|
||||
def __init__(self, dim:int, mult:int=4):
|
||||
self.net = [
|
||||
GEGLU(dim, dim*mult),
|
||||
lambda x: x, # needed for weights loading code to work
|
||||
Linear(dim*mult, dim)
|
||||
]
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return x.sequential(self.net)
|
||||
|
||||
class BasicTransformerBlock:
|
||||
def __init__(self, dim:int, ctx_dim:int, n_heads:int, d_head:int):
|
||||
self.attn1 = CrossAttention(dim, dim, n_heads, d_head)
|
||||
self.ff = FeedForward(dim)
|
||||
self.attn2 = CrossAttention(dim, ctx_dim, n_heads, d_head)
|
||||
self.norm1 = LayerNorm(dim)
|
||||
self.norm2 = LayerNorm(dim)
|
||||
self.norm3 = LayerNorm(dim)
|
||||
|
||||
def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
|
||||
x = x + self.attn1(self.norm1(x))
|
||||
x = x + self.attn2(self.norm2(x), ctx=ctx)
|
||||
x = x + self.ff(self.norm3(x))
|
||||
return x
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/attention.py#L619
|
||||
class SpatialTransformer:
|
||||
def __init__(self, channels:int, n_heads:int, d_head:int, ctx_dim:Union[int,List[int]], use_linear:bool, depth:int=1):
|
||||
if isinstance(ctx_dim, int):
|
||||
ctx_dim = [ctx_dim]*depth
|
||||
else:
|
||||
assert isinstance(ctx_dim, list) and depth == len(ctx_dim)
|
||||
self.norm = GroupNorm(32, channels)
|
||||
assert channels == n_heads * d_head
|
||||
self.proj_in = Linear(channels, channels) if use_linear else Conv2d(channels, channels, 1)
|
||||
self.transformer_blocks = [BasicTransformerBlock(channels, ctx_dim[d], n_heads, d_head) for d in range(depth)]
|
||||
self.proj_out = Linear(channels, channels) if use_linear else Conv2d(channels, channels, 1)
|
||||
self.use_linear = use_linear
|
||||
|
||||
def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
ops = [ (lambda z: z.reshape(b, c, h*w).permute(0,2,1)), (lambda z: self.proj_in(z)) ]
|
||||
x = x.sequential(ops if self.use_linear else ops[::-1])
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, ctx=ctx)
|
||||
ops = [ (lambda z: self.proj_out(z)), (lambda z: z.permute(0,2,1).reshape(b, c, h, w)) ]
|
||||
x = x.sequential(ops if self.use_linear else ops[::-1])
|
||||
return x + x_in
|
||||
|
||||
class Downsample:
|
||||
def __init__(self, channels:int):
|
||||
self.op = Conv2d(channels, channels, 3, stride=2, padding=1)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return self.op(x)
|
||||
|
||||
class Upsample:
|
||||
def __init__(self, channels:int):
|
||||
self.conv = Conv2d(channels, channels, 3, padding=1)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
bs,c,py,px = x.shape
|
||||
z = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2)
|
||||
return self.conv(z)
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/openaimodel.py#L472
|
||||
class UNetModel:
|
||||
def __init__(self, adm_in_ch:Optional[int], in_ch:int, out_ch:int, model_ch:int, attention_resolutions:List[int], num_res_blocks:int, channel_mult:List[int], transformer_depth:List[int], ctx_dim:Union[int,List[int]], use_linear:bool=False, d_head:Optional[int]=None, n_heads:Optional[int]=None):
|
||||
self.model_ch = model_ch
|
||||
self.num_res_blocks = [num_res_blocks] * len(channel_mult)
|
||||
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.d_head = d_head
|
||||
self.n_heads = n_heads
|
||||
def get_d_and_n_heads(dims:int) -> Tuple[int,int]:
|
||||
if self.d_head is None:
|
||||
assert self.n_heads is not None, f"d_head and n_heads cannot both be None"
|
||||
return dims // self.n_heads, self.n_heads
|
||||
else:
|
||||
assert self.n_heads is None, f"d_head and n_heads cannot both be non-None"
|
||||
return self.d_head, dims // self.d_head
|
||||
|
||||
time_embed_dim = model_ch * 4
|
||||
self.time_embed = [
|
||||
Linear(model_ch, time_embed_dim),
|
||||
Tensor.silu,
|
||||
Linear(time_embed_dim, time_embed_dim),
|
||||
]
|
||||
|
||||
if adm_in_ch is not None:
|
||||
self.label_emb = [
|
||||
[
|
||||
Linear(adm_in_ch, time_embed_dim),
|
||||
Tensor.silu,
|
||||
Linear(time_embed_dim, time_embed_dim),
|
||||
]
|
||||
]
|
||||
|
||||
self.input_blocks: List[Any] = [
|
||||
[Conv2d(in_ch, model_ch, 3, padding=1)]
|
||||
]
|
||||
input_block_channels = [model_ch]
|
||||
ch = model_ch
|
||||
ds = 1
|
||||
for idx, mult in enumerate(channel_mult):
|
||||
for _ in range(self.num_res_blocks[idx]):
|
||||
layers: List[Any] = [
|
||||
ResBlock(ch, time_embed_dim, model_ch*mult),
|
||||
]
|
||||
ch = mult * model_ch
|
||||
if ds in attention_resolutions:
|
||||
d_head, n_heads = get_d_and_n_heads(ch)
|
||||
layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx]))
|
||||
|
||||
self.input_blocks.append(layers)
|
||||
input_block_channels.append(ch)
|
||||
|
||||
if idx != len(channel_mult) - 1:
|
||||
self.input_blocks.append([
|
||||
Downsample(ch),
|
||||
])
|
||||
input_block_channels.append(ch)
|
||||
ds *= 2
|
||||
|
||||
d_head, n_heads = get_d_and_n_heads(ch)
|
||||
self.middle_block: List = [
|
||||
ResBlock(ch, time_embed_dim, ch),
|
||||
SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[-1]),
|
||||
ResBlock(ch, time_embed_dim, ch),
|
||||
]
|
||||
|
||||
self.output_blocks = []
|
||||
for idx, mult in list(enumerate(channel_mult))[::-1]:
|
||||
for i in range(self.num_res_blocks[idx] + 1):
|
||||
ich = input_block_channels.pop()
|
||||
layers = [
|
||||
ResBlock(ch + ich, time_embed_dim, model_ch*mult),
|
||||
]
|
||||
ch = model_ch * mult
|
||||
|
||||
if ds in attention_resolutions:
|
||||
d_head, n_heads = get_d_and_n_heads(ch)
|
||||
layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx]))
|
||||
|
||||
if idx > 0 and i == self.num_res_blocks[idx]:
|
||||
layers.append(Upsample(ch))
|
||||
ds //= 2
|
||||
self.output_blocks.append(layers)
|
||||
|
||||
self.out = [
|
||||
GroupNorm(32, ch),
|
||||
Tensor.silu,
|
||||
Conv2d(model_ch, out_ch, 3, padding=1),
|
||||
]
|
||||
|
||||
def __call__(self, x:Tensor, tms:Tensor, ctx:Tensor, y:Optional[Tensor]=None) -> Tensor:
|
||||
t_emb = timestep_embedding(tms, self.model_ch).cast(dtypes.float16)
|
||||
emb = t_emb.sequential(self.time_embed)
|
||||
|
||||
if y is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + y.sequential(self.label_emb[0])
|
||||
|
||||
emb = emb.cast(dtypes.float16)
|
||||
ctx = ctx.cast(dtypes.float16)
|
||||
x = x .cast(dtypes.float16)
|
||||
|
||||
def run(x:Tensor, bb) -> Tensor:
|
||||
if isinstance(bb, ResBlock): x = bb(x, emb)
|
||||
elif isinstance(bb, SpatialTransformer): x = bb(x, ctx)
|
||||
else: x = bb(x)
|
||||
return x
|
||||
|
||||
saved_inputs = []
|
||||
for b in self.input_blocks:
|
||||
for bb in b:
|
||||
x = run(x, bb)
|
||||
saved_inputs.append(x)
|
||||
for bb in self.middle_block:
|
||||
x = run(x, bb)
|
||||
for b in self.output_blocks:
|
||||
x = x.cat(saved_inputs.pop(), dim=1)
|
||||
for bb in b:
|
||||
x = run(x, bb)
|
||||
return x.sequential(self.out)
|
||||
4
test/external/external_test_jit_on_models.py
vendored
4
test/external/external_test_jit_on_models.py
vendored
@@ -30,8 +30,8 @@ class TestJittedModels(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(not CI, "huge for CI")
|
||||
def test_jitted_stable_diffusion(self):
|
||||
from examples.stable_diffusion import UNetModel
|
||||
model = UNetModel()
|
||||
from examples.stable_diffusion import UNetModel, unet_params
|
||||
model = UNetModel(**unet_params)
|
||||
derandomize_model(model)
|
||||
def test(t, t2): return model(t, 801, t2).realize()
|
||||
|
||||
|
||||
@@ -12,7 +12,8 @@ from test.helpers import derandomize_model, is_dtype_supported
|
||||
from examples.gpt2 import Transformer as GPT2Transformer, MODEL_PARAMS as GPT2_MODEL_PARAMS
|
||||
from examples.hlb_cifar10 import SpeedyResNet, hyp
|
||||
from examples.llama import Transformer as LLaMaTransformer, MODEL_PARAMS as LLAMA_MODEL_PARAMS
|
||||
from examples.stable_diffusion import UNetModel, ResBlock
|
||||
from examples.stable_diffusion import UNetModel, unet_params
|
||||
from extra.models.unet import ResBlock
|
||||
|
||||
global_mem_used = 0
|
||||
def helper_test(nm, gen, model, max_memory_allowed, max_kernels_allowed, all_jitted=False):
|
||||
@@ -49,7 +50,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
|
||||
@unittest.skipIf(CI, "too big for CI")
|
||||
def test_stable_diffusion(self):
|
||||
model = UNetModel()
|
||||
model = UNetModel(**unet_params)
|
||||
derandomize_model(model)
|
||||
@TinyJit
|
||||
def test(t, t2): return model(t, 801, t2).realize()
|
||||
|
||||
Reference in New Issue
Block a user