Ptx beam fix (#4296)

* Fix beam search for PTX

* fix ptr arm test
This commit is contained in:
Szymon Ożóg
2024-04-25 21:39:39 +02:00
committed by GitHub
parent f9a7badace
commit f1ebcffb87
2 changed files with 7 additions and 3 deletions

View File

@@ -236,6 +236,7 @@ class TestLocalAccess(unittest.TestCase):
@unittest.skipUnless(Device.DEFAULT in {"CUDA"} and getenv("PTX"), "This only tests assembly backends")
class TestAssembly(unittest.TestCase):
def test_pointer_arithmetics_caching(self):
from tinygrad.renderer.assembly import ptr_ar
uops = UOpGraph()
u1 = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple(), (0, 'data0', True))
u2 = uops.add(UOps.SPECIAL, dtypes.int, tuple(), (0, 'gidx0', 9))
@@ -247,7 +248,8 @@ class TestAssembly(unittest.TestCase):
u8 = uops.add(UOps.ALU, dtypes.int, (u4, u6), BinaryOps.ADD)
u9 = uops.add(UOps.LOAD, dtypes.int, (u1, u7))
u10 = uops.add(UOps.LOAD, dtypes.int, (u1, u8))
_uops_to_prg(uops)
ptr_ar(u9, uops)
ptr_ar(u10, uops)
self.assertEqual(u9.vin[0], u10.vin[0])
self.assertEqual(u9.vin[1].uop, UOps.CONST)
self.assertEqual(u9.vin[1].arg, u5.arg*dtypes.float.itemsize)