mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-15 01:48:23 -05:00
logsoftmax good, div bad
This commit is contained in:
@@ -233,13 +233,17 @@ class Tensor:
|
||||
return e.div(ss)
|
||||
|
||||
def logsoftmax(self):
|
||||
return self.softmax().log()
|
||||
ns = list(self.shape)[:-1]+[1]
|
||||
# TODO: logsumexp stability with max
|
||||
ss = self.exp().sum(axis=len(self.shape)-1).reshape(shape=ns).log()
|
||||
return self - ss
|
||||
|
||||
def dropout(self, p=0.5):
|
||||
# TODO: this needs a test
|
||||
if Tensor.training:
|
||||
_mask = np.asarray(np.random.binomial(1, 1.0-p, size=self.shape), dtype=self.dtype)
|
||||
ret = self * Tensor(_mask, requires_grad=False, device=self.device)
|
||||
return ret.div(1.0 - p)
|
||||
return ret * (1/(1.0 - p))
|
||||
else:
|
||||
return self
|
||||
|
||||
@@ -254,7 +258,7 @@ class Tensor:
|
||||
return self._pool2d(*kernel_size).mean(axis=(3,5))
|
||||
|
||||
def max_pool2d(self, kernel_size=(2,2)):
|
||||
# TODO: support tuples in max
|
||||
# TODO: support tuples in max and avoid a copy
|
||||
return self._pool2d(*kernel_size).max(axis=5).max(axis=3)
|
||||
|
||||
# An instantiation of the Function is the Context
|
||||
|
||||
Reference in New Issue
Block a user