Refactor contraction and add integration test cases for push permute (#650)

* Refactor contraction and add unit tests

* Fix typo; Fix TestConv.test_elu failure due to some ones in old_shape

* Add push permute test cases

* Fix mypy type annotation check error

* Add contraction unit test; Reshape to higher dimension is not contraction
This commit is contained in:
Alex Wang
2023-03-06 22:36:55 +08:00
committed by GitHub
parent cb5be9697c
commit 64ecbd91b5
3 changed files with 110 additions and 24 deletions

View File

@@ -10,7 +10,8 @@ import unittest
from tinygrad.tensor import Tensor, Device
from tinygrad import nn
from tinygrad.nn import optim
from tinygrad.ops import GlobalCounters
from tinygrad.ops import GlobalCounters, MovementOps, ReduceOps
from tinygrad.lazy import PUSH_PERMUTES
class CLCache():
def __enter__(self):
@@ -131,5 +132,66 @@ class TestOpt(unittest.TestCase):
print(img_conv)
assert len(GlobalCounters.cache) == 2, "optimizer didn't fold conv/relu"
def helper_push_permute_before_reshape(self, t, should_push=True, desired_reshape_arg=None, desired_permute_arg=None):
if PUSH_PERMUTES and should_push:
assert t.lazydata.op.src[0].op.op == MovementOps.PERMUTE, 'Permute should be pushed before reshape'
assert t.lazydata.op.src[0].op.arg == desired_permute_arg, f'Pushed permute arg should be {desired_permute_arg}'
assert t.lazydata.op.op == MovementOps.RESHAPE, 'Reshape should be after permute'
assert t.lazydata.op.arg == desired_reshape_arg, f'Reshape arg should be {desired_reshape_arg}'
else:
assert t.lazydata.op.src[0].op.op == MovementOps.RESHAPE, 'Reshape should before permute'
assert t.lazydata.op.op == MovementOps.PERMUTE, 'Permute should be after reshape'
def test_push_permute_before_reshape(self):
t = Tensor.ones(1,2,3,4)
t = t.reshape(1,2,3*4).permute(2,1,0)
self.helper_push_permute_before_reshape(t, should_push=True, desired_reshape_arg=(12,2,1), desired_permute_arg=(2,3,1,0))
t = Tensor.ones(1,2,3,4)
t = t.reshape(3,1,2,4).permute(3,2,1,0)
self.helper_push_permute_before_reshape(t, should_push=False)
t = Tensor.ones(1,2,3,1,4,1)
t = t.reshape(1,2,3*4).permute(2,1,0)
self.helper_push_permute_before_reshape(t, should_push=True, desired_reshape_arg=(12,2,1), desired_permute_arg=(2,3,4,5,1,0))
t = Tensor.ones(1,2,3,4)
t = t.reshape(1,2,3,1,4).permute(4,3,2,1,0)
self.helper_push_permute_before_reshape(t, should_push=False)
def test_push_permute_before_reduce(self):
t = Tensor.ones(1,2,3,4)
t = t.sum(axis=2).permute(2,1,0)
if PUSH_PERMUTES:
assert t.lazydata.op.src[0].op.src[0].op.op == MovementOps.PERMUTE, 'Permute should be pushed before reduce'
assert t.lazydata.op.src[0].op.src[0].op.arg == (3,1,0,2), 'Pushed permute arg error'
assert t.lazydata.op.src[0].op.op == ReduceOps.SUM, 'Sum should be after permute'
assert t.lazydata.op.src[0].op.arg == (4,2,1,1), 'Sum arg error'
assert t.lazydata.op.op == MovementOps.RESHAPE, 'Reshape should be after Sum'
assert t.lazydata.op.arg == (4,2,1), 'Reshape arg error'
else:
assert t.lazydata.op.src[0].op.src[0].op.op == ReduceOps.SUM, 'Sum should be the first'
assert t.lazydata.op.src[0].op.src[0].op.arg == (1,2,4,1), 'Sum arg error'
assert t.lazydata.op.src[0].op.op == MovementOps.RESHAPE, 'Reshape should be after sum'
assert t.lazydata.op.src[0].op.arg == (1,2,4), 'Reshape arg error'
assert t.lazydata.op.op == MovementOps.PERMUTE, 'Permute should be after Reshape'
assert t.lazydata.op.arg == (2,1,0), 'Permute arg error'
def test_push_permute_before_expand(self):
t = Tensor.ones(1,2,3,4)
t = t.expand(2,2,3,4).permute(3,2,1,0)
if PUSH_PERMUTES:
assert t.lazydata.op.src[0].op.op == MovementOps.PERMUTE, 'Permute should be pushed before reduce'
assert t.lazydata.op.src[0].op.arg == (3,2,1,0), 'Pushed permute arg error'
assert t.lazydata.op.op == MovementOps.EXPAND, 'Expand should be after permute'
assert t.lazydata.op.arg == (4,3,2,2), 'Expand arg error'
else:
assert t.lazydata.op.src[0].op.op == MovementOps.EXPAND, 'Expand should be the first'
assert t.lazydata.op.src[0].op.arg == (2,2,3,4), 'Expand arg error'
assert t.lazydata.op.op == MovementOps.PERMUTE, 'Permute should be after expand'
assert t.lazydata.op.arg == (3,2,1,0), 'Permute arg error'
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,31 @@
#!/usr/bin/env python
import unittest
from tinygrad.lazy import get_contraction
class TestFlopCounter(unittest.TestCase):
def test_contraction(self):
r = get_contraction((1,2,3,4), (2,3,4))
self.assertEqual(r, [[0, 1], [2], [3]])
r = get_contraction((1,2,3,1,4), (1,2,3,4))
self.assertEqual(r, [[0], [1], [2], [3, 4]])
r = get_contraction((1,2,3,1,4,1,1), (2,3,4))
self.assertEqual(r, [[0, 1], [2], [3, 4, 5, 6]])
r = get_contraction((1,2,3,4), (1,2,3*4))
self.assertEqual(r, [[0], [1], [2, 3]])
r = get_contraction((1,2,3,4), (2,1,3,4))
self.assertEqual(r, None)
r = get_contraction((1,2,3,4), (1,2,3,4,1))
self.assertEqual(r, None)
r = get_contraction((1,2,3,4), (1,2,6,2))
self.assertEqual(r, None)
if __name__ == '__main__':
unittest.main()

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Optional, Tuple, Union, List, Dict, Any, ClassVar, Type
import os, sys, weakref, importlib, inspect
import os, sys, weakref, importlib, inspect, functools
from weakref import WeakValueDictionary
from tinygrad.helpers import prod, getenv
from tinygrad.shape import ShapeTracker
@@ -78,6 +78,19 @@ def replace_with_movement_op(y:Union[LazyOp, LazyBuffer], op:MovementOps, arg:Tu
assert y.op in BinaryOps or y.op in UnaryOps
return elementwise_op(y.op, *[replace_with_movement_op(z, op, arg) for z in y.src]) # type: ignore
def get_contraction(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]):
if len(new_shape) > len(old_shape): return None
new_shape_i : int = 0
shape_idx_groups : List[List[int]] = [[] for _ in range(len(new_shape))]
for old_shape_i, t in enumerate(old_shape):
if new_shape[new_shape_i] % t != 0 or prod([old_shape[x] for x in shape_idx_groups[new_shape_i]]) * t > new_shape[new_shape_i]:
return None
shape_idx_groups[new_shape_i].append(old_shape_i)
if prod([old_shape[x] for x in shape_idx_groups[new_shape_i]]) == new_shape[new_shape_i] and new_shape_i < len(new_shape) - 1:
new_shape_i += 1
return shape_idx_groups
def support_weakref(x): return x
@support_weakref # needed for mypyc, this prevents LazyBuffer from becoming a native class
class LazyBuffer:
@@ -204,28 +217,8 @@ class LazyBuffer:
# move permutes before reshapes if we can
if op == MovementOps.PERMUTE and PUSH_PERMUTES and self.realized is None and self.op.op == MovementOps.RESHAPE and isinstance(self.op.src[0], LazyBuffer):
# TODO: this is atrocious code
# is contract? if so, group the axis
def get_contraction(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]):
out : List[List[int]] = []
curr : List[int] = []
for t in old_shape:
if len(out) >= len(new_shape): break
if t*prod(curr) <= new_shape[len(out)]:
curr.append(t)
else:
out.append(curr)
curr = [t]
out.append(curr)
if len(new_shape) == len(out) and all(prod(i) == j and len(i) >= 1 for i,j in zip(out, new_shape)):
return out
if contraction := get_contraction(self.op.src[0].shape, self.shape):
numbered, start = [], 0
for c in contraction:
numbered.append(list(range(start, start+len(c))))
start += len(c)
new_arg = []
for p in arg: new_arg += numbered[p]
if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape):
new_arg : List[int] = functools.reduce(lambda r, x: r + shape_idx_groups[x], arg, [])
self.op.src[0].children.discard(self) # this changes nothing?
return self.op.src[0].movement_op(MovementOps.PERMUTE, tuple(new_arg)) \
.movement_op(MovementOps.RESHAPE, ShapeTracker(self.st).movement_op(op, arg).shape)