mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
find gsutil on linux (#557)
* find gsutil on linux * cleaned up downloader and ditched gsutil Co-authored-by: dan <dan@nod-labs.com>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_tf_model
|
||||
|
||||
|
||||
def load_and_preprocess_image(fname: str):
|
||||
|
||||
@@ -6,7 +6,7 @@ pyinstaller
|
||||
tqdm
|
||||
|
||||
# SHARK Downloader
|
||||
gsutil
|
||||
google-cloud-storage
|
||||
|
||||
# Testing
|
||||
pytest
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_torch_model
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model("bloom")
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"bloom", frontend="torch"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="cpu", mlir_dialect="tm_tensor"
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_torch_model
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"microsoft/MiniLM-L12-H384-uncased",
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import torchvision.models as models
|
||||
from torchvision import transforms
|
||||
import sys
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_torch_model
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
|
||||
################################## Preprocessing inputs and model ############
|
||||
@@ -66,7 +66,9 @@ labels = load_labels()
|
||||
|
||||
|
||||
## Can pass any img or input to the forward module.
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model("resnet50")
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"resnet50", frontend="torch"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(mlir_model, func_name, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
|
||||
@@ -37,10 +37,12 @@ args = p.parse_args()
|
||||
|
||||
|
||||
def fp16_unet():
|
||||
from shark.shark_downloader import download_torch_model
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model(
|
||||
"stable_diff_f16_18_OCT", tank_url="gs://shark_tank/prashant_nod"
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"stable_diff_f16_18_OCT",
|
||||
tank_url="gs://shark_tank/prashant_nod",
|
||||
frontend="torch",
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
|
||||
|
||||
@@ -17,7 +17,7 @@ from keras_cv.models.generative.stable_diffusion.text_encoder import (
|
||||
)
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_tf_model
|
||||
from shark.shark_downloader import download_model
|
||||
from PIL import Image
|
||||
|
||||
# pip install "git+https://github.com/keras-team/keras-cv.git"
|
||||
@@ -75,8 +75,8 @@ class SharkStableDiffusion:
|
||||
# Create models
|
||||
self.text_encoder = TextEncoder(MAX_PROMPT_LENGTH)
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_tf_model(
|
||||
"stable_diff", tank_url="gs://shark_tank/quinn"
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"stable_diff", tank_url="gs://shark_tank/quinn", frontend="tf"
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=device, mlir_dialect="mhlo"
|
||||
|
||||
@@ -39,10 +39,12 @@ def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
|
||||
# Downloads the model from shark_tank and returns the shark_module.
|
||||
def get_shark_model(tank_url, model_name, extra_args=[]):
|
||||
from shark.shark_downloader import download_torch_model
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model(
|
||||
model_name, tank_url=tank_url
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
model_name,
|
||||
tank_url=tank_url,
|
||||
frontend="torch",
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_torch_model
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model("v_diffusion")
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"v_diffusion", frontend="torch"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="vulkan", mlir_dialect="linalg"
|
||||
|
||||
@@ -17,17 +17,22 @@ import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from shark.parser import shark_args
|
||||
from google.cloud import storage
|
||||
|
||||
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(
|
||||
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
def download_public_file(full_gs_url, destination_file_name):
|
||||
"""Downloads a public blob from the bucket."""
|
||||
# bucket_name = "gs://your-bucket-name/path/to/file"
|
||||
# destination_file_name = "local/path/to/file"
|
||||
|
||||
storage_client = storage.Client.create_anonymous_client()
|
||||
bucket_name = full_gs_url.split("/")[2]
|
||||
source_blob_name = "/".join(full_gs_url.split("/")[3:])
|
||||
bucket = storage_client.bucket(bucket_name)
|
||||
blob = bucket.blob(source_blob_name)
|
||||
blob.download_to_filename(destination_file_name)
|
||||
|
||||
|
||||
GSUTIL_PATH = resource_path("gsutil")
|
||||
GSUTIL_FLAGS = ' -o "GSUtil:parallel_process_count=1" -m cp -r '
|
||||
|
||||
|
||||
@@ -98,103 +103,23 @@ def check_dir_exists(model_name, frontend="torch", dynamic=""):
|
||||
|
||||
|
||||
# Downloads the torch model from gs://shark_tank dir.
|
||||
def download_torch_model(
|
||||
model_name, dynamic=False, tank_url="gs://shark_tank/latest"
|
||||
def download_model(
|
||||
model_name,
|
||||
dynamic=False,
|
||||
tank_url="gs://shark_tank/latest",
|
||||
frontend=None,
|
||||
tuned=None,
|
||||
):
|
||||
model_name = model_name.replace("/", "_")
|
||||
dyn_str = "_dynamic" if dynamic else ""
|
||||
os.makedirs(WORKDIR, exist_ok=True)
|
||||
model_dir_name = model_name + "_torch"
|
||||
|
||||
def gs_download_model():
|
||||
gs_command = (
|
||||
GSUTIL_PATH
|
||||
+ GSUTIL_FLAGS
|
||||
+ tank_url
|
||||
+ "/"
|
||||
+ model_dir_name
|
||||
+ ' "'
|
||||
+ WORKDIR
|
||||
+ '"'
|
||||
)
|
||||
if os.system(gs_command) != 0:
|
||||
raise Exception("model not present in the tank. Contact Nod Admin")
|
||||
|
||||
if not check_dir_exists(model_dir_name, frontend="torch", dynamic=dyn_str):
|
||||
gs_download_model()
|
||||
else:
|
||||
if not _internet_connected():
|
||||
print(
|
||||
"No internet connection. Using the model already present in the tank."
|
||||
)
|
||||
else:
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
|
||||
gs_hash = (
|
||||
GSUTIL_PATH
|
||||
+ GSUTIL_FLAGS
|
||||
+ tank_url
|
||||
+ "/"
|
||||
+ model_dir_name
|
||||
+ "/hash.npy"
|
||||
+ " "
|
||||
+ os.path.join(model_dir, "upstream_hash.npy")
|
||||
)
|
||||
if os.system(gs_hash) != 0:
|
||||
raise Exception("hash of the model not present in the tank.")
|
||||
upstream_hash = str(
|
||||
np.load(os.path.join(model_dir, "upstream_hash.npy"))
|
||||
)
|
||||
if local_hash != upstream_hash:
|
||||
if shark_args.update_tank == True:
|
||||
gs_download_model()
|
||||
else:
|
||||
print(
|
||||
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
|
||||
)
|
||||
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
with open(
|
||||
os.path.join(model_dir, model_name + dyn_str + "_torch.mlir"),
|
||||
mode="rb",
|
||||
) as f:
|
||||
mlir_file = f.read()
|
||||
|
||||
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
|
||||
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
|
||||
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
|
||||
|
||||
inputs_tuple = tuple([inputs[key] for key in inputs])
|
||||
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
|
||||
return mlir_file, function_name, inputs_tuple, golden_out_tuple
|
||||
|
||||
|
||||
# Downloads the tflite model from gs://shark_tank dir.
|
||||
def download_tflite_model(
|
||||
model_name, dynamic=False, tank_url="gs://shark_tank/latest"
|
||||
):
|
||||
dyn_str = "_dynamic" if dynamic else ""
|
||||
os.makedirs(WORKDIR, exist_ok=True)
|
||||
model_dir_name = model_name + "_tflite"
|
||||
|
||||
def gs_download_model():
|
||||
gs_command = (
|
||||
GSUTIL_PATH
|
||||
+ GSUTIL_FLAGS
|
||||
+ tank_url
|
||||
+ "/"
|
||||
+ model_dir_name
|
||||
+ ' "'
|
||||
+ WORKDIR
|
||||
+ '"'
|
||||
)
|
||||
if os.system(gs_command) != 0:
|
||||
raise Exception("model not present in the tank. Contact Nod Admin")
|
||||
model_dir_name = model_name + "_" + frontend
|
||||
full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name
|
||||
|
||||
if not check_dir_exists(
|
||||
model_dir_name, frontend="tflite", dynamic=dyn_str
|
||||
model_dir_name, frontend=frontend, dynamic=dyn_str
|
||||
):
|
||||
gs_download_model()
|
||||
download_public_file(full_gs_url, WORKDIR)
|
||||
else:
|
||||
if not _internet_connected():
|
||||
print(
|
||||
@@ -203,104 +128,34 @@ def download_tflite_model(
|
||||
else:
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
|
||||
gs_hash = (
|
||||
GSUTIL_PATH
|
||||
+ GSUTIL_FLAGS
|
||||
+ tank_url
|
||||
+ "/"
|
||||
+ model_dir_name
|
||||
+ "/hash.npy"
|
||||
+ " "
|
||||
+ os.path.join(model_dir, "upstream_hash.npy")
|
||||
gs_hash_url = (
|
||||
tank_url.rstrip("/") + "/" + model_dir_name + "/hash.npy"
|
||||
)
|
||||
download_public_file(
|
||||
gs_hash_url, os.path.join(model_dir, "upstream_hash.npy")
|
||||
)
|
||||
if os.system(gs_hash) != 0:
|
||||
raise Exception("hash of the model not present in the tank.")
|
||||
upstream_hash = str(
|
||||
np.load(os.path.join(model_dir, "upstream_hash.npy"))
|
||||
)
|
||||
if local_hash != upstream_hash:
|
||||
if shark_args.update_tank == True:
|
||||
gs_download_model()
|
||||
download_public_file(full_gs_url, WORKDIR)
|
||||
else:
|
||||
print(
|
||||
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
|
||||
)
|
||||
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
with open(
|
||||
os.path.join(model_dir, model_name + dyn_str + "_tflite.mlir"),
|
||||
mode="rb",
|
||||
) as f:
|
||||
mlir_file = f.read()
|
||||
|
||||
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
|
||||
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
|
||||
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
|
||||
|
||||
inputs_tuple = tuple([inputs[key] for key in inputs])
|
||||
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
|
||||
return mlir_file, function_name, inputs_tuple, golden_out_tuple
|
||||
|
||||
|
||||
def download_tf_model(
|
||||
model_name, tuned=None, tank_url="gs://shark_tank/latest"
|
||||
):
|
||||
model_name = model_name.replace("/", "_")
|
||||
os.makedirs(WORKDIR, exist_ok=True)
|
||||
model_dir_name = model_name + "_tf"
|
||||
|
||||
def gs_download_model():
|
||||
gs_command = (
|
||||
GSUTIL_PATH
|
||||
+ GSUTIL_FLAGS
|
||||
+ tank_url
|
||||
+ "/"
|
||||
+ model_dir_name
|
||||
+ ' "'
|
||||
+ WORKDIR
|
||||
+ '"'
|
||||
)
|
||||
if os.system(gs_command) != 0:
|
||||
raise Exception("model not present in the tank. Contact Nod Admin")
|
||||
|
||||
if not check_dir_exists(model_dir_name, frontend="tf"):
|
||||
gs_download_model()
|
||||
else:
|
||||
if not _internet_connected():
|
||||
print(
|
||||
"No internet connection. Using the model already present in the tank."
|
||||
)
|
||||
else:
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
|
||||
gs_hash = (
|
||||
GSUTIL_PATH
|
||||
+ GSUTIL_FLAGS
|
||||
+ tank_url
|
||||
+ "/"
|
||||
+ model_dir_name
|
||||
+ "/hash.npy"
|
||||
+ " "
|
||||
+ os.path.join(model_dir, "upstream_hash.npy")
|
||||
)
|
||||
if os.system(gs_hash) != 0:
|
||||
raise Exception("hash of the model not present in the tank.")
|
||||
upstream_hash = str(
|
||||
np.load(os.path.join(model_dir, "upstream_hash.npy"))
|
||||
)
|
||||
if local_hash != upstream_hash:
|
||||
if shark_args.update_tank == True:
|
||||
gs_download_model()
|
||||
else:
|
||||
print(
|
||||
"Hash does not match upstream in gs://shark_tank/. If you are using SHARK Downloader with locally generated artifacts, this is working as intended."
|
||||
)
|
||||
|
||||
model_dir = os.path.join(WORKDIR, model_dir_name)
|
||||
suffix = "_tf.mlir" if tuned is None else "_tf_" + tuned + ".mlir"
|
||||
suffix = (
|
||||
"_" + frontend + ".mlir"
|
||||
if tuned is None
|
||||
else "_" + frontend + "_" + tuned + ".mlir"
|
||||
)
|
||||
filename = os.path.join(model_dir, model_name + suffix)
|
||||
if not os.path.isfile(filename):
|
||||
filename = os.path.join(model_dir, model_name + "_tf.mlir")
|
||||
filename = os.path.join(
|
||||
model_dir, model_name + "_" + frontend + ".mlir"
|
||||
)
|
||||
|
||||
with open(filename, mode="rb") as f:
|
||||
mlir_file = f.read()
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_torch_model
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model(
|
||||
"bert-base-uncased_tosa"
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"bert-base-uncased_tosa",
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_tf_model
|
||||
from shark.shark_downloader import download_model
|
||||
from shark.parser import shark_args
|
||||
from tank.test_utils import get_valid_test_params, shark_test_name_func
|
||||
from parameterized import parameterized
|
||||
@@ -21,8 +21,8 @@ class DebertaBaseModuleTester:
|
||||
self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, func_name, inputs, golden_out = download_tf_model(
|
||||
"microsoft/deberta-base"
|
||||
model, func_name, inputs, golden_out = download_model(
|
||||
"microsoft/deberta-base", frontend="tf"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import numpy as np
|
||||
from shark.shark_downloader import download_tflite_model
|
||||
from shark.shark_downloader import download_model
|
||||
from shark.shark_inference import SharkInference
|
||||
import pytest
|
||||
import unittest
|
||||
@@ -58,8 +58,8 @@ class GptTfliteModuleTester:
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
|
||||
# Preprocess to get SharkImporter input args
|
||||
mlir_model, func_name, inputs, tflite_results = download_tflite_model(
|
||||
model_name="gpt2-64"
|
||||
mlir_model, func_name, inputs, tflite_results = download_model(
|
||||
model_name="gpt2-64", backend="tflite"
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module=mlir_model,
|
||||
|
||||
@@ -20,10 +20,6 @@ class OPTModuleTester:
|
||||
self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device, model_name):
|
||||
# model_mlir, func_name, input, act_out = download_torch_model(
|
||||
# "opt", dynamic
|
||||
# )
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
||||
# config = OPTConfig()
|
||||
# opt_model = OPTModel(config)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_tf_model
|
||||
from shark.shark_downloader import download_model
|
||||
from tank.test_utils import get_valid_test_params, shark_test_name_func
|
||||
from parameterized import parameterized
|
||||
|
||||
@@ -18,8 +18,8 @@ class RemBertModuleTester:
|
||||
self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, func_name, inputs, golden_out = download_tf_model(
|
||||
"google/rembert"
|
||||
model, func_name, inputs, golden_out = download_model(
|
||||
"google/rembert", frontend="tf"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_tf_model
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
import iree.compiler as ireec
|
||||
import unittest
|
||||
@@ -16,8 +16,9 @@ class TapasBaseModuleTester:
|
||||
self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, func_name, inputs, golden_out = download_tf_model(
|
||||
"google/tapas-base"
|
||||
model, func_name, inputs, golden_out = download_model(
|
||||
"google/tapas-base",
|
||||
frontend="tf",
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
|
||||
@@ -15,7 +15,7 @@ from torchvision.transforms import functional as TF
|
||||
from tqdm import trange
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_torch_model
|
||||
from shark.shark_downloader import download_model
|
||||
import numpy as np
|
||||
|
||||
import sys
|
||||
@@ -191,7 +191,9 @@ x_in = x[0:min_batch_size, :, :, :]
|
||||
ts = x_in.new_ones([x_in.shape[0]])
|
||||
t_in = t[0] * ts
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model("v_diffusion")
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"v_diffusion", frontend="torch"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=args.runtime_device, mlir_dialect="linalg"
|
||||
|
||||
@@ -5,11 +5,7 @@ from shark.iree_utils._common import (
|
||||
)
|
||||
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
|
||||
from parameterized import parameterized
|
||||
from shark.shark_downloader import (
|
||||
download_tf_model,
|
||||
download_torch_model,
|
||||
download_tflite_model,
|
||||
)
|
||||
from shark.shark_downloader import download_model
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
import iree.compiler as ireec
|
||||
@@ -133,23 +129,11 @@ class SharkModuleTester:
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
shark_args.local_tank_cache = self.local_tank_cache
|
||||
shark_args.update_tank = self.update_tank
|
||||
if self.config["framework"] == "tf":
|
||||
model, func_name, inputs, golden_out = download_tf_model(
|
||||
self.config["model_name"],
|
||||
tank_url=self.tank_url,
|
||||
)
|
||||
elif self.config["framework"] == "torch":
|
||||
model, func_name, inputs, golden_out = download_torch_model(
|
||||
self.config["model_name"],
|
||||
tank_url=self.tank_url,
|
||||
)
|
||||
elif self.config["framework"] == "tflite":
|
||||
model, func_name, inputs, golden_out = download_tflite_model(
|
||||
model_name=self.config["model_name"],
|
||||
tank_url=self.tank_url,
|
||||
)
|
||||
else:
|
||||
model, func_name, inputs, golden_out = None, None, None, None
|
||||
model, func_name, inputs, golden_out = download_model(
|
||||
self.config["model_name"],
|
||||
tank_url=self.tank_url,
|
||||
frontend=self.config["framework"],
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
|
||||
@@ -3,7 +3,7 @@ import requests
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_torch_model
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
################################## Preprocessing inputs and helper functions ########
|
||||
|
||||
@@ -69,8 +69,8 @@ def resnet_inf(numpy_img, device):
|
||||
if device not in compiled_module.keys():
|
||||
if DEBUG:
|
||||
log_write.write("Compiling the Resnet50 module.\n")
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model(
|
||||
"resnet50"
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"resnet50", frontend="torch"
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=device, mlir_dialect="linalg"
|
||||
|
||||
@@ -36,10 +36,10 @@ def _compile_module(args, shark_module, model_name, extra_args=[]):
|
||||
|
||||
# Downloads the model from shark_tank and returns the shark_module.
|
||||
def get_shark_model(args, tank_url, model_name, extra_args=[]):
|
||||
from shark.shark_downloader import download_torch_model
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model(
|
||||
model_name, tank_url=tank_url
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
model_name, tank_url=tank_url, frontend="torch"
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
|
||||
|
||||
Reference in New Issue
Block a user