yayay test_sgd_gpu passes

This commit is contained in:
George Hotz
2020-11-07 08:48:17 -08:00
parent 98d1a4f740
commit 3302286e68
3 changed files with 24 additions and 13 deletions

View File

@@ -58,7 +58,7 @@ class TestOps(unittest.TestCase):
def test_dot(self):
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-5, gpu=self.gpu)
def test_sum(self):
helper_test_op([(45,1)], lambda x: x.sum(), Tensor.sum, atol=1e-5, gpu=self.gpu)
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum, atol=1e-4, gpu=self.gpu)
def test_logsoftmax(self):
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.logsoftmax, atol=1e-5, gpu=self.gpu)