mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
add support for adam
This commit is contained in:
@@ -40,5 +40,4 @@ print(y.grad) # dz/dy
|
|||||||
|
|
||||||
* Implement gradcheck (numeric)
|
* Implement gradcheck (numeric)
|
||||||
* Implement convolutions
|
* Implement convolutions
|
||||||
* Implement Adam optimizer
|
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ class TinyBobNet:
|
|||||||
|
|
||||||
model = TinyBobNet()
|
model = TinyBobNet()
|
||||||
optim = optim.SGD([model.l1, model.l2], lr=0.01)
|
optim = optim.SGD([model.l1, model.l2], lr=0.01)
|
||||||
|
#optim = optim.Adam([model.l1, model.l2], lr=0.001)
|
||||||
|
|
||||||
BS = 128
|
BS = 128
|
||||||
losses, accuracies = [], []
|
losses, accuracies = [], []
|
||||||
@@ -51,7 +52,6 @@ for i in (t := trange(1000)):
|
|||||||
cat = np.argmax(outs.data, axis=1)
|
cat = np.argmax(outs.data, axis=1)
|
||||||
accuracy = (cat == Y).mean()
|
accuracy = (cat == Y).mean()
|
||||||
|
|
||||||
|
|
||||||
# printing
|
# printing
|
||||||
loss = loss.data
|
loss = loss.data
|
||||||
losses.append(loss)
|
losses.append(loss)
|
||||||
|
|||||||
@@ -1,9 +1,37 @@
|
|||||||
class SGD:
|
import numpy as np
|
||||||
def __init__(self, tensors, lr):
|
|
||||||
self.tensors = tensors
|
class Optimizer:
|
||||||
|
def __init__(self, params):
|
||||||
|
self.params = params
|
||||||
|
|
||||||
|
class SGD(Optimizer):
|
||||||
|
def __init__(self, params, lr=0.001):
|
||||||
|
super(SGD, self).__init__(params)
|
||||||
self.lr = lr
|
self.lr = lr
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
for t in self.tensors:
|
for t in self.params:
|
||||||
t.data -= self.lr * t.grad
|
t.data -= self.lr * t.grad
|
||||||
|
|
||||||
|
# 80% sure this is right?
|
||||||
|
class Adam(Optimizer):
|
||||||
|
def __init__(self, params, lr=0.001, b1=0.9, b2=0.999, eps=1e-8):
|
||||||
|
super(Adam, self).__init__(params)
|
||||||
|
self.lr = lr
|
||||||
|
self.b1 = b1
|
||||||
|
self.b2 = b2
|
||||||
|
self.eps = eps
|
||||||
|
self.t = 0
|
||||||
|
|
||||||
|
self.m = [np.zeros_like(t.data) for t in self.params]
|
||||||
|
self.v = [np.zeros_like(t.data) for t in self.params]
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
for i,t in enumerate(self.params):
|
||||||
|
self.t += 1
|
||||||
|
self.m[i] = self.b1 * self.m[i] + (1 - self.b1) * t.grad
|
||||||
|
self.v[i] = self.b2 * self.v[i] + (1 - self.b2) * np.square(t.grad)
|
||||||
|
mhat = self.m[i] / (1. - self.b1**self.t)
|
||||||
|
vhat = self.v[i] / (1. - self.b2**self.t)
|
||||||
|
t.data -= self.lr * mhat / (np.sqrt(vhat) + self.eps)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user