get the proper dataset count (#9962)

This commit is contained in:
Francis Lata
2025-04-21 12:11:37 -04:00
committed by GitHub
parent 36ed3c3253
commit defa1e77f6
2 changed files with 11 additions and 9 deletions

View File

@@ -346,7 +346,7 @@ def train_retinanet():
from contextlib import redirect_stdout
from examples.mlperf.dataloader import batch_load_retinanet
from examples.mlperf.initializers import FrozenBatchNorm2dRetinaNet, Conv2dNormalRetinaNet, Conv2dKaimingUniformRetinaNet, Linear, Conv2dRetinaNet
from extra.datasets.openimages import MLPERF_CLASSES, BASEDIR, download_dataset, normalize
from extra.datasets.openimages import MLPERF_CLASSES, BASEDIR, download_dataset, normalize, get_dataset_count
from extra.models import resnet
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
@@ -452,17 +452,14 @@ def train_retinanet():
optim = Adam(params, lr=lr)
# ** dataset **
if INITMLPERF:
config["steps_in_train_epoch"] = steps_in_train_epoch = BS
config["steps_in_val_epoch"] = steps_in_val_epoch = EVAL_BS
else:
config["steps_in_train_epoch"] = steps_in_train_epoch = round_up(get_dataset_count((base_dir_path:=Path(BASE_DIR)), False), BS) // BS
config["steps_in_val_epoch"] = steps_in_val_epoch = (round_up(get_dataset_count(base_dir_path, True), EVAL_BS) // EVAL_BS)
if not INITMLPERF:
train_dataset = COCO(download_dataset(BASE_DIR, "train"))
val_dataset = COCO(download_dataset(BASE_DIR, "validation"))
coco_val = COCOeval(cocoGt=val_dataset, iouType="bbox")
config["steps_in_train_epoch"] = steps_in_train_epoch = round_up(len(train_dataset.imgs.keys()), BS) // BS
config["steps_in_val_epoch"] = steps_in_val_epoch = (round_up(len(val_dataset.imgs.keys()), EVAL_BS) // EVAL_BS)
# ** initialize wandb **
if (WANDB:=getenv("WANDB")):
import wandb
@@ -477,7 +474,7 @@ def train_retinanet():
if INITMLPERF:
i, proc = 0, _fake_data_get(BS)
else:
train_dataloader = batch_load_retinanet(train_dataset, False, Path(BASE_DIR), batch_size=BS, seed=SEED)
train_dataloader = batch_load_retinanet(train_dataset, False, base_dir_path, batch_size=BS, seed=SEED)
it = iter(tqdm(train_dataloader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
i, proc = 0, _data_get(it)

View File

@@ -1,3 +1,4 @@
import glob
import sys
import json
import numpy as np
@@ -199,6 +200,10 @@ def normalize(img:Tensor, device:list[str]|None = None):
img = ((img.permute([0, 3, 1, 2]) / 255.0) - mean) / std
return img.cast(dtypes.default_float)
def get_dataset_count(base_dir:Path, val:bool) -> int:
if not (files:=glob.glob(p:=str(base_dir / f"{'validation' if val else 'train'}/data/*.jpg"))): raise FileNotFoundError(f"No files in {p}")
return len(files)
if __name__ == "__main__":
download_dataset(base_dir:=getenv("BASE_DIR", BASEDIR), "train")
download_dataset(base_dir, "validation")