From 4562f217e1d387eb009e259989a390da1e0cf693 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 6 Dec 2025 08:32:43 -0500 Subject: [PATCH] more bert updates (#13597) prep split jit also lower BS to 72 --- examples/mlperf/model_train.py | 56 +++++++++---------- .../implementations/tinybox_green/dev_beam.sh | 2 +- .../implementations/tinybox_green/dev_run.sh | 2 +- .../tinybox_green/run_and_time.sh | 2 +- 4 files changed, 31 insertions(+), 31 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 16c784c2cf..002fe2a62f 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -918,32 +918,6 @@ def train_rnnt(): # TODO: RNN-T pass -@TinyJit -def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, - masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor, GPUS): - for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]: - if len(GPUS) > 1: t.shard_(GPUS, axis=0) - else: t.to_(GPUS[0]) - optimizer.zero_grad() - - lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids) - loss = model.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels) - (loss * loss_scaler).backward() - - global_norm = Tensor(0.0, dtype=dtypes.float32, device=optimizer[0].device) - for p in optimizer.params: - p.grad = p.grad / loss_scaler - global_norm += p.grad.float().square().sum() - global_norm = global_norm.sqrt().contiguous() - for p in optimizer.params: - p.grad = (global_norm > 1.0).where((p.grad/global_norm).cast(p.grad.dtype), p.grad) - - optimizer.step() - scheduler.step() - # TODO: no to("CPU") here because it blocks and messes the python time - Tensor.realize(loss, global_norm, optimizer.optimizers[0].lr) - return loss, global_norm, optimizer.optimizers[0].lr - @TinyJit def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor, GPUS): @@ -1130,6 +1104,32 @@ def train_bert(): if MLLOGGER: MLLOGGER.start(key=mllog_constants.EPOCH_START, value=i*GBS, metadata={"epoch_num": i*GBS}) + @TinyJit + def train_step_bert(input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, + masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor): + for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]: + if len(GPUS) > 1: t.shard_(GPUS, axis=0) + else: t.to_(GPUS[0]) + optimizer_group.zero_grad() + + lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids) + loss = model.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels) + (loss * loss_scaler).backward() + + global_norm = Tensor(0.0, dtype=dtypes.float32, device=optimizer_group[0].device) + for p in optimizer_group.params: + p.grad = p.grad / loss_scaler + global_norm += p.grad.float().square().sum() + global_norm = global_norm.sqrt().contiguous() + for p in optimizer_group.params: + p.grad = (global_norm > 1.0).where((p.grad/global_norm).cast(p.grad.dtype), p.grad) + + optimizer_group.step() + scheduler_group.step() + # TODO: no to("CPU") here because it blocks and messes the python time + Tensor.realize(loss, global_norm, optimizer_group.optimizers[0].lr) + return loss, global_norm, optimizer_group.optimizers[0].lr + while train_data is not None and i < train_steps and not achieved: if getenv("TRAIN", 1): Tensor.training = True @@ -1137,9 +1137,9 @@ def train_bert(): st = time.perf_counter() GlobalCounters.reset() with WallTimeEvent(BenchEvent.STEP): - loss, global_norm, lr = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler, + loss, global_norm, lr = train_step_bert( train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \ - train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"], GPUS) + train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"]) pt = time.perf_counter() diff --git a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh index a22eb3e987..265455d1db 100755 --- a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh +++ b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh @@ -2,7 +2,7 @@ export PYTHONPATH="." NV=1 export MODEL="bert" -export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96 +export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=72 EVAL_BS=72 export IGNORE_OOB=1 export REWRITE_STACK_LIMIT=500000 diff --git a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh index c906579887..38c7966a29 100755 --- a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh +++ b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh @@ -2,7 +2,7 @@ export PYTHONPATH="." NV=1 export MODEL="bert" -export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96 +export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=72 EVAL_BS=72 export IGNORE_OOB=1 export REWRITE_STACK_LIMIT=500000 diff --git a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh index 4b81469316..8dd27b3ea8 100755 --- a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh +++ b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh @@ -5,7 +5,7 @@ set -o pipefail # Make pipeline fail if any command fails export PYTHONPATH="." NV=1 export MODEL="bert" export SUBMISSION_PLATFORM="tinybox_green" -export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=96 EVAL_BS=96 +export DEFAULT_FLOAT="HALF" SUM_DTYPE="HALF" GPUS=6 BS=72 EVAL_BS=72 export IGNORE_OOB=1 export REWRITE_STACK_LIMIT=500000