strided Pool funcs (#74)

* *Pool2D GPU forward supports stride

* kernel_size from ctx instead of saved_tensors

* *Pool2D CPU forward supports stride

* update ctx.stride properly
This commit is contained in:
Ryan Neph
2020-11-08 11:45:55 -08:00
committed by GitHub
parent 06504a5824
commit b0c0c5d0d6
3 changed files with 76 additions and 56 deletions

View File

@@ -87,18 +87,23 @@ class TestOps(unittest.TestCase):
lambda x,w: torch.nn.functional.conv2d(x,w,stride=(2,1)).relu(),
lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), atol=2e-5, grad_atol=2e-6, gpu=self.gpu, forward_only=self.gpu)
def test_maxpool2x2(self):
helper_test_op([(32,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, (2,2)), Tensor.max_pool2d, gpu=self.gpu, forward_only=self.gpu)
def test_maxpool2d(self):
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
for strd in [(1,1), (2,1), (2,2), (4,2)]:
# TODO Grad tolerance for CPU implementation needs to be slightly relaxed; why?
with self.subTest(kernel_size=ksz, stride=strd):
helper_test_op([(32,2,110,28)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz, stride=strd),
lambda x: Tensor.max_pool2d(x, kernel_size=ksz, stride=strd), gpu=self.gpu, forward_only=self.gpu, grad_atol=1e-3)
def test_maxpool_sizes(self):
for sz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
helper_test_op([(32,2,110,28)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=sz),
lambda x: Tensor.max_pool2d(x, kernel_size=sz), gpu=self.gpu, forward_only=self.gpu)
def test_avgpool2x2(self):
# TODO Grad tolerance needs to be slightly relaxed; why?
helper_test_op([(32,2,111,28)], lambda x: torch.nn.functional.avg_pool2d(x, (2,2)), Tensor.avg_pool2d, gpu=self.gpu, grad_atol=1e-5)
def test_avgpool2d(self):
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
for strd in [(1,1), (2,1), (2,2), (4,2)]:
# TODO Grad tolerance needs to be slightly relaxed; why?
with self.subTest(kernel_size=ksz, stride=strd):
helper_test_op([(32,2,111,28)],
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, stride=strd),
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, stride=strd), gpu=self.gpu, grad_atol=1e-5)
if GPU:
class TestOpsGPU(TestOps):

View File

@@ -1,3 +1,4 @@
import warnings
import numpy as np
from .tensor import Function, register
@@ -236,47 +237,54 @@ register('conv2d', Conv2D)
# ************* pooling ops *************
def stack_for_pool(x, py, px):
my, mx = (x.shape[2]//py)*py, (x.shape[3]//px)*px
stack = []
xup = x[:, :, :my, :mx]
for Y in range(py):
for X in range(px):
stack.append(xup[:, :, Y::py, X::px][None])
return np.concatenate(stack, axis=0)
def stack_for_pool(x, kernel_size, stride, fill_value=0):
(ky, kx), (py, px) = kernel_size, stride
my, mx = (x.shape[2]-ky)//py+1, (x.shape[3]-kx)//px+1
stack = fill_value*np.ones((ky, kx, *x.shape[:2], my+ky, mx+kx), dtype=x.dtype)
for Y in range(ky):
for X in range(kx):
sl = x[..., Y:Y+my*py+ky:py, X:X+mx*px+kx:px]
stack[Y, X, ..., :sl.shape[2], :sl.shape[3]] = sl
return stack.reshape(-1, *stack.shape[2:]), (my, mx)
def unstack_for_pool(fxn, s, py, px):
my, mx = (s[2]//py)*py, (s[3]//px)*px
for Y in range(py):
for X in range(px):
ll = fxn(Y*px+X)
def unstack_for_pool(fxn, s, kernel_size, stride):
(ky, kx), (py, px) = kernel_size, stride
for Y in range(ky):
for X in range(kx):
ll = fxn(Y*kx+X)
if X == 0 and Y == 0:
ret = np.zeros(s, dtype=ll.dtype)
ret[:, :, Y:my:py, X:mx:px] = ll
return ret
ret = np.zeros((*s[:2], s[2]+ky, s[3]+kx), dtype=ll.dtype)
ret[..., Y:Y+ll.shape[2]*py:py, X:X+ll.shape[3]*px:px] = ll
return ret[..., :s[2], :s[3]]
class MaxPool2D(Function):
@staticmethod
def forward(ctx, x, kernel_size=(2, 2)):
stack = stack_for_pool(x, *kernel_size)
idxs = np.argmax(stack, axis=0)
def forward(ctx, x, kernel_size=(2, 2), stride=None):
if not stride:
ctx.stride = stride = kernel_size
stack, (my, mx) = stack_for_pool(x, kernel_size, stride, fill_value=-np.inf)
idxs = np.nanargmax(stack, axis=0)[..., :my, :mx]
ctx.save_for_backward(idxs, x.shape)
return np.max(stack, axis=0)
return np.amax(stack, axis=0)[..., :my, :mx]
@staticmethod
def backward(ctx, grad_output):
idxs,s = ctx.saved_tensors
return unstack_for_pool(
lambda idx: grad_output * (idxs == idx),
s, *ctx.kernel_size)
s, ctx.kernel_size, ctx.stride)
register('max_pool2d', MaxPool2D)
class AvgPool2D(Function):
@staticmethod
def forward(ctx, x, kernel_size=(2, 2)):
stack = stack_for_pool(x, *kernel_size)
def forward(ctx, x, kernel_size=(2, 2), stride=None):
if not stride:
ctx.stride = stride = kernel_size
stack, (my, mx) = stack_for_pool(x, kernel_size, stride, fill_value=np.nan)
ctx.save_for_backward(x.shape)
return np.mean(stack, axis=0)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return np.nanmean(stack, axis=0)[...,:my, :mx]
@staticmethod
def backward(ctx, grad_output):
@@ -284,6 +292,6 @@ class AvgPool2D(Function):
py, px = ctx.kernel_size
return unstack_for_pool(
lambda idx: grad_output/py/px,
s, py, px)
s, ctx.kernel_size, ctx.stride)
register('avg_pool2d', AvgPool2D)

