Refactor contraction and add integration test cases for push permute (#650)

* Refactor contraction and add unit tests

* Fix typo; Fix TestConv.test_elu failure due to some ones in old_shape

* Add push permute test cases

* Fix mypy type annotation check error

* Add contraction unit test; Reshape to higher dimension is not contraction
This commit is contained in:
Alex Wang
2023-03-06 22:36:55 +08:00
committed by GitHub
parent cb5be9697c
commit 64ecbd91b5
3 changed files with 110 additions and 24 deletions

View File

@@ -10,7 +10,8 @@ import unittest
from tinygrad.tensor import Tensor, Device
from tinygrad import nn
from tinygrad.nn import optim
from tinygrad.ops import GlobalCounters
from tinygrad.ops import GlobalCounters, MovementOps, ReduceOps
from tinygrad.lazy import PUSH_PERMUTES
class CLCache():
def __enter__(self):
@@ -131,5 +132,66 @@ class TestOpt(unittest.TestCase):
print(img_conv)
assert len(GlobalCounters.cache) == 2, "optimizer didn't fold conv/relu"
def helper_push_permute_before_reshape(self, t, should_push=True, desired_reshape_arg=None, desired_permute_arg=None):
if PUSH_PERMUTES and should_push:
assert t.lazydata.op.src[0].op.op == MovementOps.PERMUTE, 'Permute should be pushed before reshape'
assert t.lazydata.op.src[0].op.arg == desired_permute_arg, f'Pushed permute arg should be {desired_permute_arg}'
assert t.lazydata.op.op == MovementOps.RESHAPE, 'Reshape should be after permute'
assert t.lazydata.op.arg == desired_reshape_arg, f'Reshape arg should be {desired_reshape_arg}'
else:
assert t.lazydata.op.src[0].op.op == MovementOps.RESHAPE, 'Reshape should before permute'
assert t.lazydata.op.op == MovementOps.PERMUTE, 'Permute should be after reshape'
def test_push_permute_before_reshape(self):
t = Tensor.ones(1,2,3,4)
t = t.reshape(1,2,3*4).permute(2,1,0)
self.helper_push_permute_before_reshape(t, should_push=True, desired_reshape_arg=(12,2,1), desired_permute_arg=(2,3,1,0))
t = Tensor.ones(1,2,3,4)
t = t.reshape(3,1,2,4).permute(3,2,1,0)
self.helper_push_permute_before_reshape(t, should_push=False)
t = Tensor.ones(1,2,3,1,4,1)
t = t.reshape(1,2,3*4).permute(2,1,0)
self.helper_push_permute_before_reshape(t, should_push=True, desired_reshape_arg=(12,2,1), desired_permute_arg=(2,3,4,5,1,0))
t = Tensor.ones(1,2,3,4)
t = t.reshape(1,2,3,1,4).permute(4,3,2,1,0)
self.helper_push_permute_before_reshape(t, should_push=False)
def test_push_permute_before_reduce(self):
t = Tensor.ones(1,2,3,4)
t = t.sum(axis=2).permute(2,1,0)
if PUSH_PERMUTES:
assert t.lazydata.op.src[0].op.src[0].op.op == MovementOps.PERMUTE, 'Permute should be pushed before reduce'
assert t.lazydata.op.src[0].op.src[0].op.arg == (3,1,0,2), 'Pushed permute arg error'
assert t.lazydata.op.src[0].op.op == ReduceOps.SUM, 'Sum should be after permute'
assert t.lazydata.op.src[0].op.arg == (4,2,1,1), 'Sum arg error'
assert t.lazydata.op.op == MovementOps.RESHAPE, 'Reshape should be after Sum'
assert t.lazydata.op.arg == (4,2,1), 'Reshape arg error'
else:
assert t.lazydata.op.src[0].op.src[0].op.op == ReduceOps.SUM, 'Sum should be the first'
assert t.lazydata.op.src[0].op.src[0].op.arg == (1,2,4,1), 'Sum arg error'
assert t.lazydata.op.src[0].op.op == MovementOps.RESHAPE, 'Reshape should be after sum'
assert t.lazydata.op.src[0].op.arg == (1,2,4), 'Reshape arg error'
assert t.lazydata.op.op == MovementOps.PERMUTE, 'Permute should be after Reshape'
assert t.lazydata.op.arg == (2,1,0), 'Permute arg error'
def test_push_permute_before_expand(self):
t = Tensor.ones(1,2,3,4)
t = t.expand(2,2,3,4).permute(3,2,1,0)
if PUSH_PERMUTES:
assert t.lazydata.op.src[0].op.op == MovementOps.PERMUTE, 'Permute should be pushed before reduce'
assert t.lazydata.op.src[0].op.arg == (3,2,1,0), 'Pushed permute arg error'
assert t.lazydata.op.op == MovementOps.EXPAND, 'Expand should be after permute'
assert t.lazydata.op.arg == (4,3,2,2), 'Expand arg error'
else:
assert t.lazydata.op.src[0].op.op == MovementOps.EXPAND, 'Expand should be the first'
assert t.lazydata.op.src[0].op.arg == (2,2,3,4), 'Expand arg error'
assert t.lazydata.op.op == MovementOps.PERMUTE, 'Permute should be after expand'
assert t.lazydata.op.arg == (3,2,1,0), 'Permute arg error'
if __name__ == '__main__':
unittest.main()