mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
swap src[2] and src[3] in load [run_process_replay] (#5821)
* swap src[2] and src[3] in load [run_process_replay] * cleanups + bugfix * fix ptx
This commit is contained in:
@@ -231,8 +231,8 @@ class TestUOpGraph(TestUOps):
|
||||
glbl1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (1, False))
|
||||
glbl2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (2, False))
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.bool, False), UOp.const(dtypes.int, 2)))
|
||||
ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.bool, True), UOp.const(dtypes.int, 3)))
|
||||
ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False)))
|
||||
ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True)))
|
||||
uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, idx, ld1+ld0))])
|
||||
ld0, ld1 = uops[-1].src[2].src
|
||||
# ld0 becomes the invalid value
|
||||
@@ -246,8 +246,8 @@ class TestUOpGraph(TestUOps):
|
||||
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
|
||||
st = UOp(UOps.STORE, None, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int)))
|
||||
barrier = UOp(UOps.BARRIER, None, (st, ))
|
||||
ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.bool, False), UOp.const(dtypes.int, 2), barrier))
|
||||
ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.bool, True), UOp.const(dtypes.int, 3), barrier))
|
||||
ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False), barrier))
|
||||
ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True), barrier))
|
||||
uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, lidx, ld1+ld0))])
|
||||
ld0, ld1 = uops[-1].src[2].src
|
||||
# ld0 becomes the invalid value
|
||||
@@ -438,18 +438,18 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
def test_simple_load_fold_gated(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
gate = UOp(UOps.DEFINE_VAR, dtypes.bool)
|
||||
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), gate, UOp.const(dtypes.float, i))) for i in range(4)]
|
||||
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
|
||||
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),))
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
|
||||
single_load = [x for x in sink.sparents if x.op is UOps.LOAD][0]
|
||||
self.assertListEqual([src.arg for src in single_load.src[3].src], [0.0, 1.0, 2.0, 3.0])
|
||||
self.assertListEqual([src.arg for src in single_load.src[2].src], [0.0, 1.0, 2.0, 3.0])
|
||||
|
||||
def test_simple_load_dont_fold_different_gated(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1")
|
||||
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g2")
|
||||
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), gate if i == 0 else gate2, UOp.const(dtypes.float, i))) for i in range(4)]
|
||||
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
|
||||
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),))
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 3
|
||||
|
||||
Reference in New Issue
Block a user