zero out the weight in bert init run (#9076)

`DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 EVAL_BS=6 GPUS=6 MODEL=bert python3 examples/mlperf/model_train.py` no longer oom. I think the buffer of random init weights caused the oom.
This commit is contained in:
chenyu
2025-02-14 08:40:41 -05:00
committed by GitHub
parent 82ad0d2e65
commit b58e7b1898
2 changed files with 10 additions and 5 deletions

View File

@@ -208,7 +208,7 @@ def get_mlperf_bert_config():
"vocab_size": 30522
}
def get_mlperf_bert_model(checkpoint_path:Optional[str]=None):
def get_mlperf_bert_model():
from extra.models import bert
from examples.mlperf.initializers import LinearBert, EmbeddingBert, LayerNormBert
@@ -220,8 +220,7 @@ def get_mlperf_bert_model(checkpoint_path:Optional[str]=None):
config = get_mlperf_bert_config()
if getenv("DISABLE_DROPOUT", 0):
config["hidden_dropout_prob"] = config["attention_probs_dropout_prob"] = 0.0
model = BertForPretraining(**config)
return model.load_from_pretrained(checkpoint_path) if checkpoint_path else model
return BertForPretraining(**config)
def get_data_bert(GPUS:list[str], it):
data: dict[str, Tensor] = next(it)

View File

@@ -683,8 +683,14 @@ def train_bert():
# ** init model **
model = get_mlperf_bert_model(init_ckpt if RUNMLPERF else None)
model = get_mlperf_bert_model()
if RUNMLPERF:
model.load_from_pretrained(init_ckpt)
else:
# for init, zero out all weights
for p in get_parameters(model):
p = p.assign(Tensor.zeros_like(p).contiguous()).realize()
parameters = get_parameters(model)
for p in parameters:
p.to_(GPUS)