give BUFFER UOp a ShapeTracker [pr] (#8811)

* give BUFFER UOp a ShapeTracker [pr]

* move that

* update contiguous

* test_advancedindex should use movement ops
This commit is contained in:
qazal
2025-01-30 15:33:32 -05:00
committed by GitHub
parent 5527f86a8f
commit 5643429c17
5 changed files with 33 additions and 17 deletions

View File

@@ -2272,6 +2272,15 @@ class TestCopyFolding(unittest.TestCase):
b = a.copy_to_device(a.device)
check_schedule(b, 0, filter_sink=False)
b = schedule_graph_rewrite(b)
# NOTE: Tensor.empty(4) always creates a VIEW(BUFFER) with ShapeTracker((4,)), we simplify this to jsut a BUFFER
# in the scheduler because buffer already has shape (4,)
self.assertIs(b, a.base)
def test_copy_to_same_device_alt(self):
a = Tensor.empty(4, 4).lazydata
b = a.copy_to_device(a.device)
check_schedule(b, 0, filter_sink=False)
b = schedule_graph_rewrite(b)
self.assertIs(b, a)
def test_clone(self):
@@ -2455,14 +2464,17 @@ class TestUOpBecome(unittest.TestCase):
b = Tensor.empty(4, 4)
add = a+b
check_schedule(add, 1)
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(add.lazydata.base, {})
# NOTE: realized base is always a flat buffer
assert UPat(Ops.BUFFER).match(add.lazydata.base, {})
# the Tensor UOp can optionally stack a VIEW on top of BUFFER
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).match(add.lazydata, {})
def test_new_buffer_view(self):
a = Tensor.empty(4, 4)
b = Tensor.empty(4, 4)
add = (a+b).reshape(8, 2)
check_schedule(add, 1)
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(add.lazydata.base, {})
assert UPat(Ops.BUFFER).match(add.lazydata.base, {})
# VIEW is preserverd after the becomes rewrite.
self.assertEqual(add.lazydata.shape, (8, 2))
assert add.lazydata is not add.lazydata.base
@@ -2472,7 +2484,7 @@ class TestUOpBecome(unittest.TestCase):
b = a*1
assert UPat(Ops.MUL).match(b.lazydata, {}) # before scheduling it's a mul
check_schedule(b, 0)
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(b.lazydata.base, {}) # scheduling replaces the tensor lazydata with a VIEW(BUFFER)
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(b.lazydata, {}) # scheduling replaces the tensor lazydata with a VIEW(BUFFER)
self.assertIs(a.lazydata.base.buffer, b.lazydata.base.buffer)
def test_become_const_in_base(self):