mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
continue work on beautiful cifar (#10555)
This commit is contained in:
@@ -3,16 +3,19 @@ start_tm = time.perf_counter()
|
||||
import math
|
||||
from typing import Tuple, cast
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, nn, GlobalCounters, TinyJit, dtypes
|
||||
from tinygrad import Tensor, nn, GlobalCounters, TinyJit, dtypes, Device
|
||||
from tinygrad.helpers import partition, trange, getenv, Context
|
||||
from extra.lr_scheduler import OneCycleLR
|
||||
|
||||
GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 1))]
|
||||
|
||||
# override tinygrad defaults
|
||||
dtypes.default_float = dtypes.half
|
||||
Context(FUSE_ARANGE=1, FUSE_OPTIM=1).__enter__()
|
||||
|
||||
# from https://github.com/tysam-code/hlb-CIFAR10/blob/main/main.py
|
||||
batchsize = getenv("BS", 1024)
|
||||
assert batchsize % len(GPUS) == 0, f"{batchsize=} is not a multiple of {len(GPUS)=}"
|
||||
bias_scaler = 64
|
||||
hyp = {
|
||||
'opt': {
|
||||
@@ -94,8 +97,13 @@ if __name__ == "__main__":
|
||||
|
||||
# *** model ***
|
||||
model = SpeedyConvNet()
|
||||
state_dict = nn.state.get_state_dict(model)
|
||||
if len(GPUS) > 1:
|
||||
cifar10_std.to_(GPUS)
|
||||
cifar10_mean.to_(GPUS)
|
||||
for x in state_dict.values(): x.to_(GPUS)
|
||||
|
||||
params_bias, params_non_bias = partition(nn.state.get_state_dict(model).items(), lambda x: 'bias' in x[0])
|
||||
params_bias, params_non_bias = partition(state_dict.items(), lambda x: 'bias' in x[0])
|
||||
opt_bias = nn.optim.SGD([x[1] for x in params_bias], lr=0.01, momentum=.85, nesterov=True, weight_decay=hyp['opt']['bias_decay'])
|
||||
opt_non_bias = nn.optim.SGD([x[1] for x in params_non_bias], lr=0.01, momentum=.85, nesterov=True, weight_decay=hyp['opt']['non_bias_decay'])
|
||||
opt = nn.optim.OptimizerGroup(opt_bias, opt_non_bias)
|
||||
@@ -117,8 +125,12 @@ if __name__ == "__main__":
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
def train_step(idxs:Tensor) -> Tensor:
|
||||
out = model(preprocess(X_train[idxs]))
|
||||
loss = loss_fn(out, Y_train[idxs])
|
||||
X, Y = X_train[idxs], Y_train[idxs]
|
||||
if len(GPUS) > 1:
|
||||
X.shard_(GPUS, axis=0)
|
||||
Y.shard_(GPUS, axis=0)
|
||||
out = model(preprocess(X))
|
||||
loss = loss_fn(out, Y)
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
return (loss / (batchsize*loss_batchsize_scaler)).realize(*opt.schedule_step(),
|
||||
@@ -130,8 +142,11 @@ if __name__ == "__main__":
|
||||
def val_step() -> Tuple[Tensor, Tensor]:
|
||||
loss, acc = [], []
|
||||
for i in range(0, X_test.size(0), eval_batchsize):
|
||||
Y = Y_test[i:i+eval_batchsize]
|
||||
out = model(preprocess(X_test[i:i+eval_batchsize]))
|
||||
X, Y = X_test[i:i+eval_batchsize], Y_test[i:i+eval_batchsize]
|
||||
if len(GPUS) > 1:
|
||||
X.shard_(GPUS, axis=0)
|
||||
Y.shard_(GPUS, axis=0)
|
||||
out = model(preprocess(X))
|
||||
loss.append(loss_fn(out, Y))
|
||||
acc.append((out.argmax(-1) == Y).sum() / eval_batchsize)
|
||||
return Tensor.stack(*loss).mean() / (batchsize*loss_batchsize_scaler), Tensor.stack(*acc).mean()
|
||||
|
||||
Reference in New Issue
Block a user