mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Fix issues with pointer provenance in load/store through ALU (#3916)
* Track pointer provenance in load/store through ALU Previously load/store could be incorrectly rendered into ld.global/st.global when the input was an ALU op that performed an address computation with DEFINE_LOCAL on one of the arguments. * Simplify the load provenance workaround The issue is that we can render the same code twice, and on the second run the opstream is already modified so that vin[0] isn't a DEFINE_*, which overwrites initially correct .shared wth .global. * Add a couple tests for basic local use * Skip local tests on LLVM since it doesn't implement DEFINE_LOCAL
This commit is contained in:
committed by
GitHub
parent
d651835ef5
commit
514c43201d
@@ -49,6 +49,18 @@ def _test_single_value_const(vals, op, dts):
|
||||
buf.copyout(ret.data)
|
||||
return ret[0]
|
||||
|
||||
def _test_uops_result(output_dtype, uops, res):
|
||||
# uops = []
|
||||
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, 'data0',True))
|
||||
# res = output_fn(uops)
|
||||
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), res))
|
||||
buf = Buffer(Device.DEFAULT, 1, output_dtype)
|
||||
prg = _uops_to_prg(UOpGraph(uops))
|
||||
prg.exec([buf])
|
||||
ret = np.empty(1, output_dtype.np)
|
||||
buf.copyout(ret.data)
|
||||
return ret[0]
|
||||
|
||||
class TestUOps(unittest.TestCase):
|
||||
def _equal(self, v1, v2):
|
||||
if not (math.isnan(v1) and math.isnan(v2)): self.assertAlmostEqual(v1, v2, places=5) if v1.dtype != np.bool_ else self.assertEqual(v1, v2)
|
||||
@@ -194,5 +206,24 @@ class TestConstantFolding(unittest.TestCase):
|
||||
lin = Device[Device.DEFAULT].get_linearizer(si.ast[0]).linearize()
|
||||
assert any(uop.uop is UOps.BITCAST for uop in lin.uops.uops), f"{[uop.uop for uop in lin.uops.uops]} does not contain bitcast"
|
||||
|
||||
class TestLocalAccess(unittest.TestCase):
|
||||
@unittest.skipIf(Device.DEFAULT in {"LLVM"}, "device doesn't support local memory")
|
||||
def test_local_basic(self):
|
||||
uops = []
|
||||
smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), ('smem', 16))
|
||||
uop(uops, UOps.STORE, None, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), uop(uops, UOps.CONST, dtypes.float32, (), 42)))
|
||||
sres = uop(uops, UOps.LOAD, dtypes.float32, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0)))
|
||||
self.assertEqual(_test_uops_result(dtypes.float32, uops, sres), 42)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"LLVM"}, "device doesn't support local memory")
|
||||
def test_local_indirect(self):
|
||||
uops = []
|
||||
smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.int32), (), ('smem', 16))
|
||||
uop(uops, UOps.STORE, None, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 1), uop(uops, UOps.CONST, dtypes.int32, (), 2)))
|
||||
uop(uops, UOps.STORE, None, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 2), uop(uops, UOps.CONST, dtypes.int32, (), 42)))
|
||||
ofs = uop(uops, UOps.LOAD, dtypes.int32, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 1)))
|
||||
sres = uop(uops, UOps.LOAD, dtypes.int32, (smem, ofs))
|
||||
self.assertEqual(_test_uops_result(dtypes.int32, uops, sres), 42)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -65,7 +65,8 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
return ld_rep(root,x,y)
|
||||
|
||||
def ptr_ar(root):
|
||||
root.arg = '.shared' if root.vin[0].uop == UOps.DEFINE_LOCAL else '.global' # move this to the argL
|
||||
assert root.arg in {'.shared', '.global', None}
|
||||
if root.arg is None: root.arg = '.shared' if root.vin[0].uop == UOps.DEFINE_LOCAL else '.global' # move this to the argL
|
||||
val = uops.add(UOps.CONST, dtypes.int, tuple(), arg=root.vin[0].dtype.itemsize, insert_before=uops.uops.index(root))
|
||||
ptr = uops.add(UOps.ALU, dtypes.int, (root.vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(root))
|
||||
if ptr.uop == UOps.CONST: root.vin = (root.vin[0], ptr) + root.vin[2:]
|
||||
|
||||
Reference in New Issue
Block a user