mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
[bounty] Muon optim (#11414)
* newton schulz * add muon + move newton schulz to tensor * compact newton schulz * better tests * cleanup * add comments for muon * cleanup * add export with tests * match muon optim with test optim * cleanup * unsed import * correct comment * whitespace * move export * muon test fix * match reference impl + tests * remove export by moving muon device * add credit * cleanup * remove print * spacing * spacing * comma * cleanup * removal * fix tests + optim momentum * consistent is not/ not * more consistency * fix test * cleanup * fix the nones * remove comment * cast * comment * comment * muon teeny test * muon flag beautiful mnist * set steps * steps as hyperparam * match default test steps * name * large cleanup * dont care about steps * nesterov false default * match each other impl * steps * switch nest * swap defaults * update docstring * add no nesterov test * ban fuse_optim * prints * classical momentum * alternative condition * recon * pre + post wd * false default * detach * signature changes * context * swap order * big cleanup * 0 step instead * parity * remove fuse * remove fused * better paper * assert message * correct shape check + eps * multidim * add eps * cleanup * correct assert message * lint * better tests * naming * ns_steps,ns_params * update docstring * docstring * match sgd and muon together * sandwich * add back fused * parity --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -21,7 +21,7 @@ if __name__ == "__main__":
|
||||
X_train, Y_train, X_test, Y_test = mnist(fashion=getenv("FASHION"))
|
||||
|
||||
model = Model()
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(model))
|
||||
opt = (nn.optim.Adam if not getenv("MUON") else nn.optim.Muon)(nn.state.get_parameters(model))
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
|
||||
Reference in New Issue
Block a user