remove realize from optimizer (#2880)

* remove realize from optimizer

* one still needed

* opt realize
This commit is contained in:
George Hotz
2023-12-20 16:42:41 -08:00
committed by GitHub
parent 1765849937
commit e1861ab65e
3 changed files with 38 additions and 16 deletions

View File

@@ -1,6 +1,6 @@
# sorted in order of increasing complexity
from typing import List
from tinygrad.helpers import dedup
from tinygrad.helpers import dedup, getenv
from tinygrad.tensor import Tensor
class Optimizer:
@@ -13,7 +13,7 @@ class Optimizer:
assert len(self.params) != 0, "optimizer must have at least one param"
self.device = self.params[0].device
self.buffers: List[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
self.lr = Tensor([lr], requires_grad=False, device=self.device).contiguous()
self.lr = lr if getenv("CONST_LR") else Tensor([lr], requires_grad=False, device=self.device).contiguous()
def zero_grad(self):
for param in self.params: param.grad = None
@@ -32,9 +32,12 @@ class SGD(Optimizer):
def step(self) -> None:
for i, t in enumerate(self.params):
assert t.grad is not None
g = t.grad.realize() + self.wd * t.detach()
# this is needed since the grads can form a "diamond"
# TODO: fix this in lazy.py
t.grad.realize()
g = t.grad + self.wd * t.detach()
if self.momentum:
self.b[i].assign(self.momentum * self.b[i] + g).realize() # NOTE: self.b[i] is zero on the first run, no if required
self.b[i].assign(self.momentum * self.b[i] + g) # NOTE: self.b[i] is zero on the first run, no if required
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
t.assign(t.detach() - g * self.lr)
self.realize(self.b)
@@ -51,12 +54,11 @@ class LAMB(Optimizer):
self.v = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
def step(self) -> None:
self.t.assign(self.t + 1).realize()
self.t.assign(self.t + 1)
for i, t in enumerate(self.params):
assert t.grad is not None
g = t.grad.realize()
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * g).realize()
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).realize()
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad)
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad))
m_hat = self.m[i] / (1.0 - self.b1**self.t)
v_hat = self.v[i] / (1.0 - self.b2**self.t)
up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()