diff --git a/shark/examples/shark_inference/mhlo_example.py b/shark/examples/shark_inference/mhlo_example.py index b4c0cde0..369d63e1 100644 --- a/shark/examples/shark_inference/mhlo_example.py +++ b/shark/examples/shark_inference/mhlo_example.py @@ -12,7 +12,23 @@ mhlo_ir = r"""builtin.module { arg0 = np.ones((1, 4)).astype(np.float32) arg1 = np.ones((4, 1)).astype(np.float32) -shark_module = SharkInference(mhlo_ir, (arg0, arg1)) -shark_module.set_frontend("mhlo") +print("Running shark on cpu backend") +shark_module = SharkInference( + mhlo_ir, function_name="forward", device="cpu", mlir_dialect="mhlo" +) +shark_module.compile() +print(shark_module.forward((arg0, arg1))) + +print("Running shark on cuda backend") +shark_module = SharkInference( + mhlo_ir, function_name="forward", device="cuda", mlir_dialect="mhlo" +) +shark_module.compile() +print(shark_module.forward((arg0, arg1))) + +print("Running shark on vulkan backend") +shark_module = SharkInference( + mhlo_ir, function_name="forward", device="vulkan", mlir_dialect="mhlo" +) shark_module.compile() print(shark_module.forward((arg0, arg1))) diff --git a/shark/examples/shark_inference/minilm_jit.py b/shark/examples/shark_inference/minilm_jit.py index ab8c0b0f..e1344f4c 100644 --- a/shark/examples/shark_inference/minilm_jit.py +++ b/shark/examples/shark_inference/minilm_jit.py @@ -1,6 +1,7 @@ import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification from shark.shark_inference import SharkInference +from shark.shark_importer import SharkImporter torch.manual_seed(0) tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased") @@ -23,12 +24,22 @@ class MiniLMSequenceClassification(torch.nn.Module): test_input = torch.randint(2, (1, 128)).to(torch.int32) -shark_module = SharkInference( +mlir_importer = SharkImporter( MiniLMSequenceClassification(), (test_input, test_input, test_input), - jit_trace=True, + frontend="torch", ) +# torch hugging face models needs tracing.. +(minilm_mlir, func_name), inputs, golden_out = mlir_importer.import_debug( + tracing_required=True +) + +print(golden_out) + +shark_module = SharkInference( + minilm_mlir, func_name, device="cpu", mlir_dialect="linalg" +) shark_module.compile() result = shark_module.forward((test_input, test_input, test_input)) print("Obtained result", result) diff --git a/shark/iree_utils/_common.py b/shark/iree_utils/_common.py index 42ffa0fc..de7eef8c 100644 --- a/shark/iree_utils/_common.py +++ b/shark/iree_utils/_common.py @@ -75,3 +75,17 @@ def check_device_drivers(device): return True return False + + +# Installation info for the missing device drivers. +def device_driver_info(device): + if device in ["gpu", "cuda"]: + print( + "nvidia-smi not found, please install the required drivers from https://www.nvidia.in/Download/index.aspx?lang=en-in" + ) + elif device in ["metal", "vulkan"]: + print( + "vulkaninfo not found, Install from https://vulkan.lunarg.com/sdk/home or your distribution" + ) + else: + print(f"{device} is not supported.") diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index 48c6b28b..29eaac20 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -14,7 +14,6 @@ import iree.runtime as ireert import iree.compiler as ireec from shark.iree_utils._common import IREE_DEVICE_MAP, IREE_TARGET_MAP -from shark.model_annotation import * import numpy as np import os @@ -157,23 +156,7 @@ def export_module_to_mlir_file(module, frontend, directory: str): def get_results(compiled_vm, input, config, frontend="torch"): """Runs a .vmfb file given inputs and config and returns output.""" - device_inputs = input - if frontend in ["torch", "pytorch"]: - device_inputs = [ireert.asdevicearray(config.device, a) for a in input] - if frontend in ["tensorflow", "tf", "tflite", "tflite-tosa"]: - device_inputs = [] - for a in input: - if isinstance(a, list): - device_inputs.append( - [ - ireert.asdevicearray( - config.device, val, dtype=val.dtype - ) - for val in a - ] - ) - else: - device_inputs.append(ireert.asdevicearray(config.device, a)) + device_inputs = [ireert.asdevicearray(config.device, a) for a in input] result = compiled_vm(*device_inputs) result_tensors = [] if isinstance(result, tuple): diff --git a/shark/shark_importer.py b/shark/shark_importer.py index cbbc6a24..43e7e23c 100644 --- a/shark/shark_importer.py +++ b/shark/shark_importer.py @@ -1,214 +1,155 @@ # Lint as: python3 """SHARK Importer""" -import numpy as np -import os -import csv -import urllib.request -from shark.iree_utils._common import IREE_TARGET_MAP -import json -from tank.model_utils_tflite import TFLiteModelUtil +import sys + +# List of the supported frontends. +supported_frontends = { + "tensorflow", + "tf", + "pytorch", + "torch", + "tf-lite", + "tflite", +} class SharkImporter: + """ + SharkImporter converts frontend modules into a + mlir_module. The supported frameworks are tensorflow, + pytorch, and tf-lite. + + ... + + Attributes + ---------- + module : + torch, tensorflow or tf-lite module. + inputs : + inputs to the module, may be required for the shape + information. + frontend: str + frontend to which the module belongs. + + Methods + ------- + import_mlir(is_dynamic, tracing_required, func_name): + is_dynamic: input shapes to be totally dynamic (pytorch specific). + tracing_required: whether tracing is required (pytorch specific. + func_name: The function to be traced out or imported to mlir. + + import_debug(is_dynamic, tracing_required, func_name): + returns the converted (mlir_module,func_name) with inputs and golden + outputs. + The inputs and outputs are converted into np array. + """ + def __init__( self, - model_name, - model_type: str = "torch", - input_details=None, - output_details=None, - model_path=None, + module, + inputs: tuple = (), + frontend: str = "torch", ): - self.model_name = model_name - self.model_type = model_type - self.input_details = ( - input_details # used for tflite, optional for tf/pytorch - ) - self.output_details = ( - output_details # used for tflite, optional for tf/pytorch - ) - self.inputs = [] - self.model_path = model_path # url to download the model - self.raw_model_file = ( - None # local address for raw tf/tflite/pytorch model - ) - self.mlir_file = ( - None # local address for .mlir file of tf/tflite/pytorch model - ) - self.mlir_model = None # read of .mlir file - self.output_tensor = ( - None # the raw tf/pytorch/tflite_output_tensor, not mlir_tensor - ) - self.interpreter = None # could be tflite/tf/torch_interpreter in utils - - # create tmp model file directory - if self.model_path is None and self.model_name is None: + self.module = module + self.inputs = None if len(inputs) == 0 else inputs + self.frontend = frontend + if not self.frontend in supported_frontends: print( - "Error. No model_path, No model name,Please input either one." + f"The frontend is not in the supported_frontends: {supported_frontends}" ) - return + sys.exit(1) - print("Setting up for TMP_WORK_DIR") - self.workdir = os.path.join( - os.path.dirname(__file__), "./../gen_shark_tank" + # NOTE: The default function for torch is "forward" and tf-lite is "main". + + def _torch_mlir(self, is_dynamic, tracing_required): + from shark.torch_mlir_utils import get_torch_mlir_module + + return get_torch_mlir_module( + self.module, self.inputs, is_dynamic, tracing_required ) - os.makedirs(self.workdir, exist_ok=True) - print(f"TMP_WORK_DIR = {self.workdir}") - # compile and run tfhub tflite - if self.model_type == "tflite": - load_model_success = self.load_tflite_model() - if not load_model_success: - print("Error, load tflite model fail") - return + def _tf_mlir(self, func_name): + from iree.compiler import tf as tfc - if (self.input_details is None) or (self.output_details is None): + return tfc.compile_module( + self.module, exported_names=[func_name], import_only=True + ) + + def _tflite_mlir(self, func_name): + from iree.compiler import tflite as tflitec + + # TODO(Chi): Just add the conversion of tflite model here. + return tflitec.compile_module( + self.module, exported_names=[func_name], import_only=True + ) + + # Adds the conversion of the frontend with the private function. + def import_mlir( + self, + is_dynamic=False, + tracing_required=False, + func_name="forward", + ): + if self.frontend in ["torch", "pytorch"]: + if self.inputs == None: print( - "Setting up tflite interpreter to get model input details" + "Please pass in the inputs, the inputs are required to determine the shape of the mlir_module" ) - self.setup_interpreter() + sys.exit(1) + return self._torch_mlir(is_dynamic, tracing_required), func_name + if self.frontend in ["tf", "tensorflow"]: + return self._tf_mlir(func_name), func_name + if self.frontend in ["tflite", "tf-lite"]: + func_name = "main" + return self._tflite_mlir(func_name), func_name - inputs = self.generate_inputs( - self.input_details - ) # device_inputs - self.setup_inputs(inputs) + # Converts the frontend specific tensors into np array. + def convert_to_numpy(self, array_tuple: tuple): + if self.frontend in ["torch", "pytorch"]: + return [x.detach().numpy() for x in array_tuple] + if self.frontend in ["tf", "tensorflow"]: + return [x.numpy() for x in array_tuple] + if self.frontend in ["tf-lite", "tflite"]: + # TODO(Chi): Validate for tf-lite tensors. + return [x.numpy() for x in array_tuple] - elif self.model_type in ["tensorflow, tf, torch, pytorch"]: - print(self.model_type, " Not Implemented yet") - - def load_tflite_model(self): - # use model name get dir. - tflite_model_name_dir = os.path.join(self.workdir, str(self.model_name)) - - os.makedirs(tflite_model_name_dir, exist_ok=True) - print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}") - - self.raw_model_file = "/".join( - [tflite_model_name_dir, str(self.model_name) + "_tflite.tflite"] - ) - self.mlir_file = "/".join( - [tflite_model_name_dir, str(self.model_name) + "_tflite.mlir"] - ) - - if os.path.exists(self.raw_model_file): + def import_debug( + self, + is_dynamic=False, + tracing_required=False, + func_name="forward", + ): + if self.inputs == None: print( - "Local address for .tflite model file Exists: ", - self.raw_model_file, + f"There is no input provided: {self.inputs}, please provide inputs or simply run import_mlir." ) - else: - print("No local tflite file, Download tflite model") - if self.model_path is None: - # get model file from tflite_model_list.csv or download from gs://bucket - print("No model_path, get from tflite_model_list.csv") - tflite_model_list_path = os.path.join( - os.path.dirname(__file__), - "../tank/tflite/tflite_model_list.csv", - ) - tflite_model_list = csv.reader(open(tflite_model_list_path)) - for row in tflite_model_list: - if str(row[0]) == str(self.model_name): - self.model_path = row[1] - print("tflite_model_name", str(row[0])) - print("tflite_model_link", self.model_path) - if self.model_path is None: - print("Error, No model path find in tflite_model_list.csv") - return False - urllib.request.urlretrieve(self.model_path, self.raw_model_file) - if os.path.exists(self.mlir_file): - print("Exists MLIR model ", self.mlir_file) - else: - print( - "No tflite tosa.mlir, please use python generate_sharktank.py to download tosa model" + sys.exit(1) + + imported_mlir = self.import_mlir( + is_dynamic, tracing_required, func_name + ) + # TODO: Make sure that any generic function name is accepted. Currently takes in the default function names. + # TODO: Check for multiple outputs. + if self.frontend in ["torch", "pytorch"]: + golden_out = self.module(*self.inputs) + return ( + imported_mlir, + self.convert_to_numpy(self.inputs), + golden_out.detach().numpy(), ) - print("Convert tflite to tosa.mlir") - import iree.compiler.tflite as ireec_tflite - - ireec_tflite.compile_file( - self.raw_model_file, - input_type="tosa", - save_temp_iree_input=self.mlir_file, - target_backends=[IREE_TARGET_MAP["cpu"]], - import_only=False, + if self.frontend in ["tf", "tensorflow"]: + golden_out = self.module.forward(*self.inputs) + return ( + imported_mlir, + self.convert_to_numpy(self.inputs), + golden_out.numpy(), ) - with open(self.mlir_file) as f: - self.mlir_model = f.read() - return True - - def setup_interpreter(self): - if self.model_type == "tflite": - self.interpreter = TFLiteModelUtil(self.raw_model_file) - ( - self.input_details, - self.output_details, - ) = self.interpreter.setup_tflite_interpreter() - - def generate_inputs(self, input_details): - self.inputs = [] - for tmp_input in input_details: - print(str(tmp_input["shape"]), tmp_input["dtype"].__name__) - self.inputs.append( - np.ones(shape=tmp_input["shape"], dtype=tmp_input["dtype"]) + if self.frontend in ["tflite", "tf-lite"]: + # TODO(Chi): Validate it for tflite models. + golden_out = self.module.main(*self.inputs) + return ( + imported_mlir, + self.convert_to_numpy(self.inputs), + golden_out.numpy(), ) - # save inputs into json file - tmp_json = [] - for tmp_input in input_details: - print(str(tmp_input["shape"]), tmp_input["dtype"].__name__) - tmp_json.append( - np.ones( - shape=tmp_input["shape"], dtype=tmp_input["dtype"] - ).tolist() - ) - with open("input1.json", "w") as f: - json.dump(tmp_json, f) - return self.inputs - - # def get_model_details(self): - # if self.model_type == "tflite": - # print("Get tflite input output details") - # self.input_details = self.tflite_interpreter.get_input_details() - # self.output_details = self.tflite_interpreter.get_output_details() - # return self.input_details, self.output_details - - def setup_inputs(self, inputs): - print("Setting up inputs") - self.inputs = inputs - - def get_mlir_model(self): - return self.mlir_model - - def get_inputs(self): - return self.inputs - - def get_raw_model_output(self): - if self.model_type == "tflite": - self.output_tensor = self.interpreter.invoke_tflite(self.inputs) - return self.output_tensor - - def get_model_details(self): - return self.input_details, self.output_details - - def get_raw_model_file(self): - return self.raw_model_file - - # def invoke_tflite(self, inputs): - # print("invoke_tflite") - # for i, input in enumerate(self.inputs): - # self.tflite_interpreter.set_tensor( - # self.input_details[i]["index"], input - # ) - # self.tflite_interpreter.invoke() - # - # # post process tflite_result for compare with mlir_result, - # # for tflite the output is a list of numpy.tensor - # tflite_results = [] - # for output_detail in self.output_details: - # tflite_results.append( - # np.array( - # self.tflite_interpreter.get_tensor(output_detail["index"]) - # ) - # ) - # - # for i in range(len(self.output_details)): - # out_dtype = self.output_details[i]["dtype"] - # tflite_results[i] = tflite_results[i].astype(out_dtype) - # return tflite_results diff --git a/shark/shark_inference.py b/shark/shark_inference.py index 5955d2cb..09e7a710 100644 --- a/shark/shark_inference.py +++ b/shark/shark_inference.py @@ -9,138 +9,64 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from shark.parser import shark_args from shark.shark_runner import SharkRunner -import sys -import os - - -# Prints to stderr. -def print_err(*a): - print(*a, file=sys.stderr) class SharkInference: - """Inference API targeting pytorch, tensorflow, linalg, mhlo and tosa frontend.""" + """ + Runs prediction or inference on mlir_module. + + ... + + Attributes + ---------- + mlir_module : str + mlir_module represented in string. + function_name : str + function to execute in the given mlir_module. + device : str + device to execute the mlir_module on. + currently supports cpu, cuda, vulkan, and metal backends. + mlir_dialect: str + The dialect in which the given mlir_module is in. + Refer to {https://mlir.llvm.org/docs/Dialects/} + + Methods + ------- + run(inputs=None): + Runs the mlir_module with the given inputs, if the inputs are not + given it autogenerates the inputs. Also, the inputs should be a + numpy array. + input_info(): + Gives the information about the inputs required by the `function_name`. + This can be expensive as it does string matching to do so. + + TODO(Stanley) Add the benchmark APIs with is_benchmark = True argument. + """ def __init__( self, - model, - input: tuple, - device: str = None, - dynamic: bool = False, - jit_trace: bool = False, - benchmark_mode: bool = False, + mlir_module: str, + function_name: str = "forward", + device: str = "cpu", + mlir_dialect: str = "linalg", ): - self.model = model - self.input = input - self.dynamic = dynamic - self.jit_trace = jit_trace - self.benchmark_mode = benchmark_mode - - # By default it's torch frontend. - self.frontend = "pytorch" - - # Sets the device. - self.device = device if device is not None else shark_args.device - - self.model_config_path = shark_args.model_config_path + self.mlir_module = mlir_module + self.function_name = function_name + self.device = device + self.mlir_dialect = mlir_dialect self.shark_runner = None - # Sets the frontend i.e `pytorch` or `tensorflow`. - def set_frontend(self, frontend: str): - if frontend not in [ - "pytorch", - "torch", - "tensorflow", - "tf", - "mhlo", - "linalg", - "tosa", - "tflite", - "tflite-tosa", - ]: - print_err("frontend not supported.") - else: - self.frontend = frontend - def compile(self): - # Inference do not use AOT. # TODO: Remove the from_aot arg as it's not - # needed. - from_aot = False - if self.benchmark_mode == True: - # Only import shark_benchmark runner when needed. - from shark.shark_benchmark_runner import SharkBenchmarkRunner - - self.shark_runner = SharkBenchmarkRunner( - self.model, - self.input, - self.dynamic, - self.device, - self.jit_trace, - from_aot, - self.frontend, - ) - else: - self.shark_runner = SharkRunner( - self.model, - self.input, - self.dynamic, - self.device, - self.jit_trace, - from_aot, - self.frontend, - self.model_config_path, - ) - - # inputs are considered to be np.array. - def forward(self, inputs): - input_list = inputs - # converts the inputs to numpy. - if self.frontend in ["pytorch", "torch"]: - input_list = [x.detach().numpy() for x in inputs] - elif self.frontend in ["tensorflow", "tf"]: - input_list = [x.numpy() for x in inputs] - return self.shark_runner.forward(input_list, self.frontend) - - def import_mlir(self, model_name="model", dir=os.getcwd()): - self.shark_runner.import_mlir(model_name, dir) - - # Saves the .vmfb module. - def save_module(self, dir=None): - if dir is None: - return self.shark_runner.save_module() - return self.shark_runner.save_module(dir) - - ######### Benchmark Related Functions ######### - def benchmark_mode(func): - def inner(self, *args, **kwargs): - assert ( - self.benchmark_mode - ), "SharkRunner needs to be in benchmark mode to run benchmark methods." - return func(self, *args, **kwargs) - - return inner - - @benchmark_mode - def benchmark_all(self, inputs): - self.shark_runner.benchmark_all(inputs) - - @benchmark_mode - def benchmark_frontend(self, inputs): - self.shark_runner.benchmark_frontend(inputs) - - @benchmark_mode - def benchmark_python(self, inputs): - self.shark_runner.benchmark_python(inputs) - - @benchmark_mode - def benchmark_c(self): - self.shark_runner.benchmark_c() - - @benchmark_mode - def benchmark_all_csv(self, inputs, modelname, dynamic, device_str): - self.shark_runner.benchmark_all_csv( - inputs, modelname, dynamic, device_str + # TODO: (Stanley) Update the shark_benchmark APIs. + self.shark_runner = SharkRunner( + self.mlir_module, + self.function_name, + self.device, + self.mlir_dialect, ) + + # inputs are considered to be tuple of np.array. + def forward(self, inputs: tuple): + return self.shark_runner.run(inputs) diff --git a/shark/shark_runner.py b/shark/shark_runner.py index 53136a94..e8ba3b92 100644 --- a/shark/shark_runner.py +++ b/shark/shark_runner.py @@ -11,137 +11,94 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from iree.compiler import tf as tfc -import iree.compiler.tflite as ireec_tflite -from torch.utils._python_dispatch import enable_torch_dispatch_mode -from torch_mlir.eager_mode import torch_mlir_tensor -from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor -from torch_mlir_e2e_test.eager_backends.refbackend import EagerModeRefBackend -from shark.iree_eager_backend import EagerModeIREELinalgOnTensorsBackend -from shark.torch_mlir_utils import get_torch_mlir_module, run_on_refbackend -from shark.parser import shark_args from shark.iree_utils.compile_utils import ( get_iree_compiled_module, - export_iree_module_to_vmfb, - export_module_to_mlir_file, get_results, + export_iree_module_to_vmfb, ) +from shark.iree_utils._common import check_device_drivers, device_driver_info import os +import sys + + +# supported dialects by the shark-runtime. +supported_dialects = {"linalg", "mhlo", "tosa", "tf-lite"} class SharkRunner: - """Base class for Shark Inference and Shark Runner.""" + """ + Base class for SharkInference and SharkTrainer + used to execute an mlir_module. + + ... + + Attributes + ---------- + mlir_module : str + mlir_module represented in string. + function_name : str + function to execute in the given mlir_module. + device : str + device to execute the mlir_module on. + currently supports cpu, cuda, vulkan, and metal backends. + mlir_dialect: str + The dialect in which the given mlir_module is in. + Refer to {https://mlir.llvm.org/docs/Dialects/} + + Methods + ------- + run(inputs=None): + Runs the mlir_module with the given inputs, if the inputs are not + given it autogenerates the inputs. Also, the inputs should be a + numpy array. + input_info(): + Gives the information about the inputs required by the `function_name`. + This can be expensive as it does string matching to do so. + """ def __init__( self, - model, - input: tuple, - dynamic: bool = False, - device: str = None, - jit_trace: bool = False, - from_aot: bool = False, - frontend: str = "torch", - model_config_path: str = None, + mlir_module: str, + function_name: str = "forward", + device: str = "cpu", + mlir_dialect: str = "linalg", ): - self.model = model - self.frontend_model = model - self.from_aot = from_aot - self.input = input - self.frontend = frontend - self.vmfb_file = None - func_name = "forward" - self.device = device if device is not None else shark_args.device + self.mlir_module = mlir_module + self.function_name = function_name + self.device = device + self.mlir_dialect = mlir_dialect - if self.frontend in ["tflite-tosa"]: - func_name = "main" - elif self.frontend in ["pytorch", "torch"]: - # get torch-mlir dialect - # self.model = torch.Module - # Lowers in linalg dialect. - # TODO assert - # TODO tosa dialect from torch_module. - self.model = get_torch_mlir_module( - self.model, input, dynamic, jit_trace, from_aot - ) - elif self.frontend in ["tensorflow", "tf"]: - # get mhlo dialect - # self.model = tf.Module - # TODO assert - self.model = tfc.compile_module( - self.model, exported_names=[func_name], import_only=True - ) - elif self.frontend in ["tflite"]: - print("Setting up for IREE compiler tflite") - # get tosa dialect - # self.model = model.tflite - # TODO assert - self.model = ireec_tflite.compile_file( - self.model, input_type="tosa", import_only=True - ) - func_name = "main" + if check_device_drivers(self.device): + device_driver_info(self.device) + sys.exit(1) - # TODO: We can capture the .vmfb module here and later use it for saving - # rather than recompiling it again, if used for saving. + # Compile the module to get the .vmfb. ( self.iree_compilation_module, self.iree_config, ) = get_iree_compiled_module( - self.model, + self.mlir_module, self.device, - self.frontend, - func_name=func_name, - model_config_path=model_config_path, + self.mlir_dialect, + func_name=self.function_name, ) - # Debugging Options: - if shark_args.save_mlir: - export_module_to_mlir_file( - self.model, self.frontend, shark_args.repro_dir - ) - if shark_args.save_vmfb: - self.vmfb_file = self.save_module(shark_args.repro_dir) - - # All the timings and benchmarking can be done here. - def forward(self, input, frontend): + def run(self, inputs: tuple): return get_results( - self.iree_compilation_module, input, self.iree_config, frontend + self.iree_compilation_module, + inputs, + self.iree_config, + self.mlir_dialect, ) - # Saves the .mlir file, can be in tosa, linalg or mhlo dialect. - # torch-mlir can export tosa or linalg dialects. - # tensorflow models get exported to mhlo dialect. - def import_mlir(self, model_name, dir): - filename = os.path.join(dir, f"{model_name}_{self.frontend}.mlir") - with open(filename, "w") as f: - f.write(self.model) - print(f"Saved mlir in {filename}.") - return filename - # TODO: Instead of passing directory and having names decided by the module # , user may want to save the module with manual names. def save_module(self, dir=os.getcwd()): return export_iree_module_to_vmfb( - self.model, self.device, dir, self.frontend + self.model, self.device, dir, self.mlir_dialect ) - # TODO: Load a module and directly use it, we will need to set the frontend - # in this case. - def load_module(self, name): + # TODO: Get the input information from the mlir_module. + def input_info(self): pass - - -# TODO: Document shark_eager mode. -class SharkEagerMode: - def __init__(self, device="cpu"): - if device == "refbackend": - torch_mlir_tensor.backend = EagerModeRefBackend() - else: - torch_mlir_tensor.backend = EagerModeIREELinalgOnTensorsBackend( - device - ) - self.guard = enable_torch_dispatch_mode(TorchMLIRTensor) - self.guard.__enter__() - - def __del__(self): - self.guard.__exit__(None, None, None) diff --git a/tank/model_utils_tflite.py b/shark/tflite_utils.py similarity index 100% rename from tank/model_utils_tflite.py rename to shark/tflite_utils.py diff --git a/shark/torch_mlir_utils.py b/shark/torch_mlir_utils.py index f65f415d..46878189 100644 --- a/shark/torch_mlir_utils.py +++ b/shark/torch_mlir_utils.py @@ -110,14 +110,14 @@ def get_torch_mlir_module( module, input: tuple, dynamic: bool, - tracing_required: bool, - from_aot: bool = False, + jit_trace: bool, + from_torchscript: bool = False, ): """TODO: Include necessary documentation.""" # Tracing is not required from the aot_module. - if not from_aot: - module = shark_jit_trace(module, input, dynamic, tracing_required) + if not from_torchscript: + module = shark_jit_trace(module, input, dynamic, jit_trace) mb = ModuleBuilder() class_annotator = ClassAnnotator() diff --git a/tank/pytorch/MiniLM-L12-H384-uncased/MiniLM-L12-H384-uncased_pytorch_test.py b/tank/pytorch/MiniLM-L12-H384-uncased/MiniLM-L12-H384-uncased_pytorch_test.py index 132f52c9..9971776e 100644 --- a/tank/pytorch/MiniLM-L12-H384-uncased/MiniLM-L12-H384-uncased_pytorch_test.py +++ b/tank/pytorch/MiniLM-L12-H384-uncased/MiniLM-L12-H384-uncased_pytorch_test.py @@ -1,4 +1,5 @@ from shark.shark_inference import SharkInference +from shark.shark_importer import SharkImporter from shark.iree_utils._common import check_device_drivers from tank.model_utils import get_hf_model, compare_tensors from shark.parser import shark_args @@ -30,12 +31,16 @@ class MiniLMModuleTester: ) shark_args.save_mlir = self.save_mlir shark_args.save_vmfb = self.save_vmfb - shark_module = SharkInference( + mlir_importer = SharkImporter( model, (input,), - device=self.device, - dynamic=self.dynamic, - jit_trace=True, + frontend="torch", + ) + minilm_mlir, func_name = mlir_importer.import_mlir( + is_dynamic=self.dynamic, tracing_required=True + ) + shark_module = SharkInference( + minilm_mlir, func_name, device="cpu", mlir_dialect="linalg" ) shark_module.compile() results = shark_module.forward((input,)) diff --git a/tank/pytorch/albert-base-v2/albert-base-v2_pytorch_test.py b/tank/pytorch/albert-base-v2/albert-base-v2_pytorch_test.py index 06f2ce3c..f69d81f9 100644 --- a/tank/pytorch/albert-base-v2/albert-base-v2_pytorch_test.py +++ b/tank/pytorch/albert-base-v2/albert-base-v2_pytorch_test.py @@ -1,4 +1,5 @@ from shark.shark_inference import SharkInference +from shark.shark_importer import SharkImporter from shark.iree_utils._common import check_device_drivers from tank.model_utils import get_hf_model, compare_tensors from shark.parser import shark_args @@ -28,12 +29,16 @@ class AlbertModuleTester: model, input, act_out = get_hf_model("albert-base-v2") shark_args.save_mlir = self.save_mlir shark_args.save_vmfb = self.save_vmfb - shark_module = SharkInference( + mlir_importer = SharkImporter( model, (input,), - device=self.device, - dynamic=self.dynamic, - jit_trace=True, + frontend="torch", + ) + minilm_mlir, func_name = mlir_importer.import_mlir( + is_dynamic=self.dynamic, tracing_required=True + ) + shark_module = SharkInference( + minilm_mlir, func_name, device="cpu", mlir_dialect="linalg" ) shark_module.compile() results = shark_module.forward((input,)) @@ -55,9 +60,6 @@ class AlbertModuleTest(unittest.TestCase): self.module_tester.device = "cpu" self.module_tester.create_and_check_module() - @pytest.mark.xfail( - reason="Language models currently failing for dynamic case" - ) def test_module_dynamic_cpu(self): self.module_tester.dynamic = True self.module_tester.device = "cpu" @@ -79,7 +81,6 @@ class AlbertModuleTest(unittest.TestCase): self.module_tester.device = "gpu" self.module_tester.create_and_check_module() - @pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554") @pytest.mark.skipif( check_device_drivers("vulkan"), reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases", @@ -89,7 +90,6 @@ class AlbertModuleTest(unittest.TestCase): self.module_tester.device = "vulkan" self.module_tester.create_and_check_module() - @pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554") @pytest.mark.skipif( check_device_drivers("vulkan"), reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases", diff --git a/tank/pytorch/alexnet/alexnet_pytorch_test.py b/tank/pytorch/alexnet/alexnet_pytorch_test.py index 1fbb3a0a..16694caf 100644 --- a/tank/pytorch/alexnet/alexnet_pytorch_test.py +++ b/tank/pytorch/alexnet/alexnet_pytorch_test.py @@ -1,4 +1,5 @@ from shark.shark_inference import SharkInference +from shark.shark_importer import SharkImporter from shark.iree_utils._common import check_device_drivers from tank.model_utils import get_vision_model, compare_tensors from shark.parser import shark_args @@ -31,11 +32,16 @@ class AlexnetModuleTester: ) shark_args.save_mlir = self.save_mlir shark_args.save_vmfb = self.save_vmfb - shark_module = SharkInference( + mlir_importer = SharkImporter( model, (input,), - device=self.device, - dynamic=self.dynamic, + frontend="torch", + ) + minilm_mlir, func_name = mlir_importer.import_mlir( + is_dynamic=self.dynamic, tracing_required=False + ) + shark_module = SharkInference( + minilm_mlir, func_name, device="cpu", mlir_dialect="linalg" ) shark_module.compile() results = shark_module.forward((input,)) diff --git a/tank/pytorch/bert-base-uncased/bert-base-uncased_pytorch_test.py b/tank/pytorch/bert-base-uncased/bert-base-uncased_pytorch_test.py index 2db9c6f1..0e81c0da 100644 --- a/tank/pytorch/bert-base-uncased/bert-base-uncased_pytorch_test.py +++ b/tank/pytorch/bert-base-uncased/bert-base-uncased_pytorch_test.py @@ -1,4 +1,5 @@ from shark.shark_inference import SharkInference +from shark.shark_importer import SharkImporter from shark.iree_utils._common import check_device_drivers from tank.model_utils import get_hf_model, compare_tensors from shark.parser import shark_args @@ -28,12 +29,16 @@ class BertModuleTester: model, input, act_out = get_hf_model("bert-base-uncased") shark_args.save_mlir = self.save_mlir shark_args.save_vmfb = self.save_vmfb - shark_module = SharkInference( + mlir_importer = SharkImporter( model, (input,), - device=self.device, - dynamic=self.dynamic, - jit_trace=True, + frontend="torch", + ) + minilm_mlir, func_name = mlir_importer.import_mlir( + is_dynamic=self.dynamic, tracing_required=True + ) + shark_module = SharkInference( + minilm_mlir, func_name, device="cpu", mlir_dialect="linalg" ) shark_module.compile() results = shark_module.forward((input,)) diff --git a/tank/pytorch/distilbert-base-uncased/distilbert-base-uncased_pytorch_test.py b/tank/pytorch/distilbert-base-uncased/distilbert-base-uncased_pytorch_test.py index 9ba260f8..f6f93b85 100644 --- a/tank/pytorch/distilbert-base-uncased/distilbert-base-uncased_pytorch_test.py +++ b/tank/pytorch/distilbert-base-uncased/distilbert-base-uncased_pytorch_test.py @@ -1,4 +1,5 @@ from shark.shark_inference import SharkInference +from shark.shark_importer import SharkImporter from shark.iree_utils._common import check_device_drivers from tank.model_utils import get_hf_model, compare_tensors from shark.parser import shark_args @@ -28,12 +29,16 @@ class DistilBertModuleTester: model, input, act_out = get_hf_model("distilbert-base-uncased") shark_args.save_mlir = self.save_mlir shark_args.save_vmfb = self.save_vmfb - shark_module = SharkInference( + mlir_importer = SharkImporter( model, (input,), - device=self.device, - dynamic=self.dynamic, - jit_trace=True, + frontend="torch", + ) + minilm_mlir, func_name = mlir_importer.import_mlir( + is_dynamic=self.dynamic, tracing_required=True + ) + shark_module = SharkInference( + minilm_mlir, func_name, device="cpu", mlir_dialect="linalg" ) shark_module.compile() results = shark_module.forward((input,)) diff --git a/tank/pytorch/resnet101/resnet101_pytorch_test.py b/tank/pytorch/resnet101/resnet101_pytorch_test.py index af982650..93e599be 100644 --- a/tank/pytorch/resnet101/resnet101_pytorch_test.py +++ b/tank/pytorch/resnet101/resnet101_pytorch_test.py @@ -1,4 +1,5 @@ from shark.shark_inference import SharkInference +from shark.shark_importer import SharkImporter from shark.iree_utils._common import check_device_drivers from tank.model_utils import get_vision_model, compare_tensors from shark.parser import shark_args @@ -31,11 +32,16 @@ class Resnet101ModuleTester: ) shark_args.save_mlir = self.save_mlir shark_args.save_vmfb = self.save_vmfb - shark_module = SharkInference( + mlir_importer = SharkImporter( model, (input,), - device=self.device, - dynamic=self.dynamic, + frontend="torch", + ) + minilm_mlir, func_name = mlir_importer.import_mlir( + is_dynamic=self.dynamic, tracing_required=False + ) + shark_module = SharkInference( + minilm_mlir, func_name, device="cpu", mlir_dialect="linalg" ) shark_module.compile() results = shark_module.forward((input,)) diff --git a/tank/pytorch/resnet18/resnet18_pytorch_test.py b/tank/pytorch/resnet18/resnet18_pytorch_test.py index a3372bee..5b3a93f7 100644 --- a/tank/pytorch/resnet18/resnet18_pytorch_test.py +++ b/tank/pytorch/resnet18/resnet18_pytorch_test.py @@ -1,4 +1,5 @@ from shark.shark_inference import SharkInference +from shark.shark_importer import SharkImporter from shark.iree_utils._common import check_device_drivers from tank.model_utils import get_vision_model, compare_tensors from shark.parser import shark_args @@ -31,11 +32,16 @@ class Resnet18ModuleTester: ) shark_args.save_mlir = self.save_mlir shark_args.save_vmfb = self.save_vmfb - shark_module = SharkInference( + mlir_importer = SharkImporter( model, (input,), - device=self.device, - dynamic=self.dynamic, + frontend="torch", + ) + minilm_mlir, func_name = mlir_importer.import_mlir( + is_dynamic=self.dynamic, tracing_required=False + ) + shark_module = SharkInference( + minilm_mlir, func_name, device="cpu", mlir_dialect="linalg" ) shark_module.compile() results = shark_module.forward((input,)) diff --git a/tank/pytorch/resnet50/resnet50_pytorch_test.py b/tank/pytorch/resnet50/resnet50_pytorch_test.py index 4ac2e839..cc93af6a 100644 --- a/tank/pytorch/resnet50/resnet50_pytorch_test.py +++ b/tank/pytorch/resnet50/resnet50_pytorch_test.py @@ -1,4 +1,5 @@ from shark.shark_inference import SharkInference +from shark.shark_importer import SharkImporter from shark.iree_utils._common import check_device_drivers from tank.model_utils import get_vision_model, compare_tensors from shark.parser import shark_args @@ -31,11 +32,16 @@ class Resnet50ModuleTester: ) shark_args.save_mlir = self.save_mlir shark_args.save_vmfb = self.save_vmfb - shark_module = SharkInference( + mlir_importer = SharkImporter( model, (input,), - device=self.device, - dynamic=self.dynamic, + frontend="torch", + ) + minilm_mlir, func_name = mlir_importer.import_mlir( + is_dynamic=self.dynamic, tracing_required=False + ) + shark_module = SharkInference( + minilm_mlir, func_name, device="cpu", mlir_dialect="linalg" ) shark_module.compile() results = shark_module.forward((input,)) diff --git a/tank/pytorch/squeezenet1_0/squeezenet1_0_pytorch_test.py b/tank/pytorch/squeezenet1_0/squeezenet1_0_pytorch_test.py index 7d95b8ee..07a1579e 100644 --- a/tank/pytorch/squeezenet1_0/squeezenet1_0_pytorch_test.py +++ b/tank/pytorch/squeezenet1_0/squeezenet1_0_pytorch_test.py @@ -1,4 +1,5 @@ from shark.shark_inference import SharkInference +from shark.shark_importer import SharkImporter from shark.iree_utils._common import check_device_drivers from tank.model_utils import get_vision_model, compare_tensors from shark.parser import shark_args @@ -31,11 +32,16 @@ class SqueezenetModuleTester: ) shark_args.save_mlir = self.save_mlir shark_args.save_vmfb = self.save_vmfb - shark_module = SharkInference( + mlir_importer = SharkImporter( model, (input,), - device=self.device, - dynamic=self.dynamic, + frontend="torch", + ) + minilm_mlir, func_name = mlir_importer.import_mlir( + is_dynamic=self.dynamic, tracing_required=False + ) + shark_module = SharkInference( + minilm_mlir, func_name, device="cpu", mlir_dialect="linalg" ) shark_module.compile() results = shark_module.forward((input,)) diff --git a/tank/pytorch/wide_resnet50_2/wide_resnet50_2_pytorch_test.py b/tank/pytorch/wide_resnet50_2/wide_resnet50_2_pytorch_test.py index c40a59f3..1cc60e12 100644 --- a/tank/pytorch/wide_resnet50_2/wide_resnet50_2_pytorch_test.py +++ b/tank/pytorch/wide_resnet50_2/wide_resnet50_2_pytorch_test.py @@ -1,4 +1,5 @@ from shark.shark_inference import SharkInference +from shark.shark_importer import SharkImporter from shark.iree_utils._common import check_device_drivers from tank.model_utils import get_vision_model, compare_tensors from shark.parser import shark_args @@ -31,11 +32,16 @@ class WideResnet50ModuleTester: ) shark_args.save_mlir = self.save_mlir shark_args.save_vmfb = self.save_vmfb - shark_module = SharkInference( + mlir_importer = SharkImporter( model, (input,), - device=self.device, - dynamic=self.dynamic, + frontend="torch", + ) + minilm_mlir, func_name = mlir_importer.import_mlir( + is_dynamic=self.dynamic + ) + shark_module = SharkInference( + minilm_mlir, func_name, device="cpu", mlir_dialect="linalg" ) shark_module.compile() results = shark_module.forward((input,))