mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Refactor contraction and add integration test cases for push permute (#650)
* Refactor contraction and add unit tests * Fix typo; Fix TestConv.test_elu failure due to some ones in old_shape * Add push permute test cases * Fix mypy type annotation check error * Add contraction unit test; Reshape to higher dimension is not contraction
This commit is contained in:
64
test/external/external_test_opt.py
vendored
64
test/external/external_test_opt.py
vendored
@@ -10,7 +10,8 @@ import unittest
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad import nn
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.ops import GlobalCounters, MovementOps, ReduceOps
|
||||
from tinygrad.lazy import PUSH_PERMUTES
|
||||
|
||||
class CLCache():
|
||||
def __enter__(self):
|
||||
@@ -131,5 +132,66 @@ class TestOpt(unittest.TestCase):
|
||||
print(img_conv)
|
||||
assert len(GlobalCounters.cache) == 2, "optimizer didn't fold conv/relu"
|
||||
|
||||
def helper_push_permute_before_reshape(self, t, should_push=True, desired_reshape_arg=None, desired_permute_arg=None):
|
||||
if PUSH_PERMUTES and should_push:
|
||||
assert t.lazydata.op.src[0].op.op == MovementOps.PERMUTE, 'Permute should be pushed before reshape'
|
||||
assert t.lazydata.op.src[0].op.arg == desired_permute_arg, f'Pushed permute arg should be {desired_permute_arg}'
|
||||
assert t.lazydata.op.op == MovementOps.RESHAPE, 'Reshape should be after permute'
|
||||
assert t.lazydata.op.arg == desired_reshape_arg, f'Reshape arg should be {desired_reshape_arg}'
|
||||
else:
|
||||
assert t.lazydata.op.src[0].op.op == MovementOps.RESHAPE, 'Reshape should before permute'
|
||||
assert t.lazydata.op.op == MovementOps.PERMUTE, 'Permute should be after reshape'
|
||||
|
||||
|
||||
def test_push_permute_before_reshape(self):
|
||||
t = Tensor.ones(1,2,3,4)
|
||||
t = t.reshape(1,2,3*4).permute(2,1,0)
|
||||
self.helper_push_permute_before_reshape(t, should_push=True, desired_reshape_arg=(12,2,1), desired_permute_arg=(2,3,1,0))
|
||||
|
||||
t = Tensor.ones(1,2,3,4)
|
||||
t = t.reshape(3,1,2,4).permute(3,2,1,0)
|
||||
self.helper_push_permute_before_reshape(t, should_push=False)
|
||||
|
||||
t = Tensor.ones(1,2,3,1,4,1)
|
||||
t = t.reshape(1,2,3*4).permute(2,1,0)
|
||||
self.helper_push_permute_before_reshape(t, should_push=True, desired_reshape_arg=(12,2,1), desired_permute_arg=(2,3,4,5,1,0))
|
||||
|
||||
t = Tensor.ones(1,2,3,4)
|
||||
t = t.reshape(1,2,3,1,4).permute(4,3,2,1,0)
|
||||
self.helper_push_permute_before_reshape(t, should_push=False)
|
||||
|
||||
|
||||
def test_push_permute_before_reduce(self):
|
||||
t = Tensor.ones(1,2,3,4)
|
||||
t = t.sum(axis=2).permute(2,1,0)
|
||||
if PUSH_PERMUTES:
|
||||
assert t.lazydata.op.src[0].op.src[0].op.op == MovementOps.PERMUTE, 'Permute should be pushed before reduce'
|
||||
assert t.lazydata.op.src[0].op.src[0].op.arg == (3,1,0,2), 'Pushed permute arg error'
|
||||
assert t.lazydata.op.src[0].op.op == ReduceOps.SUM, 'Sum should be after permute'
|
||||
assert t.lazydata.op.src[0].op.arg == (4,2,1,1), 'Sum arg error'
|
||||
assert t.lazydata.op.op == MovementOps.RESHAPE, 'Reshape should be after Sum'
|
||||
assert t.lazydata.op.arg == (4,2,1), 'Reshape arg error'
|
||||
else:
|
||||
assert t.lazydata.op.src[0].op.src[0].op.op == ReduceOps.SUM, 'Sum should be the first'
|
||||
assert t.lazydata.op.src[0].op.src[0].op.arg == (1,2,4,1), 'Sum arg error'
|
||||
assert t.lazydata.op.src[0].op.op == MovementOps.RESHAPE, 'Reshape should be after sum'
|
||||
assert t.lazydata.op.src[0].op.arg == (1,2,4), 'Reshape arg error'
|
||||
assert t.lazydata.op.op == MovementOps.PERMUTE, 'Permute should be after Reshape'
|
||||
assert t.lazydata.op.arg == (2,1,0), 'Permute arg error'
|
||||
|
||||
def test_push_permute_before_expand(self):
|
||||
t = Tensor.ones(1,2,3,4)
|
||||
t = t.expand(2,2,3,4).permute(3,2,1,0)
|
||||
if PUSH_PERMUTES:
|
||||
assert t.lazydata.op.src[0].op.op == MovementOps.PERMUTE, 'Permute should be pushed before reduce'
|
||||
assert t.lazydata.op.src[0].op.arg == (3,2,1,0), 'Pushed permute arg error'
|
||||
assert t.lazydata.op.op == MovementOps.EXPAND, 'Expand should be after permute'
|
||||
assert t.lazydata.op.arg == (4,3,2,2), 'Expand arg error'
|
||||
else:
|
||||
assert t.lazydata.op.src[0].op.op == MovementOps.EXPAND, 'Expand should be the first'
|
||||
assert t.lazydata.op.src[0].op.arg == (2,2,3,4), 'Expand arg error'
|
||||
assert t.lazydata.op.op == MovementOps.PERMUTE, 'Permute should be after expand'
|
||||
assert t.lazydata.op.arg == (3,2,1,0), 'Permute arg error'
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user