mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
add bert tiny config (#9177)
set with BERT_SIZE=tiny. easier to study embedding and fusion
This commit is contained in:
@@ -195,18 +195,16 @@ def get_bert_qa_prediction(features, example, start_end_logits):
|
||||
return "empty"
|
||||
|
||||
def get_mlperf_bert_config():
|
||||
"""Config is BERT-large"""
|
||||
return {
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 1024,
|
||||
"intermediate_size": 4096,
|
||||
"max_position_embeddings": 512,
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": getenv("BERT_LAYERS", 24),
|
||||
"type_vocab_size": 2,
|
||||
"vocab_size": 30522
|
||||
}
|
||||
"""benchmark is BERT-large"""
|
||||
ret = {"attention_probs_dropout_prob": 0.1, "hidden_dropout_prob": 0.1, "vocab_size": 30522, "type_vocab_size": 2, "max_position_embeddings": 512}
|
||||
|
||||
match (bert_size:=getenv("BERT_SIZE", "large")):
|
||||
case "large": ret.update({"hidden_size": 1024, "intermediate_size": 4096, "num_attention_heads": 16, "num_hidden_layers": 24})
|
||||
case "tiny": ret.update({"hidden_size": 128, "intermediate_size": 512, "num_attention_heads": 2, "num_hidden_layers": 2})
|
||||
case _: raise RuntimeError(f"unhandled {bert_size=}")
|
||||
|
||||
if (bert_layers:=getenv("BERT_LAYERS")): ret["num_hidden_layers"] = bert_layers
|
||||
return ret
|
||||
|
||||
def get_mlperf_bert_model():
|
||||
from extra.models import bert
|
||||
|
||||
Reference in New Issue
Block a user