add bert tiny config (#9177)

set with BERT_SIZE=tiny. easier to study embedding and fusion
This commit is contained in:
chenyu
2025-02-19 14:57:03 -05:00
committed by GitHub
parent 5662c898f1
commit 3b37cc898b

View File

@@ -195,18 +195,16 @@ def get_bert_qa_prediction(features, example, start_end_logits):
return "empty"
def get_mlperf_bert_config():
"""Config is BERT-large"""
return {
"attention_probs_dropout_prob": 0.1,
"hidden_dropout_prob": 0.1,
"hidden_size": 1024,
"intermediate_size": 4096,
"max_position_embeddings": 512,
"num_attention_heads": 16,
"num_hidden_layers": getenv("BERT_LAYERS", 24),
"type_vocab_size": 2,
"vocab_size": 30522
}
"""benchmark is BERT-large"""
ret = {"attention_probs_dropout_prob": 0.1, "hidden_dropout_prob": 0.1, "vocab_size": 30522, "type_vocab_size": 2, "max_position_embeddings": 512}
match (bert_size:=getenv("BERT_SIZE", "large")):
case "large": ret.update({"hidden_size": 1024, "intermediate_size": 4096, "num_attention_heads": 16, "num_hidden_layers": 24})
case "tiny": ret.update({"hidden_size": 128, "intermediate_size": 512, "num_attention_heads": 2, "num_hidden_layers": 2})
case _: raise RuntimeError(f"unhandled {bert_size=}")
if (bert_layers:=getenv("BERT_LAYERS")): ret["num_hidden_layers"] = bert_layers
return ret
def get_mlperf_bert_model():
from extra.models import bert