Fix .std() tests on torch=1.13 (#904)

This commit is contained in:
Alexey Zaytsev
2023-06-02 21:33:51 +07:00
committed by GitHub
parent 4d28d55683
commit 5feee9c94b

View File

@@ -305,8 +305,8 @@ class TestOps(unittest.TestCase):
helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2)))
def test_std(self):
helper_test_op([(45, 65, 85)], lambda x: torch.std(x), lambda x: Tensor.std(x))
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, correction=0), lambda x: Tensor.std(x, correction=0))
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, correction=5), lambda x: Tensor.std(x, correction=5))
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=None, correction=0), lambda x: Tensor.std(x, correction=0))
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=None, correction=5), lambda x: Tensor.std(x, correction=5))
def test_std_axis(self):
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=0), lambda x: Tensor.std(x, axis=0))
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=2), lambda x: Tensor.std(x, axis=2))
@@ -317,8 +317,8 @@ class TestOps(unittest.TestCase):
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, correction=0, dim=[1, 2]), lambda x: Tensor.std(x, axis=[1, 2], correction=0))
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, correction=0, dim=None), lambda x: Tensor.std(x, axis=None, correction=0))
def test_std_keepdim(self):
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, keepdim=True), lambda x: Tensor.std(x, keepdim=True))
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, keepdim=True, correction=0, dim=0), lambda x: Tensor.std(x, keepdim=True, correction=0, axis=0))
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=None, keepdim=True), lambda x: Tensor.std(x, keepdim=True))
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=0, keepdim=True, correction=0), lambda x: Tensor.std(x, keepdim=True, correction=0, axis=0))
def test_log_softmax(self):
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
helper_test_op([()], lambda x: torch.nn.LogSoftmax(dim=0)(x), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
@@ -886,7 +886,7 @@ class TestOps(unittest.TestCase):
with self.assertRaises(IndexError):
Tensor.stack([x], dim=77)
a = Tensor(3.14)
np.testing.assert_allclose(Tensor.stack([a, a]).numpy(), Tensor([3.14, 3.14]).numpy())