From 9366a23eb031c22390bd5fd2b642c73f96e74732 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 16 Aug 2025 20:29:39 -0700 Subject: [PATCH] test backward in test_tiny (#11697) * test backward in test_tiny * empty --- examples/beautiful_mnist.py | 3 +-- test/test_tiny.py | 18 ++++++++++++++++++ tinygrad/tensor.py | 2 +- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/examples/beautiful_mnist.py b/examples/beautiful_mnist.py index 6c8ba03c56..0a73f51011 100644 --- a/examples/beautiful_mnist.py +++ b/examples/beautiful_mnist.py @@ -29,8 +29,7 @@ if __name__ == "__main__": opt.zero_grad() samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0]) loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward() - opt.step() - return loss + return loss.realize(*opt.schedule_step()) @TinyJit def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100 diff --git a/test/test_tiny.py b/test/test_tiny.py index a38c7ed628..e11e74813b 100644 --- a/test/test_tiny.py +++ b/test/test_tiny.py @@ -106,6 +106,24 @@ class TestTiny(unittest.TestCase): probs = Tensor.rand(1, 1, 28, 28).sequential(layers).tolist() self.assertEqual(len(probs[0]), 10) + # TODO: this is failing because of how swizzling rewrites the ShapeTracker of the final STORE + @unittest.skipIf(IMAGE>0 or (CI and Device.DEFAULT == "DSP"), "failing because of make things that can't be images not images") + def test_mnist_backward(self): + # NOTE: we don't have the whole model here for speed + layers = [ + nn.Conv2d(1, 32, 5), Tensor.relu, + nn.Conv2d(32, 32, 5), Tensor.relu] + + # replace random weights with ones + # TODO: there's a bug here where it's tying two of the biases together. we need UNIQUE const + #Tensor.realize(*[p.replace(Tensor.ones_like(p).contiguous()) for p in nn.state.get_parameters(layers)]) + for p in nn.state.get_parameters(layers): p.replace(Tensor.empty(p.shape)) + + # realize gradients + for x in nn.state.get_parameters(layers): x.requires_grad_() + Tensor.empty(4, 1, 28, 28).sequential(layers).sum().backward() + Tensor.realize(*[x.grad for x in nn.state.get_parameters(layers) if x.grad is not None]) + # *** image *** @unittest.skipIf(Device.DEFAULT != "GPU", "image only supported on GPU") diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 52267c732f..b8ce5873fc 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -252,7 +252,7 @@ class Tensor(MathTrait): # create the schedule schedule, var_vals = create_schedule_with_vars(sink) schedule = memory_planner(schedule) - if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms") + if DEBUG >= 1 and len(schedule) > 1: print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms") return schedule, var_vals def schedule(self, *lst:Tensor) -> list[ScheduleItem]: