mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
small beautiful_mnist update (#11596)
gather is fast now. there's a conv/bw kernel that only gets fast with BEAM, but whole thing runs < 5 seconds now regardless
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
# model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
|
||||
from typing import List, Callable
|
||||
from typing import Callable
|
||||
from tinygrad import Tensor, TinyJit, nn, GlobalCounters
|
||||
from tinygrad.helpers import getenv, colored, trange
|
||||
from tinygrad.nn.datasets import mnist
|
||||
|
||||
class Model:
|
||||
def __init__(self):
|
||||
self.layers: List[Callable[[Tensor], Tensor]] = [
|
||||
self.layers: list[Callable[[Tensor], Tensor]] = [
|
||||
nn.Conv2d(1, 32, 5), Tensor.relu,
|
||||
nn.Conv2d(32, 32, 5), Tensor.relu,
|
||||
nn.BatchNorm(32), Tensor.max_pool2d,
|
||||
@@ -28,7 +28,6 @@ if __name__ == "__main__":
|
||||
def train_step() -> Tensor:
|
||||
opt.zero_grad()
|
||||
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
|
||||
# TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
|
||||
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
|
||||
opt.step()
|
||||
return loss
|
||||
|
||||
Reference in New Issue
Block a user