fix tensor realization bug in #8975 (#8984)

* fix tensor realization bug in #8975

* that's a reshape now

* work

* works

* give those tests better names

* test when multiple mops result in the same ShapeTracker

* test_become_existing_buf_complex is enough

* that too
This commit is contained in:
qazal
2025-02-10 13:51:30 +01:00
committed by GitHub
parent b17ec42b56
commit cd77e51810
2 changed files with 54 additions and 11 deletions

View File

@@ -2513,18 +2513,20 @@ class TestUOpBecome(unittest.TestCase):
b = a*1
assert UPat(Ops.MUL).match(b.lazydata, {}) # before scheduling it's a mul
check_schedule(b, 0)
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(b.lazydata, {}) # scheduling replaces the tensor lazydata with a VIEW(BUFFER)
assert UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER))).match(b.lazydata, {}) # scheduling backtracks to the movement op if the realized tensor becomes a view
self.assertIs(a.lazydata.base.buffer, b.lazydata.base.buffer)
# TODO: this fails because the shrink must be applied on top of the BUFFER
# currently it's a VIEW
@unittest.expectedFailure
def test_become_buf_with_mops(self):
a = Tensor.empty(2, 4, 2)
noop = a.shrink(((1, 2), (0, 4), (0, 2))).reshape(4, 2)*1+0
# before realizing, this tensor is base
assert noop.lazydata is noop.lazydata.base
noop.realize()
# it becomes a realized view after realize
assert noop.lazydata is not noop.lazydata.base
assert noop.lazydata.is_realized
late_add = noop+2
late_add.realize() # UOp verification error
late_add.realize()
def test_become_const_in_base(self):
a = Tensor.empty(4)
@@ -2549,5 +2551,46 @@ class TestUOpBecome(unittest.TestCase):
check_schedule(const_add, 0)
assert UPat(Ops.CONST, arg=3).match(const_add.lazydata.base, {})
# tensors can become another realized tensor source
def test_become_existing_buf_simple(self):
a = Tensor.empty(4, 4)
b = a+0
check_schedule(b, 0)
assert b.lazydata.is_realized
self.assertIs(a.lazydata, b.lazydata)
# they can also chain other movement ops on top of the tensor source
def test_become_existing_buf_view(self):
a = Tensor.empty(4, 4)
b = a.permute((1, 0))+0
check_schedule(b, 0)
self.assertIs(b.lazydata, a.lazydata.permute((1, 0)))
def test_become_existing_buf_view_alt(self):
a = Tensor.empty(4, 4)
b = a.permute((1, 0)).reshape((8, 2))+0
check_schedule(b, 0)
self.assertIs(b.lazydata, a.lazydata.permute((1, 0)).reshape((8, 2)))
# they can also have other base parents that simplified, in that case we just backtrack to the chained mops
def test_become_existing_buf_complex(self):
a = Tensor.empty(4, 4)
b = (a.permute((1, 0))+0).reshape((8, 2))+0
check_schedule(b, 0)
self.assertIs(b.lazydata, a.lazydata.permute((1, 0)).reshape((8, 2)))
assert b.lazydata.is_realized
def test_become_multiple_choices(self):
a = Tensor.empty(16)
b = (a.reshape(1, 1, 4, 1, 4)+0).reshape(1, 1, 4, 4).shrink(((0, 1), (0, 1), (0, 3), (0, 3)))+0
c = (a.reshape(1, 1, 4, 4)+0).shrink(((0, 1), (0, 1), (0, 3), (0, 3)))+0
check_schedule([b, c], 0)
assert all_same([x.lazydata.base.realized for x in [a,b,c]])
# these movement ops result in the same ShapeTracker
assert b.lazydata.st == c.lazydata.st
# the decision for which movement op to pick is local, we could also make this always pick the simplest one
assert UPat(Ops.SHRINK, src=(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)),)),)).match(b.lazydata, {})
assert UPat(Ops.SHRINK, src=(UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)),)).match(c.lazydata, {})
if __name__ == '__main__':
unittest.main(verbosity=2)