From 871df1436a961ab8071d9137a8e293c43b2809ea Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 28 May 2025 20:48:20 -0700 Subject: [PATCH] more beautiful cifar (#10551) * enumerate cases of Tensors in the JIT * optional fused optimizers * add fused optimizer test * move that there * ugh * work on beautiful_cifar * speed close to hlb_cifar * schedule to corealize all * one line sched step * less lines --- examples/beautiful_cifar.py | 48 +++++++++++++++---------------------- extra/lr_scheduler.py | 5 ++-- 2 files changed, 21 insertions(+), 32 deletions(-) diff --git a/examples/beautiful_cifar.py b/examples/beautiful_cifar.py index bd8c414bcc..316d07b611 100644 --- a/examples/beautiful_cifar.py +++ b/examples/beautiful_cifar.py @@ -7,7 +7,9 @@ from tinygrad import Tensor, nn, GlobalCounters, TinyJit, dtypes from tinygrad.helpers import partition, trange, getenv, Context from extra.lr_scheduler import OneCycleLR +# override tinygrad defaults dtypes.default_float = dtypes.half +Context(FUSE_ARANGE=1, FUSE_OPTIM=1).__enter__() # from https://github.com/tysam-code/hlb-CIFAR10/blob/main/main.py batchsize = getenv("BS", 1024) @@ -67,7 +69,7 @@ class ConvGroup: cast(Tensor, self.norm2.weight).requires_grad = False def __call__(self, x:Tensor) -> Tensor: x = self.norm1(self.conv1(x).max_pool2d().float()).cast(dtypes.default_float).quick_gelu() - return self.norm2(self.conv2(x).float()).cast(dtypes.default_float).quick_gelu() + return self.norm2(self.conv2(x).float()).cast(dtypes.default_float).quick_gelu() + x class SpeedyConvNet: def __init__(self): @@ -78,25 +80,22 @@ class SpeedyConvNet: self.linear = nn.Linear(depths['block3'], depths['num_classes'], bias=False) def __call__(self, x:Tensor) -> Tensor: x = self.whiten(x).quick_gelu() + # ************* HACKS ************* + x = x.pad((1,0,0,1)) # TODO: this pad should not be here! copied from hlb_cifar10 for speed + # ************* HACKS ************* x = x.sequential([self.conv_group_1, self.conv_group_2, self.conv_group_3]) return self.linear(x.max(axis=(2,3))) * hyp['opt']['scaling_factor'] if __name__ == "__main__": # *** dataset *** X_train, Y_train, X_test, Y_test = nn.datasets.cifar() - # TODO: without this line indexing doesn't fuse! - X_train, Y_train, X_test, Y_test = [x.contiguous() for x in [X_train, Y_train, X_test, Y_test]] cifar10_std, cifar10_mean = X_train.float().std_mean(axis=(0, 2, 3)) - def preprocess(X:Tensor, Y:Tensor) -> Tuple[Tensor, Tensor]: - return ((X - cifar10_mean.view(1, -1, 1, 1)) / cifar10_std.view(1, -1, 1, 1)).cast(dtypes.default_float), Y.one_hot(depths['num_classes']) + def preprocess(X:Tensor) -> Tensor: return ((X - cifar10_mean.view(1, -1, 1, 1)) / cifar10_std.view(1, -1, 1, 1)).cast(dtypes.default_float) # *** model *** model = SpeedyConvNet() - state_dict = nn.state.get_state_dict(model) - #for k,v in nn.state.torch_load("/tmp/cifar_net.pt").items(): print(k) - - params_bias, params_non_bias = partition(state_dict.items(), lambda x: 'bias' in x[0]) + params_bias, params_non_bias = partition(nn.state.get_state_dict(model).items(), lambda x: 'bias' in x[0]) opt_bias = nn.optim.SGD([x[1] for x in params_bias], lr=0.01, momentum=.85, nesterov=True, weight_decay=hyp['opt']['bias_decay']) opt_non_bias = nn.optim.SGD([x[1] for x in params_non_bias], lr=0.01, momentum=.85, nesterov=True, weight_decay=hyp['opt']['non_bias_decay']) opt = nn.optim.OptimizerGroup(opt_bias, opt_non_bias) @@ -111,40 +110,31 @@ if __name__ == "__main__": lr_sched_bias = OneCycleLR(opt_bias, max_lr=hyp['opt']['bias_lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps) lr_sched_non_bias = OneCycleLR(opt_non_bias, max_lr=hyp['opt']['non_bias_lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps) - def loss_fn(out, Y): - return out.cross_entropy(Y, reduction='none', label_smoothing=0.2).mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler']) + def loss_fn(out:Tensor, Y:Tensor) -> Tensor: + ret = out.sparse_categorical_crossentropy(Y, reduction='none', label_smoothing=0.2) + return ret.mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler']) @TinyJit @Tensor.train() def train_step(idxs:Tensor) -> Tensor: - with Context(SPLIT_REDUCEOP=0, FUSE_ARANGE=1): - X = X_train[idxs] - Y = Y_train[idxs].realize(X) - X, Y = preprocess(X, Y) - out = model(X) - loss = loss_fn(out, Y) + out = model(preprocess(X_train[idxs])) + loss = loss_fn(out, Y_train[idxs]) opt.zero_grad() loss.backward() - opt.step() - lr_sched_bias.step() - lr_sched_non_bias.step() - return loss / (batchsize*loss_batchsize_scaler) + return (loss / (batchsize*loss_batchsize_scaler)).realize(*opt.schedule_step(), + *lr_sched_bias.schedule_step(), *lr_sched_non_bias.schedule_step()) eval_batchsize = 2500 @TinyJit @Tensor.test() def val_step() -> Tuple[Tensor, Tensor]: - # TODO with Tensor.no_grad() - Tensor.no_grad = True loss, acc = [], [] for i in range(0, X_test.size(0), eval_batchsize): - X, Y = preprocess(X_test[i:i+eval_batchsize], Y_test[i:i+eval_batchsize]) - out = model(X) + Y = Y_test[i:i+eval_batchsize] + out = model(preprocess(X_test[i:i+eval_batchsize])) loss.append(loss_fn(out, Y)) - acc.append((out.argmax(-1).one_hot(depths['num_classes']) * Y).sum() / eval_batchsize) - ret = Tensor.stack(*loss).mean() / (batchsize*loss_batchsize_scaler), Tensor.stack(*acc).mean() - Tensor.no_grad = False - return ret + acc.append((out.argmax(-1) == Y).sum() / eval_batchsize) + return Tensor.stack(*loss).mean() / (batchsize*loss_batchsize_scaler), Tensor.stack(*acc).mean() np.random.seed(1337) for epoch in range(math.ceil(hyp['misc']['train_epochs'])): diff --git a/extra/lr_scheduler.py b/extra/lr_scheduler.py index 9b2756e4fa..87ff077a40 100644 --- a/extra/lr_scheduler.py +++ b/extra/lr_scheduler.py @@ -10,9 +10,8 @@ class LR_Scheduler: def get_lr(self): pass - def step(self) -> None: - self.epoch_counter.assign(self.epoch_counter + 1).realize() - self.optimizer.lr.assign(self.get_lr()).realize() + def schedule_step(self) -> list[Tensor]: return [self.epoch_counter.assign(self.epoch_counter + 1), self.optimizer.lr.assign(self.get_lr())] + def step(self) -> None: Tensor.realize(*self.schedule_step()) class LRSchedulerGroup: def __init__(self, *schedulers: LR_Scheduler): self.schedulers = schedulers