GlobalCounters cache + assign in optim

This commit is contained in:
George Hotz
2023-02-08 17:10:55 -06:00
parent d9555bc478
commit a5a55ac19e
4 changed files with 12 additions and 12 deletions

View File

@@ -48,7 +48,7 @@ class RMSprop(Optimizer):
def step(self) -> None:
for i, t in enumerate(self.params):
assert t.grad is not None
self.v[i] = self.decay * self.v[i] + (1.0 - self.decay) * (t.grad * t.grad)
self.v[i].assign(self.decay * self.v[i] + (1.0 - self.decay) * (t.grad * t.grad))
t.assign(t.detach() - (t.grad * self.lr).div(self.v[i].sqrt() + self.eps))
self.realize(self.v)
@@ -65,8 +65,8 @@ class Adam(Optimizer):
a = self.lr * ((1.0 - self.b2**self.t)**0.5) / (1.0 - self.b1**self.t)
for i, t in enumerate(self.params):
assert t.grad is not None
self.m[i] = self.b1 * self.m[i] + (1.0 - self.b1) * t.grad
self.v[i] = self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad)
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad)
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad))
t.assign(t.detach() - a * self.m[i].div(self.v[i].sqrt() + self.eps))
self.realize([self.t] + self.m + self.v)