place masks in INDEX for TestGatedStoreRewrite [pr] (#9512)

This commit is contained in:
qazal
2025-03-20 09:46:53 +08:00
committed by GitHub
parent bd731a8624
commit 1839e8c9b3

View File

@@ -247,10 +247,10 @@ class TestGatedStoreRewrite(unittest.TestCase):
def test_tiny_gate_store(self):
gmem = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4))
idx = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem, gidx0 * UOp.const(dtypes.int, 2)))
val = UOp.const(dtypes.float, 42.0)
gate = gidx0<UOp.const(dtypes.int, 1)
store = UOp(Ops.STORE, dtypes.void, (idx, val, gate))
idx = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem, gidx0 * UOp.const(dtypes.int, 2), gate))
val = UOp.const(dtypes.float, 42.0)
store = UOp(Ops.STORE, dtypes.void, (idx, val))
uops = to_uops_list([store])
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
if_uop = next(u for u in uops if u.op is Ops.IF)
@@ -265,11 +265,10 @@ class TestGatedStoreRewrite(unittest.TestCase):
gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4))
idx = gidx0 * UOp.const(dtypes.int, 2)
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx))
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx, gidx0<UOp.const(dtypes.int, 1)))
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx))
val = UOp.const(dtypes.float, 42.0)
gate = gidx0<UOp.const(dtypes.int, 1)
stores = [UOp.store(idx0, val, gate), UOp.store(idx1, val)]
stores = [UOp.store(idx0, val), UOp.store(idx1, val)]
uops = to_uops_list(stores)
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
if_uop = next(u for u in uops if u.op is Ops.IF)
@@ -285,11 +284,11 @@ class TestGatedStoreRewrite(unittest.TestCase):
gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4))
idx = gidx0*UOp.const(dtypes.int, 2)
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx))
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx))
val = UOp.const(dtypes.float, 42.0)
gate = gidx0<UOp.const(dtypes.int, 1)
stores = [UOp.store(idx0, val, gate), UOp.store(idx1, val, gate)]
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx, gate))
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx, gate))
val = UOp.const(dtypes.float, 42.0)
stores = [UOp.store(idx0, val), UOp.store(idx1, val)]
uops = to_uops_list(stores)
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
ifs = [u for u in uops if u.op is Ops.IF]