mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-12 07:35:16 -05:00
give BUFFER UOp a ShapeTracker [pr] (#8811)
* give BUFFER UOp a ShapeTracker [pr] * move that * update contiguous * test_advancedindex should use movement ops
This commit is contained in:
@@ -2272,6 +2272,15 @@ class TestCopyFolding(unittest.TestCase):
|
||||
b = a.copy_to_device(a.device)
|
||||
check_schedule(b, 0, filter_sink=False)
|
||||
b = schedule_graph_rewrite(b)
|
||||
# NOTE: Tensor.empty(4) always creates a VIEW(BUFFER) with ShapeTracker((4,)), we simplify this to jsut a BUFFER
|
||||
# in the scheduler because buffer already has shape (4,)
|
||||
self.assertIs(b, a.base)
|
||||
|
||||
def test_copy_to_same_device_alt(self):
|
||||
a = Tensor.empty(4, 4).lazydata
|
||||
b = a.copy_to_device(a.device)
|
||||
check_schedule(b, 0, filter_sink=False)
|
||||
b = schedule_graph_rewrite(b)
|
||||
self.assertIs(b, a)
|
||||
|
||||
def test_clone(self):
|
||||
@@ -2455,14 +2464,17 @@ class TestUOpBecome(unittest.TestCase):
|
||||
b = Tensor.empty(4, 4)
|
||||
add = a+b
|
||||
check_schedule(add, 1)
|
||||
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(add.lazydata.base, {})
|
||||
# NOTE: realized base is always a flat buffer
|
||||
assert UPat(Ops.BUFFER).match(add.lazydata.base, {})
|
||||
# the Tensor UOp can optionally stack a VIEW on top of BUFFER
|
||||
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).match(add.lazydata, {})
|
||||
|
||||
def test_new_buffer_view(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
b = Tensor.empty(4, 4)
|
||||
add = (a+b).reshape(8, 2)
|
||||
check_schedule(add, 1)
|
||||
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(add.lazydata.base, {})
|
||||
assert UPat(Ops.BUFFER).match(add.lazydata.base, {})
|
||||
# VIEW is preserverd after the becomes rewrite.
|
||||
self.assertEqual(add.lazydata.shape, (8, 2))
|
||||
assert add.lazydata is not add.lazydata.base
|
||||
@@ -2472,7 +2484,7 @@ class TestUOpBecome(unittest.TestCase):
|
||||
b = a*1
|
||||
assert UPat(Ops.MUL).match(b.lazydata, {}) # before scheduling it's a mul
|
||||
check_schedule(b, 0)
|
||||
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(b.lazydata.base, {}) # scheduling replaces the tensor lazydata with a VIEW(BUFFER)
|
||||
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(b.lazydata, {}) # scheduling replaces the tensor lazydata with a VIEW(BUFFER)
|
||||
self.assertIs(a.lazydata.base.buffer, b.lazydata.base.buffer)
|
||||
|
||||
def test_become_const_in_base(self):
|
||||
|
||||
Reference in New Issue
Block a user