From bec2aaf404aa3bdc569330a8d66c6678f8bfc459 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Tue, 2 Apr 2024 00:54:04 +0000 Subject: [PATCH] add beautiful_mnist_multigpu example --- examples/beautiful_mnist_multigpu.py | 57 ++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 examples/beautiful_mnist_multigpu.py diff --git a/examples/beautiful_mnist_multigpu.py b/examples/beautiful_mnist_multigpu.py new file mode 100644 index 0000000000..6c5cb70dcc --- /dev/null +++ b/examples/beautiful_mnist_multigpu.py @@ -0,0 +1,57 @@ +# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392 +from typing import List, Callable +from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device +from tinygrad.helpers import getenv, colored +from extra.datasets import fetch_mnist +from tqdm import trange + +GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 2))] + +class Model: + def __init__(self): + self.layers: List[Callable[[Tensor], Tensor]] = [ + nn.Conv2d(1, 32, 5), Tensor.relu, + nn.Conv2d(32, 32, 5), Tensor.relu, + nn.BatchNorm2d(32), Tensor.max_pool2d, + nn.Conv2d(32, 64, 3), Tensor.relu, + nn.Conv2d(64, 64, 3), Tensor.relu, + nn.BatchNorm2d(64), Tensor.max_pool2d, + lambda x: x.flatten(1), nn.Linear(576, 10)] + + def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers) + +if __name__ == "__main__": + X_train, Y_train, X_test, Y_test = fetch_mnist(tensors=True) + # we shard the test data on axis 0 + X_test.shard_(GPUS, axis=0) + Y_test.shard_(GPUS, axis=0) + + model = Model() + for k, x in nn.state.get_state_dict(model).items(): x.to_(GPUS) # we put a copy of the model on every GPU + opt = nn.optim.Adam(nn.state.get_parameters(model)) + + @TinyJit + def train_step() -> Tensor: + with Tensor.train(): + opt.zero_grad() + samples = Tensor.randint(512, high=X_train.shape[0]) + Xt, Yt = X_train[samples].shard_(GPUS, axis=0), Y_train[samples].shard_(GPUS, axis=0) # we shard the data on axis 0 + # TODO: this "gather" of samples is very slow. will be under 5s when this is fixed + loss = model(Xt).sparse_categorical_crossentropy(Yt).backward() + opt.step() + return loss + + @TinyJit + def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100 + + test_acc = float('nan') + for i in (t:=trange(70)): + GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing + loss = train_step() + if i%10 == 9: test_acc = get_test_acc().item() + t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%") + + # verify eval acc + if target := getenv("TARGET_EVAL_ACC_PCT", 0.0): + if test_acc >= target: print(colored(f"{test_acc=} >= {target}", "green")) + else: raise ValueError(colored(f"{test_acc=} < {target}", "red")) \ No newline at end of file