insert float() in bert acc (#9726)

sum of bool by default uses default_float for acc. So without float, it might overflow with a large BS and default_float=HALF.

fixed clsf_accuracy to not be inf in mi300x bert
This commit is contained in:
chenyu
2025-04-03 05:44:09 -04:00
committed by GitHub
parent 79145e3d40
commit 7dadbf3697

View File

@@ -71,8 +71,9 @@ class BertForPretraining:
seq_relationship_correct = (seq_relationship_predictions == next_sentence_labels)
next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels)
# NOTE: .float().sum() to prevent overflow with large BS since default acc of bool is in default_float
# TODO: is it okay that next_sentence_loss is half here?
return masked_lm_correct.sum().float() / valid.sum(), seq_relationship_correct.mean(), masked_lm_loss, next_sentence_loss.float()
return masked_lm_correct.float().sum() / valid.float().sum(), seq_relationship_correct.float().mean(), masked_lm_loss, next_sentence_loss.float()
def load_from_pretrained(self, tf_weight_path:str=Path(__file__).parent.parent / "datasets" / "wiki"):
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Mute tf flag info