speed up sum

This commit is contained in:
George Hotz
2021-06-17 16:38:34 -07:00
parent e8eb7d1b7e
commit 2affd226b3
2 changed files with 34 additions and 12 deletions

View File

@@ -83,6 +83,8 @@ class TestOps(unittest.TestCase):
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
def test_sum(self):
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum)
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=3), lambda x: Tensor.sum(x, axis=3))
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,3)), lambda x: Tensor.sum(x, axis=(1,3)))
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(0,2)), lambda x: Tensor.sum(x, axis=(0,2)))
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)))
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1))