Clip model updates for Stable Diffusion mlperf training (#12313)

* stable diffusion mlperf clip changes

* add clip tests

* set gelu as attribute

* add more tests

* factor out GPUS

* rerun CI

* add imports to if blocks

* remove unneeded axis

* add clip tests to CI

* move clip tests

* add deps, disable max buf size
This commit is contained in:
hooved
2025-09-29 21:50:14 -04:00
committed by GitHub
parent cdfa0f29fd
commit c2689c505e
4 changed files with 92 additions and 16 deletions

View File

@@ -259,7 +259,7 @@ jobs:
uses: ./.github/actions/setup-tinygrad
with:
key: unittest-12
pydeps: "pillow"
pydeps: "pillow numpy ftfy regex"
deps: testing_unit
- name: Run unit tests
run: python -m pytest -n=auto test/unit/ --durations=20
@@ -267,6 +267,8 @@ jobs:
run: NULL=1 python3 test/test_multitensor.py TestMultiTensor.test_data_parallel_resnet_train_step
- name: Run SDXL on NULL backend
run: MAX_BUFFER_SIZE=0 NULL=1 DEBUG=1 python3 examples/sdxl.py --seed 0 --noshow --timing --fakeweights
- name: Run Clip tests for SD MLPerf on NULL backend
run: MAX_BUFFER_SIZE=0 NULL=1 python -m pytest -n=auto test/external/mlperf_stable_diffusion/external_test_models.py::TestOpenClip --durations=20
# TODO: support fake weights
#- name: Run LLaMA 7B on 4 fake devices
# run: NULL=1 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 3 --temperature 0 --timing

View File

@@ -17,6 +17,10 @@ def he_normal(*shape, a: float = 0.00, **kwargs) -> Tensor:
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:])) / 0.87962566103423978
return std * rand_truncn(*shape, **kwargs)
# Stable Diffusion v2 training uses default torch gelu, which doesn't use tanh approximation
def gelu_erf(x:Tensor) -> Tensor:
return 0.5 * x * (1.0 + (x / 1.4142135623730951).erf())
class Conv2dHeNormal(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)

View File

