mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
move BEAM to their respective steps
This commit is contained in:
@@ -351,7 +351,7 @@ def train_retinanet():
|
||||
from extra.lr_scheduler import LambdaLR
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
from tinygrad.helpers import colored
|
||||
from tinygrad.helpers import colored, Context, DEBUG
|
||||
from tinygrad.nn.optim import Optimizer
|
||||
from typing import Iterator
|
||||
import extra.models.retinanet as retinanet
|
||||
@@ -392,23 +392,25 @@ def train_retinanet():
|
||||
|
||||
@TinyJit
|
||||
def _train_step(model, optim, lr_scheduler, loss_scaler, x, **kwargs):
|
||||
optim.zero_grad()
|
||||
with Context(BEAM=TRAIN_BEAM):
|
||||
optim.zero_grad()
|
||||
|
||||
losses = model(normalize(x, GPUS), **kwargs)
|
||||
loss = sum([l for l in losses.values()])
|
||||
losses = model(normalize(x, GPUS), **kwargs)
|
||||
loss = sum([l for l in losses.values()])
|
||||
|
||||
(loss * loss_scaler).backward()
|
||||
for t in optim.params: t.grad = t.grad.contiguous() / loss_scaler
|
||||
(loss * loss_scaler).backward()
|
||||
for t in optim.params: t.grad = t.grad.contiguous() / loss_scaler
|
||||
|
||||
optim.step()
|
||||
lr_scheduler.step()
|
||||
optim.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
return loss.realize(), losses
|
||||
return loss.realize(), losses
|
||||
|
||||
@TinyJit
|
||||
def _eval_step(model, x, **kwargs):
|
||||
out = model(normalize(x, GPUS), **kwargs)
|
||||
return out.realize()
|
||||
with Context(BEAM=EVAL_BEAM):
|
||||
out = model(normalize(x, GPUS), **kwargs)
|
||||
return out.realize()
|
||||
|
||||
# ** hyperparameters **
|
||||
config["seed"] = SEED = getenv("SEED", random.SystemRandom().randint(0, 2**32 - 1))
|
||||
@@ -546,8 +548,6 @@ def train_retinanet():
|
||||
if getenv("RESET_STEP", 1): _train_step.reset()
|
||||
|
||||
with Tensor.train(mode=False), Tensor.test():
|
||||
BEAM.value = EVAL_BEAM
|
||||
|
||||
val_dataloader = batch_load_retinanet(val_dataset, (val:=True), Path(BASE_DIR), batch_size=EVAL_BS, shuffle=False, seed=SEED)
|
||||
it = iter(tqdm(val_dataloader, total=steps_in_val_epoch))
|
||||
i, proc = 0, _data_get(it, val=val)
|
||||
|
||||
Reference in New Issue
Block a user