diff --git a/shark/shark_downloader.py b/shark/shark_downloader.py index 0c9b0e64..ea0ed277 100644 --- a/shark/shark_downloader.py +++ b/shark/shark_downloader.py @@ -52,6 +52,21 @@ class SharkDownloader: print("Error. No tank_url, No model name,Please input either one.") return + self.workdir = os.path.join(os.path.dirname(__file__), self.local_tank_dir) + os.makedirs(self.workdir, exist_ok=True) + print(f"TMP_MODEL_DIR = {self.workdir}") + # use model name get dir. + self.model_name_dir = os.path.join(self.workdir, str(self.model_name)) + if not os.path.exists(self.model_name_dir): + print( + "Model has not been download." + "shark_downloader will automatically download by " + "tank_url if provided. You can also manually to " + "download the model from shark_tank by yourself." + ) + os.makedirs(self.model_name_dir, exist_ok=True) + print(f"TMP_MODELNAME_DIR = {self.model_name_dir}") + # read inputs from json file self.load_json_input() # get milr model file @@ -66,42 +81,65 @@ class SharkDownloader: def load_json_input(self): print("load json inputs") if self.model_type in ["tflite-tosa"]: + input_url = ( + self.tank_url + + "/" + + str(self.model_name) + + "/" + + "input.json" + ) + input_file = "/".join( + [self.model_name_dir, str(self.input_json)] + ) + if os.path.exists(input_file): + print("Input has been downloaded before.", input_file) + else: + print("Download input", input_url) + urllib.request.urlretrieve(input_url, input_file) + args = [] - with open(self.input_json, "r") as f: + with open(input_file, "r") as f: args = json.load(f) self.inputs = [np.asarray(arg, dtype=self.input_type) for arg in args] else: - print("No json input required for current model type. You could call setup_inputs(YOU_INPUTS).") + print( + "No json input required for current model type. " + "You could call setup_inputs(YOU_INPUTS)." + ) return self.inputs def load_mlir_model(self): - workdir = os.path.join(os.path.dirname(__file__), self.local_tank_dir) - os.makedirs(workdir, exist_ok=True) - print(f"TMP_MODEL_DIR = {workdir}") - # use model name get dir. - model_name_dir = os.path.join(workdir, str(self.model_name)) - if not os.path.exists(model_name_dir): - print( - "Model has not been download." - "shark_downloader will automatically download by tank_url if provided." - " You can also manually to download the model from shark_tank by yourself." - ) - os.makedirs(model_name_dir, exist_ok=True) - print(f"TMP_MODELNAME_DIR = {model_name_dir}") - if self.model_type in ["tflite-tosa"]: - self.mlir_url = self.tank_url + "/" + str(self.model_name) + "/" + str(self.model_name) + "_tflite.mlir" - self.mlir_file = "/".join([model_name_dir, str(self.model_name) + "_tfite.mlir"]) + self.mlir_url = ( + self.tank_url + + "/" + + str(self.model_name) + + "/" + + str(self.model_name) + + "_tflite.mlir" + ) + self.mlir_file = "/".join( + [self.model_name_dir, str(self.model_name) + "_tfite.mlir"] + ) elif self.model_type in ["tensorflow"]: - self.mlir_url = self.tank_url + "/" + str(self.model_name) + "/" + str(self.model_name) + "_tf.mlir" - self.mlir_file = "/".join([model_name_dir, str(self.model_name) + "_tf.mlir"]) + self.mlir_url = ( + self.tank_url + + "/" + + str(self.model_name) + + "/" + + str(self.model_name) + + "_tf.mlir" + ) + self.mlir_file = "/".join( + [self.model_name_dir, str(self.model_name) + "_tf.mlir"] + ) elif self.model_type in ["torch", "jax", "mhlo", "tosa"]: self.mlir_url = ( self.tank_url + "/" + str(self.model_name) + "/" + str(self.model_name) + "_" + str(self.model_type) + ".mlir" ) self.mlir_file = "/".join( [ - model_name_dir, + self.model_name_dir, str(self.model_name) + "_" + str(self.model_type) + ".mlir", ] ) diff --git a/tank/albert_lite_base/albert_lite_base_tflite_mlir_test.py b/tank/albert_lite_base/albert_lite_base_tflite_mlir_test.py index 0e9317eb..b5cc44cd 100644 --- a/tank/albert_lite_base/albert_lite_base_tflite_mlir_test.py +++ b/tank/albert_lite_base/albert_lite_base_tflite_mlir_test.py @@ -21,7 +21,7 @@ class AlbertTfliteModuleTester: def create_and_check_module(self): shark_args.save_mlir = self.save_mlir shark_args.save_vmfb = self.save_vmfb - self.shark_downloader = SharkDownloader( + shark_downloader = SharkDownloader( model_name="albert_lite_base", tank_url="https://storage.googleapis.com/shark_tank", local_tank_dir="./../gen_shark_tank", @@ -29,18 +29,17 @@ class AlbertTfliteModuleTester: input_json="input.json", input_type="int32", ) - tflite_tosa_model = self.shark_downloader.get_mlir_file() - inputs = self.shark_downloader.get_inputs() - self.shark_module = SharkInference( - tflite_tosa_model, - inputs, + tflite_tosa_model = shark_downloader.get_mlir_file() + inputs = shark_downloader.get_inputs() + + shark_module = SharkInference( + mlir_module=tflite_tosa_model, + function_name="main", device=self.device, - dynamic=self.dynamic, - jit_trace=True, + mlir_dialect="tflite", ) - self.shark_module.set_frontend("tflite-tosa") - self.shark_module.compile() - self.shark_module.forward(inputs) + shark_module.compile() + shark_module.forward(inputs) # print(shark_results) @@ -61,9 +60,9 @@ class AlbertTfliteModuleTest(unittest.TestCase): if __name__ == "__main__": - unittest.main() - # module_tester = AlbertTfliteModuleTester() - # module_tester.create_and_check_module() + # unittest.main() + module_tester = AlbertTfliteModuleTester() + module_tester.create_and_check_module() # TEST RESULT: # (shark.venv) nod% python albert_lite_base_tflite_mlir_test.py diff --git a/tank/albert_lite_base/input.json b/tank/albert_lite_base/input.json deleted file mode 100644 index fbea5730..00000000 --- a/tank/albert_lite_base/input.json +++ /dev/null @@ -1 +0,0 @@ -[[[252, 30, 233, 138, 36, 18, 220, 186, 216, 30, 144, 111, 102, 228, 255, 213, 22, 171, 83, 26, 11, 50, 9, 248, 55, 95, 95, 51, 85, 132, 217, 164, 183, 230, 117, 218, 214, 41, 144, 52, 100, 103, 202, 171, 163, 0, 148, 174, 235, 181, 235, 192, 116, 95, 150, 160, 29, 171, 56, 126, 163, 6, 61, 161, 114, 231, 5, 87, 86, 165, 205, 170, 249, 31, 106, 8, 242, 80, 140, 163, 120, 122, 235, 4, 190, 122, 39, 247, 147, 33, 57, 18, 164, 234, 85, 82, 28, 8, 231, 241, 239, 97, 42, 167, 51, 199, 212, 38, 9, 25, 100, 163, 124, 23, 111, 159, 205, 190, 184, 108, 182, 149, 58, 37, 7, 250, 109, 50, 75, 41, 239, 106, 33, 246, 24, 40, 12, 50, 205, 107, 114, 63, 2, 26, 151, 114, 22, 178, 159, 98, 216, 168, 127, 216, 152, 242, 87, 192, 121, 34, 53, 107, 151, 14, 182, 31, 207, 99, 127, 247, 82, 198, 217, 43, 86, 228, 218, 98, 187, 11, 177, 131, 255, 137, 19, 87, 152, 249, 68, 85, 48, 57, 63, 68, 123, 136, 9, 110, 218, 21, 194, 90, 163, 101, 15, 152, 30, 230, 54, 142, 254, 78, 52, 180, 249, 198, 140, 98, 81, 72, 154, 116, 72, 100, 97, 2, 171, 45, 18, 155, 134, 108, 194, 168, 213, 126, 39, 178, 102, 154, 180, 110, 155, 149, 120, 118, 111, 96, 104, 4, 131, 72, 57, 23, 196, 195, 240, 33, 60, 85, 207, 125, 179, 21, 126, 159, 105, 186, 176, 19, 100, 161, 81, 243, 1, 234, 86, 98, 48, 36, 111, 196, 230, 19, 254, 114, 27, 100, 43, 28, 95, 108, 114, 109, 186, 189, 175, 77, 191, 214, 130, 139, 147, 118, 55, 250, 31, 100, 19, 45, 205, 221, 70, 184, 248, 10, 236, 114, 195, 5, 70, 222, 20, 169, 173, 4, 125, 255, 120, 130, 215, 164, 106, 33, 34, 215, 85, 103, 16, 252, 158, 10, 229, 40, 92, 222, 246, 211, 147, 105, 243, 135, 35, 125, 148, 207, 120, 213, 221, 229, 217, 47, 152, 92, 189, 147, 183, 133, 196, 90, 229, 233, 233, 235, 107, 3, 211, 191, 10, 247, 164, 30, 231, 59]], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]] \ No newline at end of file