View File

@@ -28,14 +28,15 @@ def clbuild(cl_ctx, prg):
def cl_subsample_krnl_build(cl_ctx, iter_op, result_op, init_val=0):
prg = """
__kernel void subsample(
__global float *output, __global const float *input, uint2 osize, uint2 isize, uint2 kernel_size, int nelem
__global float *output, __global const float *input, uint2 osize, uint2 isize, uint2 kernel_size,
uint2 stride, int nelem
) {
int3 gid = (int3)(get_global_id(2), get_global_id(1), get_global_id(0));
int oid = gid.x + osize.x*(gid.y + osize.y*gid.z);
float group_res = """+str(init_val)+""";
for (uint j=0; j<kernel_size.y; ++j) {
for (uint i=0; i<kernel_size.x; ++i) {
int iid = (gid.x*kernel_size.x+i) + isize.x*((gid.y*kernel_size.y+j) + isize.y*gid.z);
int iid = (gid.x*stride.x+i) + isize.x*((gid.y*stride.y+j) + isize.y*gid.z);
if (iid < nelem)
"""+iter_op+""";
}
@@ -45,17 +46,19 @@ def cl_subsample_krnl_build(cl_ctx, iter_op, result_op, init_val=0):
"""
return clbuild(cl_ctx, prg)
def subsample_op(ctx, input, kernel_size, iter_op, result_op, init_val=0):
N, C, Y, X = input.shape
py,px = kernel_size
ret = buffer_new(ctx, (N, C, Y//py, X//px))
osize = np.array((X//px, Y//py), dtype=cl.cltypes.uint2)
isize = np.array((X, Y), dtype=cl.cltypes.uint2)
ksize = np.array((px,py), dtype=cl.cltypes.uint2)
def subsample_op(ctx, input, kernel_size, stride, iter_op, result_op, init_val=0):
py, px = stride
N, C, Yin, Xin = input.shape
Yout, Xout = (Yin-kernel_size[0])//py+1, (Xin-kernel_size[1])//px+1
ret = buffer_new(ctx, (N, C, Yout, Xout))
osize = np.array((Xout, Yout), dtype=cl.cltypes.uint2)
isize = np.array((Xin, Yin), dtype=cl.cltypes.uint2)
ksize = np.array(kernel_size[::-1], dtype=cl.cltypes.uint2)
strd = np.array((px, py), dtype=cl.cltypes.uint2)
prg = cl_subsample_krnl_build(ctx.cl_ctx, iter_op, result_op, init_val=init_val)
prg.subsample(ctx.cl_queue, (N*C, Y//py, X//px), None,
ret, input, osize, isize, ksize, np.int32(input.size))
ctx.data = np.empty((N, C, Y, X)) # set shape expectation on tensor instance
prg.subsample(ctx.cl_queue, (N*C, Yout, Xout), None,
ret, input, osize, isize, ksize, strd, np.int32(input.size))
ctx.data = np.empty((N, C, Yout, Xout)) # set shape expectation on tensor instance
return ret
@functools.lru_cache
@@ -377,27 +380,31 @@ register('sigmoid', Sigmoid, gpu=True)
class AvgPool2D(Function):
@staticmethod
def forward(ctx, input, kernel_size=(2, 2)):
def forward(ctx, input, kernel_size=(2, 2), stride=None):
if not stride:
ctx.stride = stride = kernel_size
iter_op = "group_res += input[iid]"
result_op = "group_res / (kernel_size.x * kernel_size.y)"
ret = subsample_op(ctx, input, kernel_size, iter_op, result_op)
ctx.save_for_backward(kernel_size, input.shape)
ret = subsample_op(ctx, input, kernel_size, stride, iter_op, result_op)
ctx.save_for_backward(input.shape)
return ret
@staticmethod
def backward(ctx, grad_output):
kernel_size, orig_shape = ctx.saved_tensors
result_op = "input[iid] / (float)(kernel_size.x * kernel_size.y)"
return supersample_op(ctx, grad_output, orig_shape, kernel_size, result_op)
orig_shape, = ctx.saved_tensors
result_op = "input[iid] / (kernel_size.x * kernel_size.y)"
return supersample_op(ctx, grad_output, orig_shape, ctx.kernel_size, result_op)
register('avg_pool2d', AvgPool2D, gpu=True)
class MaxPool2D(Function):
@staticmethod
def forward(ctx, input, kernel_size=(2, 2)):
def forward(ctx, input, kernel_size=(2, 2), stride=None):
if not stride:
ctx.stride = stride = kernel_size
init_val = "FLT_MIN"
iter_op = "group_res = max(group_res, input[iid])"
result_op = "group_res"
return subsample_op(ctx, input, kernel_size, iter_op, result_op, init_val=init_val)
return subsample_op(ctx, input, kernel_size, stride, iter_op, result_op, init_val=init_val)
@staticmethod
def backward(ctx, grad_output):