Stable Diffusion model init for mlperf (#12314)

* include clip pr diff

* updated unet and sd init

* dehardcode default device

* revert beam hang workaround

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
hooved
2025-10-02 02:28:41 -04:00
committed by GitHub
parent 0eee93f0c0
commit 0f804c9a83
4 changed files with 200 additions and 37 deletions

View File

@@ -2,7 +2,9 @@ import math
from typing import Union
from tinygrad import Tensor, nn, dtypes
from tinygrad.helpers import prod, argfix
from tinygrad.helpers import prod, argfix, Context
from tinygrad.nn.state import get_parameters
from extra.models.unet import UNetModel
# rejection sampling truncated randn
def rand_truncn(*shape, dtype=None, truncstds=2, **kwargs) -> Tensor:
@@ -131,3 +133,59 @@ class Conv2dRetinaNet(nn.Conv2d):
def __call__(self, x:Tensor) -> Tensor:
return x.conv2d(self.weight.cast(dtypes.default_float), self.bias.cast(dtypes.default_float) if self.bias is not None else None,
groups=self.groups, stride=self.stride, dilation=self.dilation, padding=self.padding)
# copy torch AMP: isolate mixed precision to just the below autocast ops, instead of using dtypes.default_float which affects all new Tensors
class AutocastLinear(nn.Linear):
cast_dtype=dtypes.bfloat16 # enable monkeypatching of the mixed precision dtype
def __call__(self, x:Tensor) -> Tensor:
dtype = type(self).cast_dtype
return x.cast(dtype).linear(self.weight.cast(dtype).transpose(), self.bias.cast(dtype) if self.bias is not None else None)
class AutocastConv2d(nn.Conv2d):
cast_dtype=dtypes.bfloat16
def __call__(self, x:Tensor) -> Tensor:
dtype = type(self).cast_dtype
return x.cast(dtype).conv2d(self.weight.cast(dtype), self.bias.cast(dtype), self.groups, self.stride, self.dilation, self.padding)
# copy torch AMP: upcast to float32 before GroupNorm and LayerNorm
class AutocastGroupNorm(nn.GroupNorm):
def __call__(self, x:Tensor) -> Tensor:
return super().__call__(x.cast(dtypes.float32))
class AutocastLayerNorm(nn.LayerNorm):
def __call__(self, x:Tensor) -> Tensor:
return super().__call__(x.cast(dtypes.float32))
def zero_module(module):
for p in get_parameters(module): p.assign(Tensor.zeros_like(p).contiguous())
# Stable Diffusion mlperf reference doesn't call scaled_dot_product_attention
# copy torch AMP: upcast to float32 before softmax on CUDA
def attn_f32_softmax(q:Tensor, k:Tensor, v:Tensor) -> Tensor:
return (q.matmul(k.transpose(-2,-1), dtype=dtypes.float32) / math.sqrt(q.shape[-1])).softmax(-1).cast(q.dtype) @ v
def init_stable_diffusion(version:str, pretrained:str, devices:list[str]):
from examples.stable_diffusion import StableDiffusion
from tinygrad.nn.state import safe_load, safe_save, load_state_dict, get_state_dict
from tempfile import TemporaryDirectory
model = StableDiffusion(version=version, pretrained=pretrained)
unet:UNetModel = model.model.diffusion_model
# this prevents extra consumption of memory, enabling much larger BS
Tensor.realize(*get_parameters(unet))
with TemporaryDirectory(prefix="unet_init") as tmp:
safe_save(get_state_dict(unet), init_fn:=f"{tmp}/init_model.safetensors")
load_state_dict(unet, safe_load(init_fn))
sqrt_alphas_cumprod = model.alphas_cumprod.sqrt().realize()
sqrt_one_minus_alphas_cumprod = (1 - model.alphas_cumprod).sqrt().realize()
if len(devices) > 1:
to_move = [sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod]
if version == "v2-mlperf-train": to_move += get_parameters(unet) + get_parameters(model.cond_stage_model)
for p in to_move:
p.to_(devices)
with Context(BEAM=0):
Tensor.realize(*to_move)
return model, unet, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod

View File

@@ -9,11 +9,13 @@ from typing import Dict, Any
from PIL import Image
import numpy as np
from tinygrad import Device, GlobalCounters, dtypes, Tensor, TinyJit
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm, flatten
from tinygrad.nn import Conv2d, GroupNorm
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
from extra.models.clip import Closed, Tokenizer
from extra.models.clip import Closed, Tokenizer, FrozenOpenClipEmbedder
from extra.models import unet, clip
from extra.models.unet import UNetModel
from examples.mlperf.initializers import AutocastLinear, AutocastConv2d, AutocastGroupNorm, AutocastLayerNorm, zero_module, attn_f32_softmax, gelu_erf
from extra.bench_log import BenchEvent, WallTimeEvent
class AttnBlock:
@@ -154,12 +156,46 @@ unet_params: Dict[str,Any] = {
"use_linear": False,
}
mlperf_params: Dict[str,Any] = {"adm_in_ch": None, "in_ch": 4, "out_ch": 4, "model_ch": 320, "attention_resolutions": [4, 2, 1], "num_res_blocks": 2,
"channel_mult": [1, 2, 4, 4], "d_head": 64, "transformer_depth": [1, 1, 1, 1], "ctx_dim": 1024, "use_linear": True,
"num_groups":16, "st_norm_eps":1e-6}
class StableDiffusion:
def __init__(self):
def __init__(self, version:str|None=None, pretrained:str|None=None):
self.alphas_cumprod = get_alphas_cumprod()
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel(**unet_params))
self.first_stage_model = AutoencoderKL()
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = Closed.ClipTextTransformer()))
if version != "v2-mlperf-train":
self.first_stage_model = AutoencoderKL() # only needed for decoding generated latents to images; not needed in mlperf training from preprocessed moments
if not version:
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = Closed.ClipTextTransformer()))
unet_init_params = unet_params
elif version in {"v2-mlperf-train", "v2-mlperf-eval"}:
unet_init_params = mlperf_params
clip.gelu = gelu_erf
self.cond_stage_model = FrozenOpenClipEmbedder(**{"dims": 1024, "n_heads": 16, "layers": 24, "return_pooled": False, "ln_penultimate": True,
"clip_tokenizer_version": "sd_mlperf_v5_0"})
unet.Linear, unet.Conv2d, unet.GroupNorm, unet.LayerNorm = AutocastLinear, AutocastConv2d, AutocastGroupNorm, AutocastLayerNorm
unet.attention, unet.gelu, unet.mixed_precision_dtype = attn_f32_softmax, gelu_erf, dtypes.bfloat16
if pretrained:
print("loading text encoder")
weights: dict[str,Tensor] = {k.replace("cond_stage_model.", "", 1):v for k,v in torch_load(pretrained)["state_dict"].items() if k.startswith("cond_stage_model.")}
weights["model.attn_mask"] = Tensor.full((77, 77), fill_value=float("-inf")).triu(1)
load_state_dict(self.cond_stage_model, weights)
# only the eval model needs the decoder
if version == "v2-mlperf-eval":
print("loading image latent encoder")
weights = {k.replace("first_stage_model.", "", 1):v for k,v in torch_load(pretrained)["state_dict"].items() if k.startswith("first_stage_model.")}
load_state_dict(self.first_stage_model, weights)
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel(**unet_init_params))
if version == "v2-mlperf-train":
# the mlperf reference inits certain weights as zeroes
for bb in flatten(self.model.diffusion_model.input_blocks) + self.model.diffusion_model.middle_block + flatten(self.model.diffusion_model.output_blocks):
if isinstance(bb, unet.ResBlock):
zero_module(bb.out_layers[3])
elif isinstance(bb, unet.SpatialTransformer):
zero_module(bb.proj_out)
zero_module(self.model.diffusion_model.out[2])
def get_x_prev_and_pred_x0(self, x, e_t, a_t, a_prev):
temperature = 1

