mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
Clip model updates for Stable Diffusion mlperf training (#12313)
* stable diffusion mlperf clip changes * add clip tests * set gelu as attribute * add more tests * factor out GPUS * rerun CI * add imports to if blocks * remove unneeded axis * add clip tests to CI * move clip tests * add deps, disable max buf size
This commit is contained in:
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@@ -259,7 +259,7 @@ jobs:
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: unittest-12
|
||||
pydeps: "pillow"
|
||||
pydeps: "pillow numpy ftfy regex"
|
||||
deps: testing_unit
|
||||
- name: Run unit tests
|
||||
run: python -m pytest -n=auto test/unit/ --durations=20
|
||||
@@ -267,6 +267,8 @@ jobs:
|
||||
run: NULL=1 python3 test/test_multitensor.py TestMultiTensor.test_data_parallel_resnet_train_step
|
||||
- name: Run SDXL on NULL backend
|
||||
run: MAX_BUFFER_SIZE=0 NULL=1 DEBUG=1 python3 examples/sdxl.py --seed 0 --noshow --timing --fakeweights
|
||||
- name: Run Clip tests for SD MLPerf on NULL backend
|
||||
run: MAX_BUFFER_SIZE=0 NULL=1 python -m pytest -n=auto test/external/mlperf_stable_diffusion/external_test_models.py::TestOpenClip --durations=20
|
||||
# TODO: support fake weights
|
||||
#- name: Run LLaMA 7B on 4 fake devices
|
||||
# run: NULL=1 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 3 --temperature 0 --timing
|
||||
|
||||
@@ -17,6 +17,10 @@ def he_normal(*shape, a: float = 0.00, **kwargs) -> Tensor:
|
||||
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:])) / 0.87962566103423978
|
||||
return std * rand_truncn(*shape, **kwargs)
|
||||
|
||||
# Stable Diffusion v2 training uses default torch gelu, which doesn't use tanh approximation
|
||||
def gelu_erf(x:Tensor) -> Tensor:
|
||||
return 0.5 * x * (1.0 + (x / 1.4142135623730951).erf())
|
||||
|
||||
class Conv2dHeNormal(nn.Conv2d):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
||||
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
|
||||
|
||||
@@ -9,6 +9,9 @@ from PIL import Image
|
||||
import numpy as np
|
||||
import re, gzip
|
||||
|
||||
# Allow for monkeypatching for mlperf.
|
||||
gelu = Tensor.gelu
|
||||
|
||||
@lru_cache()
|
||||
def default_bpe():
|
||||
# Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
|
||||
@@ -53,8 +56,8 @@ class Tokenizer:
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
class ClipTokenizer:
|
||||
def __init__(self):
|
||||
self.byte_encoder = Tokenizer.bytes_to_unicode()
|
||||
def __init__(self, version=None):
|
||||
self.byte_encoder, self.version = Tokenizer.bytes_to_unicode(), version
|
||||
merges = gzip.open(default_bpe()).read().decode("utf-8").split('\n')
|
||||
merges = merges[1:49152-256-2+1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
@@ -62,11 +65,17 @@ class Tokenizer:
|
||||
vocab = vocab + [v+'</w>' for v in vocab]
|
||||
for merge in merges:
|
||||
vocab.append(''.join(merge))
|
||||
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
||||
if self.version == "sd_mlperf_v5_0":
|
||||
import regex
|
||||
vocab.extend(['<start_of_text>', '<end_of_text>'])
|
||||
self.cache = {'<start_of_text>': '<start_of_text>', '<end_of_text>': '<end_of_text>'}
|
||||
self.pat = regex.compile(r"""<start_of_text>|<end_of_text>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", regex.IGNORECASE)
|
||||
else:
|
||||
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
||||
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
||||
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE)
|
||||
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
||||
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE)
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
@@ -110,8 +119,17 @@ class Tokenizer:
|
||||
|
||||
def encode(self, text:str, pad_with_zeros:bool=False) -> List[int]:
|
||||
bpe_tokens: List[int] = []
|
||||
text = Tokenizer.whitespace_clean(text.strip()).lower()
|
||||
for token in re.findall(self.pat, text):
|
||||
if self.version == "sd_mlperf_v5_0":
|
||||
import regex, ftfy, html
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text)).strip()
|
||||
text = Tokenizer.whitespace_clean(text).lower()
|
||||
re_module = regex
|
||||
else:
|
||||
text = Tokenizer.whitespace_clean(text.strip()).lower()
|
||||
re_module = re
|
||||
|
||||
for token in re_module.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||
# Truncation, keeping two slots for start and end tokens.
|
||||
@@ -252,10 +270,8 @@ class Open:
|
||||
q,k,v = [y.reshape(T, B*self.n_heads, self.d_head).transpose(0, 1).reshape(B, self.n_heads, T, self.d_head) for y in proj.chunk(3)]
|
||||
|
||||
attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
attn_output = attn_output.permute(2, 0, 1, 3).reshape(T*B, C)
|
||||
|
||||
attn_output = attn_output.permute(2, 0, 1, 3).reshape(T, B, C)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
attn_output = attn_output.reshape(T, B, C)
|
||||
|
||||
return attn_output
|
||||
|
||||
@@ -263,9 +279,10 @@ class Open:
|
||||
def __init__(self, dims, hidden_dims):
|
||||
self.c_fc = Linear(dims, hidden_dims)
|
||||
self.c_proj = Linear(hidden_dims, dims)
|
||||
self.gelu = gelu
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return x.sequential([self.c_fc, Tensor.gelu, self.c_proj])
|
||||
return x.sequential([self.c_fc, self.gelu, self.c_proj])
|
||||
|
||||
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L210
|
||||
class ResidualAttentionBlock:
|
||||
@@ -350,15 +367,15 @@ class Open:
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L396
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L498
|
||||
class FrozenOpenClipEmbedder(Embedder):
|
||||
def __init__(self, dims:int, n_heads:int, layers:int, return_pooled:bool, ln_penultimate:bool=False):
|
||||
self.tokenizer = Tokenizer.ClipTokenizer()
|
||||
def __init__(self, dims:int, n_heads:int, layers:int, return_pooled:bool, ln_penultimate:bool=False, clip_tokenizer_version=None):
|
||||
self.tokenizer = Tokenizer.ClipTokenizer(version=clip_tokenizer_version)
|
||||
self.model = Open.ClipTextTransformer(dims, n_heads, layers)
|
||||
self.return_pooled = return_pooled
|
||||
self.input_key = "txt"
|
||||
self.ln_penultimate = ln_penultimate
|
||||
|
||||
def tokenize(self, text:str, device:Optional[str]=None) -> Tensor:
|
||||
return Tensor(self.tokenizer.encode(text, pad_with_zeros=True), dtype=dtypes.int64, device=device).reshape(1,-1)
|
||||
return Tensor(self.tokenizer.encode(text, pad_with_zeros=True), dtype=dtypes.int32, device=device).reshape(1,-1)
|
||||
|
||||
def text_transformer_forward(self, x:Tensor, attn_mask:Optional[Tensor]=None):
|
||||
for r in self.model.transformer.resblocks:
|
||||
@@ -449,7 +466,7 @@ class OpenClipEncoder:
|
||||
x = x + self.positional_embedding
|
||||
x = self.transformer(x, attn_mask=self.attn_mask)
|
||||
x = self.ln_final(x)
|
||||
x = x[:, tokens.argmax(axis=-1)]
|
||||
x = x[Tensor.arange(x.shape[0], device=x.device), tokens.argmax(axis=-1)]
|
||||
x = x @ self.text_projection
|
||||
return x
|
||||
|
||||
|
||||
53
test/external/mlperf_stable_diffusion/external_test_models.py
vendored
Normal file
53
test/external/mlperf_stable_diffusion/external_test_models.py
vendored
Normal file
@@ -0,0 +1,53 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from extra.models import clip
|
||||
from examples.mlperf.initializers import gelu_erf
|
||||
Device.DEFAULT="NULL"
|
||||
GPUS = [f"NULL:{i}" for i in range(8)]
|
||||
|
||||
clip_params = {"dims": 1024, "n_heads": 16, "layers": 24, "return_pooled": False, "ln_penultimate": True, "clip_tokenizer_version": "sd_mlperf_v5_0"}
|
||||
def get_cond_stage_model(GPUS:list[str]|None=None) -> clip.FrozenOpenClipEmbedder:
|
||||
clip.gelu = gelu_erf
|
||||
model = clip.FrozenOpenClipEmbedder(**clip_params)
|
||||
if GPUS and len(GPUS) > 1:
|
||||
for p in get_parameters(model): p.to_(GPUS)
|
||||
return model
|
||||
def get_tokens(BS:int) -> Tensor: return Tensor([0] * 77 * BS, dtype=dtypes.int32).reshape(-1, 77)
|
||||
|
||||
class TestOpenClip(unittest.TestCase):
|
||||
def test_tokenizer(self):
|
||||
prompt = "Beautiful is better than ugly.\nExplicit is better than implicit.\nSimple is better than complex.\nComplex is better than complicated."
|
||||
model = get_cond_stage_model()
|
||||
tokens = model.tokenizer.encode(prompt, pad_with_zeros=True)
|
||||
expected = [49406, 1215, 533, 1539, 1126, 8159, 269, 33228, 533, 1539, 1126, 15269, 585, 269, 4129, 533, 1539, 1126, 6324, 269, 6324, 533,
|
||||
1539, 1126, 16621, 269, 49407] + [0]*50
|
||||
self.assertEqual(tokens, expected)
|
||||
|
||||
def test_clip_gelu_init(self):
|
||||
for resblock in get_cond_stage_model().model.transformer.resblocks:
|
||||
self.assertEqual(resblock.mlp.gelu, gelu_erf)
|
||||
|
||||
def test_multigpu_clip_embed(self):
|
||||
BS = 304
|
||||
model = get_cond_stage_model(GPUS)
|
||||
tokens = get_tokens(BS)
|
||||
embeds = model.embed_tokens(tokens.shard(GPUS, axis=0)).realize()
|
||||
self.assertEqual(embeds.shape, (BS, 77, 1024))
|
||||
self.assertEqual(embeds.dtype, dtypes.float32)
|
||||
|
||||
def test_multigpu_clip_score(self):
|
||||
BS = 240
|
||||
vision_cfg = {'width': 1280, 'layers': 32, 'd_head': 80, 'image_size': 224, 'patch_size': 14}
|
||||
text_cfg = {'width': 1024, 'n_heads': 16, 'layers': 24, 'vocab_size': 49408, 'ctx_length': 77}
|
||||
clip.gelu = gelu_erf
|
||||
clip_encoder = clip.OpenClipEncoder(1024, text_cfg, vision_cfg)
|
||||
for p in get_parameters(clip_encoder): p.to_(GPUS)
|
||||
tokens = get_tokens(BS)
|
||||
imgs = Tensor.zeros(BS,3,224,224).contiguous()
|
||||
scores = clip_encoder.get_clip_score(tokens.shard(GPUS, axis=0), imgs.shard(GPUS, axis=0)).realize()
|
||||
self.assertEqual(scores.shape, (BS,))
|
||||
self.assertEqual(scores.dtype, dtypes.float32)
|
||||
|
||||
if __name__=="__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user