mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
place masks in INDEX for TestGatedStoreRewrite [pr] (#9512)
This commit is contained in:
@@ -247,10 +247,10 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
||||
def test_tiny_gate_store(self):
|
||||
gmem = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||
idx = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem, gidx0 * UOp.const(dtypes.int, 2)))
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
gate = gidx0<UOp.const(dtypes.int, 1)
|
||||
store = UOp(Ops.STORE, dtypes.void, (idx, val, gate))
|
||||
idx = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem, gidx0 * UOp.const(dtypes.int, 2), gate))
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
store = UOp(Ops.STORE, dtypes.void, (idx, val))
|
||||
uops = to_uops_list([store])
|
||||
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
|
||||
if_uop = next(u for u in uops if u.op is Ops.IF)
|
||||
@@ -265,11 +265,10 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
||||
gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||
idx = gidx0 * UOp.const(dtypes.int, 2)
|
||||
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx))
|
||||
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx, gidx0<UOp.const(dtypes.int, 1)))
|
||||
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx))
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
gate = gidx0<UOp.const(dtypes.int, 1)
|
||||
stores = [UOp.store(idx0, val, gate), UOp.store(idx1, val)]
|
||||
stores = [UOp.store(idx0, val), UOp.store(idx1, val)]
|
||||
uops = to_uops_list(stores)
|
||||
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
|
||||
if_uop = next(u for u in uops if u.op is Ops.IF)
|
||||
@@ -285,11 +284,11 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
||||
gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||
idx = gidx0*UOp.const(dtypes.int, 2)
|
||||
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx))
|
||||
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx))
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
gate = gidx0<UOp.const(dtypes.int, 1)
|
||||
stores = [UOp.store(idx0, val, gate), UOp.store(idx1, val, gate)]
|
||||
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx, gate))
|
||||
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx, gate))
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
stores = [UOp.store(idx0, val), UOp.store(idx1, val)]
|
||||
uops = to_uops_list(stores)
|
||||
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
|
||||
ifs = [u for u in uops if u.op is Ops.IF]
|
||||
|
||||
Reference in New Issue
Block a user