mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
GlobalCounters cache + assign in optim
This commit is contained in:
@@ -48,7 +48,7 @@ class RMSprop(Optimizer):
|
||||
def step(self) -> None:
|
||||
for i, t in enumerate(self.params):
|
||||
assert t.grad is not None
|
||||
self.v[i] = self.decay * self.v[i] + (1.0 - self.decay) * (t.grad * t.grad)
|
||||
self.v[i].assign(self.decay * self.v[i] + (1.0 - self.decay) * (t.grad * t.grad))
|
||||
t.assign(t.detach() - (t.grad * self.lr).div(self.v[i].sqrt() + self.eps))
|
||||
self.realize(self.v)
|
||||
|
||||
@@ -65,8 +65,8 @@ class Adam(Optimizer):
|
||||
a = self.lr * ((1.0 - self.b2**self.t)**0.5) / (1.0 - self.b1**self.t)
|
||||
for i, t in enumerate(self.params):
|
||||
assert t.grad is not None
|
||||
self.m[i] = self.b1 * self.m[i] + (1.0 - self.b1) * t.grad
|
||||
self.v[i] = self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad)
|
||||
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad)
|
||||
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad))
|
||||
t.assign(t.detach() - a * self.m[i].div(self.v[i].sqrt() + self.eps))
|
||||
self.realize([self.t] + self.m + self.v)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user