mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
test case of NOOP store load folding (#13997)
This commit is contained in:
@@ -1014,6 +1014,21 @@ class TestInvalidIndex(unittest.TestCase):
|
||||
c2 = UOp.const(dtypes.index.vec(4), (1, Invalid, 1, 1))
|
||||
self.assertIs((c1+c2).simplify(), UOp.const(dtypes.index.vec(4), (2, Invalid, Invalid, Invalid)))
|
||||
|
||||
class TestStoreLoadFolding(unittest.TestCase):
|
||||
"""Tests for store(index, load(index)) -> NOOP rule. This rule matches patterns that EMERGE during simplification."""
|
||||
def test_store_load_folding(self):
|
||||
# store(idx, load(idx)) -> NOOP, including emergent patterns like store(idx, load(idx) + 0)
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
|
||||
index = buf.index(UOp.const(dtypes.index, 0))
|
||||
# Direct: store(idx, load(idx)) -> NOOP
|
||||
self.assertEqual(graph_rewrite(index.store(index.load()), sym).op, Ops.NOOP)
|
||||
# Emergent: store(idx, load(idx) + 0) -> store(idx, load(idx)) -> NOOP
|
||||
self.assertEqual(graph_rewrite(index.store(index.load() + UOp.const(dtypes.int, 0)), sym).op, Ops.NOOP)
|
||||
# Emergent: store(idx, load(idx) * 1) -> store(idx, load(idx)) -> NOOP
|
||||
self.assertEqual(graph_rewrite(index.store(index.load() * UOp.const(dtypes.int, 1)), sym).op, Ops.NOOP)
|
||||
# Negative: store(idx, load(idx) + 1) should NOT fold
|
||||
self.assertEqual(graph_rewrite(index.store(index.load() + UOp.const(dtypes.int, 1)), sym).op, Ops.STORE)
|
||||
|
||||
class TestSymbolicRealWorld(unittest.TestCase):
|
||||
def test_resnet_half(self):
|
||||
gidx0 = Variable("gidx0", 0, 3)
|
||||
|
||||
Reference in New Issue
Block a user