View File

@@ -1,21 +1,24 @@
from tinygrad import Tensor, dtypes
from tinygrad.nn import Linear, Conv2d, GroupNorm, LayerNorm
from tinygrad import Tensor, dtypes, nn
from tinygrad.device import is_dtype_supported
from typing import Optional, Union, List, Any, Tuple
from typing import Optional, Union, List, Any, Tuple, Callable
import math
# allow for monkeypatching
Linear, Conv2d, GroupNorm, LayerNorm = nn.Linear, nn.Conv2d, nn.GroupNorm, nn.LayerNorm
attention, gelu, mixed_precision_dtype = Tensor.scaled_dot_product_attention, Tensor.gelu, dtypes.float16
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/util.py#L207
def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000):
half = dim // 2
freqs = (-math.log(max_period) * Tensor.arange(half, device=timesteps.device) / half).exp()
args = timesteps.unsqueeze(1) * freqs.unsqueeze(0)
out = Tensor.cat(args.cos(), args.sin(), dim=-1)
return out.cast(dtypes.float16) if is_dtype_supported(dtypes.float16) else out
return out.cast(mixed_precision_dtype) if is_dtype_supported(mixed_precision_dtype) else out
class ResBlock:
def __init__(self, channels:int, emb_channels:int, out_channels:int):
def __init__(self, channels:int, emb_channels:int, out_channels:int, num_groups:int=32):
self.in_layers = [
GroupNorm(32, channels),
GroupNorm(num_groups, channels),
Tensor.silu,
Conv2d(channels, out_channels, 3, padding=1),
]
@@ -24,7 +27,7 @@ class ResBlock:
Linear(emb_channels, out_channels),
]
self.out_layers = [
GroupNorm(32, out_channels),
GroupNorm(num_groups, out_channels),
Tensor.silu,
lambda x: x, # needed for weights loading code to work
Conv2d(out_channels, out_channels, 3, padding=1),
@@ -45,35 +48,37 @@ class CrossAttention:
self.to_v = Linear(ctx_dim, n_heads*d_head, bias=False)
self.num_heads = n_heads
self.head_size = d_head
self.attn = attention
self.to_out = [Linear(n_heads*d_head, query_dim)]
def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
ctx = x if ctx is None else ctx
q,k,v = self.to_q(x), self.to_k(ctx), self.to_v(ctx)
q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)]
attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1,2)
attention = self.attn(q, k, v).transpose(1,2)
h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size)
return h_.sequential(self.to_out)
class GEGLU:
def __init__(self, dim_in:int, dim_out:int):
self.proj = Linear(dim_in, dim_out * 2)
self.gelu = gelu
self.dim_out = dim_out
def __call__(self, x:Tensor) -> Tensor:
x, gate = self.proj(x).chunk(2, dim=-1)
return x * gate.gelu()
return x * self.gelu(gate)
class FeedForward:
def __init__(self, dim:int, mult:int=4):
self.net = [
self.net: tuple[GEGLU, Callable, nn.Linear] = (
GEGLU(dim, dim*mult),
lambda x: x, # needed for weights loading code to work
Linear(dim*mult, dim)
]
)
def __call__(self, x:Tensor) -> Tensor:
return x.sequential(self.net)
return x.sequential(list(self.net))
class BasicTransformerBlock:
def __init__(self, dim:int, ctx_dim:int, n_heads:int, d_head:int):
@@ -92,12 +97,13 @@ class BasicTransformerBlock:
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/attention.py#L619
class SpatialTransformer:
def __init__(self, channels:int, n_heads:int, d_head:int, ctx_dim:Union[int,List[int]], use_linear:bool, depth:int=1):
def __init__(self, channels:int, n_heads:int, d_head:int, ctx_dim:Union[int,List[int]], use_linear:bool, depth:int=1,
norm_eps:float=1e-5):
if isinstance(ctx_dim, int):
ctx_dim = [ctx_dim]*depth
else:
assert isinstance(ctx_dim, list) and depth == len(ctx_dim)
self.norm = GroupNorm(32, channels)
self.norm = GroupNorm(32, channels, eps=norm_eps)
assert channels == n_heads * d_head
self.proj_in = Linear(channels, channels) if use_linear else Conv2d(channels, channels, 1)
self.transformer_blocks = [BasicTransformerBlock(channels, ctx_dim[d], n_heads, d_head) for d in range(depth)]
@@ -134,7 +140,9 @@ class Upsample:
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/openaimodel.py#L472
class UNetModel:
def __init__(self, adm_in_ch:Optional[int], in_ch:int, out_ch:int, model_ch:int, attention_resolutions:List[int], num_res_blocks:int, channel_mult:List[int], transformer_depth:List[int], ctx_dim:Union[int,List[int]], use_linear:bool=False, d_head:Optional[int]=None, n_heads:Optional[int]=None):
def __init__(self, adm_in_ch:Optional[int], in_ch:int, out_ch:int, model_ch:int, attention_resolutions:List[int], num_res_blocks:int,
channel_mult:List[int], transformer_depth:List[int], ctx_dim:Union[int,List[int]], use_linear:bool=False, d_head:Optional[int]=None,
n_heads:Optional[int]=None, num_groups:int=32, st_norm_eps:float=1e-5):
self.model_ch = model_ch
self.num_res_blocks = [num_res_blocks] * len(channel_mult)
@@ -174,12 +182,12 @@ class UNetModel:
for idx, mult in enumerate(channel_mult):
for _ in range(self.num_res_blocks[idx]):
layers: List[Any] = [
ResBlock(ch, time_embed_dim, model_ch*mult),
ResBlock(ch, time_embed_dim, model_ch*mult, num_groups),
]
ch = mult * model_ch
if ds in attention_resolutions:
d_head, n_heads = get_d_and_n_heads(ch)
layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx]))
layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx], norm_eps=st_norm_eps))
self.input_blocks.append(layers)
input_block_channels.append(ch)
@@ -193,9 +201,9 @@ class UNetModel:
d_head, n_heads = get_d_and_n_heads(ch)
self.middle_block: List = [
ResBlock(ch, time_embed_dim, ch),
SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[-1]),
ResBlock(ch, time_embed_dim, ch),
ResBlock(ch, time_embed_dim, ch, num_groups),
SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[-1], norm_eps=st_norm_eps),
ResBlock(ch, time_embed_dim, ch, num_groups),
]
self.output_blocks = []
@@ -203,13 +211,13 @@ class UNetModel:
for i in range(self.num_res_blocks[idx] + 1):
ich = input_block_channels.pop()
layers = [
ResBlock(ch + ich, time_embed_dim, model_ch*mult),
ResBlock(ch + ich, time_embed_dim, model_ch*mult, num_groups),
]
ch = model_ch * mult
if ds in attention_resolutions:
d_head, n_heads = get_d_and_n_heads(ch)
layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx]))
layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx], norm_eps=st_norm_eps))
if idx > 0 and i == self.num_res_blocks[idx]:
layers.append(Upsample(ch))
@@ -217,7 +225,7 @@ class UNetModel:
self.output_blocks.append(layers)
self.out = [
GroupNorm(32, ch),
GroupNorm(num_groups, ch),
Tensor.silu,
Conv2d(model_ch, out_ch, 3, padding=1),
]
@@ -230,10 +238,10 @@ class UNetModel:
assert y.shape[0] == x.shape[0]
emb = emb + y.sequential(self.label_emb[0])
if is_dtype_supported(dtypes.float16):
emb = emb.cast(dtypes.float16)
ctx = ctx.cast(dtypes.float16)
x = x .cast(dtypes.float16)
if is_dtype_supported(mixed_precision_dtype):
emb = emb.cast(mixed_precision_dtype)
ctx = ctx.cast(mixed_precision_dtype)
x = x .cast(mixed_precision_dtype)
def run(x:Tensor, bb) -> Tensor:
if isinstance(bb, ResBlock): x = bb(x, emb)

