diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index b73639298a..c9c8fa2c03 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -292,9 +292,9 @@ jobs: - name: Run MLPerf resnet eval on training data run: time HSA=1 MODEL=resnet python3 examples/mlperf/model_eval.py - name: Run 10 MLPerf ResNet50 training steps (1 gpu) - run: HSA=1 BENCHMARK=10 BS=128 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet_one_gpu.txt + run: HSA=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=128 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet_one_gpu.txt - name: Run 10 MLPerf ResNet50 training steps (6 gpu) - run: HSA=1 BENCHMARK=10 BS=768 GPUS=6 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet.txt + run: HSA=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=768 GPUS=6 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet.txt - uses: actions/upload-artifact@v4 with: name: Speed (AMD Training) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 6738b8c88d..7de32cee56 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -5,7 +5,7 @@ from tqdm import tqdm import multiprocessing from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes -from tinygrad.helpers import getenv, BEAM, WINO, Context +from tinygrad.helpers import getenv, BEAM, WINO from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save from tinygrad.nn.optim import LARS, SGD, OptimizerGroup @@ -50,7 +50,7 @@ def train_resnet(): epochs = config["epochs"] = getenv("EPOCHS", 37) BS = config["BS"] = getenv("BS", 104 * len(GPUS)) # fp32 GPUS<=6 7900xtx can fit BS=112 EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", BS) - base_lr = config["base_lr"] = getenv("LR", 7.4 * (BS/1632)) + base_lr = config["base_lr"] = getenv("LR", 7.4 * (BS/1536)) lr_warmup_epochs = config["lr_warmup_epochs"] = getenv("WARMUP_EPOCHS", 2) decay = config["decay"] = getenv("DECAY", 5e-5)