This commit is contained in:
George Hotz
2020-12-28 23:54:52 -05:00
parent bcb3ceeca3
commit 36579f66bf
3 changed files with 27 additions and 2 deletions

View File

@@ -80,6 +80,12 @@ class TestOps(unittest.TestCase):
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, device=self.device)
def test_sum(self):
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum, device=self.device)
@cpu_only
def test_max(self):
helper_test_op([(45,3)], lambda x: x.max(), Tensor.max, device=self.device)
@cpu_only
def test_max_axis(self):
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1), device=self.device)
def test_sum_axis(self):
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)), device=self.device)
def test_mean_axis(self):