Mean axis (doesn't work) (#154)

* mean axis

* fixed
This commit is contained in:
George Hotz
2020-12-07 22:58:34 -08:00
committed by GitHub
parent 38f97c8c80
commit b355cd2571
2 changed files with 7 additions and 3 deletions

View File

@@ -61,6 +61,8 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum, gpu=self.gpu)
def test_sum_axis(self):
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)), gpu=self.gpu)
def test_mean_axis(self):
helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2)), gpu=self.gpu)
def test_logsoftmax(self):
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.logsoftmax, atol=1e-7, grad_atol=1e-7, gpu=self.gpu)
def test_tanh(self):

View File

@@ -194,9 +194,11 @@ class Tensor:
# ***** non first class ops *****
def mean(self):
div = Tensor(np.array([1/np.prod(self.shape)], dtype=self.dtype), gpu=self.gpu, requires_grad=False)
return self.sum().mul(div)
def mean(self, axis=None):
out = self.sum(axis=axis)
coeff = np.prod(out.shape)/np.prod(self.shape)
div = Tensor(coeff+np.zeros(out.shape, dtype=self.dtype), gpu=self.gpu, requires_grad=False)
return out.mul(div)
def sqrt(self):
root = Tensor(np.zeros(self.shape, dtype=self.dtype)+0.5, gpu=self.gpu, requires_grad=False)