From 4b9825c82903da9a2f7b27bd7d4300b4e00816c5 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Fri, 20 Feb 2026 02:43:56 -0800 Subject: [PATCH] make optim _step return update (#14906) --- tinygrad/nn/optim.py | 44 +++++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index fc54f5190d..9fe71d1b86 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -8,7 +8,7 @@ class Optimizer: """ Base class for all optimizers. """ - def __init__(self, params: list[Tensor], lr: float, fused=FUSE_OPTIM): + def __init__(self, params: list[Tensor], lr: float, device=None, fused=FUSE_OPTIM): # if requires_grad is None, but being put into an optimizer, set it to True for x in params: if x.requires_grad is None: x.requires_grad_(True) @@ -16,18 +16,18 @@ class Optimizer: self.params: list[Tensor] = dedup([x for x in params if x.requires_grad]) assert len(self.params) != 0, "optimizer must have at least one param" self.buffers: list[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized + self.device = device or self.params[0].device self.fused = fused # store lr in at least float32 precision self.lr = Tensor(lr if getenv("CONST_LR") else [lr], requires_grad=False, device=self.device, dtype=least_upper_dtype(dtypes.default_float, dtypes.float32)) if self.fused: self.pos_params = list(itertools.accumulate(self.params, lambda x,y: x+y.numel(), initial=0)) - @property - def device(self): return self.params[0].device - def _new_optim_param(self) -> list[Tensor]: param_dtype = to_dtype(getenv("OPTIM_DTYPE", "float32")) if self.fused: return [Tensor.zeros(self.pos_params[-1], dtype=param_dtype, device=self.device, requires_grad=False).contiguous()] + if self.device is not None: + return [Tensor.zeros(t.shape, dtype=param_dtype, device=self.device, requires_grad=False).contiguous() for t in self.params] return [Tensor.zeros_like(t, dtype=param_dtype, requires_grad=False).contiguous() for t in self.params] def zero_grad(self): @@ -54,13 +54,14 @@ class Optimizer: # NOTE: contiguous is for speed out, extra = self._step([Tensor.cat(*[t.flatten() for t in self.params], dim=0)], [Tensor.cat(*[unwrap(t.grad).contiguous().flatten() for t in self.params], dim=0)]) - updated_params = [out[0][self.pos_params[i]:self.pos_params[i+1]].reshape(tt.shape) for i, tt in enumerate(self.params)] + updates = [out[0][self.pos_params[i]:self.pos_params[i+1]].reshape(tt.shape) for i, tt in enumerate(self.params)] else: - updated_params, extra = self._step(self.params, [unwrap(t.grad) for t in self.params]) - for i, tt in enumerate(self.params): tt.assign(updated_params[i]) + updates, extra = self._step(self.params, [unwrap(t.grad) for t in self.params]) + for i, tt in enumerate(self.params): tt.assign(self._apply_update(tt, updates[i])) return extra+self.params+self.buffers def _step(self, params:list[Tensor], grads:list[Tensor]) -> tuple[list[Tensor], list[Tensor]]: raise NotImplementedError + def _apply_update(self, t:Tensor, up:Tensor) -> Tensor: return t.detach() - up.to(t.device) class OptimizerGroup(Optimizer): """ @@ -74,17 +75,17 @@ class OptimizerGroup(Optimizer): def schedule_step(self) -> list[Tensor]: return [x for o in self.optimizers for x in o.schedule_step()] # LARS is essentially just trust ratio to SGD so if we just set the trust coeff 0.0 it's just standard SGD. -def SGD(params: list[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False, fused=FUSE_OPTIM): +def SGD(params: list[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False, device=None, fused=FUSE_OPTIM): """ Stochastic Gradient Descent (SGD) optimizer with optional momentum and weight decay. `classic` is a boolean flag that determines whether to use the popular momentum update rule or the classic momentum update rule. """ - return LARS(params, lr, momentum, weight_decay, 0, None, nesterov, classic=classic, pre_wd=True, tcoef=0.0, fused=fused) + return LARS(params, lr, momentum, weight_decay, 0, None, nesterov, classic=classic, pre_wd=True, tcoef=0.0, device=device, fused=fused) # Muon applies the newton schulz algorithm on gradient. also can include momentum, nesterov, and weight decay def Muon(params: list[Tensor], lr=0.001, momentum=0.95, weight_decay=0.1, ns_steps=5, ns_coefficients=(3.4445, -4.775, 2.0315), - nesterov=True, fused=FUSE_OPTIM): + nesterov=True, device=None, fused=FUSE_OPTIM): """ SGD with newton-schulz iteration and post momentum weight decay. @@ -92,7 +93,8 @@ def Muon(params: list[Tensor], lr=0.001, momentum=0.95, weight_decay=0.1, ns_ste - Paper: https://arxiv.org/pdf/2502.16982 """ assert not fused, "FUSE_OPTIM not allowed for Muon optimizer" - return LARS(params, lr, momentum, weight_decay, ns_steps, ns_coefficients, nesterov, classic=False, pre_wd=False, tcoef=0.0, fused=fused) + return LARS(params, lr, momentum, weight_decay, ns_steps, ns_coefficients, nesterov, + classic=False, pre_wd=False, tcoef=0.0, device=None, fused=fused) class LARS(Optimizer): """ @@ -101,8 +103,8 @@ class LARS(Optimizer): - Paper: https://arxiv.org/abs/1708.03888v3 """ def __init__(self, params:list[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-4, ns_steps=0, ns_coefficients=None, - nesterov=False, classic=True, pre_wd=True, tcoef=0.001, fused=FUSE_OPTIM): - super().__init__(params, lr, fused) + nesterov=False, classic=True, pre_wd=True, tcoef=0.001, device=None, fused=FUSE_OPTIM): + super().__init__(params, lr, device, fused) self.momentum, self.wd, self.ns_steps, self.ns_coefficients = momentum, weight_decay, ns_steps, ns_coefficients self.nesterov, self.classic, self.pre_wd, self.tcoef = nesterov, classic, pre_wd, tcoef self.b = self._new_optim_param() if self.momentum else [] @@ -126,24 +128,24 @@ class LARS(Optimizer): if not self.pre_wd and self.wd > 0: t = t.detach() * (1.0 - self.wd * self.lr) # popular momentum does pre learning rate update if not self.classic: g = g * r * self.lr - ret.append((t.detach() - g).cast(t.dtype)) + ret.append(g.cast(t.dtype)) return ret, self.b # LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 it's just Adam/W. -def AdamW(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01, fused=FUSE_OPTIM): +def AdamW(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01, device=None, fused=FUSE_OPTIM): """ AdamW optimizer with optional weight decay. - Paper: https://arxiv.org/abs/1711.05101v3 """ - return LAMB(params, lr, b1, b2, eps, weight_decay, adam=True, fused=fused) -def Adam(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, fused=FUSE_OPTIM): + return LAMB(params, lr, b1, b2, eps, weight_decay, adam=True, device=device, fused=fused) +def Adam(params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, device=None, fused=FUSE_OPTIM): """ Adam optimizer. - Paper: https://arxiv.org/abs/1412.6980 """ - return LAMB(params, lr, b1, b2, eps, 0.0, adam=True, fused=fused) + return LAMB(params, lr, b1, b2, eps, 0.0, adam=True, device=device, fused=fused) class LAMB(Optimizer): """ @@ -151,8 +153,8 @@ class LAMB(Optimizer): - Paper: https://arxiv.org/abs/1904.00962 """ - def __init__(self, params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False, fused=FUSE_OPTIM): - super().__init__(params, lr, fused) + def __init__(self, params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False, device=None, fused=FUSE_OPTIM): + super().__init__(params, lr, device, fused) self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, weight_decay, adam self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device, requires_grad=False).contiguous() for _ in [b1, b2]) self.m = self._new_optim_param() @@ -175,5 +177,5 @@ class LAMB(Optimizer): r: Tensor|float = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0) else: r = 1.0 - ret.append((t.detach() - self.lr * r * up).cast(t.dtype)) + ret.append((self.lr * r * up).cast(t.dtype)) return ret, [self.b1_t, self.b2_t] + self.m + self.v