mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
reducing AutoConfig.from_pretrained (#411)
* reducing AutoConfig.from_pretrained
This commit is contained in:
@@ -408,10 +408,7 @@ def load_model(checkpoint_path, task, num_labels, per_model_config=None):
|
||||
)
|
||||
from ..data import SEQCLASSIFICATION, SEQREGRESSION, TOKENCLASSIFICATION
|
||||
|
||||
this_model_type = AutoConfig.from_pretrained(checkpoint_path).model_type
|
||||
this_vocab_size = AutoConfig.from_pretrained(checkpoint_path).vocab_size
|
||||
|
||||
def get_this_model(task):
|
||||
def get_this_model(task, model_config):
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
from transformers import AutoModelForMultipleChoice
|
||||
@@ -460,28 +457,34 @@ def load_model(checkpoint_path, task, num_labels, per_model_config=None):
|
||||
model_config = AutoConfig.from_pretrained(checkpoint_path)
|
||||
return model_config
|
||||
|
||||
current_config = AutoConfig.from_pretrained(checkpoint_path)
|
||||
this_model_type, this_vocab_size = (
|
||||
current_config.model_type,
|
||||
current_config.vocab_size,
|
||||
)
|
||||
|
||||
if task == SEQCLASSIFICATION:
|
||||
num_labels_old = AutoConfig.from_pretrained(checkpoint_path).num_labels
|
||||
num_labels_old = current_config.num_labels
|
||||
if is_pretrained_model_in_classification_head_list(this_model_type):
|
||||
model_config_num_labels = num_labels_old
|
||||
else:
|
||||
model_config_num_labels = num_labels
|
||||
model_config = _set_model_config(checkpoint_path)
|
||||
new_config = _set_model_config(checkpoint_path)
|
||||
|
||||
if is_pretrained_model_in_classification_head_list(this_model_type):
|
||||
if num_labels != num_labels_old:
|
||||
this_model = get_this_model(task)
|
||||
model_config.num_labels = num_labels
|
||||
this_model = get_this_model(task, new_config)
|
||||
new_config.num_labels = num_labels
|
||||
this_model.num_labels = num_labels
|
||||
this_model.classifier = (
|
||||
AutoSeqClassificationHead.from_model_type_and_config(
|
||||
this_model_type, model_config
|
||||
this_model_type, new_config
|
||||
)
|
||||
)
|
||||
else:
|
||||
this_model = get_this_model(task)
|
||||
this_model = get_this_model(task, new_config)
|
||||
else:
|
||||
this_model = get_this_model(task)
|
||||
this_model = get_this_model(task, new_config)
|
||||
this_model.resize_token_embeddings(this_vocab_size)
|
||||
return this_model
|
||||
else:
|
||||
@@ -490,7 +493,7 @@ def load_model(checkpoint_path, task, num_labels, per_model_config=None):
|
||||
elif task == TOKENCLASSIFICATION:
|
||||
model_config_num_labels = num_labels
|
||||
model_config = _set_model_config(checkpoint_path)
|
||||
this_model = get_this_model(task)
|
||||
this_model = get_this_model(task, model_config)
|
||||
return this_model
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user