mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Remove tf dependencies from importer path. (#1874)
* Remove tf dependencies from import path. * Fix formatting.
This commit is contained in:
@@ -6,7 +6,8 @@ class FirstVicuna(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
precision="fp32",accumulates="fp32",
|
||||
precision="fp32",
|
||||
accumulates="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
@@ -15,7 +16,9 @@ class FirstVicuna(torch.nn.Module):
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
if "llama2" in model_name:
|
||||
kwargs["use_auth_token"] = hf_auth_token
|
||||
self.accumulates = torch.float32 if accumulates=="fp32" else torch.float16
|
||||
self.accumulates = (
|
||||
torch.float32 if accumulates == "fp32" else torch.float16
|
||||
)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
@@ -58,7 +61,8 @@ class SecondVicuna7B(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
precision="fp32",accumulates="fp32",
|
||||
precision="fp32",
|
||||
accumulates="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
@@ -70,7 +74,9 @@ class SecondVicuna7B(torch.nn.Module):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
self.accumulates = torch.float32 if accumulates=="fp32" else torch.float16
|
||||
self.accumulates = (
|
||||
torch.float32 if accumulates == "fp32" else torch.float16
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
@@ -319,7 +325,9 @@ class SecondVicuna13B(torch.nn.Module):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
self.accumulates = torch.float32 if accumulates=="fp32" else torch.float16
|
||||
self.accumulates = (
|
||||
torch.float32 if accumulates == "fp32" else torch.float16
|
||||
)
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
@@ -601,7 +609,8 @@ class SecondVicuna70B(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
precision="fp32",accumulates="fp32",
|
||||
precision="fp32",
|
||||
accumulates="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
@@ -613,7 +622,9 @@ class SecondVicuna70B(torch.nn.Module):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
self.accumulates = torch.float32 if accumulates=="fp32" else torch.float16
|
||||
self.accumulates = (
|
||||
torch.float32 if accumulates == "fp32" else torch.float16
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
|
||||
@@ -8,19 +8,8 @@ torchvision
|
||||
tqdm
|
||||
|
||||
#iree-compiler | iree-runtime should already be installed
|
||||
#these dont work ok osx
|
||||
#iree-tools-tflite
|
||||
#iree-tools-xla
|
||||
#iree-tools-tf
|
||||
|
||||
# TensorFlow and JAX.
|
||||
gin-config
|
||||
tensorflow-macos
|
||||
tensorflow-metal
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
transformers
|
||||
tensorflow-probability
|
||||
#jax[cpu]
|
||||
|
||||
# tflitehub dependencies.
|
||||
|
||||
@@ -9,23 +9,13 @@ tabulate
|
||||
tqdm
|
||||
|
||||
#iree-compiler | iree-runtime should already be installed
|
||||
iree-tools-tflite
|
||||
iree-tools-xla
|
||||
iree-tools-tf
|
||||
|
||||
# TensorFlow and JAX.
|
||||
# Modelling and JAX.
|
||||
gin-config
|
||||
tf-nightly
|
||||
keras-nightly
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
transformers
|
||||
diffusers
|
||||
#tensorflow-probability
|
||||
#jax[cpu]
|
||||
|
||||
|
||||
# tflitehub dependencies.
|
||||
Pillow
|
||||
|
||||
# Testing and support.
|
||||
|
||||
@@ -1,24 +1,6 @@
|
||||
resnet50,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
albert-base-v2,stablehlo,tf,1e-2,1e-2,default,None,False,False,False,"",""
|
||||
roberta-base,stablehlo,tf,1e-02,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
|
||||
bert-base-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"","enabled_windows"
|
||||
camembert-base,stablehlo,tf,1e-2,1e-3,default,None,True,True,True,"",""
|
||||
dbmdz/convbert-base-turkish-cased,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,True,True,False,"https://github.com/iree-org/iree/issues/9971",""
|
||||
distilbert-base-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
facebook/convnext-tiny-224,stablehlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc,True,True,False,"https://github.com/nod-ai/SHARK/issues/311 & https://github.com/nod-ai/SHARK/issues/342","macos"
|
||||
funnel-transformer/small,stablehlo,tf,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/201",""
|
||||
google/electra-small-discriminator,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
google/mobilebert-uncased,stablehlo,tf,1e-2,1e-3,default,None,True,False,False,"Fails during iree-compile","macos"
|
||||
google/vit-base-patch16-224,stablehlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc,False,False,False,"",""
|
||||
microsoft/MiniLM-L12-H384-uncased,stablehlo,tf,1e-2,1e-3,tf_hf,None,True,False,False,"Fails during iree-compile.",""
|
||||
microsoft/layoutlm-base-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
microsoft/mpnet-base,stablehlo,tf,1e-2,1e-2,default,None,True,True,True,"",""
|
||||
alexnet,linalg,torch,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/879",""
|
||||
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,True,True,"",""
|
||||
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-large-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails during iree-compile.",""
|
||||
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311",""
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
|
||||
@@ -32,14 +14,8 @@ resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc/img2col,True,True,True,"Numerics issues, awaiting cuda-independent fp16 integration",""
|
||||
squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
|
||||
efficientnet-v2-s,stablehlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
|
||||
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
efficientnet_b0,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
|
||||
efficientnet_b7,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos"
|
||||
gpt2,stablehlo,tf,1e-2,1e-3,default,None,True,False,False,"","macos"
|
||||
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.","macos"
|
||||
t5-base,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"","macos"
|
||||
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported","macos"
|
||||
t5-large,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"","macos"
|
||||
|
||||
|
@@ -130,133 +130,6 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
)
|
||||
|
||||
|
||||
def save_tf_model(tf_model_list, local_tank_cache, import_args):
|
||||
from tank.model_utils_tf import (
|
||||
get_causal_image_model,
|
||||
get_masked_lm_model,
|
||||
get_causal_lm_model,
|
||||
get_keras_model,
|
||||
get_TFhf_model,
|
||||
get_tfhf_seq2seq_model,
|
||||
)
|
||||
import os
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
||||
import tensorflow as tf
|
||||
|
||||
visible_default = tf.config.list_physical_devices("GPU")
|
||||
try:
|
||||
tf.config.set_visible_devices([], "GPU")
|
||||
visible_devices = tf.config.get_visible_devices()
|
||||
for device in visible_devices:
|
||||
assert device.device_type != "GPU"
|
||||
except:
|
||||
# Invalid device or cannot modify virtual devices once initialized.
|
||||
pass
|
||||
|
||||
with open(tf_model_list) as csvfile:
|
||||
tf_reader = csv.reader(csvfile, delimiter=",")
|
||||
fields = next(tf_reader)
|
||||
for row in tf_reader:
|
||||
tf_model_name = row[0]
|
||||
model_type = row[1]
|
||||
|
||||
model = None
|
||||
input = None
|
||||
print(f"Generating artifacts for model {tf_model_name}")
|
||||
if model_type == "hf":
|
||||
model, input, _ = get_masked_lm_model(
|
||||
tf_model_name, import_args
|
||||
)
|
||||
elif model_type == "img":
|
||||
model, input, _ = get_causal_image_model(
|
||||
tf_model_name, import_args
|
||||
)
|
||||
elif model_type == "keras":
|
||||
model, input, _ = get_keras_model(tf_model_name, import_args)
|
||||
elif model_type == "TFhf":
|
||||
model, input, _ = get_TFhf_model(tf_model_name, import_args)
|
||||
elif model_type == "tfhf_seq2seq":
|
||||
model, input, _ = get_tfhf_seq2seq_model(
|
||||
tf_model_name, import_args
|
||||
)
|
||||
elif model_type == "hf_causallm":
|
||||
model, input, _ = get_causal_lm_model(
|
||||
tf_model_name, import_args
|
||||
)
|
||||
|
||||
tf_model_name = tf_model_name.replace("/", "_")
|
||||
if import_args["batch_size"] != 1:
|
||||
tf_model_dir = os.path.join(
|
||||
local_tank_cache,
|
||||
str(tf_model_name)
|
||||
+ "_tf"
|
||||
+ f"_BS{str(import_args['batch_size'])}",
|
||||
)
|
||||
else:
|
||||
tf_model_dir = os.path.join(
|
||||
local_tank_cache, str(tf_model_name) + "_tf"
|
||||
)
|
||||
os.makedirs(tf_model_dir, exist_ok=True)
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
inputs=input,
|
||||
frontend="tf",
|
||||
)
|
||||
mlir_importer.import_debug(
|
||||
is_dynamic=False,
|
||||
dir=tf_model_dir,
|
||||
model_name=tf_model_name,
|
||||
)
|
||||
|
||||
|
||||
def save_tflite_model(tflite_model_list, local_tank_cache, import_args):
|
||||
from shark.tflite_utils import TFLitePreprocessor
|
||||
|
||||
with open(tflite_model_list) as csvfile:
|
||||
tflite_reader = csv.reader(csvfile, delimiter=",")
|
||||
for row in tflite_reader:
|
||||
print("\n")
|
||||
tflite_model_name = row[0]
|
||||
tflite_model_link = row[1]
|
||||
print("tflite_model_name", tflite_model_name)
|
||||
print("tflite_model_link", tflite_model_link)
|
||||
tflite_model_name_dir = os.path.join(
|
||||
local_tank_cache, str(tflite_model_name) + "_tflite"
|
||||
)
|
||||
os.makedirs(tflite_model_name_dir, exist_ok=True)
|
||||
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
|
||||
|
||||
# Preprocess to get SharkImporter input import_args
|
||||
tflite_preprocessor = TFLitePreprocessor(str(tflite_model_name))
|
||||
raw_model_file_path = tflite_preprocessor.get_raw_model_file()
|
||||
inputs = tflite_preprocessor.get_inputs()
|
||||
tflite_interpreter = tflite_preprocessor.get_interpreter()
|
||||
|
||||
# Use SharkImporter to get SharkInference input import_args
|
||||
my_shark_importer = SharkImporter(
|
||||
module=tflite_interpreter,
|
||||
inputs=inputs,
|
||||
frontend="tflite",
|
||||
raw_model_file=raw_model_file_path,
|
||||
)
|
||||
my_shark_importer.import_debug(
|
||||
dir=tflite_model_name_dir,
|
||||
model_name=tflite_model_name,
|
||||
func_name="main",
|
||||
)
|
||||
mlir_hash = create_hash(
|
||||
os.path.join(
|
||||
tflite_model_name_dir,
|
||||
tflite_model_name + "_tflite" + ".mlir",
|
||||
)
|
||||
)
|
||||
np.save(
|
||||
os.path.join(tflite_model_name_dir, "hash"),
|
||||
np.array(mlir_hash),
|
||||
)
|
||||
|
||||
|
||||
def check_requirements(frontend):
|
||||
import importlib
|
||||
|
||||
@@ -265,10 +138,6 @@ def check_requirements(frontend):
|
||||
tv_spec = importlib.util.find_spec("torchvision")
|
||||
has_pkgs = tv_spec is not None
|
||||
|
||||
elif frontend in ["tensorflow", "tf"]:
|
||||
tf_spec = importlib.util.find_spec("tensorflow")
|
||||
has_pkgs = tf_spec is not None
|
||||
|
||||
return has_pkgs
|
||||
|
||||
|
||||
@@ -287,27 +156,11 @@ def gen_shark_files(modelname, frontend, tank_dir, importer_args):
|
||||
torch_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "torch_model_list.csv"
|
||||
)
|
||||
tf_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "tf_model_list.csv"
|
||||
)
|
||||
custom_model_csv = tempfile.NamedTemporaryFile(
|
||||
dir=os.path.dirname(__file__),
|
||||
delete=True,
|
||||
)
|
||||
# Create a temporary .csv with only the desired entry.
|
||||
if frontend == "tf":
|
||||
with open(tf_model_csv, mode="r") as src:
|
||||
reader = csv.reader(src)
|
||||
for row in reader:
|
||||
if row[0] == modelname:
|
||||
target = row
|
||||
with open(custom_model_csv.name, mode="w") as trg:
|
||||
writer = csv.writer(trg)
|
||||
writer.writerow(["modelname", "src"])
|
||||
writer.writerow(target)
|
||||
save_tf_model(custom_model_csv.name, tank_dir, import_args)
|
||||
|
||||
elif frontend == "torch":
|
||||
if frontend == "torch":
|
||||
with open(torch_model_csv, mode="r") as src:
|
||||
reader = csv.reader(src)
|
||||
for row in reader:
|
||||
@@ -341,18 +194,6 @@ if __name__ == "__main__":
|
||||
# Please see: https://github.com/nod-ai/SHARK/blob/main/tank/torch_model_list.csv""",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--tf_model_csv",
|
||||
# type=lambda x: is_valid_file(x),
|
||||
# default="./tank/tf_model_list.csv",
|
||||
# help="Contains the file with tf model name and args.",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--tflite_model_csv",
|
||||
# type=lambda x: is_valid_file(x),
|
||||
# default="./tank/tflite/tflite_model_list.csv",
|
||||
# help="Contains the file with tf model name and args.",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--ci_tank_dir",
|
||||
# type=bool,
|
||||
# default=False,
|
||||
@@ -369,11 +210,5 @@ if __name__ == "__main__":
|
||||
torch_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "torch_model_list.csv"
|
||||
)
|
||||
tf_model_csv = os.path.join(os.path.dirname(__file__), "tf_model_list.csv")
|
||||
tflite_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "tflite", "tflite_model_list.csv"
|
||||
)
|
||||
|
||||
save_torch_model(torch_model_csv, WORKDIR, import_args)
|
||||
# save_tf_model(tf_model_csv, WORKDIR, import_args)
|
||||
# save_tflite_model(tflite_model_csv, WORKDIR, import_args)
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
model_name, model_type
|
||||
albert-base-v2,hf
|
||||
bert-base-uncased,hf
|
||||
camembert-base,hf
|
||||
dbmdz/convbert-base-turkish-cased,hf
|
||||
distilbert-base-uncased,hf
|
||||
google/electra-small-discriminator,hf
|
||||
funnel-transformer/small,hf
|
||||
microsoft/layoutlm-base-uncased,hf
|
||||
google/mobilebert-uncased,hf
|
||||
microsoft/mpnet-base,hf
|
||||
roberta-base,hf
|
||||
resnet50,keras
|
||||
xlm-roberta-base,hf
|
||||
microsoft/MiniLM-L12-H384-uncased,TFhf
|
||||
funnel-transformer/small,hf
|
||||
microsoft/mpnet-base,hf
|
||||
facebook/convnext-tiny-224,img
|
||||
google/vit-base-patch16-224,img
|
||||
efficientnet-v2-s,keras
|
||||
bert-large-uncased,hf
|
||||
t5-base,tfhf_seq2seq
|
||||
t5-large,tfhf_seq2seq
|
||||
efficientnet_b0,keras
|
||||
efficientnet_b7,keras
|
||||
gpt2,hf_causallm
|
||||
t5-base,tfhf_seq2seq
|
||||
t5-large,tfhf_seq2seq
|
||||
|
Reference in New Issue
Block a user