mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
refactor/softmax (#201)
* generalized logsoftmax and sigmoid with softmax * reverted sigmoid impl Co-authored-by: Iain Wong <iainwong@outlook.com>
This commit is contained in:
@@ -128,6 +128,10 @@ class ReLU(Function):
|
||||
return grad_output * (input >= 0)
|
||||
register('relu', ReLU)
|
||||
|
||||
def _exp_normalize(x, axis=None):
|
||||
y = np.exp(x - x.max(axis=axis, keepdims=True))
|
||||
return y / y.sum(axis=axis, keepdims=True)
|
||||
|
||||
class Sigmoid(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
@@ -149,18 +153,14 @@ register('sigmoid', Sigmoid)
|
||||
class LogSoftmax(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
def logsumexp(x):
|
||||
#return np.log(np.exp(x).sum(axis=1))
|
||||
c = x.max(axis=1)
|
||||
return c + np.log(np.exp(x-c.reshape((-1, 1))).sum(axis=1))
|
||||
output = input - logsumexp(input).reshape((-1, 1))
|
||||
ctx.save_for_backward(output)
|
||||
return output
|
||||
softmax = _exp_normalize(input, axis=1)
|
||||
ctx.save_for_backward(softmax)
|
||||
return np.log(softmax)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
output, = ctx.saved_tensors
|
||||
return grad_output - np.exp(output)*(grad_output.sum(axis=1).reshape((-1, 1)))
|
||||
softmax, = ctx.saved_tensors
|
||||
return grad_output - grad_output.sum(axis=1, keepdims=True)*softmax
|
||||
register('logsoftmax', LogSoftmax)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user