mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-12 15:45:27 -05:00
Tensor UOps can become a buffer or const after scheduling (#8698)
* spec * work * update test_viewed_consts_do_not_realize * remove
This commit is contained in:
@@ -2366,5 +2366,56 @@ class TestContiguous(unittest.TestCase):
|
||||
b = a.expand((4, 4)).contiguous().contiguous()
|
||||
check_schedule(b, 1)
|
||||
|
||||
|
||||
class TestUOpBecome(unittest.TestCase):
|
||||
# the simplest case, if we create a new BUFFER for this UOp
|
||||
def test_new_buffer(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
b = Tensor.empty(4, 4)
|
||||
add = a+b
|
||||
check_schedule(add, 1)
|
||||
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(add.lazydata.base, {})
|
||||
|
||||
def test_new_buffer_view(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
b = Tensor.empty(4, 4)
|
||||
add = (a+b).reshape(8, 2)
|
||||
check_schedule(add, 1)
|
||||
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(add.lazydata.base, {})
|
||||
# VIEW is preserverd after the becomes rewrite.
|
||||
self.assertEqual(add.lazydata.shape, (8, 2))
|
||||
assert add.lazydata is not add.lazydata.base
|
||||
|
||||
def test_become_existing_buffer(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
b = a*1
|
||||
assert UPat(Ops.MUL).match(b.lazydata, {}) # before scheduling it's a mul
|
||||
check_schedule(b, 0)
|
||||
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(b.lazydata.base, {}) # scheduling replaces the tensor lazydata with a VIEW(BUFFER)
|
||||
self.assertIs(a.lazydata.base.buffer, b.lazydata.base.buffer)
|
||||
|
||||
def test_become_const_in_base(self):
|
||||
a = Tensor.empty(4)
|
||||
b = a*0
|
||||
assert UPat(Ops.MUL).match(b.lazydata, {}) # before scheduling it's a mul
|
||||
check_schedule(b, 0)
|
||||
assert UPat(Ops.CONST, arg=0).match(b.lazydata.base, {}) # scheduling replaces the tensor lazydata with a VIEW(BUFFER)
|
||||
|
||||
def test_become_const_in_view(self):
|
||||
# if we shrink the base down to a size 0, only the VIEW becomes CONST, base is unchanged.
|
||||
add = Tensor.empty(2, 2)+Tensor.empty(2, 2)
|
||||
b = add.shrink(((0, 1), (0, 0)))
|
||||
check_schedule(b, 0)
|
||||
assert UPat(Ops.CONST, arg=0).match(b.lazydata, {})
|
||||
self.assertEqual(b.shape, (1, 0))
|
||||
# the base is untouched.
|
||||
assert UPat(Ops.ADD).match(add.lazydata, {})
|
||||
|
||||
def test_become_const_from_const(self):
|
||||
const_add = Tensor(1)+Tensor(2)
|
||||
assert UPat(Ops.ADD).match(const_add.lazydata, {})
|
||||
check_schedule(const_add, 0)
|
||||
assert UPat(Ops.CONST, arg=3).match(const_add.lazydata.base, {})
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user