From b27bcbe4b453c194b9fa3fd3a3040535f2269a0d Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 25 Oct 2020 18:40:01 -0700 Subject: [PATCH] avgpool and test refactor --- test/test_ops.py | 35 +++++++++++++++++++++++++++++++++++ test/test_tensor.py | 38 -------------------------------------- tinygrad/ops.py | 35 ++++++++++++++++++++++++++++------- 3 files changed, 63 insertions(+), 45 deletions(-) create mode 100644 test/test_ops.py diff --git a/test/test_ops.py b/test/test_ops.py new file mode 100644 index 0000000000..c9ff8675f6 --- /dev/null +++ b/test/test_ops.py @@ -0,0 +1,35 @@ +import torch +import numpy as np +import unittest +from tinygrad.tensor import Tensor + +def test_op(shps, f1, f2, atol=1e-7, grad_atol=1e-7): + ts = [torch.rand(x, requires_grad=True) for x in shps] + tst = [Tensor(x.detach().numpy()) for x in ts] + + out = f1(*ts) + ret = f2(*tst) + # TODO: why so inaccurate? + np.testing.assert_allclose(ret.data, out.detach().numpy(), atol=atol) + + out.mean().backward() + ret.mean().backward() + + for t, tt in zip(ts, tst): + np.testing.assert_allclose(t.grad, tt.grad, atol=grad_atol) + +class TestOps(unittest.TestCase): + def test_conv2d(self): + for cin in [1,2,3]: + for H in [2,3,5]: + for W in [2,3,5]: + test_op([(5,cin,10,7), (4,cin,H,W)], torch.nn.functional.conv2d, Tensor.conv2d, atol=1e-5) + + def test_maxpool2x2(self): + test_op([(5,2,11,8)], lambda x: torch.nn.functional.max_pool2d(x, (2,2)), Tensor.max_pool2d) + + def test_avgpool2x2(self): + test_op([(5,2,11,8)], lambda x: torch.nn.functional.avg_pool2d(x, (2,2)), Tensor.avg_pool2d) + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_tensor.py b/test/test_tensor.py index 676adef4af..c358f1737c 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -64,44 +64,6 @@ class TestTinygrad(unittest.TestCase): # coarse approx. since a "big" eps and the non-linearities of the model self.assertFalse(gradcheck(tiny_func, tiny_x, eps = 0.1)) -class TestOps(unittest.TestCase): - def test_conv2d(self): - for cin in [1,2,3]: - for H in [2,3,5]: - for W in [2,3,5]: - x = torch.randn((5,cin,10,7), requires_grad=True) - w = torch.randn((4,cin,H,W), requires_grad=True) - xt = Tensor(x.detach().numpy()) - wt = Tensor(w.detach().numpy()) - - out = torch.nn.functional.conv2d(x,w) - ret = Tensor.conv2d(xt, wt) - # TODO: why so inaccurate? - np.testing.assert_allclose(ret.data, out.detach().numpy(), atol=1e-5) - - out.relu().mean().backward() - ret.relu().mean().backward() - - np.testing.assert_allclose(w.grad, wt.grad, atol=1e-7) - np.testing.assert_allclose(x.grad, xt.grad, atol=1e-7) - - def test_maxpool2x2(self): - x = torch.randn((5,2,10,8), requires_grad=True) - xt = Tensor(x.detach().numpy()) - - # in tinygrad - ret = xt.max_pool2d() - assert ret.shape == (5,2,10//2,8//2) - ret.mean().backward() - - # in torch - out = torch.nn.MaxPool2d((2,2))(x) - out.mean().backward() - - # forward and backward the same - np.testing.assert_allclose(ret.data, out.detach().numpy(), atol=1e-5) - np.testing.assert_allclose(x.grad, xt.grad, atol=1e-5) - if __name__ == '__main__': unittest.main() diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 18fa9c579e..50695b92fa 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -124,16 +124,19 @@ class Conv2D(Function): return dx, dw register('conv2d', Conv2D) +def stack_for_pool(x, py, px): + my, mx = (x.shape[2]//py)*py, (x.shape[3]//px)*px + stack = [] + xup = x[:, :, :my, :mx] + for Y in range(py): + for X in range(px): + stack.append(xup[:, :, Y::2, X::2][None]) + return np.concatenate(stack, axis=0) + class MaxPool2D(Function): @staticmethod def forward(ctx, x): - my, mx = (x.shape[2]//2)*2, (x.shape[3]//2)*2 - stack = [] - xup = x[:, :, :my, :mx] - for Y in range(2): - for X in range(2): - stack.append(xup[:, :, Y::2, X::2][None]) - stack = np.concatenate(stack, axis=0) + stack = stack_for_pool(x, 2, 2) idxs = np.argmax(stack, axis=0) ctx.save_for_backward(idxs, x.shape) return np.max(stack, axis=0) @@ -149,3 +152,21 @@ class MaxPool2D(Function): return ret register('max_pool2d', MaxPool2D) +class AvgPool2D(Function): + @staticmethod + def forward(ctx, x): + stack = stack_for_pool(x, 2, 2) + ctx.save_for_backward(x.shape) + return np.mean(stack, axis=0) + + @staticmethod + def backward(ctx, grad_output): + s, = ctx.saved_tensors + my, mx = (s[2]//2)*2, (s[3]//2)*2 + ret = np.zeros(s, dtype=grad_output.dtype) + for Y in range(2): + for X in range(2): + ret[:, :, Y:my:2, X:mx:2] = grad_output/4 + return ret +register('avg_pool2d', AvgPool2D) +