mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
Refactor contraction and add integration test cases for push permute (#650)
* Refactor contraction and add unit tests * Fix typo; Fix TestConv.test_elu failure due to some ones in old_shape * Add push permute test cases * Fix mypy type annotation check error * Add contraction unit test; Reshape to higher dimension is not contraction
This commit is contained in:
64
test/external/external_test_opt.py
vendored
64
test/external/external_test_opt.py
vendored
@@ -10,7 +10,8 @@ import unittest
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad import nn
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.ops import GlobalCounters, MovementOps, ReduceOps
|
||||
from tinygrad.lazy import PUSH_PERMUTES
|
||||
|
||||
class CLCache():
|
||||
def __enter__(self):
|
||||
@@ -131,5 +132,66 @@ class TestOpt(unittest.TestCase):
|
||||
print(img_conv)
|
||||
assert len(GlobalCounters.cache) == 2, "optimizer didn't fold conv/relu"
|
||||
|
||||
def helper_push_permute_before_reshape(self, t, should_push=True, desired_reshape_arg=None, desired_permute_arg=None):
|
||||
if PUSH_PERMUTES and should_push:
|
||||
assert t.lazydata.op.src[0].op.op == MovementOps.PERMUTE, 'Permute should be pushed before reshape'
|
||||
assert t.lazydata.op.src[0].op.arg == desired_permute_arg, f'Pushed permute arg should be {desired_permute_arg}'
|
||||
assert t.lazydata.op.op == MovementOps.RESHAPE, 'Reshape should be after permute'
|
||||
assert t.lazydata.op.arg == desired_reshape_arg, f'Reshape arg should be {desired_reshape_arg}'
|
||||
else:
|
||||
assert t.lazydata.op.src[0].op.op == MovementOps.RESHAPE, 'Reshape should before permute'
|
||||
assert t.lazydata.op.op == MovementOps.PERMUTE, 'Permute should be after reshape'
|
||||
|
||||
|
||||
def test_push_permute_before_reshape(self):
|
||||
t = Tensor.ones(1,2,3,4)
|
||||
t = t.reshape(1,2,3*4).permute(2,1,0)
|
||||
self.helper_push_permute_before_reshape(t, should_push=True, desired_reshape_arg=(12,2,1), desired_permute_arg=(2,3,1,0))
|
||||
|
||||
t = Tensor.ones(1,2,3,4)
|
||||
t = t.reshape(3,1,2,4).permute(3,2,1,0)
|
||||
self.helper_push_permute_before_reshape(t, should_push=False)
|
||||
|
||||
t = Tensor.ones(1,2,3,1,4,1)
|
||||
t = t.reshape(1,2,3*4).permute(2,1,0)
|
||||
self.helper_push_permute_before_reshape(t, should_push=True, desired_reshape_arg=(12,2,1), desired_permute_arg=(2,3,4,5,1,0))
|
||||
|
||||
t = Tensor.ones(1,2,3,4)
|
||||
t = t.reshape(1,2,3,1,4).permute(4,3,2,1,0)
|
||||
self.helper_push_permute_before_reshape(t, should_push=False)
|
||||
|
||||
|
||||
def test_push_permute_before_reduce(self):
|
||||
t = Tensor.ones(1,2,3,4)
|
||||
t = t.sum(axis=2).permute(2,1,0)
|
||||
if PUSH_PERMUTES:
|
||||
assert t.lazydata.op.src[0].op.src[0].op.op == MovementOps.PERMUTE, 'Permute should be pushed before reduce'
|
||||
assert t.lazydata.op.src[0].op.src[0].op.arg == (3,1,0,2), 'Pushed permute arg error'
|
||||
assert t.lazydata.op.src[0].op.op == ReduceOps.SUM, 'Sum should be after permute'
|
||||
assert t.lazydata.op.src[0].op.arg == (4,2,1,1), 'Sum arg error'
|
||||
assert t.lazydata.op.op == MovementOps.RESHAPE, 'Reshape should be after Sum'
|
||||
assert t.lazydata.op.arg == (4,2,1), 'Reshape arg error'
|
||||
else:
|
||||
assert t.lazydata.op.src[0].op.src[0].op.op == ReduceOps.SUM, 'Sum should be the first'
|
||||
assert t.lazydata.op.src[0].op.src[0].op.arg == (1,2,4,1), 'Sum arg error'
|
||||
assert t.lazydata.op.src[0].op.op == MovementOps.RESHAPE, 'Reshape should be after sum'
|
||||
assert t.lazydata.op.src[0].op.arg == (1,2,4), 'Reshape arg error'
|
||||
assert t.lazydata.op.op == MovementOps.PERMUTE, 'Permute should be after Reshape'
|
||||
assert t.lazydata.op.arg == (2,1,0), 'Permute arg error'
|
||||
|
||||
def test_push_permute_before_expand(self):
|
||||
t = Tensor.ones(1,2,3,4)
|
||||
t = t.expand(2,2,3,4).permute(3,2,1,0)
|
||||
if PUSH_PERMUTES:
|
||||
assert t.lazydata.op.src[0].op.op == MovementOps.PERMUTE, 'Permute should be pushed before reduce'
|
||||
assert t.lazydata.op.src[0].op.arg == (3,2,1,0), 'Pushed permute arg error'
|
||||
assert t.lazydata.op.op == MovementOps.EXPAND, 'Expand should be after permute'
|
||||
assert t.lazydata.op.arg == (4,3,2,2), 'Expand arg error'
|
||||
else:
|
||||
assert t.lazydata.op.src[0].op.op == MovementOps.EXPAND, 'Expand should be the first'
|
||||
assert t.lazydata.op.src[0].op.arg == (2,2,3,4), 'Expand arg error'
|
||||
assert t.lazydata.op.op == MovementOps.PERMUTE, 'Permute should be after expand'
|
||||
assert t.lazydata.op.arg == (3,2,1,0), 'Permute arg error'
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
31
test/unit/test_contraction.py
Normal file
31
test/unit/test_contraction.py
Normal file
@@ -0,0 +1,31 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
from tinygrad.lazy import get_contraction
|
||||
|
||||
class TestFlopCounter(unittest.TestCase):
|
||||
def test_contraction(self):
|
||||
r = get_contraction((1,2,3,4), (2,3,4))
|
||||
self.assertEqual(r, [[0, 1], [2], [3]])
|
||||
|
||||
r = get_contraction((1,2,3,1,4), (1,2,3,4))
|
||||
self.assertEqual(r, [[0], [1], [2], [3, 4]])
|
||||
|
||||
r = get_contraction((1,2,3,1,4,1,1), (2,3,4))
|
||||
self.assertEqual(r, [[0, 1], [2], [3, 4, 5, 6]])
|
||||
|
||||
r = get_contraction((1,2,3,4), (1,2,3*4))
|
||||
self.assertEqual(r, [[0], [1], [2, 3]])
|
||||
|
||||
r = get_contraction((1,2,3,4), (2,1,3,4))
|
||||
self.assertEqual(r, None)
|
||||
|
||||
r = get_contraction((1,2,3,4), (1,2,3,4,1))
|
||||
self.assertEqual(r, None)
|
||||
|
||||
r = get_contraction((1,2,3,4), (1,2,6,2))
|
||||
self.assertEqual(r, None)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Tuple, Union, List, Dict, Any, ClassVar, Type
|
||||
import os, sys, weakref, importlib, inspect
|
||||
import os, sys, weakref, importlib, inspect, functools
|
||||
from weakref import WeakValueDictionary
|
||||
from tinygrad.helpers import prod, getenv
|
||||
from tinygrad.shape import ShapeTracker
|
||||
@@ -78,6 +78,19 @@ def replace_with_movement_op(y:Union[LazyOp, LazyBuffer], op:MovementOps, arg:Tu
|
||||
assert y.op in BinaryOps or y.op in UnaryOps
|
||||
return elementwise_op(y.op, *[replace_with_movement_op(z, op, arg) for z in y.src]) # type: ignore
|
||||
|
||||
def get_contraction(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]):
|
||||
if len(new_shape) > len(old_shape): return None
|
||||
new_shape_i : int = 0
|
||||
shape_idx_groups : List[List[int]] = [[] for _ in range(len(new_shape))]
|
||||
for old_shape_i, t in enumerate(old_shape):
|
||||
if new_shape[new_shape_i] % t != 0 or prod([old_shape[x] for x in shape_idx_groups[new_shape_i]]) * t > new_shape[new_shape_i]:
|
||||
return None
|
||||
shape_idx_groups[new_shape_i].append(old_shape_i)
|
||||
if prod([old_shape[x] for x in shape_idx_groups[new_shape_i]]) == new_shape[new_shape_i] and new_shape_i < len(new_shape) - 1:
|
||||
new_shape_i += 1
|
||||
return shape_idx_groups
|
||||
|
||||
|
||||
def support_weakref(x): return x
|
||||
@support_weakref # needed for mypyc, this prevents LazyBuffer from becoming a native class
|
||||
class LazyBuffer:
|
||||
@@ -204,28 +217,8 @@ class LazyBuffer:
|
||||
|
||||
# move permutes before reshapes if we can
|
||||
if op == MovementOps.PERMUTE and PUSH_PERMUTES and self.realized is None and self.op.op == MovementOps.RESHAPE and isinstance(self.op.src[0], LazyBuffer):
|
||||
# TODO: this is atrocious code
|
||||
# is contract? if so, group the axis
|
||||
def get_contraction(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]):
|
||||
out : List[List[int]] = []
|
||||
curr : List[int] = []
|
||||
for t in old_shape:
|
||||
if len(out) >= len(new_shape): break
|
||||
if t*prod(curr) <= new_shape[len(out)]:
|
||||
curr.append(t)
|
||||
else:
|
||||
out.append(curr)
|
||||
curr = [t]
|
||||
out.append(curr)
|
||||
if len(new_shape) == len(out) and all(prod(i) == j and len(i) >= 1 for i,j in zip(out, new_shape)):
|
||||
return out
|
||||
if contraction := get_contraction(self.op.src[0].shape, self.shape):
|
||||
numbered, start = [], 0
|
||||
for c in contraction:
|
||||
numbered.append(list(range(start, start+len(c))))
|
||||
start += len(c)
|
||||
new_arg = []
|
||||
for p in arg: new_arg += numbered[p]
|
||||
if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape):
|
||||
new_arg : List[int] = functools.reduce(lambda r, x: r + shape_idx_groups[x], arg, [])
|
||||
self.op.src[0].children.discard(self) # this changes nothing?
|
||||
return self.op.src[0].movement_op(MovementOps.PERMUTE, tuple(new_arg)) \
|
||||
.movement_op(MovementOps.RESHAPE, ShapeTracker(self.st).movement_op(op, arg).shape)
|
||||
|
||||
Reference in New Issue
Block a user