mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
test case for Tensor.var reducing over size = 1 axis (#4902)
backward failed when correction >= reducing n
This commit is contained in:
@@ -805,10 +805,18 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var(0, correction=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var(2, correction=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var([1, 2], correction=0))
|
||||
def test_var_zero_axis(self):
|
||||
def test_var_zero_in_axis(self):
|
||||
helper_test_op([(1,0,3,0,5)], lambda x: x.var(axis=(1,3)))
|
||||
helper_test_op([(1,0,3,0,5)], lambda x: x.var(axis=(1,3), correction=0))
|
||||
helper_test_op([(1,0,3,0,5)], lambda x: x.var(axis=(1,3), correction=5))
|
||||
# TODO: fix backward when correction >= n
|
||||
def test_var_one_in_axis(self):
|
||||
helper_test_op([(1,2,3,1,5)], lambda x: x.var(axis=(0,3)), forward_only=True)
|
||||
helper_test_op([(1,2,3,1,5)], lambda x: x.var(axis=(0,3), correction=0))
|
||||
helper_test_op([(1,2,3,1,5)], lambda x: x.var(axis=(0,3), correction=5), forward_only=True)
|
||||
helper_test_op([(1,2,3,1,5)], lambda x: x.var(axis=(0,4)))
|
||||
helper_test_op([(1,2,3,1,5)], lambda x: x.var(axis=(0,4), correction=0))
|
||||
helper_test_op([(1,2,3,1,5)], lambda x: x.var(axis=(0,4), correction=5), forward_only=True)
|
||||
def test_var_keepdim(self):
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var(keepdim=True))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.var(0, keepdim=True, correction=0))
|
||||
@@ -824,10 +832,18 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.std(0, correction=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.std(2, correction=0))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.std([1, 2], correction=0))
|
||||
def test_std_zero_axis(self):
|
||||
def test_std_zero_in_axis(self):
|
||||
helper_test_op([(1,0,3,0,5)], lambda x: x.std(axis=(1,3)))
|
||||
helper_test_op([(1,0,3,0,5)], lambda x: x.std(axis=(1,3), correction=0))
|
||||
helper_test_op([(1,0,3,0,5)], lambda x: x.std(axis=(1,3), correction=5))
|
||||
# TODO: fix backward when correction >= n
|
||||
def test_std_one_in_axis(self):
|
||||
helper_test_op([(1,2,3,1,5)], lambda x: x.std(axis=(0,3)), forward_only=True)
|
||||
helper_test_op([(1,2,3,1,5)], lambda x: x.std(axis=(0,3), correction=0))
|
||||
helper_test_op([(1,2,3,1,5)], lambda x: x.std(axis=(0,3), correction=5), forward_only=True)
|
||||
helper_test_op([(1,2,3,1,5)], lambda x: x.std(axis=(0,4)))
|
||||
helper_test_op([(1,2,3,1,5)], lambda x: x.std(axis=(0,4), correction=0))
|
||||
helper_test_op([(1,2,3,1,5)], lambda x: x.std(axis=(0,4), correction=5))
|
||||
def test_std_keepdim(self):
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.std(keepdim=True))
|
||||
helper_test_op([(15, 25, 35)], lambda x: x.std(0, keepdim=True, correction=0))
|
||||
|
||||
Reference in New Issue
Block a user