add eval_freq flag (#9894)

This commit is contained in:
Francis Lata
2025-04-15 06:42:40 -04:00
committed by GitHub
parent 83ae83d871
commit 31483050c0

View File

@@ -411,6 +411,7 @@ def train_retinanet():
config["lr"] = lr = getenv("LR", 8.5e-5 * (BS / 96))
config["loss_scaler"] = loss_scaler = getenv("LOSS_SCALER", 2**11 if dtypes.default_float == dtypes.float16 else 1.0)
config["default_float"] = dtypes.default_float.name
config["eval_freq"] = eval_freq = getenv("EVAL_FREQ", 1)
if SEED: Tensor.manual_seed(SEED)
@@ -518,68 +519,69 @@ def train_retinanet():
return
# ** eval loop **
if getenv("RESET_STEP", 1): _train_step.reset()
if (e + 1) % eval_freq == 0:
if getenv("RESET_STEP", 1): _train_step.reset()
with Tensor.train(mode=False), Tensor.test():
val_dataloader = batch_load_retinanet(val_dataset, (val:=True), Path(BASE_DIR), batch_size=EVAL_BS, shuffle=False, seed=SEED)
it = iter(tqdm(val_dataloader, total=steps_in_val_epoch))
i, proc = 0, _data_get(it, val=val)
with Tensor.train(mode=False), Tensor.test():
val_dataloader = batch_load_retinanet(val_dataset, (val:=True), Path(BASE_DIR), batch_size=EVAL_BS, shuffle=False, seed=SEED)
it = iter(tqdm(val_dataloader, total=steps_in_val_epoch))
i, proc = 0, _data_get(it, val=val)
eval_times, prev_cookies = [], []
val_img_ids, val_imgs, ncats, narea = [], [], len(coco_val.params.catIds), len(coco_val.params.areaRng)
eval_times, prev_cookies = [], []
val_img_ids, val_imgs, ncats, narea = [], [], len(coco_val.params.catIds), len(coco_val.params.areaRng)
while proc is not None:
GlobalCounters.reset()
st = time.time()
while proc is not None:
GlobalCounters.reset()
st = time.time()
out, img_ids, img_sizes, proc = _eval_step(model, (x:=proc[0])).numpy(), proc[1], proc[2], proc[3]
out = model.postprocess_detections(out, input_size=x.shape[1:3], orig_image_sizes=img_sizes)
coco_results = [{"image_id": img_ids[i], "category_id": label, "bbox": box.tolist(), "score": score}
for i, prediction in enumerate(out) for box, score, label in zip(*prediction.values())]
out, img_ids, img_sizes, proc = _eval_step(model, (x:=proc[0])).numpy(), proc[1], proc[2], proc[3]
out = model.postprocess_detections(out, input_size=x.shape[1:3], orig_image_sizes=img_sizes)
coco_results = [{"image_id": img_ids[i], "category_id": label, "bbox": box.tolist(), "score": score}
for i, prediction in enumerate(out) for box, score, label in zip(*prediction.values())]
with redirect_stdout(None):
coco_val.cocoDt = val_dataset.loadRes(coco_results)
coco_val.params.imgIds = img_ids
coco_val.evaluate()
with redirect_stdout(None):
coco_val.cocoDt = val_dataset.loadRes(coco_results)
coco_val.params.imgIds = img_ids
coco_val.evaluate()
val_img_ids.extend(img_ids)
val_imgs.append(np.array(coco_val.evalImgs).reshape(ncats, narea, len(img_ids)))
val_img_ids.extend(img_ids)
val_imgs.append(np.array(coco_val.evalImgs).reshape(ncats, narea, len(img_ids)))
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
try:
next_proc = _data_get(it, val=val)
except StopIteration:
next_proc = None
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
try:
next_proc = _data_get(it, val=val)
except StopIteration:
next_proc = None
prev_cookies.append(proc)
proc, next_proc = next_proc, None
i += 1
prev_cookies.append(proc)
proc, next_proc = next_proc, None
i += 1
if i == BENCHMARK:
return
if i == BENCHMARK:
return
et = time.time()
eval_times.append(et - st)
et = time.time()
eval_times.append(et - st)
if getenv("RESET_STEP", 1): _eval_step.reset()
total_fw_time = sum(eval_times) / len(eval_times)
if getenv("RESET_STEP", 1): _eval_step.reset()
total_fw_time = sum(eval_times) / len(eval_times)
coco_val.params.imgIds = val_img_ids
coco_val._paramsEval.imgIds = val_img_ids
coco_val.evalImgs = list(np.concatenate(val_imgs, -1).flatten())
coco_val.accumulate()
coco_val.summarize()
coco_val.params.imgIds = val_img_ids
coco_val._paramsEval.imgIds = val_img_ids
coco_val.evalImgs = list(np.concatenate(val_imgs, -1).flatten())
coco_val.accumulate()
coco_val.summarize()
val_metric = coco_val.stats[0]
val_metric = coco_val.stats[0]
tqdm.write(f"eval time: {total_fw_time:.2f}, eval metric: {val_metric:.4f}")
tqdm.write(f"eval time: {total_fw_time:.2f}, eval metric: {val_metric:.4f}")
if WANDB:
wandb.log({"eval/forward_time": total_fw_time, "eval/metric": val_metric, "epoch": e + 1})
if WANDB:
wandb.log({"eval/forward_time": total_fw_time, "eval/metric": val_metric, "epoch": e + 1})
if val_metric >= target_metric:
print(colored(f"target metric reached: {val_metric:.2f}/{target_metric:.2f}", color="green"))
break
if val_metric >= target_metric:
print(colored(f"target metric reached: {val_metric:.2f}/{target_metric:.2f}", color="green"))
break
def train_unet3d():
"""