Files
tinygrad/test/external/mlperf_stable_diffusion/external_test_models.py
hooved 0f804c9a83 Stable Diffusion model init for mlperf (#12314)
* include clip pr diff

* updated unet and sd init

* dehardcode default device

* revert beam hang workaround

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
2025-10-02 02:28:41 -04:00

114 lines
6.3 KiB
Python

import unittest
import numpy as np
from pathlib import Path
from tinygrad import Tensor, dtypes, Device
from tinygrad.helpers import getenv
from tinygrad.nn.state import get_parameters
from extra.models import clip
from examples.mlperf.initializers import gelu_erf, init_stable_diffusion, attn_f32_softmax
from typing import Literal
clip_params = {"dims": 1024, "n_heads": 16, "layers": 24, "return_pooled": False, "ln_penultimate": True, "clip_tokenizer_version": "sd_mlperf_v5_0"}
def get_cond_stage_model(GPUS:list[str]|None=None) -> clip.FrozenOpenClipEmbedder:
clip.gelu = gelu_erf
model = clip.FrozenOpenClipEmbedder(**clip_params)
if GPUS and len(GPUS) > 1:
for p in get_parameters(model): p.to_(GPUS)
return model
def get_tokens(BS:int) -> Tensor: return Tensor([0] * 77 * BS, dtype=dtypes.int32).reshape(-1, 77)
class TestOpenClip(unittest.TestCase):
def test_tokenizer(self):
prompt = "Beautiful is better than ugly.\nExplicit is better than implicit.\nSimple is better than complex.\nComplex is better than complicated."
model = get_cond_stage_model()
tokens = model.tokenizer.encode(prompt, pad_with_zeros=True)
expected = [49406, 1215, 533, 1539, 1126, 8159, 269, 33228, 533, 1539, 1126, 15269, 585, 269, 4129, 533, 1539, 1126, 6324, 269, 6324, 533,
1539, 1126, 16621, 269, 49407] + [0]*50
self.assertEqual(tokens, expected)
def test_clip_gelu_init(self):
for resblock in get_cond_stage_model().model.transformer.resblocks:
self.assertEqual(resblock.mlp.gelu, gelu_erf)
def test_multigpu_clip_embed(self):
BS = 304
GPUS = [f"{Device.DEFAULT}:{i}" for i in range(8)]
model = get_cond_stage_model(GPUS)
tokens = get_tokens(BS)
embeds = model.embed_tokens(tokens.shard(GPUS, axis=0)).realize()
self.assertEqual(embeds.shape, (BS, 77, 1024))
self.assertEqual(embeds.dtype, dtypes.float32)
def test_multigpu_clip_score(self):
BS = 240
GPUS = [f"{Device.DEFAULT}:{i}" for i in range(8)]
vision_cfg = {'width': 1280, 'layers': 32, 'd_head': 80, 'image_size': 224, 'patch_size': 14}
text_cfg = {'width': 1024, 'n_heads': 16, 'layers': 24, 'vocab_size': 49408, 'ctx_length': 77}
clip.gelu = gelu_erf
clip_encoder = clip.OpenClipEncoder(1024, text_cfg, vision_cfg)
for p in get_parameters(clip_encoder): p.to_(GPUS)
tokens = get_tokens(BS)
imgs = Tensor.zeros(BS,3,224,224).contiguous()
scores = clip_encoder.get_clip_score(tokens.shard(GPUS, axis=0), imgs.shard(GPUS, axis=0)).realize()
self.assertEqual(scores.shape, (BS,))
self.assertEqual(scores.dtype, dtypes.float32)
class TestInitStableDiffusion(unittest.TestCase):
def setUp(self):
# NOTE: set env variable based on where checkpoints are on the system
self.CKPTDIR = Path(getenv("CKPTDIR", "/raid/weights/stable_diffusion"))
def helper_test_init(self, version:Literal["v2-mlperf-train", "v2-mlperf-eval"]):
model, unet, sqrt_acp, sqrt_omacp = init_stable_diffusion(version, self.CKPTDIR / "sd" / "512-base-ema.ckpt", ["CPU"])
with self.subTest("test that StableDiffusion has correct models"):
self.assertEqual(model.model.diffusion_model, unet)
has_encoder = True if version=="v2-mlperf-eval" else False
self.assertEqual(hasattr(model, "first_stage_model"), has_encoder, "only the eval model uses the encoder")
self.assertTrue(isinstance(model.cond_stage_model, clip.FrozenOpenClipEmbedder))
with self.subTest("test for mlperf unique attributes"):
self.assertEqual(model.cond_stage_model.tokenizer.version, 'sd_mlperf_v5_0')
self.assertEqual(unet.out[0].num_groups, 16)
self.assertEqual(unet.input_blocks[1][1].norm.eps, 1e-6)
self.assertEqual(unet.input_blocks[1][1].transformer_blocks[0].attn1.attn, attn_f32_softmax)
with self.subTest("test loaded clip parameters"):
sample = model.cond_stage_model.model.transformer.resblocks[8].mlp.c_fc.bias.flatten()[42:46].numpy()
expected = np.array([-0.49812260270118713, -0.3039605915546417, -0.40284937620162964, -0.45069342851638794], dtype=np.float32)
np.testing.assert_allclose(sample, expected, rtol=1e-7, atol=0, err_msg="loaded clip parameters are incorrect")
if version=="v2-mlperf-train":
with self.subTest("test that zero_module worked"):
self.assertTrue((unet.out[2].weight == 0).all().item(), "expected all zeroes")
self.assertTrue((unet.out[2].bias == 0).all().item(), "expected all zeroes")
elif version=="v2-mlperf-eval":
with self.subTest("test loaded vae parameters"):
sample = model.first_stage_model.decoder.up[0]['block'][1].conv2.weight.flatten()[42:46].numpy()
expected = np.array([0.08192943036556244, 0.040095631033182144, 0.07541035860776901, 0.1475081741809845], dtype=np.float32)
np.testing.assert_allclose(sample, expected, rtol=1e-7, atol=0, err_msg="loaded vae parameters are incorrect")
with self.subTest("check schedules"):
expected = np.array([0.9995748996734619, 0.06826484948396683], dtype=np.float32)
np.testing.assert_allclose(sqrt_acp[[0,-1]].numpy(), expected, rtol=1e-7, atol=0, err_msg="sqrt_acp is incorrect")
expected = np.array([0.029155133292078972, 0.9976672530174255], dtype=np.float32)
np.testing.assert_allclose(sqrt_omacp[[0,-1]].numpy(), expected, rtol=1e-7, atol=0, err_msg="sqrt_omacp is incorrect")
with self.subTest("check mixed precision"):
out = unet.input_blocks[2][1].proj_in(Tensor.randn(320, dtype=dtypes.float32))
self.assertEqual(out.dtype, dtypes.bfloat16, "expected float32 to be downcast to bfloat16 by Linear")
out = unet.out[2](Tensor.randn(304,320,64,64, dtype=dtypes.float32))
self.assertEqual(out.dtype, dtypes.bfloat16, "expected float32 to be downcast to bfloat16 by Conv2d")
out = unet.input_blocks[1][1].transformer_blocks[0].norm1(Tensor.randn(320, dtype=dtypes.bfloat16))
self.assertEqual(out.dtype, dtypes.float32, "expected bfloat16 to be upcast to float32 by LayerNorm")
out = unet.input_blocks[5][0].in_layers[0](Tensor.randn(304, 640, dtype=dtypes.bfloat16))
self.assertEqual(out.dtype, dtypes.float32, "expected bfloat16 to be upcast to float32 by GroupNorm")
def test_train_model(self):
self.helper_test_init("v2-mlperf-train")
def test_eval_model(self):
self.helper_test_init("v2-mlperf-eval")
if __name__=="__main__":
unittest.main()