From ce00c1c5e18e6a3457e3902aaf4d61327453da3a Mon Sep 17 00:00:00 2001 From: Abhishek-Varma Date: Thu, 22 Dec 2022 09:23:50 +0000 Subject: [PATCH] [SharkInference] Make SharkInference compile the entire module -- Previously SharkInference was compiling and providing run APIs for a harcoded function with function name "forward". -- This commit makes the compiling functionality generic and now any function being defined within the module can be run. -- It also creates an API to fetch all the function names defined within the compiled module. Signed-off-by: Abhishek Varma --- shark/iree_eager_backend.py | 4 +--- shark/iree_utils/compile_utils.py | 32 +++++++++++++++---------------- shark/shark_benchmark_runner.py | 7 ++----- shark/shark_inference.py | 31 +++++++++++++----------------- shark/shark_runner.py | 20 +++++++++---------- tank/test_models.py | 3 +-- 6 files changed, 43 insertions(+), 54 deletions(-) diff --git a/shark/iree_eager_backend.py b/shark/iree_eager_backend.py index db7cfb9e..6be9cbf2 100644 --- a/shark/iree_eager_backend.py +++ b/shark/iree_eager_backend.py @@ -21,7 +21,6 @@ import torch from iree.runtime import DeviceArray from torch_mlir._mlir_libs._mlir.ir import Module from torch_mlir.compiler_utils import ( - get_module_name_for_debug_dump, run_pipeline_with_repro_report, ) from torch_mlir.eager_mode.torch_mlir_eager_backend import ( @@ -64,14 +63,13 @@ class EagerModeIREELinalgOnTensorsBackend(TorchMLIREagerBackend): ) def compile(self, imported_module: Module): - fn_name = get_module_name_for_debug_dump(imported_module) run_pipeline_with_repro_report( imported_module, "torch-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline", "EagerMode", ) callable, _ = get_iree_compiled_module( - imported_module, self.raw_device_str, func_name=fn_name + imported_module, self.raw_device_str ) return callable diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index 02bbfeb3..c62d9208 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -234,7 +234,6 @@ def compile_module_to_flatbuffer( module, device, frontend, - func_name, model_config_path, extra_args, model_name="None", @@ -277,7 +276,7 @@ def compile_module_to_flatbuffer( return flatbuffer_blob -def get_iree_module(flatbuffer_blob, device, func_name): +def get_iree_module(flatbuffer_blob, device): # Returns the compiled module and the configs. config = get_iree_runtime_config(device) vm_module = ireert.VmModule.from_flatbuffer( @@ -285,7 +284,7 @@ def get_iree_module(flatbuffer_blob, device, func_name): ) ctx = ireert.SystemContext(config=config) ctx.add_vm_module(vm_module) - ModuleCompiled = ctx.modules.module[func_name] + ModuleCompiled = ctx.modules.module return ModuleCompiled, config @@ -293,25 +292,22 @@ def get_iree_compiled_module( module, device: str, frontend: str = "torch", - func_name: str = "forward", model_config_path: str = None, extra_args: list = [], ): """Given a module returns the compiled .vmfb and configs""" flatbuffer_blob = compile_module_to_flatbuffer( - module, device, frontend, func_name, model_config_path, extra_args + module, device, frontend, model_config_path, extra_args ) - return get_iree_module(flatbuffer_blob, device, func_name) + return get_iree_module(flatbuffer_blob, device) -def load_flatbuffer( - flatbuffer_path: str, device: str, func_name: str = "forward" -): +def load_flatbuffer(flatbuffer_path: str, device: str): with open(os.path.join(flatbuffer_path), "rb") as f: flatbuffer_blob = f.read() - return get_iree_module(flatbuffer_blob, device, func_name) + return get_iree_module(flatbuffer_blob, device) def export_iree_module_to_vmfb( @@ -319,20 +315,19 @@ def export_iree_module_to_vmfb( device: str, directory: str, mlir_dialect: str = "linalg", - func_name: str = "forward", model_config_path: str = None, module_name: str = None, extra_args: list = [], ): # Compiles the module given specs and saves it as .vmfb file. flatbuffer_blob = compile_module_to_flatbuffer( - module, device, mlir_dialect, func_name, model_config_path, extra_args + module, device, mlir_dialect, model_config_path, extra_args ) if module_name is None: device_name = ( device if "://" not in device else "-".join(device.split("://")) ) - module_name = f"{mlir_dialect}_{func_name}_{device_name}" + module_name = f"{mlir_dialect}_{device_name}" filename = os.path.join(directory, module_name + ".vmfb") print(f"Saved vmfb in {filename}.") with open(filename, "wb") as f: @@ -355,11 +350,16 @@ def export_module_to_mlir_file(module, frontend, directory: str): def get_results( - compiled_vm, input, config, frontend="torch", send_to_host=True + compiled_vm, + function_name, + input, + config, + frontend="torch", + send_to_host=True, ): """Runs a .vmfb file given inputs and config and returns output.""" device_inputs = [ireert.asdevicearray(config.device, a) for a in input] - result = compiled_vm(*device_inputs) + result = compiled_vm[function_name](*device_inputs) result_tensors = [] if isinstance(result, tuple): if send_to_host: @@ -376,7 +376,7 @@ def get_results( return np.copy(res) return data else: - if send_to_host: + if send_to_host and result is not None: return result.to_host() return result diff --git a/shark/shark_benchmark_runner.py b/shark/shark_benchmark_runner.py index 533b9beb..5651d2bf 100644 --- a/shark/shark_benchmark_runner.py +++ b/shark/shark_benchmark_runner.py @@ -60,7 +60,6 @@ class SharkBenchmarkRunner(SharkRunner): def __init__( self, mlir_module: bytes, - function_name: str = "forward", device: str = "none", mlir_dialect: str = "linalg", extra_args: list = [], @@ -73,7 +72,6 @@ class SharkBenchmarkRunner(SharkRunner): SharkRunner.__init__( self, mlir_module, - function_name, device, self.mlir_dialect, self.extra_args, @@ -85,7 +83,6 @@ class SharkBenchmarkRunner(SharkRunner): device, shark_args.repro_dir, self.mlir_dialect, - function_name, extra_args=self.extra_args, ) @@ -185,11 +182,11 @@ class SharkBenchmarkRunner(SharkRunner): def benchmark_python(self, inputs): input_list = [x for x in inputs] for i in range(shark_args.num_warmup_iterations): - self.run(input_list) + self.run("forward", input_list) begin = time.time() for i in range(shark_args.num_iterations): - out = self.run(input_list) + out = self.run("forward", input_list) if i == shark_args.num_iterations - 1: end = time.time() print( diff --git a/shark/shark_inference.py b/shark/shark_inference.py index 5c8d69ca..df46a61f 100644 --- a/shark/shark_inference.py +++ b/shark/shark_inference.py @@ -40,8 +40,6 @@ class SharkInference: ---------- mlir_module : str mlir_module represented in string; modules from torch-mlir are serialized in bytecode format. - function_name : str - function to execute in the given mlir_module. device : str device to execute the mlir_module on. currently supports cpu, cuda, vulkan, and metal backends. @@ -53,10 +51,10 @@ class SharkInference: Methods ------- - run(inputs=None): - Runs the mlir_module with the given inputs, if the inputs are not - given it autogenerates the inputs. Also, the inputs should be a - numpy array. + __call__(function_name, inputs=None): + Runs the function with `function_name` within the mlir_module along + with the given inputs, if the inputs are not given it autogenerates the + inputs. Also, the inputs should be a numpy array. input_info(): Gives the information about the inputs required by the `function_name`. This can be expensive as it does string matching to do so. @@ -66,7 +64,6 @@ class SharkInference: def __init__( self, mlir_module: bytes, - function_name: str = "forward", device: str = "none", mlir_dialect: str = "linalg", is_benchmark: bool = False, @@ -74,7 +71,6 @@ class SharkInference: dispatch_benchmark_dir: str = "temp_dispatch_benchmarks", ): self.mlir_module = mlir_module - self.function_name = function_name self.device = shark_args.device if device == "none" else device self.mlir_dialect = mlir_dialect self.is_benchmark = is_benchmark @@ -113,7 +109,6 @@ class SharkInference: self.shark_runner = SharkBenchmarkRunner( self.mlir_module, - self.function_name, self.device, self.mlir_dialect, extra_args=extra_args, @@ -122,7 +117,6 @@ class SharkInference: else: self.shark_runner = SharkRunner( self.mlir_module, - self.function_name, self.device, self.mlir_dialect, extra_args=extra_args, @@ -138,21 +132,25 @@ class SharkInference: os.system(f"rm -rf {self.temp_dispatch_benchmarks_dir}") # inputs are considered to be tuple of np.array. - def forward(self, inputs: tuple, send_to_host=True): - return self.shark_runner.run(inputs, send_to_host) + def __call__(self, function_name: str, inputs: tuple, send_to_host=True): + return self.shark_runner.run(function_name, inputs, send_to_host) + + # Get all function names defined within the compiled module. + def get_functions_in_module(self): + return self.shark_runner.get_functions_in_module() # Captures the static input information from the mlir_module. # TODO(pashu123): Generate the input information for dynamic shapes. - def _input_info(self): + def _input_info(self, function_name): # func_key to get the line which contains the function. - func_key = "func.func @" + self.function_name + func_key = "func.func @" + function_name func_header = None for line in str(self.mlir_module).splitlines(): if func_key in line: func_header = line break if func_header is None: - print(f"Function: {self.function_name} not found") + print(f"Function: {function_name} not found") import re @@ -190,7 +188,6 @@ class SharkInference: self.device, dir, self.mlir_dialect, - self.function_name, module_name=module_name, extra_args=extra_args, ) @@ -198,7 +195,6 @@ class SharkInference: # load and return the module. def load_module(self, path, extra_args=[]): self.shark_runner = SharkRunner( - function_name=self.function_name, device=self.device, compile_vmfb=False, extra_args=extra_args, @@ -209,6 +205,5 @@ class SharkInference: ) = load_flatbuffer( path, self.device, - self.function_name, ) return diff --git a/shark/shark_runner.py b/shark/shark_runner.py index bbe66039..e1db7ea8 100644 --- a/shark/shark_runner.py +++ b/shark/shark_runner.py @@ -39,8 +39,6 @@ class SharkRunner: ---------- mlir_module : str mlir_module represented in string. - function_name : str - function to execute in the given mlir_module. device : str device to execute the mlir_module on. currently supports cpu, cuda, vulkan, and metal backends. @@ -50,10 +48,10 @@ class SharkRunner: Methods ------- - run(inputs=None): - Runs the mlir_module with the given inputs, if the inputs are not - given it autogenerates the inputs. Also, the inputs should be a - numpy array. + run(function_name, inputs=None): + Runs the function with `function_name` within the mlir_module along + with the given inputs, if the inputs are not given it autogenerates the + inputs. Also, the inputs should be a numpy array. input_info(): Gives the information about the inputs required by the `function_name`. This can be expensive as it does string matching to do so. @@ -62,14 +60,12 @@ class SharkRunner: def __init__( self, mlir_module: bytes = None, - function_name: str = "forward", device: str = "none", mlir_dialect: str = "linalg", extra_args: list = [], compile_vmfb: bool = True, ): self.mlir_module = mlir_module - self.function_name = function_name self.device = shark_args.device if device == "none" else device self.mlir_dialect = mlir_dialect self.extra_args = extra_args @@ -87,15 +83,19 @@ class SharkRunner: self.mlir_module, self.device, self.mlir_dialect, - func_name=self.function_name, extra_args=self.extra_args, ) - def run(self, inputs: tuple, send_to_host=False): + def run(self, function_name, inputs: tuple, send_to_host=False): return get_results( self.iree_compilation_module, + function_name, inputs, self.iree_config, self.mlir_dialect, send_to_host, ) + + # Get all function names defined within the compiled module. + def get_functions_in_module(self): + return self.iree_compilation_module._vm_module.function_names diff --git a/tank/test_models.py b/tank/test_models.py index fdf90e4a..9efec9b1 100644 --- a/tank/test_models.py +++ b/tank/test_models.py @@ -148,7 +148,6 @@ class SharkModuleTester: shark_module = SharkInference( model, - func_name, device=device, mlir_dialect=self.config["dialect"], is_benchmark=self.benchmark, @@ -163,7 +162,7 @@ class SharkModuleTester: self.upload_repro() raise - result = shark_module.forward(inputs) + result = shark_module(func_name, inputs) golden_out, result = self.postprocess_outputs(golden_out, result) try: np.testing.assert_allclose(