diff --git a/examples/beautiful_cifar.py b/examples/beautiful_cifar.py index 316d07b611..8619550ad6 100644 --- a/examples/beautiful_cifar.py +++ b/examples/beautiful_cifar.py @@ -3,16 +3,19 @@ start_tm = time.perf_counter() import math from typing import Tuple, cast import numpy as np -from tinygrad import Tensor, nn, GlobalCounters, TinyJit, dtypes +from tinygrad import Tensor, nn, GlobalCounters, TinyJit, dtypes, Device from tinygrad.helpers import partition, trange, getenv, Context from extra.lr_scheduler import OneCycleLR +GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 1))] + # override tinygrad defaults dtypes.default_float = dtypes.half Context(FUSE_ARANGE=1, FUSE_OPTIM=1).__enter__() # from https://github.com/tysam-code/hlb-CIFAR10/blob/main/main.py batchsize = getenv("BS", 1024) +assert batchsize % len(GPUS) == 0, f"{batchsize=} is not a multiple of {len(GPUS)=}" bias_scaler = 64 hyp = { 'opt': { @@ -94,8 +97,13 @@ if __name__ == "__main__": # *** model *** model = SpeedyConvNet() + state_dict = nn.state.get_state_dict(model) + if len(GPUS) > 1: + cifar10_std.to_(GPUS) + cifar10_mean.to_(GPUS) + for x in state_dict.values(): x.to_(GPUS) - params_bias, params_non_bias = partition(nn.state.get_state_dict(model).items(), lambda x: 'bias' in x[0]) + params_bias, params_non_bias = partition(state_dict.items(), lambda x: 'bias' in x[0]) opt_bias = nn.optim.SGD([x[1] for x in params_bias], lr=0.01, momentum=.85, nesterov=True, weight_decay=hyp['opt']['bias_decay']) opt_non_bias = nn.optim.SGD([x[1] for x in params_non_bias], lr=0.01, momentum=.85, nesterov=True, weight_decay=hyp['opt']['non_bias_decay']) opt = nn.optim.OptimizerGroup(opt_bias, opt_non_bias) @@ -117,8 +125,12 @@ if __name__ == "__main__": @TinyJit @Tensor.train() def train_step(idxs:Tensor) -> Tensor: - out = model(preprocess(X_train[idxs])) - loss = loss_fn(out, Y_train[idxs]) + X, Y = X_train[idxs], Y_train[idxs] + if len(GPUS) > 1: + X.shard_(GPUS, axis=0) + Y.shard_(GPUS, axis=0) + out = model(preprocess(X)) + loss = loss_fn(out, Y) opt.zero_grad() loss.backward() return (loss / (batchsize*loss_batchsize_scaler)).realize(*opt.schedule_step(), @@ -130,8 +142,11 @@ if __name__ == "__main__": def val_step() -> Tuple[Tensor, Tensor]: loss, acc = [], [] for i in range(0, X_test.size(0), eval_batchsize): - Y = Y_test[i:i+eval_batchsize] - out = model(preprocess(X_test[i:i+eval_batchsize])) + X, Y = X_test[i:i+eval_batchsize], Y_test[i:i+eval_batchsize] + if len(GPUS) > 1: + X.shard_(GPUS, axis=0) + Y.shard_(GPUS, axis=0) + out = model(preprocess(X)) loss.append(loss_fn(out, Y)) acc.append((out.argmax(-1) == Y).sum() / eval_batchsize) return Tensor.stack(*loss).mean() / (batchsize*loss_batchsize_scaler), Tensor.stack(*acc).mean()