diff --git a/test/external/external_test_opt.py b/test/external/external_test_opt.py index 3b53ae4e76..0422f6d665 100644 --- a/test/external/external_test_opt.py +++ b/test/external/external_test_opt.py @@ -10,7 +10,8 @@ import unittest from tinygrad.tensor import Tensor, Device from tinygrad import nn from tinygrad.nn import optim -from tinygrad.ops import GlobalCounters +from tinygrad.ops import GlobalCounters, MovementOps, ReduceOps +from tinygrad.lazy import PUSH_PERMUTES class CLCache(): def __enter__(self): @@ -131,5 +132,66 @@ class TestOpt(unittest.TestCase): print(img_conv) assert len(GlobalCounters.cache) == 2, "optimizer didn't fold conv/relu" + def helper_push_permute_before_reshape(self, t, should_push=True, desired_reshape_arg=None, desired_permute_arg=None): + if PUSH_PERMUTES and should_push: + assert t.lazydata.op.src[0].op.op == MovementOps.PERMUTE, 'Permute should be pushed before reshape' + assert t.lazydata.op.src[0].op.arg == desired_permute_arg, f'Pushed permute arg should be {desired_permute_arg}' + assert t.lazydata.op.op == MovementOps.RESHAPE, 'Reshape should be after permute' + assert t.lazydata.op.arg == desired_reshape_arg, f'Reshape arg should be {desired_reshape_arg}' + else: + assert t.lazydata.op.src[0].op.op == MovementOps.RESHAPE, 'Reshape should before permute' + assert t.lazydata.op.op == MovementOps.PERMUTE, 'Permute should be after reshape' + + + def test_push_permute_before_reshape(self): + t = Tensor.ones(1,2,3,4) + t = t.reshape(1,2,3*4).permute(2,1,0) + self.helper_push_permute_before_reshape(t, should_push=True, desired_reshape_arg=(12,2,1), desired_permute_arg=(2,3,1,0)) + + t = Tensor.ones(1,2,3,4) + t = t.reshape(3,1,2,4).permute(3,2,1,0) + self.helper_push_permute_before_reshape(t, should_push=False) + + t = Tensor.ones(1,2,3,1,4,1) + t = t.reshape(1,2,3*4).permute(2,1,0) + self.helper_push_permute_before_reshape(t, should_push=True, desired_reshape_arg=(12,2,1), desired_permute_arg=(2,3,4,5,1,0)) + + t = Tensor.ones(1,2,3,4) + t = t.reshape(1,2,3,1,4).permute(4,3,2,1,0) + self.helper_push_permute_before_reshape(t, should_push=False) + + + def test_push_permute_before_reduce(self): + t = Tensor.ones(1,2,3,4) + t = t.sum(axis=2).permute(2,1,0) + if PUSH_PERMUTES: + assert t.lazydata.op.src[0].op.src[0].op.op == MovementOps.PERMUTE, 'Permute should be pushed before reduce' + assert t.lazydata.op.src[0].op.src[0].op.arg == (3,1,0,2), 'Pushed permute arg error' + assert t.lazydata.op.src[0].op.op == ReduceOps.SUM, 'Sum should be after permute' + assert t.lazydata.op.src[0].op.arg == (4,2,1,1), 'Sum arg error' + assert t.lazydata.op.op == MovementOps.RESHAPE, 'Reshape should be after Sum' + assert t.lazydata.op.arg == (4,2,1), 'Reshape arg error' + else: + assert t.lazydata.op.src[0].op.src[0].op.op == ReduceOps.SUM, 'Sum should be the first' + assert t.lazydata.op.src[0].op.src[0].op.arg == (1,2,4,1), 'Sum arg error' + assert t.lazydata.op.src[0].op.op == MovementOps.RESHAPE, 'Reshape should be after sum' + assert t.lazydata.op.src[0].op.arg == (1,2,4), 'Reshape arg error' + assert t.lazydata.op.op == MovementOps.PERMUTE, 'Permute should be after Reshape' + assert t.lazydata.op.arg == (2,1,0), 'Permute arg error' + + def test_push_permute_before_expand(self): + t = Tensor.ones(1,2,3,4) + t = t.expand(2,2,3,4).permute(3,2,1,0) + if PUSH_PERMUTES: + assert t.lazydata.op.src[0].op.op == MovementOps.PERMUTE, 'Permute should be pushed before reduce' + assert t.lazydata.op.src[0].op.arg == (3,2,1,0), 'Pushed permute arg error' + assert t.lazydata.op.op == MovementOps.EXPAND, 'Expand should be after permute' + assert t.lazydata.op.arg == (4,3,2,2), 'Expand arg error' + else: + assert t.lazydata.op.src[0].op.op == MovementOps.EXPAND, 'Expand should be the first' + assert t.lazydata.op.src[0].op.arg == (2,2,3,4), 'Expand arg error' + assert t.lazydata.op.op == MovementOps.PERMUTE, 'Permute should be after expand' + assert t.lazydata.op.arg == (3,2,1,0), 'Permute arg error' + if __name__ == '__main__': unittest.main() diff --git a/test/unit/test_contraction.py b/test/unit/test_contraction.py new file mode 100644 index 0000000000..3886cfbaec --- /dev/null +++ b/test/unit/test_contraction.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +import unittest +from tinygrad.lazy import get_contraction + +class TestFlopCounter(unittest.TestCase): + def test_contraction(self): + r = get_contraction((1,2,3,4), (2,3,4)) + self.assertEqual(r, [[0, 1], [2], [3]]) + + r = get_contraction((1,2,3,1,4), (1,2,3,4)) + self.assertEqual(r, [[0], [1], [2], [3, 4]]) + + r = get_contraction((1,2,3,1,4,1,1), (2,3,4)) + self.assertEqual(r, [[0, 1], [2], [3, 4, 5, 6]]) + + r = get_contraction((1,2,3,4), (1,2,3*4)) + self.assertEqual(r, [[0], [1], [2, 3]]) + + r = get_contraction((1,2,3,4), (2,1,3,4)) + self.assertEqual(r, None) + + r = get_contraction((1,2,3,4), (1,2,3,4,1)) + self.assertEqual(r, None) + + r = get_contraction((1,2,3,4), (1,2,6,2)) + self.assertEqual(r, None) + + + +if __name__ == '__main__': + unittest.main() diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 12cdeddbe9..3b72a7f127 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import Optional, Tuple, Union, List, Dict, Any, ClassVar, Type -import os, sys, weakref, importlib, inspect +import os, sys, weakref, importlib, inspect, functools from weakref import WeakValueDictionary from tinygrad.helpers import prod, getenv from tinygrad.shape import ShapeTracker @@ -78,6 +78,19 @@ def replace_with_movement_op(y:Union[LazyOp, LazyBuffer], op:MovementOps, arg:Tu assert y.op in BinaryOps or y.op in UnaryOps return elementwise_op(y.op, *[replace_with_movement_op(z, op, arg) for z in y.src]) # type: ignore +def get_contraction(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]): + if len(new_shape) > len(old_shape): return None + new_shape_i : int = 0 + shape_idx_groups : List[List[int]] = [[] for _ in range(len(new_shape))] + for old_shape_i, t in enumerate(old_shape): + if new_shape[new_shape_i] % t != 0 or prod([old_shape[x] for x in shape_idx_groups[new_shape_i]]) * t > new_shape[new_shape_i]: + return None + shape_idx_groups[new_shape_i].append(old_shape_i) + if prod([old_shape[x] for x in shape_idx_groups[new_shape_i]]) == new_shape[new_shape_i] and new_shape_i < len(new_shape) - 1: + new_shape_i += 1 + return shape_idx_groups + + def support_weakref(x): return x @support_weakref # needed for mypyc, this prevents LazyBuffer from becoming a native class class LazyBuffer: @@ -204,28 +217,8 @@ class LazyBuffer: # move permutes before reshapes if we can if op == MovementOps.PERMUTE and PUSH_PERMUTES and self.realized is None and self.op.op == MovementOps.RESHAPE and isinstance(self.op.src[0], LazyBuffer): - # TODO: this is atrocious code - # is contract? if so, group the axis - def get_contraction(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]): - out : List[List[int]] = [] - curr : List[int] = [] - for t in old_shape: - if len(out) >= len(new_shape): break - if t*prod(curr) <= new_shape[len(out)]: - curr.append(t) - else: - out.append(curr) - curr = [t] - out.append(curr) - if len(new_shape) == len(out) and all(prod(i) == j and len(i) >= 1 for i,j in zip(out, new_shape)): - return out - if contraction := get_contraction(self.op.src[0].shape, self.shape): - numbered, start = [], 0 - for c in contraction: - numbered.append(list(range(start, start+len(c)))) - start += len(c) - new_arg = [] - for p in arg: new_arg += numbered[p] + if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape): + new_arg : List[int] = functools.reduce(lambda r, x: r + shape_idx_groups[x], arg, []) self.op.src[0].children.discard(self) # this changes nothing? return self.op.src[0].movement_op(MovementOps.PERMUTE, tuple(new_arg)) \ .movement_op(MovementOps.RESHAPE, ShapeTracker(self.st).movement_op(op, arg).shape)