mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
Fix .std() tests on torch=1.13 (#904)
This commit is contained in:
@@ -305,8 +305,8 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2)))
|
||||
def test_std(self):
|
||||
helper_test_op([(45, 65, 85)], lambda x: torch.std(x), lambda x: Tensor.std(x))
|
||||
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, correction=0), lambda x: Tensor.std(x, correction=0))
|
||||
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, correction=5), lambda x: Tensor.std(x, correction=5))
|
||||
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=None, correction=0), lambda x: Tensor.std(x, correction=0))
|
||||
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=None, correction=5), lambda x: Tensor.std(x, correction=5))
|
||||
def test_std_axis(self):
|
||||
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=0), lambda x: Tensor.std(x, axis=0))
|
||||
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=2), lambda x: Tensor.std(x, axis=2))
|
||||
@@ -317,8 +317,8 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, correction=0, dim=[1, 2]), lambda x: Tensor.std(x, axis=[1, 2], correction=0))
|
||||
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, correction=0, dim=None), lambda x: Tensor.std(x, axis=None, correction=0))
|
||||
def test_std_keepdim(self):
|
||||
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, keepdim=True), lambda x: Tensor.std(x, keepdim=True))
|
||||
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, keepdim=True, correction=0, dim=0), lambda x: Tensor.std(x, keepdim=True, correction=0, axis=0))
|
||||
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=None, keepdim=True), lambda x: Tensor.std(x, keepdim=True))
|
||||
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=0, keepdim=True, correction=0), lambda x: Tensor.std(x, keepdim=True, correction=0, axis=0))
|
||||
def test_log_softmax(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([()], lambda x: torch.nn.LogSoftmax(dim=0)(x), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
|
||||
@@ -886,7 +886,7 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
with self.assertRaises(IndexError):
|
||||
Tensor.stack([x], dim=77)
|
||||
|
||||
|
||||
a = Tensor(3.14)
|
||||
np.testing.assert_allclose(Tensor.stack([a, a]).numpy(), Tensor([3.14, 3.14]).numpy())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user