avgpool and test refactor

This commit is contained in:
George Hotz
2020-10-25 18:40:01 -07:00
parent 4c42676cb6
commit b27bcbe4b4
3 changed files with 63 additions and 45 deletions

35
test/test_ops.py Normal file
View File

@@ -0,0 +1,35 @@
import torch
import numpy as np
import unittest
from tinygrad.tensor import Tensor
def test_op(shps, f1, f2, atol=1e-7, grad_atol=1e-7):
ts = [torch.rand(x, requires_grad=True) for x in shps]
tst = [Tensor(x.detach().numpy()) for x in ts]
out = f1(*ts)
ret = f2(*tst)
# TODO: why so inaccurate?
np.testing.assert_allclose(ret.data, out.detach().numpy(), atol=atol)
out.mean().backward()
ret.mean().backward()
for t, tt in zip(ts, tst):
np.testing.assert_allclose(t.grad, tt.grad, atol=grad_atol)
class TestOps(unittest.TestCase):
def test_conv2d(self):
for cin in [1,2,3]:
for H in [2,3,5]:
for W in [2,3,5]:
test_op([(5,cin,10,7), (4,cin,H,W)], torch.nn.functional.conv2d, Tensor.conv2d, atol=1e-5)
def test_maxpool2x2(self):
test_op([(5,2,11,8)], lambda x: torch.nn.functional.max_pool2d(x, (2,2)), Tensor.max_pool2d)
def test_avgpool2x2(self):
test_op([(5,2,11,8)], lambda x: torch.nn.functional.avg_pool2d(x, (2,2)), Tensor.avg_pool2d)
if __name__ == '__main__':
unittest.main()

View File

@@ -64,44 +64,6 @@ class TestTinygrad(unittest.TestCase):
# coarse approx. since a "big" eps and the non-linearities of the model
self.assertFalse(gradcheck(tiny_func, tiny_x, eps = 0.1))
class TestOps(unittest.TestCase):
def test_conv2d(self):
for cin in [1,2,3]:
for H in [2,3,5]:
for W in [2,3,5]:
x = torch.randn((5,cin,10,7), requires_grad=True)
w = torch.randn((4,cin,H,W), requires_grad=True)
xt = Tensor(x.detach().numpy())
wt = Tensor(w.detach().numpy())
out = torch.nn.functional.conv2d(x,w)
ret = Tensor.conv2d(xt, wt)
# TODO: why so inaccurate?
np.testing.assert_allclose(ret.data, out.detach().numpy(), atol=1e-5)
out.relu().mean().backward()
ret.relu().mean().backward()
np.testing.assert_allclose(w.grad, wt.grad, atol=1e-7)
np.testing.assert_allclose(x.grad, xt.grad, atol=1e-7)
def test_maxpool2x2(self):
x = torch.randn((5,2,10,8), requires_grad=True)
xt = Tensor(x.detach().numpy())
# in tinygrad
ret = xt.max_pool2d()
assert ret.shape == (5,2,10//2,8//2)
ret.mean().backward()
# in torch
out = torch.nn.MaxPool2d((2,2))(x)
out.mean().backward()
# forward and backward the same
np.testing.assert_allclose(ret.data, out.detach().numpy(), atol=1e-5)
np.testing.assert_allclose(x.grad, xt.grad, atol=1e-5)
if __name__ == '__main__':
unittest.main()

View File

@@ -124,16 +124,19 @@ class Conv2D(Function):
return dx, dw
register('conv2d', Conv2D)
def stack_for_pool(x, py, px):
my, mx = (x.shape[2]//py)*py, (x.shape[3]//px)*px
stack = []
xup = x[:, :, :my, :mx]
for Y in range(py):
for X in range(px):
stack.append(xup[:, :, Y::2, X::2][None])
return np.concatenate(stack, axis=0)
class MaxPool2D(Function):
@staticmethod
def forward(ctx, x):
my, mx = (x.shape[2]//2)*2, (x.shape[3]//2)*2
stack = []
xup = x[:, :, :my, :mx]
for Y in range(2):
for X in range(2):
stack.append(xup[:, :, Y::2, X::2][None])
stack = np.concatenate(stack, axis=0)
stack = stack_for_pool(x, 2, 2)
idxs = np.argmax(stack, axis=0)
ctx.save_for_backward(idxs, x.shape)
return np.max(stack, axis=0)
@@ -149,3 +152,21 @@ class MaxPool2D(Function):
return ret
register('max_pool2d', MaxPool2D)
class AvgPool2D(Function):
@staticmethod
def forward(ctx, x):
stack = stack_for_pool(x, 2, 2)
ctx.save_for_backward(x.shape)
return np.mean(stack, axis=0)
@staticmethod
def backward(ctx, grad_output):
s, = ctx.saved_tensors
my, mx = (s[2]//2)*2, (s[3]//2)*2
ret = np.zeros(s, dtype=grad_output.dtype)
for Y in range(2):
for X in range(2):
ret[:, :, Y:my:2, X:mx:2] = grad_output/4
return ret
register('avg_pool2d', AvgPool2D)