logsoftmax good, div bad

This commit is contained in:
George Hotz
2020-12-29 13:59:39 -05:00
parent f18801c7db
commit ea341c84fe

View File

@@ -233,13 +233,17 @@ class Tensor:
return e.div(ss)
def logsoftmax(self):
return self.softmax().log()
ns = list(self.shape)[:-1]+[1]
# TODO: logsumexp stability with max
ss = self.exp().sum(axis=len(self.shape)-1).reshape(shape=ns).log()
return self - ss
def dropout(self, p=0.5):
# TODO: this needs a test
if Tensor.training:
_mask = np.asarray(np.random.binomial(1, 1.0-p, size=self.shape), dtype=self.dtype)
ret = self * Tensor(_mask, requires_grad=False, device=self.device)
return ret.div(1.0 - p)
return ret * (1/(1.0 - p))
else:
return self
@@ -254,7 +258,7 @@ class Tensor:
return self._pool2d(*kernel_size).mean(axis=(3,5))
def max_pool2d(self, kernel_size=(2,2)):
# TODO: support tuples in max
# TODO: support tuples in max and avoid a copy
return self._pool2d(*kernel_size).max(axis=5).max(axis=3)
# An instantiation of the Function is the Context