mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-14 08:35:17 -05:00
example script to show BasicTransformerBlock speed regression (#7724)
This commit is contained in:
99
extra/debug_sd_speed.py
Normal file
99
extra/debug_sd_speed.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# NOTE: this is written in a way that checkout back to old commit still works
|
||||
# fast SD 297ms step on M1 Max, 4444e6d https://github.com/tinygrad/tinygrad/pull/2129
|
||||
# lazy rewrite, 1765849 https://github.com/tinygrad/tinygrad/pull/2878
|
||||
# SD 415ms step on M1 Max on master around 11/15/2024
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
try: from tinygrad.jit import TinyJit
|
||||
except ImportError: from tinygrad import TinyJit
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.helpers import GlobalCounters
|
||||
from tinygrad.nn import Linear, LayerNorm
|
||||
from tinygrad.nn.state import get_parameters
|
||||
|
||||
class CrossAttention:
|
||||
def __init__(self, query_dim:int, ctx_dim:int, n_heads:int, d_head:int):
|
||||
self.to_q = Linear(query_dim, n_heads*d_head, bias=False)
|
||||
self.to_k = Linear(ctx_dim, n_heads*d_head, bias=False)
|
||||
self.to_v = Linear(ctx_dim, n_heads*d_head, bias=False)
|
||||
self.num_heads = n_heads
|
||||
self.head_size = d_head
|
||||
self.to_out = [Linear(n_heads*d_head, query_dim)]
|
||||
|
||||
def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
|
||||
ctx = x if ctx is None else ctx
|
||||
q,k,v = self.to_q(x), self.to_k(ctx), self.to_v(ctx)
|
||||
q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)]
|
||||
attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1,2)
|
||||
h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size)
|
||||
return h_.sequential(self.to_out)
|
||||
|
||||
class GEGLU:
|
||||
def __init__(self, dim_in:int, dim_out:int):
|
||||
self.proj = Linear(dim_in, dim_out * 2)
|
||||
self.dim_out = dim_out
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * gate.gelu()
|
||||
|
||||
class FeedForward:
|
||||
def __init__(self, dim:int, mult:int=4):
|
||||
self.net = [
|
||||
GEGLU(dim, dim*mult),
|
||||
lambda x: x, # needed for weights loading code to work
|
||||
Linear(dim*mult, dim)
|
||||
]
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return x.sequential(self.net)
|
||||
|
||||
class BasicTransformerBlock:
|
||||
def __init__(self, dim:int, ctx_dim:int, n_heads:int, d_head:int):
|
||||
self.attn1 = CrossAttention(dim, dim, n_heads, d_head)
|
||||
self.ff = FeedForward(dim)
|
||||
self.attn2 = CrossAttention(dim, ctx_dim, n_heads, d_head)
|
||||
self.norm1 = LayerNorm(dim)
|
||||
self.norm2 = LayerNorm(dim)
|
||||
self.norm3 = LayerNorm(dim)
|
||||
|
||||
def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
|
||||
x = x + self.attn1(self.norm1(x)) # 5.4 before, # 6.8 master
|
||||
x = x + self.attn2(self.norm2(x), ctx=ctx) # 12 before, 12 master
|
||||
x = x + self.ff(self.norm3(x)) # 23 before, # 27 master
|
||||
return x
|
||||
|
||||
def helper_test(gen, model):
|
||||
tms = []
|
||||
for _ in range(5):
|
||||
early_gen = [x.realize() if isinstance(x, Tensor) else x for x in gen()]
|
||||
GlobalCounters.reset()
|
||||
Device[Device.DEFAULT].synchronize()
|
||||
st = time.perf_counter_ns()
|
||||
model(*early_gen)
|
||||
Device[Device.DEFAULT].synchronize()
|
||||
tms.append(time.perf_counter_ns() - st)
|
||||
print(f"{min(tms)/1e6=:.2f} ms")
|
||||
|
||||
def derandomize_model(model):
|
||||
for p in get_parameters(model):
|
||||
p.lazydata = Tensor.empty(*p.shape, device=p.device, dtype=p.dtype).lazydata
|
||||
p.realize()
|
||||
|
||||
def test_transformer_block():
|
||||
# dim, d_head, x = 320, 40, (4096, 320) # 137ms 4444e6d 115ms master
|
||||
# dim, d_head, x = 640, 80, (1024, 640) # 36ms 4444e6d, 31ms master
|
||||
dim, d_head, x = 1280, 160, (256, 1280) # 23ms 4444e6d, 28ms master, 31ms on 176584993
|
||||
|
||||
model = [BasicTransformerBlock(dim, 768, 8, d_head) for _ in range(4)]
|
||||
|
||||
derandomize_model(model)
|
||||
@TinyJit
|
||||
def test(t, t2):
|
||||
for l in model: t = l(t, t2)
|
||||
return t.realize()
|
||||
helper_test(lambda: (Tensor.empty(2, *x), Tensor.empty(2, 77, 768)), test)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_transformer_block()
|
||||
Reference in New Issue
Block a user