mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
add support for adam
This commit is contained in:
@@ -40,5 +40,4 @@ print(y.grad) # dz/dy
|
||||
|
||||
* Implement gradcheck (numeric)
|
||||
* Implement convolutions
|
||||
* Implement Adam optimizer
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ class TinyBobNet:
|
||||
|
||||
model = TinyBobNet()
|
||||
optim = optim.SGD([model.l1, model.l2], lr=0.01)
|
||||
#optim = optim.Adam([model.l1, model.l2], lr=0.001)
|
||||
|
||||
BS = 128
|
||||
losses, accuracies = [], []
|
||||
@@ -51,7 +52,6 @@ for i in (t := trange(1000)):
|
||||
cat = np.argmax(outs.data, axis=1)
|
||||
accuracy = (cat == Y).mean()
|
||||
|
||||
|
||||
# printing
|
||||
loss = loss.data
|
||||
losses.append(loss)
|
||||
|
||||
@@ -1,9 +1,37 @@
|
||||
class SGD:
|
||||
def __init__(self, tensors, lr):
|
||||
self.tensors = tensors
|
||||
import numpy as np
|
||||
|
||||
class Optimizer:
|
||||
def __init__(self, params):
|
||||
self.params = params
|
||||
|
||||
class SGD(Optimizer):
|
||||
def __init__(self, params, lr=0.001):
|
||||
super(SGD, self).__init__(params)
|
||||
self.lr = lr
|
||||
|
||||
def step(self):
|
||||
for t in self.tensors:
|
||||
for t in self.params:
|
||||
t.data -= self.lr * t.grad
|
||||
|
||||
# 80% sure this is right?
|
||||
class Adam(Optimizer):
|
||||
def __init__(self, params, lr=0.001, b1=0.9, b2=0.999, eps=1e-8):
|
||||
super(Adam, self).__init__(params)
|
||||
self.lr = lr
|
||||
self.b1 = b1
|
||||
self.b2 = b2
|
||||
self.eps = eps
|
||||
self.t = 0
|
||||
|
||||
self.m = [np.zeros_like(t.data) for t in self.params]
|
||||
self.v = [np.zeros_like(t.data) for t in self.params]
|
||||
|
||||
def step(self):
|
||||
for i,t in enumerate(self.params):
|
||||
self.t += 1
|
||||
self.m[i] = self.b1 * self.m[i] + (1 - self.b1) * t.grad
|
||||
self.v[i] = self.b2 * self.v[i] + (1 - self.b2) * np.square(t.grad)
|
||||
mhat = self.m[i] / (1. - self.b1**self.t)
|
||||
vhat = self.v[i] / (1. - self.b2**self.t)
|
||||
t.data -= self.lr * mhat / (np.sqrt(vhat) + self.eps)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user