diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 6f265eb9d9..3bb15935cc 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -411,6 +411,7 @@ def train_retinanet(): config["lr"] = lr = getenv("LR", 8.5e-5 * (BS / 96)) config["loss_scaler"] = loss_scaler = getenv("LOSS_SCALER", 2**11 if dtypes.default_float == dtypes.float16 else 1.0) config["default_float"] = dtypes.default_float.name + config["eval_freq"] = eval_freq = getenv("EVAL_FREQ", 1) if SEED: Tensor.manual_seed(SEED) @@ -518,68 +519,69 @@ def train_retinanet(): return # ** eval loop ** - if getenv("RESET_STEP", 1): _train_step.reset() + if (e + 1) % eval_freq == 0: + if getenv("RESET_STEP", 1): _train_step.reset() - with Tensor.train(mode=False), Tensor.test(): - val_dataloader = batch_load_retinanet(val_dataset, (val:=True), Path(BASE_DIR), batch_size=EVAL_BS, shuffle=False, seed=SEED) - it = iter(tqdm(val_dataloader, total=steps_in_val_epoch)) - i, proc = 0, _data_get(it, val=val) + with Tensor.train(mode=False), Tensor.test(): + val_dataloader = batch_load_retinanet(val_dataset, (val:=True), Path(BASE_DIR), batch_size=EVAL_BS, shuffle=False, seed=SEED) + it = iter(tqdm(val_dataloader, total=steps_in_val_epoch)) + i, proc = 0, _data_get(it, val=val) - eval_times, prev_cookies = [], [] - val_img_ids, val_imgs, ncats, narea = [], [], len(coco_val.params.catIds), len(coco_val.params.areaRng) + eval_times, prev_cookies = [], [] + val_img_ids, val_imgs, ncats, narea = [], [], len(coco_val.params.catIds), len(coco_val.params.areaRng) - while proc is not None: - GlobalCounters.reset() - st = time.time() + while proc is not None: + GlobalCounters.reset() + st = time.time() - out, img_ids, img_sizes, proc = _eval_step(model, (x:=proc[0])).numpy(), proc[1], proc[2], proc[3] - out = model.postprocess_detections(out, input_size=x.shape[1:3], orig_image_sizes=img_sizes) - coco_results = [{"image_id": img_ids[i], "category_id": label, "bbox": box.tolist(), "score": score} - for i, prediction in enumerate(out) for box, score, label in zip(*prediction.values())] + out, img_ids, img_sizes, proc = _eval_step(model, (x:=proc[0])).numpy(), proc[1], proc[2], proc[3] + out = model.postprocess_detections(out, input_size=x.shape[1:3], orig_image_sizes=img_sizes) + coco_results = [{"image_id": img_ids[i], "category_id": label, "bbox": box.tolist(), "score": score} + for i, prediction in enumerate(out) for box, score, label in zip(*prediction.values())] - with redirect_stdout(None): - coco_val.cocoDt = val_dataset.loadRes(coco_results) - coco_val.params.imgIds = img_ids - coco_val.evaluate() + with redirect_stdout(None): + coco_val.cocoDt = val_dataset.loadRes(coco_results) + coco_val.params.imgIds = img_ids + coco_val.evaluate() - val_img_ids.extend(img_ids) - val_imgs.append(np.array(coco_val.evalImgs).reshape(ncats, narea, len(img_ids))) + val_img_ids.extend(img_ids) + val_imgs.append(np.array(coco_val.evalImgs).reshape(ncats, narea, len(img_ids))) - if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued - try: - next_proc = _data_get(it, val=val) - except StopIteration: - next_proc = None + if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued + try: + next_proc = _data_get(it, val=val) + except StopIteration: + next_proc = None - prev_cookies.append(proc) - proc, next_proc = next_proc, None - i += 1 + prev_cookies.append(proc) + proc, next_proc = next_proc, None + i += 1 - if i == BENCHMARK: - return + if i == BENCHMARK: + return - et = time.time() - eval_times.append(et - st) + et = time.time() + eval_times.append(et - st) - if getenv("RESET_STEP", 1): _eval_step.reset() - total_fw_time = sum(eval_times) / len(eval_times) + if getenv("RESET_STEP", 1): _eval_step.reset() + total_fw_time = sum(eval_times) / len(eval_times) - coco_val.params.imgIds = val_img_ids - coco_val._paramsEval.imgIds = val_img_ids - coco_val.evalImgs = list(np.concatenate(val_imgs, -1).flatten()) - coco_val.accumulate() - coco_val.summarize() + coco_val.params.imgIds = val_img_ids + coco_val._paramsEval.imgIds = val_img_ids + coco_val.evalImgs = list(np.concatenate(val_imgs, -1).flatten()) + coco_val.accumulate() + coco_val.summarize() - val_metric = coco_val.stats[0] + val_metric = coco_val.stats[0] - tqdm.write(f"eval time: {total_fw_time:.2f}, eval metric: {val_metric:.4f}") + tqdm.write(f"eval time: {total_fw_time:.2f}, eval metric: {val_metric:.4f}") - if WANDB: - wandb.log({"eval/forward_time": total_fw_time, "eval/metric": val_metric, "epoch": e + 1}) + if WANDB: + wandb.log({"eval/forward_time": total_fw_time, "eval/metric": val_metric, "epoch": e + 1}) - if val_metric >= target_metric: - print(colored(f"target metric reached: {val_metric:.2f}/{target_metric:.2f}", color="green")) - break + if val_metric >= target_metric: + print(colored(f"target metric reached: {val_metric:.2f}/{target_metric:.2f}", color="green")) + break def train_unet3d(): """