mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
@@ -27,8 +27,8 @@ class Adam(Optimizer):
|
|||||||
self.v = [np.zeros_like(t.data) for t in self.params]
|
self.v = [np.zeros_like(t.data) for t in self.params]
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
|
self.t += 1
|
||||||
for i,t in enumerate(self.params):
|
for i,t in enumerate(self.params):
|
||||||
self.t += 1
|
|
||||||
self.m[i] = self.b1 * self.m[i] + (1 - self.b1) * t.grad
|
self.m[i] = self.b1 * self.m[i] + (1 - self.b1) * t.grad
|
||||||
self.v[i] = self.b2 * self.v[i] + (1 - self.b2) * np.square(t.grad)
|
self.v[i] = self.b2 * self.v[i] + (1 - self.b2) * np.square(t.grad)
|
||||||
mhat = self.m[i] / (1. - self.b1**self.t)
|
mhat = self.m[i] / (1. - self.b1**self.t)
|
||||||
|
|||||||
Reference in New Issue
Block a user