From fb1004103dd12abad819ab0d45e247fbb8ebfeaa Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 18 Oct 2020 09:22:32 -0700 Subject: [PATCH] logsoftmax works --- tensor.py | 35 ++++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/tensor.py b/tensor.py index 598a6f83a1..d69028c04e 100644 --- a/tensor.py +++ b/tensor.py @@ -14,7 +14,11 @@ class Context: self.saved_tensors.extend(x) class Tensor: - def __init__(self, data, _children=()): + def __init__(self, data): + #print(type(data), data) + if type(data) != np.ndarray: + print("error constructing tensor with %r" % data) + assert(False) self.data = data self.grad = None @@ -22,7 +26,7 @@ class Tensor: self._ctx = None def __str__(self): - return "Tensor of shape %r with grad %r" % (self.data.shape, self.grad) + return "Tensor %r with grad %r" % (self.data, self.grad) def backward(self, allow_fill=True): #print("running backward on", self) @@ -37,7 +41,12 @@ class Tensor: assert(self.grad is not None) grads = self._ctx.arg.backward(self._ctx, self.grad) + if len(self._ctx.parents) == 1: + grads = [grads] for t,g in zip(self._ctx.parents, grads): + if g.shape != t.data.shape: + print("grad shape must match tensor shape in %r, %r != %r" % (self._ctx.arg, g.shape, t.data.shape)) + assert(False) t.grad = g t.backward(False) @@ -64,7 +73,7 @@ class ReLU(Function): input, = ctx.saved_tensors grad_input = grad_output.copy() grad_input[input < 0] = 0 - return grad_input, + return grad_input register('relu', ReLU) class Dot(Function): @@ -85,11 +94,27 @@ class Sum(Function): @staticmethod def forward(ctx, input): ctx.save_for_backward(input) - return input.sum() + return np.array(input.sum()) @staticmethod def backward(ctx, grad_output): - input = ctx.saved_tensors + input, = ctx.saved_tensors return grad_output * np.ones_like(input) register('sum', Sum) +class LogSoftmax(Function): + @staticmethod + def forward(ctx, input): + def logsumexp(x): + c = x.max(axis=1) + return c + np.log(np.exp(x-c.reshape((-1, 1))).sum(axis=1)) + output = input - logsumexp(input) + ctx.save_for_backward(output) + return output + + @staticmethod + def backward(ctx, grad_output): + output, = ctx.saved_tensors + return grad_output - np.exp(output)*grad_output.sum(axis=1).reshape((-1, 1)) +register('logsoftmax', LogSoftmax) +