mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
rangeify attn tests (#12377)
This commit is contained in:
@@ -2,9 +2,11 @@ import unittest
|
||||
from tinygrad import Tensor, dtypes, TinyJit, UOp
|
||||
from tinygrad.helpers import RANGEIFY
|
||||
from tinygrad.apps.llm import apply_rope
|
||||
#from tinygrad.engine.realize import run_schedule
|
||||
|
||||
# TODO: test_scheduler, but just in uint
|
||||
class TestAttention(unittest.TestCase):
|
||||
@unittest.skipIf(RANGEIFY > 0, "not half on rangeify")
|
||||
def test_half_qkv_buffers(self):
|
||||
BS, seqlen, dim = 10, 4, 100
|
||||
q = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
|
||||
@@ -12,11 +14,12 @@ class TestAttention(unittest.TestCase):
|
||||
v = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
|
||||
attn = q.scaled_dot_product_attention(k, v)
|
||||
sched = attn.schedule()
|
||||
#run_schedule(sched[:])
|
||||
# attention has 5 kernels now
|
||||
self.assertEqual(len(sched), 4 if RANGEIFY else 5)
|
||||
softmax_inputs = sched[1:4]
|
||||
for si in softmax_inputs:
|
||||
assert all(b.dtype == dtypes.half for b in si.bufs), f"non half {si.bufs=}"
|
||||
for i,si in enumerate(softmax_inputs):
|
||||
assert all(b.dtype == dtypes.half for b in si.bufs), f"non half {si.bufs=} in kernel {i}"
|
||||
|
||||
def test_apply_rope(self):
|
||||
x = Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32)
|
||||
|
||||
Reference in New Issue
Block a user