From 1839e8c9b38fe9b670b3eefa50ddcd276520fc5f Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 20 Mar 2025 09:46:53 +0800 Subject: [PATCH] place masks in INDEX for TestGatedStoreRewrite [pr] (#9512) --- test/test_uops.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/test/test_uops.py b/test/test_uops.py index 5a2c476822..cf003e8e81 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -247,10 +247,10 @@ class TestGatedStoreRewrite(unittest.TestCase): def test_tiny_gate_store(self): gmem = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4)) - idx = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem, gidx0 * UOp.const(dtypes.int, 2))) - val = UOp.const(dtypes.float, 42.0) gate = gidx0= 4: print(Device[Device.DEFAULT].renderer.render("test", uops)) if_uop = next(u for u in uops if u.op is Ops.IF) @@ -265,11 +265,10 @@ class TestGatedStoreRewrite(unittest.TestCase): gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1) gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4)) idx = gidx0 * UOp.const(dtypes.int, 2) - idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx)) + idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx, gidx0= 4: print(Device[Device.DEFAULT].renderer.render("test", uops)) if_uop = next(u for u in uops if u.op is Ops.IF) @@ -285,11 +284,11 @@ class TestGatedStoreRewrite(unittest.TestCase): gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1) gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4)) idx = gidx0*UOp.const(dtypes.int, 2) - idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx)) - idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx)) - val = UOp.const(dtypes.float, 42.0) gate = gidx0= 4: print(Device[Device.DEFAULT].renderer.render("test", uops)) ifs = [u for u in uops if u.op is Ops.IF]