mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -236,6 +236,7 @@ class TestLocalAccess(unittest.TestCase):
|
||||
@unittest.skipUnless(Device.DEFAULT in {"CUDA"} and getenv("PTX"), "This only tests assembly backends")
|
||||
class TestAssembly(unittest.TestCase):
|
||||
def test_pointer_arithmetics_caching(self):
|
||||
from tinygrad.renderer.assembly import ptr_ar
|
||||
uops = UOpGraph()
|
||||
u1 = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple(), (0, 'data0', True))
|
||||
u2 = uops.add(UOps.SPECIAL, dtypes.int, tuple(), (0, 'gidx0', 9))
|
||||
@@ -247,7 +248,8 @@ class TestAssembly(unittest.TestCase):
|
||||
u8 = uops.add(UOps.ALU, dtypes.int, (u4, u6), BinaryOps.ADD)
|
||||
u9 = uops.add(UOps.LOAD, dtypes.int, (u1, u7))
|
||||
u10 = uops.add(UOps.LOAD, dtypes.int, (u1, u8))
|
||||
_uops_to_prg(uops)
|
||||
ptr_ar(u9, uops)
|
||||
ptr_ar(u10, uops)
|
||||
self.assertEqual(u9.vin[0], u10.vin[0])
|
||||
self.assertEqual(u9.vin[1].uop, UOps.CONST)
|
||||
self.assertEqual(u9.vin[1].arg, u5.arg*dtypes.float.itemsize)
|
||||
|
||||
Reference in New Issue
Block a user