improve conv testing

This commit is contained in:
George Hotz
2020-10-25 12:46:04 -07:00
parent ef24aac09e
commit bb98cdfef7
2 changed files with 14 additions and 12 deletions

View File

@@ -74,7 +74,6 @@ def evaluate(model):
assert accuracy > 0.95
class TestMNIST(unittest.TestCase):
@unittest.skip(reason="mad slow")
def test_conv(self):
np.random.seed(1337)
model = TinyConvNet()

View File

@@ -66,20 +66,23 @@ class TestTinygrad(unittest.TestCase):
class TestOps(unittest.TestCase):
def test_conv2d(self):
x = torch.randn((5,2,10,7), requires_grad=True)
w = torch.randn((4,2,3,2), requires_grad=True)
xt = Tensor(x.detach().numpy())
wt = Tensor(w.detach().numpy())
for cin in [1,2,3]:
for H in [2,3,5]:
for W in [2,3,5]:
x = torch.randn((5,cin,10,7), requires_grad=True)
w = torch.randn((4,cin,H,W), requires_grad=True)
xt = Tensor(x.detach().numpy())
wt = Tensor(w.detach().numpy())
out = torch.nn.functional.conv2d(x,w)
ret = Tensor.conv2d(xt, wt)
np.testing.assert_allclose(ret.data, out.detach().numpy(), atol=1e-5)
out = torch.nn.functional.conv2d(x,w)
ret = Tensor.conv2d(xt, wt)
np.testing.assert_allclose(ret.data, out.detach().numpy(), atol=1e-5)
out.mean().backward()
ret.mean().backward()
out.mean().backward()
ret.mean().backward()
np.testing.assert_allclose(w.grad, wt.grad, atol=1e-7)
np.testing.assert_allclose(x.grad, xt.grad, atol=1e-7)
np.testing.assert_allclose(w.grad, wt.grad, atol=1e-7)
np.testing.assert_allclose(x.grad, xt.grad, atol=1e-7)
def test_maxpool2x2(self):
x = torch.randn((5,2,10,8), requires_grad=True)