From 89f9e1dcd5fa6a2c2533e2fa75f31a486d654e7d Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 4 Dec 2025 12:17:29 -0500 Subject: [PATCH] add SGD to beautiful_mnist (#13571) --- examples/beautiful_mnist.py | 2 +- test/test_schedule.py | 2 +- tinygrad/nn/optim.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/beautiful_mnist.py b/examples/beautiful_mnist.py index 0a73f51011..39c1cf2554 100644 --- a/examples/beautiful_mnist.py +++ b/examples/beautiful_mnist.py @@ -21,7 +21,7 @@ if __name__ == "__main__": X_train, Y_train, X_test, Y_test = mnist(fashion=getenv("FASHION")) model = Model() - opt = (nn.optim.Adam if not getenv("MUON") else nn.optim.Muon)(nn.state.get_parameters(model)) + opt = (nn.optim.Muon if getenv("MUON") else nn.optim.SGD if getenv("SGD") else nn.optim.Adam)(nn.state.get_parameters(model)) @TinyJit @Tensor.train() diff --git a/test/test_schedule.py b/test/test_schedule.py index 5d513808e2..9e10966980 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -27,7 +27,7 @@ def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Te # test lowering all the ScheduleItems to ExecItems kernel_cnt = len([si for si,ei in lower_schedule(sched.copy()) if isinstance(ei.prg, CompiledRunner) or not filter_sink]) if kernel_cnt != allowed: - print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}") + print(f"SCHEDULE ISSUE, expecting {allowed} got {kernel_cnt}") if DEBUG >= 3: for i,s in enumerate(sched): print("kernel", i+1) diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index 791e974c87..bcca52f0e0 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -9,7 +9,7 @@ class Optimizer: Base class for all optimizers. """ def __init__(self, params: list[Tensor], lr: float, fused=FUSE_OPTIM): - # if it's None, but being put into an optimizer, set it to True + # if requires_grad is None, but being put into an optimizer, set it to True for x in params: if x.requires_grad is None: x.requires_grad = True @@ -51,6 +51,7 @@ class Optimizer: - help: Consider setting Tensor.training=True before calling Optimizer.step().""") if self.fused: # optimizer fusion just concatenates all the buffers, runs the _step, then splits them back up + # NOTE: contiguous is for speed out, extra = self._step([Tensor.cat(*[t.flatten() for t in self.params], dim=0)], [Tensor.cat(*[unwrap(t.grad).contiguous().flatten() for t in self.params], dim=0)]) updated_params = [out[0][self.pos_params[i]:self.pos_params[i+1]].reshape(tt.shape) for i, tt in enumerate(self.params)]