huh...this is faster

This commit is contained in:
George Hotz
2023-04-18 17:36:31 -07:00
parent dbc99c243b
commit aedd4685fa

View File

@@ -30,9 +30,9 @@ class SGD(Optimizer):
def step(self) -> None:
for i, t in enumerate(self.params):
assert t.grad is not None
g = t.grad
g = t.grad.realize()
if self.momentum:
self.b[i].assign(self.momentum * self.b[i] + g).realize() # NOTE: self.b[i] is zero on the first run, no if required
self.b[i].assign((self.momentum * self.b[i] + g).realize()) # NOTE: self.b[i] is zero on the first run, no if required
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
t.assign((t.detach() - g * self.lr).realize())
self.realize(self.b)