@@ -9,6 +9,9 @@ from PIL import Image
import numpy as np
import re, gzip
# Allow for monkeypatching for mlperf.
gelu = Tensor.gelu
@lru_cache()
def default_bpe():
# Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
@@ -53,8 +56,8 @@ class Tokenizer:
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
class ClipTokenizer:
def __init__(self):
self.byte_encoder = Tokenizer.bytes_to_unicode()
def __init__(self, version=None):
self.byte_encoder, self.version = Tokenizer.bytes_to_unicode(), version
merges = gzip.open(default_bpe()).read().decode("utf-8").split('\n')
merges = merges[1:49152-256-2+1]
merges = [tuple(merge.split()) for merge in merges]
@@ -62,11 +65,17 @@ class Tokenizer:
vocab = vocab + [v+'</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
if self.version == "sd_mlperf_v5_0":
import regex
vocab.extend(['<start_of_text>', '<end_of_text>'])
self.cache = {'<start_of_text>': '<start_of_text>', '<end_of_text>': '<end_of_text>'}
self.pat = regex.compile(r"""<start_of_text>|<end_of_text>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", regex.IGNORECASE)
else:
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE)
self.encoder = dict(zip(vocab, range(len(vocab))))
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE)
def bpe(self, token):
if token in self.cache:
@@ -110,8 +119,17 @@ class Tokenizer:
def encode(self, text:str, pad_with_zeros:bool=False) -> List[int]:
bpe_tokens: List[int] = []
text = Tokenizer.whitespace_clean(text.strip()).lower()
for token in re.findall(self.pat, text):
if self.version == "sd_mlperf_v5_0":
import regex, ftfy, html
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text)).strip()
text = Tokenizer.whitespace_clean(text).lower()
re_module = regex
else:
text = Tokenizer.whitespace_clean(text.strip()).lower()
re_module = re
for token in re_module.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
# Truncation, keeping two slots for start and end tokens.
@@ -252,10 +270,8 @@ class Open:
q,k,v = [y.reshape(T, B*self.n_heads, self.d_head).transpose(0, 1).reshape(B, self.n_heads, T, self.d_head) for y in proj.chunk(3)]
attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
attn_output = attn_output.permute(2, 0, 1, 3).reshape(T*B, C)
attn_output = attn_output.permute(2, 0, 1, 3).reshape(T, B, C)
attn_output = self.out_proj(attn_output)
attn_output = attn_output.reshape(T, B, C)
return attn_output
@@ -263,9 +279,10 @@ class Open:
def __init__(self, dims, hidden_dims):
self.c_fc = Linear(dims, hidden_dims)
self.c_proj = Linear(hidden_dims, dims)
self.gelu = gelu
def __call__(self, x:Tensor) -> Tensor:
return x.sequential([self.c_fc, Tensor.gelu, self.c_proj])
return x.sequential([self.c_fc, self.gelu, self.c_proj])
# https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L210
class ResidualAttentionBlock:
@@ -350,15 +367,15 @@ class Open:
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L396
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L498
class FrozenOpenClipEmbedder(Embedder):
def __init__(self, dims:int, n_heads:int, layers:int, return_pooled:bool, ln_penultimate:bool=False):
self.tokenizer = Tokenizer.ClipTokenizer()
def __init__(self, dims:int, n_heads:int, layers:int, return_pooled:bool, ln_penultimate:bool=False, clip_tokenizer_version=None):
self.tokenizer = Tokenizer.ClipTokenizer(version=clip_tokenizer_version)
self.model = Open.ClipTextTransformer(dims, n_heads, layers)
self.return_pooled = return_pooled
self.input_key = "txt"
self.ln_penultimate = ln_penultimate
def tokenize(self, text:str, device:Optional[str]=None) -> Tensor:
return Tensor(self.tokenizer.encode(text, pad_with_zeros=True), dtype=dtypes.int64, device=device).reshape(1,-1)
return Tensor(self.tokenizer.encode(text, pad_with_zeros=True), dtype=dtypes.int32, device=device).reshape(1,-1)
def text_transformer_forward(self, x:Tensor, attn_mask:Optional[Tensor]=None):
for r in self.model.transformer.resblocks:
@@ -449,7 +466,7 @@ class OpenClipEncoder:
x = x + self.positional_embedding
x = self.transformer(x, attn_mask=self.attn_mask)
x = self.ln_final(x)
x = x[:, tokens.argmax(axis=-1)]
x = x[Tensor.arange(x.shape[0], device=x.device), tokens.argmax(axis=-1)]
x = x @ self.text_projection
return x

View File

@@ -0,0 +1,53 @@
import unittest
from tinygrad import Tensor, dtypes, Device
from tinygrad.nn.state import get_parameters
from extra.models import clip
from examples.mlperf.initializers import gelu_erf
Device.DEFAULT="NULL"
GPUS = [f"NULL:{i}" for i in range(8)]
clip_params = {"dims": 1024, "n_heads": 16, "layers": 24, "return_pooled": False, "ln_penultimate": True, "clip_tokenizer_version": "sd_mlperf_v5_0"}
def get_cond_stage_model(GPUS:list[str]|None=None) -> clip.FrozenOpenClipEmbedder:
clip.gelu = gelu_erf
model = clip.FrozenOpenClipEmbedder(**clip_params)
if GPUS and len(GPUS) > 1:
for p in get_parameters(model): p.to_(GPUS)
return model
def get_tokens(BS:int) -> Tensor: return Tensor([0] * 77 * BS, dtype=dtypes.int32).reshape(-1, 77)
class TestOpenClip(unittest.TestCase):
def test_tokenizer(self):
prompt = "Beautiful is better than ugly.\nExplicit is better than implicit.\nSimple is better than complex.\nComplex is better than complicated."
model = get_cond_stage_model()
tokens = model.tokenizer.encode(prompt, pad_with_zeros=True)
expected = [49406, 1215, 533, 1539, 1126, 8159, 269, 33228, 533, 1539, 1126, 15269, 585, 269, 4129, 533, 1539, 1126, 6324, 269, 6324, 533,
1539, 1126, 16621, 269, 49407] + [0]*50
self.assertEqual(tokens, expected)
def test_clip_gelu_init(self):
for resblock in get_cond_stage_model().model.transformer.resblocks:
self.assertEqual(resblock.mlp.gelu, gelu_erf)
def test_multigpu_clip_embed(self):
BS = 304
model = get_cond_stage_model(GPUS)
tokens = get_tokens(BS)
embeds = model.embed_tokens(tokens.shard(GPUS, axis=0)).realize()
self.assertEqual(embeds.shape, (BS, 77, 1024))
self.assertEqual(embeds.dtype, dtypes.float32)
def test_multigpu_clip_score(self):
BS = 240
vision_cfg = {'width': 1280, 'layers': 32, 'd_head': 80, 'image_size': 224, 'patch_size': 14}
text_cfg = {'width': 1024, 'n_heads': 16, 'layers': 24, 'vocab_size': 49408, 'ctx_length': 77}
clip.gelu = gelu_erf
clip_encoder = clip.OpenClipEncoder(1024, text_cfg, vision_cfg)
for p in get_parameters(clip_encoder): p.to_(GPUS)
tokens = get_tokens(BS)
imgs = Tensor.zeros(BS,3,224,224).contiguous()
scores = clip_encoder.get_clip_score(tokens.shard(GPUS, axis=0), imgs.shard(GPUS, axis=0)).realize()
self.assertEqual(scores.shape, (BS,))
self.assertEqual(scores.dtype, dtypes.float32)
if __name__=="__main__":
unittest.main()