From defa1e77f60a5bdc3568dc094d62533e72fcc633 Mon Sep 17 00:00:00 2001 From: Francis Lata Date: Mon, 21 Apr 2025 12:11:37 -0400 Subject: [PATCH] get the proper dataset count (#9962) --- examples/mlperf/model_train.py | 15 ++++++--------- extra/datasets/openimages.py | 5 +++++ 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index e85c0f1887..6134ff32d2 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -346,7 +346,7 @@ def train_retinanet(): from contextlib import redirect_stdout from examples.mlperf.dataloader import batch_load_retinanet from examples.mlperf.initializers import FrozenBatchNorm2dRetinaNet, Conv2dNormalRetinaNet, Conv2dKaimingUniformRetinaNet, Linear, Conv2dRetinaNet - from extra.datasets.openimages import MLPERF_CLASSES, BASEDIR, download_dataset, normalize + from extra.datasets.openimages import MLPERF_CLASSES, BASEDIR, download_dataset, normalize, get_dataset_count from extra.models import resnet from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval @@ -452,17 +452,14 @@ def train_retinanet(): optim = Adam(params, lr=lr) # ** dataset ** - if INITMLPERF: - config["steps_in_train_epoch"] = steps_in_train_epoch = BS - config["steps_in_val_epoch"] = steps_in_val_epoch = EVAL_BS - else: + config["steps_in_train_epoch"] = steps_in_train_epoch = round_up(get_dataset_count((base_dir_path:=Path(BASE_DIR)), False), BS) // BS + config["steps_in_val_epoch"] = steps_in_val_epoch = (round_up(get_dataset_count(base_dir_path, True), EVAL_BS) // EVAL_BS) + + if not INITMLPERF: train_dataset = COCO(download_dataset(BASE_DIR, "train")) val_dataset = COCO(download_dataset(BASE_DIR, "validation")) coco_val = COCOeval(cocoGt=val_dataset, iouType="bbox") - config["steps_in_train_epoch"] = steps_in_train_epoch = round_up(len(train_dataset.imgs.keys()), BS) // BS - config["steps_in_val_epoch"] = steps_in_val_epoch = (round_up(len(val_dataset.imgs.keys()), EVAL_BS) // EVAL_BS) - # ** initialize wandb ** if (WANDB:=getenv("WANDB")): import wandb @@ -477,7 +474,7 @@ def train_retinanet(): if INITMLPERF: i, proc = 0, _fake_data_get(BS) else: - train_dataloader = batch_load_retinanet(train_dataset, False, Path(BASE_DIR), batch_size=BS, seed=SEED) + train_dataloader = batch_load_retinanet(train_dataset, False, base_dir_path, batch_size=BS, seed=SEED) it = iter(tqdm(train_dataloader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK)) i, proc = 0, _data_get(it) diff --git a/extra/datasets/openimages.py b/extra/datasets/openimages.py index 996b286126..54bfee2eaa 100644 --- a/extra/datasets/openimages.py +++ b/extra/datasets/openimages.py @@ -1,3 +1,4 @@ +import glob import sys import json import numpy as np @@ -199,6 +200,10 @@ def normalize(img:Tensor, device:list[str]|None = None): img = ((img.permute([0, 3, 1, 2]) / 255.0) - mean) / std return img.cast(dtypes.default_float) +def get_dataset_count(base_dir:Path, val:bool) -> int: + if not (files:=glob.glob(p:=str(base_dir / f"{'validation' if val else 'train'}/data/*.jpg"))): raise FileNotFoundError(f"No files in {p}") + return len(files) + if __name__ == "__main__": download_dataset(base_dir:=getenv("BASE_DIR", BASEDIR), "train") download_dataset(base_dir, "validation")