mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
remove realize from optimizer (#2880)
* remove realize from optimizer * one still needed * opt realize
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# sorted in order of increasing complexity
|
||||
from typing import List
|
||||
from tinygrad.helpers import dedup
|
||||
from tinygrad.helpers import dedup, getenv
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
class Optimizer:
|
||||
@@ -13,7 +13,7 @@ class Optimizer:
|
||||
assert len(self.params) != 0, "optimizer must have at least one param"
|
||||
self.device = self.params[0].device
|
||||
self.buffers: List[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
|
||||
self.lr = Tensor([lr], requires_grad=False, device=self.device).contiguous()
|
||||
self.lr = lr if getenv("CONST_LR") else Tensor([lr], requires_grad=False, device=self.device).contiguous()
|
||||
|
||||
def zero_grad(self):
|
||||
for param in self.params: param.grad = None
|
||||
@@ -32,9 +32,12 @@ class SGD(Optimizer):
|
||||
def step(self) -> None:
|
||||
for i, t in enumerate(self.params):
|
||||
assert t.grad is not None
|
||||
g = t.grad.realize() + self.wd * t.detach()
|
||||
# this is needed since the grads can form a "diamond"
|
||||
# TODO: fix this in lazy.py
|
||||
t.grad.realize()
|
||||
g = t.grad + self.wd * t.detach()
|
||||
if self.momentum:
|
||||
self.b[i].assign(self.momentum * self.b[i] + g).realize() # NOTE: self.b[i] is zero on the first run, no if required
|
||||
self.b[i].assign(self.momentum * self.b[i] + g) # NOTE: self.b[i] is zero on the first run, no if required
|
||||
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
|
||||
t.assign(t.detach() - g * self.lr)
|
||||
self.realize(self.b)
|
||||
@@ -51,12 +54,11 @@ class LAMB(Optimizer):
|
||||
self.v = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
|
||||
|
||||
def step(self) -> None:
|
||||
self.t.assign(self.t + 1).realize()
|
||||
self.t.assign(self.t + 1)
|
||||
for i, t in enumerate(self.params):
|
||||
assert t.grad is not None
|
||||
g = t.grad.realize()
|
||||
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * g).realize()
|
||||
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).realize()
|
||||
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad)
|
||||
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad))
|
||||
m_hat = self.m[i] / (1.0 - self.b1**self.t)
|
||||
v_hat = self.v[i] / (1.0 - self.b2**self.t)
|
||||
up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()
|
||||
|
||||
Reference in New Issue
Block a user