diff --git a/extra/debug_sd_speed.py b/extra/debug_sd_speed.py new file mode 100644 index 0000000000..9d7af29cff --- /dev/null +++ b/extra/debug_sd_speed.py @@ -0,0 +1,99 @@ +# NOTE: this is written in a way that checkout back to old commit still works +# fast SD 297ms step on M1 Max, 4444e6d https://github.com/tinygrad/tinygrad/pull/2129 +# lazy rewrite, 1765849 https://github.com/tinygrad/tinygrad/pull/2878 +# SD 415ms step on M1 Max on master around 11/15/2024 + +import time +from typing import Optional +try: from tinygrad.jit import TinyJit +except ImportError: from tinygrad import TinyJit +from tinygrad.tensor import Tensor, Device +from tinygrad.helpers import GlobalCounters +from tinygrad.nn import Linear, LayerNorm +from tinygrad.nn.state import get_parameters + +class CrossAttention: + def __init__(self, query_dim:int, ctx_dim:int, n_heads:int, d_head:int): + self.to_q = Linear(query_dim, n_heads*d_head, bias=False) + self.to_k = Linear(ctx_dim, n_heads*d_head, bias=False) + self.to_v = Linear(ctx_dim, n_heads*d_head, bias=False) + self.num_heads = n_heads + self.head_size = d_head + self.to_out = [Linear(n_heads*d_head, query_dim)] + + def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor: + ctx = x if ctx is None else ctx + q,k,v = self.to_q(x), self.to_k(ctx), self.to_v(ctx) + q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)] + attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1,2) + h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size) + return h_.sequential(self.to_out) + +class GEGLU: + def __init__(self, dim_in:int, dim_out:int): + self.proj = Linear(dim_in, dim_out * 2) + self.dim_out = dim_out + + def __call__(self, x:Tensor) -> Tensor: + x, gate = self.proj(x).chunk(2, dim=-1) + return x * gate.gelu() + +class FeedForward: + def __init__(self, dim:int, mult:int=4): + self.net = [ + GEGLU(dim, dim*mult), + lambda x: x, # needed for weights loading code to work + Linear(dim*mult, dim) + ] + + def __call__(self, x:Tensor) -> Tensor: + return x.sequential(self.net) + +class BasicTransformerBlock: + def __init__(self, dim:int, ctx_dim:int, n_heads:int, d_head:int): + self.attn1 = CrossAttention(dim, dim, n_heads, d_head) + self.ff = FeedForward(dim) + self.attn2 = CrossAttention(dim, ctx_dim, n_heads, d_head) + self.norm1 = LayerNorm(dim) + self.norm2 = LayerNorm(dim) + self.norm3 = LayerNorm(dim) + + def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor: + x = x + self.attn1(self.norm1(x)) # 5.4 before, # 6.8 master + x = x + self.attn2(self.norm2(x), ctx=ctx) # 12 before, 12 master + x = x + self.ff(self.norm3(x)) # 23 before, # 27 master + return x + +def helper_test(gen, model): + tms = [] + for _ in range(5): + early_gen = [x.realize() if isinstance(x, Tensor) else x for x in gen()] + GlobalCounters.reset() + Device[Device.DEFAULT].synchronize() + st = time.perf_counter_ns() + model(*early_gen) + Device[Device.DEFAULT].synchronize() + tms.append(time.perf_counter_ns() - st) + print(f"{min(tms)/1e6=:.2f} ms") + +def derandomize_model(model): + for p in get_parameters(model): + p.lazydata = Tensor.empty(*p.shape, device=p.device, dtype=p.dtype).lazydata + p.realize() + +def test_transformer_block(): + # dim, d_head, x = 320, 40, (4096, 320) # 137ms 4444e6d 115ms master + # dim, d_head, x = 640, 80, (1024, 640) # 36ms 4444e6d, 31ms master + dim, d_head, x = 1280, 160, (256, 1280) # 23ms 4444e6d, 28ms master, 31ms on 176584993 + + model = [BasicTransformerBlock(dim, 768, 8, d_head) for _ in range(4)] + + derandomize_model(model) + @TinyJit + def test(t, t2): + for l in model: t = l(t, t2) + return t.realize() + helper_test(lambda: (Tensor.empty(2, *x), Tensor.empty(2, 77, 768)), test) + +if __name__ == "__main__": + test_transformer_block()