mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
update store load noop pattern to use Invalid (#12141)
* update pattern * add test
This commit is contained in:
@@ -419,7 +419,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
|
||||
def test_where_on_gated_load_fold(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 128)
|
||||
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0)
|
||||
ld = d0.index(ridx0, ridx0<50).load()
|
||||
w = (ridx0<50).where(ld, 5)
|
||||
uops = to_uops_list([w])
|
||||
@@ -429,7 +429,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
|
||||
def test_where_on_gated_load_folds_swapped_branches(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 128)
|
||||
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0)
|
||||
ld = d0.index(ridx0, (ridx0<50).logical_not()).load()
|
||||
w = (ridx0<50).where(5, ld)
|
||||
uops = to_uops_list([w])
|
||||
@@ -437,6 +437,18 @@ class TestUOpGraph(unittest.TestCase):
|
||||
assert u.op is not Ops.WHERE
|
||||
if u.op is Ops.LOAD: assert u.src[1].arg==5
|
||||
|
||||
def test_where_in_store_becomes_gate(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0)
|
||||
idx = d0.index(ridx0)
|
||||
ld = idx.load()
|
||||
val = (ridx0<50).where(5, ld)
|
||||
st = idx.store(val, ridx0)
|
||||
uops = to_uops_list([st])
|
||||
for u in uops:
|
||||
assert u.op is not Ops.WHERE
|
||||
if u.op is Ops.STORE: assert u.src[1].arg==5
|
||||
|
||||
def test_load_idx_becomes_int(self):
|
||||
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0)
|
||||
d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 1)
|
||||
|
||||
@@ -519,7 +519,7 @@ sym = symbolic_flat+PatternMatcher([
|
||||
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)),
|
||||
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"),
|
||||
UPat.load(UPat(Ops.INDEX, name="index"))), allow_any_len=True, name="store"),
|
||||
lambda index, gate, alt, store: UOp.store(index.src[0].index(index.src[1], gate), alt, *store.src[2:])),
|
||||
lambda index, gate, alt, store: UOp.store(index.src[0].index(gate.where(index.src[1], UOp.invalid())), alt, *store.src[2:])),
|
||||
# fold gated LOAD/STORE
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat.const(dtypes.index, Invalid)).or_casted(),), allow_any_len=True, name="x"),
|
||||
lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # invalid store does nothing. invalid load produces 0
|
||||
|
||||
Reference in New Issue
Block a user