mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
rangeify attn tests (#12377)
This commit is contained in:
@@ -2,9 +2,11 @@ import unittest
|
||||
from tinygrad import Tensor, dtypes, TinyJit, UOp
|
||||
from tinygrad.helpers import RANGEIFY
|
||||
from tinygrad.apps.llm import apply_rope
|
||||
#from tinygrad.engine.realize import run_schedule
|
||||
|
||||
# TODO: test_scheduler, but just in uint
|
||||
class TestAttention(unittest.TestCase):
|
||||
@unittest.skipIf(RANGEIFY > 0, "not half on rangeify")
|
||||
def test_half_qkv_buffers(self):
|
||||
BS, seqlen, dim = 10, 4, 100
|
||||
q = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
|
||||
@@ -12,11 +14,12 @@ class TestAttention(unittest.TestCase):
|
||||
v = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
|
||||
attn = q.scaled_dot_product_attention(k, v)
|
||||
sched = attn.schedule()
|
||||
#run_schedule(sched[:])
|
||||
# attention has 5 kernels now
|
||||
self.assertEqual(len(sched), 4 if RANGEIFY else 5)
|
||||
softmax_inputs = sched[1:4]
|
||||
for si in softmax_inputs:
|
||||
assert all(b.dtype == dtypes.half for b in si.bufs), f"non half {si.bufs=}"
|
||||
for i,si in enumerate(softmax_inputs):
|
||||
assert all(b.dtype == dtypes.half for b in si.bufs), f"non half {si.bufs=} in kernel {i}"
|
||||
|
||||
def test_apply_rope(self):
|
||||
x = Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32)
|
||||
|
||||
@@ -58,7 +58,8 @@ def apply_rope(x:Tensor, start_pos:int|UOp, base:float = 10000.0) -> Tensor:
|
||||
assert (Hd & 1) == 0, "RoPE requires an even head dimension"
|
||||
half = Hd // 2
|
||||
angles = (Tensor.arange(T, dtype="float32") + start_pos)[:, None] * (base ** (-(Tensor.arange(half, dtype="float32") / half)))[None, :]
|
||||
cos, sin = angles.cos().reshape(1, 1, T, half).cast(x.dtype), angles.sin().reshape(1, 1, T, half).cast(x.dtype)
|
||||
# contiguous here allows RoPE to be pruned in the JIT
|
||||
cos, sin = angles.cos().reshape(1, 1, T, half).cast(x.dtype).contiguous(), angles.sin().reshape(1, 1, T, half).cast(x.dtype).contiguous()
|
||||
x_pairs = x.reshape(B, H, T, half, 2)
|
||||
return Tensor.stack(x_pairs[..., 0] * cos - x_pairs[..., 1] * sin,
|
||||
x_pairs[..., 0] * sin + x_pairs[..., 1] * cos, dim=-1).reshape(B, H, T, Hd)
|
||||
|
||||
Reference in New Issue
Block a user