From 22fc0a2e3670ef64be34281d2d1b631360eaf2de Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 11 Mar 2025 23:03:15 -0400 Subject: [PATCH] bert sum acc in half (#9412) also BS=96 --- examples/mlperf/model_train.py | 4 ++-- .../benchmarks/bert/implementations/tinybox_green/dev_beam.sh | 2 +- .../benchmarks/bert/implementations/tinybox_green/dev_run.sh | 2 +- .../bert/implementations/tinybox_green/run_and_time.sh | 2 +- .../benchmarks/bert/implementations/tinybox_red/dev_beam.sh | 2 +- .../benchmarks/bert/implementations/tinybox_red/dev_run.sh | 2 +- .../bert/implementations/tinybox_red/run_and_time.sh | 2 +- test/test_tensor.py | 2 +- tinygrad/dtype.py | 4 ++-- 9 files changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index e1d7d4a022..83e1a2a107 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -658,7 +658,7 @@ def train_bert(): # ** hyperparameters ** BS = config["GLOBAL_BATCH_SIZE"] = getenv("BS", 11 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS)) EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS)) - max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.00011 * math.sqrt(BS/66)) + max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.0002 * math.sqrt(BS/96)) train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 3630000 // BS) warmup_steps = config["NUM_WARMUP_STEPS"] = getenv("NUM_WARMUP_STEPS", 1) @@ -669,7 +669,7 @@ def train_bert(): save_ckpt_dir = config["SAVE_CKPT_DIR"] = getenv("SAVE_CKPT_DIR", "./ckpts") init_ckpt = config["INIT_CKPT_DIR"] = getenv("INIT_CKPT_DIR", BASEDIR) - loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 2.0**10 if dtypes.default_float == dtypes.float16 else 1.0) + loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 2.0**11 if dtypes.default_float == dtypes.float16 else 1.0) decay = config["DECAY"] = getenv("DECAY", 0.01) epsilon = config["EPSILON"] = getenv("EPSILON", 1e-6) poly_power = config["POLY_POWER"] = getenv("POLY_POWER", 1.0) diff --git a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh index 6f17109784..013a61820c 100755 --- a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh +++ b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh @@ -2,7 +2,7 @@ export PYTHONPATH="." export MODEL="bert" -export DEFAULT_FLOAT="HALF" GPUS=6 BS=78 EVAL_BS=78 +export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96 export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 export IGNORE_JIT_FIRST_BEAM=1 diff --git a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh index 05c5a75619..f70edf4ccb 100755 --- a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh +++ b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh @@ -2,7 +2,7 @@ export PYTHONPATH="." export MODEL="bert" -export DEFAULT_FLOAT="HALF" GPUS=6 BS=78 EVAL_BS=78 +export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96 export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 export IGNORE_JIT_FIRST_BEAM=1 diff --git a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh index 6a77928d89..dd71f162a0 100755 --- a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh +++ b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh @@ -3,7 +3,7 @@ export PYTHONPATH="." export MODEL="bert" export SUBMISSION_PLATFORM="tinybox_green" -export DEFAULT_FLOAT="HALF" GPUS=6 BS=78 EVAL_BS=78 +export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96 export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 export IGNORE_JIT_FIRST_BEAM=1 diff --git a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh index f72acd8942..5cf5771d0e 100755 --- a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh +++ b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh @@ -2,7 +2,7 @@ export PYTHONPATH="." export MODEL="bert" -export DEFAULT_FLOAT="HALF" GPUS=6 BS=78 EVAL_BS=78 +export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96 export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 export IGNORE_JIT_FIRST_BEAM=1 diff --git a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh index 503c91aa93..b9529deb4f 100755 --- a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh +++ b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh @@ -2,7 +2,7 @@ export PYTHONPATH="." export MODEL="bert" -export DEFAULT_FLOAT="HALF" GPUS=6 BS=78 EVAL_BS=78 +export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96 export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 export IGNORE_JIT_FIRST_BEAM=1 diff --git a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh index ec2554b25f..4e150d74a2 100755 --- a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh +++ b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh @@ -3,7 +3,7 @@ export PYTHONPATH="." export MODEL="bert" export SUBMISSION_PLATFORM="tinybox_red" -export DEFAULT_FLOAT="HALF" GPUS=6 BS=78 EVAL_BS=78 +export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96 export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 export IGNORE_JIT_FIRST_BEAM=1 diff --git a/test/test_tensor.py b/test/test_tensor.py index 2b4d758848..fb9bbc72cb 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -414,7 +414,7 @@ class TestTinygrad(unittest.TestCase): def test_tensor_dtype_errors(self): with self.assertRaises(AttributeError): Tensor([3], dtype="typo") - with self.assertRaises(TypeError): Tensor([3], dtype=(dtypes.int,)) + with self.assertRaises(AttributeError): Tensor([3], dtype=(dtypes.int,)) def test_tensor_bytes(self): data = b"abc123" diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 1ee9bfa8f4..c0b80d5528 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -156,7 +156,7 @@ if (env_default_float := getenv("DEFAULT_FLOAT", "")): assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype" DTypeLike = Union[str, DType] -def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) else getattr(dtypes, dtype) +def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) else getattr(dtypes, dtype.lower()) # https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html # we don't support weak type and complex type @@ -180,7 +180,7 @@ def sum_acc_dtype(dt:DType): # default acc dtype for sum if dtypes.is_unsigned(dt): return least_upper_dtype(dt, dtypes.uint) if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int) - return least_upper_dtype(dt, dtypes.float) + return least_upper_dtype(dt, to_dtype(getenv("SUM_DTYPE", "float32"))) def truncate_fp16(x): try: return struct.unpack("@e", struct.pack("@e", float(x)))[0]