From f13610bd78a69d26adab7af982021c24f3384944 Mon Sep 17 00:00:00 2001 From: Francis Lata Date: Mon, 17 Mar 2025 20:11:43 +0000 Subject: [PATCH] move BEAM to their respective steps --- examples/mlperf/model_train.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index a94199b41b..43ba85561f 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -351,7 +351,7 @@ def train_retinanet(): from extra.lr_scheduler import LambdaLR from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval - from tinygrad.helpers import colored + from tinygrad.helpers import colored, Context, DEBUG from tinygrad.nn.optim import Optimizer from typing import Iterator import extra.models.retinanet as retinanet @@ -392,23 +392,25 @@ def train_retinanet(): @TinyJit def _train_step(model, optim, lr_scheduler, loss_scaler, x, **kwargs): - optim.zero_grad() + with Context(BEAM=TRAIN_BEAM): + optim.zero_grad() - losses = model(normalize(x, GPUS), **kwargs) - loss = sum([l for l in losses.values()]) + losses = model(normalize(x, GPUS), **kwargs) + loss = sum([l for l in losses.values()]) - (loss * loss_scaler).backward() - for t in optim.params: t.grad = t.grad.contiguous() / loss_scaler + (loss * loss_scaler).backward() + for t in optim.params: t.grad = t.grad.contiguous() / loss_scaler - optim.step() - lr_scheduler.step() + optim.step() + lr_scheduler.step() - return loss.realize(), losses + return loss.realize(), losses @TinyJit def _eval_step(model, x, **kwargs): - out = model(normalize(x, GPUS), **kwargs) - return out.realize() + with Context(BEAM=EVAL_BEAM): + out = model(normalize(x, GPUS), **kwargs) + return out.realize() # ** hyperparameters ** config["seed"] = SEED = getenv("SEED", random.SystemRandom().randint(0, 2**32 - 1)) @@ -546,8 +548,6 @@ def train_retinanet(): if getenv("RESET_STEP", 1): _train_step.reset() with Tensor.train(mode=False), Tensor.test(): - BEAM.value = EVAL_BEAM - val_dataloader = batch_load_retinanet(val_dataset, (val:=True), Path(BASE_DIR), batch_size=EVAL_BS, shuffle=False, seed=SEED) it = iter(tqdm(val_dataloader, total=steps_in_val_epoch)) i, proc = 0, _data_get(it, val=val)