Refactor the shark_runner shark_inference to only support mlir_modules.

1. The shark_inference is divided into shark_importer and
   shark_inference.
2. All the tank/pytorch tests have been updated.
This commit is contained in:
Prashant Kumar
2022-06-23 20:22:31 +05:30
parent 44dce561e9
commit b07377cbfd
19 changed files with 378 additions and 479 deletions

View File

@@ -12,7 +12,23 @@ mhlo_ir = r"""builtin.module {
arg0 = np.ones((1, 4)).astype(np.float32)
arg1 = np.ones((4, 1)).astype(np.float32)
shark_module = SharkInference(mhlo_ir, (arg0, arg1))
shark_module.set_frontend("mhlo")
print("Running shark on cpu backend")
shark_module = SharkInference(
mhlo_ir, function_name="forward", device="cpu", mlir_dialect="mhlo"
)
shark_module.compile()
print(shark_module.forward((arg0, arg1)))
print("Running shark on cuda backend")
shark_module = SharkInference(
mhlo_ir, function_name="forward", device="cuda", mlir_dialect="mhlo"
)
shark_module.compile()
print(shark_module.forward((arg0, arg1)))
print("Running shark on vulkan backend")
shark_module = SharkInference(
mhlo_ir, function_name="forward", device="vulkan", mlir_dialect="mhlo"
)
shark_module.compile()
print(shark_module.forward((arg0, arg1)))

View File

@@ -1,6 +1,7 @@
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
torch.manual_seed(0)
tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
@@ -23,12 +24,22 @@ class MiniLMSequenceClassification(torch.nn.Module):
test_input = torch.randint(2, (1, 128)).to(torch.int32)
shark_module = SharkInference(
mlir_importer = SharkImporter(
MiniLMSequenceClassification(),
(test_input, test_input, test_input),
jit_trace=True,
frontend="torch",
)
# torch hugging face models needs tracing..
(minilm_mlir, func_name), inputs, golden_out = mlir_importer.import_debug(
tracing_required=True
)
print(golden_out)
shark_module = SharkInference(
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
result = shark_module.forward((test_input, test_input, test_input))
print("Obtained result", result)

View File

@@ -75,3 +75,17 @@ def check_device_drivers(device):
return True
return False
# Installation info for the missing device drivers.
def device_driver_info(device):
if device in ["gpu", "cuda"]:
print(
"nvidia-smi not found, please install the required drivers from https://www.nvidia.in/Download/index.aspx?lang=en-in"
)
elif device in ["metal", "vulkan"]:
print(
"vulkaninfo not found, Install from https://vulkan.lunarg.com/sdk/home or your distribution"
)
else:
print(f"{device} is not supported.")

View File

@@ -14,7 +14,6 @@
import iree.runtime as ireert
import iree.compiler as ireec
from shark.iree_utils._common import IREE_DEVICE_MAP, IREE_TARGET_MAP
from shark.model_annotation import *
import numpy as np
import os
@@ -157,23 +156,7 @@ def export_module_to_mlir_file(module, frontend, directory: str):
def get_results(compiled_vm, input, config, frontend="torch"):
"""Runs a .vmfb file given inputs and config and returns output."""
device_inputs = input
if frontend in ["torch", "pytorch"]:
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
if frontend in ["tensorflow", "tf", "tflite", "tflite-tosa"]:
device_inputs = []
for a in input:
if isinstance(a, list):
device_inputs.append(
[
ireert.asdevicearray(
config.device, val, dtype=val.dtype
)
for val in a
]
)
else:
device_inputs.append(ireert.asdevicearray(config.device, a))
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
result = compiled_vm(*device_inputs)
result_tensors = []
if isinstance(result, tuple):

View File

@@ -1,214 +1,155 @@
# Lint as: python3
"""SHARK Importer"""
import numpy as np
import os
import csv
import urllib.request
from shark.iree_utils._common import IREE_TARGET_MAP
import json
from tank.model_utils_tflite import TFLiteModelUtil
import sys
# List of the supported frontends.
supported_frontends = {
"tensorflow",
"tf",
"pytorch",
"torch",
"tf-lite",
"tflite",
}
class SharkImporter:
"""
SharkImporter converts frontend modules into a
mlir_module. The supported frameworks are tensorflow,
pytorch, and tf-lite.
...
Attributes
----------
module :
torch, tensorflow or tf-lite module.
inputs :
inputs to the module, may be required for the shape
information.
frontend: str
frontend to which the module belongs.
Methods
-------
import_mlir(is_dynamic, tracing_required, func_name):
is_dynamic: input shapes to be totally dynamic (pytorch specific).
tracing_required: whether tracing is required (pytorch specific.
func_name: The function to be traced out or imported to mlir.
import_debug(is_dynamic, tracing_required, func_name):
returns the converted (mlir_module,func_name) with inputs and golden
outputs.
The inputs and outputs are converted into np array.
"""
def __init__(
self,
model_name,
model_type: str = "torch",
input_details=None,
output_details=None,
model_path=None,
module,
inputs: tuple = (),
frontend: str = "torch",
):
self.model_name = model_name
self.model_type = model_type
self.input_details = (
input_details # used for tflite, optional for tf/pytorch
)
self.output_details = (
output_details # used for tflite, optional for tf/pytorch
)
self.inputs = []
self.model_path = model_path # url to download the model
self.raw_model_file = (
None # local address for raw tf/tflite/pytorch model
)
self.mlir_file = (
None # local address for .mlir file of tf/tflite/pytorch model
)
self.mlir_model = None # read of .mlir file
self.output_tensor = (
None # the raw tf/pytorch/tflite_output_tensor, not mlir_tensor
)
self.interpreter = None # could be tflite/tf/torch_interpreter in utils
# create tmp model file directory
if self.model_path is None and self.model_name is None:
self.module = module
self.inputs = None if len(inputs) == 0 else inputs
self.frontend = frontend
if not self.frontend in supported_frontends:
print(
"Error. No model_path, No model name,Please input either one."
f"The frontend is not in the supported_frontends: {supported_frontends}"
)
return
sys.exit(1)
print("Setting up for TMP_WORK_DIR")
self.workdir = os.path.join(
os.path.dirname(__file__), "./../gen_shark_tank"
# NOTE: The default function for torch is "forward" and tf-lite is "main".
def _torch_mlir(self, is_dynamic, tracing_required):
from shark.torch_mlir_utils import get_torch_mlir_module
return get_torch_mlir_module(
self.module, self.inputs, is_dynamic, tracing_required
)
os.makedirs(self.workdir, exist_ok=True)
print(f"TMP_WORK_DIR = {self.workdir}")
# compile and run tfhub tflite
if self.model_type == "tflite":
load_model_success = self.load_tflite_model()
if not load_model_success:
print("Error, load tflite model fail")
return
def _tf_mlir(self, func_name):
from iree.compiler import tf as tfc
if (self.input_details is None) or (self.output_details is None):
return tfc.compile_module(
self.module, exported_names=[func_name], import_only=True
)
def _tflite_mlir(self, func_name):
from iree.compiler import tflite as tflitec
# TODO(Chi): Just add the conversion of tflite model here.
return tflitec.compile_module(
self.module, exported_names=[func_name], import_only=True
)
# Adds the conversion of the frontend with the private function.
def import_mlir(
self,
is_dynamic=False,
tracing_required=False,
func_name="forward",
):
if self.frontend in ["torch", "pytorch"]:
if self.inputs == None:
print(
"Setting up tflite interpreter to get model input details"
"Please pass in the inputs, the inputs are required to determine the shape of the mlir_module"
)
self.setup_interpreter()
sys.exit(1)
return self._torch_mlir(is_dynamic, tracing_required), func_name
if self.frontend in ["tf", "tensorflow"]:
return self._tf_mlir(func_name), func_name
if self.frontend in ["tflite", "tf-lite"]:
func_name = "main"
return self._tflite_mlir(func_name), func_name
inputs = self.generate_inputs(
self.input_details
) # device_inputs
self.setup_inputs(inputs)
# Converts the frontend specific tensors into np array.
def convert_to_numpy(self, array_tuple: tuple):
if self.frontend in ["torch", "pytorch"]:
return [x.detach().numpy() for x in array_tuple]
if self.frontend in ["tf", "tensorflow"]:
return [x.numpy() for x in array_tuple]
if self.frontend in ["tf-lite", "tflite"]:
# TODO(Chi): Validate for tf-lite tensors.
return [x.numpy() for x in array_tuple]
elif self.model_type in ["tensorflow, tf, torch, pytorch"]:
print(self.model_type, " Not Implemented yet")
def load_tflite_model(self):
# use model name get dir.
tflite_model_name_dir = os.path.join(self.workdir, str(self.model_name))
os.makedirs(tflite_model_name_dir, exist_ok=True)
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
self.raw_model_file = "/".join(
[tflite_model_name_dir, str(self.model_name) + "_tflite.tflite"]
)
self.mlir_file = "/".join(
[tflite_model_name_dir, str(self.model_name) + "_tflite.mlir"]
)
if os.path.exists(self.raw_model_file):
def import_debug(
self,
is_dynamic=False,
tracing_required=False,
func_name="forward",
):
if self.inputs == None:
print(
"Local address for .tflite model file Exists: ",
self.raw_model_file,
f"There is no input provided: {self.inputs}, please provide inputs or simply run import_mlir."
)
else:
print("No local tflite file, Download tflite model")
if self.model_path is None:
# get model file from tflite_model_list.csv or download from gs://bucket
print("No model_path, get from tflite_model_list.csv")
tflite_model_list_path = os.path.join(
os.path.dirname(__file__),
"../tank/tflite/tflite_model_list.csv",
)
tflite_model_list = csv.reader(open(tflite_model_list_path))
for row in tflite_model_list:
if str(row[0]) == str(self.model_name):
self.model_path = row[1]
print("tflite_model_name", str(row[0]))
print("tflite_model_link", self.model_path)
if self.model_path is None:
print("Error, No model path find in tflite_model_list.csv")
return False
urllib.request.urlretrieve(self.model_path, self.raw_model_file)
if os.path.exists(self.mlir_file):
print("Exists MLIR model ", self.mlir_file)
else:
print(
"No tflite tosa.mlir, please use python generate_sharktank.py to download tosa model"
sys.exit(1)
imported_mlir = self.import_mlir(
is_dynamic, tracing_required, func_name
)
# TODO: Make sure that any generic function name is accepted. Currently takes in the default function names.
# TODO: Check for multiple outputs.
if self.frontend in ["torch", "pytorch"]:
golden_out = self.module(*self.inputs)
return (
imported_mlir,
self.convert_to_numpy(self.inputs),
golden_out.detach().numpy(),
)
print("Convert tflite to tosa.mlir")
import iree.compiler.tflite as ireec_tflite
ireec_tflite.compile_file(
self.raw_model_file,
input_type="tosa",
save_temp_iree_input=self.mlir_file,
target_backends=[IREE_TARGET_MAP["cpu"]],
import_only=False,
if self.frontend in ["tf", "tensorflow"]:
golden_out = self.module.forward(*self.inputs)
return (
imported_mlir,
self.convert_to_numpy(self.inputs),
golden_out.numpy(),
)
with open(self.mlir_file) as f:
self.mlir_model = f.read()
return True
def setup_interpreter(self):
if self.model_type == "tflite":
self.interpreter = TFLiteModelUtil(self.raw_model_file)
(
self.input_details,
self.output_details,
) = self.interpreter.setup_tflite_interpreter()
def generate_inputs(self, input_details):
self.inputs = []
for tmp_input in input_details:
print(str(tmp_input["shape"]), tmp_input["dtype"].__name__)
self.inputs.append(
np.ones(shape=tmp_input["shape"], dtype=tmp_input["dtype"])
if self.frontend in ["tflite", "tf-lite"]:
# TODO(Chi): Validate it for tflite models.
golden_out = self.module.main(*self.inputs)
return (
imported_mlir,
self.convert_to_numpy(self.inputs),
golden_out.numpy(),
)
# save inputs into json file
tmp_json = []
for tmp_input in input_details:
print(str(tmp_input["shape"]), tmp_input["dtype"].__name__)
tmp_json.append(
np.ones(
shape=tmp_input["shape"], dtype=tmp_input["dtype"]
).tolist()
)
with open("input1.json", "w") as f:
json.dump(tmp_json, f)
return self.inputs
# def get_model_details(self):
# if self.model_type == "tflite":
# print("Get tflite input output details")
# self.input_details = self.tflite_interpreter.get_input_details()
# self.output_details = self.tflite_interpreter.get_output_details()
# return self.input_details, self.output_details
def setup_inputs(self, inputs):
print("Setting up inputs")
self.inputs = inputs
def get_mlir_model(self):
return self.mlir_model
def get_inputs(self):
return self.inputs
def get_raw_model_output(self):
if self.model_type == "tflite":
self.output_tensor = self.interpreter.invoke_tflite(self.inputs)
return self.output_tensor
def get_model_details(self):
return self.input_details, self.output_details
def get_raw_model_file(self):
return self.raw_model_file
# def invoke_tflite(self, inputs):
# print("invoke_tflite")
# for i, input in enumerate(self.inputs):
# self.tflite_interpreter.set_tensor(
# self.input_details[i]["index"], input
# )
# self.tflite_interpreter.invoke()
#
# # post process tflite_result for compare with mlir_result,
# # for tflite the output is a list of numpy.tensor
# tflite_results = []
# for output_detail in self.output_details:
# tflite_results.append(
# np.array(
# self.tflite_interpreter.get_tensor(output_detail["index"])
# )
# )
#
# for i in range(len(self.output_details)):
# out_dtype = self.output_details[i]["dtype"]
# tflite_results[i] = tflite_results[i].astype(out_dtype)
# return tflite_results

View File

@@ -9,138 +9,64 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from shark.parser import shark_args
from shark.shark_runner import SharkRunner
import sys
import os
# Prints to stderr.
def print_err(*a):
print(*a, file=sys.stderr)
class SharkInference:
"""Inference API targeting pytorch, tensorflow, linalg, mhlo and tosa frontend."""
"""
Runs prediction or inference on mlir_module.
...
Attributes
----------
mlir_module : str
mlir_module represented in string.
function_name : str
function to execute in the given mlir_module.
device : str
device to execute the mlir_module on.
currently supports cpu, cuda, vulkan, and metal backends.
mlir_dialect: str
The dialect in which the given mlir_module is in.
Refer to {https://mlir.llvm.org/docs/Dialects/}
Methods
-------
run(inputs=None):
Runs the mlir_module with the given inputs, if the inputs are not
given it autogenerates the inputs. Also, the inputs should be a
numpy array.
input_info():
Gives the information about the inputs required by the `function_name`.
This can be expensive as it does string matching to do so.
TODO(Stanley) Add the benchmark APIs with is_benchmark = True argument.
"""
def __init__(
self,
model,
input: tuple,
device: str = None,
dynamic: bool = False,
jit_trace: bool = False,
benchmark_mode: bool = False,
mlir_module: str,
function_name: str = "forward",
device: str = "cpu",
mlir_dialect: str = "linalg",
):
self.model = model
self.input = input
self.dynamic = dynamic
self.jit_trace = jit_trace
self.benchmark_mode = benchmark_mode
# By default it's torch frontend.
self.frontend = "pytorch"
# Sets the device.
self.device = device if device is not None else shark_args.device
self.model_config_path = shark_args.model_config_path
self.mlir_module = mlir_module
self.function_name = function_name
self.device = device
self.mlir_dialect = mlir_dialect
self.shark_runner = None
# Sets the frontend i.e `pytorch` or `tensorflow`.
def set_frontend(self, frontend: str):
if frontend not in [
"pytorch",
"torch",
"tensorflow",
"tf",
"mhlo",
"linalg",
"tosa",
"tflite",
"tflite-tosa",
]:
print_err("frontend not supported.")
else:
self.frontend = frontend
def compile(self):
# Inference do not use AOT. # TODO: Remove the from_aot arg as it's not
# needed.
from_aot = False
if self.benchmark_mode == True:
# Only import shark_benchmark runner when needed.
from shark.shark_benchmark_runner import SharkBenchmarkRunner
self.shark_runner = SharkBenchmarkRunner(
self.model,
self.input,
self.dynamic,
self.device,
self.jit_trace,
from_aot,
self.frontend,
)
else:
self.shark_runner = SharkRunner(
self.model,
self.input,
self.dynamic,
self.device,
self.jit_trace,
from_aot,
self.frontend,
self.model_config_path,
)
# inputs are considered to be np.array.
def forward(self, inputs):
input_list = inputs
# converts the inputs to numpy.
if self.frontend in ["pytorch", "torch"]:
input_list = [x.detach().numpy() for x in inputs]
elif self.frontend in ["tensorflow", "tf"]:
input_list = [x.numpy() for x in inputs]
return self.shark_runner.forward(input_list, self.frontend)
def import_mlir(self, model_name="model", dir=os.getcwd()):
self.shark_runner.import_mlir(model_name, dir)
# Saves the .vmfb module.
def save_module(self, dir=None):
if dir is None:
return self.shark_runner.save_module()
return self.shark_runner.save_module(dir)
######### Benchmark Related Functions #########
def benchmark_mode(func):
def inner(self, *args, **kwargs):
assert (
self.benchmark_mode
), "SharkRunner needs to be in benchmark mode to run benchmark methods."
return func(self, *args, **kwargs)
return inner
@benchmark_mode
def benchmark_all(self, inputs):
self.shark_runner.benchmark_all(inputs)
@benchmark_mode
def benchmark_frontend(self, inputs):
self.shark_runner.benchmark_frontend(inputs)
@benchmark_mode
def benchmark_python(self, inputs):
self.shark_runner.benchmark_python(inputs)
@benchmark_mode
def benchmark_c(self):
self.shark_runner.benchmark_c()
@benchmark_mode
def benchmark_all_csv(self, inputs, modelname, dynamic, device_str):
self.shark_runner.benchmark_all_csv(
inputs, modelname, dynamic, device_str
# TODO: (Stanley) Update the shark_benchmark APIs.
self.shark_runner = SharkRunner(
self.mlir_module,
self.function_name,
self.device,
self.mlir_dialect,
)
# inputs are considered to be tuple of np.array.
def forward(self, inputs: tuple):
return self.shark_runner.run(inputs)

View File

@@ -11,137 +11,94 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from iree.compiler import tf as tfc
import iree.compiler.tflite as ireec_tflite
from torch.utils._python_dispatch import enable_torch_dispatch_mode
from torch_mlir.eager_mode import torch_mlir_tensor
from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor
from torch_mlir_e2e_test.eager_backends.refbackend import EagerModeRefBackend
from shark.iree_eager_backend import EagerModeIREELinalgOnTensorsBackend
from shark.torch_mlir_utils import get_torch_mlir_module, run_on_refbackend
from shark.parser import shark_args
from shark.iree_utils.compile_utils import (
get_iree_compiled_module,
export_iree_module_to_vmfb,
export_module_to_mlir_file,
get_results,
export_iree_module_to_vmfb,
)
from shark.iree_utils._common import check_device_drivers, device_driver_info
import os
import sys
# supported dialects by the shark-runtime.
supported_dialects = {"linalg", "mhlo", "tosa", "tf-lite"}
class SharkRunner:
"""Base class for Shark Inference and Shark Runner."""
"""
Base class for SharkInference and SharkTrainer
used to execute an mlir_module.
...
Attributes
----------
mlir_module : str
mlir_module represented in string.
function_name : str
function to execute in the given mlir_module.
device : str
device to execute the mlir_module on.
currently supports cpu, cuda, vulkan, and metal backends.
mlir_dialect: str
The dialect in which the given mlir_module is in.
Refer to {https://mlir.llvm.org/docs/Dialects/}
Methods
-------
run(inputs=None):
Runs the mlir_module with the given inputs, if the inputs are not
given it autogenerates the inputs. Also, the inputs should be a
numpy array.
input_info():
Gives the information about the inputs required by the `function_name`.
This can be expensive as it does string matching to do so.
"""
def __init__(
self,
model,
input: tuple,
dynamic: bool = False,
device: str = None,
jit_trace: bool = False,
from_aot: bool = False,
frontend: str = "torch",
model_config_path: str = None,
mlir_module: str,
function_name: str = "forward",
device: str = "cpu",
mlir_dialect: str = "linalg",
):
self.model = model
self.frontend_model = model
self.from_aot = from_aot
self.input = input
self.frontend = frontend
self.vmfb_file = None
func_name = "forward"
self.device = device if device is not None else shark_args.device
self.mlir_module = mlir_module
self.function_name = function_name
self.device = device
self.mlir_dialect = mlir_dialect
if self.frontend in ["tflite-tosa"]:
func_name = "main"
elif self.frontend in ["pytorch", "torch"]:
# get torch-mlir dialect
# self.model = torch.Module
# Lowers in linalg dialect.
# TODO assert
# TODO tosa dialect from torch_module.
self.model = get_torch_mlir_module(
self.model, input, dynamic, jit_trace, from_aot
)
elif self.frontend in ["tensorflow", "tf"]:
# get mhlo dialect
# self.model = tf.Module
# TODO assert
self.model = tfc.compile_module(
self.model, exported_names=[func_name], import_only=True
)
elif self.frontend in ["tflite"]:
print("Setting up for IREE compiler tflite")
# get tosa dialect
# self.model = model.tflite
# TODO assert
self.model = ireec_tflite.compile_file(
self.model, input_type="tosa", import_only=True
)
func_name = "main"
if check_device_drivers(self.device):
device_driver_info(self.device)
sys.exit(1)
# TODO: We can capture the .vmfb module here and later use it for saving
# rather than recompiling it again, if used for saving.
# Compile the module to get the .vmfb.
(
self.iree_compilation_module,
self.iree_config,
) = get_iree_compiled_module(
self.model,
self.mlir_module,
self.device,
self.frontend,
func_name=func_name,
model_config_path=model_config_path,
self.mlir_dialect,
func_name=self.function_name,
)
# Debugging Options:
if shark_args.save_mlir:
export_module_to_mlir_file(
self.model, self.frontend, shark_args.repro_dir
)
if shark_args.save_vmfb:
self.vmfb_file = self.save_module(shark_args.repro_dir)
# All the timings and benchmarking can be done here.
def forward(self, input, frontend):
def run(self, inputs: tuple):
return get_results(
self.iree_compilation_module, input, self.iree_config, frontend
self.iree_compilation_module,
inputs,
self.iree_config,
self.mlir_dialect,
)
# Saves the .mlir file, can be in tosa, linalg or mhlo dialect.
# torch-mlir can export tosa or linalg dialects.
# tensorflow models get exported to mhlo dialect.
def import_mlir(self, model_name, dir):
filename = os.path.join(dir, f"{model_name}_{self.frontend}.mlir")
with open(filename, "w") as f:
f.write(self.model)
print(f"Saved mlir in {filename}.")
return filename
# TODO: Instead of passing directory and having names decided by the module
# , user may want to save the module with manual names.
def save_module(self, dir=os.getcwd()):
return export_iree_module_to_vmfb(
self.model, self.device, dir, self.frontend
self.model, self.device, dir, self.mlir_dialect
)
# TODO: Load a module and directly use it, we will need to set the frontend
# in this case.
def load_module(self, name):
# TODO: Get the input information from the mlir_module.
def input_info(self):
pass
# TODO: Document shark_eager mode.
class SharkEagerMode:
def __init__(self, device="cpu"):
if device == "refbackend":
torch_mlir_tensor.backend = EagerModeRefBackend()
else:
torch_mlir_tensor.backend = EagerModeIREELinalgOnTensorsBackend(
device
)
self.guard = enable_torch_dispatch_mode(TorchMLIRTensor)
self.guard.__enter__()
def __del__(self):
self.guard.__exit__(None, None, None)

View File

@@ -110,14 +110,14 @@ def get_torch_mlir_module(
module,
input: tuple,
dynamic: bool,
tracing_required: bool,
from_aot: bool = False,
jit_trace: bool,
from_torchscript: bool = False,
):
"""TODO: Include necessary documentation."""
# Tracing is not required from the aot_module.
if not from_aot:
module = shark_jit_trace(module, input, dynamic, tracing_required)
if not from_torchscript:
module = shark_jit_trace(module, input, dynamic, jit_trace)
mb = ModuleBuilder()
class_annotator = ClassAnnotator()

View File

@@ -1,4 +1,5 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from tank.model_utils import get_hf_model, compare_tensors
from shark.parser import shark_args
@@ -30,12 +31,16 @@ class MiniLMModuleTester:
)
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
shark_module = SharkInference(
mlir_importer = SharkImporter(
model,
(input,),
device=self.device,
dynamic=self.dynamic,
jit_trace=True,
frontend="torch",
)
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=self.dynamic, tracing_required=True
)
shark_module = SharkInference(
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
results = shark_module.forward((input,))

View File

@@ -1,4 +1,5 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from tank.model_utils import get_hf_model, compare_tensors
from shark.parser import shark_args
@@ -28,12 +29,16 @@ class AlbertModuleTester:
model, input, act_out = get_hf_model("albert-base-v2")
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
shark_module = SharkInference(
mlir_importer = SharkImporter(
model,
(input,),
device=self.device,
dynamic=self.dynamic,
jit_trace=True,
frontend="torch",
)
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=self.dynamic, tracing_required=True
)
shark_module = SharkInference(
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
results = shark_module.forward((input,))
@@ -55,9 +60,6 @@ class AlbertModuleTest(unittest.TestCase):
self.module_tester.device = "cpu"
self.module_tester.create_and_check_module()
@pytest.mark.xfail(
reason="Language models currently failing for dynamic case"
)
def test_module_dynamic_cpu(self):
self.module_tester.dynamic = True
self.module_tester.device = "cpu"
@@ -79,7 +81,6 @@ class AlbertModuleTest(unittest.TestCase):
self.module_tester.device = "gpu"
self.module_tester.create_and_check_module()
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
@pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
@@ -89,7 +90,6 @@ class AlbertModuleTest(unittest.TestCase):
self.module_tester.device = "vulkan"
self.module_tester.create_and_check_module()
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
@pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",

View File

@@ -1,4 +1,5 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from tank.model_utils import get_vision_model, compare_tensors
from shark.parser import shark_args
@@ -31,11 +32,16 @@ class AlexnetModuleTester:
)
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
shark_module = SharkInference(
mlir_importer = SharkImporter(
model,
(input,),
device=self.device,
dynamic=self.dynamic,
frontend="torch",
)
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=self.dynamic, tracing_required=False
)
shark_module = SharkInference(
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
results = shark_module.forward((input,))

View File

@@ -1,4 +1,5 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from tank.model_utils import get_hf_model, compare_tensors
from shark.parser import shark_args
@@ -28,12 +29,16 @@ class BertModuleTester:
model, input, act_out = get_hf_model("bert-base-uncased")
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
shark_module = SharkInference(
mlir_importer = SharkImporter(
model,
(input,),
device=self.device,
dynamic=self.dynamic,
jit_trace=True,
frontend="torch",
)
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=self.dynamic, tracing_required=True
)
shark_module = SharkInference(
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
results = shark_module.forward((input,))

View File

@@ -1,4 +1,5 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from tank.model_utils import get_hf_model, compare_tensors
from shark.parser import shark_args
@@ -28,12 +29,16 @@ class DistilBertModuleTester:
model, input, act_out = get_hf_model("distilbert-base-uncased")
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
shark_module = SharkInference(
mlir_importer = SharkImporter(
model,
(input,),
device=self.device,
dynamic=self.dynamic,
jit_trace=True,
frontend="torch",
)
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=self.dynamic, tracing_required=True
)
shark_module = SharkInference(
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
results = shark_module.forward((input,))

View File

@@ -1,4 +1,5 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from tank.model_utils import get_vision_model, compare_tensors
from shark.parser import shark_args
@@ -31,11 +32,16 @@ class Resnet101ModuleTester:
)
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
shark_module = SharkInference(
mlir_importer = SharkImporter(
model,
(input,),
device=self.device,
dynamic=self.dynamic,
frontend="torch",
)
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=self.dynamic, tracing_required=False
)
shark_module = SharkInference(
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
results = shark_module.forward((input,))

View File

@@ -1,4 +1,5 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from tank.model_utils import get_vision_model, compare_tensors
from shark.parser import shark_args
@@ -31,11 +32,16 @@ class Resnet18ModuleTester:
)
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
shark_module = SharkInference(
mlir_importer = SharkImporter(
model,
(input,),
device=self.device,
dynamic=self.dynamic,
frontend="torch",
)
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=self.dynamic, tracing_required=False
)
shark_module = SharkInference(
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
results = shark_module.forward((input,))

View File

@@ -1,4 +1,5 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from tank.model_utils import get_vision_model, compare_tensors
from shark.parser import shark_args
@@ -31,11 +32,16 @@ class Resnet50ModuleTester:
)
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
shark_module = SharkInference(
mlir_importer = SharkImporter(
model,
(input,),
device=self.device,
dynamic=self.dynamic,
frontend="torch",
)
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=self.dynamic, tracing_required=False
)
shark_module = SharkInference(
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
results = shark_module.forward((input,))

View File

@@ -1,4 +1,5 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from tank.model_utils import get_vision_model, compare_tensors
from shark.parser import shark_args
@@ -31,11 +32,16 @@ class SqueezenetModuleTester:
)
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
shark_module = SharkInference(
mlir_importer = SharkImporter(
model,
(input,),
device=self.device,
dynamic=self.dynamic,
frontend="torch",
)
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=self.dynamic, tracing_required=False
)
shark_module = SharkInference(
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
results = shark_module.forward((input,))

View File

@@ -1,4 +1,5 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from tank.model_utils import get_vision_model, compare_tensors
from shark.parser import shark_args
@@ -31,11 +32,16 @@ class WideResnet50ModuleTester:
)
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
shark_module = SharkInference(
mlir_importer = SharkImporter(
model,
(input,),
device=self.device,
dynamic=self.dynamic,
frontend="torch",
)
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=self.dynamic
)
shark_module = SharkInference(
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
results = shark_module.forward((input,))