mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
* pulled clip and unet into seperate files * reference cleanup, lru cache fix * better pool indexing
349 lines
13 KiB
Python
349 lines
13 KiB
Python
from tinygrad import Tensor, dtypes
|
|
from tinygrad.helpers import fetch
|
|
from tinygrad.nn import Linear, LayerNorm, Embedding
|
|
|
|
from typing import List, Optional, Union, Tuple
|
|
from abc import ABC, abstractmethod
|
|
from functools import lru_cache
|
|
import re, gzip
|
|
|
|
@lru_cache()
|
|
def default_bpe():
|
|
# Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
|
|
return fetch("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", "bpe_simple_vocab_16e6.txt.gz")
|
|
|
|
class Tokenizer:
|
|
"""
|
|
Namespace for CLIP Text Tokenizer components.
|
|
"""
|
|
|
|
@staticmethod
|
|
def get_pairs(word):
|
|
"""
|
|
Return set of symbol pairs in a word.
|
|
Word is represented as tuple of symbols (symbols being variable-length strings).
|
|
"""
|
|
return set(zip(word, word[1:]))
|
|
@staticmethod
|
|
def whitespace_clean(text):
|
|
text = re.sub(r'\s+', ' ', text)
|
|
text = text.strip()
|
|
return text
|
|
@staticmethod
|
|
def bytes_to_unicode():
|
|
"""
|
|
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
|
The reversible bpe codes work on unicode strings.
|
|
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
|
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
|
This is a significant percentage of your normal, say, 32K bpe vocab.
|
|
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
|
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
|
"""
|
|
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
|
cs = bs[:]
|
|
n = 0
|
|
for b in range(2**8):
|
|
if b not in bs:
|
|
bs.append(b)
|
|
cs.append(2**8+n)
|
|
n += 1
|
|
cs = [chr(n) for n in cs]
|
|
return dict(zip(bs, cs))
|
|
class ClipTokenizer:
|
|
def __init__(self):
|
|
self.byte_encoder = Tokenizer.bytes_to_unicode()
|
|
merges = gzip.open(default_bpe()).read().decode("utf-8").split('\n')
|
|
merges = merges[1:49152-256-2+1]
|
|
merges = [tuple(merge.split()) for merge in merges]
|
|
vocab = list(Tokenizer.bytes_to_unicode().values())
|
|
vocab = vocab + [v+'</w>' for v in vocab]
|
|
for merge in merges:
|
|
vocab.append(''.join(merge))
|
|
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
|
self.encoder = dict(zip(vocab, range(len(vocab))))
|
|
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
|
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
|
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE)
|
|
|
|
def bpe(self, token):
|
|
if token in self.cache:
|
|
return self.cache[token]
|
|
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
|
pairs = Tokenizer.get_pairs(word)
|
|
|
|
if not pairs:
|
|
return token+'</w>'
|
|
|
|
while True:
|
|
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
|
if bigram not in self.bpe_ranks:
|
|
break
|
|
first, second = bigram
|
|
new_word = []
|
|
i = 0
|
|
while i < len(word):
|
|
try:
|
|
j = word.index(first, i)
|
|
new_word.extend(word[i:j])
|
|
i = j
|
|
except Exception:
|
|
new_word.extend(word[i:])
|
|
break
|
|
|
|
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
|
new_word.append(first+second)
|
|
i += 2
|
|
else:
|
|
new_word.append(word[i])
|
|
i += 1
|
|
new_word = tuple(new_word)
|
|
word = new_word
|
|
if len(word) == 1:
|
|
break
|
|
pairs = Tokenizer.get_pairs(word)
|
|
word = ' '.join(word)
|
|
self.cache[token] = word
|
|
return word
|
|
|
|
def encode(self, text:str, pad_with_zeros:bool=False):
|
|
bpe_tokens: List[int] = []
|
|
text = Tokenizer.whitespace_clean(text.strip()).lower()
|
|
for token in re.findall(self.pat, text):
|
|
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
|
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
|
# Truncation, keeping two slots for start and end tokens.
|
|
if len(bpe_tokens) > 75:
|
|
bpe_tokens = bpe_tokens[:75]
|
|
return [49406] + bpe_tokens + [49407] + ([0] if pad_with_zeros else [49407]) * (77 - len(bpe_tokens) - 2)
|
|
|
|
|
|
class Embedder(ABC):
|
|
input_key: str
|
|
@abstractmethod
|
|
def __call__(self, text:str) -> Union[Tensor,Tuple[Tensor,...]]:
|
|
pass
|
|
|
|
|
|
class Closed:
|
|
"""
|
|
Namespace for OpenAI CLIP model components.
|
|
"""
|
|
class ClipMlp:
|
|
def __init__(self):
|
|
self.fc1 = Linear(768, 3072)
|
|
self.fc2 = Linear(3072, 768)
|
|
|
|
def __call__(self, h:Tensor) -> Tensor:
|
|
h = self.fc1(h)
|
|
h = h.quick_gelu()
|
|
h = self.fc2(h)
|
|
return h
|
|
|
|
class ClipAttention:
|
|
def __init__(self):
|
|
self.embed_dim = 768
|
|
self.num_heads = 12
|
|
self.head_dim = self.embed_dim // self.num_heads
|
|
self.k_proj = Linear(self.embed_dim, self.embed_dim)
|
|
self.v_proj = Linear(self.embed_dim, self.embed_dim)
|
|
self.q_proj = Linear(self.embed_dim, self.embed_dim)
|
|
self.out_proj = Linear(self.embed_dim, self.embed_dim)
|
|
|
|
def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor:
|
|
bsz, tgt_len, embed_dim = hidden_states.shape
|
|
q,k,v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
|
|
q,k,v = [x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) for x in (q,k,v)]
|
|
attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=causal_attention_mask)
|
|
return self.out_proj(attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim))
|
|
|
|
class ClipEncoderLayer:
|
|
def __init__(self):
|
|
self.self_attn = Closed.ClipAttention()
|
|
self.layer_norm1 = LayerNorm(768)
|
|
self.mlp = Closed.ClipMlp()
|
|
self.layer_norm2 = LayerNorm(768)
|
|
|
|
def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor:
|
|
residual = hidden_states
|
|
hidden_states = self.layer_norm1(hidden_states)
|
|
hidden_states = self.self_attn(hidden_states, causal_attention_mask)
|
|
hidden_states = residual + hidden_states
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.layer_norm2(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
return hidden_states
|
|
|
|
class ClipTextEmbeddings:
|
|
def __init__(self):
|
|
self.token_embedding = Embedding(49408, 768)
|
|
self.position_embedding = Embedding(77, 768)
|
|
|
|
def __call__(self, input_ids:Tensor, position_ids:Tensor) -> Tensor:
|
|
return self.token_embedding(input_ids) + self.position_embedding(position_ids)
|
|
|
|
class ClipEncoder:
|
|
def __init__(self, layer_count:int=12):
|
|
self.layers = [Closed.ClipEncoderLayer() for _ in range(layer_count)]
|
|
|
|
def __call__(self, x:Tensor, causal_attention_mask:Tensor, ret_layer_idx:Optional[int]=None) -> Tensor:
|
|
# the indexing of layers is NOT off by 1, the original code considers the "input" as the first hidden state
|
|
layers = self.layers if ret_layer_idx is None else self.layers[:ret_layer_idx]
|
|
for l in layers:
|
|
x = l(x, causal_attention_mask)
|
|
return x
|
|
|
|
class ClipTextTransformer:
|
|
def __init__(self, ret_layer_idx:Optional[int]=None):
|
|
self.embeddings = Closed.ClipTextEmbeddings()
|
|
self.encoder = Closed.ClipEncoder()
|
|
self.final_layer_norm = LayerNorm(768)
|
|
self.ret_layer_idx = ret_layer_idx
|
|
|
|
def __call__(self, input_ids:Tensor) -> Tensor:
|
|
x = self.embeddings(input_ids, Tensor.arange(input_ids.shape[1]).reshape(1, -1))
|
|
x = self.encoder(x, Tensor.full((1, 1, 77, 77), float("-inf")).triu(1), self.ret_layer_idx)
|
|
return self.final_layer_norm(x) if (self.ret_layer_idx is None) else x
|
|
|
|
class ClipTextModel:
|
|
def __init__(self, ret_layer_idx:Optional[int]):
|
|
self.text_model = Closed.ClipTextTransformer(ret_layer_idx=ret_layer_idx)
|
|
|
|
|
|
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L331
|
|
class FrozenClosedClipEmbedder(Embedder):
|
|
def __init__(self, ret_layer_idx:Optional[int]=None):
|
|
self.tokenizer = Tokenizer.ClipTokenizer()
|
|
self.transformer = Closed.ClipTextModel(ret_layer_idx)
|
|
self.input_key = "txt"
|
|
|
|
def __call__(self, text:str) -> Union[Tensor,Tuple[Tensor,...]]:
|
|
tokens = Tensor(self.tokenizer.encode(text))
|
|
return self.transformer.text_model(tokens.reshape(1,-1))
|
|
|
|
|
|
class Open:
|
|
"""
|
|
Namespace for OpenCLIP model components.
|
|
"""
|
|
class MultiheadAttention:
|
|
def __init__(self, dims:int, n_heads:int):
|
|
self.dims = dims
|
|
self.n_heads = n_heads
|
|
self.d_head = self.dims // self.n_heads
|
|
|
|
self.in_proj_bias = Tensor.empty(3*dims)
|
|
self.in_proj_weight = Tensor.empty(3*dims, dims)
|
|
self.out_proj = Linear(dims, dims)
|
|
|
|
def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor:
|
|
T,B,C = x.shape
|
|
|
|
proj = x.linear(self.in_proj_weight.T, self.in_proj_bias)
|
|
proj = proj.unflatten(-1, (3,C)).unsqueeze(0).transpose(0,-2)
|
|
q,k,v = proj.chunk(3)
|
|
|
|
q,k,v = [y.reshape(T, B*self.n_heads, self.d_head).transpose(0, 1).reshape(B, self.n_heads, T, self.d_head) for y in (q,k,v)]
|
|
|
|
attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
|
attn_output = attn_output.permute(2,0,1,3).reshape(B*T, C)
|
|
|
|
attn_output = self.out_proj(attn_output)
|
|
attn_output = attn_output.reshape(T, B, C)
|
|
|
|
return attn_output
|
|
|
|
class Mlp:
|
|
def __init__(self, dims, hidden_dims):
|
|
self.c_fc = Linear(dims, hidden_dims)
|
|
self.c_proj = Linear(hidden_dims, dims)
|
|
|
|
def __call__(self, x:Tensor) -> Tensor:
|
|
return x.sequential([self.c_fc, Tensor.gelu, self.c_proj])
|
|
|
|
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L210
|
|
class ResidualAttentionBlocks:
|
|
def __init__(self, dims:int, n_heads:int, mlp_ratio:float):
|
|
self.ln_1 = LayerNorm(dims)
|
|
self.attn = Open.MultiheadAttention(dims, n_heads)
|
|
|
|
self.ln_2 = LayerNorm(dims)
|
|
self.mlp = Open.Mlp(dims, int(dims * mlp_ratio))
|
|
|
|
def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor:
|
|
x = x + self.attn(self.ln_1(x), attn_mask=attn_mask)
|
|
x = x + self.mlp(self.ln_2(x))
|
|
return x
|
|
|
|
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L317
|
|
class ClipTransformer:
|
|
def __init__(self, dims:int, layers:int, n_heads:int, mlp_ratio:float=4.0):
|
|
self.resblocks = [
|
|
Open.ResidualAttentionBlocks(dims, n_heads, mlp_ratio) for _ in range(layers)
|
|
]
|
|
|
|
def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor:
|
|
x = x.transpose(0, 1).contiguous()
|
|
for r in self.resblocks:
|
|
x = r(x, attn_mask=attn_mask)
|
|
x = x.transpose(0, 1)
|
|
return x
|
|
|
|
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/model.py#L220
|
|
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L661
|
|
class ClipTextTransformer:
|
|
def __init__(self, dims:int, n_heads:int, layers:int, vocab_size:int=49408, ctx_length:int=77):
|
|
self.token_embedding = Embedding(vocab_size, dims)
|
|
self.positional_embedding = Tensor.empty(ctx_length, dims)
|
|
self.transformer = Open.ClipTransformer(dims, layers, n_heads)
|
|
self.ln_final = LayerNorm(dims)
|
|
self.text_projection = Tensor.empty(dims, dims)
|
|
|
|
@property
|
|
def attn_mask(self) -> Tensor:
|
|
if not hasattr(self, "_attn_mask"):
|
|
self._attn_mask = Tensor.full((77, 77), float("-inf")).triu(1)
|
|
return self._attn_mask
|
|
|
|
def __call__(self, text:Tensor) -> Tensor:
|
|
seq_len = text.shape[1]
|
|
|
|
x = self.token_embedding(text)
|
|
x = x + self.positional_embedding[:seq_len]
|
|
x = self.transformer(x, attn_mask=self.attn_mask)
|
|
x = self.ln_final(x)
|
|
|
|
pooled = x[:, text.argmax(dim=-1)] @ self.text_projection
|
|
return pooled
|
|
|
|
|
|
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L396
|
|
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L498
|
|
class FrozenOpenClipEmbedder(Embedder):
|
|
def __init__(self, dims:int, n_heads:int, layers:int, return_pooled:bool):
|
|
self.tokenizer = Tokenizer.ClipTokenizer()
|
|
self.model = Open.ClipTextTransformer(dims, n_heads, layers)
|
|
self.return_pooled = return_pooled
|
|
self.input_key = "txt"
|
|
|
|
def text_transformer_forward(self, x:Tensor, attn_mask:Optional[Tensor]=None):
|
|
for r in self.model.transformer.resblocks:
|
|
x, penultimate = r(x, attn_mask=attn_mask), x
|
|
return x.permute(1,0,2), penultimate.permute(1,0,2)
|
|
|
|
def __call__(self, text:str) -> Union[Tensor,Tuple[Tensor,...]]:
|
|
tokens = Tensor(self.tokenizer.encode(text, pad_with_zeros=True), dtype=dtypes.int64).reshape(1,-1)
|
|
|
|
x = self.model.token_embedding(tokens).add(self.model.positional_embedding).permute(1,0,2)
|
|
x, penultimate = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
|
|
|
if self.return_pooled:
|
|
x = self.model.ln_final(x)
|
|
pooled = x[Tensor.arange(x.shape[0]), tokens.argmax(axis=-1).numpy().item()] @ self.model.text_projection
|
|
return penultimate, pooled
|
|
else:
|
|
return penultimate
|