TRAIN_BEAM and EVAL_BEAM for resnet (#4177)

working on measuring compile time
This commit is contained in:
chenyu
2024-04-15 14:57:21 -04:00
committed by GitHub
parent 4592fc8fe7
commit 6a2168e698

View File

@@ -5,7 +5,7 @@ from tqdm import tqdm
import multiprocessing
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
from tinygrad.helpers import getenv, BEAM, WINO
from tinygrad.helpers import getenv, BEAM, WINO, Context
from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save
from tinygrad.nn.optim import LARS, SGD, OptimizerGroup
@@ -105,23 +105,27 @@ def train_resnet():
def normalize(x): return (x.permute([0, 3, 1, 2]) - input_mean).cast(dtypes.default_float)
@TinyJit
def train_step(X, Y):
optimizer_group.zero_grad()
X = normalize(X)
out = model.forward(X)
loss = out.cast(dtypes.float32).sparse_categorical_crossentropy(Y, label_smoothing=0.1)
top_1 = (out.argmax(-1) == Y).sum()
(loss * loss_scaler).backward()
for t in optimizer_group.params: t.grad = t.grad.contiguous() / loss_scaler
optimizer_group.step()
scheduler_group.step()
return loss.realize(), top_1.realize()
with Context(BEAM=getenv("TRAIN_BEAM", BEAM.value)):
optimizer_group.zero_grad()
X = normalize(X)
out = model.forward(X)
loss = out.cast(dtypes.float32).sparse_categorical_crossentropy(Y, label_smoothing=0.1)
top_1 = (out.argmax(-1) == Y).sum()
(loss * loss_scaler).backward()
for t in optimizer_group.params: t.grad = t.grad.contiguous() / loss_scaler
optimizer_group.step()
scheduler_group.step()
return loss.realize(), top_1.realize()
@TinyJit
def eval_step(X, Y):
X = normalize(X)
out = model.forward(X)
loss = out.cast(dtypes.float32).sparse_categorical_crossentropy(Y, label_smoothing=0.1)
top_1 = (out.argmax(-1) == Y).sum()
return loss.realize(), top_1.realize()
with Context(BEAM=getenv("EVAL_BEAM", BEAM.value)):
X = normalize(X)
out = model.forward(X)
loss = out.cast(dtypes.float32).sparse_categorical_crossentropy(Y, label_smoothing=0.1)
top_1 = (out.argmax(-1) == Y).sum()
return loss.realize(), top_1.realize()
def data_get(it):
x, y, cookie = next(it)
return x.shard(GPUS, axis=0).realize(), Tensor(y, requires_grad=False).shard(GPUS, axis=0), cookie