diff --git a/test/mnist.py b/test/mnist.py index 3bc1aa95fa..ffcead0a24 100644 --- a/test/mnist.py +++ b/test/mnist.py @@ -35,9 +35,20 @@ class TinyBobNet: def forward(self, x): return x.dot(self.l1).relu().dot(self.l2).logsoftmax() -model = TinyBobNet() +# optimizer + +class SGD: + def __init__(self, tensors, lr): + self.tensors = tensors + self.lr = lr + + def step(self): + for t in self.tensors: + t.data -= self.lr * t.grad + +model = TinyBobNet() +optim = SGD([model.l1, model.l2], lr=0.01) -lr = 0.01 BS = 128 losses, accuracies = [], [] for i in (t := trange(1000)): @@ -55,13 +66,11 @@ for i in (t := trange(1000)): # NLL loss function loss = outs.mul(y).mean() loss.backward() + optim.step() cat = np.argmax(outs.data, axis=1) accuracy = (cat == Y).mean() - # SGD - model.l1.data = model.l1.data - lr*model.l1.grad - model.l2.data = model.l2.data - lr*model.l2.grad # printing loss = loss.data diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 170cd9c909..13ef238b0c 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -37,7 +37,8 @@ class Tensor: grads = [grads] for t,g in zip(self._ctx.parents, grads): if g.shape != t.data.shape: - print("grad shape must match tensor shape in %r, %r != %r" % (self._ctx, g.shape, t.data.shape)) + print("grad shape must match tensor shape in %r, %r != %r" % + (self._ctx, g.shape, t.data.shape)) assert(False) t.grad = g t.backward(False) @@ -46,7 +47,7 @@ class Tensor: div = Tensor(np.array([1/self.data.size])) return self.sum().mul(div) -# The Function is the Context +# An instantiation of the Function is the Context class Function: def __init__(self, *tensors): self.parents = tensors