mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
_pool2d -> _pool
This commit is contained in:
@@ -10,7 +10,8 @@ def all_same(items): return all(x == items[0] for x in items) if len(items) > 0
|
||||
def colored(st, color): return f"\u001b[{30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color)}m{st}\u001b[0m" # replace the termcolor library with one line
|
||||
def partition(lst, fxn): return [x for x in lst if fxn(x)], [x for x in lst if not fxn(x)]
|
||||
def modn(x, a): return -((-x)%a) if x < 0 else x%a
|
||||
def make_pair(x:Union[int, Tuple[int, ...]]) -> Tuple[int, ...]: return (x,x) if isinstance(x, int) else x
|
||||
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
|
||||
def flatten(l): return [item for sublist in l for item in sublist]
|
||||
|
||||
class Timing(object):
|
||||
def __enter__(self): self.st = time.monotonic_ns()
|
||||
|
||||
@@ -227,7 +227,7 @@ class ShapeTracker:
|
||||
|
||||
def expand(self, new_shape : Tuple[int, ...]) -> ShapeTracker:
|
||||
assert isinstance(new_shape, tuple)
|
||||
assert all(isinstance(x, int) for x in new_shape)
|
||||
assert all(isinstance(x, int) for x in new_shape), f"non ints for expand in {new_shape}"
|
||||
assert all(x == y or x == 1 for x,y in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
|
||||
strides : Tuple[int, ...] = tuple(s if x == y else 0 for s,(x,y) in zip(self.strides, zip(self.shape, new_shape)))
|
||||
self.views[-1] = View(new_shape, strides, self.offset)
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import math, functools, itertools
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union
|
||||
from tinygrad.helpers import prod, argfix, make_pair, getenv, DEBUG
|
||||
from tinygrad.helpers import prod, argfix, make_pair, getenv, DEBUG, flatten
|
||||
from tinygrad.lazy import Device, LazyBuffer
|
||||
|
||||
HLOP = getenv("HLOP", 0)
|
||||
@@ -195,7 +195,7 @@ class Tensor:
|
||||
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple(x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))))
|
||||
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
|
||||
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=argfix(axis, *args))
|
||||
def slice(self, arg) -> Tensor: return mlops.Slice.apply(self, arg=arg)
|
||||
def slice(self, arg) -> Tensor: return mlops.Slice.apply(self, arg=tuple(arg))
|
||||
|
||||
# ***** movement hlops *****
|
||||
|
||||
@@ -285,29 +285,33 @@ class Tensor:
|
||||
|
||||
# ***** processing ops *****
|
||||
|
||||
def _pool2d(self, ky, kx, sy, sx, dy=1, dx=1):
|
||||
if ky > sy or kx > sx or dy != 1 or dx != 1:
|
||||
bs,c,iy,ix = self.shape
|
||||
oy = (iy - dy * (ky-1) - 1)//sy + 1
|
||||
ox = (ix - dx * (kx-1) - 1)//sx + 1
|
||||
# duplicate the inputs for each of the kernels
|
||||
#xup = self.reshape(bs, c, 1, iy, 1, ix).expand(bs, c, ky, iy, kx, ix).reshape(bs, c, ky*iy, kx*ix)
|
||||
# NOTE: if you oversize this, you can avoid the ZeroView creation. remove when optimizer can fix
|
||||
ey, ex = math.ceil(ky*(iy+dy) / iy), math.ceil(kx*(ix+dx) / ix)
|
||||
xup = self.reshape(bs, c, 1, iy, 1, ix).expand(bs, c, ey, iy, ex, ix).reshape(bs, c, ey*iy, ex*ix)
|
||||
def _pool(self, k_:Tuple[int, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor:
|
||||
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
|
||||
s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_))
|
||||
assert len(k_) == len(s_) and len(k_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
|
||||
slc_prefix, prefix, i_ = [(0,x) for x in self.shape[0:-len(k_)]], self.shape[0:-len(k_)], self.shape[-len(k_):]
|
||||
if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_):
|
||||
o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)]
|
||||
e_ = [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding
|
||||
xup = self.reshape(*prefix, *flatten((1,i) for i in i_)).expand(*prefix, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *[e*i for e,i in zip(e_, i_)])
|
||||
# slide by dilation
|
||||
xup = xup.slice(((0,bs), (0,c), (0,ky*(iy+dy)), (0,kx*(ix+dx))))
|
||||
xup = xup.reshape(bs, c, ky, iy+dy, kx, ix+dx)
|
||||
xup = xup.slice(((0,bs), (0,c), (0,ky), (0,oy*sy), (0,kx), (0,ox*sx)))
|
||||
xup = xup.slice(slc_prefix + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])
|
||||
xup = xup.reshape(*prefix, *flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
|
||||
xup = xup.slice(slc_prefix + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_)))
|
||||
# handle stride, and permute to move reduce to the end
|
||||
return xup.reshape(bs, c, ky, oy, sy, kx, ox, sx)[:, :, :, :, 0, :, :, 0]
|
||||
xup = xup.reshape(*prefix, *flatten((k,o,s) for k,o,s in zip(k_, o_, s_)))
|
||||
xup = xup.slice(slc_prefix + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_)))
|
||||
return xup.reshape(*prefix, *flatten((k,o) for k,o in zip(k_, o_)))
|
||||
else:
|
||||
# TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker
|
||||
xup = self.slice(((0, self.shape[0]), (0, self.shape[1]), (0, (self.shape[2]+(sy-ky))//sy*sy), (0, (self.shape[3]+(sx-kx))//sx*sx)))
|
||||
return xup.reshape(shape=(xup.shape[0], xup.shape[1], xup.shape[2]//sy, sy, xup.shape[3]//sx, sx))[:, :, :, :ky, :, :kx].permute(0, 1, 3, 2, 5, 4)
|
||||
o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)]
|
||||
xup = self.slice(slc_prefix + [(0,o*s) for o,s in zip(o_, s_)])
|
||||
xup = xup.reshape(*prefix, *flatten(((o, s) for o,s in zip(o_, s_))))
|
||||
xup = xup.slice(slc_prefix + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
|
||||
return xup.permute(*range(len(prefix)), *flatten((len(prefix)+i*2+1, len(prefix)+i*2) for i in range(len(k_))))
|
||||
|
||||
def avg_pool2d(self, kernel_size=(2,2), stride=None): return self._pool2d(*make_pair(kernel_size), *make_pair(stride if stride is not None else kernel_size)).mean(axis=(2,4))
|
||||
def max_pool2d(self, kernel_size=(2,2), stride=None): return self._pool2d(*make_pair(kernel_size), *make_pair(stride if stride is not None else kernel_size)).max(axis=(2,4))
|
||||
def avg_pool2d(self, kernel_size=(2,2), stride=None): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size).mean(axis=(2,4))
|
||||
def max_pool2d(self, kernel_size=(2,2), stride=None): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size).max(axis=(2,4))
|
||||
|
||||
@image_conv2d_decorator
|
||||
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor:
|
||||
@@ -321,7 +325,7 @@ class Tensor:
|
||||
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
|
||||
|
||||
# conv2d is a pooling op (with padding)
|
||||
x = self.pad2d(padding_)._pool2d(H,W,*make_pair(stride),*make_pair(dilation))
|
||||
x = self.pad2d(padding_)._pool((H,W),stride, dilation)
|
||||
|
||||
oy, ox, rcout = x.shape[3], x.shape[5], cout//groups
|
||||
# NOTE: we do this expand explicitly so the permute isn't pushed in the binop
|
||||
|
||||
Reference in New Issue
Block a user