From 7338ffead0b4e31b21fdecabf651ef03413f0f44 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 9 Aug 2025 16:51:14 -0700 Subject: [PATCH] small beautiful_mnist update (#11596) gather is fast now. there's a conv/bw kernel that only gets fast with BEAM, but whole thing runs < 5 seconds now regardless --- examples/beautiful_mnist.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/beautiful_mnist.py b/examples/beautiful_mnist.py index 685a413116..590d6c3697 100644 --- a/examples/beautiful_mnist.py +++ b/examples/beautiful_mnist.py @@ -1,12 +1,12 @@ # model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392 -from typing import List, Callable +from typing import Callable from tinygrad import Tensor, TinyJit, nn, GlobalCounters from tinygrad.helpers import getenv, colored, trange from tinygrad.nn.datasets import mnist class Model: def __init__(self): - self.layers: List[Callable[[Tensor], Tensor]] = [ + self.layers: list[Callable[[Tensor], Tensor]] = [ nn.Conv2d(1, 32, 5), Tensor.relu, nn.Conv2d(32, 32, 5), Tensor.relu, nn.BatchNorm(32), Tensor.max_pool2d, @@ -28,7 +28,6 @@ if __name__ == "__main__": def train_step() -> Tensor: opt.zero_grad() samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0]) - # TODO: this "gather" of samples is very slow. will be under 5s when this is fixed loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward() opt.step() return loss