From 3b37cc898be62bfa3efae94e8750872f21bbb7e6 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 19 Feb 2025 14:57:03 -0500 Subject: [PATCH] add bert tiny config (#9177) set with BERT_SIZE=tiny. easier to study embedding and fusion --- examples/mlperf/helpers.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/examples/mlperf/helpers.py b/examples/mlperf/helpers.py index f3aa16be12..0c01db2a9a 100644 --- a/examples/mlperf/helpers.py +++ b/examples/mlperf/helpers.py @@ -195,18 +195,16 @@ def get_bert_qa_prediction(features, example, start_end_logits): return "empty" def get_mlperf_bert_config(): - """Config is BERT-large""" - return { - "attention_probs_dropout_prob": 0.1, - "hidden_dropout_prob": 0.1, - "hidden_size": 1024, - "intermediate_size": 4096, - "max_position_embeddings": 512, - "num_attention_heads": 16, - "num_hidden_layers": getenv("BERT_LAYERS", 24), - "type_vocab_size": 2, - "vocab_size": 30522 - } + """benchmark is BERT-large""" + ret = {"attention_probs_dropout_prob": 0.1, "hidden_dropout_prob": 0.1, "vocab_size": 30522, "type_vocab_size": 2, "max_position_embeddings": 512} + + match (bert_size:=getenv("BERT_SIZE", "large")): + case "large": ret.update({"hidden_size": 1024, "intermediate_size": 4096, "num_attention_heads": 16, "num_hidden_layers": 24}) + case "tiny": ret.update({"hidden_size": 128, "intermediate_size": 512, "num_attention_heads": 2, "num_hidden_layers": 2}) + case _: raise RuntimeError(f"unhandled {bert_size=}") + + if (bert_layers:=getenv("BERT_LAYERS")): ret["num_hidden_layers"] = bert_layers + return ret def get_mlperf_bert_model(): from extra.models import bert