mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
dev scripts for retinanet (#9968)
also BASE_DIR -> BASEDIR for consistency, and move wandb up a bit for more accurate timing
This commit is contained in:
@@ -359,7 +359,7 @@ def train_retinanet():
|
||||
config, target_metric = {}, 0.34
|
||||
|
||||
NUM_CLASSES = len(MLPERF_CLASSES)
|
||||
BASE_DIR = getenv("BASE_DIR", BASEDIR)
|
||||
BASEDIR = getenv("BASEDIR", BASEDIR)
|
||||
BENCHMARK = getenv("BENCHMARK")
|
||||
INITMLPERF = getenv("INITMLPERF")
|
||||
config["gpus"] = GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 6))]
|
||||
@@ -424,6 +424,11 @@ def train_retinanet():
|
||||
config["default_float"] = dtypes.default_float.name
|
||||
config["eval_freq"] = eval_freq = getenv("EVAL_FREQ", 1)
|
||||
|
||||
# ** initialize wandb **
|
||||
if (WANDB:=getenv("WANDB")):
|
||||
import wandb
|
||||
wandb.init(config=config, project="MLPerf-RetinaNet")
|
||||
|
||||
if SEED: Tensor.manual_seed(SEED)
|
||||
|
||||
# ** model initializers **
|
||||
@@ -437,6 +442,7 @@ def train_retinanet():
|
||||
|
||||
# ** model setup **
|
||||
backbone = resnet.ResNeXt50_32X4D(num_classes=None)
|
||||
# TODO: should not load_from_pretrained during setup
|
||||
backbone.load_from_pretrained()
|
||||
_freeze_backbone_layers(backbone, 3)
|
||||
|
||||
@@ -452,19 +458,14 @@ def train_retinanet():
|
||||
optim = Adam(params, lr=lr)
|
||||
|
||||
# ** dataset **
|
||||
config["steps_in_train_epoch"] = steps_in_train_epoch = round_up(get_dataset_count((base_dir_path:=Path(BASE_DIR)), False), BS) // BS
|
||||
config["steps_in_train_epoch"] = steps_in_train_epoch = round_up(get_dataset_count((base_dir_path:=Path(BASEDIR)), False), BS) // BS
|
||||
config["steps_in_val_epoch"] = steps_in_val_epoch = (round_up(get_dataset_count(base_dir_path, True), EVAL_BS) // EVAL_BS)
|
||||
|
||||
if not INITMLPERF:
|
||||
train_dataset = COCO(download_dataset(BASE_DIR, "train"))
|
||||
val_dataset = COCO(download_dataset(BASE_DIR, "validation"))
|
||||
train_dataset = COCO(download_dataset(BASEDIR, "train"))
|
||||
val_dataset = COCO(download_dataset(BASEDIR, "validation"))
|
||||
coco_val = COCOeval(cocoGt=val_dataset, iouType="bbox")
|
||||
|
||||
# ** initialize wandb **
|
||||
if (WANDB:=getenv("WANDB")):
|
||||
import wandb
|
||||
wandb.init(config=config, project="MLPerf-RetinaNet")
|
||||
|
||||
print(f"training with batch size {BS} for {EPOCHS} epochs")
|
||||
|
||||
for e in range(start_epoch, EPOCHS):
|
||||
@@ -547,7 +548,7 @@ def train_retinanet():
|
||||
if INITMLPERF:
|
||||
i, proc = 0, _fake_data_get(EVAL_BS, val=(val:=True))
|
||||
else:
|
||||
val_dataloader = batch_load_retinanet(val_dataset, (val:=True), Path(BASE_DIR), batch_size=EVAL_BS, shuffle=False, seed=SEED)
|
||||
val_dataloader = batch_load_retinanet(val_dataset, (val:=True), Path(BASEDIR), batch_size=EVAL_BS, shuffle=False, seed=SEED)
|
||||
it = iter(tqdm(val_dataloader, total=steps_in_val_epoch))
|
||||
i, proc = 0, _data_get(it, val=val)
|
||||
val_img_ids, val_imgs, ncats, narea = [], [], len(coco_val.params.catIds), len(coco_val.params.areaRng)
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
|
||||
export PYTHONPATH="."
|
||||
export MODEL="retinanet"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=96 EVAL_BS=96
|
||||
export BASEDIR="/raid/datasets/openimages"
|
||||
|
||||
# export RESET_STEP=0
|
||||
|
||||
export TRAIN_BEAM=2 IGNORE_JIT_FIRST_BEAM=1 BEAM_UOPS_MAX=1500 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5 BEAM_PADTO=0
|
||||
|
||||
export INITMLPERF=1
|
||||
export BENCHMARK=10 DEBUG=2
|
||||
|
||||
python examples/mlperf/model_train.py
|
||||
@@ -0,0 +1,14 @@
|
||||
#!/bin/bash
|
||||
|
||||
export PYTHONPATH="."
|
||||
export MODEL="retinanet"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=96 EVAL_BS=96
|
||||
export BASEDIR="/raid/datasets/openimages"
|
||||
|
||||
# export RESET_STEP=0
|
||||
|
||||
export TRAIN_BEAM=2 IGNORE_JIT_FIRST_BEAM=1 BEAM_UOPS_MAX=1500 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5 BEAM_PADTO=0
|
||||
|
||||
export WANDB=1 PARALLEL=0
|
||||
|
||||
python examples/mlperf/model_train.py
|
||||
Reference in New Issue
Block a user