Compare commits

...

2 Commits

Author SHA1 Message Date
Ean Garvey
0f3b62032b Only generate artifacts OTF if requirements are met. 2023-03-22 17:04:08 +00:00
Ean Garvey
57e99885e1 Set missing arg in SD tank generation 2023-03-21 14:12:42 -05:00
5 changed files with 74 additions and 46 deletions

View File

@@ -46,7 +46,7 @@ efficientnet_b7,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,False,"",""
efficientnet_b0,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,"",""
efficientnet_b7,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,"",""
gpt2,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
t5-base,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.",""
t5-base,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
t5-large,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported",""
t5-large,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
1 resnet50 mhlo tf 1e-2 1e-3 default nhcw-nhwc False False False macos
46 efficientnet_b0 mhlo tf 1e-2 1e-3 default None nhcw-nhwc False False
47 efficientnet_b7 mhlo tf 1e-2 1e-3 default None nhcw-nhwc False False
48 gpt2 mhlo tf 1e-2 1e-3 default None False False False
49 t5-base linalg torch 1e-2 1e-3 default None False True False True False True Inputs for seq2seq models in torch currently unsupported.
50 t5-base mhlo tf 1e-2 1e-3 default None False False False
51 t5-large linalg torch 1e-2 1e-3 default None False True False True False True Inputs for seq2seq models in torch currently unsupported
52 t5-large mhlo tf 1e-2 1e-3 default None False False False

View File

@@ -75,6 +75,7 @@ def save_torch_model(torch_model_list, local_tank_cache):
width=512,
height=512,
use_base_vae=False,
custom_vae="",
debug=True,
sharktank_dir=local_tank_cache,
generate_vmfb=False,
@@ -150,7 +151,7 @@ def save_tf_model(tf_model_list, local_tank_cache):
input = None
print(f"Generating artifacts for model {tf_model_name}")
if model_type == "hf":
model, input, _ = get_causal_lm_model(tf_model_name)
model, input, _ = get_masked_lm_model(tf_model_name)
elif model_type == "img":
model, input, _ = get_causal_image_model(tf_model_name)
elif model_type == "keras":
@@ -159,6 +160,8 @@ def save_tf_model(tf_model_list, local_tank_cache):
model, input, _ = get_TFhf_model(tf_model_name)
elif model_type == "tfhf_seq2seq":
model, input, _ = get_tfhf_seq2seq_model(tf_model_name)
elif model_type == "hf_causallm":
model, input, _ = get_causal_lm_model(tf_model_name)
tf_model_name = tf_model_name.replace("/", "_")
tf_model_dir = os.path.join(
@@ -175,10 +178,6 @@ def save_tf_model(tf_model_list, local_tank_cache):
dir=tf_model_dir,
model_name=tf_model_name,
)
mlir_hash = create_hash(
os.path.join(tf_model_dir, tf_model_name + "_tf" + ".mlir")
)
np.save(os.path.join(tf_model_dir, "hash"), np.array(mlir_hash))
def save_tflite_model(tflite_model_list, local_tank_cache):
@@ -228,43 +227,69 @@ def save_tflite_model(tflite_model_list, local_tank_cache):
)
def check_requirements(frontend):
import importlib
has_pkgs = False
if frontend == "torch":
tv_spec = importlib.util.find_spec("torchvision")
has_pkgs = tv_spec is not None
elif frontend in ["tensorflow", "tf"]:
keras_spec = importlib.util.find_spec("keras")
tf_spec = importlib.util.find_spec("tensorflow")
has_pkgs = keras_spec is not None and tf_spec is not None
return has_pkgs
class NoImportException(Exception):
"Raised when requirements are not met for OTF model artifact generation."
pass
def gen_shark_files(modelname, frontend, tank_dir):
# If a model's artifacts are requested by shark_downloader but they don't exist in the cloud, we call this function to generate the artifacts on-the-fly.
# TODO: Add TFlite support.
import tempfile
torch_model_csv = os.path.join(
os.path.dirname(__file__), "torch_model_list.csv"
)
tf_model_csv = os.path.join(os.path.dirname(__file__), "tf_model_list.csv")
custom_model_csv = tempfile.NamedTemporaryFile(
dir=os.path.dirname(__file__),
delete=True,
)
# Create a temporary .csv with only the desired entry.
if frontend == "tf":
with open(tf_model_csv, mode="r") as src:
reader = csv.reader(src)
for row in reader:
if row[0] == modelname:
target = row
with open(custom_model_csv.name, mode="w") as trg:
writer = csv.writer(trg)
writer.writerow(["modelname", "src"])
writer.writerow(target)
save_tf_model(custom_model_csv.name, tank_dir)
if check_requirements(frontend):
torch_model_csv = os.path.join(
os.path.dirname(__file__), "torch_model_list.csv"
)
tf_model_csv = os.path.join(
os.path.dirname(__file__), "tf_model_list.csv"
)
custom_model_csv = tempfile.NamedTemporaryFile(
dir=os.path.dirname(__file__),
delete=True,
)
# Create a temporary .csv with only the desired entry.
if frontend == "tf":
with open(tf_model_csv, mode="r") as src:
reader = csv.reader(src)
for row in reader:
if row[0] == modelname:
target = row
with open(custom_model_csv.name, mode="w") as trg:
writer = csv.writer(trg)
writer.writerow(["modelname", "src"])
writer.writerow(target)
save_tf_model(custom_model_csv.name, tank_dir)
if frontend == "torch":
with open(torch_model_csv, mode="r") as src:
reader = csv.reader(src)
for row in reader:
if row[0] == modelname:
target = row
with open(custom_model_csv.name, mode="w") as trg:
writer = csv.writer(trg)
writer.writerow(["modelname", "src"])
writer.writerow(target)
save_torch_model(custom_model_csv.name, tank_dir)
elif frontend == "torch":
with open(torch_model_csv, mode="r") as src:
reader = csv.reader(src)
for row in reader:
if row[0] == modelname:
target = row
with open(custom_model_csv.name, mode="w") as trg:
writer = csv.writer(trg)
writer.writerow(["modelname", "src"])
writer.writerow(target)
save_torch_model(custom_model_csv.name, tank_dir)
else:
raise NoImportException
# Validates whether the file is present or not.

