swap src[2] and src[3] in load [run_process_replay] (#5821)

* swap src[2] and src[3] in load [run_process_replay]

* cleanups + bugfix

* fix ptx
This commit is contained in:
George Hotz
2024-07-30 14:04:13 -07:00
committed by GitHub
parent 17a2f74412
commit 693990a346
8 changed files with 30 additions and 30 deletions

View File

@@ -231,8 +231,8 @@ class TestUOpGraph(TestUOps):
glbl1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (1, False))
glbl2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (2, False))
idx = UOp.const(dtypes.int, 0)
ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.bool, False), UOp.const(dtypes.int, 2)))
ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.bool, True), UOp.const(dtypes.int, 3)))
ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False)))
ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True)))
uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, idx, ld1+ld0))])
ld0, ld1 = uops[-1].src[2].src
# ld0 becomes the invalid value
@@ -246,8 +246,8 @@ class TestUOpGraph(TestUOps):
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
st = UOp(UOps.STORE, None, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int)))
barrier = UOp(UOps.BARRIER, None, (st, ))
ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.bool, False), UOp.const(dtypes.int, 2), barrier))
ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.bool, True), UOp.const(dtypes.int, 3), barrier))
ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False), barrier))
ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True), barrier))
uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, lidx, ld1+ld0))])
ld0, ld1 = uops[-1].src[2].src
# ld0 becomes the invalid value
@@ -438,18 +438,18 @@ class TestLoadStoreFolder(unittest.TestCase):
def test_simple_load_fold_gated(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
gate = UOp(UOps.DEFINE_VAR, dtypes.bool)
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), gate, UOp.const(dtypes.float, i))) for i in range(4)]
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
single_load = [x for x in sink.sparents if x.op is UOps.LOAD][0]
self.assertListEqual([src.arg for src in single_load.src[3].src], [0.0, 1.0, 2.0, 3.0])
self.assertListEqual([src.arg for src in single_load.src[2].src], [0.0, 1.0, 2.0, 3.0])
def test_simple_load_dont_fold_different_gated(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1")
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g2")
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), gate if i == 0 else gate2, UOp.const(dtypes.float, i))) for i in range(4)]
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 3