mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
resnet benchmarks use DEFAULT_FLOAT=HALF (#4285)
also update LR default to scaled based on 1536 (the BS we are submitting)
This commit is contained in:
4
.github/workflows/benchmark.yml
vendored
4
.github/workflows/benchmark.yml
vendored
@@ -292,9 +292,9 @@ jobs:
|
||||
- name: Run MLPerf resnet eval on training data
|
||||
run: time HSA=1 MODEL=resnet python3 examples/mlperf/model_eval.py
|
||||
- name: Run 10 MLPerf ResNet50 training steps (1 gpu)
|
||||
run: HSA=1 BENCHMARK=10 BS=128 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet_one_gpu.txt
|
||||
run: HSA=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=128 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet_one_gpu.txt
|
||||
- name: Run 10 MLPerf ResNet50 training steps (6 gpu)
|
||||
run: HSA=1 BENCHMARK=10 BS=768 GPUS=6 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet.txt
|
||||
run: HSA=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=768 GPUS=6 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet.txt
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: Speed (AMD Training)
|
||||
|
||||
@@ -5,7 +5,7 @@ from tqdm import tqdm
|
||||
import multiprocessing
|
||||
|
||||
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
|
||||
from tinygrad.helpers import getenv, BEAM, WINO, Context
|
||||
from tinygrad.helpers import getenv, BEAM, WINO
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save
|
||||
from tinygrad.nn.optim import LARS, SGD, OptimizerGroup
|
||||
|
||||
@@ -50,7 +50,7 @@ def train_resnet():
|
||||
epochs = config["epochs"] = getenv("EPOCHS", 37)
|
||||
BS = config["BS"] = getenv("BS", 104 * len(GPUS)) # fp32 GPUS<=6 7900xtx can fit BS=112
|
||||
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", BS)
|
||||
base_lr = config["base_lr"] = getenv("LR", 7.4 * (BS/1632))
|
||||
base_lr = config["base_lr"] = getenv("LR", 7.4 * (BS/1536))
|
||||
lr_warmup_epochs = config["lr_warmup_epochs"] = getenv("WARMUP_EPOCHS", 2)
|
||||
decay = config["decay"] = getenv("DECAY", 5e-5)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user