Residual in MLM loss + Change default steps (#4935)

* Residual in mlm loss

* Reduce default steps to 160K * 24

* oops

* comment
This commit is contained in:
Elias Wahl
2024-06-12 22:09:18 +02:00
committed by GitHub
parent a21ea165bc
commit d2e3c391e8
2 changed files with 9 additions and 2 deletions

View File

@@ -402,7 +402,7 @@ def train_bert():
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 4 * len(GPUS))
max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.000004166 * BS)
train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 4800000 // BS)
train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 3840000 // BS)
warmup_steps = config["NUM_WARMUP_STEPS"] = getenv("NUM_WARMUP_STEPS", 1)
max_eval_steps = config["MAX_EVAL_STEPS"] = getenv("MAX_EVAL_STEPS", (10000 + EVAL_BS - 1) // EVAL_BS) # EVAL_BS * MAX_EVAL_STEPS >= 10000
eval_step_freq = config["EVAL_STEP_FREQ"] = getenv("EVAL_STEP_FREQ", int((math.floor(0.05 * (230.23 * BS + 3000000) / 25000) * 25000) / BS)) # Round down

View File

@@ -50,7 +50,14 @@ class BertForPretraining:
return self.cls(output, masked_lm_positions)
def loss(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
masked_lm_loss = prediction_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights)
# Reference has residual on denominator: https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/run_pretraining.py#L315
def sparse_categorical_crossentropy(predictions:Tensor, labels:Tensor, ignore_index=-1):
log_probs, loss_mask = predictions.log_softmax(), (labels != ignore_index)
y_counter = Tensor.arange(predictions.shape[-1], requires_grad=False, device=predictions.device).unsqueeze(0).expand(labels.numel(), predictions.shape[-1])
y = ((y_counter == labels.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*labels.shape, predictions.shape[-1])
return -((log_probs * y).sum()) / (loss_mask.sum() + 1e-5) # Small constant to avoid division by zero
masked_lm_loss = sparse_categorical_crossentropy(prediction_logits, masked_lm_ids, ignore_index=masked_lm_weights)
next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels)
return masked_lm_loss + next_sentence_loss