always keep batch size out front

This commit is contained in:
George Hotz
2020-10-25 08:14:07 -07:00
parent b91fd3afad
commit 935f5ddaaa
4 changed files with 25 additions and 11 deletions

View File

@@ -78,8 +78,8 @@ class TestOps(unittest.TestCase):
out.mean().backward()
ret.mean().backward()
np.testing.assert_allclose(w.grad, wt.grad, atol=1e-5)
np.testing.assert_allclose(x.grad, xt.grad, atol=1e-5)
np.testing.assert_allclose(w.grad, wt.grad, atol=1e-7)
np.testing.assert_allclose(x.grad, xt.grad, atol=1e-7)
def test_maxpool2x2(self):
x = torch.randn((5,2,10,8), requires_grad=True)