mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
get the proper dataset count (#9962)
This commit is contained in:
@@ -346,7 +346,7 @@ def train_retinanet():
|
||||
from contextlib import redirect_stdout
|
||||
from examples.mlperf.dataloader import batch_load_retinanet
|
||||
from examples.mlperf.initializers import FrozenBatchNorm2dRetinaNet, Conv2dNormalRetinaNet, Conv2dKaimingUniformRetinaNet, Linear, Conv2dRetinaNet
|
||||
from extra.datasets.openimages import MLPERF_CLASSES, BASEDIR, download_dataset, normalize
|
||||
from extra.datasets.openimages import MLPERF_CLASSES, BASEDIR, download_dataset, normalize, get_dataset_count
|
||||
from extra.models import resnet
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
@@ -452,17 +452,14 @@ def train_retinanet():
|
||||
optim = Adam(params, lr=lr)
|
||||
|
||||
# ** dataset **
|
||||
if INITMLPERF:
|
||||
config["steps_in_train_epoch"] = steps_in_train_epoch = BS
|
||||
config["steps_in_val_epoch"] = steps_in_val_epoch = EVAL_BS
|
||||
else:
|
||||
config["steps_in_train_epoch"] = steps_in_train_epoch = round_up(get_dataset_count((base_dir_path:=Path(BASE_DIR)), False), BS) // BS
|
||||
config["steps_in_val_epoch"] = steps_in_val_epoch = (round_up(get_dataset_count(base_dir_path, True), EVAL_BS) // EVAL_BS)
|
||||
|
||||
if not INITMLPERF:
|
||||
train_dataset = COCO(download_dataset(BASE_DIR, "train"))
|
||||
val_dataset = COCO(download_dataset(BASE_DIR, "validation"))
|
||||
coco_val = COCOeval(cocoGt=val_dataset, iouType="bbox")
|
||||
|
||||
config["steps_in_train_epoch"] = steps_in_train_epoch = round_up(len(train_dataset.imgs.keys()), BS) // BS
|
||||
config["steps_in_val_epoch"] = steps_in_val_epoch = (round_up(len(val_dataset.imgs.keys()), EVAL_BS) // EVAL_BS)
|
||||
|
||||
# ** initialize wandb **
|
||||
if (WANDB:=getenv("WANDB")):
|
||||
import wandb
|
||||
@@ -477,7 +474,7 @@ def train_retinanet():
|
||||
if INITMLPERF:
|
||||
i, proc = 0, _fake_data_get(BS)
|
||||
else:
|
||||
train_dataloader = batch_load_retinanet(train_dataset, False, Path(BASE_DIR), batch_size=BS, seed=SEED)
|
||||
train_dataloader = batch_load_retinanet(train_dataset, False, base_dir_path, batch_size=BS, seed=SEED)
|
||||
it = iter(tqdm(train_dataloader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
|
||||
i, proc = 0, _data_get(it)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import glob
|
||||
import sys
|
||||
import json
|
||||
import numpy as np
|
||||
@@ -199,6 +200,10 @@ def normalize(img:Tensor, device:list[str]|None = None):
|
||||
img = ((img.permute([0, 3, 1, 2]) / 255.0) - mean) / std
|
||||
return img.cast(dtypes.default_float)
|
||||
|
||||
def get_dataset_count(base_dir:Path, val:bool) -> int:
|
||||
if not (files:=glob.glob(p:=str(base_dir / f"{'validation' if val else 'train'}/data/*.jpg"))): raise FileNotFoundError(f"No files in {p}")
|
||||
return len(files)
|
||||
|
||||
if __name__ == "__main__":
|
||||
download_dataset(base_dir:=getenv("BASE_DIR", BASEDIR), "train")
|
||||
download_dataset(base_dir, "validation")
|
||||
|
||||
Reference in New Issue
Block a user