diff --git a/test/test_ops.py b/test/test_ops.py index 1751be483c..a5e7dec334 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -61,6 +61,8 @@ class TestOps(unittest.TestCase): helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum, gpu=self.gpu) def test_sum_axis(self): helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)), gpu=self.gpu) + def test_mean_axis(self): + helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2)), gpu=self.gpu) def test_logsoftmax(self): helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.logsoftmax, atol=1e-7, grad_atol=1e-7, gpu=self.gpu) def test_tanh(self): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e40b465d1a..4f0eed933c 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -194,9 +194,11 @@ class Tensor: # ***** non first class ops ***** - def mean(self): - div = Tensor(np.array([1/np.prod(self.shape)], dtype=self.dtype), gpu=self.gpu, requires_grad=False) - return self.sum().mul(div) + def mean(self, axis=None): + out = self.sum(axis=axis) + coeff = np.prod(out.shape)/np.prod(self.shape) + div = Tensor(coeff+np.zeros(out.shape, dtype=self.dtype), gpu=self.gpu, requires_grad=False) + return out.mul(div) def sqrt(self): root = Tensor(np.zeros(self.shape, dtype=self.dtype)+0.5, gpu=self.gpu, requires_grad=False)