diff --git a/.github/workflows/test-models.yml b/.github/workflows/test-models.yml index 10ea19a4..8dd2e8fa 100644 --- a/.github/workflows/test-models.yml +++ b/.github/workflows/test-models.yml @@ -120,7 +120,7 @@ jobs: if: matrix.suite == 'cuda' run: | cd $GITHUB_WORKSPACE - PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh + PYTHON=python${{ matrix.python-version }} ./setup_venv.sh source shark.venv/bin/activate pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv diff --git a/conftest.py b/conftest.py index aa952713..89cf39aa 100644 --- a/conftest.py +++ b/conftest.py @@ -70,3 +70,9 @@ def pytest_addoption(parser): default="./temp_dispatch_benchmarks", help="Directory in which dispatch benchmarks are saved.", ) + parser.addoption( + "--batchsize", + default=1, + type=int, + help="Batch size for the tested model.", + ) diff --git a/requirements-importer.txt b/requirements-importer.txt index 0ac1d3cb..8acab5ba 100644 --- a/requirements-importer.txt +++ b/requirements-importer.txt @@ -2,8 +2,8 @@ --pre numpy>1.22.4 -torchvision pytorch-triton +torchvision==0.16.0.dev20230322 tabulate tqdm @@ -15,8 +15,8 @@ iree-tools-tf # TensorFlow and JAX. gin-config -tf-nightly -keras>=2.10 +tensorflow>2.11 +keras #tf-models-nightly #tensorflow-text-nightly transformers diff --git a/setup_venv.sh b/setup_venv.sh index 1100571e..ef99201b 100755 --- a/setup_venv.sh +++ b/setup_venv.sh @@ -129,11 +129,11 @@ if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then TV_VERSION=${TV_VER:9:18} $PYTHON -m pip uninstall -y torch torchvision $PYTHON -m pip install -U --pre --no-warn-conflicts triton - $PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu117/torch-${TORCH_VERSION}%2Bcu117-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu117/torchvision-${TV_VERSION}%2Bcu117-cp311-cp311-linux_x86_64.whl + $PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu118/torch-${TORCH_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu118/torchvision-${TV_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl if [ $? -eq 0 ];then - echo "Successfully Installed torch + cu117." + echo "Successfully Installed torch + cu118." else - echo "Could not install torch + cu117." >&2 + echo "Could not install torch + cu118." >&2 fi fi diff --git a/shark/shark_benchmark_runner.py b/shark/shark_benchmark_runner.py index dd14bc71..8a287a91 100644 --- a/shark/shark_benchmark_runner.py +++ b/shark/shark_benchmark_runner.py @@ -78,6 +78,7 @@ class SharkBenchmarkRunner(SharkRunner): self.vmfb_file = None self.mlir_dialect = mlir_dialect self.extra_args = extra_args + self.import_args = {} SharkRunner.__init__( self, mlir_module, @@ -112,24 +113,31 @@ class SharkBenchmarkRunner(SharkRunner): def benchmark_torch(self, modelname): import torch - import torch._dynamo as dynamo from tank.model_utils import get_torch_model if self.device == "cuda": torch.set_default_tensor_type(torch.cuda.FloatTensor) if self.enable_tf32: - torch.backends.cuda.matmul.allow_tf32 = True + print( + "Currently disabled TensorFloat32 calculations in pytorch benchmarks." + ) + # torch.backends.cuda.matmul.allow_tf32 = True else: torch.set_default_tensor_type(torch.FloatTensor) torch_device = torch.device( "cuda:0" if self.device == "cuda" else "cpu" ) - HFmodel, input = get_torch_model(modelname)[:2] + HFmodel, input = get_torch_model(modelname, self.import_args)[:2] frontend_model = HFmodel.model frontend_model.to(torch_device) input.to(torch_device) - # frontend_model = torch.compile(frontend_model, mode="max-autotune", backend="inductor") + try: + frontend_model = torch.compile( + frontend_model, mode="max-autotune", backend="inductor" + ) + except RuntimeError: + frontend_model = HFmodel.model for i in range(shark_args.num_warmup_iterations): frontend_model.forward(input) @@ -178,7 +186,7 @@ class SharkBenchmarkRunner(SharkRunner): model, input, ) = get_tf_model( - modelname + modelname, self.import_args )[:2] frontend_model = model @@ -338,11 +346,19 @@ for currently supported models. Exiting benchmark ONNX." return comp_str def benchmark_all_csv( - self, inputs: tuple, modelname, dynamic, device_str, frontend + self, + inputs: tuple, + modelname, + dynamic, + device_str, + frontend, + import_args, ): self.setup_cl(inputs) + self.import_args = import_args field_names = [ "model", + "batch_size", "engine", "dialect", "device", @@ -375,6 +391,7 @@ for currently supported models. Exiting benchmark ONNX." writer = csv.DictWriter(f, fieldnames=field_names) bench_info = {} bench_info["model"] = modelname + bench_info["batch_size"] = str(import_args["batch_size"]) bench_info["dialect"] = self.mlir_dialect bench_info["iterations"] = shark_args.num_iterations if dynamic == True: diff --git a/shark/shark_downloader.py b/shark/shark_downloader.py index a4be6768..280decf1 100644 --- a/shark/shark_downloader.py +++ b/shark/shark_downloader.py @@ -139,11 +139,21 @@ def download_model( tank_url="gs://shark_tank/latest", frontend=None, tuned=None, + import_args={"batch_size": "1"}, ): model_name = model_name.replace("/", "_") dyn_str = "_dynamic" if dynamic else "" os.makedirs(WORKDIR, exist_ok=True) - model_dir_name = model_name + "_" + frontend + if import_args["batch_size"] != 1: + model_dir_name = ( + model_name + + "_" + + frontend + + "_BS" + + str(import_args["batch_size"]) + ) + else: + model_dir_name = model_name + "_" + frontend model_dir = os.path.join(WORKDIR, model_dir_name) full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name @@ -201,7 +211,9 @@ def download_model( from tank.generate_sharktank import gen_shark_files tank_dir = WORKDIR - gen_shark_files(model_name, frontend, tank_dir) + gen_shark_files(model_name, frontend, tank_dir, import_args) + with open(filename, mode="rb") as f: + mlir_file = f.read() function_name = str(np.load(os.path.join(model_dir, "function_name.npy"))) inputs = np.load(os.path.join(model_dir, "inputs.npz")) diff --git a/tank/all_models.csv b/tank/all_models.csv index 41bf9fb8..7b1cd38a 100644 --- a/tank/all_models.csv +++ b/tank/all_models.csv @@ -35,12 +35,12 @@ squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","mac wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos" efficientnet-v2-s,mhlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos" mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos" -t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"","" -t5-base,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"","" -t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"","" -t5-large,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"","" -efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"","" -efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","" +efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,False,"https://github.com/nod-ai/SHARK/issues/1243","" +efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,False,False,"Torchvision imports issue","" efficientnet_b0,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,False,"","" efficientnet_b7,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,False,"","" gpt2,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"","" +t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.","" +t5-base,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"","" +t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported","" +t5-large,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"","" diff --git a/tank/generate_sharktank.py b/tank/generate_sharktank.py index d89f23d4..29906008 100644 --- a/tank/generate_sharktank.py +++ b/tank/generate_sharktank.py @@ -33,7 +33,7 @@ def create_hash(file_name): return file_hash.hexdigest() -def save_torch_model(torch_model_list, local_tank_cache): +def save_torch_model(torch_model_list, local_tank_cache, import_args): from tank.model_utils import ( get_hf_model, get_hf_seq2seq_model, @@ -59,7 +59,6 @@ def save_torch_model(torch_model_list, local_tank_cache): if model_type == "stable_diffusion": args.use_tuned = False args.import_mlir = True - args.use_tuned = False args.local_tank_cache = local_tank_cache precision_values = ["fp16"] @@ -75,6 +74,7 @@ def save_torch_model(torch_model_list, local_tank_cache): width=512, height=512, use_base_vae=False, + custom_vae="", debug=True, sharktank_dir=local_tank_cache, generate_vmfb=False, @@ -82,19 +82,33 @@ def save_torch_model(torch_model_list, local_tank_cache): model() continue if model_type == "vision": - model, input, _ = get_vision_model(torch_model_name) + model, input, _ = get_vision_model( + torch_model_name, import_args + ) elif model_type == "hf": - model, input, _ = get_hf_model(torch_model_name) + model, input, _ = get_hf_model(torch_model_name, import_args) elif model_type == "hf_seq2seq": - model, input, _ = get_hf_seq2seq_model(torch_model_name) + model, input, _ = get_hf_seq2seq_model( + torch_model_name, import_args + ) elif model_type == "hf_img_cls": - model, input, _ = get_hf_img_cls_model(torch_model_name) + model, input, _ = get_hf_img_cls_model( + torch_model_name, import_args + ) elif model_type == "fp16": - model, input, _ = get_fp16_model(torch_model_name) + model, input, _ = get_fp16_model(torch_model_name, import_args) torch_model_name = torch_model_name.replace("/", "_") - torch_model_dir = os.path.join( - local_tank_cache, str(torch_model_name) + "_torch" - ) + if import_args["batch_size"] != 1: + torch_model_dir = os.path.join( + local_tank_cache, + str(torch_model_name) + + "_torch" + + f"_BS{str(import_args['batch_size'])}", + ) + else: + torch_model_dir = os.path.join( + local_tank_cache, str(torch_model_name) + "_torch" + ) os.makedirs(torch_model_dir, exist_ok=True) mlir_importer = SharkImporter( @@ -118,7 +132,7 @@ def save_torch_model(torch_model_list, local_tank_cache): ) -def save_tf_model(tf_model_list, local_tank_cache): +def save_tf_model(tf_model_list, local_tank_cache, import_args): from tank.model_utils_tf import ( get_causal_image_model, get_masked_lm_model, @@ -150,20 +164,38 @@ def save_tf_model(tf_model_list, local_tank_cache): input = None print(f"Generating artifacts for model {tf_model_name}") if model_type == "hf": - model, input, _ = get_causal_lm_model(tf_model_name) + model, input, _ = get_masked_lm_model( + tf_model_name, import_args + ) elif model_type == "img": - model, input, _ = get_causal_image_model(tf_model_name) + model, input, _ = get_causal_image_model( + tf_model_name, import_args + ) elif model_type == "keras": - model, input, _ = get_keras_model(tf_model_name) + model, input, _ = get_keras_model(tf_model_name, import_args) elif model_type == "TFhf": - model, input, _ = get_TFhf_model(tf_model_name) + model, input, _ = get_TFhf_model(tf_model_name, import_args) elif model_type == "tfhf_seq2seq": - model, input, _ = get_tfhf_seq2seq_model(tf_model_name) + model, input, _ = get_tfhf_seq2seq_model( + tf_model_name, import_args + ) + elif model_type == "hf_causallm": + model, input, _ = get_causal_lm_model( + tf_model_name, import_args + ) tf_model_name = tf_model_name.replace("/", "_") - tf_model_dir = os.path.join( - local_tank_cache, str(tf_model_name) + "_tf" - ) + if import_args["batch_size"] != 1: + tf_model_dir = os.path.join( + local_tank_cache, + str(tf_model_name) + + "_tf" + + f"_BS{str(import_args['batch_size'])}", + ) + else: + tf_model_dir = os.path.join( + local_tank_cache, str(tf_model_name) + "_tf" + ) os.makedirs(tf_model_dir, exist_ok=True) mlir_importer = SharkImporter( model, @@ -175,13 +207,9 @@ def save_tf_model(tf_model_list, local_tank_cache): dir=tf_model_dir, model_name=tf_model_name, ) - mlir_hash = create_hash( - os.path.join(tf_model_dir, tf_model_name + "_tf" + ".mlir") - ) - np.save(os.path.join(tf_model_dir, "hash"), np.array(mlir_hash)) -def save_tflite_model(tflite_model_list, local_tank_cache): +def save_tflite_model(tflite_model_list, local_tank_cache, import_args): from shark.tflite_utils import TFLitePreprocessor with open(tflite_model_list) as csvfile: @@ -198,13 +226,13 @@ def save_tflite_model(tflite_model_list, local_tank_cache): os.makedirs(tflite_model_name_dir, exist_ok=True) print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}") - # Preprocess to get SharkImporter input args + # Preprocess to get SharkImporter input import_args tflite_preprocessor = TFLitePreprocessor(str(tflite_model_name)) raw_model_file_path = tflite_preprocessor.get_raw_model_file() inputs = tflite_preprocessor.get_inputs() tflite_interpreter = tflite_preprocessor.get_interpreter() - # Use SharkImporter to get SharkInference input args + # Use SharkImporter to get SharkInference input import_args my_shark_importer = SharkImporter( module=tflite_interpreter, inputs=inputs, @@ -228,43 +256,69 @@ def save_tflite_model(tflite_model_list, local_tank_cache): ) -def gen_shark_files(modelname, frontend, tank_dir): +def check_requirements(frontend): + import importlib + + has_pkgs = False + if frontend == "torch": + tv_spec = importlib.util.find_spec("torchvision") + has_pkgs = tv_spec is not None + + elif frontend in ["tensorflow", "tf"]: + tf_spec = importlib.util.find_spec("tensorflow") + has_pkgs = tf_spec is not None + + return has_pkgs + + +class NoImportException(Exception): + "Raised when requirements are not met for OTF model artifact generation." + pass + + +def gen_shark_files(modelname, frontend, tank_dir, importer_args): # If a model's artifacts are requested by shark_downloader but they don't exist in the cloud, we call this function to generate the artifacts on-the-fly. # TODO: Add TFlite support. import tempfile - torch_model_csv = os.path.join( - os.path.dirname(__file__), "torch_model_list.csv" - ) - tf_model_csv = os.path.join(os.path.dirname(__file__), "tf_model_list.csv") - custom_model_csv = tempfile.NamedTemporaryFile( - dir=os.path.dirname(__file__), - delete=True, - ) - # Create a temporary .csv with only the desired entry. - if frontend == "tf": - with open(tf_model_csv, mode="r") as src: - reader = csv.reader(src) - for row in reader: - if row[0] == modelname: - target = row - with open(custom_model_csv.name, mode="w") as trg: - writer = csv.writer(trg) - writer.writerow(["modelname", "src"]) - writer.writerow(target) - save_tf_model(custom_model_csv.name, tank_dir) + import_args = importer_args + if check_requirements(frontend): + torch_model_csv = os.path.join( + os.path.dirname(__file__), "torch_model_list.csv" + ) + tf_model_csv = os.path.join( + os.path.dirname(__file__), "tf_model_list.csv" + ) + custom_model_csv = tempfile.NamedTemporaryFile( + dir=os.path.dirname(__file__), + delete=True, + ) + # Create a temporary .csv with only the desired entry. + if frontend == "tf": + with open(tf_model_csv, mode="r") as src: + reader = csv.reader(src) + for row in reader: + if row[0] == modelname: + target = row + with open(custom_model_csv.name, mode="w") as trg: + writer = csv.writer(trg) + writer.writerow(["modelname", "src"]) + writer.writerow(target) + save_tf_model(custom_model_csv.name, tank_dir, import_args) - if frontend == "torch": - with open(torch_model_csv, mode="r") as src: - reader = csv.reader(src) - for row in reader: - if row[0] == modelname: - target = row - with open(custom_model_csv.name, mode="w") as trg: - writer = csv.writer(trg) - writer.writerow(["modelname", "src"]) - writer.writerow(target) - save_torch_model(custom_model_csv.name, tank_dir) + elif frontend == "torch": + with open(torch_model_csv, mode="r") as src: + reader = csv.reader(src) + for row in reader: + if row[0] == modelname: + target = row + with open(custom_model_csv.name, mode="w") as trg: + writer = csv.writer(trg) + writer.writerow(["modelname", "src"]) + writer.writerow(target) + save_torch_model(custom_model_csv.name, tank_dir, import_args) + else: + raise NoImportException # Validates whether the file is present or not. @@ -276,7 +330,7 @@ def is_valid_file(arg): if __name__ == "__main__": - # Note, all of these flags are overridden by the import of args from stable_args.py, flags are duplicated temporarily to preserve functionality + # Note, all of these flags are overridden by the import of import_args from stable_args.py, flags are duplicated temporarily to preserve functionality # parser = argparse.ArgumentParser() # parser.add_argument( # "--torch_model_csv", @@ -304,8 +358,11 @@ if __name__ == "__main__": # ) # parser.add_argument("--upload", type=bool, default=False) - # old_args = parser.parse_args() - + # old_import_args = parser.parse_import_args() + import_args = { + "batch_size": "1", + } + print(import_args) home = str(Path.home()) WORKDIR = os.path.join(os.path.dirname(__file__), "..", "gen_shark_tank") torch_model_csv = os.path.join( @@ -319,7 +376,8 @@ if __name__ == "__main__": save_torch_model( os.path.join(os.path.dirname(__file__), "torch_sd_list.csv"), WORKDIR, + import_args, ) - save_torch_model(torch_model_csv, WORKDIR) - save_tf_model(tf_model_csv, WORKDIR) - save_tflite_model(tflite_model_csv, WORKDIR) + save_torch_model(torch_model_csv, WORKDIR, import_args) + save_tf_model(tf_model_csv, WORKDIR, import_args) + save_tflite_model(tflite_model_csv, WORKDIR, import_args) diff --git a/tank/model_utils.py b/tank/model_utils.py index e7e5cc5b..f641fd32 100644 --- a/tank/model_utils.py +++ b/tank/model_utils.py @@ -1,5 +1,4 @@ from shark.shark_inference import SharkInference -from shark.parser import shark_args import torch import numpy as np @@ -35,17 +34,17 @@ hf_seq2seq_models = [ ] -def get_torch_model(modelname): +def get_torch_model(modelname, import_args): if modelname in vision_models: - return get_vision_model(modelname) + return get_vision_model(modelname, import_args) elif modelname in hf_img_cls_models: - return get_hf_img_cls_model(modelname) + return get_hf_img_cls_model(modelname, import_args) elif modelname in hf_seq2seq_models: - return get_hf_seq2seq_model(modelname) + return get_hf_seq2seq_model(modelname, import_args) elif "fp16" in modelname: - return get_fp16_model(modelname) + return get_fp16_model(modelname, import_args) else: - return get_hf_model(modelname) + return get_hf_model(modelname, import_args) ##################### Hugging Face Image Classification Models ################################### @@ -88,14 +87,14 @@ class HuggingFaceImageClassification(torch.nn.Module): return self.model.forward(inputs)[0] -def get_hf_img_cls_model(name): +def get_hf_img_cls_model(name, import_args): model = HuggingFaceImageClassification(name) # you can use preprocess_input_image to get the test_input or just random value. test_input = preprocess_input_image(name) # test_input = torch.FloatTensor(1, 3, 224, 224).uniform_(-1, 1) # print("test_input.shape: ", test_input.shape) # test_input.shape: torch.Size([1, 3, 224, 224]) - test_input = test_input.repeat(BATCH_SIZE, 1, 1, 1) + test_input = test_input.repeat(import_args["batch_size"], 1, 1, 1) actual_out = model(test_input) # print("actual_out.shape: ", actual_out.shape) # actual_out.shape: torch.Size([1, 1000]) @@ -125,14 +124,13 @@ class HuggingFaceLanguage(torch.nn.Module): return self.model.forward(tokens)[0] -def get_hf_model(name): +def get_hf_model(name, import_args): from transformers import ( BertTokenizer, ) model = HuggingFaceLanguage(name) - # TODO: Currently the test input is set to (1,128) - test_input = torch.randint(2, (BATCH_SIZE, 128)) + test_input = torch.randint(2, (import_args["batch_size"], 128)) actual_out = model(test_input) return model, test_input, actual_out @@ -165,7 +163,7 @@ class HFSeq2SeqLanguageModel(torch.nn.Module): )[0] -def get_hf_seq2seq_model(name): +def get_hf_seq2seq_model(name, import_args): m = HFSeq2SeqLanguageModel(name) encoded_input_ids = m.preprocess_input( "Studies have been shown that owning a dog is good for you" @@ -193,7 +191,7 @@ class VisionModule(torch.nn.Module): return self.model.forward(input) -def get_vision_model(torch_model): +def get_vision_model(torch_model, import_args): import torchvision.models as models default_image_size = (224, 224) @@ -239,7 +237,7 @@ def get_vision_model(torch_model): fp16_model = True torch_model, input_image_size = vision_models_dict[torch_model] model = VisionModule(torch_model) - test_input = torch.randn(BATCH_SIZE, 3, 224, 224) + test_input = torch.randn(import_args["batch_size"], 3, *input_image_size) actual_out = model(test_input) if fp16_model is not None: test_input_fp16 = test_input.to( @@ -280,14 +278,14 @@ class BertHalfPrecisionModel(torch.nn.Module): return self.model.forward(tokens)[0] -def get_fp16_model(torch_model): +def get_fp16_model(torch_model, import_args): from transformers import AutoTokenizer modelname = torch_model.replace("_fp16", "") model = BertHalfPrecisionModel(modelname) tokenizer = AutoTokenizer.from_pretrained(modelname) text = "Replace me by any text you like." - text = [text] * BATCH_SIZE + text = [text] * import_args["batch_size"] test_input_fp16 = tokenizer( text, truncation=True, diff --git a/tank/model_utils_tf.py b/tank/model_utils_tf.py index d46b5ceb..b1b34a3c 100644 --- a/tank/model_utils_tf.py +++ b/tank/model_utils_tf.py @@ -1,10 +1,5 @@ import tensorflow as tf import numpy as np -from transformers import ( - AutoModelForSequenceClassification, - BertTokenizer, - TFBertModel, -) BATCH_SIZE = 1 @@ -52,19 +47,19 @@ img_models = [ ] -def get_tf_model(name): +def get_tf_model(name, import_args): if name in keras_models: - return get_keras_model(name) + return get_keras_model(name, import_args) elif name in maskedlm_models: - return get_masked_lm_model(name) + return get_masked_lm_model(name, import_args) elif name in causallm_models: - return get_causal_lm_model(name) + return get_causal_lm_model(name, import_args) elif name in tfhf_models: - return get_TFhf_model(name) + return get_TFhf_model(name, import_args) elif name in img_models: - return get_causal_image_model(name) + return get_causal_image_model(name, import_args) elif name in tfhf_seq2seq_models: - return get_tfhf_seq2seq_model(name) + return get_tfhf_seq2seq_model(name, import_args) else: raise Exception( "TF model not found! Please check that the modelname has been input correctly." @@ -72,6 +67,12 @@ def get_tf_model(name): ##################### Tensorflow Hugging Face Bert Models ################################### +from transformers import ( + AutoModelForSequenceClassification, + BertTokenizer, + TFBertModel, +) + BERT_MAX_SEQUENCE_LENGTH = 128 # Create a set of 2-dimensional inputs @@ -104,7 +105,7 @@ class TFHuggingFaceLanguage(tf.Module): return self.m.predict(input_ids, attention_mask, token_type_ids) -def get_TFhf_model(name): +def get_TFhf_model(name, import_args): model = TFHuggingFaceLanguage(name) tokenizer = BertTokenizer.from_pretrained( "microsoft/MiniLM-L12-H384-uncased" @@ -166,7 +167,6 @@ def preprocess_input( ##################### Tensorflow Hugging Face Masked LM Models ################################### from transformers import TFAutoModelForMaskedLM, AutoTokenizer -import tensorflow as tf MASKED_LM_MAX_SEQUENCE_LENGTH = 128 @@ -196,7 +196,9 @@ class MaskedLM(tf.Module): return self.m.predict(input_ids, attention_mask) -def get_masked_lm_model(hf_name, text="Hello, this is the default text."): +def get_masked_lm_model( + hf_name, import_args, text="Hello, this is the default text." +): model = MaskedLM(hf_name) encoded_input = preprocess_input( hf_name, MASKED_LM_MAX_SEQUENCE_LENGTH, text @@ -251,7 +253,9 @@ class CausalLM(tf.Module): return self.model.predict(input_ids, attention_mask) -def get_causal_lm_model(hf_name, text="Hello, this is the default text."): +def get_causal_lm_model( + hf_name, import_args, text="Hello, this is the default text." +): model = CausalLM(hf_name) batched_text = [text] * BATCH_SIZE encoded_input = model.preprocess_input(batched_text) @@ -306,7 +310,7 @@ class TFHFSeq2SeqLanguageModel(tf.Module): return self.model.predict(input_ids, decoder_input_ids) -def get_tfhf_seq2seq_model(name): +def get_tfhf_seq2seq_model(name, import_args): m = TFHFSeq2SeqLanguageModel(name) text = "Studies have been shown that owning a dog is good for you" batched_text = [text] * BATCH_SIZE @@ -442,7 +446,7 @@ def load_image(path_to_image, width, height, channels): return image -def get_keras_model(modelname): +def get_keras_model(modelname, import_args): if modelname == "efficientnet-v2-s": model = EfficientNetV2SModule() elif modelname == "efficientnet_b0": @@ -530,7 +534,7 @@ def preprocess_input_image(model_name): return [inputs[str(*inputs)]] -def get_causal_image_model(hf_name): +def get_causal_image_model(hf_name, import_args): model = AutoModelImageClassfication(hf_name) test_input = preprocess_input_image(hf_name) # TFSequenceClassifierOutput(loss=None, logits=