Files
tinygrad/extra/debug_sd_speed.py

100 lines
3.6 KiB
Python

# NOTE: this is written in a way that checkout back to old commit still works
# fast SD 297ms step on M1 Max, 4444e6d https://github.com/tinygrad/tinygrad/pull/2129
# lazy rewrite, 1765849 https://github.com/tinygrad/tinygrad/pull/2878
# SD 415ms step on M1 Max on master around 11/15/2024
import time
from typing import Optional
try: from tinygrad.jit import TinyJit
except ImportError: from tinygrad import TinyJit
from tinygrad.tensor import Tensor, Device
from tinygrad.helpers import GlobalCounters
from tinygrad.nn import Linear, LayerNorm
from tinygrad.nn.state import get_parameters
class CrossAttention:
def __init__(self, query_dim:int, ctx_dim:int, n_heads:int, d_head:int):
self.to_q = Linear(query_dim, n_heads*d_head, bias=False)
self.to_k = Linear(ctx_dim, n_heads*d_head, bias=False)
self.to_v = Linear(ctx_dim, n_heads*d_head, bias=False)
self.num_heads = n_heads
self.head_size = d_head
self.to_out = [Linear(n_heads*d_head, query_dim)]
def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
ctx = x if ctx is None else ctx
q,k,v = self.to_q(x), self.to_k(ctx), self.to_v(ctx)
q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)]
attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1,2)
h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size)
return h_.sequential(self.to_out)
class GEGLU:
def __init__(self, dim_in:int, dim_out:int):
self.proj = Linear(dim_in, dim_out * 2)
self.dim_out = dim_out
def __call__(self, x:Tensor) -> Tensor:
x, gate = self.proj(x).chunk(2, dim=-1)
return x * gate.gelu()
class FeedForward:
def __init__(self, dim:int, mult:int=4):
self.net = [
GEGLU(dim, dim*mult),
lambda x: x, # needed for weights loading code to work
Linear(dim*mult, dim)
]
def __call__(self, x:Tensor) -> Tensor:
return x.sequential(self.net)
class BasicTransformerBlock:
def __init__(self, dim:int, ctx_dim:int, n_heads:int, d_head:int):
self.attn1 = CrossAttention(dim, dim, n_heads, d_head)
self.ff = FeedForward(dim)
self.attn2 = CrossAttention(dim, ctx_dim, n_heads, d_head)
self.norm1 = LayerNorm(dim)
self.norm2 = LayerNorm(dim)
self.norm3 = LayerNorm(dim)
def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
x = x + self.attn1(self.norm1(x)) # 5.4 before, # 6.8 master
x = x + self.attn2(self.norm2(x), ctx=ctx) # 12 before, 12 master
x = x + self.ff(self.norm3(x)) # 23 before, # 27 master
return x
def helper_test(gen, model):
tms = []
for _ in range(5):
early_gen = [x.realize() if isinstance(x, Tensor) else x for x in gen()]
GlobalCounters.reset()
Device[Device.DEFAULT].synchronize()
st = time.perf_counter_ns()
model(*early_gen)
Device[Device.DEFAULT].synchronize()
tms.append(time.perf_counter_ns() - st)
print(f"{min(tms)/1e6=:.2f} ms")
def derandomize_model(model):
for p in get_parameters(model):
p.lazydata = Tensor.empty(*p.shape, device=p.device, dtype=p.dtype).lazydata
p.realize()
def test_transformer_block():
# dim, d_head, x = 320, 40, (4096, 320) # 137ms 4444e6d 115ms master
# dim, d_head, x = 640, 80, (1024, 640) # 36ms 4444e6d, 31ms master
dim, d_head, x = 1280, 160, (256, 1280) # 23ms 4444e6d, 28ms master, 31ms on 176584993
model = [BasicTransformerBlock(dim, 768, 8, d_head) for _ in range(4)]
derandomize_model(model)
@TinyJit
def test(t, t2):
for l in model: t = l(t, t2)
return t.realize()
helper_test(lambda: (Tensor.empty(2, *x), Tensor.empty(2, 77, 768)), test)
if __name__ == "__main__":
test_transformer_block()