diff --git a/test/test_ops.py b/test/test_ops.py index 8862e54a91..e1272a4efe 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -258,14 +258,21 @@ class TestOps(unittest.TestCase): def test_mean_axis(self): helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2))) def test_std(self): - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, unbiased=False), lambda x: Tensor.std(x)) + helper_test_op([(45, 65, 85)], lambda x: torch.std(x), lambda x: Tensor.std(x)) + helper_test_op([(45, 65, 85)], lambda x: torch.std(x, correction=0), lambda x: Tensor.std(x, correction=0)) + helper_test_op([(45, 65, 85)], lambda x: torch.std(x, correction=5), lambda x: Tensor.std(x, correction=5)) def test_std_axis(self): - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, unbiased=False, dim=0), lambda x: Tensor.std(x, axis=0)) - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, unbiased=False, dim=2), lambda x: Tensor.std(x, axis=2)) - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, unbiased=False, dim=[1, 2]), lambda x: Tensor.std(x, axis=[1, 2])) - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, unbiased=False, dim=None), lambda x: Tensor.std(x, axis=None)) + helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=0), lambda x: Tensor.std(x, axis=0)) + helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=2), lambda x: Tensor.std(x, axis=2)) + helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=[1, 2]), lambda x: Tensor.std(x, axis=[1, 2])) + helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=None), lambda x: Tensor.std(x, axis=None)) + helper_test_op([(45, 65, 85)], lambda x: torch.std(x, correction=0, dim=0), lambda x: Tensor.std(x, axis=0, correction=0)) + helper_test_op([(45, 65, 85)], lambda x: torch.std(x, correction=0, dim=2), lambda x: Tensor.std(x, axis=2, correction=0)) + helper_test_op([(45, 65, 85)], lambda x: torch.std(x, correction=0, dim=[1, 2]), lambda x: Tensor.std(x, axis=[1, 2], correction=0)) + helper_test_op([(45, 65, 85)], lambda x: torch.std(x, correction=0, dim=None), lambda x: Tensor.std(x, axis=None, correction=0)) def test_std_keepdim(self): helper_test_op([(45, 65, 85)], lambda x: torch.std(x, keepdim=True), lambda x: Tensor.std(x, keepdim=True)) + helper_test_op([(45, 65, 85)], lambda x: torch.std(x, keepdim=True, correction=0, dim=0), lambda x: Tensor.std(x, keepdim=True, correction=0, axis=0)) def test_log_softmax(self): helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7) def test_log_softmax_other_axis(self): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index b31abb9188..c4d8e2a9a7 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -336,10 +336,9 @@ class Tensor: def mean(self, axis=None, keepdim=False): out = self.sum(axis=axis, keepdim=keepdim) return out * (prod(out.shape)/prod(self.shape)) - # TODO: implement unbiased True option for torch bessel's correction (subtracting 1 from divisor causes 0.01 error) - def std(self, axis=None, keepdim=False): - square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim) - return (square_sum * (prod(square_sum.shape)/prod(self.shape))).sqrt() + def std(self, axis=None, keepdim=False, correction=1): + square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim) + return (square_sum / (prod(self.shape)/prod(square_sum.shape)-correction)).sqrt() def _softmax(self, axis): m = self - self.max(axis=axis, keepdim=True) e = m.exp()