From d722ffbd04a2cc1f488a7f9d72aa88623df889f9 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Tue, 28 Feb 2023 11:35:19 -0800 Subject: [PATCH] _pool2d -> _pool --- tinygrad/helpers.py | 3 ++- tinygrad/shape/__init__.py | 2 +- tinygrad/tensor.py | 46 +++++++++++++++++++++----------------- 3 files changed, 28 insertions(+), 23 deletions(-) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 33f812567a..aa02b8d2bb 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -10,7 +10,8 @@ def all_same(items): return all(x == items[0] for x in items) if len(items) > 0 def colored(st, color): return f"\u001b[{30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color)}m{st}\u001b[0m" # replace the termcolor library with one line def partition(lst, fxn): return [x for x in lst if fxn(x)], [x for x in lst if not fxn(x)] def modn(x, a): return -((-x)%a) if x < 0 else x%a -def make_pair(x:Union[int, Tuple[int, ...]]) -> Tuple[int, ...]: return (x,x) if isinstance(x, int) else x +def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x +def flatten(l): return [item for sublist in l for item in sublist] class Timing(object): def __enter__(self): self.st = time.monotonic_ns() diff --git a/tinygrad/shape/__init__.py b/tinygrad/shape/__init__.py index e261b39971..b99f2de586 100644 --- a/tinygrad/shape/__init__.py +++ b/tinygrad/shape/__init__.py @@ -227,7 +227,7 @@ class ShapeTracker: def expand(self, new_shape : Tuple[int, ...]) -> ShapeTracker: assert isinstance(new_shape, tuple) - assert all(isinstance(x, int) for x in new_shape) + assert all(isinstance(x, int) for x in new_shape), f"non ints for expand in {new_shape}" assert all(x == y or x == 1 for x,y in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}" strides : Tuple[int, ...] = tuple(s if x == y else 0 for s,(x,y) in zip(self.strides, zip(self.shape, new_shape))) self.views[-1] = View(new_shape, strides, self.offset) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 94371bec8b..86efc9c9dd 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3,7 +3,7 @@ from __future__ import annotations import math, functools, itertools import numpy as np from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union -from tinygrad.helpers import prod, argfix, make_pair, getenv, DEBUG +from tinygrad.helpers import prod, argfix, make_pair, getenv, DEBUG, flatten from tinygrad.lazy import Device, LazyBuffer HLOP = getenv("HLOP", 0) @@ -195,7 +195,7 @@ class Tensor: def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple(x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args)))) def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args)) def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=argfix(axis, *args)) - def slice(self, arg) -> Tensor: return mlops.Slice.apply(self, arg=arg) + def slice(self, arg) -> Tensor: return mlops.Slice.apply(self, arg=tuple(arg)) # ***** movement hlops ***** @@ -285,29 +285,33 @@ class Tensor: # ***** processing ops ***** - def _pool2d(self, ky, kx, sy, sx, dy=1, dx=1): - if ky > sy or kx > sx or dy != 1 or dx != 1: - bs,c,iy,ix = self.shape - oy = (iy - dy * (ky-1) - 1)//sy + 1 - ox = (ix - dx * (kx-1) - 1)//sx + 1 - # duplicate the inputs for each of the kernels - #xup = self.reshape(bs, c, 1, iy, 1, ix).expand(bs, c, ky, iy, kx, ix).reshape(bs, c, ky*iy, kx*ix) - # NOTE: if you oversize this, you can avoid the ZeroView creation. remove when optimizer can fix - ey, ex = math.ceil(ky*(iy+dy) / iy), math.ceil(kx*(ix+dx) / ix) - xup = self.reshape(bs, c, 1, iy, 1, ix).expand(bs, c, ey, iy, ex, ix).reshape(bs, c, ey*iy, ex*ix) + def _pool(self, k_:Tuple[int, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor: + assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}" + s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_)) + assert len(k_) == len(s_) and len(k_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}" + slc_prefix, prefix, i_ = [(0,x) for x in self.shape[0:-len(k_)]], self.shape[0:-len(k_)], self.shape[-len(k_):] + if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_): + o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)] + e_ = [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding + xup = self.reshape(*prefix, *flatten((1,i) for i in i_)).expand(*prefix, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *[e*i for e,i in zip(e_, i_)]) # slide by dilation - xup = xup.slice(((0,bs), (0,c), (0,ky*(iy+dy)), (0,kx*(ix+dx)))) - xup = xup.reshape(bs, c, ky, iy+dy, kx, ix+dx) - xup = xup.slice(((0,bs), (0,c), (0,ky), (0,oy*sy), (0,kx), (0,ox*sx))) + xup = xup.slice(slc_prefix + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)]) + xup = xup.reshape(*prefix, *flatten((k,i+d) for k,i,d in zip(k_, i_, d_))) + xup = xup.slice(slc_prefix + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_))) # handle stride, and permute to move reduce to the end - return xup.reshape(bs, c, ky, oy, sy, kx, ox, sx)[:, :, :, :, 0, :, :, 0] + xup = xup.reshape(*prefix, *flatten((k,o,s) for k,o,s in zip(k_, o_, s_))) + xup = xup.slice(slc_prefix + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_))) + return xup.reshape(*prefix, *flatten((k,o) for k,o in zip(k_, o_))) else: # TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker - xup = self.slice(((0, self.shape[0]), (0, self.shape[1]), (0, (self.shape[2]+(sy-ky))//sy*sy), (0, (self.shape[3]+(sx-kx))//sx*sx))) - return xup.reshape(shape=(xup.shape[0], xup.shape[1], xup.shape[2]//sy, sy, xup.shape[3]//sx, sx))[:, :, :, :ky, :, :kx].permute(0, 1, 3, 2, 5, 4) + o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)] + xup = self.slice(slc_prefix + [(0,o*s) for o,s in zip(o_, s_)]) + xup = xup.reshape(*prefix, *flatten(((o, s) for o,s in zip(o_, s_)))) + xup = xup.slice(slc_prefix + flatten(((0,o), (0,k)) for o,k in zip(o_, k_))) + return xup.permute(*range(len(prefix)), *flatten((len(prefix)+i*2+1, len(prefix)+i*2) for i in range(len(k_)))) - def avg_pool2d(self, kernel_size=(2,2), stride=None): return self._pool2d(*make_pair(kernel_size), *make_pair(stride if stride is not None else kernel_size)).mean(axis=(2,4)) - def max_pool2d(self, kernel_size=(2,2), stride=None): return self._pool2d(*make_pair(kernel_size), *make_pair(stride if stride is not None else kernel_size)).max(axis=(2,4)) + def avg_pool2d(self, kernel_size=(2,2), stride=None): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size).mean(axis=(2,4)) + def max_pool2d(self, kernel_size=(2,2), stride=None): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size).max(axis=(2,4)) @image_conv2d_decorator def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor: @@ -321,7 +325,7 @@ class Tensor: return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1)) # conv2d is a pooling op (with padding) - x = self.pad2d(padding_)._pool2d(H,W,*make_pair(stride),*make_pair(dilation)) + x = self.pad2d(padding_)._pool((H,W),stride, dilation) oy, ox, rcout = x.shape[3], x.shape[5], cout//groups # NOTE: we do this expand explicitly so the permute isn't pushed in the binop