mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add tf/torch/mhlo/tosa support for SharkDownloader (#151)
This commit is contained in:
@@ -40,9 +40,12 @@ class SharkDownloader:
|
||||
self.local_tank_dir = local_tank_dir
|
||||
self.tank_url = tank_url
|
||||
self.model_type = model_type
|
||||
self.input_json = input_json
|
||||
self.input_type = input_type_to_np_dtype[input_type]
|
||||
self.input_json = input_json # optional if you don't have input
|
||||
self.input_type = input_type_to_np_dtype[
|
||||
input_type
|
||||
] # optional if you don't have input
|
||||
self.mlir_file = None # .mlir file local address.
|
||||
self.mlir_url = None
|
||||
self.inputs = None # Input has to be (list of np.array) for sharkInference.forward use
|
||||
self.mlir_model = []
|
||||
|
||||
@@ -73,51 +76,78 @@ class SharkDownloader:
|
||||
]
|
||||
else:
|
||||
print(
|
||||
"No json input required for current model. You could call setup_inputs(you_inputs)."
|
||||
"No json input required for current model type. You could call setup_inputs(YOU_INPUTS)."
|
||||
)
|
||||
return self.inputs
|
||||
|
||||
def load_mlir_model(self):
|
||||
if self.model_type in ["tflite-tosa"]:
|
||||
workdir = os.path.join(
|
||||
os.path.dirname(__file__), self.local_tank_dir
|
||||
workdir = os.path.join(os.path.dirname(__file__), self.local_tank_dir)
|
||||
os.makedirs(workdir, exist_ok=True)
|
||||
print(f"TMP_MODEL_DIR = {workdir}")
|
||||
# use model name get dir.
|
||||
model_name_dir = os.path.join(workdir, str(self.model_name))
|
||||
if not os.path.exists(model_name_dir):
|
||||
print(
|
||||
"Model has not been download."
|
||||
"shark_downloader will automatically download by tank_url if provided."
|
||||
" You can also manually to download the model from shark_tank by yourself."
|
||||
)
|
||||
os.makedirs(workdir, exist_ok=True)
|
||||
print(f"TMP_MODEL_DIR = {workdir}")
|
||||
os.makedirs(model_name_dir, exist_ok=True)
|
||||
print(f"TMP_MODELNAME_DIR = {model_name_dir}")
|
||||
|
||||
# use model name get dir.
|
||||
model_name_dir = os.path.join(workdir, str(self.model_name))
|
||||
if not os.path.exists(model_name_dir):
|
||||
print(
|
||||
"Model has not been download."
|
||||
"shark_downloader will automatically download by tank_url if provided."
|
||||
" You can also manually to download the model from shark_tank by yourself."
|
||||
)
|
||||
os.makedirs(model_name_dir, exist_ok=True)
|
||||
print(f"TMP_MODELNAME_DIR = {model_name_dir}")
|
||||
|
||||
mlir_url = (
|
||||
if self.model_type in ["tflite-tosa"]:
|
||||
self.mlir_url = (
|
||||
self.tank_url
|
||||
+ "/tflite/"
|
||||
+ "/"
|
||||
+ str(self.model_name)
|
||||
+ "/"
|
||||
+ str(self.model_name)
|
||||
+ "_tosa.mlir"
|
||||
+ "_tflite.mlir"
|
||||
)
|
||||
self.mlir_file = "/".join(
|
||||
[model_name_dir, str(self.model_name) + "_tosa.mlir"]
|
||||
[model_name_dir, str(self.model_name) + "_tfite.mlir"]
|
||||
)
|
||||
elif self.model_type in ["tensorflow"]:
|
||||
self.mlir_url = (
|
||||
self.tank_url
|
||||
+ "/"
|
||||
+ str(self.model_name)
|
||||
+ "/"
|
||||
+ str(self.model_name)
|
||||
+ "_tf.mlir"
|
||||
)
|
||||
self.mlir_file = "/".join(
|
||||
[model_name_dir, str(self.model_name) + "_tf.mlir"]
|
||||
)
|
||||
elif self.model_type in ["torch", "jax", "mhlo", "tosa"]:
|
||||
self.mlir_url = (
|
||||
self.tank_url
|
||||
+ "/"
|
||||
+ str(self.model_name)
|
||||
+ "/"
|
||||
+ str(self.model_name)
|
||||
+ "_"
|
||||
+ str(self.model_type)
|
||||
+ ".mlir"
|
||||
)
|
||||
self.mlir_file = "/".join(
|
||||
[
|
||||
model_name_dir,
|
||||
str(self.model_name) + "_" + str(self.model_type) + ".mlir",
|
||||
]
|
||||
)
|
||||
if os.path.exists(self.mlir_file):
|
||||
print("Model has been downloaded before.", self.mlir_file)
|
||||
else:
|
||||
print("Download mlir model", mlir_url)
|
||||
urllib.request.urlretrieve(mlir_url, self.mlir_file)
|
||||
|
||||
print("Get tosa.mlir model return")
|
||||
with open(self.mlir_file) as f:
|
||||
self.mlir_model = f.read()
|
||||
else:
|
||||
print("Unsupported mlir model")
|
||||
|
||||
if os.path.exists(self.mlir_file):
|
||||
print("Model has been downloaded before.", self.mlir_file)
|
||||
else:
|
||||
print("Download mlir model", self.mlir_url)
|
||||
urllib.request.urlretrieve(self.mlir_url, self.mlir_file)
|
||||
|
||||
print("Get .mlir model return")
|
||||
with open(self.mlir_file) as f:
|
||||
self.mlir_model = f.read()
|
||||
return self.mlir_model
|
||||
|
||||
def setup_inputs(self, inputs):
|
||||
|
||||
@@ -24,7 +24,7 @@ class AlbertTfliteModuleTester:
|
||||
self.shark_downloader = SharkDownloader(
|
||||
model_name="albert_lite_base",
|
||||
tank_url="https://storage.googleapis.com/shark_tank",
|
||||
local_tank_dir="./../gen_shark_tank/tflite",
|
||||
local_tank_dir="./../gen_shark_tank",
|
||||
model_type="tflite-tosa",
|
||||
input_json="input.json",
|
||||
input_type="int32",
|
||||
|
||||
Reference in New Issue
Block a user