View File

@@ -1,10 +1,12 @@
import unittest
import numpy as np
from pathlib import Path
from tinygrad import Tensor, dtypes, Device
from tinygrad.helpers import getenv
from tinygrad.nn.state import get_parameters
from extra.models import clip
from examples.mlperf.initializers import gelu_erf
Device.DEFAULT="NULL"
GPUS = [f"NULL:{i}" for i in range(8)]
from examples.mlperf.initializers import gelu_erf, init_stable_diffusion, attn_f32_softmax
from typing import Literal
clip_params = {"dims": 1024, "n_heads": 16, "layers": 24, "return_pooled": False, "ln_penultimate": True, "clip_tokenizer_version": "sd_mlperf_v5_0"}
def get_cond_stage_model(GPUS:list[str]|None=None) -> clip.FrozenOpenClipEmbedder:
@@ -30,6 +32,7 @@ class TestOpenClip(unittest.TestCase):
def test_multigpu_clip_embed(self):
BS = 304
GPUS = [f"{Device.DEFAULT}:{i}" for i in range(8)]
model = get_cond_stage_model(GPUS)
tokens = get_tokens(BS)
embeds = model.embed_tokens(tokens.shard(GPUS, axis=0)).realize()
@@ -38,6 +41,7 @@ class TestOpenClip(unittest.TestCase):
def test_multigpu_clip_score(self):
BS = 240
GPUS = [f"{Device.DEFAULT}:{i}" for i in range(8)]
vision_cfg = {'width': 1280, 'layers': 32, 'd_head': 80, 'image_size': 224, 'patch_size': 14}
text_cfg = {'width': 1024, 'n_heads': 16, 'layers': 24, 'vocab_size': 49408, 'ctx_length': 77}
clip.gelu = gelu_erf
@@ -49,5 +53,62 @@ class TestOpenClip(unittest.TestCase):
self.assertEqual(scores.shape, (BS,))
self.assertEqual(scores.dtype, dtypes.float32)
class TestInitStableDiffusion(unittest.TestCase):
def setUp(self):
# NOTE: set env variable based on where checkpoints are on the system
self.CKPTDIR = Path(getenv("CKPTDIR", "/raid/weights/stable_diffusion"))
def helper_test_init(self, version:Literal["v2-mlperf-train", "v2-mlperf-eval"]):
model, unet, sqrt_acp, sqrt_omacp = init_stable_diffusion(version, self.CKPTDIR / "sd" / "512-base-ema.ckpt", ["CPU"])
with self.subTest("test that StableDiffusion has correct models"):
self.assertEqual(model.model.diffusion_model, unet)
has_encoder = True if version=="v2-mlperf-eval" else False
self.assertEqual(hasattr(model, "first_stage_model"), has_encoder, "only the eval model uses the encoder")
self.assertTrue(isinstance(model.cond_stage_model, clip.FrozenOpenClipEmbedder))
with self.subTest("test for mlperf unique attributes"):
self.assertEqual(model.cond_stage_model.tokenizer.version, 'sd_mlperf_v5_0')
self.assertEqual(unet.out[0].num_groups, 16)
self.assertEqual(unet.input_blocks[1][1].norm.eps, 1e-6)
self.assertEqual(unet.input_blocks[1][1].transformer_blocks[0].attn1.attn, attn_f32_softmax)
with self.subTest("test loaded clip parameters"):
sample = model.cond_stage_model.model.transformer.resblocks[8].mlp.c_fc.bias.flatten()[42:46].numpy()
expected = np.array([-0.49812260270118713, -0.3039605915546417, -0.40284937620162964, -0.45069342851638794], dtype=np.float32)
np.testing.assert_allclose(sample, expected, rtol=1e-7, atol=0, err_msg="loaded clip parameters are incorrect")
if version=="v2-mlperf-train":
with self.subTest("test that zero_module worked"):
self.assertTrue((unet.out[2].weight == 0).all().item(), "expected all zeroes")
self.assertTrue((unet.out[2].bias == 0).all().item(), "expected all zeroes")
elif version=="v2-mlperf-eval":
with self.subTest("test loaded vae parameters"):
sample = model.first_stage_model.decoder.up[0]['block'][1].conv2.weight.flatten()[42:46].numpy()
expected = np.array([0.08192943036556244, 0.040095631033182144, 0.07541035860776901, 0.1475081741809845], dtype=np.float32)
np.testing.assert_allclose(sample, expected, rtol=1e-7, atol=0, err_msg="loaded vae parameters are incorrect")
with self.subTest("check schedules"):
expected = np.array([0.9995748996734619, 0.06826484948396683], dtype=np.float32)
np.testing.assert_allclose(sqrt_acp[[0,-1]].numpy(), expected, rtol=1e-7, atol=0, err_msg="sqrt_acp is incorrect")
expected = np.array([0.029155133292078972, 0.9976672530174255], dtype=np.float32)
np.testing.assert_allclose(sqrt_omacp[[0,-1]].numpy(), expected, rtol=1e-7, atol=0, err_msg="sqrt_omacp is incorrect")
with self.subTest("check mixed precision"):
out = unet.input_blocks[2][1].proj_in(Tensor.randn(320, dtype=dtypes.float32))
self.assertEqual(out.dtype, dtypes.bfloat16, "expected float32 to be downcast to bfloat16 by Linear")
out = unet.out[2](Tensor.randn(304,320,64,64, dtype=dtypes.float32))
self.assertEqual(out.dtype, dtypes.bfloat16, "expected float32 to be downcast to bfloat16 by Conv2d")
out = unet.input_blocks[1][1].transformer_blocks[0].norm1(Tensor.randn(320, dtype=dtypes.bfloat16))
self.assertEqual(out.dtype, dtypes.float32, "expected bfloat16 to be upcast to float32 by LayerNorm")
out = unet.input_blocks[5][0].in_layers[0](Tensor.randn(304, 640, dtype=dtypes.bfloat16))
self.assertEqual(out.dtype, dtypes.float32, "expected bfloat16 to be upcast to float32 by GroupNorm")
def test_train_model(self):
self.helper_test_init("v2-mlperf-train")
def test_eval_model(self):
self.helper_test_init("v2-mlperf-eval")
if __name__=="__main__":
unittest.main()