Remove contiguous on buffer (#8676)

* remove contiguous on buffer

* spec

* make things that can't be images not images
This commit is contained in:
qazal
2025-01-20 06:48:33 -05:00
committed by GitHub
parent 3499a2c72d
commit ed63ff2372
2 changed files with 27 additions and 0 deletions

View File

@@ -2314,5 +2314,30 @@ class TestBufferUOp(unittest.TestCase):
a2 = a.contiguous().realize()
self.assertEqual(a2.lazydata.base.realized.size, 16)
class TestContiguous(unittest.TestCase):
def test_contiguous_buffer(self):
a = Tensor.empty(4).lazydata
b = a.alu(Ops.CONTIGUOUS)
b = schedule_graph_rewrite(b)
self.assertIs(b, a)
def test_contiguous_buffer_view(self):
a = Tensor.empty(4).lazydata
b = a.reshape((2, 2)).alu(Ops.CONTIGUOUS)
b = schedule_graph_rewrite(b)
self.assertIs(b, a.buf_uop.view(unwrap(b.st)))
def test_non_contiguous_buffer_view(self):
a = Tensor.empty(4, 1).lazydata
b = a.expand((4, 4)).alu(Ops.CONTIGUOUS)
b = schedule_graph_rewrite(b)
assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {})
def test_size_change_buffer_view(self):
a = Tensor.empty(4).lazydata
b = a.reshape((1, 1, 4)).shrink(((0, 1), (0, 1), (0, 3))).alu(Ops.CONTIGUOUS)
b = schedule_graph_rewrite(b)
assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {})
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -386,6 +386,8 @@ ops_folding = symbolic_simple+PatternMatcher([
# no COPY to same device, except clone (arg is True)
(UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
# support for using a contiguous permuted view instead of the parent view if one exists
(UPatScheduled(Ops.CONTIGUOUS, name="contig"), found_contiguous),
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),