mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
Stable Diffusion model init for mlperf (#12314)
* include clip pr diff * updated unet and sd init * dehardcode default device * revert beam hang workaround --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -2,7 +2,9 @@ import math
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from tinygrad import Tensor, nn, dtypes
|
from tinygrad import Tensor, nn, dtypes
|
||||||
from tinygrad.helpers import prod, argfix
|
from tinygrad.helpers import prod, argfix, Context
|
||||||
|
from tinygrad.nn.state import get_parameters
|
||||||
|
from extra.models.unet import UNetModel
|
||||||
|
|
||||||
# rejection sampling truncated randn
|
# rejection sampling truncated randn
|
||||||
def rand_truncn(*shape, dtype=None, truncstds=2, **kwargs) -> Tensor:
|
def rand_truncn(*shape, dtype=None, truncstds=2, **kwargs) -> Tensor:
|
||||||
@@ -131,3 +133,59 @@ class Conv2dRetinaNet(nn.Conv2d):
|
|||||||
def __call__(self, x:Tensor) -> Tensor:
|
def __call__(self, x:Tensor) -> Tensor:
|
||||||
return x.conv2d(self.weight.cast(dtypes.default_float), self.bias.cast(dtypes.default_float) if self.bias is not None else None,
|
return x.conv2d(self.weight.cast(dtypes.default_float), self.bias.cast(dtypes.default_float) if self.bias is not None else None,
|
||||||
groups=self.groups, stride=self.stride, dilation=self.dilation, padding=self.padding)
|
groups=self.groups, stride=self.stride, dilation=self.dilation, padding=self.padding)
|
||||||
|
|
||||||
|
# copy torch AMP: isolate mixed precision to just the below autocast ops, instead of using dtypes.default_float which affects all new Tensors
|
||||||
|
class AutocastLinear(nn.Linear):
|
||||||
|
cast_dtype=dtypes.bfloat16 # enable monkeypatching of the mixed precision dtype
|
||||||
|
def __call__(self, x:Tensor) -> Tensor:
|
||||||
|
dtype = type(self).cast_dtype
|
||||||
|
return x.cast(dtype).linear(self.weight.cast(dtype).transpose(), self.bias.cast(dtype) if self.bias is not None else None)
|
||||||
|
|
||||||
|
class AutocastConv2d(nn.Conv2d):
|
||||||
|
cast_dtype=dtypes.bfloat16
|
||||||
|
def __call__(self, x:Tensor) -> Tensor:
|
||||||
|
dtype = type(self).cast_dtype
|
||||||
|
return x.cast(dtype).conv2d(self.weight.cast(dtype), self.bias.cast(dtype), self.groups, self.stride, self.dilation, self.padding)
|
||||||
|
|
||||||
|
# copy torch AMP: upcast to float32 before GroupNorm and LayerNorm
|
||||||
|
class AutocastGroupNorm(nn.GroupNorm):
|
||||||
|
def __call__(self, x:Tensor) -> Tensor:
|
||||||
|
return super().__call__(x.cast(dtypes.float32))
|
||||||
|
|
||||||
|
class AutocastLayerNorm(nn.LayerNorm):
|
||||||
|
def __call__(self, x:Tensor) -> Tensor:
|
||||||
|
return super().__call__(x.cast(dtypes.float32))
|
||||||
|
|
||||||
|
def zero_module(module):
|
||||||
|
for p in get_parameters(module): p.assign(Tensor.zeros_like(p).contiguous())
|
||||||
|
|
||||||
|
# Stable Diffusion mlperf reference doesn't call scaled_dot_product_attention
|
||||||
|
# copy torch AMP: upcast to float32 before softmax on CUDA
|
||||||
|
def attn_f32_softmax(q:Tensor, k:Tensor, v:Tensor) -> Tensor:
|
||||||
|
return (q.matmul(k.transpose(-2,-1), dtype=dtypes.float32) / math.sqrt(q.shape[-1])).softmax(-1).cast(q.dtype) @ v
|
||||||
|
|
||||||
|
def init_stable_diffusion(version:str, pretrained:str, devices:list[str]):
|
||||||
|
from examples.stable_diffusion import StableDiffusion
|
||||||
|
from tinygrad.nn.state import safe_load, safe_save, load_state_dict, get_state_dict
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
model = StableDiffusion(version=version, pretrained=pretrained)
|
||||||
|
unet:UNetModel = model.model.diffusion_model
|
||||||
|
|
||||||
|
# this prevents extra consumption of memory, enabling much larger BS
|
||||||
|
Tensor.realize(*get_parameters(unet))
|
||||||
|
with TemporaryDirectory(prefix="unet_init") as tmp:
|
||||||
|
safe_save(get_state_dict(unet), init_fn:=f"{tmp}/init_model.safetensors")
|
||||||
|
load_state_dict(unet, safe_load(init_fn))
|
||||||
|
|
||||||
|
sqrt_alphas_cumprod = model.alphas_cumprod.sqrt().realize()
|
||||||
|
sqrt_one_minus_alphas_cumprod = (1 - model.alphas_cumprod).sqrt().realize()
|
||||||
|
|
||||||
|
if len(devices) > 1:
|
||||||
|
to_move = [sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod]
|
||||||
|
if version == "v2-mlperf-train": to_move += get_parameters(unet) + get_parameters(model.cond_stage_model)
|
||||||
|
for p in to_move:
|
||||||
|
p.to_(devices)
|
||||||
|
with Context(BEAM=0):
|
||||||
|
Tensor.realize(*to_move)
|
||||||
|
|
||||||
|
return model, unet, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod
|
||||||
|
|||||||
@@ -9,11 +9,13 @@ from typing import Dict, Any
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tinygrad import Device, GlobalCounters, dtypes, Tensor, TinyJit
|
from tinygrad import Device, GlobalCounters, dtypes, Tensor, TinyJit
|
||||||
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm
|
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm, flatten
|
||||||
from tinygrad.nn import Conv2d, GroupNorm
|
from tinygrad.nn import Conv2d, GroupNorm
|
||||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||||
from extra.models.clip import Closed, Tokenizer
|
from extra.models.clip import Closed, Tokenizer, FrozenOpenClipEmbedder
|
||||||
|
from extra.models import unet, clip
|
||||||
from extra.models.unet import UNetModel
|
from extra.models.unet import UNetModel
|
||||||
|
from examples.mlperf.initializers import AutocastLinear, AutocastConv2d, AutocastGroupNorm, AutocastLayerNorm, zero_module, attn_f32_softmax, gelu_erf
|
||||||
from extra.bench_log import BenchEvent, WallTimeEvent
|
from extra.bench_log import BenchEvent, WallTimeEvent
|
||||||
|
|
||||||
class AttnBlock:
|
class AttnBlock:
|
||||||
@@ -154,12 +156,46 @@ unet_params: Dict[str,Any] = {
|
|||||||
"use_linear": False,
|
"use_linear": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mlperf_params: Dict[str,Any] = {"adm_in_ch": None, "in_ch": 4, "out_ch": 4, "model_ch": 320, "attention_resolutions": [4, 2, 1], "num_res_blocks": 2,
|
||||||
|
"channel_mult": [1, 2, 4, 4], "d_head": 64, "transformer_depth": [1, 1, 1, 1], "ctx_dim": 1024, "use_linear": True,
|
||||||
|
"num_groups":16, "st_norm_eps":1e-6}
|
||||||
|
|
||||||
class StableDiffusion:
|
class StableDiffusion:
|
||||||
def __init__(self):
|
def __init__(self, version:str|None=None, pretrained:str|None=None):
|
||||||
self.alphas_cumprod = get_alphas_cumprod()
|
self.alphas_cumprod = get_alphas_cumprod()
|
||||||
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel(**unet_params))
|
if version != "v2-mlperf-train":
|
||||||
self.first_stage_model = AutoencoderKL()
|
self.first_stage_model = AutoencoderKL() # only needed for decoding generated latents to images; not needed in mlperf training from preprocessed moments
|
||||||
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = Closed.ClipTextTransformer()))
|
|
||||||
|
if not version:
|
||||||
|
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = Closed.ClipTextTransformer()))
|
||||||
|
unet_init_params = unet_params
|
||||||
|
elif version in {"v2-mlperf-train", "v2-mlperf-eval"}:
|
||||||
|
unet_init_params = mlperf_params
|
||||||
|
clip.gelu = gelu_erf
|
||||||
|
self.cond_stage_model = FrozenOpenClipEmbedder(**{"dims": 1024, "n_heads": 16, "layers": 24, "return_pooled": False, "ln_penultimate": True,
|
||||||
|
"clip_tokenizer_version": "sd_mlperf_v5_0"})
|
||||||
|
unet.Linear, unet.Conv2d, unet.GroupNorm, unet.LayerNorm = AutocastLinear, AutocastConv2d, AutocastGroupNorm, AutocastLayerNorm
|
||||||
|
unet.attention, unet.gelu, unet.mixed_precision_dtype = attn_f32_softmax, gelu_erf, dtypes.bfloat16
|
||||||
|
if pretrained:
|
||||||
|
print("loading text encoder")
|
||||||
|
weights: dict[str,Tensor] = {k.replace("cond_stage_model.", "", 1):v for k,v in torch_load(pretrained)["state_dict"].items() if k.startswith("cond_stage_model.")}
|
||||||
|
weights["model.attn_mask"] = Tensor.full((77, 77), fill_value=float("-inf")).triu(1)
|
||||||
|
load_state_dict(self.cond_stage_model, weights)
|
||||||
|
# only the eval model needs the decoder
|
||||||
|
if version == "v2-mlperf-eval":
|
||||||
|
print("loading image latent encoder")
|
||||||
|
weights = {k.replace("first_stage_model.", "", 1):v for k,v in torch_load(pretrained)["state_dict"].items() if k.startswith("first_stage_model.")}
|
||||||
|
load_state_dict(self.first_stage_model, weights)
|
||||||
|
|
||||||
|
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel(**unet_init_params))
|
||||||
|
if version == "v2-mlperf-train":
|
||||||
|
# the mlperf reference inits certain weights as zeroes
|
||||||
|
for bb in flatten(self.model.diffusion_model.input_blocks) + self.model.diffusion_model.middle_block + flatten(self.model.diffusion_model.output_blocks):
|
||||||
|
if isinstance(bb, unet.ResBlock):
|
||||||
|
zero_module(bb.out_layers[3])
|
||||||
|
elif isinstance(bb, unet.SpatialTransformer):
|
||||||
|
zero_module(bb.proj_out)
|
||||||
|
zero_module(self.model.diffusion_model.out[2])
|
||||||
|
|
||||||
def get_x_prev_and_pred_x0(self, x, e_t, a_t, a_prev):
|
def get_x_prev_and_pred_x0(self, x, e_t, a_t, a_prev):
|
||||||
temperature = 1
|
temperature = 1
|
||||||
|
|||||||
@@ -1,21 +1,24 @@
|
|||||||
from tinygrad import Tensor, dtypes
|
from tinygrad import Tensor, dtypes, nn
|
||||||
from tinygrad.nn import Linear, Conv2d, GroupNorm, LayerNorm
|
|
||||||
from tinygrad.device import is_dtype_supported
|
from tinygrad.device import is_dtype_supported
|
||||||
from typing import Optional, Union, List, Any, Tuple
|
from typing import Optional, Union, List, Any, Tuple, Callable
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
# allow for monkeypatching
|
||||||
|
Linear, Conv2d, GroupNorm, LayerNorm = nn.Linear, nn.Conv2d, nn.GroupNorm, nn.LayerNorm
|
||||||
|
attention, gelu, mixed_precision_dtype = Tensor.scaled_dot_product_attention, Tensor.gelu, dtypes.float16
|
||||||
|
|
||||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/util.py#L207
|
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/util.py#L207
|
||||||
def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000):
|
def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000):
|
||||||
half = dim // 2
|
half = dim // 2
|
||||||
freqs = (-math.log(max_period) * Tensor.arange(half, device=timesteps.device) / half).exp()
|
freqs = (-math.log(max_period) * Tensor.arange(half, device=timesteps.device) / half).exp()
|
||||||
args = timesteps.unsqueeze(1) * freqs.unsqueeze(0)
|
args = timesteps.unsqueeze(1) * freqs.unsqueeze(0)
|
||||||
out = Tensor.cat(args.cos(), args.sin(), dim=-1)
|
out = Tensor.cat(args.cos(), args.sin(), dim=-1)
|
||||||
return out.cast(dtypes.float16) if is_dtype_supported(dtypes.float16) else out
|
return out.cast(mixed_precision_dtype) if is_dtype_supported(mixed_precision_dtype) else out
|
||||||
|
|
||||||
class ResBlock:
|
class ResBlock:
|
||||||
def __init__(self, channels:int, emb_channels:int, out_channels:int):
|
def __init__(self, channels:int, emb_channels:int, out_channels:int, num_groups:int=32):
|
||||||
self.in_layers = [
|
self.in_layers = [
|
||||||
GroupNorm(32, channels),
|
GroupNorm(num_groups, channels),
|
||||||
Tensor.silu,
|
Tensor.silu,
|
||||||
Conv2d(channels, out_channels, 3, padding=1),
|
Conv2d(channels, out_channels, 3, padding=1),
|
||||||
]
|
]
|
||||||
@@ -24,7 +27,7 @@ class ResBlock:
|
|||||||
Linear(emb_channels, out_channels),
|
Linear(emb_channels, out_channels),
|
||||||
]
|
]
|
||||||
self.out_layers = [
|
self.out_layers = [
|
||||||
GroupNorm(32, out_channels),
|
GroupNorm(num_groups, out_channels),
|
||||||
Tensor.silu,
|
Tensor.silu,
|
||||||
lambda x: x, # needed for weights loading code to work
|
lambda x: x, # needed for weights loading code to work
|
||||||
Conv2d(out_channels, out_channels, 3, padding=1),
|
Conv2d(out_channels, out_channels, 3, padding=1),
|
||||||
@@ -45,35 +48,37 @@ class CrossAttention:
|
|||||||
self.to_v = Linear(ctx_dim, n_heads*d_head, bias=False)
|
self.to_v = Linear(ctx_dim, n_heads*d_head, bias=False)
|
||||||
self.num_heads = n_heads
|
self.num_heads = n_heads
|
||||||
self.head_size = d_head
|
self.head_size = d_head
|
||||||
|
self.attn = attention
|
||||||
self.to_out = [Linear(n_heads*d_head, query_dim)]
|
self.to_out = [Linear(n_heads*d_head, query_dim)]
|
||||||
|
|
||||||
def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
|
def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
|
||||||
ctx = x if ctx is None else ctx
|
ctx = x if ctx is None else ctx
|
||||||
q,k,v = self.to_q(x), self.to_k(ctx), self.to_v(ctx)
|
q,k,v = self.to_q(x), self.to_k(ctx), self.to_v(ctx)
|
||||||
q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)]
|
q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)]
|
||||||
attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1,2)
|
attention = self.attn(q, k, v).transpose(1,2)
|
||||||
h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size)
|
h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size)
|
||||||
return h_.sequential(self.to_out)
|
return h_.sequential(self.to_out)
|
||||||
|
|
||||||
class GEGLU:
|
class GEGLU:
|
||||||
def __init__(self, dim_in:int, dim_out:int):
|
def __init__(self, dim_in:int, dim_out:int):
|
||||||
self.proj = Linear(dim_in, dim_out * 2)
|
self.proj = Linear(dim_in, dim_out * 2)
|
||||||
|
self.gelu = gelu
|
||||||
self.dim_out = dim_out
|
self.dim_out = dim_out
|
||||||
|
|
||||||
def __call__(self, x:Tensor) -> Tensor:
|
def __call__(self, x:Tensor) -> Tensor:
|
||||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||||
return x * gate.gelu()
|
return x * self.gelu(gate)
|
||||||
|
|
||||||
class FeedForward:
|
class FeedForward:
|
||||||
def __init__(self, dim:int, mult:int=4):
|
def __init__(self, dim:int, mult:int=4):
|
||||||
self.net = [
|
self.net: tuple[GEGLU, Callable, nn.Linear] = (
|
||||||
GEGLU(dim, dim*mult),
|
GEGLU(dim, dim*mult),
|
||||||
lambda x: x, # needed for weights loading code to work
|
lambda x: x, # needed for weights loading code to work
|
||||||
Linear(dim*mult, dim)
|
Linear(dim*mult, dim)
|
||||||
]
|
)
|
||||||
|
|
||||||
def __call__(self, x:Tensor) -> Tensor:
|
def __call__(self, x:Tensor) -> Tensor:
|
||||||
return x.sequential(self.net)
|
return x.sequential(list(self.net))
|
||||||
|
|
||||||
class BasicTransformerBlock:
|
class BasicTransformerBlock:
|
||||||
def __init__(self, dim:int, ctx_dim:int, n_heads:int, d_head:int):
|
def __init__(self, dim:int, ctx_dim:int, n_heads:int, d_head:int):
|
||||||
@@ -92,12 +97,13 @@ class BasicTransformerBlock:
|
|||||||
|
|
||||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/attention.py#L619
|
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/attention.py#L619
|
||||||
class SpatialTransformer:
|
class SpatialTransformer:
|
||||||
def __init__(self, channels:int, n_heads:int, d_head:int, ctx_dim:Union[int,List[int]], use_linear:bool, depth:int=1):
|
def __init__(self, channels:int, n_heads:int, d_head:int, ctx_dim:Union[int,List[int]], use_linear:bool, depth:int=1,
|
||||||
|
norm_eps:float=1e-5):
|
||||||
if isinstance(ctx_dim, int):
|
if isinstance(ctx_dim, int):
|
||||||
ctx_dim = [ctx_dim]*depth
|
ctx_dim = [ctx_dim]*depth
|
||||||
else:
|
else:
|
||||||
assert isinstance(ctx_dim, list) and depth == len(ctx_dim)
|
assert isinstance(ctx_dim, list) and depth == len(ctx_dim)
|
||||||
self.norm = GroupNorm(32, channels)
|
self.norm = GroupNorm(32, channels, eps=norm_eps)
|
||||||
assert channels == n_heads * d_head
|
assert channels == n_heads * d_head
|
||||||
self.proj_in = Linear(channels, channels) if use_linear else Conv2d(channels, channels, 1)
|
self.proj_in = Linear(channels, channels) if use_linear else Conv2d(channels, channels, 1)
|
||||||
self.transformer_blocks = [BasicTransformerBlock(channels, ctx_dim[d], n_heads, d_head) for d in range(depth)]
|
self.transformer_blocks = [BasicTransformerBlock(channels, ctx_dim[d], n_heads, d_head) for d in range(depth)]
|
||||||
@@ -134,7 +140,9 @@ class Upsample:
|
|||||||
|
|
||||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/openaimodel.py#L472
|
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/openaimodel.py#L472
|
||||||
class UNetModel:
|
class UNetModel:
|
||||||
def __init__(self, adm_in_ch:Optional[int], in_ch:int, out_ch:int, model_ch:int, attention_resolutions:List[int], num_res_blocks:int, channel_mult:List[int], transformer_depth:List[int], ctx_dim:Union[int,List[int]], use_linear:bool=False, d_head:Optional[int]=None, n_heads:Optional[int]=None):
|
def __init__(self, adm_in_ch:Optional[int], in_ch:int, out_ch:int, model_ch:int, attention_resolutions:List[int], num_res_blocks:int,
|
||||||
|
channel_mult:List[int], transformer_depth:List[int], ctx_dim:Union[int,List[int]], use_linear:bool=False, d_head:Optional[int]=None,
|
||||||
|
n_heads:Optional[int]=None, num_groups:int=32, st_norm_eps:float=1e-5):
|
||||||
self.model_ch = model_ch
|
self.model_ch = model_ch
|
||||||
self.num_res_blocks = [num_res_blocks] * len(channel_mult)
|
self.num_res_blocks = [num_res_blocks] * len(channel_mult)
|
||||||
|
|
||||||
@@ -174,12 +182,12 @@ class UNetModel:
|
|||||||
for idx, mult in enumerate(channel_mult):
|
for idx, mult in enumerate(channel_mult):
|
||||||
for _ in range(self.num_res_blocks[idx]):
|
for _ in range(self.num_res_blocks[idx]):
|
||||||
layers: List[Any] = [
|
layers: List[Any] = [
|
||||||
ResBlock(ch, time_embed_dim, model_ch*mult),
|
ResBlock(ch, time_embed_dim, model_ch*mult, num_groups),
|
||||||
]
|
]
|
||||||
ch = mult * model_ch
|
ch = mult * model_ch
|
||||||
if ds in attention_resolutions:
|
if ds in attention_resolutions:
|
||||||
d_head, n_heads = get_d_and_n_heads(ch)
|
d_head, n_heads = get_d_and_n_heads(ch)
|
||||||
layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx]))
|
layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx], norm_eps=st_norm_eps))
|
||||||
|
|
||||||
self.input_blocks.append(layers)
|
self.input_blocks.append(layers)
|
||||||
input_block_channels.append(ch)
|
input_block_channels.append(ch)
|
||||||
@@ -193,9 +201,9 @@ class UNetModel:
|
|||||||
|
|
||||||
d_head, n_heads = get_d_and_n_heads(ch)
|
d_head, n_heads = get_d_and_n_heads(ch)
|
||||||
self.middle_block: List = [
|
self.middle_block: List = [
|
||||||
ResBlock(ch, time_embed_dim, ch),
|
ResBlock(ch, time_embed_dim, ch, num_groups),
|
||||||
SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[-1]),
|
SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[-1], norm_eps=st_norm_eps),
|
||||||
ResBlock(ch, time_embed_dim, ch),
|
ResBlock(ch, time_embed_dim, ch, num_groups),
|
||||||
]
|
]
|
||||||
|
|
||||||
self.output_blocks = []
|
self.output_blocks = []
|
||||||
@@ -203,13 +211,13 @@ class UNetModel:
|
|||||||
for i in range(self.num_res_blocks[idx] + 1):
|
for i in range(self.num_res_blocks[idx] + 1):
|
||||||
ich = input_block_channels.pop()
|
ich = input_block_channels.pop()
|
||||||
layers = [
|
layers = [
|
||||||
ResBlock(ch + ich, time_embed_dim, model_ch*mult),
|
ResBlock(ch + ich, time_embed_dim, model_ch*mult, num_groups),
|
||||||
]
|
]
|
||||||
ch = model_ch * mult
|
ch = model_ch * mult
|
||||||
|
|
||||||
if ds in attention_resolutions:
|
if ds in attention_resolutions:
|
||||||
d_head, n_heads = get_d_and_n_heads(ch)
|
d_head, n_heads = get_d_and_n_heads(ch)
|
||||||
layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx]))
|
layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx], norm_eps=st_norm_eps))
|
||||||
|
|
||||||
if idx > 0 and i == self.num_res_blocks[idx]:
|
if idx > 0 and i == self.num_res_blocks[idx]:
|
||||||
layers.append(Upsample(ch))
|
layers.append(Upsample(ch))
|
||||||
@@ -217,7 +225,7 @@ class UNetModel:
|
|||||||
self.output_blocks.append(layers)
|
self.output_blocks.append(layers)
|
||||||
|
|
||||||
self.out = [
|
self.out = [
|
||||||
GroupNorm(32, ch),
|
GroupNorm(num_groups, ch),
|
||||||
Tensor.silu,
|
Tensor.silu,
|
||||||
Conv2d(model_ch, out_ch, 3, padding=1),
|
Conv2d(model_ch, out_ch, 3, padding=1),
|
||||||
]
|
]
|
||||||
@@ -230,10 +238,10 @@ class UNetModel:
|
|||||||
assert y.shape[0] == x.shape[0]
|
assert y.shape[0] == x.shape[0]
|
||||||
emb = emb + y.sequential(self.label_emb[0])
|
emb = emb + y.sequential(self.label_emb[0])
|
||||||
|
|
||||||
if is_dtype_supported(dtypes.float16):
|
if is_dtype_supported(mixed_precision_dtype):
|
||||||
emb = emb.cast(dtypes.float16)
|
emb = emb.cast(mixed_precision_dtype)
|
||||||
ctx = ctx.cast(dtypes.float16)
|
ctx = ctx.cast(mixed_precision_dtype)
|
||||||
x = x .cast(dtypes.float16)
|
x = x .cast(mixed_precision_dtype)
|
||||||
|
|
||||||
def run(x:Tensor, bb) -> Tensor:
|
def run(x:Tensor, bb) -> Tensor:
|
||||||
if isinstance(bb, ResBlock): x = bb(x, emb)
|
if isinstance(bb, ResBlock): x = bb(x, emb)
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
from tinygrad import Tensor, dtypes, Device
|
from tinygrad import Tensor, dtypes, Device
|
||||||
|
from tinygrad.helpers import getenv
|
||||||
from tinygrad.nn.state import get_parameters
|
from tinygrad.nn.state import get_parameters
|
||||||
from extra.models import clip
|
from extra.models import clip
|
||||||
from examples.mlperf.initializers import gelu_erf
|
from examples.mlperf.initializers import gelu_erf, init_stable_diffusion, attn_f32_softmax
|
||||||
Device.DEFAULT="NULL"
|
from typing import Literal
|
||||||
GPUS = [f"NULL:{i}" for i in range(8)]
|
|
||||||
|
|
||||||
clip_params = {"dims": 1024, "n_heads": 16, "layers": 24, "return_pooled": False, "ln_penultimate": True, "clip_tokenizer_version": "sd_mlperf_v5_0"}
|
clip_params = {"dims": 1024, "n_heads": 16, "layers": 24, "return_pooled": False, "ln_penultimate": True, "clip_tokenizer_version": "sd_mlperf_v5_0"}
|
||||||
def get_cond_stage_model(GPUS:list[str]|None=None) -> clip.FrozenOpenClipEmbedder:
|
def get_cond_stage_model(GPUS:list[str]|None=None) -> clip.FrozenOpenClipEmbedder:
|
||||||
@@ -30,6 +32,7 @@ class TestOpenClip(unittest.TestCase):
|
|||||||
|
|
||||||
def test_multigpu_clip_embed(self):
|
def test_multigpu_clip_embed(self):
|
||||||
BS = 304
|
BS = 304
|
||||||
|
GPUS = [f"{Device.DEFAULT}:{i}" for i in range(8)]
|
||||||
model = get_cond_stage_model(GPUS)
|
model = get_cond_stage_model(GPUS)
|
||||||
tokens = get_tokens(BS)
|
tokens = get_tokens(BS)
|
||||||
embeds = model.embed_tokens(tokens.shard(GPUS, axis=0)).realize()
|
embeds = model.embed_tokens(tokens.shard(GPUS, axis=0)).realize()
|
||||||
@@ -38,6 +41,7 @@ class TestOpenClip(unittest.TestCase):
|
|||||||
|
|
||||||
def test_multigpu_clip_score(self):
|
def test_multigpu_clip_score(self):
|
||||||
BS = 240
|
BS = 240
|
||||||
|
GPUS = [f"{Device.DEFAULT}:{i}" for i in range(8)]
|
||||||
vision_cfg = {'width': 1280, 'layers': 32, 'd_head': 80, 'image_size': 224, 'patch_size': 14}
|
vision_cfg = {'width': 1280, 'layers': 32, 'd_head': 80, 'image_size': 224, 'patch_size': 14}
|
||||||
text_cfg = {'width': 1024, 'n_heads': 16, 'layers': 24, 'vocab_size': 49408, 'ctx_length': 77}
|
text_cfg = {'width': 1024, 'n_heads': 16, 'layers': 24, 'vocab_size': 49408, 'ctx_length': 77}
|
||||||
clip.gelu = gelu_erf
|
clip.gelu = gelu_erf
|
||||||
@@ -49,5 +53,62 @@ class TestOpenClip(unittest.TestCase):
|
|||||||
self.assertEqual(scores.shape, (BS,))
|
self.assertEqual(scores.shape, (BS,))
|
||||||
self.assertEqual(scores.dtype, dtypes.float32)
|
self.assertEqual(scores.dtype, dtypes.float32)
|
||||||
|
|
||||||
|
class TestInitStableDiffusion(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
# NOTE: set env variable based on where checkpoints are on the system
|
||||||
|
self.CKPTDIR = Path(getenv("CKPTDIR", "/raid/weights/stable_diffusion"))
|
||||||
|
|
||||||
|
def helper_test_init(self, version:Literal["v2-mlperf-train", "v2-mlperf-eval"]):
|
||||||
|
model, unet, sqrt_acp, sqrt_omacp = init_stable_diffusion(version, self.CKPTDIR / "sd" / "512-base-ema.ckpt", ["CPU"])
|
||||||
|
|
||||||
|
with self.subTest("test that StableDiffusion has correct models"):
|
||||||
|
self.assertEqual(model.model.diffusion_model, unet)
|
||||||
|
has_encoder = True if version=="v2-mlperf-eval" else False
|
||||||
|
self.assertEqual(hasattr(model, "first_stage_model"), has_encoder, "only the eval model uses the encoder")
|
||||||
|
self.assertTrue(isinstance(model.cond_stage_model, clip.FrozenOpenClipEmbedder))
|
||||||
|
|
||||||
|
with self.subTest("test for mlperf unique attributes"):
|
||||||
|
self.assertEqual(model.cond_stage_model.tokenizer.version, 'sd_mlperf_v5_0')
|
||||||
|
self.assertEqual(unet.out[0].num_groups, 16)
|
||||||
|
self.assertEqual(unet.input_blocks[1][1].norm.eps, 1e-6)
|
||||||
|
self.assertEqual(unet.input_blocks[1][1].transformer_blocks[0].attn1.attn, attn_f32_softmax)
|
||||||
|
|
||||||
|
with self.subTest("test loaded clip parameters"):
|
||||||
|
sample = model.cond_stage_model.model.transformer.resblocks[8].mlp.c_fc.bias.flatten()[42:46].numpy()
|
||||||
|
expected = np.array([-0.49812260270118713, -0.3039605915546417, -0.40284937620162964, -0.45069342851638794], dtype=np.float32)
|
||||||
|
np.testing.assert_allclose(sample, expected, rtol=1e-7, atol=0, err_msg="loaded clip parameters are incorrect")
|
||||||
|
|
||||||
|
if version=="v2-mlperf-train":
|
||||||
|
with self.subTest("test that zero_module worked"):
|
||||||
|
self.assertTrue((unet.out[2].weight == 0).all().item(), "expected all zeroes")
|
||||||
|
self.assertTrue((unet.out[2].bias == 0).all().item(), "expected all zeroes")
|
||||||
|
elif version=="v2-mlperf-eval":
|
||||||
|
with self.subTest("test loaded vae parameters"):
|
||||||
|
sample = model.first_stage_model.decoder.up[0]['block'][1].conv2.weight.flatten()[42:46].numpy()
|
||||||
|
expected = np.array([0.08192943036556244, 0.040095631033182144, 0.07541035860776901, 0.1475081741809845], dtype=np.float32)
|
||||||
|
np.testing.assert_allclose(sample, expected, rtol=1e-7, atol=0, err_msg="loaded vae parameters are incorrect")
|
||||||
|
|
||||||
|
with self.subTest("check schedules"):
|
||||||
|
expected = np.array([0.9995748996734619, 0.06826484948396683], dtype=np.float32)
|
||||||
|
np.testing.assert_allclose(sqrt_acp[[0,-1]].numpy(), expected, rtol=1e-7, atol=0, err_msg="sqrt_acp is incorrect")
|
||||||
|
expected = np.array([0.029155133292078972, 0.9976672530174255], dtype=np.float32)
|
||||||
|
np.testing.assert_allclose(sqrt_omacp[[0,-1]].numpy(), expected, rtol=1e-7, atol=0, err_msg="sqrt_omacp is incorrect")
|
||||||
|
|
||||||
|
with self.subTest("check mixed precision"):
|
||||||
|
out = unet.input_blocks[2][1].proj_in(Tensor.randn(320, dtype=dtypes.float32))
|
||||||
|
self.assertEqual(out.dtype, dtypes.bfloat16, "expected float32 to be downcast to bfloat16 by Linear")
|
||||||
|
out = unet.out[2](Tensor.randn(304,320,64,64, dtype=dtypes.float32))
|
||||||
|
self.assertEqual(out.dtype, dtypes.bfloat16, "expected float32 to be downcast to bfloat16 by Conv2d")
|
||||||
|
out = unet.input_blocks[1][1].transformer_blocks[0].norm1(Tensor.randn(320, dtype=dtypes.bfloat16))
|
||||||
|
self.assertEqual(out.dtype, dtypes.float32, "expected bfloat16 to be upcast to float32 by LayerNorm")
|
||||||
|
out = unet.input_blocks[5][0].in_layers[0](Tensor.randn(304, 640, dtype=dtypes.bfloat16))
|
||||||
|
self.assertEqual(out.dtype, dtypes.float32, "expected bfloat16 to be upcast to float32 by GroupNorm")
|
||||||
|
|
||||||
|
def test_train_model(self):
|
||||||
|
self.helper_test_init("v2-mlperf-train")
|
||||||
|
|
||||||
|
def test_eval_model(self):
|
||||||
|
self.helper_test_init("v2-mlperf-eval")
|
||||||
|
|
||||||
if __name__=="__main__":
|
if __name__=="__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
Reference in New Issue
Block a user