mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
add eval_freq flag (#9894)
This commit is contained in:
@@ -411,6 +411,7 @@ def train_retinanet():
|
||||
config["lr"] = lr = getenv("LR", 8.5e-5 * (BS / 96))
|
||||
config["loss_scaler"] = loss_scaler = getenv("LOSS_SCALER", 2**11 if dtypes.default_float == dtypes.float16 else 1.0)
|
||||
config["default_float"] = dtypes.default_float.name
|
||||
config["eval_freq"] = eval_freq = getenv("EVAL_FREQ", 1)
|
||||
|
||||
if SEED: Tensor.manual_seed(SEED)
|
||||
|
||||
@@ -518,68 +519,69 @@ def train_retinanet():
|
||||
return
|
||||
|
||||
# ** eval loop **
|
||||
if getenv("RESET_STEP", 1): _train_step.reset()
|
||||
if (e + 1) % eval_freq == 0:
|
||||
if getenv("RESET_STEP", 1): _train_step.reset()
|
||||
|
||||
with Tensor.train(mode=False), Tensor.test():
|
||||
val_dataloader = batch_load_retinanet(val_dataset, (val:=True), Path(BASE_DIR), batch_size=EVAL_BS, shuffle=False, seed=SEED)
|
||||
it = iter(tqdm(val_dataloader, total=steps_in_val_epoch))
|
||||
i, proc = 0, _data_get(it, val=val)
|
||||
with Tensor.train(mode=False), Tensor.test():
|
||||
val_dataloader = batch_load_retinanet(val_dataset, (val:=True), Path(BASE_DIR), batch_size=EVAL_BS, shuffle=False, seed=SEED)
|
||||
it = iter(tqdm(val_dataloader, total=steps_in_val_epoch))
|
||||
i, proc = 0, _data_get(it, val=val)
|
||||
|
||||
eval_times, prev_cookies = [], []
|
||||
val_img_ids, val_imgs, ncats, narea = [], [], len(coco_val.params.catIds), len(coco_val.params.areaRng)
|
||||
eval_times, prev_cookies = [], []
|
||||
val_img_ids, val_imgs, ncats, narea = [], [], len(coco_val.params.catIds), len(coco_val.params.areaRng)
|
||||
|
||||
while proc is not None:
|
||||
GlobalCounters.reset()
|
||||
st = time.time()
|
||||
while proc is not None:
|
||||
GlobalCounters.reset()
|
||||
st = time.time()
|
||||
|
||||
out, img_ids, img_sizes, proc = _eval_step(model, (x:=proc[0])).numpy(), proc[1], proc[2], proc[3]
|
||||
out = model.postprocess_detections(out, input_size=x.shape[1:3], orig_image_sizes=img_sizes)
|
||||
coco_results = [{"image_id": img_ids[i], "category_id": label, "bbox": box.tolist(), "score": score}
|
||||
for i, prediction in enumerate(out) for box, score, label in zip(*prediction.values())]
|
||||
out, img_ids, img_sizes, proc = _eval_step(model, (x:=proc[0])).numpy(), proc[1], proc[2], proc[3]
|
||||
out = model.postprocess_detections(out, input_size=x.shape[1:3], orig_image_sizes=img_sizes)
|
||||
coco_results = [{"image_id": img_ids[i], "category_id": label, "bbox": box.tolist(), "score": score}
|
||||
for i, prediction in enumerate(out) for box, score, label in zip(*prediction.values())]
|
||||
|
||||
with redirect_stdout(None):
|
||||
coco_val.cocoDt = val_dataset.loadRes(coco_results)
|
||||
coco_val.params.imgIds = img_ids
|
||||
coco_val.evaluate()
|
||||
with redirect_stdout(None):
|
||||
coco_val.cocoDt = val_dataset.loadRes(coco_results)
|
||||
coco_val.params.imgIds = img_ids
|
||||
coco_val.evaluate()
|
||||
|
||||
val_img_ids.extend(img_ids)
|
||||
val_imgs.append(np.array(coco_val.evalImgs).reshape(ncats, narea, len(img_ids)))
|
||||
val_img_ids.extend(img_ids)
|
||||
val_imgs.append(np.array(coco_val.evalImgs).reshape(ncats, narea, len(img_ids)))
|
||||
|
||||
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
|
||||
try:
|
||||
next_proc = _data_get(it, val=val)
|
||||
except StopIteration:
|
||||
next_proc = None
|
||||
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
|
||||
try:
|
||||
next_proc = _data_get(it, val=val)
|
||||
except StopIteration:
|
||||
next_proc = None
|
||||
|
||||
prev_cookies.append(proc)
|
||||
proc, next_proc = next_proc, None
|
||||
i += 1
|
||||
prev_cookies.append(proc)
|
||||
proc, next_proc = next_proc, None
|
||||
i += 1
|
||||
|
||||
if i == BENCHMARK:
|
||||
return
|
||||
if i == BENCHMARK:
|
||||
return
|
||||
|
||||
et = time.time()
|
||||
eval_times.append(et - st)
|
||||
et = time.time()
|
||||
eval_times.append(et - st)
|
||||
|
||||
if getenv("RESET_STEP", 1): _eval_step.reset()
|
||||
total_fw_time = sum(eval_times) / len(eval_times)
|
||||
if getenv("RESET_STEP", 1): _eval_step.reset()
|
||||
total_fw_time = sum(eval_times) / len(eval_times)
|
||||
|
||||
coco_val.params.imgIds = val_img_ids
|
||||
coco_val._paramsEval.imgIds = val_img_ids
|
||||
coco_val.evalImgs = list(np.concatenate(val_imgs, -1).flatten())
|
||||
coco_val.accumulate()
|
||||
coco_val.summarize()
|
||||
coco_val.params.imgIds = val_img_ids
|
||||
coco_val._paramsEval.imgIds = val_img_ids
|
||||
coco_val.evalImgs = list(np.concatenate(val_imgs, -1).flatten())
|
||||
coco_val.accumulate()
|
||||
coco_val.summarize()
|
||||
|
||||
val_metric = coco_val.stats[0]
|
||||
val_metric = coco_val.stats[0]
|
||||
|
||||
tqdm.write(f"eval time: {total_fw_time:.2f}, eval metric: {val_metric:.4f}")
|
||||
tqdm.write(f"eval time: {total_fw_time:.2f}, eval metric: {val_metric:.4f}")
|
||||
|
||||
if WANDB:
|
||||
wandb.log({"eval/forward_time": total_fw_time, "eval/metric": val_metric, "epoch": e + 1})
|
||||
if WANDB:
|
||||
wandb.log({"eval/forward_time": total_fw_time, "eval/metric": val_metric, "epoch": e + 1})
|
||||
|
||||
if val_metric >= target_metric:
|
||||
print(colored(f"target metric reached: {val_metric:.2f}/{target_metric:.2f}", color="green"))
|
||||
break
|
||||
if val_metric >= target_metric:
|
||||
print(colored(f"target metric reached: {val_metric:.2f}/{target_metric:.2f}", color="green"))
|
||||
break
|
||||
|
||||
def train_unet3d():
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user