update store load noop pattern to use Invalid (#12141)

* update pattern

* add test
This commit is contained in:
Sieds Lykles
2025-09-12 22:25:53 +02:00
committed by GitHub
parent 647965fb09
commit 62376c8b2b
2 changed files with 15 additions and 3 deletions

View File

@@ -419,7 +419,7 @@ class TestUOpGraph(unittest.TestCase):
def test_where_on_gated_load_fold(self):
ridx0 = UOp.range(100, 0)
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 128)
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0)
ld = d0.index(ridx0, ridx0<50).load()
w = (ridx0<50).where(ld, 5)
uops = to_uops_list([w])
@@ -429,7 +429,7 @@ class TestUOpGraph(unittest.TestCase):
def test_where_on_gated_load_folds_swapped_branches(self):
ridx0 = UOp.range(100, 0)
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 128)
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0)
ld = d0.index(ridx0, (ridx0<50).logical_not()).load()
w = (ridx0<50).where(5, ld)
uops = to_uops_list([w])
@@ -437,6 +437,18 @@ class TestUOpGraph(unittest.TestCase):
assert u.op is not Ops.WHERE
if u.op is Ops.LOAD: assert u.src[1].arg==5
def test_where_in_store_becomes_gate(self):
ridx0 = UOp.range(100, 0)
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0)
idx = d0.index(ridx0)
ld = idx.load()
val = (ridx0<50).where(5, ld)
st = idx.store(val, ridx0)
uops = to_uops_list([st])
for u in uops:
assert u.op is not Ops.WHERE
if u.op is Ops.STORE: assert u.src[1].arg==5
def test_load_idx_becomes_int(self):
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 0)
d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), (), 1)

View File

@@ -519,7 +519,7 @@ sym = symbolic_flat+PatternMatcher([
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)),
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"),
UPat.load(UPat(Ops.INDEX, name="index"))), allow_any_len=True, name="store"),
lambda index, gate, alt, store: UOp.store(index.src[0].index(index.src[1], gate), alt, *store.src[2:])),
lambda index, gate, alt, store: UOp.store(index.src[0].index(gate.where(index.src[1], UOp.invalid())), alt, *store.src[2:])),
# fold gated LOAD/STORE
(UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat.const(dtypes.index, Invalid)).or_casted(),), allow_any_len=True, name="x"),
lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # invalid store does nothing. invalid load produces 0