mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
more beautiful cifar (#10551)
* enumerate cases of Tensors in the JIT * optional fused optimizers * add fused optimizer test * move that there * ugh * work on beautiful_cifar * speed close to hlb_cifar * schedule to corealize all * one line sched step * less lines
This commit is contained in:
@@ -7,7 +7,9 @@ from tinygrad import Tensor, nn, GlobalCounters, TinyJit, dtypes
|
||||
from tinygrad.helpers import partition, trange, getenv, Context
|
||||
from extra.lr_scheduler import OneCycleLR
|
||||
|
||||
# override tinygrad defaults
|
||||
dtypes.default_float = dtypes.half
|
||||
Context(FUSE_ARANGE=1, FUSE_OPTIM=1).__enter__()
|
||||
|
||||
# from https://github.com/tysam-code/hlb-CIFAR10/blob/main/main.py
|
||||
batchsize = getenv("BS", 1024)
|
||||
@@ -67,7 +69,7 @@ class ConvGroup:
|
||||
cast(Tensor, self.norm2.weight).requires_grad = False
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
x = self.norm1(self.conv1(x).max_pool2d().float()).cast(dtypes.default_float).quick_gelu()
|
||||
return self.norm2(self.conv2(x).float()).cast(dtypes.default_float).quick_gelu()
|
||||
return self.norm2(self.conv2(x).float()).cast(dtypes.default_float).quick_gelu() + x
|
||||
|
||||
class SpeedyConvNet:
|
||||
def __init__(self):
|
||||
@@ -78,25 +80,22 @@ class SpeedyConvNet:
|
||||
self.linear = nn.Linear(depths['block3'], depths['num_classes'], bias=False)
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
x = self.whiten(x).quick_gelu()
|
||||
# ************* HACKS *************
|
||||
x = x.pad((1,0,0,1)) # TODO: this pad should not be here! copied from hlb_cifar10 for speed
|
||||
# ************* HACKS *************
|
||||
x = x.sequential([self.conv_group_1, self.conv_group_2, self.conv_group_3])
|
||||
return self.linear(x.max(axis=(2,3))) * hyp['opt']['scaling_factor']
|
||||
|
||||
if __name__ == "__main__":
|
||||
# *** dataset ***
|
||||
X_train, Y_train, X_test, Y_test = nn.datasets.cifar()
|
||||
# TODO: without this line indexing doesn't fuse!
|
||||
X_train, Y_train, X_test, Y_test = [x.contiguous() for x in [X_train, Y_train, X_test, Y_test]]
|
||||
cifar10_std, cifar10_mean = X_train.float().std_mean(axis=(0, 2, 3))
|
||||
def preprocess(X:Tensor, Y:Tensor) -> Tuple[Tensor, Tensor]:
|
||||
return ((X - cifar10_mean.view(1, -1, 1, 1)) / cifar10_std.view(1, -1, 1, 1)).cast(dtypes.default_float), Y.one_hot(depths['num_classes'])
|
||||
def preprocess(X:Tensor) -> Tensor: return ((X - cifar10_mean.view(1, -1, 1, 1)) / cifar10_std.view(1, -1, 1, 1)).cast(dtypes.default_float)
|
||||
|
||||
# *** model ***
|
||||
model = SpeedyConvNet()
|
||||
state_dict = nn.state.get_state_dict(model)
|
||||
|
||||
#for k,v in nn.state.torch_load("/tmp/cifar_net.pt").items(): print(k)
|
||||
|
||||
params_bias, params_non_bias = partition(state_dict.items(), lambda x: 'bias' in x[0])
|
||||
params_bias, params_non_bias = partition(nn.state.get_state_dict(model).items(), lambda x: 'bias' in x[0])
|
||||
opt_bias = nn.optim.SGD([x[1] for x in params_bias], lr=0.01, momentum=.85, nesterov=True, weight_decay=hyp['opt']['bias_decay'])
|
||||
opt_non_bias = nn.optim.SGD([x[1] for x in params_non_bias], lr=0.01, momentum=.85, nesterov=True, weight_decay=hyp['opt']['non_bias_decay'])
|
||||
opt = nn.optim.OptimizerGroup(opt_bias, opt_non_bias)
|
||||
@@ -111,40 +110,31 @@ if __name__ == "__main__":
|
||||
lr_sched_bias = OneCycleLR(opt_bias, max_lr=hyp['opt']['bias_lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps)
|
||||
lr_sched_non_bias = OneCycleLR(opt_non_bias, max_lr=hyp['opt']['non_bias_lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps)
|
||||
|
||||
def loss_fn(out, Y):
|
||||
return out.cross_entropy(Y, reduction='none', label_smoothing=0.2).mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
|
||||
def loss_fn(out:Tensor, Y:Tensor) -> Tensor:
|
||||
ret = out.sparse_categorical_crossentropy(Y, reduction='none', label_smoothing=0.2)
|
||||
return ret.mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
def train_step(idxs:Tensor) -> Tensor:
|
||||
with Context(SPLIT_REDUCEOP=0, FUSE_ARANGE=1):
|
||||
X = X_train[idxs]
|
||||
Y = Y_train[idxs].realize(X)
|
||||
X, Y = preprocess(X, Y)
|
||||
out = model(X)
|
||||
loss = loss_fn(out, Y)
|
||||
out = model(preprocess(X_train[idxs]))
|
||||
loss = loss_fn(out, Y_train[idxs])
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
lr_sched_bias.step()
|
||||
lr_sched_non_bias.step()
|
||||
return loss / (batchsize*loss_batchsize_scaler)
|
||||
return (loss / (batchsize*loss_batchsize_scaler)).realize(*opt.schedule_step(),
|
||||
*lr_sched_bias.schedule_step(), *lr_sched_non_bias.schedule_step())
|
||||
|
||||
eval_batchsize = 2500
|
||||
@TinyJit
|
||||
@Tensor.test()
|
||||
def val_step() -> Tuple[Tensor, Tensor]:
|
||||
# TODO with Tensor.no_grad()
|
||||
Tensor.no_grad = True
|
||||
loss, acc = [], []
|
||||
for i in range(0, X_test.size(0), eval_batchsize):
|
||||
X, Y = preprocess(X_test[i:i+eval_batchsize], Y_test[i:i+eval_batchsize])
|
||||
out = model(X)
|
||||
Y = Y_test[i:i+eval_batchsize]
|
||||
out = model(preprocess(X_test[i:i+eval_batchsize]))
|
||||
loss.append(loss_fn(out, Y))
|
||||
acc.append((out.argmax(-1).one_hot(depths['num_classes']) * Y).sum() / eval_batchsize)
|
||||
ret = Tensor.stack(*loss).mean() / (batchsize*loss_batchsize_scaler), Tensor.stack(*acc).mean()
|
||||
Tensor.no_grad = False
|
||||
return ret
|
||||
acc.append((out.argmax(-1) == Y).sum() / eval_batchsize)
|
||||
return Tensor.stack(*loss).mean() / (batchsize*loss_batchsize_scaler), Tensor.stack(*acc).mean()
|
||||
|
||||
np.random.seed(1337)
|
||||
for epoch in range(math.ceil(hyp['misc']['train_epochs'])):
|
||||
|
||||
@@ -10,9 +10,8 @@ class LR_Scheduler:
|
||||
|
||||
def get_lr(self): pass
|
||||
|
||||
def step(self) -> None:
|
||||
self.epoch_counter.assign(self.epoch_counter + 1).realize()
|
||||
self.optimizer.lr.assign(self.get_lr()).realize()
|
||||
def schedule_step(self) -> list[Tensor]: return [self.epoch_counter.assign(self.epoch_counter + 1), self.optimizer.lr.assign(self.get_lr())]
|
||||
def step(self) -> None: Tensor.realize(*self.schedule_step())
|
||||
|
||||
class LRSchedulerGroup:
|
||||
def __init__(self, *schedulers: LR_Scheduler): self.schedulers = schedulers
|
||||
|
||||
Reference in New Issue
Block a user