mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
zero out the weight in bert init run (#9076)
`DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 EVAL_BS=6 GPUS=6 MODEL=bert python3 examples/mlperf/model_train.py` no longer oom. I think the buffer of random init weights caused the oom.
This commit is contained in:
@@ -208,7 +208,7 @@ def get_mlperf_bert_config():
|
||||
"vocab_size": 30522
|
||||
}
|
||||
|
||||
def get_mlperf_bert_model(checkpoint_path:Optional[str]=None):
|
||||
def get_mlperf_bert_model():
|
||||
from extra.models import bert
|
||||
from examples.mlperf.initializers import LinearBert, EmbeddingBert, LayerNormBert
|
||||
|
||||
@@ -220,8 +220,7 @@ def get_mlperf_bert_model(checkpoint_path:Optional[str]=None):
|
||||
config = get_mlperf_bert_config()
|
||||
if getenv("DISABLE_DROPOUT", 0):
|
||||
config["hidden_dropout_prob"] = config["attention_probs_dropout_prob"] = 0.0
|
||||
model = BertForPretraining(**config)
|
||||
return model.load_from_pretrained(checkpoint_path) if checkpoint_path else model
|
||||
return BertForPretraining(**config)
|
||||
|
||||
def get_data_bert(GPUS:list[str], it):
|
||||
data: dict[str, Tensor] = next(it)
|
||||
|
||||
@@ -683,8 +683,14 @@ def train_bert():
|
||||
|
||||
# ** init model **
|
||||
|
||||
model = get_mlperf_bert_model(init_ckpt if RUNMLPERF else None)
|
||||
|
||||
model = get_mlperf_bert_model()
|
||||
if RUNMLPERF:
|
||||
model.load_from_pretrained(init_ckpt)
|
||||
else:
|
||||
# for init, zero out all weights
|
||||
for p in get_parameters(model):
|
||||
p = p.assign(Tensor.zeros_like(p).contiguous()).realize()
|
||||
|
||||
parameters = get_parameters(model)
|
||||
for p in parameters:
|
||||
p.to_(GPUS)
|
||||
|
||||
Reference in New Issue
Block a user