From 93dceb4bee3964e80dc5ccbbb2914d96000100d3 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 26 Oct 2020 08:38:53 -0700 Subject: [PATCH] fix kernel_size bug, name like torch, add test --- test/test_ops.py | 8 +++++++- tinygrad/ops.py | 14 +++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 04b262b469..76dd17bd0e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -32,7 +32,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=1e-7, grad_atol=1e-7): class TestOps(unittest.TestCase): def test_conv2d(self): - for bs in [1,128]: + for bs in [1,8]: for cin in [1,3]: for H in [2,5]: for W in [2,3,5]: @@ -43,6 +43,12 @@ class TestOps(unittest.TestCase): def test_maxpool2x2(self): helper_test_op([(32,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, (2,2)), Tensor.max_pool2d) + def test_maxpool_sizes(self): + for sz in [(2,2), (3,3), (3,2), (5,5), (5,1)]: + helper_test_op([(32,2,110,28)], + lambda x: torch.nn.functional.max_pool2d(x, kernel_size=sz), + lambda x: Tensor.max_pool2d(x, kernel_size=sz)) + def test_avgpool2x2(self): helper_test_op([(32,2,111,28)], lambda x: torch.nn.functional.avg_pool2d(x, (2,2)), Tensor.avg_pool2d) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index dda57fb800..d4feffc58a 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -130,13 +130,13 @@ def stack_for_pool(x, py, px): xup = x[:, :, :my, :mx] for Y in range(py): for X in range(px): - stack.append(xup[:, :, Y::2, X::2][None]) + stack.append(xup[:, :, Y::py, X::px][None]) return np.concatenate(stack, axis=0) class MaxPool2D(Function): @staticmethod - def forward(ctx, x, pool_size=(2, 2)): - stack = stack_for_pool(x, *pool_size) + def forward(ctx, x, kernel_size=(2, 2)): + stack = stack_for_pool(x, *kernel_size) idxs = np.argmax(stack, axis=0) ctx.save_for_backward(idxs, x.shape) return np.max(stack, axis=0) @@ -144,7 +144,7 @@ class MaxPool2D(Function): @staticmethod def backward(ctx, grad_output): idxs,s = ctx.saved_tensors - py, px = ctx.pool_size + py, px = ctx.kernel_size my, mx = (s[2]//py)*py, (s[3]//px)*px ret = np.zeros(s, dtype=grad_output.dtype) for Y in range(py): @@ -155,15 +155,15 @@ register('max_pool2d', MaxPool2D) class AvgPool2D(Function): @staticmethod - def forward(ctx, x, pool_size=(2, 2)): - stack = stack_for_pool(x, *pool_size) + def forward(ctx, x, kernel_size=(2, 2)): + stack = stack_for_pool(x, *kernel_size) ctx.save_for_backward(x.shape) return np.mean(stack, axis=0) @staticmethod def backward(ctx, grad_output): s, = ctx.saved_tensors - py, px = ctx.pool_size + py, px = ctx.kernel_size my, mx = (s[2]//py)*py, (s[3]//px)*px ret = np.zeros(s, dtype=grad_output.dtype) for Y in range(py):