From c2689c505ede3d079812caa5622dc37c1f25a5e7 Mon Sep 17 00:00:00 2001 From: hooved <172129504+hooved@users.noreply.github.com> Date: Mon, 29 Sep 2025 21:50:14 -0400 Subject: [PATCH] Clip model updates for Stable Diffusion mlperf training (#12313) * stable diffusion mlperf clip changes * add clip tests * set gelu as attribute * add more tests * factor out GPUS * rerun CI * add imports to if blocks * remove unneeded axis * add clip tests to CI * move clip tests * add deps, disable max buf size --- .github/workflows/test.yml | 4 +- examples/mlperf/initializers.py | 4 ++ extra/models/clip.py | 47 ++++++++++------ .../external_test_models.py | 53 +++++++++++++++++++ 4 files changed, 92 insertions(+), 16 deletions(-) create mode 100644 test/external/mlperf_stable_diffusion/external_test_models.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index bb5742a146..49be678d7f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -259,7 +259,7 @@ jobs: uses: ./.github/actions/setup-tinygrad with: key: unittest-12 - pydeps: "pillow" + pydeps: "pillow numpy ftfy regex" deps: testing_unit - name: Run unit tests run: python -m pytest -n=auto test/unit/ --durations=20 @@ -267,6 +267,8 @@ jobs: run: NULL=1 python3 test/test_multitensor.py TestMultiTensor.test_data_parallel_resnet_train_step - name: Run SDXL on NULL backend run: MAX_BUFFER_SIZE=0 NULL=1 DEBUG=1 python3 examples/sdxl.py --seed 0 --noshow --timing --fakeweights + - name: Run Clip tests for SD MLPerf on NULL backend + run: MAX_BUFFER_SIZE=0 NULL=1 python -m pytest -n=auto test/external/mlperf_stable_diffusion/external_test_models.py::TestOpenClip --durations=20 # TODO: support fake weights #- name: Run LLaMA 7B on 4 fake devices # run: NULL=1 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 3 --temperature 0 --timing diff --git a/examples/mlperf/initializers.py b/examples/mlperf/initializers.py index d84e64bab9..9fdfec4f3d 100644 --- a/examples/mlperf/initializers.py +++ b/examples/mlperf/initializers.py @@ -17,6 +17,10 @@ def he_normal(*shape, a: float = 0.00, **kwargs) -> Tensor: std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:])) / 0.87962566103423978 return std * rand_truncn(*shape, **kwargs) +# Stable Diffusion v2 training uses default torch gelu, which doesn't use tanh approximation +def gelu_erf(x:Tensor) -> Tensor: + return 0.5 * x * (1.0 + (x / 1.4142135623730951).erf()) + class Conv2dHeNormal(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) diff --git a/extra/models/clip.py b/extra/models/clip.py index 7b176fcdf3..2801049ad8 100644 --- a/extra/models/clip.py +++ b/extra/models/clip.py @@ -9,6 +9,9 @@ from PIL import Image import numpy as np import re, gzip +# Allow for monkeypatching for mlperf. +gelu = Tensor.gelu + @lru_cache() def default_bpe(): # Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license) @@ -53,8 +56,8 @@ class Tokenizer: cs = [chr(n) for n in cs] return dict(zip(bs, cs)) class ClipTokenizer: - def __init__(self): - self.byte_encoder = Tokenizer.bytes_to_unicode() + def __init__(self, version=None): + self.byte_encoder, self.version = Tokenizer.bytes_to_unicode(), version merges = gzip.open(default_bpe()).read().decode("utf-8").split('\n') merges = merges[1:49152-256-2+1] merges = [tuple(merge.split()) for merge in merges] @@ -62,11 +65,17 @@ class Tokenizer: vocab = vocab + [v+'' for v in vocab] for merge in merges: vocab.append(''.join(merge)) - vocab.extend(['<|startoftext|>', '<|endoftext|>']) + if self.version == "sd_mlperf_v5_0": + import regex + vocab.extend(['', '']) + self.cache = {'': '', '': ''} + self.pat = regex.compile(r"""||'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", regex.IGNORECASE) + else: + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE) self.encoder = dict(zip(vocab, range(len(vocab)))) self.bpe_ranks = dict(zip(merges, range(len(merges)))) - self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} - self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE) def bpe(self, token): if token in self.cache: @@ -110,8 +119,17 @@ class Tokenizer: def encode(self, text:str, pad_with_zeros:bool=False) -> List[int]: bpe_tokens: List[int] = [] - text = Tokenizer.whitespace_clean(text.strip()).lower() - for token in re.findall(self.pat, text): + if self.version == "sd_mlperf_v5_0": + import regex, ftfy, html + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)).strip() + text = Tokenizer.whitespace_clean(text).lower() + re_module = regex + else: + text = Tokenizer.whitespace_clean(text.strip()).lower() + re_module = re + + for token in re_module.findall(self.pat, text): token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) # Truncation, keeping two slots for start and end tokens. @@ -252,10 +270,8 @@ class Open: q,k,v = [y.reshape(T, B*self.n_heads, self.d_head).transpose(0, 1).reshape(B, self.n_heads, T, self.d_head) for y in proj.chunk(3)] attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) - attn_output = attn_output.permute(2, 0, 1, 3).reshape(T*B, C) - + attn_output = attn_output.permute(2, 0, 1, 3).reshape(T, B, C) attn_output = self.out_proj(attn_output) - attn_output = attn_output.reshape(T, B, C) return attn_output @@ -263,9 +279,10 @@ class Open: def __init__(self, dims, hidden_dims): self.c_fc = Linear(dims, hidden_dims) self.c_proj = Linear(hidden_dims, dims) + self.gelu = gelu def __call__(self, x:Tensor) -> Tensor: - return x.sequential([self.c_fc, Tensor.gelu, self.c_proj]) + return x.sequential([self.c_fc, self.gelu, self.c_proj]) # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L210 class ResidualAttentionBlock: @@ -350,15 +367,15 @@ class Open: # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L396 # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L498 class FrozenOpenClipEmbedder(Embedder): - def __init__(self, dims:int, n_heads:int, layers:int, return_pooled:bool, ln_penultimate:bool=False): - self.tokenizer = Tokenizer.ClipTokenizer() + def __init__(self, dims:int, n_heads:int, layers:int, return_pooled:bool, ln_penultimate:bool=False, clip_tokenizer_version=None): + self.tokenizer = Tokenizer.ClipTokenizer(version=clip_tokenizer_version) self.model = Open.ClipTextTransformer(dims, n_heads, layers) self.return_pooled = return_pooled self.input_key = "txt" self.ln_penultimate = ln_penultimate def tokenize(self, text:str, device:Optional[str]=None) -> Tensor: - return Tensor(self.tokenizer.encode(text, pad_with_zeros=True), dtype=dtypes.int64, device=device).reshape(1,-1) + return Tensor(self.tokenizer.encode(text, pad_with_zeros=True), dtype=dtypes.int32, device=device).reshape(1,-1) def text_transformer_forward(self, x:Tensor, attn_mask:Optional[Tensor]=None): for r in self.model.transformer.resblocks: @@ -449,7 +466,7 @@ class OpenClipEncoder: x = x + self.positional_embedding x = self.transformer(x, attn_mask=self.attn_mask) x = self.ln_final(x) - x = x[:, tokens.argmax(axis=-1)] + x = x[Tensor.arange(x.shape[0], device=x.device), tokens.argmax(axis=-1)] x = x @ self.text_projection return x diff --git a/test/external/mlperf_stable_diffusion/external_test_models.py b/test/external/mlperf_stable_diffusion/external_test_models.py new file mode 100644 index 0000000000..e445849f32 --- /dev/null +++ b/test/external/mlperf_stable_diffusion/external_test_models.py @@ -0,0 +1,53 @@ +import unittest +from tinygrad import Tensor, dtypes, Device +from tinygrad.nn.state import get_parameters +from extra.models import clip +from examples.mlperf.initializers import gelu_erf +Device.DEFAULT="NULL" +GPUS = [f"NULL:{i}" for i in range(8)] + +clip_params = {"dims": 1024, "n_heads": 16, "layers": 24, "return_pooled": False, "ln_penultimate": True, "clip_tokenizer_version": "sd_mlperf_v5_0"} +def get_cond_stage_model(GPUS:list[str]|None=None) -> clip.FrozenOpenClipEmbedder: + clip.gelu = gelu_erf + model = clip.FrozenOpenClipEmbedder(**clip_params) + if GPUS and len(GPUS) > 1: + for p in get_parameters(model): p.to_(GPUS) + return model +def get_tokens(BS:int) -> Tensor: return Tensor([0] * 77 * BS, dtype=dtypes.int32).reshape(-1, 77) + +class TestOpenClip(unittest.TestCase): + def test_tokenizer(self): + prompt = "Beautiful is better than ugly.\nExplicit is better than implicit.\nSimple is better than complex.\nComplex is better than complicated." + model = get_cond_stage_model() + tokens = model.tokenizer.encode(prompt, pad_with_zeros=True) + expected = [49406, 1215, 533, 1539, 1126, 8159, 269, 33228, 533, 1539, 1126, 15269, 585, 269, 4129, 533, 1539, 1126, 6324, 269, 6324, 533, + 1539, 1126, 16621, 269, 49407] + [0]*50 + self.assertEqual(tokens, expected) + + def test_clip_gelu_init(self): + for resblock in get_cond_stage_model().model.transformer.resblocks: + self.assertEqual(resblock.mlp.gelu, gelu_erf) + + def test_multigpu_clip_embed(self): + BS = 304 + model = get_cond_stage_model(GPUS) + tokens = get_tokens(BS) + embeds = model.embed_tokens(tokens.shard(GPUS, axis=0)).realize() + self.assertEqual(embeds.shape, (BS, 77, 1024)) + self.assertEqual(embeds.dtype, dtypes.float32) + + def test_multigpu_clip_score(self): + BS = 240 + vision_cfg = {'width': 1280, 'layers': 32, 'd_head': 80, 'image_size': 224, 'patch_size': 14} + text_cfg = {'width': 1024, 'n_heads': 16, 'layers': 24, 'vocab_size': 49408, 'ctx_length': 77} + clip.gelu = gelu_erf + clip_encoder = clip.OpenClipEncoder(1024, text_cfg, vision_cfg) + for p in get_parameters(clip_encoder): p.to_(GPUS) + tokens = get_tokens(BS) + imgs = Tensor.zeros(BS,3,224,224).contiguous() + scores = clip_encoder.get_clip_score(tokens.shard(GPUS, axis=0), imgs.shard(GPUS, axis=0)).realize() + self.assertEqual(scores.shape, (BS,)) + self.assertEqual(scores.dtype, dtypes.float32) + +if __name__=="__main__": + unittest.main() \ No newline at end of file