From d2e3c391e851892af9c703acebf68746e1e6e16c Mon Sep 17 00:00:00 2001 From: Elias Wahl <82230675+Eliulm@users.noreply.github.com> Date: Wed, 12 Jun 2024 22:09:18 +0200 Subject: [PATCH] Residual in MLM loss + Change default steps (#4935) * Residual in mlm loss * Reduce default steps to 160K * 24 * oops * comment --- examples/mlperf/model_train.py | 2 +- extra/models/bert.py | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index cac68051db..3eac557ce3 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -402,7 +402,7 @@ def train_bert(): EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 4 * len(GPUS)) max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.000004166 * BS) - train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 4800000 // BS) + train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 3840000 // BS) warmup_steps = config["NUM_WARMUP_STEPS"] = getenv("NUM_WARMUP_STEPS", 1) max_eval_steps = config["MAX_EVAL_STEPS"] = getenv("MAX_EVAL_STEPS", (10000 + EVAL_BS - 1) // EVAL_BS) # EVAL_BS * MAX_EVAL_STEPS >= 10000 eval_step_freq = config["EVAL_STEP_FREQ"] = getenv("EVAL_STEP_FREQ", int((math.floor(0.05 * (230.23 * BS + 3000000) / 25000) * 25000) / BS)) # Round down diff --git a/extra/models/bert.py b/extra/models/bert.py index d5f0c2b4df..fdf702c35e 100644 --- a/extra/models/bert.py +++ b/extra/models/bert.py @@ -50,7 +50,14 @@ class BertForPretraining: return self.cls(output, masked_lm_positions) def loss(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor): - masked_lm_loss = prediction_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights) + # Reference has residual on denominator: https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/run_pretraining.py#L315 + def sparse_categorical_crossentropy(predictions:Tensor, labels:Tensor, ignore_index=-1): + log_probs, loss_mask = predictions.log_softmax(), (labels != ignore_index) + y_counter = Tensor.arange(predictions.shape[-1], requires_grad=False, device=predictions.device).unsqueeze(0).expand(labels.numel(), predictions.shape[-1]) + y = ((y_counter == labels.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*labels.shape, predictions.shape[-1]) + return -((log_probs * y).sum()) / (loss_mask.sum() + 1e-5) # Small constant to avoid division by zero + + masked_lm_loss = sparse_categorical_crossentropy(prediction_logits, masked_lm_ids, ignore_index=masked_lm_weights) next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels) return masked_lm_loss + next_sentence_loss