mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
* newton schulz * add muon + move newton schulz to tensor * compact newton schulz * better tests * cleanup * add comments for muon * cleanup * add export with tests * match muon optim with test optim * cleanup * unsed import * correct comment * whitespace * move export * muon test fix * match reference impl + tests * remove export by moving muon device * add credit * cleanup * remove print * spacing * spacing * comma * cleanup * removal * fix tests + optim momentum * consistent is not/ not * more consistency * fix test * cleanup * fix the nones * remove comment * cast * comment * comment * muon teeny test * muon flag beautiful mnist * set steps * steps as hyperparam * match default test steps * name * large cleanup * dont care about steps * nesterov false default * match each other impl * steps * switch nest * swap defaults * update docstring * add no nesterov test * ban fuse_optim * prints * classical momentum * alternative condition * recon * pre + post wd * false default * detach * signature changes * context * swap order * big cleanup * 0 step instead * parity * remove fuse * remove fused * better paper * assert message * correct shape check + eps * multidim * add eps * cleanup * correct assert message * lint * better tests * naming * ns_steps,ns_params * update docstring * docstring * match sgd and muon together * sandwich * add back fused * parity --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
76 lines
3.1 KiB
Python
76 lines
3.1 KiB
Python
import torch
|
|
|
|
#credit to KellerJordan at https://github.com/KellerJordan/Muon/tree/master
|
|
#some changes: classic momentum instead of weighting gradient
|
|
#added ns_steps, ns_params, nesterov as hyperparams
|
|
def zeropower_via_newtonschulz5(G:torch.tensor, steps:int, params:tuple[int, ...]):
|
|
"""
|
|
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
|
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
|
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
|
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
|
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
|
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
|
performance at all relative to UV^T, where USV^T = G is the SVD.
|
|
"""
|
|
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
|
|
|
a, b, c = params
|
|
X = G
|
|
if G.size(-2) > G.size(-1):
|
|
X = X.mT
|
|
|
|
# Ensure spectral norm is at most 1
|
|
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
|
# Perform the NS iterations
|
|
for _ in range(steps):
|
|
A = X @ X.mT
|
|
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
|
X = a * X + B @ X
|
|
|
|
if G.size(-2) > G.size(-1):
|
|
X = X.mT
|
|
|
|
return X
|
|
|
|
def muon_update(grad, momentum, beta=0.95, ns_steps=5, ns_params=(3.4445, -4.7750, 2.0315), nesterov=True):
|
|
if beta:
|
|
momentum.mul_(beta).add_(grad)
|
|
update = grad.add(momentum,alpha=beta) if nesterov else momentum
|
|
else: update = grad
|
|
if update.ndim == 4: # for the case of conv filters
|
|
update = update.view(len(update), -1)
|
|
update = zeropower_via_newtonschulz5(update, steps=ns_steps, params=ns_params)
|
|
return update
|
|
|
|
class SingleDeviceMuon(torch.optim.Optimizer):
|
|
"""
|
|
Muon variant for usage in non-distributed settings.
|
|
"""
|
|
def __init__(self, params, lr=0.02, weight_decay=0.0, momentum=0.95, ns_steps=5, ns_params=(3.4445, -4.7750, 2.0315), nesterov=True):
|
|
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, ns_steps=ns_steps, ns_params=ns_params, nesterov=nesterov)
|
|
super().__init__(params, defaults)
|
|
|
|
@torch.no_grad()
|
|
def step(self, closure=None):
|
|
|
|
loss = None
|
|
if closure is not None:
|
|
with torch.enable_grad():
|
|
loss = closure()
|
|
|
|
for group in self.param_groups:
|
|
for p in group["params"]:
|
|
if p.grad is None:
|
|
p.grad = torch.zeros_like(p) # Force synchronization
|
|
state = self.state[p]
|
|
if len(state) == 0:
|
|
state["momentum_buffer"] = torch.zeros_like(p)
|
|
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"], ns_steps=group["ns_steps"],
|
|
ns_params=group["ns_params"], nesterov=group["nesterov"])
|
|
p.mul_(1.0 - group["lr"] * group["weight_decay"])
|
|
|
|
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
|
|
|
return loss
|