mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
make optim _step return update (#14906)
This commit is contained in:
@@ -8,7 +8,7 @@ class Optimizer:
|
||||
"""
|
||||
Base class for all optimizers.
|
||||
"""
|
||||
def __init__(self, params: list[Tensor], lr: float, fused=FUSE_OPTIM):
|
||||
def __init__(self, params: list[Tensor], lr: float, device=None, fused=FUSE_OPTIM):
|
||||
# if requires_grad is None, but being put into an optimizer, set it to True
|
||||
for x in params:
|
||||
if x.requires_grad is None: x.requires_grad_(True)
|
||||
@@ -16,18 +16,18 @@ class Optimizer:
|
||||
self.params: list[Tensor] = dedup([x for x in params if x.requires_grad])
|
||||
assert len(self.params) != 0, "optimizer must have at least one param"
|
||||
self.buffers: list[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
|
||||
self.device = device or self.params[0].device
|
||||
self.fused = fused
|
||||
# store lr in at least float32 precision
|
||||
self.lr = Tensor(lr if getenv("CONST_LR") else [lr], requires_grad=False, device=self.device,
|
||||
dtype=least_upper_dtype(dtypes.default_float, dtypes.float32))
|
||||
if self.fused: self.pos_params = list(itertools.accumulate(self.params, lambda x,y: x+y.numel(), initial=0))
|
||||
|
||||
@property
|
||||
def device(self): return self.params[0].device
|
||||
|
||||
def _new_optim_param(self) -> list[Tensor]:
|
||||
param_dtype = to_dtype(getenv("OPTIM_DTYPE", "float32"))
|
||||
if self.fused: return [Tensor.zeros(self.pos_params[-1], dtype=param_dtype, device=self.device, requires_grad=False).contiguous()]
|
||||
if self.device is not None:
|
||||
return [Tensor.zeros(t.shape, dtype=param_dtype, device=self.device, requires_grad=False).contiguous() for t in self.params]
|
||||
return [Tensor.zeros_like(t, dtype=param_dtype, requires_grad=False).contiguous() for t in self.params]
|
||||
|
||||
def zero_grad(self):
|
||||
@@ -54,13 +54,14 @@ class Optimizer:
|
||||
# NOTE: contiguous is for speed
|
||||
out, extra = self._step([Tensor.cat(*[t.flatten() for t in self.params], dim=0)],
|
||||
[Tensor.cat(*[unwrap(t.grad).contiguous().flatten() for t in self.params], dim=0)])
|
||||
updated_params = [out[0][self.pos_params[i]:self.pos_params[i+1]].reshape(tt.shape) for i, tt in enumerate(self.params)]
|
||||
updates = [out[0][self.pos_params[i]:self.pos_params[i+1]].reshape(tt.shape) for i, tt in enumerate(self.params)]
|
||||
else:
|
||||
updated_params, extra = self._step(self.params, [unwrap(t.grad) for t in self.params])
|
||||
for i, tt in enumerate(self.params): tt.assign(updated_params[i])
|
||||
updates, extra = self._step(self.params, [unwrap(t.grad) for t in self.params])
|
||||
for i, tt in enumerate(self.params): tt.assign(self._apply_update(tt, updates[i]))
|
||||
return extra+self.params+self.buffers
|
||||
|
||||
def _step(self, params:list[Tensor], grads:list[Tensor]) -> tuple[list[Tensor], list[Tensor]]: raise NotImplementedError
|
||||
def _apply_update(self, t:Tensor, up:Tensor) -> Tensor: return t.detach() - up.to(t.device)
|
||||
|
||||
class OptimizerGroup(Optimizer):
|
||||
"""
|
||||
@@ -74,17 +75,17 @@ class OptimizerGroup(Optimizer):
|
||||
def schedule_step(self) -> list[Tensor]: return [x for o in self.optimizers for x in o.schedule_step()]
|
||||
|
||||
# LARS is essentially just trust ratio to SGD so if we just set the trust coeff 0.0 it's just standard SGD.
|
||||
def SGD(params: list[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False, fused=FUSE_OPTIM):
|
||||
def SGD(params: list[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False, device=None, fused=FUSE_OPTIM):
|
||||
"""
|
||||
Stochastic Gradient Descent (SGD) optimizer with optional momentum and weight decay.
|
||||
|
||||
`classic` is a boolean flag that determines whether to use the popular momentum update rule or the classic momentum update rule.
|
||||
"""
|
||||
return LARS(params, lr, momentum, weight_decay, 0, None, nesterov, classic=classic, pre_wd=True, tcoef=0.0, fused=fused)
|
||||
return LARS(params, lr, momentum, weight_decay, 0, None, nesterov, classic=classic, pre_wd=True, tcoef=0.0, device=device, fused=fused)
|
||||
|
||||
# Muon applies the newton schulz algorithm on gradient. also can include momentum, nesterov, and weight decay
|
||||
def Muon(params: list[Tensor], lr=0.001, momentum=0.95, weight_decay=0.1, ns_steps=5, ns_coefficients=(3.4445, -4.775, 2.0315),
|
||||
nesterov=True, fused=FUSE_OPTIM):
|
||||
nesterov=True, device=None, fused=FUSE_OPTIM):
|
||||
"""
|
||||
SGD with newton-schulz iteration and post momentum weight decay.
|
||||
|
||||
@@ -92,7 +93,8 @@ def Muon(params: list[Tensor], lr=0.001, momentum=0.95, weight_decay=0.1, ns_ste
|
||||
- Paper: https://arxiv.org/pdf/2502.16982
|
||||
"""
|
||||
assert not fused, "FUSE_OPTIM not allowed for Muon optimizer"
|
||||
return LARS(params, lr, momentum, weight_decay, ns_steps, ns_coefficients, nesterov, classic=False, pre_wd=False, tcoef=0.0, fused=fused)
|
||||
return LARS(params, lr, momentum, weight_decay, ns_steps, ns_coefficients, nesterov,
|
||||
classic=False, pre_wd=False, tcoef=0.0, device=None, fused=fused)
|
||||
|
||||
class LARS(Optimizer):
|
||||
"""
|
||||
@@ -101,8 +103,8 @@ class LARS(Optimizer):
|
||||
- Paper: https://arxiv.org/abs/1708.03888v3
|
||||
"""
|
||||
def __init__(self, params:list[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-4, ns_steps=0, ns_coefficients=None,
|
||||
nesterov=False, classic=True, pre_wd=True, tcoef=0.001, fused=FUSE_OPTIM):
|
||||
super().__init__(params, lr, fused)
|
||||
nesterov=False, classic=True, pre_wd=True, tcoef=0.001, device=None, fused=FUSE_OPTIM):
|
||||
super().__init__(params, lr, device, fused)
|
||||
self.momentum, self.wd, self.ns_steps, self.ns_coefficients = momentum, weight_decay, ns_steps, ns_coefficients
|
||||
self.nesterov, self.classic, self.pre_wd, self.tcoef = nesterov, classic, pre_wd, tcoef
|
||||
self.b = self._new_optim_param() if self.momentum else []
|
||||
@@ -126,24 +128,24 @@ class LARS(Optimizer):
|
||||
if not self.pre_wd and self.wd > 0: t = t.detach() * (1.0 - self.wd * self.lr)
|
||||
# popular momentum does pre learning rate update
|
||||
if not self.classic: g = g * r * self.lr
|
||||
ret.append((t.detach() - g).cast(t.dtype))
|
||||
ret.append(g.cast(t.dtype))
|
||||
return ret, self.b
|
||||
|
||||
# LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 it's just Adam/W.
|
||||
def AdamW(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01, fused=FUSE_OPTIM):
|
||||
def AdamW(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01, device=None, fused=FUSE_OPTIM):
|
||||
"""
|
||||
AdamW optimizer with optional weight decay.
|
||||
|
||||
- Paper: https://arxiv.org/abs/1711.05101v3
|
||||
"""
|
||||
return LAMB(params, lr, b1, b2, eps, weight_decay, adam=True, fused=fused)
|
||||
def Adam(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, fused=FUSE_OPTIM):
|
||||
return LAMB(params, lr, b1, b2, eps, weight_decay, adam=True, device=device, fused=fused)
|
||||
def Adam(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, device=None, fused=FUSE_OPTIM):
|
||||
"""
|
||||
Adam optimizer.
|
||||
|
||||
- Paper: https://arxiv.org/abs/1412.6980
|
||||
"""
|
||||
return LAMB(params, lr, b1, b2, eps, 0.0, adam=True, fused=fused)
|
||||
return LAMB(params, lr, b1, b2, eps, 0.0, adam=True, device=device, fused=fused)
|
||||
|
||||
class LAMB(Optimizer):
|
||||
"""
|
||||
@@ -151,8 +153,8 @@ class LAMB(Optimizer):
|
||||
|
||||
- Paper: https://arxiv.org/abs/1904.00962
|
||||
"""
|
||||
def __init__(self, params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False, fused=FUSE_OPTIM):
|
||||
super().__init__(params, lr, fused)
|
||||
def __init__(self, params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False, device=None, fused=FUSE_OPTIM):
|
||||
super().__init__(params, lr, device, fused)
|
||||
self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, weight_decay, adam
|
||||
self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device, requires_grad=False).contiguous() for _ in [b1, b2])
|
||||
self.m = self._new_optim_param()
|
||||
@@ -175,5 +177,5 @@ class LAMB(Optimizer):
|
||||
r: Tensor|float = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
|
||||
else:
|
||||
r = 1.0
|
||||
ret.append((t.detach() - self.lr * r * up).cast(t.dtype))
|
||||
ret.append((self.lr * r * up).cast(t.dtype))
|
||||
return ret, [self.b1_t, self.b2_t] + self.m + self.v
|
||||
|
||||
Reference in New Issue
Block a user