fix Tensor.var with 0 in reduce dim. (#3324)

fix when correction is too big. it seems to only work when input size is 0 though.
torch can output -inf in var when correction is too big, which does not make sense.
This commit is contained in:
chenyu
2024-02-05 20:59:13 -05:00
committed by GitHub
parent ee25f73283
commit d9ef8e25b3
2 changed files with 12 additions and 2 deletions

View File

@@ -684,12 +684,14 @@ class TestOps(unittest.TestCase):
def test_mean_axis(self):
helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)))
def test_mean_zero_axis(self):
helper_test_op([], lambda: torch.ones((1,0,3,0,5)).mean(axis=(1,3)), lambda: Tensor.ones((1,0,3,0,5)).mean(axis=(1,3)), forward_only=True)
helper_test_op([(1,0,3,0,5)], lambda x: x.mean(axis=(1,3)))
def test_var(self):
helper_test_op([(15, 25, 35)], lambda x: x.var())
helper_test_op([(15, 25, 35)], lambda x: x.var(correction=0))
helper_test_op([(15, 25, 35)], lambda x: x.var(correction=5))
# TODO: fix this
# helper_test_op([(10, 2)], lambda x: x.var(correction=50))
def test_var_axis(self):
helper_test_op([(15, 25, 35)], lambda x: x.var(0))
helper_test_op([(15, 25, 35)], lambda x: x.var(2))
@@ -697,6 +699,10 @@ class TestOps(unittest.TestCase):
helper_test_op([(15, 25, 35)], lambda x: x.var(0, correction=0))
helper_test_op([(15, 25, 35)], lambda x: x.var(2, correction=0))
helper_test_op([(15, 25, 35)], lambda x: x.var([1, 2], correction=0))
def test_var_zero_axis(self):
helper_test_op([(1,0,3,0,5)], lambda x: x.var(axis=(1,3)))
helper_test_op([(1,0,3,0,5)], lambda x: x.var(axis=(1,3), correction=0))
helper_test_op([(1,0,3,0,5)], lambda x: x.var(axis=(1,3), correction=5))
def test_var_keepdim(self):
helper_test_op([(15, 25, 35)], lambda x: x.var(keepdim=True))
helper_test_op([(15, 25, 35)], lambda x: x.var(0, keepdim=True, correction=0))
@@ -712,6 +718,10 @@ class TestOps(unittest.TestCase):
helper_test_op([(15, 25, 35)], lambda x: x.std(0, correction=0))
helper_test_op([(15, 25, 35)], lambda x: x.std(2, correction=0))
helper_test_op([(15, 25, 35)], lambda x: x.std([1, 2], correction=0))
def test_std_zero_axis(self):
helper_test_op([(1,0,3,0,5)], lambda x: x.std(axis=(1,3)))
helper_test_op([(1,0,3,0,5)], lambda x: x.std(axis=(1,3), correction=0))
helper_test_op([(1,0,3,0,5)], lambda x: x.std(axis=(1,3), correction=5))
def test_std_keepdim(self):
helper_test_op([(15, 25, 35)], lambda x: x.std(keepdim=True))
helper_test_op([(15, 25, 35)], lambda x: x.std(0, keepdim=True, correction=0))

View File

@@ -546,7 +546,7 @@ class Tensor:
def var(self, axis=None, keepdim=False, correction=1):
assert all_int(self.shape), "does not support symbolic shape"
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
return square_sum.div(prod(self.shape)/prod(square_sum.shape)-correction)
return square_sum.div(max(0, prod(self.shape)/prod(square_sum.shape)-correction))
def std(self, axis=None, keepdim=False, correction=1): return self.var(axis, keepdim, correction).sqrt()
def _softmax(self, axis):