mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
Fix load json input bug in SharkDownloader albert test
This commit is contained in:
@@ -52,6 +52,21 @@ class SharkDownloader:
|
||||
print("Error. No tank_url, No model name,Please input either one.")
|
||||
return
|
||||
|
||||
self.workdir = os.path.join(os.path.dirname(__file__), self.local_tank_dir)
|
||||
os.makedirs(self.workdir, exist_ok=True)
|
||||
print(f"TMP_MODEL_DIR = {self.workdir}")
|
||||
# use model name get dir.
|
||||
self.model_name_dir = os.path.join(self.workdir, str(self.model_name))
|
||||
if not os.path.exists(self.model_name_dir):
|
||||
print(
|
||||
"Model has not been download."
|
||||
"shark_downloader will automatically download by "
|
||||
"tank_url if provided. You can also manually to "
|
||||
"download the model from shark_tank by yourself."
|
||||
)
|
||||
os.makedirs(self.model_name_dir, exist_ok=True)
|
||||
print(f"TMP_MODELNAME_DIR = {self.model_name_dir}")
|
||||
|
||||
# read inputs from json file
|
||||
self.load_json_input()
|
||||
# get milr model file
|
||||
@@ -66,42 +81,65 @@ class SharkDownloader:
|
||||
def load_json_input(self):
|
||||
print("load json inputs")
|
||||
if self.model_type in ["tflite-tosa"]:
|
||||
input_url = (
|
||||
self.tank_url
|
||||
+ "/"
|
||||
+ str(self.model_name)
|
||||
+ "/"
|
||||
+ "input.json"
|
||||
)
|
||||
input_file = "/".join(
|
||||
[self.model_name_dir, str(self.input_json)]
|
||||
)
|
||||
if os.path.exists(input_file):
|
||||
print("Input has been downloaded before.", input_file)
|
||||
else:
|
||||
print("Download input", input_url)
|
||||
urllib.request.urlretrieve(input_url, input_file)
|
||||
|
||||
args = []
|
||||
with open(self.input_json, "r") as f:
|
||||
with open(input_file, "r") as f:
|
||||
args = json.load(f)
|
||||
self.inputs = [np.asarray(arg, dtype=self.input_type) for arg in args]
|
||||
else:
|
||||
print("No json input required for current model type. You could call setup_inputs(YOU_INPUTS).")
|
||||
print(
|
||||
"No json input required for current model type. "
|
||||
"You could call setup_inputs(YOU_INPUTS)."
|
||||
)
|
||||
return self.inputs
|
||||
|
||||
def load_mlir_model(self):
|
||||
workdir = os.path.join(os.path.dirname(__file__), self.local_tank_dir)
|
||||
os.makedirs(workdir, exist_ok=True)
|
||||
print(f"TMP_MODEL_DIR = {workdir}")
|
||||
# use model name get dir.
|
||||
model_name_dir = os.path.join(workdir, str(self.model_name))
|
||||
if not os.path.exists(model_name_dir):
|
||||
print(
|
||||
"Model has not been download."
|
||||
"shark_downloader will automatically download by tank_url if provided."
|
||||
" You can also manually to download the model from shark_tank by yourself."
|
||||
)
|
||||
os.makedirs(model_name_dir, exist_ok=True)
|
||||
print(f"TMP_MODELNAME_DIR = {model_name_dir}")
|
||||
|
||||
if self.model_type in ["tflite-tosa"]:
|
||||
self.mlir_url = self.tank_url + "/" + str(self.model_name) + "/" + str(self.model_name) + "_tflite.mlir"
|
||||
self.mlir_file = "/".join([model_name_dir, str(self.model_name) + "_tfite.mlir"])
|
||||
self.mlir_url = (
|
||||
self.tank_url
|
||||
+ "/"
|
||||
+ str(self.model_name)
|
||||
+ "/"
|
||||
+ str(self.model_name)
|
||||
+ "_tflite.mlir"
|
||||
)
|
||||
self.mlir_file = "/".join(
|
||||
[self.model_name_dir, str(self.model_name) + "_tfite.mlir"]
|
||||
)
|
||||
elif self.model_type in ["tensorflow"]:
|
||||
self.mlir_url = self.tank_url + "/" + str(self.model_name) + "/" + str(self.model_name) + "_tf.mlir"
|
||||
self.mlir_file = "/".join([model_name_dir, str(self.model_name) + "_tf.mlir"])
|
||||
self.mlir_url = (
|
||||
self.tank_url
|
||||
+ "/"
|
||||
+ str(self.model_name)
|
||||
+ "/"
|
||||
+ str(self.model_name)
|
||||
+ "_tf.mlir"
|
||||
)
|
||||
self.mlir_file = "/".join(
|
||||
[self.model_name_dir, str(self.model_name) + "_tf.mlir"]
|
||||
)
|
||||
elif self.model_type in ["torch", "jax", "mhlo", "tosa"]:
|
||||
self.mlir_url = (
|
||||
self.tank_url + "/" + str(self.model_name) + "/" + str(self.model_name) + "_" + str(self.model_type) + ".mlir"
|
||||
)
|
||||
self.mlir_file = "/".join(
|
||||
[
|
||||
model_name_dir,
|
||||
self.model_name_dir,
|
||||
str(self.model_name) + "_" + str(self.model_type) + ".mlir",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -21,7 +21,7 @@ class AlbertTfliteModuleTester:
|
||||
def create_and_check_module(self):
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
self.shark_downloader = SharkDownloader(
|
||||
shark_downloader = SharkDownloader(
|
||||
model_name="albert_lite_base",
|
||||
tank_url="https://storage.googleapis.com/shark_tank",
|
||||
local_tank_dir="./../gen_shark_tank",
|
||||
@@ -29,18 +29,17 @@ class AlbertTfliteModuleTester:
|
||||
input_json="input.json",
|
||||
input_type="int32",
|
||||
)
|
||||
tflite_tosa_model = self.shark_downloader.get_mlir_file()
|
||||
inputs = self.shark_downloader.get_inputs()
|
||||
self.shark_module = SharkInference(
|
||||
tflite_tosa_model,
|
||||
inputs,
|
||||
tflite_tosa_model = shark_downloader.get_mlir_file()
|
||||
inputs = shark_downloader.get_inputs()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=tflite_tosa_model,
|
||||
function_name="main",
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=True,
|
||||
mlir_dialect="tflite",
|
||||
)
|
||||
self.shark_module.set_frontend("tflite-tosa")
|
||||
self.shark_module.compile()
|
||||
self.shark_module.forward(inputs)
|
||||
shark_module.compile()
|
||||
shark_module.forward(inputs)
|
||||
# print(shark_results)
|
||||
|
||||
|
||||
@@ -61,9 +60,9 @@ class AlbertTfliteModuleTest(unittest.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
# module_tester = AlbertTfliteModuleTester()
|
||||
# module_tester.create_and_check_module()
|
||||
# unittest.main()
|
||||
module_tester = AlbertTfliteModuleTester()
|
||||
module_tester.create_and_check_module()
|
||||
|
||||
# TEST RESULT:
|
||||
# (shark.venv) nod% python albert_lite_base_tflite_mlir_test.py
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
[[[252, 30, 233, 138, 36, 18, 220, 186, 216, 30, 144, 111, 102, 228, 255, 213, 22, 171, 83, 26, 11, 50, 9, 248, 55, 95, 95, 51, 85, 132, 217, 164, 183, 230, 117, 218, 214, 41, 144, 52, 100, 103, 202, 171, 163, 0, 148, 174, 235, 181, 235, 192, 116, 95, 150, 160, 29, 171, 56, 126, 163, 6, 61, 161, 114, 231, 5, 87, 86, 165, 205, 170, 249, 31, 106, 8, 242, 80, 140, 163, 120, 122, 235, 4, 190, 122, 39, 247, 147, 33, 57, 18, 164, 234, 85, 82, 28, 8, 231, 241, 239, 97, 42, 167, 51, 199, 212, 38, 9, 25, 100, 163, 124, 23, 111, 159, 205, 190, 184, 108, 182, 149, 58, 37, 7, 250, 109, 50, 75, 41, 239, 106, 33, 246, 24, 40, 12, 50, 205, 107, 114, 63, 2, 26, 151, 114, 22, 178, 159, 98, 216, 168, 127, 216, 152, 242, 87, 192, 121, 34, 53, 107, 151, 14, 182, 31, 207, 99, 127, 247, 82, 198, 217, 43, 86, 228, 218, 98, 187, 11, 177, 131, 255, 137, 19, 87, 152, 249, 68, 85, 48, 57, 63, 68, 123, 136, 9, 110, 218, 21, 194, 90, 163, 101, 15, 152, 30, 230, 54, 142, 254, 78, 52, 180, 249, 198, 140, 98, 81, 72, 154, 116, 72, 100, 97, 2, 171, 45, 18, 155, 134, 108, 194, 168, 213, 126, 39, 178, 102, 154, 180, 110, 155, 149, 120, 118, 111, 96, 104, 4, 131, 72, 57, 23, 196, 195, 240, 33, 60, 85, 207, 125, 179, 21, 126, 159, 105, 186, 176, 19, 100, 161, 81, 243, 1, 234, 86, 98, 48, 36, 111, 196, 230, 19, 254, 114, 27, 100, 43, 28, 95, 108, 114, 109, 186, 189, 175, 77, 191, 214, 130, 139, 147, 118, 55, 250, 31, 100, 19, 45, 205, 221, 70, 184, 248, 10, 236, 114, 195, 5, 70, 222, 20, 169, 173, 4, 125, 255, 120, 130, 215, 164, 106, 33, 34, 215, 85, 103, 16, 252, 158, 10, 229, 40, 92, 222, 246, 211, 147, 105, 243, 135, 35, 125, 148, 207, 120, 213, 221, 229, 217, 47, 152, 92, 189, 147, 183, 133, 196, 90, 229, 233, 233, 235, 107, 3, 211, 191, 10, 247, 164, 30, 231, 59]], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]
|
||||
Reference in New Issue
Block a user