cherry pick mlperf5.0 branch to master (#10089)

This commit is contained in:
chenyu
2025-04-28 15:36:56 -04:00
committed by GitHub
parent 459a223202
commit 610ee79b22
16 changed files with 242 additions and 35 deletions

View File

@@ -357,11 +357,43 @@ def train_retinanet():
config, target_metric = {}, 0.34
config["SEED"] = SEED = getenv("SEED", random.SystemRandom().randint(0, 2**32 - 1))
Tensor.manual_seed(SEED)
NUM_CLASSES = len(MLPERF_CLASSES)
BASEDIR = getenv("BASEDIR", BASEDIR)
BENCHMARK = getenv("BENCHMARK")
# INITMLPERF = getenv("INITMLPERF")
INITMLPERF = getenv("INITMLPERF")
RUNMLPERF = getenv("RUNMLPERF")
if getenv("LOGMLPERF"):
from mlperf_logging import mllog
import mlperf_logging.mllog.constants as mllog_constants
mllog.config(filename=f"result_retinanet_{SEED}.log")
mllog.config(root_dir=Path(__file__).parents[3].as_posix())
MLLOGGER = mllog.get_mllogger()
MLLOGGER.logger.propagate = False
if INITMLPERF:
assert BENCHMARK, "BENCHMARK must be set for INITMLPERF"
MLLOGGER.event(key=mllog_constants.SUBMISSION_ORG, value="tinycorp")
MLLOGGER.event(key=mllog_constants.SUBMISSION_PLATFORM, value=getenv("SUBMISSION_PLATFORM", "tinybox"))
MLLOGGER.event(key=mllog_constants.SUBMISSION_DIVISION, value=mllog_constants.CLOSED)
MLLOGGER.event(key=mllog_constants.SUBMISSION_STATUS, value=mllog_constants.ONPREM)
MLLOGGER.event(key=mllog_constants.SUBMISSION_BENCHMARK, value=mllog_constants.RETINANET)
diskcache_clear()
MLLOGGER.event(key=mllog_constants.CACHE_CLEAR, value=True)
MLLOGGER.start(key=mllog_constants.INIT_START)
if RUNMLPERF:
MLLOGGER.start(key=mllog_constants.RUN_START)
MLLOGGER.event(key=mllog_constants.SEED, value=SEED)
else:
MLLOGGER = None
config["gpus"] = GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 6))]
for x in GPUS: Device[x]
@@ -414,24 +446,21 @@ def train_retinanet():
return out.to(GPUS[0]).realize()
# ** hyperparameters **
config["seed"] = SEED = getenv("SEED", random.SystemRandom().randint(0, 2**32 - 1))
config["bs"] = BS = getenv("BS", 16 * len(GPUS) if dtypes.default_float == dtypes.float16 else 12 * len(GPUS))
config["eval_bs"] = EVAL_BS = getenv("EVAL_BS", BS)
config["epochs"] = EPOCHS = getenv("EPOCHS", 4)
config["train_beam"] = TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
config["eval_beam"] = EVAL_BEAM = getenv("EVAL_BEAM", BEAM.value)
config["lr"] = lr = getenv("LR", 9.5e-5 * (BS / 96))
config["loss_scaler"] = loss_scaler = getenv("LOSS_SCALER", 2**11 if dtypes.default_float == dtypes.float16 else 1.0)
config["default_float"] = dtypes.default_float.name
config["eval_freq"] = eval_freq = getenv("EVAL_FREQ", 1)
config["BS"] = BS = getenv("BS", 16 * len(GPUS) if dtypes.default_float == dtypes.float16 else 12 * len(GPUS))
config["EVAL_BS"] = EVAL_BS = getenv("EVAL_BS", BS)
config["EPOCHS"] = EPOCHS = getenv("EPOCHS", 4)
config["TRAIN_BEAM"] = TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
config["EVAL_BEAM"] = EVAL_BEAM = getenv("EVAL_BEAM", BEAM.value)
config["LR"] = lr = getenv("LR", 9.5e-5 * (BS / 96))
config["LOSS_SCALER"] = loss_scaler = getenv("LOSS_SCALER", 2**11 if dtypes.default_float == dtypes.float16 else 1.0)
config["DEFAULT_FLOAT"] = dtypes.default_float.name
config["EVAL_FREQ"] = eval_freq = getenv("EVAL_FREQ", 1)
# ** initialize wandb **
if (WANDB:=getenv("WANDB")):
import wandb
wandb.init(config=config, project="MLPerf-RetinaNet")
if SEED: Tensor.manual_seed(SEED)
# ** model initializers **
resnet.BatchNorm = FrozenBatchNorm2dRetinaNet
resnet.Linear = Linear
@@ -464,8 +493,24 @@ def train_retinanet():
optim = Adam(params, lr=lr)
# ** dataset **
config["steps_in_train_epoch"] = steps_in_train_epoch = round_up(get_dataset_count((base_dir_path:=Path(BASEDIR)), False), BS) // BS
config["steps_in_val_epoch"] = steps_in_val_epoch = (round_up(get_dataset_count(base_dir_path, True), EVAL_BS) // EVAL_BS)
config["STEPS_IN_TRAIN_EPOCH"] = steps_in_train_epoch = round_up(get_dataset_count((base_dir_path:=Path(BASEDIR)), False), BS) // BS
config["STEPS_IN_VAL_EPOCH"] = steps_in_val_epoch = (round_up(get_dataset_count(base_dir_path, True), EVAL_BS) // EVAL_BS)
# log mlperf hparams
if MLLOGGER:
if RUNMLPERF:
MLLOGGER.event(key=mllog_constants.GLOBAL_BATCH_SIZE, value=config["BS"])
MLLOGGER.event(key=mllog_constants.TRAIN_SAMPLES, value=config["STEPS_IN_TRAIN_EPOCH"])
MLLOGGER.event(key=mllog_constants.EVAL_SAMPLES, value=config["STEPS_IN_VAL_EPOCH"])
MLLOGGER.event(key=mllog_constants.EPOCH_COUNT, value=config["EPOCHS"])
MLLOGGER.event(key=mllog_constants.FIRST_EPOCH_NUM, value=start_epoch)
MLLOGGER.event(key=mllog_constants.OPT_NAME, value=mllog_constants.ADAM)
MLLOGGER.event(key=mllog_constants.OPT_BASE_LR, value=config["LR"])
MLLOGGER.event(key=mllog_constants.OPT_WEIGHT_DECAY, value=0)
MLLOGGER.event(key=mllog_constants.OPT_LR_WARMUP_EPOCHS, value=0)
MLLOGGER.event(key=mllog_constants.OPT_LR_WARMUP_FACTOR, value=0)
MLLOGGER.event(key=mllog_constants.GRADIENT_ACCUMULATION_STEPS, value=1)
if RUNMLPERF:
train_dataset = COCO(download_dataset(BASEDIR, "train"))
@@ -476,13 +521,16 @@ def train_retinanet():
for e in range(start_epoch, EPOCHS):
# ** training loop **
if MLLOGGER and RUNMLPERF:
MLLOGGER.start(key=mllog_constants.EPOCH_START, value=e + 1, metadata={"epoch_num": e + 1})
BEAM.value = TRAIN_BEAM
if not RUNMLPERF:
i, proc = 0, _fake_data_get(BS)
else:
train_dataloader = batch_load_retinanet(train_dataset, False, base_dir_path, batch_size=BS, seed=SEED)
it = iter(tqdm(train_dataloader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
it = iter(tqdm(train_dataloader, total=steps_in_train_epoch, desc=f"epoch {e + 1}", disable=BENCHMARK))
i, proc = 0, _data_get(it)
prev_cookies = []
@@ -544,8 +592,14 @@ def train_retinanet():
if (TRAIN_BEAM or EVAL_BEAM) and e == start_epoch: break
return
if MLLOGGER and RUNMLPERF:
MLLOGGER.event(key=mllog_constants.EPOCH_STOP, value=e + 1, metadata={"epoch_num": e + 1})
# ** eval loop **
if (e + 1) % eval_freq == 0:
if MLLOGGER and RUNMLPERF:
MLLOGGER.start(key=mllog_constants.EVAL_START, value=e + 1, metadata={"epoch_num": e + 1})
BEAM.value = EVAL_BEAM
if getenv("RESET_STEP", 1): _train_step.reset()
@@ -593,12 +647,15 @@ def train_retinanet():
proc, next_proc = next_proc, None
i += 1
if i == BENCHMARK:
return
et = time.time()
eval_times.append(et - st)
if i == BENCHMARK:
# assume INITMLPERF has BENCHMARK set
if MLLOGGER and INITMLPERF:
MLLOGGER.event(key=mllog_constants.INIT_STOP)
return
if getenv("RESET_STEP", 1): _eval_step.reset()
total_fw_time = sum(eval_times) / len(eval_times)
@@ -616,8 +673,16 @@ def train_retinanet():
if WANDB:
wandb.log({"eval/forward_time": total_fw_time, "eval/metric": val_metric, "epoch": e + 1})
if MLLOGGER:
MLLOGGER.event(key=mllog_constants.EVAL_ACCURACY, value=val_metric, metadata={"epoch_num": e + 1}, clear_line=True)
MLLOGGER.end(key=mllog_constants.EVAL_STOP, value=e + 1, metadata={"epoch_num": e + 1})
if val_metric >= target_metric:
print(colored(f"target metric reached: {val_metric:.2f}/{target_metric:.2f}", color="green"))
if MLLOGGER:
MLLOGGER.end(key=mllog_constants.RUN_STOP, metadata={"status": mllog_constants.SUCCESS})
break
def train_unet3d():

View File

@@ -0,0 +1,69 @@
# 1. Problem
This problem uses BERT for NLP.
## Requirements
Install tinygrad and mlperf-logging (uncomment mlperf from setup.py) from branch mlperf_training_v5.0.
```
git clone https://github.com/tinygrad/tinygrad.git
python3 -m pip install -e ".[mlperf]"
```
Also install gdown (for dataset), numpy, tqdm and tensorflow.
```
pip install gdown numpy tqdm tensorflow
```
### tinybox_green
Install the p2p driver per [README](https://github.com/tinygrad/open-gpu-kernel-modules/blob/550.54.15-p2p/README.md)
This is the default on production tinybox green.
# 2. Directions
## Steps to download and verify data
### 1. Download raw data
```
BASEDIR="/raid/datasets/wiki" WIKI_TRAIN=1 VERIFY_CHECKSUM=1 python3 extra/datasets/wikipedia_download.py
```
### 2. Preprocess train and validation data
Note: The number of threads used for preprocessing is limited by available memory. With 128GB of RAM, a maximum of 16 threads is recommended.
#### Training:
```
BASEDIR="/raid/datasets/wiki" NUM_WORKERS=16 python3 extra/datasets/wikipedia.py pre-train all
```
Generating a specific topic (Between 0 and 499)
```
BASEDIR="/raid/datasets/wiki" python3 extra/datasets/wikipedia.py pre-train 42
```
#### Validation:
```
BASEDIR="/raid/datasets/wiki" python3 extra/datasets/wikipedia.py pre-eval
```
## Running
### tinybox_green
#### Steps to run benchmark
```
examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh
```
### tinybox_red
#### Steps to run benchmark
```
examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh
```
### tinybox_8xMI300X
#### Steps to run benchmark
```
examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_8xMI300X/run_and_time.sh
```

View File

@@ -4,14 +4,14 @@ This problem uses BERT for NLP.
## Requirements
Install tinygrad and mlperf-logging from master.
Install tinygrad and mlperf-logging (uncomment mlperf from setup.py) from branch mlperf_training_v5.0.
```
git clone https://github.com/tinygrad/tinygrad.git
python3 -m pip install -e ".[mlperf]"
```
Also install tqdm and tensorflow.
Also install gdown (for dataset), numpy, tqdm and tensorflow.
```
pip install tqdm tensorflow
pip install gdown numpy tqdm tensorflow
```
### tinybox_green
@@ -52,12 +52,18 @@ BASEDIR="/raid/datasets/wiki" python3 extra/datasets/wikipedia.py pre-eval
#### Steps to run benchmark
```
examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh
examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh
```
### tinybox_red
#### Steps to run benchmark
```
examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh
examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh
```
### tinybox_8xMI300X
#### Steps to run benchmark
```
examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_8xMI300X/run_and_time.sh
```

View File

@@ -4,14 +4,14 @@ This problem uses BERT for NLP.
## Requirements
Install tinygrad and mlperf-logging from master.
Install tinygrad and mlperf-logging (uncomment mlperf from setup.py) from branch mlperf_training_v5.0.
```
git clone https://github.com/tinygrad/tinygrad.git
python3 -m pip install -e ".[mlperf]"
```
Also install tqdm and tensorflow.
Also install gdown (for dataset), numpy, tqdm and tensorflow.
```
pip install tqdm tensorflow
pip install gdown numpy tqdm tensorflow
```
### tinybox_green
@@ -52,12 +52,18 @@ BASEDIR="/raid/datasets/wiki" python3 extra/datasets/wikipedia.py pre-eval
#### Steps to run benchmark
```
examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh
examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh
```
### tinybox_red
#### Steps to run benchmark
```
examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh
examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh
```
### tinybox_8xMI300X
#### Steps to run benchmark
```
examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_8xMI300X/run_and_time.sh
```

View File

@@ -6,7 +6,7 @@ export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96
export FUSE_ARANGE=1 FUSE_ARANGE_UINT=0
export BEAM=5 BEAM_UOPS_MAX=10000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
export BEAM=5 BEAM_UOPS_MAX=8000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
export IGNORE_JIT_FIRST_BEAM=1
export BEAM_LOG_SURPASS_MAX=1
export BASEDIR="/raid/datasets/wiki"

View File

@@ -6,7 +6,7 @@ export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96
export FUSE_ARANGE=1 FUSE_ARANGE_UINT=0
export BEAM=5 BEAM_UOPS_MAX=10000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
export BEAM=5 BEAM_UOPS_MAX=8000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/wiki"

View File

@@ -7,7 +7,7 @@ export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96
export FUSE_ARANGE=1 FUSE_ARANGE_UINT=0
export BEAM=5 BEAM_UOPS_MAX=10000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
export BEAM=5 BEAM_UOPS_MAX=8000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/wiki"

View File

@@ -0,0 +1,38 @@
# 1. Problem
This problem uses RetinaNet for SSD.
## Requirements
Install tinygrad and mlperf-logging (uncomment mlperf from setup.py) from branch mlperf_training_v5.0.
```
git clone https://github.com/tinygrad/tinygrad.git
python3 -m pip install -e ".[mlperf]"
```
Also install the following dependencies:
```
pip install tqdm numpy pycocotools boto3 pandas torch torchvision
```
### tinybox_green
Install the p2p driver per [README](https://github.com/tinygrad/open-gpu-kernel-modules/blob/550.54.15-p2p/README.md)
This is the default on production tinybox green.
# 2. Directions
## Steps to download data
Run the following:
```
BASEDIR=/raid/datasets/openimages python3 extra/datasets/openimages.py
```
## Running
### tinybox_green
#### Steps to run benchmark
```
examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/retinanet/implementations/tinybox_green/run_and_time.sh
```

View File

@@ -0,0 +1,23 @@
#!/bin/bash
export PYTHONPATH="." NV=1
export MODEL="retinanet"
export SUBMISSION_PLATFORM="tinybox_green"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=96 EVAL_BS=96
export TRAIN_BEAM=2 BEAM_UOPS_MAX=1500 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5 BEAM_PADTO=0
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/openimages"
# pip install -e ".[mlperf]"
export LOGMLPERF=1
export SEED=$RANDOM
DATETIME=$(date "+%m%d%H%M")
LOGFILE="retinanet_green_${DATETIME}_${SEED}.log"
# init
BENCHMARK=10 INITMLPERF=1 python3 examples/mlperf/model_train.py | tee $LOGFILE
# run
PARALLEL=0 RUNMLPERF=1 python3 examples/mlperf/model_train.py | tee -a $LOGFILE

View File

@@ -28,7 +28,7 @@
"accelerator_interconnect_topology": "",
"cooling": "air",
"hw_notes": "",
"framework": "tinygrad, commit TBD",
"framework": "tinygrad, branch mlperf_training_v5.0",
"other_software_stack": {
"python": "3.10.16",
"ROCm": "3.0.0+94441cb"

View File

@@ -28,7 +28,7 @@
"accelerator_interconnect_topology": "",
"cooling": "air",
"hw_notes": "",
"framework": "tinygrad, commit b5546912e24e0a864b35924da4efa5d71cfe368b",
"framework": "tinygrad, branch mlperf_training_v5.0",
"other_software_stack": {
"python": "3.10.12",
"CUDA": "12.4"

View File

@@ -28,7 +28,7 @@
"accelerator_interconnect_topology": "",
"cooling": "air",
"hw_notes": "",
"framework": "tinygrad, commit b5546912e24e0a864b35924da4efa5d71cfe368b",
"framework": "tinygrad, branch mlperf_training_v5.0",
"other_software_stack": {
"python": "3.10.12"
},