mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
Stable Diffusion model init for mlperf (#12314)
* include clip pr diff * updated unet and sd init * dehardcode default device * revert beam hang workaround --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -2,7 +2,9 @@ import math
|
||||
from typing import Union
|
||||
|
||||
from tinygrad import Tensor, nn, dtypes
|
||||
from tinygrad.helpers import prod, argfix
|
||||
from tinygrad.helpers import prod, argfix, Context
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from extra.models.unet import UNetModel
|
||||
|
||||
# rejection sampling truncated randn
|
||||
def rand_truncn(*shape, dtype=None, truncstds=2, **kwargs) -> Tensor:
|
||||
@@ -131,3 +133,59 @@ class Conv2dRetinaNet(nn.Conv2d):
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return x.conv2d(self.weight.cast(dtypes.default_float), self.bias.cast(dtypes.default_float) if self.bias is not None else None,
|
||||
groups=self.groups, stride=self.stride, dilation=self.dilation, padding=self.padding)
|
||||
|
||||
# copy torch AMP: isolate mixed precision to just the below autocast ops, instead of using dtypes.default_float which affects all new Tensors
|
||||
class AutocastLinear(nn.Linear):
|
||||
cast_dtype=dtypes.bfloat16 # enable monkeypatching of the mixed precision dtype
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
dtype = type(self).cast_dtype
|
||||
return x.cast(dtype).linear(self.weight.cast(dtype).transpose(), self.bias.cast(dtype) if self.bias is not None else None)
|
||||
|
||||
class AutocastConv2d(nn.Conv2d):
|
||||
cast_dtype=dtypes.bfloat16
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
dtype = type(self).cast_dtype
|
||||
return x.cast(dtype).conv2d(self.weight.cast(dtype), self.bias.cast(dtype), self.groups, self.stride, self.dilation, self.padding)
|
||||
|
||||
# copy torch AMP: upcast to float32 before GroupNorm and LayerNorm
|
||||
class AutocastGroupNorm(nn.GroupNorm):
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return super().__call__(x.cast(dtypes.float32))
|
||||
|
||||
class AutocastLayerNorm(nn.LayerNorm):
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return super().__call__(x.cast(dtypes.float32))
|
||||
|
||||
def zero_module(module):
|
||||
for p in get_parameters(module): p.assign(Tensor.zeros_like(p).contiguous())
|
||||
|
||||
# Stable Diffusion mlperf reference doesn't call scaled_dot_product_attention
|
||||
# copy torch AMP: upcast to float32 before softmax on CUDA
|
||||
def attn_f32_softmax(q:Tensor, k:Tensor, v:Tensor) -> Tensor:
|
||||
return (q.matmul(k.transpose(-2,-1), dtype=dtypes.float32) / math.sqrt(q.shape[-1])).softmax(-1).cast(q.dtype) @ v
|
||||
|
||||
def init_stable_diffusion(version:str, pretrained:str, devices:list[str]):
|
||||
from examples.stable_diffusion import StableDiffusion
|
||||
from tinygrad.nn.state import safe_load, safe_save, load_state_dict, get_state_dict
|
||||
from tempfile import TemporaryDirectory
|
||||
model = StableDiffusion(version=version, pretrained=pretrained)
|
||||
unet:UNetModel = model.model.diffusion_model
|
||||
|
||||
# this prevents extra consumption of memory, enabling much larger BS
|
||||
Tensor.realize(*get_parameters(unet))
|
||||
with TemporaryDirectory(prefix="unet_init") as tmp:
|
||||
safe_save(get_state_dict(unet), init_fn:=f"{tmp}/init_model.safetensors")
|
||||
load_state_dict(unet, safe_load(init_fn))
|
||||
|
||||
sqrt_alphas_cumprod = model.alphas_cumprod.sqrt().realize()
|
||||
sqrt_one_minus_alphas_cumprod = (1 - model.alphas_cumprod).sqrt().realize()
|
||||
|
||||
if len(devices) > 1:
|
||||
to_move = [sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod]
|
||||
if version == "v2-mlperf-train": to_move += get_parameters(unet) + get_parameters(model.cond_stage_model)
|
||||
for p in to_move:
|
||||
p.to_(devices)
|
||||
with Context(BEAM=0):
|
||||
Tensor.realize(*to_move)
|
||||
|
||||
return model, unet, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod
|
||||
|
||||
@@ -9,11 +9,13 @@ from typing import Dict, Any
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from tinygrad import Device, GlobalCounters, dtypes, Tensor, TinyJit
|
||||
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm
|
||||
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm, flatten
|
||||
from tinygrad.nn import Conv2d, GroupNorm
|
||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||
from extra.models.clip import Closed, Tokenizer
|
||||
from extra.models.clip import Closed, Tokenizer, FrozenOpenClipEmbedder
|
||||
from extra.models import unet, clip
|
||||
from extra.models.unet import UNetModel
|
||||
from examples.mlperf.initializers import AutocastLinear, AutocastConv2d, AutocastGroupNorm, AutocastLayerNorm, zero_module, attn_f32_softmax, gelu_erf
|
||||
from extra.bench_log import BenchEvent, WallTimeEvent
|
||||
|
||||
class AttnBlock:
|
||||
@@ -154,12 +156,46 @@ unet_params: Dict[str,Any] = {
|
||||
"use_linear": False,
|
||||
}
|
||||
|
||||
mlperf_params: Dict[str,Any] = {"adm_in_ch": None, "in_ch": 4, "out_ch": 4, "model_ch": 320, "attention_resolutions": [4, 2, 1], "num_res_blocks": 2,
|
||||
"channel_mult": [1, 2, 4, 4], "d_head": 64, "transformer_depth": [1, 1, 1, 1], "ctx_dim": 1024, "use_linear": True,
|
||||
"num_groups":16, "st_norm_eps":1e-6}
|
||||
|
||||
class StableDiffusion:
|
||||
def __init__(self):
|
||||
def __init__(self, version:str|None=None, pretrained:str|None=None):
|
||||
self.alphas_cumprod = get_alphas_cumprod()
|
||||
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel(**unet_params))
|
||||
self.first_stage_model = AutoencoderKL()
|
||||
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = Closed.ClipTextTransformer()))
|
||||
if version != "v2-mlperf-train":
|
||||
self.first_stage_model = AutoencoderKL() # only needed for decoding generated latents to images; not needed in mlperf training from preprocessed moments
|
||||
|
||||
if not version:
|
||||
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = Closed.ClipTextTransformer()))
|
||||
unet_init_params = unet_params
|
||||
elif version in {"v2-mlperf-train", "v2-mlperf-eval"}:
|
||||
unet_init_params = mlperf_params
|
||||
clip.gelu = gelu_erf
|
||||
self.cond_stage_model = FrozenOpenClipEmbedder(**{"dims": 1024, "n_heads": 16, "layers": 24, "return_pooled": False, "ln_penultimate": True,
|
||||
"clip_tokenizer_version": "sd_mlperf_v5_0"})
|
||||
unet.Linear, unet.Conv2d, unet.GroupNorm, unet.LayerNorm = AutocastLinear, AutocastConv2d, AutocastGroupNorm, AutocastLayerNorm
|
||||
unet.attention, unet.gelu, unet.mixed_precision_dtype = attn_f32_softmax, gelu_erf, dtypes.bfloat16
|
||||
if pretrained:
|
||||
print("loading text encoder")
|
||||
weights: dict[str,Tensor] = {k.replace("cond_stage_model.", "", 1):v for k,v in torch_load(pretrained)["state_dict"].items() if k.startswith("cond_stage_model.")}
|
||||
weights["model.attn_mask"] = Tensor.full((77, 77), fill_value=float("-inf")).triu(1)
|
||||
load_state_dict(self.cond_stage_model, weights)
|
||||
# only the eval model needs the decoder
|
||||
if version == "v2-mlperf-eval":
|
||||
print("loading image latent encoder")
|
||||
weights = {k.replace("first_stage_model.", "", 1):v for k,v in torch_load(pretrained)["state_dict"].items() if k.startswith("first_stage_model.")}
|
||||
load_state_dict(self.first_stage_model, weights)
|
||||
|
||||
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel(**unet_init_params))
|
||||
if version == "v2-mlperf-train":
|
||||
# the mlperf reference inits certain weights as zeroes
|
||||
for bb in flatten(self.model.diffusion_model.input_blocks) + self.model.diffusion_model.middle_block + flatten(self.model.diffusion_model.output_blocks):
|
||||
if isinstance(bb, unet.ResBlock):
|
||||
zero_module(bb.out_layers[3])
|
||||
elif isinstance(bb, unet.SpatialTransformer):
|
||||
zero_module(bb.proj_out)
|
||||
zero_module(self.model.diffusion_model.out[2])
|
||||
|
||||
def get_x_prev_and_pred_x0(self, x, e_t, a_t, a_prev):
|
||||
temperature = 1
|
||||
|
||||
Reference in New Issue
Block a user