more beautiful cifar (#10551)

* enumerate cases of Tensors in the JIT

* optional fused optimizers

* add fused optimizer test

* move that there

* ugh

* work on beautiful_cifar

* speed close to hlb_cifar

* schedule to corealize all

* one line sched step

* less lines
This commit is contained in:
George Hotz
2025-05-28 20:48:20 -07:00
committed by GitHub
parent ee12e801a3
commit 871df1436a
2 changed files with 21 additions and 32 deletions

View File

@@ -7,7 +7,9 @@ from tinygrad import Tensor, nn, GlobalCounters, TinyJit, dtypes
from tinygrad.helpers import partition, trange, getenv, Context
from extra.lr_scheduler import OneCycleLR
# override tinygrad defaults
dtypes.default_float = dtypes.half
Context(FUSE_ARANGE=1, FUSE_OPTIM=1).__enter__()
# from https://github.com/tysam-code/hlb-CIFAR10/blob/main/main.py
batchsize = getenv("BS", 1024)
@@ -67,7 +69,7 @@ class ConvGroup:
cast(Tensor, self.norm2.weight).requires_grad = False
def __call__(self, x:Tensor) -> Tensor:
x = self.norm1(self.conv1(x).max_pool2d().float()).cast(dtypes.default_float).quick_gelu()
return self.norm2(self.conv2(x).float()).cast(dtypes.default_float).quick_gelu()
return self.norm2(self.conv2(x).float()).cast(dtypes.default_float).quick_gelu() + x
class SpeedyConvNet:
def __init__(self):
@@ -78,25 +80,22 @@ class SpeedyConvNet:
self.linear = nn.Linear(depths['block3'], depths['num_classes'], bias=False)
def __call__(self, x:Tensor) -> Tensor:
x = self.whiten(x).quick_gelu()
# ************* HACKS *************
x = x.pad((1,0,0,1)) # TODO: this pad should not be here! copied from hlb_cifar10 for speed
# ************* HACKS *************
x = x.sequential([self.conv_group_1, self.conv_group_2, self.conv_group_3])
return self.linear(x.max(axis=(2,3))) * hyp['opt']['scaling_factor']
if __name__ == "__main__":
# *** dataset ***
X_train, Y_train, X_test, Y_test = nn.datasets.cifar()
# TODO: without this line indexing doesn't fuse!
X_train, Y_train, X_test, Y_test = [x.contiguous() for x in [X_train, Y_train, X_test, Y_test]]
cifar10_std, cifar10_mean = X_train.float().std_mean(axis=(0, 2, 3))
def preprocess(X:Tensor, Y:Tensor) -> Tuple[Tensor, Tensor]:
return ((X - cifar10_mean.view(1, -1, 1, 1)) / cifar10_std.view(1, -1, 1, 1)).cast(dtypes.default_float), Y.one_hot(depths['num_classes'])
def preprocess(X:Tensor) -> Tensor: return ((X - cifar10_mean.view(1, -1, 1, 1)) / cifar10_std.view(1, -1, 1, 1)).cast(dtypes.default_float)
# *** model ***
model = SpeedyConvNet()
state_dict = nn.state.get_state_dict(model)
#for k,v in nn.state.torch_load("/tmp/cifar_net.pt").items(): print(k)
params_bias, params_non_bias = partition(state_dict.items(), lambda x: 'bias' in x[0])
params_bias, params_non_bias = partition(nn.state.get_state_dict(model).items(), lambda x: 'bias' in x[0])
opt_bias = nn.optim.SGD([x[1] for x in params_bias], lr=0.01, momentum=.85, nesterov=True, weight_decay=hyp['opt']['bias_decay'])
opt_non_bias = nn.optim.SGD([x[1] for x in params_non_bias], lr=0.01, momentum=.85, nesterov=True, weight_decay=hyp['opt']['non_bias_decay'])
opt = nn.optim.OptimizerGroup(opt_bias, opt_non_bias)
@@ -111,40 +110,31 @@ if __name__ == "__main__":
lr_sched_bias = OneCycleLR(opt_bias, max_lr=hyp['opt']['bias_lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps)
lr_sched_non_bias = OneCycleLR(opt_non_bias, max_lr=hyp['opt']['non_bias_lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps)
def loss_fn(out, Y):
return out.cross_entropy(Y, reduction='none', label_smoothing=0.2).mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
def loss_fn(out:Tensor, Y:Tensor) -> Tensor:
ret = out.sparse_categorical_crossentropy(Y, reduction='none', label_smoothing=0.2)
return ret.mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
@TinyJit
@Tensor.train()
def train_step(idxs:Tensor) -> Tensor:
with Context(SPLIT_REDUCEOP=0, FUSE_ARANGE=1):
X = X_train[idxs]
Y = Y_train[idxs].realize(X)
X, Y = preprocess(X, Y)
out = model(X)
loss = loss_fn(out, Y)
out = model(preprocess(X_train[idxs]))
loss = loss_fn(out, Y_train[idxs])
opt.zero_grad()
loss.backward()
opt.step()
lr_sched_bias.step()
lr_sched_non_bias.step()
return loss / (batchsize*loss_batchsize_scaler)
return (loss / (batchsize*loss_batchsize_scaler)).realize(*opt.schedule_step(),
*lr_sched_bias.schedule_step(), *lr_sched_non_bias.schedule_step())
eval_batchsize = 2500
@TinyJit
@Tensor.test()
def val_step() -> Tuple[Tensor, Tensor]:
# TODO with Tensor.no_grad()
Tensor.no_grad = True
loss, acc = [], []
for i in range(0, X_test.size(0), eval_batchsize):
X, Y = preprocess(X_test[i:i+eval_batchsize], Y_test[i:i+eval_batchsize])
out = model(X)
Y = Y_test[i:i+eval_batchsize]
out = model(preprocess(X_test[i:i+eval_batchsize]))
loss.append(loss_fn(out, Y))
acc.append((out.argmax(-1).one_hot(depths['num_classes']) * Y).sum() / eval_batchsize)
ret = Tensor.stack(*loss).mean() / (batchsize*loss_batchsize_scaler), Tensor.stack(*acc).mean()
Tensor.no_grad = False
return ret
acc.append((out.argmax(-1) == Y).sum() / eval_batchsize)
return Tensor.stack(*loss).mean() / (batchsize*loss_batchsize_scaler), Tensor.stack(*acc).mean()
np.random.seed(1337)
for epoch in range(math.ceil(hyp['misc']['train_epochs'])):

View File

@@ -10,9 +10,8 @@ class LR_Scheduler:
def get_lr(self): pass
def step(self) -> None:
self.epoch_counter.assign(self.epoch_counter + 1).realize()
self.optimizer.lr.assign(self.get_lr()).realize()
def schedule_step(self) -> list[Tensor]: return [self.epoch_counter.assign(self.epoch_counter + 1), self.optimizer.lr.assign(self.get_lr())]
def step(self) -> None: Tensor.realize(*self.schedule_step())
class LRSchedulerGroup:
def __init__(self, *schedulers: LR_Scheduler): self.schedulers = schedulers