reducing AutoConfig.from_pretrained (#411)

* reducing AutoConfig.from_pretrained
This commit is contained in:
Xueqing Liu
2022-01-17 14:44:11 -05:00
committed by GitHub
parent 1c911da9f8
commit 3ef758cd7b

View File

@@ -408,10 +408,7 @@ def load_model(checkpoint_path, task, num_labels, per_model_config=None):
)
from ..data import SEQCLASSIFICATION, SEQREGRESSION, TOKENCLASSIFICATION
this_model_type = AutoConfig.from_pretrained(checkpoint_path).model_type
this_vocab_size = AutoConfig.from_pretrained(checkpoint_path).vocab_size
def get_this_model(task):
def get_this_model(task, model_config):
from transformers import AutoModelForSequenceClassification
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoModelForMultipleChoice
@@ -460,28 +457,34 @@ def load_model(checkpoint_path, task, num_labels, per_model_config=None):
model_config = AutoConfig.from_pretrained(checkpoint_path)
return model_config
current_config = AutoConfig.from_pretrained(checkpoint_path)
this_model_type, this_vocab_size = (
current_config.model_type,
current_config.vocab_size,
)
if task == SEQCLASSIFICATION:
num_labels_old = AutoConfig.from_pretrained(checkpoint_path).num_labels
num_labels_old = current_config.num_labels
if is_pretrained_model_in_classification_head_list(this_model_type):
model_config_num_labels = num_labels_old
else:
model_config_num_labels = num_labels
model_config = _set_model_config(checkpoint_path)
new_config = _set_model_config(checkpoint_path)
if is_pretrained_model_in_classification_head_list(this_model_type):
if num_labels != num_labels_old:
this_model = get_this_model(task)
model_config.num_labels = num_labels
this_model = get_this_model(task, new_config)
new_config.num_labels = num_labels
this_model.num_labels = num_labels
this_model.classifier = (
AutoSeqClassificationHead.from_model_type_and_config(
this_model_type, model_config
this_model_type, new_config
)
)
else:
this_model = get_this_model(task)
this_model = get_this_model(task, new_config)
else:
this_model = get_this_model(task)
this_model = get_this_model(task, new_config)
this_model.resize_token_embeddings(this_vocab_size)
return this_model
else:
@@ -490,7 +493,7 @@ def load_model(checkpoint_path, task, num_labels, per_model_config=None):
elif task == TOKENCLASSIFICATION:
model_config_num_labels = num_labels
model_config = _set_model_config(checkpoint_path)
this_model = get_this_model(task)
this_model = get_this_model(task, model_config)
return this_model