Stable Diffusion model init for mlperf (#12314)

* include clip pr diff

* updated unet and sd init

* dehardcode default device

* revert beam hang workaround

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
hooved
2025-10-02 02:28:41 -04:00
committed by GitHub
parent 0eee93f0c0
commit 0f804c9a83
4 changed files with 200 additions and 37 deletions

View File

@@ -2,7 +2,9 @@ import math
from typing import Union
from tinygrad import Tensor, nn, dtypes
from tinygrad.helpers import prod, argfix
from tinygrad.helpers import prod, argfix, Context
from tinygrad.nn.state import get_parameters
from extra.models.unet import UNetModel
# rejection sampling truncated randn
def rand_truncn(*shape, dtype=None, truncstds=2, **kwargs) -> Tensor:
@@ -131,3 +133,59 @@ class Conv2dRetinaNet(nn.Conv2d):
def __call__(self, x:Tensor) -> Tensor:
return x.conv2d(self.weight.cast(dtypes.default_float), self.bias.cast(dtypes.default_float) if self.bias is not None else None,
groups=self.groups, stride=self.stride, dilation=self.dilation, padding=self.padding)
# copy torch AMP: isolate mixed precision to just the below autocast ops, instead of using dtypes.default_float which affects all new Tensors
class AutocastLinear(nn.Linear):
cast_dtype=dtypes.bfloat16 # enable monkeypatching of the mixed precision dtype
def __call__(self, x:Tensor) -> Tensor:
dtype = type(self).cast_dtype
return x.cast(dtype).linear(self.weight.cast(dtype).transpose(), self.bias.cast(dtype) if self.bias is not None else None)
class AutocastConv2d(nn.Conv2d):
cast_dtype=dtypes.bfloat16
def __call__(self, x:Tensor) -> Tensor:
dtype = type(self).cast_dtype
return x.cast(dtype).conv2d(self.weight.cast(dtype), self.bias.cast(dtype), self.groups, self.stride, self.dilation, self.padding)
# copy torch AMP: upcast to float32 before GroupNorm and LayerNorm
class AutocastGroupNorm(nn.GroupNorm):
def __call__(self, x:Tensor) -> Tensor:
return super().__call__(x.cast(dtypes.float32))
class AutocastLayerNorm(nn.LayerNorm):
def __call__(self, x:Tensor) -> Tensor:
return super().__call__(x.cast(dtypes.float32))
def zero_module(module):
for p in get_parameters(module): p.assign(Tensor.zeros_like(p).contiguous())
# Stable Diffusion mlperf reference doesn't call scaled_dot_product_attention
# copy torch AMP: upcast to float32 before softmax on CUDA
def attn_f32_softmax(q:Tensor, k:Tensor, v:Tensor) -> Tensor:
return (q.matmul(k.transpose(-2,-1), dtype=dtypes.float32) / math.sqrt(q.shape[-1])).softmax(-1).cast(q.dtype) @ v
def init_stable_diffusion(version:str, pretrained:str, devices:list[str]):
from examples.stable_diffusion import StableDiffusion
from tinygrad.nn.state import safe_load, safe_save, load_state_dict, get_state_dict
from tempfile import TemporaryDirectory
model = StableDiffusion(version=version, pretrained=pretrained)
unet:UNetModel = model.model.diffusion_model
# this prevents extra consumption of memory, enabling much larger BS
Tensor.realize(*get_parameters(unet))
with TemporaryDirectory(prefix="unet_init") as tmp:
safe_save(get_state_dict(unet), init_fn:=f"{tmp}/init_model.safetensors")
load_state_dict(unet, safe_load(init_fn))
sqrt_alphas_cumprod = model.alphas_cumprod.sqrt().realize()
sqrt_one_minus_alphas_cumprod = (1 - model.alphas_cumprod).sqrt().realize()
if len(devices) > 1:
to_move = [sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod]
if version == "v2-mlperf-train": to_move += get_parameters(unet) + get_parameters(model.cond_stage_model)
for p in to_move:
p.to_(devices)
with Context(BEAM=0):
Tensor.realize(*to_move)
return model, unet, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod

View File

@@ -9,11 +9,13 @@ from typing import Dict, Any
from PIL import Image
import numpy as np
from tinygrad import Device, GlobalCounters, dtypes, Tensor, TinyJit
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm, flatten
from tinygrad.nn import Conv2d, GroupNorm
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
from extra.models.clip import Closed, Tokenizer
from extra.models.clip import Closed, Tokenizer, FrozenOpenClipEmbedder
from extra.models import unet, clip
from extra.models.unet import UNetModel
from examples.mlperf.initializers import AutocastLinear, AutocastConv2d, AutocastGroupNorm, AutocastLayerNorm, zero_module, attn_f32_softmax, gelu_erf
from extra.bench_log import BenchEvent, WallTimeEvent
class AttnBlock:
@@ -154,12 +156,46 @@ unet_params: Dict[str,Any] = {
"use_linear": False,
}
mlperf_params: Dict[str,Any] = {"adm_in_ch": None, "in_ch": 4, "out_ch": 4, "model_ch": 320, "attention_resolutions": [4, 2, 1], "num_res_blocks": 2,
"channel_mult": [1, 2, 4, 4], "d_head": 64, "transformer_depth": [1, 1, 1, 1], "ctx_dim": 1024, "use_linear": True,
"num_groups":16, "st_norm_eps":1e-6}
class StableDiffusion:
def __init__(self):
def __init__(self, version:str|None=None, pretrained:str|None=None):
self.alphas_cumprod = get_alphas_cumprod()
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel(**unet_params))
self.first_stage_model = AutoencoderKL()
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = Closed.ClipTextTransformer()))
if version != "v2-mlperf-train":
self.first_stage_model = AutoencoderKL() # only needed for decoding generated latents to images; not needed in mlperf training from preprocessed moments
if not version:
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = Closed.ClipTextTransformer()))
unet_init_params = unet_params
elif version in {"v2-mlperf-train", "v2-mlperf-eval"}:
unet_init_params = mlperf_params
clip.gelu = gelu_erf
self.cond_stage_model = FrozenOpenClipEmbedder(**{"dims": 1024, "n_heads": 16, "layers": 24, "return_pooled": False, "ln_penultimate": True,
"clip_tokenizer_version": "sd_mlperf_v5_0"})
unet.Linear, unet.Conv2d, unet.GroupNorm, unet.LayerNorm = AutocastLinear, AutocastConv2d, AutocastGroupNorm, AutocastLayerNorm
unet.attention, unet.gelu, unet.mixed_precision_dtype = attn_f32_softmax, gelu_erf, dtypes.bfloat16
if pretrained:
print("loading text encoder")
weights: dict[str,Tensor] = {k.replace("cond_stage_model.", "", 1):v for k,v in torch_load(pretrained)["state_dict"].items() if k.startswith("cond_stage_model.")}
weights["model.attn_mask"] = Tensor.full((77, 77), fill_value=float("-inf")).triu(1)
load_state_dict(self.cond_stage_model, weights)
# only the eval model needs the decoder
if version == "v2-mlperf-eval":
print("loading image latent encoder")
weights = {k.replace("first_stage_model.", "", 1):v for k,v in torch_load(pretrained)["state_dict"].items() if k.startswith("first_stage_model.")}
load_state_dict(self.first_stage_model, weights)
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel(**unet_init_params))
if version == "v2-mlperf-train":
# the mlperf reference inits certain weights as zeroes
for bb in flatten(self.model.diffusion_model.input_blocks) + self.model.diffusion_model.middle_block + flatten(self.model.diffusion_model.output_blocks):
if isinstance(bb, unet.ResBlock):
zero_module(bb.out_layers[3])
elif isinstance(bb, unet.SpatialTransformer):
zero_module(bb.proj_out)
zero_module(self.model.diffusion_model.out[2])
def get_x_prev_and_pred_x0(self, x, e_t, a_t, a_prev):
temperature = 1