diff --git a/test/test_schedule.py b/test/test_schedule.py index cf9fd16eb2..d712b97b70 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -2314,5 +2314,30 @@ class TestBufferUOp(unittest.TestCase): a2 = a.contiguous().realize() self.assertEqual(a2.lazydata.base.realized.size, 16) +class TestContiguous(unittest.TestCase): + def test_contiguous_buffer(self): + a = Tensor.empty(4).lazydata + b = a.alu(Ops.CONTIGUOUS) + b = schedule_graph_rewrite(b) + self.assertIs(b, a) + + def test_contiguous_buffer_view(self): + a = Tensor.empty(4).lazydata + b = a.reshape((2, 2)).alu(Ops.CONTIGUOUS) + b = schedule_graph_rewrite(b) + self.assertIs(b, a.buf_uop.view(unwrap(b.st))) + + def test_non_contiguous_buffer_view(self): + a = Tensor.empty(4, 1).lazydata + b = a.expand((4, 4)).alu(Ops.CONTIGUOUS) + b = schedule_graph_rewrite(b) + assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {}) + + def test_size_change_buffer_view(self): + a = Tensor.empty(4).lazydata + b = a.reshape((1, 1, 4)).shrink(((0, 1), (0, 1), (0, 3))).alu(Ops.CONTIGUOUS) + b = schedule_graph_rewrite(b) + assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {}) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 6b1d2bf223..f31b58b97c 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -386,6 +386,8 @@ ops_folding = symbolic_simple+PatternMatcher([ # no COPY to same device, except clone (arg is True) (UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"), lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None), + (UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)), + lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None), # support for using a contiguous permuted view instead of the parent view if one exists (UPatScheduled(Ops.CONTIGUOUS, name="contig"), found_contiguous), (UPat(GroupOp.ALU, name="alu"), replace_contiguous),