add SGD to beautiful_mnist (#13571)

This commit is contained in:
chenyu
2025-12-04 12:17:29 -05:00
committed by GitHub
parent 512a8f3dd4
commit 89f9e1dcd5
3 changed files with 4 additions and 3 deletions

View File

@@ -21,7 +21,7 @@ if __name__ == "__main__":
X_train, Y_train, X_test, Y_test = mnist(fashion=getenv("FASHION"))
model = Model()
opt = (nn.optim.Adam if not getenv("MUON") else nn.optim.Muon)(nn.state.get_parameters(model))
opt = (nn.optim.Muon if getenv("MUON") else nn.optim.SGD if getenv("SGD") else nn.optim.Adam)(nn.state.get_parameters(model))
@TinyJit
@Tensor.train()