Revert "schedule sink folding with graph_rewrite [pr] (#7963)" (#7965)

This reverts commit 4529c5d0da.
This commit is contained in:
qazal
2024-11-30 06:02:06 -05:00
committed by GitHub
parent 4529c5d0da
commit 8780818d04
3 changed files with 7 additions and 19 deletions

View File

@@ -1936,22 +1936,14 @@ class TestBigGraph(unittest.TestCase):
def test_sink_childless_const(self):
x = UOp.const(dtypes.int, 0)
big_graph = big_graph_rewrite(x.sink(), realizes:={})
self.assertIs(big_graph, UOp(Ops.NOOP))
self.assertIs(big_graph, UOp(Ops.SINK, dtypes.void, (x,)))
self.assertEqual(len(realizes), 0)
def test_sink_childless_const_alt(self):
x = UOp.const(dtypes.int, 0)
y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(()))
y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape((10, 10)))
big_graph = big_graph_rewrite(UOp.sink(x, y), realizes:={})
self.assertIs(big_graph, UOp(Ops.NOOP))
self.assertEqual(len(realizes), 0)
def test_sink_childless_const_alt_expanded(self):
# this is a real STORE of CONST
y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(()))
out = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), y.reshape((1,)).expand((2,)).contiguous(),), ShapeTracker.from_shape((2,)))
big_graph = big_graph_rewrite(out.sink(), realizes:={})
self.assertIs(big_graph, out.sink())
self.assertIs(big_graph, UOp(Ops.SINK, dtypes.void, (x, y)))
self.assertEqual(len(realizes), 1)
if __name__ == '__main__':