Fix load json input bug in SharkDownloader albert test

This commit is contained in:
Chi Liu
2022-06-29 21:35:46 -07:00
parent 06a45d9025
commit 96dd08cca4
3 changed files with 72 additions and 36 deletions

View File

@@ -52,6 +52,21 @@ class SharkDownloader:
print("Error. No tank_url, No model name,Please input either one.")
return
self.workdir = os.path.join(os.path.dirname(__file__), self.local_tank_dir)
os.makedirs(self.workdir, exist_ok=True)
print(f"TMP_MODEL_DIR = {self.workdir}")
# use model name get dir.
self.model_name_dir = os.path.join(self.workdir, str(self.model_name))
if not os.path.exists(self.model_name_dir):
print(
"Model has not been download."
"shark_downloader will automatically download by "
"tank_url if provided. You can also manually to "
"download the model from shark_tank by yourself."
)
os.makedirs(self.model_name_dir, exist_ok=True)
print(f"TMP_MODELNAME_DIR = {self.model_name_dir}")
# read inputs from json file
self.load_json_input()
# get milr model file
@@ -66,42 +81,65 @@ class SharkDownloader:
def load_json_input(self):
print("load json inputs")
if self.model_type in ["tflite-tosa"]:
input_url = (
self.tank_url
+ "/"
+ str(self.model_name)
+ "/"
+ "input.json"
)
input_file = "/".join(
[self.model_name_dir, str(self.input_json)]
)
if os.path.exists(input_file):
print("Input has been downloaded before.", input_file)
else:
print("Download input", input_url)
urllib.request.urlretrieve(input_url, input_file)
args = []
with open(self.input_json, "r") as f:
with open(input_file, "r") as f:
args = json.load(f)
self.inputs = [np.asarray(arg, dtype=self.input_type) for arg in args]
else:
print("No json input required for current model type. You could call setup_inputs(YOU_INPUTS).")
print(
"No json input required for current model type. "
"You could call setup_inputs(YOU_INPUTS)."
)
return self.inputs
def load_mlir_model(self):
workdir = os.path.join(os.path.dirname(__file__), self.local_tank_dir)
os.makedirs(workdir, exist_ok=True)
print(f"TMP_MODEL_DIR = {workdir}")
# use model name get dir.
model_name_dir = os.path.join(workdir, str(self.model_name))
if not os.path.exists(model_name_dir):
print(
"Model has not been download."
"shark_downloader will automatically download by tank_url if provided."
" You can also manually to download the model from shark_tank by yourself."
)
os.makedirs(model_name_dir, exist_ok=True)
print(f"TMP_MODELNAME_DIR = {model_name_dir}")
if self.model_type in ["tflite-tosa"]:
self.mlir_url = self.tank_url + "/" + str(self.model_name) + "/" + str(self.model_name) + "_tflite.mlir"
self.mlir_file = "/".join([model_name_dir, str(self.model_name) + "_tfite.mlir"])
self.mlir_url = (
self.tank_url
+ "/"
+ str(self.model_name)
+ "/"
+ str(self.model_name)
+ "_tflite.mlir"
)
self.mlir_file = "/".join(
[self.model_name_dir, str(self.model_name) + "_tfite.mlir"]
)
elif self.model_type in ["tensorflow"]:
self.mlir_url = self.tank_url + "/" + str(self.model_name) + "/" + str(self.model_name) + "_tf.mlir"
self.mlir_file = "/".join([model_name_dir, str(self.model_name) + "_tf.mlir"])
self.mlir_url = (
self.tank_url
+ "/"
+ str(self.model_name)
+ "/"
+ str(self.model_name)
+ "_tf.mlir"
)
self.mlir_file = "/".join(
[self.model_name_dir, str(self.model_name) + "_tf.mlir"]
)
elif self.model_type in ["torch", "jax", "mhlo", "tosa"]:
self.mlir_url = (
self.tank_url + "/" + str(self.model_name) + "/" + str(self.model_name) + "_" + str(self.model_type) + ".mlir"
)
self.mlir_file = "/".join(
[
model_name_dir,
self.model_name_dir,
str(self.model_name) + "_" + str(self.model_type) + ".mlir",
]
)

View File

