mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
3 Commits
fp16cpu
...
20230321.6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3ac36e0cc8 | ||
|
|
2533f751a2 | ||
|
|
57e99885e1 |
@@ -46,7 +46,7 @@ efficientnet_b7,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,False,"",""
|
||||
efficientnet_b0,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,"",""
|
||||
efficientnet_b7,mhlo,tf,1e-2,1e-3,default,None,nhcw-nhwc,False,False,"",""
|
||||
gpt2,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-base,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.",""
|
||||
t5-base,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-large,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported",""
|
||||
t5-large,mhlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
|
||||
|
@@ -75,6 +75,7 @@ def save_torch_model(torch_model_list, local_tank_cache):
|
||||
width=512,
|
||||
height=512,
|
||||
use_base_vae=False,
|
||||
custom_vae="",
|
||||
debug=True,
|
||||
sharktank_dir=local_tank_cache,
|
||||
generate_vmfb=False,
|
||||
@@ -175,10 +176,6 @@ def save_tf_model(tf_model_list, local_tank_cache):
|
||||
dir=tf_model_dir,
|
||||
model_name=tf_model_name,
|
||||
)
|
||||
mlir_hash = create_hash(
|
||||
os.path.join(tf_model_dir, tf_model_name + "_tf" + ".mlir")
|
||||
)
|
||||
np.save(os.path.join(tf_model_dir, "hash"), np.array(mlir_hash))
|
||||
|
||||
|
||||
def save_tflite_model(tflite_model_list, local_tank_cache):
|
||||
@@ -228,43 +225,69 @@ def save_tflite_model(tflite_model_list, local_tank_cache):
|
||||
)
|
||||
|
||||
|
||||
def check_requirements(frontend):
|
||||
import importlib
|
||||
|
||||
has_pkgs = False
|
||||
if frontend == "torch":
|
||||
tv_spec = importlib.util.find_spec("torchvision")
|
||||
has_pkgs = tv_spec is not None
|
||||
|
||||
elif frontend in ["tensorflow", "tf"]:
|
||||
keras_spec = importlib.util.find_spec("keras")
|
||||
tf_spec = importlib.util.find_spec("tensorflow")
|
||||
has_pkgs = keras_spec is not None and tf_spec is not None
|
||||
|
||||
return has_pkgs
|
||||
|
||||
|
||||
class NoImportException(Exception):
|
||||
"Raised when requirements are not met for OTF model artifact generation."
|
||||
pass
|
||||
|
||||
|
||||
def gen_shark_files(modelname, frontend, tank_dir):
|
||||
# If a model's artifacts are requested by shark_downloader but they don't exist in the cloud, we call this function to generate the artifacts on-the-fly.
|
||||
# TODO: Add TFlite support.
|
||||
import tempfile
|
||||
|
||||
torch_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "torch_model_list.csv"
|
||||
)
|
||||
tf_model_csv = os.path.join(os.path.dirname(__file__), "tf_model_list.csv")
|
||||
custom_model_csv = tempfile.NamedTemporaryFile(
|
||||
dir=os.path.dirname(__file__),
|
||||
delete=True,
|
||||
)
|
||||
# Create a temporary .csv with only the desired entry.
|
||||
if frontend == "tf":
|
||||
with open(tf_model_csv, mode="r") as src:
|
||||
reader = csv.reader(src)
|
||||
for row in reader:
|
||||
if row[0] == modelname:
|
||||
target = row
|
||||
with open(custom_model_csv.name, mode="w") as trg:
|
||||
writer = csv.writer(trg)
|
||||
writer.writerow(["modelname", "src"])
|
||||
writer.writerow(target)
|
||||
save_tf_model(custom_model_csv.name, tank_dir)
|
||||
if check_requirements(frontend):
|
||||
torch_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "torch_model_list.csv"
|
||||
)
|
||||
tf_model_csv = os.path.join(
|
||||
os.path.dirname(__file__), "tf_model_list.csv"
|
||||
)
|
||||
custom_model_csv = tempfile.NamedTemporaryFile(
|
||||
dir=os.path.dirname(__file__),
|
||||
delete=True,
|
||||
)
|
||||
# Create a temporary .csv with only the desired entry.
|
||||
if frontend == "tf":
|
||||
with open(tf_model_csv, mode="r") as src:
|
||||
reader = csv.reader(src)
|
||||
for row in reader:
|
||||
if row[0] == modelname:
|
||||
target = row
|
||||
with open(custom_model_csv.name, mode="w") as trg:
|
||||
writer = csv.writer(trg)
|
||||
writer.writerow(["modelname", "src"])
|
||||
writer.writerow(target)
|
||||
save_tf_model(custom_model_csv.name, tank_dir)
|
||||
|
||||
if frontend == "torch":
|
||||
with open(torch_model_csv, mode="r") as src:
|
||||
reader = csv.reader(src)
|
||||
for row in reader:
|
||||
if row[0] == modelname:
|
||||
target = row
|
||||
with open(custom_model_csv.name, mode="w") as trg:
|
||||
writer = csv.writer(trg)
|
||||
writer.writerow(["modelname", "src"])
|
||||
writer.writerow(target)
|
||||
save_torch_model(custom_model_csv.name, tank_dir)
|
||||
elif frontend == "torch":
|
||||
with open(torch_model_csv, mode="r") as src:
|
||||
reader = csv.reader(src)
|
||||
for row in reader:
|
||||
if row[0] == modelname:
|
||||
target = row
|
||||
with open(custom_model_csv.name, mode="w") as trg:
|
||||
writer = csv.writer(trg)
|
||||
writer.writerow(["modelname", "src"])
|
||||
writer.writerow(target)
|
||||
save_torch_model(custom_model_csv.name, tank_dir)
|
||||
else:
|
||||
raise NoImportException
|
||||
|
||||
|
||||
# Validates whether the file is present or not.
|
||||
|
||||
@@ -8,6 +8,7 @@ from parameterized import parameterized
|
||||
from shark.shark_downloader import download_model
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
from tank.generate_sharktank import NoImportException
|
||||
import iree.compiler as ireec
|
||||
import pytest
|
||||
import unittest
|
||||
@@ -161,11 +162,16 @@ class SharkModuleTester:
|
||||
if "winograd" in self.config["flags"]:
|
||||
shark_args.use_winograd = True
|
||||
|
||||
model, func_name, inputs, golden_out = download_model(
|
||||
self.config["model_name"],
|
||||
tank_url=self.tank_url,
|
||||
frontend=self.config["framework"],
|
||||
)
|
||||
try:
|
||||
model, func_name, inputs, golden_out = download_model(
|
||||
self.config["model_name"],
|
||||
tank_url=self.tank_url,
|
||||
frontend=self.config["framework"],
|
||||
)
|
||||
except NoImportException:
|
||||
pytest.xfail(
|
||||
reason=f"Artifacts for this model/config must be generated locally. Please make sure {self.config['framework']} is installed."
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
|
||||
@@ -23,5 +23,3 @@ t5-base,True,hf_seq2seq,True,220M,"nlp;transformer-encoder;transformer-decoder",
|
||||
t5-large,True,hf_seq2seq,True,770M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
|
||||
efficientnet_b0,True,vision,False,5.3M,"image-classification;cnn;conv2d;depthwise-conv","Smallest EfficientNet variant with 224x224 input"
|
||||
efficientnet_b7,True,vision,False,66M,"image-classification;cnn;conv2d;depthwise-conv","Largest EfficientNet variant with 600x600 input"
|
||||
t5-base,True,hf_seq2seq,True,220M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
|
||||
t5-large,True,hf_seq2seq,True,770M,"nlp;transformer-encoder;transformer-decoder","Text-to-Text Transfer Transformer"
|
||||
|
||||
|
Reference in New Issue
Block a user