mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
Stable Diffusion model init for mlperf (#12314)
* include clip pr diff * updated unet and sd init * dehardcode default device * revert beam hang workaround --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -2,7 +2,9 @@ import math
|
||||
from typing import Union
|
||||
|
||||
from tinygrad import Tensor, nn, dtypes
|
||||
from tinygrad.helpers import prod, argfix
|
||||
from tinygrad.helpers import prod, argfix, Context
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from extra.models.unet import UNetModel
|
||||
|
||||
# rejection sampling truncated randn
|
||||
def rand_truncn(*shape, dtype=None, truncstds=2, **kwargs) -> Tensor:
|
||||
@@ -131,3 +133,59 @@ class Conv2dRetinaNet(nn.Conv2d):
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return x.conv2d(self.weight.cast(dtypes.default_float), self.bias.cast(dtypes.default_float) if self.bias is not None else None,
|
||||
groups=self.groups, stride=self.stride, dilation=self.dilation, padding=self.padding)
|
||||
|
||||
# copy torch AMP: isolate mixed precision to just the below autocast ops, instead of using dtypes.default_float which affects all new Tensors
|
||||
class AutocastLinear(nn.Linear):
|
||||
cast_dtype=dtypes.bfloat16 # enable monkeypatching of the mixed precision dtype
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
dtype = type(self).cast_dtype
|
||||
return x.cast(dtype).linear(self.weight.cast(dtype).transpose(), self.bias.cast(dtype) if self.bias is not None else None)
|
||||
|
||||
class AutocastConv2d(nn.Conv2d):
|
||||
cast_dtype=dtypes.bfloat16
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
dtype = type(self).cast_dtype
|
||||
return x.cast(dtype).conv2d(self.weight.cast(dtype), self.bias.cast(dtype), self.groups, self.stride, self.dilation, self.padding)
|
||||
|
||||
# copy torch AMP: upcast to float32 before GroupNorm and LayerNorm
|
||||
class AutocastGroupNorm(nn.GroupNorm):
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return super().__call__(x.cast(dtypes.float32))
|
||||
|
||||
class AutocastLayerNorm(nn.LayerNorm):
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return super().__call__(x.cast(dtypes.float32))
|
||||
|
||||
def zero_module(module):
|
||||
for p in get_parameters(module): p.assign(Tensor.zeros_like(p).contiguous())
|
||||
|
||||
# Stable Diffusion mlperf reference doesn't call scaled_dot_product_attention
|
||||
# copy torch AMP: upcast to float32 before softmax on CUDA
|
||||
def attn_f32_softmax(q:Tensor, k:Tensor, v:Tensor) -> Tensor:
|
||||
return (q.matmul(k.transpose(-2,-1), dtype=dtypes.float32) / math.sqrt(q.shape[-1])).softmax(-1).cast(q.dtype) @ v
|
||||
|
||||
def init_stable_diffusion(version:str, pretrained:str, devices:list[str]):
|
||||
from examples.stable_diffusion import StableDiffusion
|
||||
from tinygrad.nn.state import safe_load, safe_save, load_state_dict, get_state_dict
|
||||
from tempfile import TemporaryDirectory
|
||||
model = StableDiffusion(version=version, pretrained=pretrained)
|
||||
unet:UNetModel = model.model.diffusion_model
|
||||
|
||||
# this prevents extra consumption of memory, enabling much larger BS
|
||||
Tensor.realize(*get_parameters(unet))
|
||||
with TemporaryDirectory(prefix="unet_init") as tmp:
|
||||
safe_save(get_state_dict(unet), init_fn:=f"{tmp}/init_model.safetensors")
|
||||
load_state_dict(unet, safe_load(init_fn))
|
||||
|
||||
sqrt_alphas_cumprod = model.alphas_cumprod.sqrt().realize()
|
||||
sqrt_one_minus_alphas_cumprod = (1 - model.alphas_cumprod).sqrt().realize()
|
||||
|
||||
if len(devices) > 1:
|
||||
to_move = [sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod]
|
||||
if version == "v2-mlperf-train": to_move += get_parameters(unet) + get_parameters(model.cond_stage_model)
|
||||
for p in to_move:
|
||||
p.to_(devices)
|
||||
with Context(BEAM=0):
|
||||
Tensor.realize(*to_move)
|
||||
|
||||
return model, unet, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod
|
||||
|
||||
@@ -9,11 +9,13 @@ from typing import Dict, Any
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from tinygrad import Device, GlobalCounters, dtypes, Tensor, TinyJit
|
||||
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm
|
||||
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm, flatten
|
||||
from tinygrad.nn import Conv2d, GroupNorm
|
||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||
from extra.models.clip import Closed, Tokenizer
|
||||
from extra.models.clip import Closed, Tokenizer, FrozenOpenClipEmbedder
|
||||
from extra.models import unet, clip
|
||||
from extra.models.unet import UNetModel
|
||||
from examples.mlperf.initializers import AutocastLinear, AutocastConv2d, AutocastGroupNorm, AutocastLayerNorm, zero_module, attn_f32_softmax, gelu_erf
|
||||
from extra.bench_log import BenchEvent, WallTimeEvent
|
||||
|
||||
class AttnBlock:
|
||||
@@ -154,12 +156,46 @@ unet_params: Dict[str,Any] = {
|
||||
"use_linear": False,
|
||||
}
|
||||
|
||||
mlperf_params: Dict[str,Any] = {"adm_in_ch": None, "in_ch": 4, "out_ch": 4, "model_ch": 320, "attention_resolutions": [4, 2, 1], "num_res_blocks": 2,
|
||||
"channel_mult": [1, 2, 4, 4], "d_head": 64, "transformer_depth": [1, 1, 1, 1], "ctx_dim": 1024, "use_linear": True,
|
||||
"num_groups":16, "st_norm_eps":1e-6}
|
||||
|
||||
class StableDiffusion:
|
||||
def __init__(self):
|
||||
def __init__(self, version:str|None=None, pretrained:str|None=None):
|
||||
self.alphas_cumprod = get_alphas_cumprod()
|
||||
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel(**unet_params))
|
||||
self.first_stage_model = AutoencoderKL()
|
||||
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = Closed.ClipTextTransformer()))
|
||||
if version != "v2-mlperf-train":
|
||||
self.first_stage_model = AutoencoderKL() # only needed for decoding generated latents to images; not needed in mlperf training from preprocessed moments
|
||||
|
||||
if not version:
|
||||
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = Closed.ClipTextTransformer()))
|
||||
unet_init_params = unet_params
|
||||
elif version in {"v2-mlperf-train", "v2-mlperf-eval"}:
|
||||
unet_init_params = mlperf_params
|
||||
clip.gelu = gelu_erf
|
||||
self.cond_stage_model = FrozenOpenClipEmbedder(**{"dims": 1024, "n_heads": 16, "layers": 24, "return_pooled": False, "ln_penultimate": True,
|
||||
"clip_tokenizer_version": "sd_mlperf_v5_0"})
|
||||
unet.Linear, unet.Conv2d, unet.GroupNorm, unet.LayerNorm = AutocastLinear, AutocastConv2d, AutocastGroupNorm, AutocastLayerNorm
|
||||
unet.attention, unet.gelu, unet.mixed_precision_dtype = attn_f32_softmax, gelu_erf, dtypes.bfloat16
|
||||
if pretrained:
|
||||
print("loading text encoder")
|
||||
weights: dict[str,Tensor] = {k.replace("cond_stage_model.", "", 1):v for k,v in torch_load(pretrained)["state_dict"].items() if k.startswith("cond_stage_model.")}
|
||||
weights["model.attn_mask"] = Tensor.full((77, 77), fill_value=float("-inf")).triu(1)
|
||||
load_state_dict(self.cond_stage_model, weights)
|
||||
# only the eval model needs the decoder
|
||||
if version == "v2-mlperf-eval":
|
||||
print("loading image latent encoder")
|
||||
weights = {k.replace("first_stage_model.", "", 1):v for k,v in torch_load(pretrained)["state_dict"].items() if k.startswith("first_stage_model.")}
|
||||
load_state_dict(self.first_stage_model, weights)
|
||||
|
||||
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel(**unet_init_params))
|
||||
if version == "v2-mlperf-train":
|
||||
# the mlperf reference inits certain weights as zeroes
|
||||
for bb in flatten(self.model.diffusion_model.input_blocks) + self.model.diffusion_model.middle_block + flatten(self.model.diffusion_model.output_blocks):
|
||||
if isinstance(bb, unet.ResBlock):
|
||||
zero_module(bb.out_layers[3])
|
||||
elif isinstance(bb, unet.SpatialTransformer):
|
||||
zero_module(bb.proj_out)
|
||||
zero_module(self.model.diffusion_model.out[2])
|
||||
|
||||
def get_x_prev_and_pred_x0(self, x, e_t, a_t, a_prev):
|
||||
temperature = 1
|
||||
|
||||
@@ -1,21 +1,24 @@
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.nn import Linear, Conv2d, GroupNorm, LayerNorm
|
||||
from tinygrad import Tensor, dtypes, nn
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from typing import Optional, Union, List, Any, Tuple
|
||||
from typing import Optional, Union, List, Any, Tuple, Callable
|
||||
import math
|
||||
|
||||
# allow for monkeypatching
|
||||
Linear, Conv2d, GroupNorm, LayerNorm = nn.Linear, nn.Conv2d, nn.GroupNorm, nn.LayerNorm
|
||||
attention, gelu, mixed_precision_dtype = Tensor.scaled_dot_product_attention, Tensor.gelu, dtypes.float16
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/util.py#L207
|
||||
def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = (-math.log(max_period) * Tensor.arange(half, device=timesteps.device) / half).exp()
|
||||
args = timesteps.unsqueeze(1) * freqs.unsqueeze(0)
|
||||
out = Tensor.cat(args.cos(), args.sin(), dim=-1)
|
||||
return out.cast(dtypes.float16) if is_dtype_supported(dtypes.float16) else out
|
||||
return out.cast(mixed_precision_dtype) if is_dtype_supported(mixed_precision_dtype) else out
|
||||
|
||||
class ResBlock:
|
||||
def __init__(self, channels:int, emb_channels:int, out_channels:int):
|
||||
def __init__(self, channels:int, emb_channels:int, out_channels:int, num_groups:int=32):
|
||||
self.in_layers = [
|
||||
GroupNorm(32, channels),
|
||||
GroupNorm(num_groups, channels),
|
||||
Tensor.silu,
|
||||
Conv2d(channels, out_channels, 3, padding=1),
|
||||
]
|
||||
@@ -24,7 +27,7 @@ class ResBlock:
|
||||
Linear(emb_channels, out_channels),
|
||||
]
|
||||
self.out_layers = [
|
||||
GroupNorm(32, out_channels),
|
||||
GroupNorm(num_groups, out_channels),
|
||||
Tensor.silu,
|
||||
lambda x: x, # needed for weights loading code to work
|
||||
Conv2d(out_channels, out_channels, 3, padding=1),
|
||||
@@ -45,35 +48,37 @@ class CrossAttention:
|
||||
self.to_v = Linear(ctx_dim, n_heads*d_head, bias=False)
|
||||
self.num_heads = n_heads
|
||||
self.head_size = d_head
|
||||
self.attn = attention
|
||||
self.to_out = [Linear(n_heads*d_head, query_dim)]
|
||||
|
||||
def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
|
||||
ctx = x if ctx is None else ctx
|
||||
q,k,v = self.to_q(x), self.to_k(ctx), self.to_v(ctx)
|
||||
q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)]
|
||||
attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1,2)
|
||||
attention = self.attn(q, k, v).transpose(1,2)
|
||||
h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size)
|
||||
return h_.sequential(self.to_out)
|
||||
|
||||
class GEGLU:
|
||||
def __init__(self, dim_in:int, dim_out:int):
|
||||
self.proj = Linear(dim_in, dim_out * 2)
|
||||
self.gelu = gelu
|
||||
self.dim_out = dim_out
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * gate.gelu()
|
||||
return x * self.gelu(gate)
|
||||
|
||||
class FeedForward:
|
||||
def __init__(self, dim:int, mult:int=4):
|
||||
self.net = [
|
||||
self.net: tuple[GEGLU, Callable, nn.Linear] = (
|
||||
GEGLU(dim, dim*mult),
|
||||
lambda x: x, # needed for weights loading code to work
|
||||
Linear(dim*mult, dim)
|
||||
]
|
||||
)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return x.sequential(self.net)
|
||||
return x.sequential(list(self.net))
|
||||
|
||||
class BasicTransformerBlock:
|
||||
def __init__(self, dim:int, ctx_dim:int, n_heads:int, d_head:int):
|
||||
@@ -92,12 +97,13 @@ class BasicTransformerBlock:
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/attention.py#L619
|
||||
class SpatialTransformer:
|
||||
def __init__(self, channels:int, n_heads:int, d_head:int, ctx_dim:Union[int,List[int]], use_linear:bool, depth:int=1):
|
||||
def __init__(self, channels:int, n_heads:int, d_head:int, ctx_dim:Union[int,List[int]], use_linear:bool, depth:int=1,
|
||||
norm_eps:float=1e-5):
|
||||
if isinstance(ctx_dim, int):
|
||||
ctx_dim = [ctx_dim]*depth
|
||||
else:
|
||||
assert isinstance(ctx_dim, list) and depth == len(ctx_dim)
|
||||
self.norm = GroupNorm(32, channels)
|
||||
self.norm = GroupNorm(32, channels, eps=norm_eps)
|
||||
assert channels == n_heads * d_head
|
||||
self.proj_in = Linear(channels, channels) if use_linear else Conv2d(channels, channels, 1)
|
||||
self.transformer_blocks = [BasicTransformerBlock(channels, ctx_dim[d], n_heads, d_head) for d in range(depth)]
|
||||
@@ -134,7 +140,9 @@ class Upsample:
|
||||
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/openaimodel.py#L472
|
||||
class UNetModel:
|
||||
def __init__(self, adm_in_ch:Optional[int], in_ch:int, out_ch:int, model_ch:int, attention_resolutions:List[int], num_res_blocks:int, channel_mult:List[int], transformer_depth:List[int], ctx_dim:Union[int,List[int]], use_linear:bool=False, d_head:Optional[int]=None, n_heads:Optional[int]=None):
|
||||
def __init__(self, adm_in_ch:Optional[int], in_ch:int, out_ch:int, model_ch:int, attention_resolutions:List[int], num_res_blocks:int,
|
||||
channel_mult:List[int], transformer_depth:List[int], ctx_dim:Union[int,List[int]], use_linear:bool=False, d_head:Optional[int]=None,
|
||||
n_heads:Optional[int]=None, num_groups:int=32, st_norm_eps:float=1e-5):
|
||||
self.model_ch = model_ch
|
||||
self.num_res_blocks = [num_res_blocks] * len(channel_mult)
|
||||
|
||||
@@ -174,12 +182,12 @@ class UNetModel:
|
||||
for idx, mult in enumerate(channel_mult):
|
||||
for _ in range(self.num_res_blocks[idx]):
|
||||
layers: List[Any] = [
|
||||
ResBlock(ch, time_embed_dim, model_ch*mult),
|
||||
ResBlock(ch, time_embed_dim, model_ch*mult, num_groups),
|
||||
]
|
||||
ch = mult * model_ch
|
||||
if ds in attention_resolutions:
|
||||
d_head, n_heads = get_d_and_n_heads(ch)
|
||||
layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx]))
|
||||
layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx], norm_eps=st_norm_eps))
|
||||
|
||||
self.input_blocks.append(layers)
|
||||
input_block_channels.append(ch)
|
||||
@@ -193,9 +201,9 @@ class UNetModel:
|
||||
|
||||
d_head, n_heads = get_d_and_n_heads(ch)
|
||||
self.middle_block: List = [
|
||||
ResBlock(ch, time_embed_dim, ch),
|
||||
SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[-1]),
|
||||
ResBlock(ch, time_embed_dim, ch),
|
||||
ResBlock(ch, time_embed_dim, ch, num_groups),
|
||||
SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[-1], norm_eps=st_norm_eps),
|
||||
ResBlock(ch, time_embed_dim, ch, num_groups),
|
||||
]
|
||||
|
||||
self.output_blocks = []
|
||||
@@ -203,13 +211,13 @@ class UNetModel:
|
||||
for i in range(self.num_res_blocks[idx] + 1):
|
||||
ich = input_block_channels.pop()
|
||||
layers = [
|
||||
ResBlock(ch + ich, time_embed_dim, model_ch*mult),
|
||||
ResBlock(ch + ich, time_embed_dim, model_ch*mult, num_groups),
|
||||
]
|
||||
ch = model_ch * mult
|
||||
|
||||
if ds in attention_resolutions:
|
||||
d_head, n_heads = get_d_and_n_heads(ch)
|
||||
layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx]))
|
||||
layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx], norm_eps=st_norm_eps))
|
||||
|
||||
if idx > 0 and i == self.num_res_blocks[idx]:
|
||||
layers.append(Upsample(ch))
|
||||
@@ -217,7 +225,7 @@ class UNetModel:
|
||||
self.output_blocks.append(layers)
|
||||
|
||||
self.out = [
|
||||
GroupNorm(32, ch),
|
||||
GroupNorm(num_groups, ch),
|
||||
Tensor.silu,
|
||||
Conv2d(model_ch, out_ch, 3, padding=1),
|
||||
]
|
||||
@@ -230,10 +238,10 @@ class UNetModel:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + y.sequential(self.label_emb[0])
|
||||
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
emb = emb.cast(dtypes.float16)
|
||||
ctx = ctx.cast(dtypes.float16)
|
||||
x = x .cast(dtypes.float16)
|
||||
if is_dtype_supported(mixed_precision_dtype):
|
||||
emb = emb.cast(mixed_precision_dtype)
|
||||
ctx = ctx.cast(mixed_precision_dtype)
|
||||
x = x .cast(mixed_precision_dtype)
|
||||
|
||||
def run(x:Tensor, bb) -> Tensor:
|
||||
if isinstance(bb, ResBlock): x = bb(x, emb)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from extra.models import clip
|
||||
from examples.mlperf.initializers import gelu_erf
|
||||
Device.DEFAULT="NULL"
|
||||
GPUS = [f"NULL:{i}" for i in range(8)]
|
||||
from examples.mlperf.initializers import gelu_erf, init_stable_diffusion, attn_f32_softmax
|
||||
from typing import Literal
|
||||
|
||||
clip_params = {"dims": 1024, "n_heads": 16, "layers": 24, "return_pooled": False, "ln_penultimate": True, "clip_tokenizer_version": "sd_mlperf_v5_0"}
|
||||
def get_cond_stage_model(GPUS:list[str]|None=None) -> clip.FrozenOpenClipEmbedder:
|
||||
@@ -30,6 +32,7 @@ class TestOpenClip(unittest.TestCase):
|
||||
|
||||
def test_multigpu_clip_embed(self):
|
||||
BS = 304
|
||||
GPUS = [f"{Device.DEFAULT}:{i}" for i in range(8)]
|
||||
model = get_cond_stage_model(GPUS)
|
||||
tokens = get_tokens(BS)
|
||||
embeds = model.embed_tokens(tokens.shard(GPUS, axis=0)).realize()
|
||||
@@ -38,6 +41,7 @@ class TestOpenClip(unittest.TestCase):
|
||||
|
||||
def test_multigpu_clip_score(self):
|
||||
BS = 240
|
||||
GPUS = [f"{Device.DEFAULT}:{i}" for i in range(8)]
|
||||
vision_cfg = {'width': 1280, 'layers': 32, 'd_head': 80, 'image_size': 224, 'patch_size': 14}
|
||||
text_cfg = {'width': 1024, 'n_heads': 16, 'layers': 24, 'vocab_size': 49408, 'ctx_length': 77}
|
||||
clip.gelu = gelu_erf
|
||||
@@ -49,5 +53,62 @@ class TestOpenClip(unittest.TestCase):
|
||||
self.assertEqual(scores.shape, (BS,))
|
||||
self.assertEqual(scores.dtype, dtypes.float32)
|
||||
|
||||
class TestInitStableDiffusion(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# NOTE: set env variable based on where checkpoints are on the system
|
||||
self.CKPTDIR = Path(getenv("CKPTDIR", "/raid/weights/stable_diffusion"))
|
||||
|
||||
def helper_test_init(self, version:Literal["v2-mlperf-train", "v2-mlperf-eval"]):
|
||||
model, unet, sqrt_acp, sqrt_omacp = init_stable_diffusion(version, self.CKPTDIR / "sd" / "512-base-ema.ckpt", ["CPU"])
|
||||
|
||||
with self.subTest("test that StableDiffusion has correct models"):
|
||||
self.assertEqual(model.model.diffusion_model, unet)
|
||||
has_encoder = True if version=="v2-mlperf-eval" else False
|
||||
self.assertEqual(hasattr(model, "first_stage_model"), has_encoder, "only the eval model uses the encoder")
|
||||
self.assertTrue(isinstance(model.cond_stage_model, clip.FrozenOpenClipEmbedder))
|
||||
|
||||
with self.subTest("test for mlperf unique attributes"):
|
||||
self.assertEqual(model.cond_stage_model.tokenizer.version, 'sd_mlperf_v5_0')
|
||||
self.assertEqual(unet.out[0].num_groups, 16)
|
||||
self.assertEqual(unet.input_blocks[1][1].norm.eps, 1e-6)
|
||||
self.assertEqual(unet.input_blocks[1][1].transformer_blocks[0].attn1.attn, attn_f32_softmax)
|
||||
|
||||
with self.subTest("test loaded clip parameters"):
|
||||
sample = model.cond_stage_model.model.transformer.resblocks[8].mlp.c_fc.bias.flatten()[42:46].numpy()
|
||||
expected = np.array([-0.49812260270118713, -0.3039605915546417, -0.40284937620162964, -0.45069342851638794], dtype=np.float32)
|
||||
np.testing.assert_allclose(sample, expected, rtol=1e-7, atol=0, err_msg="loaded clip parameters are incorrect")
|
||||
|
||||
if version=="v2-mlperf-train":
|
||||
with self.subTest("test that zero_module worked"):
|
||||
self.assertTrue((unet.out[2].weight == 0).all().item(), "expected all zeroes")
|
||||
self.assertTrue((unet.out[2].bias == 0).all().item(), "expected all zeroes")
|
||||
elif version=="v2-mlperf-eval":
|
||||
with self.subTest("test loaded vae parameters"):
|
||||
sample = model.first_stage_model.decoder.up[0]['block'][1].conv2.weight.flatten()[42:46].numpy()
|
||||
expected = np.array([0.08192943036556244, 0.040095631033182144, 0.07541035860776901, 0.1475081741809845], dtype=np.float32)
|
||||
np.testing.assert_allclose(sample, expected, rtol=1e-7, atol=0, err_msg="loaded vae parameters are incorrect")
|
||||
|
||||
with self.subTest("check schedules"):
|
||||
expected = np.array([0.9995748996734619, 0.06826484948396683], dtype=np.float32)
|
||||
np.testing.assert_allclose(sqrt_acp[[0,-1]].numpy(), expected, rtol=1e-7, atol=0, err_msg="sqrt_acp is incorrect")
|
||||
expected = np.array([0.029155133292078972, 0.9976672530174255], dtype=np.float32)
|
||||
np.testing.assert_allclose(sqrt_omacp[[0,-1]].numpy(), expected, rtol=1e-7, atol=0, err_msg="sqrt_omacp is incorrect")
|
||||
|
||||
with self.subTest("check mixed precision"):
|
||||
out = unet.input_blocks[2][1].proj_in(Tensor.randn(320, dtype=dtypes.float32))
|
||||
self.assertEqual(out.dtype, dtypes.bfloat16, "expected float32 to be downcast to bfloat16 by Linear")
|
||||
out = unet.out[2](Tensor.randn(304,320,64,64, dtype=dtypes.float32))
|
||||
self.assertEqual(out.dtype, dtypes.bfloat16, "expected float32 to be downcast to bfloat16 by Conv2d")
|
||||
out = unet.input_blocks[1][1].transformer_blocks[0].norm1(Tensor.randn(320, dtype=dtypes.bfloat16))
|
||||
self.assertEqual(out.dtype, dtypes.float32, "expected bfloat16 to be upcast to float32 by LayerNorm")
|
||||
out = unet.input_blocks[5][0].in_layers[0](Tensor.randn(304, 640, dtype=dtypes.bfloat16))
|
||||
self.assertEqual(out.dtype, dtypes.float32, "expected bfloat16 to be upcast to float32 by GroupNorm")
|
||||
|
||||
def test_train_model(self):
|
||||
self.helper_test_init("v2-mlperf-train")
|
||||
|
||||
def test_eval_model(self):
|
||||
self.helper_test_init("v2-mlperf-eval")
|
||||
|
||||
if __name__=="__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user