mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
the compiler error was due to `error: call to 'max' is ambiguous` when we have max(int, float) in kernel.
it was first fixed in 4380ccb1 the non fp32 math PR, and further solidified with dtype refactor
45 lines
1.7 KiB
Python
45 lines
1.7 KiB
Python
# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
|
|
from typing import List, Callable
|
|
from tinygrad import Tensor, TinyJit, nn, GlobalCounters
|
|
from extra.datasets import fetch_mnist
|
|
from tqdm import trange
|
|
|
|
class Model:
|
|
def __init__(self):
|
|
self.layers: List[Callable[[Tensor], Tensor]] = [
|
|
nn.Conv2d(1, 32, 5), Tensor.relu,
|
|
nn.Conv2d(32, 32, 5), Tensor.relu,
|
|
nn.BatchNorm2d(32), Tensor.max_pool2d,
|
|
nn.Conv2d(32, 64, 3), Tensor.relu,
|
|
nn.Conv2d(64, 64, 3), Tensor.relu,
|
|
nn.BatchNorm2d(64), Tensor.max_pool2d,
|
|
lambda x: x.flatten(1), nn.Linear(576, 10)]
|
|
|
|
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
|
|
|
|
if __name__ == "__main__":
|
|
X_train, Y_train, X_test, Y_test = fetch_mnist(tensors=True)
|
|
|
|
model = Model()
|
|
opt = nn.optim.Adam(nn.state.get_parameters(model))
|
|
|
|
@TinyJit
|
|
def train_step(samples:Tensor) -> Tensor:
|
|
with Tensor.train():
|
|
opt.zero_grad()
|
|
# TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
|
|
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
|
|
opt.step()
|
|
return loss.realize()
|
|
|
|
@TinyJit
|
|
def get_test_acc() -> Tensor: return ((model(X_test).argmax(axis=1) == Y_test).mean()*100).realize()
|
|
|
|
test_acc = float('nan')
|
|
for i in (t:=trange(70)):
|
|
GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing
|
|
samples = Tensor.randint(512, high=X_train.shape[0]) # TODO: put this in the JIT when rand is fixed
|
|
loss = train_step(samples)
|
|
if i%10 == 9: test_acc = get_test_acc().item()
|
|
t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")
|