schedule sink folding with graph_rewrite [pr] (#7963)

* schedule sink folding with graph_rewrite [pr]

* x is reserved, use u

* match lazy const folding
This commit is contained in:
qazal
2024-11-30 05:30:41 -05:00
committed by GitHub
parent 10f431b96d
commit 4529c5d0da
3 changed files with 19 additions and 7 deletions

View File

@@ -1936,14 +1936,22 @@ class TestBigGraph(unittest.TestCase):
def test_sink_childless_const(self):
x = UOp.const(dtypes.int, 0)
big_graph = big_graph_rewrite(x.sink(), realizes:={})
self.assertIs(big_graph, UOp(Ops.SINK, dtypes.void, (x,)))
self.assertIs(big_graph, UOp(Ops.NOOP))
self.assertEqual(len(realizes), 0)
def test_sink_childless_const_alt(self):
x = UOp.const(dtypes.int, 0)
y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape((10, 10)))
y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(()))
big_graph = big_graph_rewrite(UOp.sink(x, y), realizes:={})
self.assertIs(big_graph, UOp(Ops.SINK, dtypes.void, (x, y)))
self.assertIs(big_graph, UOp(Ops.NOOP))
self.assertEqual(len(realizes), 0)
def test_sink_childless_const_alt_expanded(self):
# this is a real STORE of CONST
y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(()))
out = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), y.reshape((1,)).expand((2,)).contiguous(),), ShapeTracker.from_shape((2,)))
big_graph = big_graph_rewrite(out.sink(), realizes:={})
self.assertIs(big_graph, out.sink())
self.assertEqual(len(realizes), 1)
if __name__ == '__main__':