logsoftmax works

This commit is contained in:
George Hotz
2020-10-18 09:22:32 -07:00
parent 3b81dbe41c
commit fb1004103d

View File

@@ -14,7 +14,11 @@ class Context:
self.saved_tensors.extend(x)
class Tensor:
def __init__(self, data, _children=()):
def __init__(self, data):
#print(type(data), data)
if type(data) != np.ndarray:
print("error constructing tensor with %r" % data)
assert(False)
self.data = data
self.grad = None
@@ -22,7 +26,7 @@ class Tensor:
self._ctx = None
def __str__(self):
return "Tensor of shape %r with grad %r" % (self.data.shape, self.grad)
return "Tensor %r with grad %r" % (self.data, self.grad)
def backward(self, allow_fill=True):
#print("running backward on", self)
@@ -37,7 +41,12 @@ class Tensor:
assert(self.grad is not None)
grads = self._ctx.arg.backward(self._ctx, self.grad)
if len(self._ctx.parents) == 1:
grads = [grads]
for t,g in zip(self._ctx.parents, grads):
if g.shape != t.data.shape:
print("grad shape must match tensor shape in %r, %r != %r" % (self._ctx.arg, g.shape, t.data.shape))
assert(False)
t.grad = g
t.backward(False)
@@ -64,7 +73,7 @@ class ReLU(Function):
input, = ctx.saved_tensors
grad_input = grad_output.copy()
grad_input[input < 0] = 0
return grad_input,
return grad_input
register('relu', ReLU)
class Dot(Function):
@@ -85,11 +94,27 @@ class Sum(Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.sum()
return np.array(input.sum())
@staticmethod
def backward(ctx, grad_output):
input = ctx.saved_tensors
input, = ctx.saved_tensors
return grad_output * np.ones_like(input)
register('sum', Sum)
class LogSoftmax(Function):
@staticmethod
def forward(ctx, input):
def logsumexp(x):
c = x.max(axis=1)
return c + np.log(np.exp(x-c.reshape((-1, 1))).sum(axis=1))
output = input - logsumexp(input)
ctx.save_for_backward(output)
return output
@staticmethod
def backward(ctx, grad_output):
output, = ctx.saved_tensors
return grad_output - np.exp(output)*grad_output.sum(axis=1).reshape((-1, 1))
register('logsoftmax', LogSoftmax)