mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
insert float() in bert acc (#9726)
sum of bool by default uses default_float for acc. So without float, it might overflow with a large BS and default_float=HALF. fixed clsf_accuracy to not be inf in mi300x bert
This commit is contained in:
@@ -71,8 +71,9 @@ class BertForPretraining:
|
||||
seq_relationship_correct = (seq_relationship_predictions == next_sentence_labels)
|
||||
next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels)
|
||||
|
||||
# NOTE: .float().sum() to prevent overflow with large BS since default acc of bool is in default_float
|
||||
# TODO: is it okay that next_sentence_loss is half here?
|
||||
return masked_lm_correct.sum().float() / valid.sum(), seq_relationship_correct.mean(), masked_lm_loss, next_sentence_loss.float()
|
||||
return masked_lm_correct.float().sum() / valid.float().sum(), seq_relationship_correct.float().mean(), masked_lm_loss, next_sentence_loss.float()
|
||||
|
||||
def load_from_pretrained(self, tf_weight_path:str=Path(__file__).parent.parent / "datasets" / "wiki"):
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Mute tf flag info
|
||||
|
||||
Reference in New Issue
Block a user