diff --git a/test/test_schedule.py b/test/test_schedule.py index a7ad17ea9c..dd4f8b3288 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1941,9 +1941,17 @@ class TestBigGraph(unittest.TestCase): def test_sink_childless_const_alt(self): x = UOp.const(dtypes.int, 0) - y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape((10, 10))) + y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(())) big_graph = big_graph_rewrite(UOp.sink(x, y), realizes:={}) self.assertIs(big_graph, UOp(Ops.SINK, dtypes.void, (x, y))) + self.assertEqual(len(realizes), 1) # TODO: this should fold a flat CONST + + def test_sink_childless_const_alt_expanded(self): + # this is a real STORE of CONST (post expand) + y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(())) + out = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), y.reshape((1,)).expand((2,)).contiguous(),), ShapeTracker.from_shape((2,))) + big_graph = big_graph_rewrite(out.sink(), realizes:={}) + self.assertIs(big_graph, out.sink()) self.assertEqual(len(realizes), 1) if __name__ == '__main__':