@@ -21,7 +21,7 @@ class AlbertTfliteModuleTester:
def create_and_check_module(self):
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
self.shark_downloader = SharkDownloader(
shark_downloader = SharkDownloader(
model_name="albert_lite_base",
tank_url="https://storage.googleapis.com/shark_tank",
local_tank_dir="./../gen_shark_tank",
@@ -29,18 +29,17 @@ class AlbertTfliteModuleTester:
input_json="input.json",
input_type="int32",
)
tflite_tosa_model = self.shark_downloader.get_mlir_file()
inputs = self.shark_downloader.get_inputs()
self.shark_module = SharkInference(
tflite_tosa_model,
inputs,
tflite_tosa_model = shark_downloader.get_mlir_file()
inputs = shark_downloader.get_inputs()
shark_module = SharkInference(
mlir_module=tflite_tosa_model,
function_name="main",
device=self.device,
dynamic=self.dynamic,
jit_trace=True,
mlir_dialect="tflite",
)
self.shark_module.set_frontend("tflite-tosa")
self.shark_module.compile()
self.shark_module.forward(inputs)
shark_module.compile()
shark_module.forward(inputs)
# print(shark_results)
@@ -61,9 +60,9 @@ class AlbertTfliteModuleTest(unittest.TestCase):
if __name__ == "__main__":
unittest.main()
# module_tester = AlbertTfliteModuleTester()
# module_tester.create_and_check_module()
# unittest.main()
module_tester = AlbertTfliteModuleTester()
module_tester.create_and_check_module()
# TEST RESULT:
# (shark.venv) nod% python albert_lite_base_tflite_mlir_test.py

View File

@@ -1 +0,0 @@
[[[252, 30, 233, 138, 36, 18, 220, 186, 216, 30, 144, 111, 102, 228, 255, 213, 22, 171, 83, 26, 11, 50, 9, 248, 55, 95, 95, 51, 85, 132, 217, 164, 183, 230, 117, 218, 214, 41, 144, 52, 100, 103, 202, 171, 163, 0, 148, 174, 235, 181, 235, 192, 116, 95, 150, 160, 29, 171, 56, 126, 163, 6, 61, 161, 114, 231, 5, 87, 86, 165, 205, 170, 249, 31, 106, 8, 242, 80, 140, 163, 120, 122, 235, 4, 190, 122, 39, 247, 147, 33, 57, 18, 164, 234, 85, 82, 28, 8, 231, 241, 239, 97, 42, 167, 51, 199, 212, 38, 9, 25, 100, 163, 124, 23, 111, 159, 205, 190, 184, 108, 182, 149, 58, 37, 7, 250, 109, 50, 75, 41, 239, 106, 33, 246, 24, 40, 12, 50, 205, 107, 114, 63, 2, 26, 151, 114, 22, 178, 159, 98, 216, 168, 127, 216, 152, 242, 87, 192, 121, 34, 53, 107, 151, 14, 182, 31, 207, 99, 127, 247, 82, 198, 217, 43, 86, 228, 218, 98, 187, 11, 177, 131, 255, 137, 19, 87, 152, 249, 68, 85, 48, 57, 63, 68, 123, 136, 9, 110, 218, 21, 194, 90, 163, 101, 15, 152, 30, 230, 54, 142, 254, 78, 52, 180, 249, 198, 140, 98, 81, 72, 154, 116, 72, 100, 97, 2, 171, 45, 18, 155, 134, 108, 194, 168, 213, 126, 39, 178, 102, 154, 180, 110, 155, 149, 120, 118, 111, 96, 104, 4, 131, 72, 57, 23, 196, 195, 240, 33, 60, 85, 207, 125, 179, 21, 126, 159, 105, 186, 176, 19, 100, 161, 81, 243, 1, 234, 86, 98, 48, 36, 111, 196, 230, 19, 254, 114, 27, 100, 43, 28, 95, 108, 114, 109, 186, 189, 175, 77, 191, 214, 130, 139, 147, 118, 55, 250, 31, 100, 19, 45, 205, 221, 70, 184, 248, 10, 236, 114, 195, 5, 70, 222, 20, 169, 173, 4, 125, 255, 120, 130, 215, 164, 106, 33, 34, 215, 85, 103, 16, 252, 158, 10, 229, 40, 92, 222, 246, 211, 147, 105, 243, 135, 35, 125, 148, 207, 120, 213, 221, 229, 217, 47, 152, 92, 189, 147, 183, 133, 196, 90, 229, 233, 233, 235, 107, 3, 211, 191, 10, 247, 164, 30, 231, 59]], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]