mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
TRAIN_BEAM and EVAL_BEAM for resnet (#4177)
working on measuring compile time
This commit is contained in:
@@ -5,7 +5,7 @@ from tqdm import tqdm
|
||||
import multiprocessing
|
||||
|
||||
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
|
||||
from tinygrad.helpers import getenv, BEAM, WINO
|
||||
from tinygrad.helpers import getenv, BEAM, WINO, Context
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save
|
||||
from tinygrad.nn.optim import LARS, SGD, OptimizerGroup
|
||||
|
||||
@@ -105,23 +105,27 @@ def train_resnet():
|
||||
def normalize(x): return (x.permute([0, 3, 1, 2]) - input_mean).cast(dtypes.default_float)
|
||||
@TinyJit
|
||||
def train_step(X, Y):
|
||||
optimizer_group.zero_grad()
|
||||
X = normalize(X)
|
||||
out = model.forward(X)
|
||||
loss = out.cast(dtypes.float32).sparse_categorical_crossentropy(Y, label_smoothing=0.1)
|
||||
top_1 = (out.argmax(-1) == Y).sum()
|
||||
(loss * loss_scaler).backward()
|
||||
for t in optimizer_group.params: t.grad = t.grad.contiguous() / loss_scaler
|
||||
optimizer_group.step()
|
||||
scheduler_group.step()
|
||||
return loss.realize(), top_1.realize()
|
||||
with Context(BEAM=getenv("TRAIN_BEAM", BEAM.value)):
|
||||
optimizer_group.zero_grad()
|
||||
X = normalize(X)
|
||||
out = model.forward(X)
|
||||
loss = out.cast(dtypes.float32).sparse_categorical_crossentropy(Y, label_smoothing=0.1)
|
||||
top_1 = (out.argmax(-1) == Y).sum()
|
||||
(loss * loss_scaler).backward()
|
||||
for t in optimizer_group.params: t.grad = t.grad.contiguous() / loss_scaler
|
||||
optimizer_group.step()
|
||||
scheduler_group.step()
|
||||
return loss.realize(), top_1.realize()
|
||||
|
||||
@TinyJit
|
||||
def eval_step(X, Y):
|
||||
X = normalize(X)
|
||||
out = model.forward(X)
|
||||
loss = out.cast(dtypes.float32).sparse_categorical_crossentropy(Y, label_smoothing=0.1)
|
||||
top_1 = (out.argmax(-1) == Y).sum()
|
||||
return loss.realize(), top_1.realize()
|
||||
with Context(BEAM=getenv("EVAL_BEAM", BEAM.value)):
|
||||
X = normalize(X)
|
||||
out = model.forward(X)
|
||||
loss = out.cast(dtypes.float32).sparse_categorical_crossentropy(Y, label_smoothing=0.1)
|
||||
top_1 = (out.argmax(-1) == Y).sum()
|
||||
return loss.realize(), top_1.realize()
|
||||
|
||||
def data_get(it):
|
||||
x, y, cookie = next(it)
|
||||
return x.shard(GPUS, axis=0).realize(), Tensor(y, requires_grad=False).shard(GPUS, axis=0), cookie
|
||||
|
||||
Reference in New Issue
Block a user