add JIT reset support

This commit is contained in:
Francis Lata
2025-02-25 10:52:26 +00:00
parent 30d5daa121
commit 8737020d75

View File

@@ -523,6 +523,8 @@ def train_retinanet():
return
# ** eval loop **
if getenv("RESET_STEP", 1): _train_step.reset()
with Tensor.train(mode=False), Tensor.test():
val_dataloader = batch_load_retinanet(val_dataset, (val:=True), Path(BASE_DIR), batch_size=BS, shuffle=False, seed=SEED)
it = iter(tqdm(val_dataloader, total=steps_in_val_epoch))
@@ -561,6 +563,7 @@ def train_retinanet():
et = time.time()
eval_times.append(et - st)
if getenv("RESET_STEP", 1): _eval_step.reset()
total_fw_time = sum(eval_times) / len(eval_times)
tqdm.write(f"eval time: {total_fw_time:.2f}")