rangeify attn tests (#12377)

This commit is contained in:
George Hotz
2025-10-01 09:59:19 +08:00
committed by GitHub
parent 26247573e1
commit 4c9a930de2
2 changed files with 7 additions and 3 deletions

View File

@@ -2,9 +2,11 @@ import unittest
from tinygrad import Tensor, dtypes, TinyJit, UOp
from tinygrad.helpers import RANGEIFY
from tinygrad.apps.llm import apply_rope
#from tinygrad.engine.realize import run_schedule
# TODO: test_scheduler, but just in uint
class TestAttention(unittest.TestCase):
@unittest.skipIf(RANGEIFY > 0, "not half on rangeify")
def test_half_qkv_buffers(self):
BS, seqlen, dim = 10, 4, 100
q = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
@@ -12,11 +14,12 @@ class TestAttention(unittest.TestCase):
v = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
attn = q.scaled_dot_product_attention(k, v)
sched = attn.schedule()
#run_schedule(sched[:])
# attention has 5 kernels now
self.assertEqual(len(sched), 4 if RANGEIFY else 5)
softmax_inputs = sched[1:4]
for si in softmax_inputs:
assert all(b.dtype == dtypes.half for b in si.bufs), f"non half {si.bufs=}"
for i,si in enumerate(softmax_inputs):
assert all(b.dtype == dtypes.half for b in si.bufs), f"non half {si.bufs=} in kernel {i}"
def test_apply_rope(self):
x = Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32)

View File

@@ -58,7 +58,8 @@ def apply_rope(x:Tensor, start_pos:int|UOp, base:float = 10000.0) -> Tensor:
assert (Hd & 1) == 0, "RoPE requires an even head dimension"
half = Hd // 2
angles = (Tensor.arange(T, dtype="float32") + start_pos)[:, None] * (base ** (-(Tensor.arange(half, dtype="float32") / half)))[None, :]
cos, sin = angles.cos().reshape(1, 1, T, half).cast(x.dtype), angles.sin().reshape(1, 1, T, half).cast(x.dtype)
# contiguous here allows RoPE to be pruned in the JIT
cos, sin = angles.cos().reshape(1, 1, T, half).cast(x.dtype).contiguous(), angles.sin().reshape(1, 1, T, half).cast(x.dtype).contiguous()
x_pairs = x.reshape(B, H, T, half, 2)
return Tensor.stack(x_pairs[..., 0] * cos - x_pairs[..., 1] * sin,
x_pairs[..., 0] * sin + x_pairs[..., 1] * cos, dim=-1).reshape(B, H, T, Hd)