mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
fix Tensor.var with 0 in reduce dim. (#3324)
fix when correction is too big. it seems to only work when input size is 0 though. torch can output -inf in var when correction is too big, which does not make sense.
This commit is contained in:
@@ -684,12 +684,14 @@ class TestOps(unittest.TestCase):
|
||||
def test_mean_axis(self):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)))
|
||||
def test_mean_zero_axis(self):
|
||||
helper_test_op([], lambda: torch.ones((1,0,3,0,5)).mean(axis=(1,3)), lambda: Tensor.ones((1,0,3,0,5)).mean(axis=(1,3)), forward_only=True)
|
||||
helper_test_op([(1,0,3,0,5)], lambda x: x.mean(axis=(1,3)))
|
||||
|
||||
def test_var(self):
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var())
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var(correction=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var(correction=5))
|
||||
# TODO: fix this
|
||||
# helper_test_op([(10, 2)], lambda x: x.var(correction=50))
|
||||
def test_var_axis(self):
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var(0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var(2))
|
||||
@@ -697,6 +699,10 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var(0, correction=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var(2, correction=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var([1, 2], correction=0))
|
||||
def test_var_zero_axis(self):
|
||||
helper_test_op([(1,0,3,0,5)], lambda x: x.var(axis=(1,3)))
|
||||
helper_test_op([(1,0,3,0,5)], lambda x: x.var(axis=(1,3), correction=0))
|
||||
helper_test_op([(1,0,3,0,5)], lambda x: x.var(axis=(1,3), correction=5))
|
||||
def test_var_keepdim(self):
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var(keepdim=True))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var(0, keepdim=True, correction=0))
|
||||
@@ -712,6 +718,10 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.std(0, correction=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.std(2, correction=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.std([1, 2], correction=0))
|
||||
def test_std_zero_axis(self):
|
||||
helper_test_op([(1,0,3,0,5)], lambda x: x.std(axis=(1,3)))
|
||||
helper_test_op([(1,0,3,0,5)], lambda x: x.std(axis=(1,3), correction=0))
|
||||
helper_test_op([(1,0,3,0,5)], lambda x: x.std(axis=(1,3), correction=5))
|
||||
def test_std_keepdim(self):
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.std(keepdim=True))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.std(0, keepdim=True, correction=0))
|
||||
|
||||
@@ -546,7 +546,7 @@ class Tensor:
|
||||
def var(self, axis=None, keepdim=False, correction=1):
|
||||
assert all_int(self.shape), "does not support symbolic shape"
|
||||
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
|
||||
return square_sum.div(prod(self.shape)/prod(square_sum.shape)-correction)
|
||||
return square_sum.div(max(0, prod(self.shape)/prod(square_sum.shape)-correction))
|
||||
def std(self, axis=None, keepdim=False, correction=1): return self.var(axis, keepdim, correction).sqrt()
|
||||
|
||||
def _softmax(self, axis):
|
||||
|
||||
Reference in New Issue
Block a user