test case of NOOP store load folding (#13997)

This commit is contained in:
chenyu
2026-01-03 14:39:26 -05:00
committed by GitHub
parent c1b8644a3f
commit 8003db2a28

View File

@@ -1014,6 +1014,21 @@ class TestInvalidIndex(unittest.TestCase):
c2 = UOp.const(dtypes.index.vec(4), (1, Invalid, 1, 1))
self.assertIs((c1+c2).simplify(), UOp.const(dtypes.index.vec(4), (2, Invalid, Invalid, Invalid)))
class TestStoreLoadFolding(unittest.TestCase):
"""Tests for store(index, load(index)) -> NOOP rule. This rule matches patterns that EMERGE during simplification."""
def test_store_load_folding(self):
# store(idx, load(idx)) -> NOOP, including emergent patterns like store(idx, load(idx) + 0)
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
index = buf.index(UOp.const(dtypes.index, 0))
# Direct: store(idx, load(idx)) -> NOOP
self.assertEqual(graph_rewrite(index.store(index.load()), sym).op, Ops.NOOP)
# Emergent: store(idx, load(idx) + 0) -> store(idx, load(idx)) -> NOOP
self.assertEqual(graph_rewrite(index.store(index.load() + UOp.const(dtypes.int, 0)), sym).op, Ops.NOOP)
# Emergent: store(idx, load(idx) * 1) -> store(idx, load(idx)) -> NOOP
self.assertEqual(graph_rewrite(index.store(index.load() * UOp.const(dtypes.int, 1)), sym).op, Ops.NOOP)
# Negative: store(idx, load(idx) + 1) should NOT fold
self.assertEqual(graph_rewrite(index.store(index.load() + UOp.const(dtypes.int, 1)), sym).op, Ops.STORE)
class TestSymbolicRealWorld(unittest.TestCase):
def test_resnet_half(self):
gidx0 = Variable("gidx0", 0, 3)