mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-05 20:24:57 -05:00
100 lines
3.6 KiB
Python
100 lines
3.6 KiB
Python
# NOTE: this is written in a way that checkout back to old commit still works
|
|
# fast SD 297ms step on M1 Max, 4444e6d https://github.com/tinygrad/tinygrad/pull/2129
|
|
# lazy rewrite, 1765849 https://github.com/tinygrad/tinygrad/pull/2878
|
|
# SD 415ms step on M1 Max on master around 11/15/2024
|
|
|
|
import time
|
|
from typing import Optional
|
|
try: from tinygrad.jit import TinyJit
|
|
except ImportError: from tinygrad import TinyJit
|
|
from tinygrad.tensor import Tensor, Device
|
|
from tinygrad.helpers import GlobalCounters
|
|
from tinygrad.nn import Linear, LayerNorm
|
|
from tinygrad.nn.state import get_parameters
|
|
|
|
class CrossAttention:
|
|
def __init__(self, query_dim:int, ctx_dim:int, n_heads:int, d_head:int):
|
|
self.to_q = Linear(query_dim, n_heads*d_head, bias=False)
|
|
self.to_k = Linear(ctx_dim, n_heads*d_head, bias=False)
|
|
self.to_v = Linear(ctx_dim, n_heads*d_head, bias=False)
|
|
self.num_heads = n_heads
|
|
self.head_size = d_head
|
|
self.to_out = [Linear(n_heads*d_head, query_dim)]
|
|
|
|
def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
|
|
ctx = x if ctx is None else ctx
|
|
q,k,v = self.to_q(x), self.to_k(ctx), self.to_v(ctx)
|
|
q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)]
|
|
attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1,2)
|
|
h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size)
|
|
return h_.sequential(self.to_out)
|
|
|
|
class GEGLU:
|
|
def __init__(self, dim_in:int, dim_out:int):
|
|
self.proj = Linear(dim_in, dim_out * 2)
|
|
self.dim_out = dim_out
|
|
|
|
def __call__(self, x:Tensor) -> Tensor:
|
|
x, gate = self.proj(x).chunk(2, dim=-1)
|
|
return x * gate.gelu()
|
|
|
|
class FeedForward:
|
|
def __init__(self, dim:int, mult:int=4):
|
|
self.net = [
|
|
GEGLU(dim, dim*mult),
|
|
lambda x: x, # needed for weights loading code to work
|
|
Linear(dim*mult, dim)
|
|
]
|
|
|
|
def __call__(self, x:Tensor) -> Tensor:
|
|
return x.sequential(self.net)
|
|
|
|
class BasicTransformerBlock:
|
|
def __init__(self, dim:int, ctx_dim:int, n_heads:int, d_head:int):
|
|
self.attn1 = CrossAttention(dim, dim, n_heads, d_head)
|
|
self.ff = FeedForward(dim)
|
|
self.attn2 = CrossAttention(dim, ctx_dim, n_heads, d_head)
|
|
self.norm1 = LayerNorm(dim)
|
|
self.norm2 = LayerNorm(dim)
|
|
self.norm3 = LayerNorm(dim)
|
|
|
|
def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
|
|
x = x + self.attn1(self.norm1(x)) # 5.4 before, # 6.8 master
|
|
x = x + self.attn2(self.norm2(x), ctx=ctx) # 12 before, 12 master
|
|
x = x + self.ff(self.norm3(x)) # 23 before, # 27 master
|
|
return x
|
|
|
|
def helper_test(gen, model):
|
|
tms = []
|
|
for _ in range(5):
|
|
early_gen = [x.realize() if isinstance(x, Tensor) else x for x in gen()]
|
|
GlobalCounters.reset()
|
|
Device[Device.DEFAULT].synchronize()
|
|
st = time.perf_counter_ns()
|
|
model(*early_gen)
|
|
Device[Device.DEFAULT].synchronize()
|
|
tms.append(time.perf_counter_ns() - st)
|
|
print(f"{min(tms)/1e6=:.2f} ms")
|
|
|
|
def derandomize_model(model):
|
|
for p in get_parameters(model):
|
|
p.lazydata = Tensor.empty(*p.shape, device=p.device, dtype=p.dtype).lazydata
|
|
p.realize()
|
|
|
|
def test_transformer_block():
|
|
# dim, d_head, x = 320, 40, (4096, 320) # 137ms 4444e6d 115ms master
|
|
# dim, d_head, x = 640, 80, (1024, 640) # 36ms 4444e6d, 31ms master
|
|
dim, d_head, x = 1280, 160, (256, 1280) # 23ms 4444e6d, 28ms master, 31ms on 176584993
|
|
|
|
model = [BasicTransformerBlock(dim, 768, 8, d_head) for _ in range(4)]
|
|
|
|
derandomize_model(model)
|
|
@TinyJit
|
|
def test(t, t2):
|
|
for l in model: t = l(t, t2)
|
|
return t.realize()
|
|
helper_test(lambda: (Tensor.empty(2, *x), Tensor.empty(2, 77, 768)), test)
|
|
|
|
if __name__ == "__main__":
|
|
test_transformer_block()
|