From 90c958bca2229e737e35d1501a95c575de977403 Mon Sep 17 00:00:00 2001 From: mariecwhite Date: Tue, 21 Mar 2023 09:32:50 +1100 Subject: [PATCH] Add T5-base and T5-large Torch and TF Models (#1116) --- requirements.txt | 1 + tank/all_models.csv | 6 ++++ tank/generate_sharktank.py | 14 +++++--- tank/model_metadata.csv | 3 ++ tank/model_utils.py | 47 ++++++++++++++++++++++++++ tank/model_utils_tf.py | 68 ++++++++++++++++++++++++++++++++++++++ tank/tf_model_list.csv | 2 ++ tank/torch_model_list.csv | 2 ++ 8 files changed, 139 insertions(+), 4 deletions(-) diff --git a/requirements.txt b/requirements.txt index ae4844ae..679c727d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,6 +26,7 @@ safetensors opencv-python scikit-image pytorch_lightning # for runwayml models +sentencepiece # Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors pefile diff --git a/tank/all_models.csv b/tank/all_models.csv index 8bd80506..44d89c80 100644 --- a/tank/all_models.csv +++ b/tank/all_models.csv @@ -35,8 +35,14 @@ squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","mac wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos" efficientnet-v2-s,mhlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos" mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos" +t5-base,linalg,torch,1e-2,1e-3,default,None,False,False,False,"","" +t5-base,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"","" +t5-large,linalg,torch,1e-2,1e-3,default,None,False,False,False,"","" +t5-large,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"","" efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","" efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","" +efficientnet_b0,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,False,"","" +efficientnet_b7,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,False,"","" efficientnet_b0,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,"","" efficientnet_b7,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,"","" gpt2,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"","" diff --git a/tank/generate_sharktank.py b/tank/generate_sharktank.py index cf2d632b..e80b4404 100644 --- a/tank/generate_sharktank.py +++ b/tank/generate_sharktank.py @@ -36,6 +36,7 @@ def create_hash(file_name): def save_torch_model(torch_model_list, local_tank_cache): from tank.model_utils import ( get_hf_model, + get_hf_seq2seq_model, get_vision_model, get_hf_img_cls_model, get_fp16_model, @@ -84,6 +85,8 @@ def save_torch_model(torch_model_list, local_tank_cache): model, input, _ = get_vision_model(torch_model_name) elif model_type == "hf": model, input, _ = get_hf_model(torch_model_name) + elif model_type == "hf_seq2seq": + model, input, _ = get_hf_seq2seq_model(torch_model_name) elif model_type == "hf_img_cls": model, input, _ = get_hf_img_cls_model(torch_model_name) elif model_type == "fp16": @@ -122,6 +125,7 @@ def save_tf_model(tf_model_list, local_tank_cache): get_causal_lm_model, get_keras_model, get_TFhf_model, + get_tfhf_seq2seq_model, ) import tensorflow as tf @@ -147,13 +151,15 @@ def save_tf_model(tf_model_list, local_tank_cache): print(f"Generating artifacts for model {tf_model_name}") if model_type == "hf": model, input, _ = get_masked_lm_model(tf_model_name) - if model_type == "img": + elif model_type == "img": model, input, _ = get_causal_image_model(tf_model_name) - if model_type == "keras": + elif model_type == "keras": model, input, _ = get_keras_model(tf_model_name) - if model_type == "TFhf": + elif model_type == "TFhf": model, input, _ = get_TFhf_model(tf_model_name) - if model_type == "hf_causallm": + elif model_type == "tfhf_seq2seq": + model, input, _ = get_tfhf_seq2seq_model(tf_model_name) + elif model_type == "hf_causallm": model, input, _ = get_causal_lm_model(tf_model_name) tf_model_name = tf_model_name.replace("/", "_") diff --git a/tank/model_metadata.csv b/tank/model_metadata.csv index 00692160..7e9920ba 100644 --- a/tank/model_metadata.csv +++ b/tank/model_metadata.csv @@ -31,6 +31,9 @@ xlm-roberta-base,False,False,-,-,- facebook/convnext-tiny-224,False,False,-,-,- efficientnet-v2-s,False,False,22M,"image-classification,cnn","Includes MBConv and Fused-MBConv" mnasnet1_0,False,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency" +bert-large-uncased,True,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads" +t5-base,True,False,220M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer" +t5-large,True,False,770M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer" bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads" efficientnet_b0,True,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input" efficientnet_b7,True,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input" diff --git a/tank/model_utils.py b/tank/model_utils.py index 5b7d0daf..e7e5cc5b 100644 --- a/tank/model_utils.py +++ b/tank/model_utils.py @@ -29,6 +29,10 @@ hf_img_cls_models = [ "microsoft/beit-base-patch16-224-pt22k-ft22k", "nvidia/mit-b0", ] +hf_seq2seq_models = [ + "t5-base", + "t5-large", +] def get_torch_model(modelname): @@ -36,6 +40,8 @@ def get_torch_model(modelname): return get_vision_model(modelname) elif modelname in hf_img_cls_models: return get_hf_img_cls_model(modelname) + elif modelname in hf_seq2seq_models: + return get_hf_seq2seq_model(modelname) elif "fp16" in modelname: return get_fp16_model(modelname) else: @@ -131,6 +137,47 @@ def get_hf_model(name): return model, test_input, actual_out +##################### Hugging Face Seq2SeqLM Models ################################### + +# We use a maximum sequence length of 512 since this is the default used in the T5 config. +T5_MAX_SEQUENCE_LENGTH = 512 + + +class HFSeq2SeqLanguageModel(torch.nn.Module): + def __init__(self, model_name): + super().__init__() + from transformers import AutoTokenizer, T5Model + + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.tokenization_kwargs = { + "pad_to_multiple_of": T5_MAX_SEQUENCE_LENGTH, + "padding": True, + "return_tensors": "pt", + } + self.model = T5Model.from_pretrained(model_name, return_dict=True) + + def preprocess_input(self, text): + return self.tokenizer(text, **self.tokenization_kwargs) + + def forward(self, input_ids, decoder_input_ids): + return self.model.forward( + input_ids, decoder_input_ids=decoder_input_ids + )[0] + + +def get_hf_seq2seq_model(name): + m = HFSeq2SeqLanguageModel(name) + encoded_input_ids = m.preprocess_input( + "Studies have been shown that owning a dog is good for you" + ).input_ids + decoder_input_ids = m.preprocess_input("Studies show that").input_ids + decoder_input_ids = m.model._shift_right(decoder_input_ids) + + test_input = (encoded_input_ids, decoder_input_ids) + actual_out = m.forward(*test_input) + return m, test_input, actual_out + + ################################################################################ ##################### Torch Vision Models ################################### diff --git a/tank/model_utils_tf.py b/tank/model_utils_tf.py index 8baeae0e..d46b5ceb 100644 --- a/tank/model_utils_tf.py +++ b/tank/model_utils_tf.py @@ -42,6 +42,10 @@ causallm_models = [ tfhf_models = [ "microsoft/MiniLM-L12-H384-uncased", ] +tfhf_seq2seq_models = [ + "t5-base", + "t5-large", +] img_models = [ "google/vit-base-patch16-224", "facebook/convnext-tiny-224", @@ -59,6 +63,8 @@ def get_tf_model(name): return get_TFhf_model(name) elif name in img_models: return get_causal_image_model(name) + elif name in tfhf_seq2seq_models: + return get_tfhf_seq2seq_model(name) else: raise Exception( "TF model not found! Please check that the modelname has been input correctly." @@ -254,6 +260,68 @@ def get_causal_lm_model(hf_name, text="Hello, this is the default text."): return model, test_input, actual_out +##################### TensorflowHugging Face Seq2SeqLM Models ################################### + +# We use a maximum sequence length of 512 since this is the default used in the T5 config. +T5_MAX_SEQUENCE_LENGTH = 512 + +input_signature_t5 = [ + tf.TensorSpec( + shape=[BATCH_SIZE, T5_MAX_SEQUENCE_LENGTH], + dtype=tf.int32, + name="input_ids", + ), + tf.TensorSpec( + shape=[BATCH_SIZE, T5_MAX_SEQUENCE_LENGTH], + dtype=tf.int32, + name="attention_mask", + ), +] + + +class TFHFSeq2SeqLanguageModel(tf.Module): + def __init__(self, model_name): + super(TFHFSeq2SeqLanguageModel, self).__init__() + from transformers import ( + AutoTokenizer, + AutoConfig, + TFAutoModelForSeq2SeqLM, + TFT5Model, + ) + + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.tokenization_kwargs = { + "pad_to_multiple_of": T5_MAX_SEQUENCE_LENGTH, + "padding": True, + "return_tensors": "tf", + } + self.model = TFT5Model.from_pretrained(model_name, return_dict=True) + self.model.predict = lambda x, y: self.model(x, decoder_input_ids=y)[0] + + def preprocess_input(self, text): + return self.tokenizer(text, **self.tokenization_kwargs) + + @tf.function(input_signature=input_signature_t5, jit_compile=True) + def forward(self, input_ids, decoder_input_ids): + return self.model.predict(input_ids, decoder_input_ids) + + +def get_tfhf_seq2seq_model(name): + m = TFHFSeq2SeqLanguageModel(name) + text = "Studies have been shown that owning a dog is good for you" + batched_text = [text] * BATCH_SIZE + encoded_input_ids = m.preprocess_input(batched_text).input_ids + + text = "Studies show that" + batched_text = [text] * BATCH_SIZE + decoder_input_ids = m.preprocess_input(batched_text).input_ids + decoder_input_ids = m.model._shift_right(decoder_input_ids) + + test_input = (encoded_input_ids, decoder_input_ids) + actual_out = m.forward(*test_input) + return m, test_input, actual_out + + ##################### TensorFlow Keras Resnet Models ######################################################### # Static shape, including batch size (1). # Can be dynamic once dynamic shape support is ready. diff --git a/tank/tf_model_list.csv b/tank/tf_model_list.csv index 29f47e5b..e9dd240c 100644 --- a/tank/tf_model_list.csv +++ b/tank/tf_model_list.csv @@ -19,6 +19,8 @@ facebook/convnext-tiny-224,img google/vit-base-patch16-224,img efficientnet-v2-s,keras bert-large-uncased,hf +t5-base,tfhf_seq2seq +t5-large,tfhf_seq2seq efficientnet_b0,keras efficientnet_b7,keras gpt2,hf_causallm diff --git a/tank/torch_model_list.csv b/tank/torch_model_list.csv index b3ce59bb..107f29a2 100644 --- a/tank/torch_model_list.csv +++ b/tank/torch_model_list.csv @@ -19,5 +19,7 @@ mnasnet1_0,False,vision,True,-,"cnn, torchvision, mobile, architecture-search"," resnet50_fp16,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)" bert-base-uncased_fp16,True,fp16,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads" bert-large-uncased,True,hf,True,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads" +t5-base,True,hf_seq2seq,True,220M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer" +t5-large,True,hf_seq2seq,True,770M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer" efficientnet_b0,True,vision,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input" efficientnet_b7,True,vision,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"