[bounty] Muon optim (#11414)

* newton schulz

* add muon + move newton schulz to tensor

* compact newton schulz

* better tests

* cleanup

* add comments for muon

* cleanup

* add export with tests

* match muon optim with test optim

* cleanup

* unsed import

* correct comment

* whitespace

* move export

* muon test fix

* match reference impl + tests

* remove export by moving muon device

* add credit

* cleanup

* remove print

* spacing

* spacing

* comma

* cleanup

* removal

* fix tests + optim momentum

* consistent is not/ not

* more consistency

* fix test

* cleanup

* fix the nones

* remove comment

* cast

* comment

* comment

* muon teeny test

* muon flag beautiful mnist

* set steps

* steps as hyperparam

* match default test steps

* name

* large cleanup

* dont care about steps

* nesterov false default

* match each other impl

* steps

* switch nest

* swap defaults

* update docstring

* add no nesterov test

* ban fuse_optim

* prints

* classical momentum

* alternative condition

* recon

* pre + post wd

* false default

* detach

* signature changes

* context

* swap order

* big cleanup

* 0 step instead

* parity

* remove fuse

* remove fused

* better paper

* assert message

* correct shape check + eps

* multidim

* add eps

* cleanup

* correct assert message

* lint

* better tests

* naming

* ns_steps,ns_params

* update docstring

* docstring

* match sgd and muon together

* sandwich

* add back fused

* parity

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
kevvz
2025-08-13 11:27:55 -07:00
committed by GitHub
parent 94e6d84e32
commit e2873a3a41
6 changed files with 149 additions and 6 deletions

View File

@@ -21,7 +21,7 @@ if __name__ == "__main__":
X_train, Y_train, X_test, Y_test = mnist(fashion=getenv("FASHION"))
model = Model()
opt = nn.optim.Adam(nn.state.get_parameters(model))
opt = (nn.optim.Adam if not getenv("MUON") else nn.optim.Muon)(nn.state.get_parameters(model))
@TinyJit
@Tensor.train()