mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
logsoftmax works
This commit is contained in:
35
tensor.py
35
tensor.py
@@ -14,7 +14,11 @@ class Context:
|
||||
self.saved_tensors.extend(x)
|
||||
|
||||
class Tensor:
|
||||
def __init__(self, data, _children=()):
|
||||
def __init__(self, data):
|
||||
#print(type(data), data)
|
||||
if type(data) != np.ndarray:
|
||||
print("error constructing tensor with %r" % data)
|
||||
assert(False)
|
||||
self.data = data
|
||||
self.grad = None
|
||||
|
||||
@@ -22,7 +26,7 @@ class Tensor:
|
||||
self._ctx = None
|
||||
|
||||
def __str__(self):
|
||||
return "Tensor of shape %r with grad %r" % (self.data.shape, self.grad)
|
||||
return "Tensor %r with grad %r" % (self.data, self.grad)
|
||||
|
||||
def backward(self, allow_fill=True):
|
||||
#print("running backward on", self)
|
||||
@@ -37,7 +41,12 @@ class Tensor:
|
||||
assert(self.grad is not None)
|
||||
|
||||
grads = self._ctx.arg.backward(self._ctx, self.grad)
|
||||
if len(self._ctx.parents) == 1:
|
||||
grads = [grads]
|
||||
for t,g in zip(self._ctx.parents, grads):
|
||||
if g.shape != t.data.shape:
|
||||
print("grad shape must match tensor shape in %r, %r != %r" % (self._ctx.arg, g.shape, t.data.shape))
|
||||
assert(False)
|
||||
t.grad = g
|
||||
t.backward(False)
|
||||
|
||||
@@ -64,7 +73,7 @@ class ReLU(Function):
|
||||
input, = ctx.saved_tensors
|
||||
grad_input = grad_output.copy()
|
||||
grad_input[input < 0] = 0
|
||||
return grad_input,
|
||||
return grad_input
|
||||
register('relu', ReLU)
|
||||
|
||||
class Dot(Function):
|
||||
@@ -85,11 +94,27 @@ class Sum(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return input.sum()
|
||||
return np.array(input.sum())
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input = ctx.saved_tensors
|
||||
input, = ctx.saved_tensors
|
||||
return grad_output * np.ones_like(input)
|
||||
register('sum', Sum)
|
||||
|
||||
class LogSoftmax(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
def logsumexp(x):
|
||||
c = x.max(axis=1)
|
||||
return c + np.log(np.exp(x-c.reshape((-1, 1))).sum(axis=1))
|
||||
output = input - logsumexp(input)
|
||||
ctx.save_for_backward(output)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
output, = ctx.saved_tensors
|
||||
return grad_output - np.exp(output)*grad_output.sum(axis=1).reshape((-1, 1))
|
||||
register('logsoftmax', LogSoftmax)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user