small beautiful_mnist update (#11596)

gather is fast now. there's a conv/bw kernel that only gets fast with BEAM, but whole thing runs < 5 seconds now regardless
This commit is contained in:
chenyu
2025-08-09 16:51:14 -07:00
committed by GitHub
parent 45baec1aab
commit 7338ffead0

View File

@@ -1,12 +1,12 @@
# model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import List, Callable
from typing import Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters
from tinygrad.helpers import getenv, colored, trange
from tinygrad.nn.datasets import mnist
class Model:
def __init__(self):
self.layers: List[Callable[[Tensor], Tensor]] = [
self.layers: list[Callable[[Tensor], Tensor]] = [
nn.Conv2d(1, 32, 5), Tensor.relu,
nn.Conv2d(32, 32, 5), Tensor.relu,
nn.BatchNorm(32), Tensor.max_pool2d,
@@ -28,7 +28,6 @@ if __name__ == "__main__":
def train_step() -> Tensor:
opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
# TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
opt.step()
return loss