From cb4c6324ef9a60941e20fdab36a8b345b6fec411 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 5 Dec 2025 17:30:08 -0500 Subject: [PATCH] revert bert grad accumulation (#13596) prep for the new split jit style --- examples/mlperf/model_train.py | 36 ++++++++++++++-------------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index a8189c4ee5..16c784c2cf 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -919,24 +919,16 @@ def train_rnnt(): pass @TinyJit -def train_step_bert(model, optimizer, scheduler, loss_scaler:float, GPUS, grad_acc:int, **kwargs): +def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, + masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor, GPUS): + for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]: + if len(GPUS) > 1: t.shard_(GPUS, axis=0) + else: t.to_(GPUS[0]) optimizer.zero_grad() - for i in range(grad_acc): - input_ids, segment_ids = kwargs[f"input_ids{i}"], kwargs[f"segment_ids{i}"] - # NOTE: these two have different names - attention_mask, masked_positions = kwargs[f"input_mask{i}"], kwargs[f"masked_lm_positions{i}"] - masked_lm_ids, masked_lm_weights, next_sentence_labels = kwargs[f"masked_lm_ids{i}"], kwargs[f"masked_lm_weights{i}"], kwargs[f"next_sentence_labels{i}"] - - for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]: - if len(GPUS) > 1: t.shard_(GPUS, axis=0) - else: t.to_(GPUS[0]) - - lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids) - loss = model.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels) - (loss * loss_scaler).backward() - # TODO: OOM without this realize with large grad_acc - Tensor.realize(*[p.grad for p in optimizer.params]) + lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids) + loss = model.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels) + (loss * loss_scaler).backward() global_norm = Tensor(0.0, dtype=dtypes.float32, device=optimizer[0].device) for p in optimizer.params: @@ -1014,7 +1006,8 @@ def train_bert(): # ** hyperparameters ** BS = config["BS"] = getenv("BS", 11 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS)) grad_acc = config["GRADIENT_ACC_STEPS"] = getenv("GRADIENT_ACC_STEPS", 1) - # TODO: mlperf logging + # TODO: implement grad accumulation + mlperf logging + assert grad_acc == 1 GBS = config["GLOBAL_BATCH_SIZE"] = BS * grad_acc EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS)) max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.000175 * math.sqrt(GBS/96)) @@ -1131,7 +1124,7 @@ def train_bert(): # ** train loop ** wc_start = time.perf_counter() - i, train_data = start_step, [next(train_it) for _ in range(grad_acc)] + i, train_data = start_step, next(train_it) if RUNMLPERF: if MLLOGGER: @@ -1144,13 +1137,14 @@ def train_bert(): st = time.perf_counter() GlobalCounters.reset() with WallTimeEvent(BenchEvent.STEP): - data = {f"{k}{i}":v for i,d in enumerate(train_data) for k,v in d.items()} - loss, global_norm, lr = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler, GPUS, grad_acc, **data) + loss, global_norm, lr = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler, + train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \ + train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"], GPUS) pt = time.perf_counter() try: - next_data = [next(train_it) for _ in range(grad_acc)] + next_data = next(train_it) except StopIteration: next_data = None