mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-11 16:08:10 -05:00
strided Pool funcs (#74)
* *Pool2D GPU forward supports stride * kernel_size from ctx instead of saved_tensors * *Pool2D CPU forward supports stride * update ctx.stride properly
This commit is contained in:
@@ -87,18 +87,23 @@ class TestOps(unittest.TestCase):
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,stride=(2,1)).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), atol=2e-5, grad_atol=2e-6, gpu=self.gpu, forward_only=self.gpu)
|
||||
|
||||
def test_maxpool2x2(self):
|
||||
helper_test_op([(32,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, (2,2)), Tensor.max_pool2d, gpu=self.gpu, forward_only=self.gpu)
|
||||
def test_maxpool2d(self):
|
||||
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
|
||||
for strd in [(1,1), (2,1), (2,2), (4,2)]:
|
||||
# TODO Grad tolerance for CPU implementation needs to be slightly relaxed; why?
|
||||
with self.subTest(kernel_size=ksz, stride=strd):
|
||||
helper_test_op([(32,2,110,28)],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz, stride=strd),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=ksz, stride=strd), gpu=self.gpu, forward_only=self.gpu, grad_atol=1e-3)
|
||||
|
||||
def test_maxpool_sizes(self):
|
||||
for sz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
|
||||
helper_test_op([(32,2,110,28)],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=sz),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=sz), gpu=self.gpu, forward_only=self.gpu)
|
||||
|
||||
def test_avgpool2x2(self):
|
||||
# TODO Grad tolerance needs to be slightly relaxed; why?
|
||||
helper_test_op([(32,2,111,28)], lambda x: torch.nn.functional.avg_pool2d(x, (2,2)), Tensor.avg_pool2d, gpu=self.gpu, grad_atol=1e-5)
|
||||
def test_avgpool2d(self):
|
||||
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
|
||||
for strd in [(1,1), (2,1), (2,2), (4,2)]:
|
||||
# TODO Grad tolerance needs to be slightly relaxed; why?
|
||||
with self.subTest(kernel_size=ksz, stride=strd):
|
||||
helper_test_op([(32,2,111,28)],
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, stride=strd),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, stride=strd), gpu=self.gpu, grad_atol=1e-5)
|
||||
|
||||
if GPU:
|
||||
class TestOpsGPU(TestOps):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import warnings
|
||||
import numpy as np
|
||||
from .tensor import Function, register
|
||||
|
||||
@@ -236,47 +237,54 @@ register('conv2d', Conv2D)
|
||||
|
||||
# ************* pooling ops *************
|
||||
|
||||
def stack_for_pool(x, py, px):
|
||||
my, mx = (x.shape[2]//py)*py, (x.shape[3]//px)*px
|
||||
stack = []
|
||||
xup = x[:, :, :my, :mx]
|
||||
for Y in range(py):
|
||||
for X in range(px):
|
||||
stack.append(xup[:, :, Y::py, X::px][None])
|
||||
return np.concatenate(stack, axis=0)
|
||||
def stack_for_pool(x, kernel_size, stride, fill_value=0):
|
||||
(ky, kx), (py, px) = kernel_size, stride
|
||||
my, mx = (x.shape[2]-ky)//py+1, (x.shape[3]-kx)//px+1
|
||||
stack = fill_value*np.ones((ky, kx, *x.shape[:2], my+ky, mx+kx), dtype=x.dtype)
|
||||
for Y in range(ky):
|
||||
for X in range(kx):
|
||||
sl = x[..., Y:Y+my*py+ky:py, X:X+mx*px+kx:px]
|
||||
stack[Y, X, ..., :sl.shape[2], :sl.shape[3]] = sl
|
||||
return stack.reshape(-1, *stack.shape[2:]), (my, mx)
|
||||
|
||||
def unstack_for_pool(fxn, s, py, px):
|
||||
my, mx = (s[2]//py)*py, (s[3]//px)*px
|
||||
for Y in range(py):
|
||||
for X in range(px):
|
||||
ll = fxn(Y*px+X)
|
||||
def unstack_for_pool(fxn, s, kernel_size, stride):
|
||||
(ky, kx), (py, px) = kernel_size, stride
|
||||
for Y in range(ky):
|
||||
for X in range(kx):
|
||||
ll = fxn(Y*kx+X)
|
||||
if X == 0 and Y == 0:
|
||||
ret = np.zeros(s, dtype=ll.dtype)
|
||||
ret[:, :, Y:my:py, X:mx:px] = ll
|
||||
return ret
|
||||
ret = np.zeros((*s[:2], s[2]+ky, s[3]+kx), dtype=ll.dtype)
|
||||
ret[..., Y:Y+ll.shape[2]*py:py, X:X+ll.shape[3]*px:px] = ll
|
||||
return ret[..., :s[2], :s[3]]
|
||||
|
||||
class MaxPool2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, kernel_size=(2, 2)):
|
||||
stack = stack_for_pool(x, *kernel_size)
|
||||
idxs = np.argmax(stack, axis=0)
|
||||
def forward(ctx, x, kernel_size=(2, 2), stride=None):
|
||||
if not stride:
|
||||
ctx.stride = stride = kernel_size
|
||||
stack, (my, mx) = stack_for_pool(x, kernel_size, stride, fill_value=-np.inf)
|
||||
idxs = np.nanargmax(stack, axis=0)[..., :my, :mx]
|
||||
ctx.save_for_backward(idxs, x.shape)
|
||||
return np.max(stack, axis=0)
|
||||
return np.amax(stack, axis=0)[..., :my, :mx]
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
idxs,s = ctx.saved_tensors
|
||||
return unstack_for_pool(
|
||||
lambda idx: grad_output * (idxs == idx),
|
||||
s, *ctx.kernel_size)
|
||||
s, ctx.kernel_size, ctx.stride)
|
||||
register('max_pool2d', MaxPool2D)
|
||||
|
||||
class AvgPool2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, kernel_size=(2, 2)):
|
||||
stack = stack_for_pool(x, *kernel_size)
|
||||
def forward(ctx, x, kernel_size=(2, 2), stride=None):
|
||||
if not stride:
|
||||
ctx.stride = stride = kernel_size
|
||||
stack, (my, mx) = stack_for_pool(x, kernel_size, stride, fill_value=np.nan)
|
||||
ctx.save_for_backward(x.shape)
|
||||
return np.mean(stack, axis=0)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
return np.nanmean(stack, axis=0)[...,:my, :mx]
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
@@ -284,6 +292,6 @@ class AvgPool2D(Function):
|
||||
py, px = ctx.kernel_size
|
||||
return unstack_for_pool(
|
||||
lambda idx: grad_output/py/px,
|
||||
s, py, px)
|
||||
s, ctx.kernel_size, ctx.stride)
|
||||
register('avg_pool2d', AvgPool2D)
|
||||
|
||||
|
||||
@@ -28,14 +28,15 @@ def clbuild(cl_ctx, prg):
|
||||
def cl_subsample_krnl_build(cl_ctx, iter_op, result_op, init_val=0):
|
||||
prg = """
|
||||
__kernel void subsample(
|
||||
__global float *output, __global const float *input, uint2 osize, uint2 isize, uint2 kernel_size, int nelem
|
||||
__global float *output, __global const float *input, uint2 osize, uint2 isize, uint2 kernel_size,
|
||||
uint2 stride, int nelem
|
||||
) {
|
||||
int3 gid = (int3)(get_global_id(2), get_global_id(1), get_global_id(0));
|
||||
int oid = gid.x + osize.x*(gid.y + osize.y*gid.z);
|
||||
float group_res = """+str(init_val)+""";
|
||||
for (uint j=0; j<kernel_size.y; ++j) {
|
||||
for (uint i=0; i<kernel_size.x; ++i) {
|
||||
int iid = (gid.x*kernel_size.x+i) + isize.x*((gid.y*kernel_size.y+j) + isize.y*gid.z);
|
||||
int iid = (gid.x*stride.x+i) + isize.x*((gid.y*stride.y+j) + isize.y*gid.z);
|
||||
if (iid < nelem)
|
||||
"""+iter_op+""";
|
||||
}
|
||||
@@ -45,17 +46,19 @@ def cl_subsample_krnl_build(cl_ctx, iter_op, result_op, init_val=0):
|
||||
"""
|
||||
return clbuild(cl_ctx, prg)
|
||||
|
||||
def subsample_op(ctx, input, kernel_size, iter_op, result_op, init_val=0):
|
||||
N, C, Y, X = input.shape
|
||||
py,px = kernel_size
|
||||
ret = buffer_new(ctx, (N, C, Y//py, X//px))
|
||||
osize = np.array((X//px, Y//py), dtype=cl.cltypes.uint2)
|
||||
isize = np.array((X, Y), dtype=cl.cltypes.uint2)
|
||||
ksize = np.array((px,py), dtype=cl.cltypes.uint2)
|
||||
def subsample_op(ctx, input, kernel_size, stride, iter_op, result_op, init_val=0):
|
||||
py, px = stride
|
||||
N, C, Yin, Xin = input.shape
|
||||
Yout, Xout = (Yin-kernel_size[0])//py+1, (Xin-kernel_size[1])//px+1
|
||||
ret = buffer_new(ctx, (N, C, Yout, Xout))
|
||||
osize = np.array((Xout, Yout), dtype=cl.cltypes.uint2)
|
||||
isize = np.array((Xin, Yin), dtype=cl.cltypes.uint2)
|
||||
ksize = np.array(kernel_size[::-1], dtype=cl.cltypes.uint2)
|
||||
strd = np.array((px, py), dtype=cl.cltypes.uint2)
|
||||
prg = cl_subsample_krnl_build(ctx.cl_ctx, iter_op, result_op, init_val=init_val)
|
||||
prg.subsample(ctx.cl_queue, (N*C, Y//py, X//px), None,
|
||||
ret, input, osize, isize, ksize, np.int32(input.size))
|
||||
ctx.data = np.empty((N, C, Y, X)) # set shape expectation on tensor instance
|
||||
prg.subsample(ctx.cl_queue, (N*C, Yout, Xout), None,
|
||||
ret, input, osize, isize, ksize, strd, np.int32(input.size))
|
||||
ctx.data = np.empty((N, C, Yout, Xout)) # set shape expectation on tensor instance
|
||||
return ret
|
||||
|
||||
@functools.lru_cache
|
||||
@@ -377,27 +380,31 @@ register('sigmoid', Sigmoid, gpu=True)
|
||||
|
||||
class AvgPool2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, kernel_size=(2, 2)):
|
||||
def forward(ctx, input, kernel_size=(2, 2), stride=None):
|
||||
if not stride:
|
||||
ctx.stride = stride = kernel_size
|
||||
iter_op = "group_res += input[iid]"
|
||||
result_op = "group_res / (kernel_size.x * kernel_size.y)"
|
||||
ret = subsample_op(ctx, input, kernel_size, iter_op, result_op)
|
||||
ctx.save_for_backward(kernel_size, input.shape)
|
||||
ret = subsample_op(ctx, input, kernel_size, stride, iter_op, result_op)
|
||||
ctx.save_for_backward(input.shape)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
kernel_size, orig_shape = ctx.saved_tensors
|
||||
result_op = "input[iid] / (float)(kernel_size.x * kernel_size.y)"
|
||||
return supersample_op(ctx, grad_output, orig_shape, kernel_size, result_op)
|
||||
orig_shape, = ctx.saved_tensors
|
||||
result_op = "input[iid] / (kernel_size.x * kernel_size.y)"
|
||||
return supersample_op(ctx, grad_output, orig_shape, ctx.kernel_size, result_op)
|
||||
register('avg_pool2d', AvgPool2D, gpu=True)
|
||||
|
||||
class MaxPool2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, kernel_size=(2, 2)):
|
||||
def forward(ctx, input, kernel_size=(2, 2), stride=None):
|
||||
if not stride:
|
||||
ctx.stride = stride = kernel_size
|
||||
init_val = "FLT_MIN"
|
||||
iter_op = "group_res = max(group_res, input[iid])"
|
||||
result_op = "group_res"
|
||||
return subsample_op(ctx, input, kernel_size, iter_op, result_op, init_val=init_val)
|
||||
return subsample_op(ctx, input, kernel_size, stride, iter_op, result_op, init_val=init_val)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
|
||||
Reference in New Issue
Block a user