diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 4cc015da2e..82a2767dce 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -1014,6 +1014,21 @@ class TestInvalidIndex(unittest.TestCase): c2 = UOp.const(dtypes.index.vec(4), (1, Invalid, 1, 1)) self.assertIs((c1+c2).simplify(), UOp.const(dtypes.index.vec(4), (2, Invalid, Invalid, Invalid))) +class TestStoreLoadFolding(unittest.TestCase): + """Tests for store(index, load(index)) -> NOOP rule. This rule matches patterns that EMERGE during simplification.""" + def test_store_load_folding(self): + # store(idx, load(idx)) -> NOOP, including emergent patterns like store(idx, load(idx) + 0) + buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0) + index = buf.index(UOp.const(dtypes.index, 0)) + # Direct: store(idx, load(idx)) -> NOOP + self.assertEqual(graph_rewrite(index.store(index.load()), sym).op, Ops.NOOP) + # Emergent: store(idx, load(idx) + 0) -> store(idx, load(idx)) -> NOOP + self.assertEqual(graph_rewrite(index.store(index.load() + UOp.const(dtypes.int, 0)), sym).op, Ops.NOOP) + # Emergent: store(idx, load(idx) * 1) -> store(idx, load(idx)) -> NOOP + self.assertEqual(graph_rewrite(index.store(index.load() * UOp.const(dtypes.int, 1)), sym).op, Ops.NOOP) + # Negative: store(idx, load(idx) + 1) should NOT fold + self.assertEqual(graph_rewrite(index.store(index.load() + UOp.const(dtypes.int, 1)), sym).op, Ops.STORE) + class TestSymbolicRealWorld(unittest.TestCase): def test_resnet_half(self): gidx0 = Variable("gidx0", 0, 3)