mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
refactor on training loop and start the work on val looop
This commit is contained in:
@@ -351,6 +351,7 @@ def train_retinanet():
|
||||
from extra.models import resnet
|
||||
from extra.lr_scheduler import LambdaLR
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
from tinygrad.helpers import get_child
|
||||
|
||||
import numpy as np
|
||||
@@ -385,7 +386,6 @@ def train_retinanet():
|
||||
return warmup_factor * (1 - alpha) + alpha
|
||||
return LambdaLR(optim, _lr_lambda)
|
||||
|
||||
@Tensor.train()
|
||||
@TinyJit
|
||||
def _train_step(model, optim, lr_scheduler, x, **kwargs):
|
||||
optim.zero_grad()
|
||||
@@ -398,6 +398,12 @@ def train_retinanet():
|
||||
lr_scheduler.step()
|
||||
|
||||
return loss.realize(), losses
|
||||
|
||||
@TinyJit
|
||||
def _eval_step(model, x, **kwargs):
|
||||
# TODO: Consider returning loss here as well
|
||||
out = model(normalize(x, GPUS), **kwargs)
|
||||
return out.realize()
|
||||
|
||||
# ** hyperparameters **
|
||||
# using https://github.com/mlcommons/logging/blob/96d0acee011ba97702532dcc39e6eeaa99ebef24/mlperf_logging/rcp_checker/training_4.1.0/rcps_ssd.json#L3
|
||||
@@ -433,9 +439,11 @@ def train_retinanet():
|
||||
# ** dataset **
|
||||
train_dataset = COCO(download_dataset(BASE_DIR, "train"))
|
||||
val_dataset = COCO(download_dataset(BASE_DIR, "validation"))
|
||||
coco_val = COCOeval(cocoGt=val_dataset, iouType="bbox")
|
||||
|
||||
# ** lr scheduler **
|
||||
config["steps_in_train_epoch"] = steps_in_train_epoch = round_up(len(train_dataset.imgs.keys()), BS) // BS
|
||||
config["steps_in_val_epoch"] = steps_in_val_epoch = (round_up(len(val_dataset.imgs.keys()), BS) // BS)
|
||||
start_iter, warmup_iters = start_epoch * steps_in_train_epoch, lr_warmup_epochs * steps_in_train_epoch
|
||||
lr_scheduler = _create_lr_scheduler(optim, start_iter, warmup_iters, lr_warmup_factor)
|
||||
|
||||
@@ -458,61 +466,90 @@ def train_retinanet():
|
||||
|
||||
print(f"training with batch size {BS} for {EPOCHS} epochs")
|
||||
|
||||
# ** training loop **
|
||||
for e in range(start_epoch, EPOCHS):
|
||||
train_dataloader = batch_load_retinanet(train_dataset, False, Path(BASE_DIR), batch_size=BS, seed=SEED)
|
||||
it = iter(tqdm(train_dataloader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
|
||||
i, proc = 0, _data_get(it)
|
||||
# ** training loop **
|
||||
with Tensor.train():
|
||||
train_dataloader = batch_load_retinanet(train_dataset, False, Path(BASE_DIR), batch_size=BS, seed=SEED)
|
||||
it = iter(tqdm(train_dataloader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
|
||||
i, proc = 0, _data_get(it)
|
||||
|
||||
prev_cookies = []
|
||||
st = time.perf_counter()
|
||||
prev_cookies = []
|
||||
st = time.perf_counter()
|
||||
|
||||
while proc is not None:
|
||||
GlobalCounters.reset()
|
||||
while proc is not None:
|
||||
GlobalCounters.reset()
|
||||
|
||||
x, y_bboxes, y_labels, matches, anchors, proc = proc
|
||||
loss, losses = _train_step(model, optim, lr_scheduler, x, labels=y_labels, matches=matches, anchors=anchors, bboxes=y_bboxes)
|
||||
x, y_bboxes, y_labels, matches, anchors, proc = proc
|
||||
loss, losses = _train_step(model, optim, lr_scheduler, x, labels=y_labels, matches=matches, anchors=anchors, bboxes=y_bboxes)
|
||||
|
||||
pt = time.perf_counter()
|
||||
pt = time.perf_counter()
|
||||
|
||||
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
|
||||
try:
|
||||
next_proc = _data_get(it)
|
||||
except StopIteration:
|
||||
next_proc = None
|
||||
|
||||
dt = time.perf_counter()
|
||||
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
|
||||
try:
|
||||
next_proc = _data_get(it)
|
||||
except StopIteration:
|
||||
next_proc = None
|
||||
|
||||
dt = time.perf_counter()
|
||||
|
||||
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
|
||||
loss = loss.item()
|
||||
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
|
||||
loss = loss.item()
|
||||
|
||||
cl = time.perf_counter()
|
||||
if BENCHMARK: step_times.append(cl - st)
|
||||
cl = time.perf_counter()
|
||||
if BENCHMARK: step_times.append(cl - st)
|
||||
|
||||
tqdm.write(
|
||||
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
|
||||
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {losses['classification_loss'].item():5.4f} classification loss, {losses['regression_loss'].item():5.4f} regression loss, "
|
||||
f"{optim.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS"
|
||||
)
|
||||
tqdm.write(
|
||||
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
|
||||
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {losses['classification_loss'].item():5.4f} classification loss, {losses['regression_loss'].item():5.4f} regression loss, "
|
||||
f"{optim.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS"
|
||||
)
|
||||
|
||||
if WANDB:
|
||||
wandb.log({"lr": optim.lr.numpy(), "train/loss": loss, "train/classification_loss": losses["classification_loss"].item(), "train/regression_loss": losses["regression_loss"].item(),
|
||||
"train/step_time": cl - st, "train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
|
||||
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": e + (i + 1) / steps_in_train_epoch})
|
||||
if WANDB:
|
||||
wandb.log({"lr": optim.lr.numpy(), "train/loss": loss, "train/classification_loss": losses["classification_loss"].item(), "train/regression_loss": losses["regression_loss"].item(),
|
||||
"train/step_time": cl - st, "train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
|
||||
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": e + (i + 1) / steps_in_train_epoch})
|
||||
|
||||
st = cl
|
||||
prev_cookies.append(proc)
|
||||
proc, next_proc = next_proc, None # return old cookie
|
||||
i += 1
|
||||
st = cl
|
||||
prev_cookies.append(proc)
|
||||
proc, next_proc = next_proc, None # return old cookie
|
||||
i += 1
|
||||
|
||||
if i == BENCHMARK:
|
||||
assert not math.isnan(loss)
|
||||
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
|
||||
estimated_total_minutes = int(median_step_time * steps_in_train_epoch * EPOCHS / 60)
|
||||
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
|
||||
print(f"epoch global_ops: {steps_in_train_epoch * GlobalCounters.global_ops:_}, "
|
||||
f"epoch global_mem: {steps_in_train_epoch * GlobalCounters.global_mem:_}")
|
||||
return
|
||||
if i == BENCHMARK:
|
||||
assert not math.isnan(loss)
|
||||
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
|
||||
estimated_total_minutes = int(median_step_time * steps_in_train_epoch * EPOCHS / 60)
|
||||
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
|
||||
print(f"epoch global_ops: {steps_in_train_epoch * GlobalCounters.global_ops:_}, "
|
||||
f"epoch global_mem: {steps_in_train_epoch * GlobalCounters.global_mem:_}")
|
||||
return
|
||||
|
||||
# ** eval loop **
|
||||
with Tensor.train(mode=False):
|
||||
val_dataloader = batch_load_retinanet(val_dataset, True, Path(BASE_DIR), batch_size=BS, seed=SEED)
|
||||
it = iter(tqdm(val_dataloader, total=steps_in_val_epoch, desc=f"epoch {e}", disable=BENCHMARK))
|
||||
i, proc = 0, _data_get(it)
|
||||
|
||||
eval_times, prev_cookies = [], []
|
||||
|
||||
while proc is not None:
|
||||
GlobalCounters.reset()
|
||||
st = time.time()
|
||||
|
||||
out, proc = _eval_step(model, proc[0]), proc[1]
|
||||
out = model.postprocess_detections(out)
|
||||
|
||||
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
|
||||
try:
|
||||
next_proc = _data_get(it)
|
||||
except StopIteration:
|
||||
next_proc = None
|
||||
|
||||
prev_cookies.append(proc)
|
||||
proc, next_proc = next_proc, None
|
||||
i += 1
|
||||
|
||||
et = time.time()
|
||||
eval_times.append(et - st)
|
||||
|
||||
if getenv("CKPT"):
|
||||
if not os.path.exists(ckpt_dir := Path(getenv("CKPT_DIR", "./ckpts"))): os.mkdir(ckpt_dir)
|
||||
|
||||
Reference in New Issue
Block a user