From 7dadbf3697bc786a57948a951d114e31c92f3e73 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 3 Apr 2025 05:44:09 -0400 Subject: [PATCH] insert float() in bert acc (#9726) sum of bool by default uses default_float for acc. So without float, it might overflow with a large BS and default_float=HALF. fixed clsf_accuracy to not be inf in mi300x bert --- extra/models/bert.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/extra/models/bert.py b/extra/models/bert.py index 966b092ffb..a3edcb2c29 100644 --- a/extra/models/bert.py +++ b/extra/models/bert.py @@ -71,8 +71,9 @@ class BertForPretraining: seq_relationship_correct = (seq_relationship_predictions == next_sentence_labels) next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels) + # NOTE: .float().sum() to prevent overflow with large BS since default acc of bool is in default_float # TODO: is it okay that next_sentence_loss is half here? - return masked_lm_correct.sum().float() / valid.sum(), seq_relationship_correct.mean(), masked_lm_loss, next_sentence_loss.float() + return masked_lm_correct.float().sum() / valid.float().sum(), seq_relationship_correct.float().mean(), masked_lm_loss, next_sentence_loss.float() def load_from_pretrained(self, tf_weight_path:str=Path(__file__).parent.parent / "datasets" / "wiki"): os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Mute tf flag info