const folding tests [pr] (#7967)

This commit is contained in:
qazal
2024-11-30 06:27:30 -05:00
committed by GitHub
parent 8780818d04
commit 5615e92df8

View File

@@ -1941,9 +1941,17 @@ class TestBigGraph(unittest.TestCase):
def test_sink_childless_const_alt(self):
x = UOp.const(dtypes.int, 0)
y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape((10, 10)))
y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(()))
big_graph = big_graph_rewrite(UOp.sink(x, y), realizes:={})
self.assertIs(big_graph, UOp(Ops.SINK, dtypes.void, (x, y)))
self.assertEqual(len(realizes), 1) # TODO: this should fold a flat CONST
def test_sink_childless_const_alt_expanded(self):
# this is a real STORE of CONST (post expand)
y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(()))
out = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), y.reshape((1,)).expand((2,)).contiguous(),), ShapeTracker.from_shape((2,)))
big_graph = big_graph_rewrite(out.sink(), realizes:={})
self.assertIs(big_graph, out.sink())
self.assertEqual(len(realizes), 1)
if __name__ == '__main__':