View File

@@ -8,6 +8,7 @@ from parameterized import parameterized
from shark.shark_downloader import download_model
from shark.shark_inference import SharkInference
from shark.parser import shark_args
from tank.generate_sharktank import NoImportException
import iree.compiler as ireec
import pytest
import unittest
@@ -161,11 +162,16 @@ class SharkModuleTester:
if "winograd" in self.config["flags"]:
shark_args.use_winograd = True
model, func_name, inputs, golden_out = download_model(
self.config["model_name"],
tank_url=self.tank_url,
frontend=self.config["framework"],
)
try:
model, func_name, inputs, golden_out = download_model(
self.config["model_name"],
tank_url=self.tank_url,
frontend=self.config["framework"],
)
except NoImportException:
pytest.xfail(
reason=f"Artifacts for this model/config must be generated locally. Please make sure {self.config['framework']} is installed."
)
shark_module = SharkInference(
model,

View File

@@ -23,5 +23,3 @@ t5-base,True,hf_seq2seq,True,220M,"nlp;transformer-encoder;transformer-decoder",
t5-large,True,hf_seq2seq,True,770M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
efficientnet_b0,True,vision,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
efficientnet_b7,True,vision,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
t5-base,True,hf_seq2seq,True,220M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
t5-large,True,hf_seq2seq,True,770M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
1 model_name use_tracing model_type dynamic param_count tags notes
23 t5-large True hf_seq2seq True 770M nlp;transformer-encoder;transformer-decoder Text-to-Text Transfer Transformer
24 efficientnet_b0 True vision False 5.3M image-classification;cnn;conv2d;depthwise-conv Smallest EfficientNet variant with 224x224 input
25 efficientnet_b7 True vision False 66M image-classification;cnn;conv2d;depthwise-conv Largest EfficientNet variant with 600x600 input
t5-base True hf_seq2seq True 220M nlp;transformer-encoder;transformer-decoder Text-to-Text Transfer Transformer
t5-large True hf_seq2seq True 770M nlp;transformer-encoder;transformer-decoder Text-to-Text Transfer Transformer

View File

@@ -1,4 +1,3 @@
model_name, use_tracing, model_type, dynamic, param_count, tags, notes
stabilityai/stable-diffusion-2-1-base,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
stabilityai/stable-diffusion-2-1,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
prompthero/openjourney,True,stable_diffusion,False,??M,"stable diffusion 2.1 base, LLM, Text to image", N/A
1 model_name use_tracing model_type dynamic param_count tags notes
2 stabilityai/stable-diffusion-2-1-base True stable_diffusion False ??M stable diffusion 2.1 base, LLM, Text to image N/A
3 stabilityai/stable-diffusion-2-1 True stable_diffusion False ??M stable diffusion 2.1 base, LLM, Text to image N/A
prompthero/openjourney True stable_diffusion False ??M stable diffusion 2.1 base, LLM, Text to image N/A