mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
refactor tinygrad to be more tiny
This commit is contained in:
@@ -38,6 +38,7 @@ print(y.grad) # dz/dy
|
||||
|
||||
### TODO (to make real neural network library)
|
||||
|
||||
* Implement gradcheck (numeric)
|
||||
* Implement convolutions
|
||||
* Implement Adam optimizer
|
||||
|
||||
|
||||
@@ -2,16 +2,7 @@
|
||||
from functools import partialmethod
|
||||
import numpy as np
|
||||
|
||||
# **** start with three base classes ****
|
||||
|
||||
class Context:
|
||||
def __init__(self, arg, *tensors):
|
||||
self.arg = arg
|
||||
self.parents = tensors
|
||||
self.saved_tensors = []
|
||||
|
||||
def save_for_backward(self, *x):
|
||||
self.saved_tensors.extend(x)
|
||||
# **** start with two base classes ****
|
||||
|
||||
class Tensor:
|
||||
def __init__(self, data):
|
||||
@@ -35,17 +26,18 @@ class Tensor:
|
||||
|
||||
if self.grad is None and allow_fill:
|
||||
# fill in the first grad with one
|
||||
# this is "implicit gradient creation"
|
||||
assert self.data.size == 1
|
||||
self.grad = np.ones_like(self.data)
|
||||
|
||||
assert(self.grad is not None)
|
||||
|
||||
grads = self._ctx.arg.backward(self._ctx, self.grad)
|
||||
grads = self._ctx.backward(self._ctx, self.grad)
|
||||
if len(self._ctx.parents) == 1:
|
||||
grads = [grads]
|
||||
for t,g in zip(self._ctx.parents, grads):
|
||||
if g.shape != t.data.shape:
|
||||
print("grad shape must match tensor shape in %r, %r != %r" % (self._ctx.arg, g.shape, t.data.shape))
|
||||
print("grad shape must match tensor shape in %r, %r != %r" % (self._ctx, g.shape, t.data.shape))
|
||||
assert(False)
|
||||
t.grad = g
|
||||
t.backward(False)
|
||||
@@ -54,9 +46,18 @@ class Tensor:
|
||||
div = Tensor(np.array([1/self.data.size]))
|
||||
return self.sum().mul(div)
|
||||
|
||||
# The Function is the Context
|
||||
class Function:
|
||||
def __init__(self, *tensors):
|
||||
self.parents = tensors
|
||||
self.saved_tensors = []
|
||||
|
||||
def save_for_backward(self, *x):
|
||||
self.saved_tensors.extend(x)
|
||||
|
||||
# note that due to how partialmethod works, self and arg are switched
|
||||
def apply(self, arg, *x):
|
||||
ctx = Context(arg, self, *x)
|
||||
ctx = arg(self, *x)
|
||||
ret = Tensor(arg.forward(ctx, self.data, *[t.data for t in x]))
|
||||
ret._ctx = ctx
|
||||
return ret
|
||||
|
||||
Reference in New Issue
Block a user