continue work on beautiful cifar (#10555)

This commit is contained in:
George Hotz
2025-05-28 21:42:01 -07:00
committed by GitHub
parent e140f8f0d8
commit e4e7b5d7e1

View File

@@ -3,16 +3,19 @@ start_tm = time.perf_counter()
import math import math
from typing import Tuple, cast from typing import Tuple, cast
import numpy as np import numpy as np
from tinygrad import Tensor, nn, GlobalCounters, TinyJit, dtypes from tinygrad import Tensor, nn, GlobalCounters, TinyJit, dtypes, Device
from tinygrad.helpers import partition, trange, getenv, Context from tinygrad.helpers import partition, trange, getenv, Context
from extra.lr_scheduler import OneCycleLR from extra.lr_scheduler import OneCycleLR
GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 1))]
# override tinygrad defaults # override tinygrad defaults
dtypes.default_float = dtypes.half dtypes.default_float = dtypes.half
Context(FUSE_ARANGE=1, FUSE_OPTIM=1).__enter__() Context(FUSE_ARANGE=1, FUSE_OPTIM=1).__enter__()
# from https://github.com/tysam-code/hlb-CIFAR10/blob/main/main.py # from https://github.com/tysam-code/hlb-CIFAR10/blob/main/main.py
batchsize = getenv("BS", 1024) batchsize = getenv("BS", 1024)
assert batchsize % len(GPUS) == 0, f"{batchsize=} is not a multiple of {len(GPUS)=}"
bias_scaler = 64 bias_scaler = 64
hyp = { hyp = {
'opt': { 'opt': {
@@ -94,8 +97,13 @@ if __name__ == "__main__":
# *** model *** # *** model ***
model = SpeedyConvNet() model = SpeedyConvNet()
state_dict = nn.state.get_state_dict(model)
if len(GPUS) > 1:
cifar10_std.to_(GPUS)
cifar10_mean.to_(GPUS)
for x in state_dict.values(): x.to_(GPUS)
params_bias, params_non_bias = partition(nn.state.get_state_dict(model).items(), lambda x: 'bias' in x[0]) params_bias, params_non_bias = partition(state_dict.items(), lambda x: 'bias' in x[0])
opt_bias = nn.optim.SGD([x[1] for x in params_bias], lr=0.01, momentum=.85, nesterov=True, weight_decay=hyp['opt']['bias_decay']) opt_bias = nn.optim.SGD([x[1] for x in params_bias], lr=0.01, momentum=.85, nesterov=True, weight_decay=hyp['opt']['bias_decay'])
opt_non_bias = nn.optim.SGD([x[1] for x in params_non_bias], lr=0.01, momentum=.85, nesterov=True, weight_decay=hyp['opt']['non_bias_decay']) opt_non_bias = nn.optim.SGD([x[1] for x in params_non_bias], lr=0.01, momentum=.85, nesterov=True, weight_decay=hyp['opt']['non_bias_decay'])
opt = nn.optim.OptimizerGroup(opt_bias, opt_non_bias) opt = nn.optim.OptimizerGroup(opt_bias, opt_non_bias)
@@ -117,8 +125,12 @@ if __name__ == "__main__":
@TinyJit @TinyJit
@Tensor.train() @Tensor.train()
def train_step(idxs:Tensor) -> Tensor: def train_step(idxs:Tensor) -> Tensor:
out = model(preprocess(X_train[idxs])) X, Y = X_train[idxs], Y_train[idxs]
loss = loss_fn(out, Y_train[idxs]) if len(GPUS) > 1:
X.shard_(GPUS, axis=0)
Y.shard_(GPUS, axis=0)
out = model(preprocess(X))
loss = loss_fn(out, Y)
opt.zero_grad() opt.zero_grad()
loss.backward() loss.backward()
return (loss / (batchsize*loss_batchsize_scaler)).realize(*opt.schedule_step(), return (loss / (batchsize*loss_batchsize_scaler)).realize(*opt.schedule_step(),
@@ -130,8 +142,11 @@ if __name__ == "__main__":
def val_step() -> Tuple[Tensor, Tensor]: def val_step() -> Tuple[Tensor, Tensor]:
loss, acc = [], [] loss, acc = [], []
for i in range(0, X_test.size(0), eval_batchsize): for i in range(0, X_test.size(0), eval_batchsize):
Y = Y_test[i:i+eval_batchsize] X, Y = X_test[i:i+eval_batchsize], Y_test[i:i+eval_batchsize]
out = model(preprocess(X_test[i:i+eval_batchsize])) if len(GPUS) > 1:
X.shard_(GPUS, axis=0)
Y.shard_(GPUS, axis=0)
out = model(preprocess(X))
loss.append(loss_fn(out, Y)) loss.append(loss_fn(out, Y))
acc.append((out.argmax(-1) == Y).sum() / eval_batchsize) acc.append((out.argmax(-1) == Y).sum() / eval_batchsize)
return Tensor.stack(*loss).mean() / (batchsize*loss_batchsize_scaler), Tensor.stack(*acc).mean() return Tensor.stack(*loss).mean() / (batchsize*loss_batchsize_scaler), Tensor.stack(*acc).mean()