Remove contiguous on buffer (#8676)

* remove contiguous on buffer

* spec

* make things that can't be images not images
This commit is contained in:
qazal
2025-01-20 06:48:33 -05:00
committed by GitHub
parent 3499a2c72d
commit ed63ff2372
2 changed files with 27 additions and 0 deletions

View File

@@ -2314,5 +2314,30 @@ class TestBufferUOp(unittest.TestCase):
a2 = a.contiguous().realize()
self.assertEqual(a2.lazydata.base.realized.size, 16)
class TestContiguous(unittest.TestCase):
def test_contiguous_buffer(self):
a = Tensor.empty(4).lazydata
b = a.alu(Ops.CONTIGUOUS)
b = schedule_graph_rewrite(b)
self.assertIs(b, a)
def test_contiguous_buffer_view(self):
a = Tensor.empty(4).lazydata
b = a.reshape((2, 2)).alu(Ops.CONTIGUOUS)
b = schedule_graph_rewrite(b)
self.assertIs(b, a.buf_uop.view(unwrap(b.st)))
def test_non_contiguous_buffer_view(self):
a = Tensor.empty(4, 1).lazydata
b = a.expand((4, 4)).alu(Ops.CONTIGUOUS)
b = schedule_graph_rewrite(b)
assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {})
def test_size_change_buffer_view(self):
a = Tensor.empty(4).lazydata
b = a.reshape((1, 1, 4)).shrink(((0, 1), (0, 1), (0, 3))).alu(Ops.CONTIGUOUS)
b = schedule_graph_rewrite(b)
assert UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)))).match(b, {})
if __name__ == '__main__':
unittest.main(verbosity=2)