mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
Refactor the shark_runner shark_inference to only support mlir_modules.
1. The shark_inference is divided into shark_importer and shark_inference. 2. All the tank/pytorch tests have been updated.
This commit is contained in:
@@ -12,7 +12,23 @@ mhlo_ir = r"""builtin.module {
|
||||
arg0 = np.ones((1, 4)).astype(np.float32)
|
||||
arg1 = np.ones((4, 1)).astype(np.float32)
|
||||
|
||||
shark_module = SharkInference(mhlo_ir, (arg0, arg1))
|
||||
shark_module.set_frontend("mhlo")
|
||||
print("Running shark on cpu backend")
|
||||
shark_module = SharkInference(
|
||||
mhlo_ir, function_name="forward", device="cpu", mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.compile()
|
||||
print(shark_module.forward((arg0, arg1)))
|
||||
|
||||
print("Running shark on cuda backend")
|
||||
shark_module = SharkInference(
|
||||
mhlo_ir, function_name="forward", device="cuda", mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.compile()
|
||||
print(shark_module.forward((arg0, arg1)))
|
||||
|
||||
print("Running shark on vulkan backend")
|
||||
shark_module = SharkInference(
|
||||
mhlo_ir, function_name="forward", device="vulkan", mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.compile()
|
||||
print(shark_module.forward((arg0, arg1)))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
|
||||
@@ -23,12 +24,22 @@ class MiniLMSequenceClassification(torch.nn.Module):
|
||||
|
||||
test_input = torch.randint(2, (1, 128)).to(torch.int32)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
MiniLMSequenceClassification(),
|
||||
(test_input, test_input, test_input),
|
||||
jit_trace=True,
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
# torch hugging face models needs tracing..
|
||||
(minilm_mlir, func_name), inputs, golden_out = mlir_importer.import_debug(
|
||||
tracing_required=True
|
||||
)
|
||||
|
||||
print(golden_out)
|
||||
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((test_input, test_input, test_input))
|
||||
print("Obtained result", result)
|
||||
|
||||
@@ -75,3 +75,17 @@ def check_device_drivers(device):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# Installation info for the missing device drivers.
|
||||
def device_driver_info(device):
|
||||
if device in ["gpu", "cuda"]:
|
||||
print(
|
||||
"nvidia-smi not found, please install the required drivers from https://www.nvidia.in/Download/index.aspx?lang=en-in"
|
||||
)
|
||||
elif device in ["metal", "vulkan"]:
|
||||
print(
|
||||
"vulkaninfo not found, Install from https://vulkan.lunarg.com/sdk/home or your distribution"
|
||||
)
|
||||
else:
|
||||
print(f"{device} is not supported.")
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
import iree.runtime as ireert
|
||||
import iree.compiler as ireec
|
||||
from shark.iree_utils._common import IREE_DEVICE_MAP, IREE_TARGET_MAP
|
||||
from shark.model_annotation import *
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
@@ -157,23 +156,7 @@ def export_module_to_mlir_file(module, frontend, directory: str):
|
||||
|
||||
def get_results(compiled_vm, input, config, frontend="torch"):
|
||||
"""Runs a .vmfb file given inputs and config and returns output."""
|
||||
device_inputs = input
|
||||
if frontend in ["torch", "pytorch"]:
|
||||
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
|
||||
if frontend in ["tensorflow", "tf", "tflite", "tflite-tosa"]:
|
||||
device_inputs = []
|
||||
for a in input:
|
||||
if isinstance(a, list):
|
||||
device_inputs.append(
|
||||
[
|
||||
ireert.asdevicearray(
|
||||
config.device, val, dtype=val.dtype
|
||||
)
|
||||
for val in a
|
||||
]
|
||||
)
|
||||
else:
|
||||
device_inputs.append(ireert.asdevicearray(config.device, a))
|
||||
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
|
||||
result = compiled_vm(*device_inputs)
|
||||
result_tensors = []
|
||||
if isinstance(result, tuple):
|
||||
|
||||
@@ -1,214 +1,155 @@
|
||||
# Lint as: python3
|
||||
"""SHARK Importer"""
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import csv
|
||||
import urllib.request
|
||||
from shark.iree_utils._common import IREE_TARGET_MAP
|
||||
import json
|
||||
from tank.model_utils_tflite import TFLiteModelUtil
|
||||
import sys
|
||||
|
||||
# List of the supported frontends.
|
||||
supported_frontends = {
|
||||
"tensorflow",
|
||||
"tf",
|
||||
"pytorch",
|
||||
"torch",
|
||||
"tf-lite",
|
||||
"tflite",
|
||||
}
|
||||
|
||||
|
||||
class SharkImporter:
|
||||
"""
|
||||
SharkImporter converts frontend modules into a
|
||||
mlir_module. The supported frameworks are tensorflow,
|
||||
pytorch, and tf-lite.
|
||||
|
||||
...
|
||||
|
||||
Attributes
|
||||
----------
|
||||
module :
|
||||
torch, tensorflow or tf-lite module.
|
||||
inputs :
|
||||
inputs to the module, may be required for the shape
|
||||
information.
|
||||
frontend: str
|
||||
frontend to which the module belongs.
|
||||
|
||||
Methods
|
||||
-------
|
||||
import_mlir(is_dynamic, tracing_required, func_name):
|
||||
is_dynamic: input shapes to be totally dynamic (pytorch specific).
|
||||
tracing_required: whether tracing is required (pytorch specific.
|
||||
func_name: The function to be traced out or imported to mlir.
|
||||
|
||||
import_debug(is_dynamic, tracing_required, func_name):
|
||||
returns the converted (mlir_module,func_name) with inputs and golden
|
||||
outputs.
|
||||
The inputs and outputs are converted into np array.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
model_type: str = "torch",
|
||||
input_details=None,
|
||||
output_details=None,
|
||||
model_path=None,
|
||||
module,
|
||||
inputs: tuple = (),
|
||||
frontend: str = "torch",
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.model_type = model_type
|
||||
self.input_details = (
|
||||
input_details # used for tflite, optional for tf/pytorch
|
||||
)
|
||||
self.output_details = (
|
||||
output_details # used for tflite, optional for tf/pytorch
|
||||
)
|
||||
self.inputs = []
|
||||
self.model_path = model_path # url to download the model
|
||||
self.raw_model_file = (
|
||||
None # local address for raw tf/tflite/pytorch model
|
||||
)
|
||||
self.mlir_file = (
|
||||
None # local address for .mlir file of tf/tflite/pytorch model
|
||||
)
|
||||
self.mlir_model = None # read of .mlir file
|
||||
self.output_tensor = (
|
||||
None # the raw tf/pytorch/tflite_output_tensor, not mlir_tensor
|
||||
)
|
||||
self.interpreter = None # could be tflite/tf/torch_interpreter in utils
|
||||
|
||||
# create tmp model file directory
|
||||
if self.model_path is None and self.model_name is None:
|
||||
self.module = module
|
||||
self.inputs = None if len(inputs) == 0 else inputs
|
||||
self.frontend = frontend
|
||||
if not self.frontend in supported_frontends:
|
||||
print(
|
||||
"Error. No model_path, No model name,Please input either one."
|
||||
f"The frontend is not in the supported_frontends: {supported_frontends}"
|
||||
)
|
||||
return
|
||||
sys.exit(1)
|
||||
|
||||
print("Setting up for TMP_WORK_DIR")
|
||||
self.workdir = os.path.join(
|
||||
os.path.dirname(__file__), "./../gen_shark_tank"
|
||||
# NOTE: The default function for torch is "forward" and tf-lite is "main".
|
||||
|
||||
def _torch_mlir(self, is_dynamic, tracing_required):
|
||||
from shark.torch_mlir_utils import get_torch_mlir_module
|
||||
|
||||
return get_torch_mlir_module(
|
||||
self.module, self.inputs, is_dynamic, tracing_required
|
||||
)
|
||||
os.makedirs(self.workdir, exist_ok=True)
|
||||
print(f"TMP_WORK_DIR = {self.workdir}")
|
||||
|
||||
# compile and run tfhub tflite
|
||||
if self.model_type == "tflite":
|
||||
load_model_success = self.load_tflite_model()
|
||||
if not load_model_success:
|
||||
print("Error, load tflite model fail")
|
||||
return
|
||||
def _tf_mlir(self, func_name):
|
||||
from iree.compiler import tf as tfc
|
||||
|
||||
if (self.input_details is None) or (self.output_details is None):
|
||||
return tfc.compile_module(
|
||||
self.module, exported_names=[func_name], import_only=True
|
||||
)
|
||||
|
||||
def _tflite_mlir(self, func_name):
|
||||
from iree.compiler import tflite as tflitec
|
||||
|
||||
# TODO(Chi): Just add the conversion of tflite model here.
|
||||
return tflitec.compile_module(
|
||||
self.module, exported_names=[func_name], import_only=True
|
||||
)
|
||||
|
||||
# Adds the conversion of the frontend with the private function.
|
||||
def import_mlir(
|
||||
self,
|
||||
is_dynamic=False,
|
||||
tracing_required=False,
|
||||
func_name="forward",
|
||||
):
|
||||
if self.frontend in ["torch", "pytorch"]:
|
||||
if self.inputs == None:
|
||||
print(
|
||||
"Setting up tflite interpreter to get model input details"
|
||||
"Please pass in the inputs, the inputs are required to determine the shape of the mlir_module"
|
||||
)
|
||||
self.setup_interpreter()
|
||||
sys.exit(1)
|
||||
return self._torch_mlir(is_dynamic, tracing_required), func_name
|
||||
if self.frontend in ["tf", "tensorflow"]:
|
||||
return self._tf_mlir(func_name), func_name
|
||||
if self.frontend in ["tflite", "tf-lite"]:
|
||||
func_name = "main"
|
||||
return self._tflite_mlir(func_name), func_name
|
||||
|
||||
inputs = self.generate_inputs(
|
||||
self.input_details
|
||||
) # device_inputs
|
||||
self.setup_inputs(inputs)
|
||||
# Converts the frontend specific tensors into np array.
|
||||
def convert_to_numpy(self, array_tuple: tuple):
|
||||
if self.frontend in ["torch", "pytorch"]:
|
||||
return [x.detach().numpy() for x in array_tuple]
|
||||
if self.frontend in ["tf", "tensorflow"]:
|
||||
return [x.numpy() for x in array_tuple]
|
||||
if self.frontend in ["tf-lite", "tflite"]:
|
||||
# TODO(Chi): Validate for tf-lite tensors.
|
||||
return [x.numpy() for x in array_tuple]
|
||||
|
||||
elif self.model_type in ["tensorflow, tf, torch, pytorch"]:
|
||||
print(self.model_type, " Not Implemented yet")
|
||||
|
||||
def load_tflite_model(self):
|
||||
# use model name get dir.
|
||||
tflite_model_name_dir = os.path.join(self.workdir, str(self.model_name))
|
||||
|
||||
os.makedirs(tflite_model_name_dir, exist_ok=True)
|
||||
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
|
||||
|
||||
self.raw_model_file = "/".join(
|
||||
[tflite_model_name_dir, str(self.model_name) + "_tflite.tflite"]
|
||||
)
|
||||
self.mlir_file = "/".join(
|
||||
[tflite_model_name_dir, str(self.model_name) + "_tflite.mlir"]
|
||||
)
|
||||
|
||||
if os.path.exists(self.raw_model_file):
|
||||
def import_debug(
|
||||
self,
|
||||
is_dynamic=False,
|
||||
tracing_required=False,
|
||||
func_name="forward",
|
||||
):
|
||||
if self.inputs == None:
|
||||
print(
|
||||
"Local address for .tflite model file Exists: ",
|
||||
self.raw_model_file,
|
||||
f"There is no input provided: {self.inputs}, please provide inputs or simply run import_mlir."
|
||||
)
|
||||
else:
|
||||
print("No local tflite file, Download tflite model")
|
||||
if self.model_path is None:
|
||||
# get model file from tflite_model_list.csv or download from gs://bucket
|
||||
print("No model_path, get from tflite_model_list.csv")
|
||||
tflite_model_list_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"../tank/tflite/tflite_model_list.csv",
|
||||
)
|
||||
tflite_model_list = csv.reader(open(tflite_model_list_path))
|
||||
for row in tflite_model_list:
|
||||
if str(row[0]) == str(self.model_name):
|
||||
self.model_path = row[1]
|
||||
print("tflite_model_name", str(row[0]))
|
||||
print("tflite_model_link", self.model_path)
|
||||
if self.model_path is None:
|
||||
print("Error, No model path find in tflite_model_list.csv")
|
||||
return False
|
||||
urllib.request.urlretrieve(self.model_path, self.raw_model_file)
|
||||
if os.path.exists(self.mlir_file):
|
||||
print("Exists MLIR model ", self.mlir_file)
|
||||
else:
|
||||
print(
|
||||
"No tflite tosa.mlir, please use python generate_sharktank.py to download tosa model"
|
||||
sys.exit(1)
|
||||
|
||||
imported_mlir = self.import_mlir(
|
||||
is_dynamic, tracing_required, func_name
|
||||
)
|
||||
# TODO: Make sure that any generic function name is accepted. Currently takes in the default function names.
|
||||
# TODO: Check for multiple outputs.
|
||||
if self.frontend in ["torch", "pytorch"]:
|
||||
golden_out = self.module(*self.inputs)
|
||||
return (
|
||||
imported_mlir,
|
||||
self.convert_to_numpy(self.inputs),
|
||||
golden_out.detach().numpy(),
|
||||
)
|
||||
print("Convert tflite to tosa.mlir")
|
||||
import iree.compiler.tflite as ireec_tflite
|
||||
|
||||
ireec_tflite.compile_file(
|
||||
self.raw_model_file,
|
||||
input_type="tosa",
|
||||
save_temp_iree_input=self.mlir_file,
|
||||
target_backends=[IREE_TARGET_MAP["cpu"]],
|
||||
import_only=False,
|
||||
if self.frontend in ["tf", "tensorflow"]:
|
||||
golden_out = self.module.forward(*self.inputs)
|
||||
return (
|
||||
imported_mlir,
|
||||
self.convert_to_numpy(self.inputs),
|
||||
golden_out.numpy(),
|
||||
)
|
||||
with open(self.mlir_file) as f:
|
||||
self.mlir_model = f.read()
|
||||
return True
|
||||
|
||||
def setup_interpreter(self):
|
||||
if self.model_type == "tflite":
|
||||
self.interpreter = TFLiteModelUtil(self.raw_model_file)
|
||||
(
|
||||
self.input_details,
|
||||
self.output_details,
|
||||
) = self.interpreter.setup_tflite_interpreter()
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
self.inputs = []
|
||||
for tmp_input in input_details:
|
||||
print(str(tmp_input["shape"]), tmp_input["dtype"].__name__)
|
||||
self.inputs.append(
|
||||
np.ones(shape=tmp_input["shape"], dtype=tmp_input["dtype"])
|
||||
if self.frontend in ["tflite", "tf-lite"]:
|
||||
# TODO(Chi): Validate it for tflite models.
|
||||
golden_out = self.module.main(*self.inputs)
|
||||
return (
|
||||
imported_mlir,
|
||||
self.convert_to_numpy(self.inputs),
|
||||
golden_out.numpy(),
|
||||
)
|
||||
# save inputs into json file
|
||||
tmp_json = []
|
||||
for tmp_input in input_details:
|
||||
print(str(tmp_input["shape"]), tmp_input["dtype"].__name__)
|
||||
tmp_json.append(
|
||||
np.ones(
|
||||
shape=tmp_input["shape"], dtype=tmp_input["dtype"]
|
||||
).tolist()
|
||||
)
|
||||
with open("input1.json", "w") as f:
|
||||
json.dump(tmp_json, f)
|
||||
return self.inputs
|
||||
|
||||
# def get_model_details(self):
|
||||
# if self.model_type == "tflite":
|
||||
# print("Get tflite input output details")
|
||||
# self.input_details = self.tflite_interpreter.get_input_details()
|
||||
# self.output_details = self.tflite_interpreter.get_output_details()
|
||||
# return self.input_details, self.output_details
|
||||
|
||||
def setup_inputs(self, inputs):
|
||||
print("Setting up inputs")
|
||||
self.inputs = inputs
|
||||
|
||||
def get_mlir_model(self):
|
||||
return self.mlir_model
|
||||
|
||||
def get_inputs(self):
|
||||
return self.inputs
|
||||
|
||||
def get_raw_model_output(self):
|
||||
if self.model_type == "tflite":
|
||||
self.output_tensor = self.interpreter.invoke_tflite(self.inputs)
|
||||
return self.output_tensor
|
||||
|
||||
def get_model_details(self):
|
||||
return self.input_details, self.output_details
|
||||
|
||||
def get_raw_model_file(self):
|
||||
return self.raw_model_file
|
||||
|
||||
# def invoke_tflite(self, inputs):
|
||||
# print("invoke_tflite")
|
||||
# for i, input in enumerate(self.inputs):
|
||||
# self.tflite_interpreter.set_tensor(
|
||||
# self.input_details[i]["index"], input
|
||||
# )
|
||||
# self.tflite_interpreter.invoke()
|
||||
#
|
||||
# # post process tflite_result for compare with mlir_result,
|
||||
# # for tflite the output is a list of numpy.tensor
|
||||
# tflite_results = []
|
||||
# for output_detail in self.output_details:
|
||||
# tflite_results.append(
|
||||
# np.array(
|
||||
# self.tflite_interpreter.get_tensor(output_detail["index"])
|
||||
# )
|
||||
# )
|
||||
#
|
||||
# for i in range(len(self.output_details)):
|
||||
# out_dtype = self.output_details[i]["dtype"]
|
||||
# tflite_results[i] = tflite_results[i].astype(out_dtype)
|
||||
# return tflite_results
|
||||
|
||||
@@ -9,138 +9,64 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from shark.parser import shark_args
|
||||
from shark.shark_runner import SharkRunner
|
||||
import sys
|
||||
import os
|
||||
|
||||
|
||||
# Prints to stderr.
|
||||
def print_err(*a):
|
||||
print(*a, file=sys.stderr)
|
||||
|
||||
|
||||
class SharkInference:
|
||||
"""Inference API targeting pytorch, tensorflow, linalg, mhlo and tosa frontend."""
|
||||
"""
|
||||
Runs prediction or inference on mlir_module.
|
||||
|
||||
...
|
||||
|
||||
Attributes
|
||||
----------
|
||||
mlir_module : str
|
||||
mlir_module represented in string.
|
||||
function_name : str
|
||||
function to execute in the given mlir_module.
|
||||
device : str
|
||||
device to execute the mlir_module on.
|
||||
currently supports cpu, cuda, vulkan, and metal backends.
|
||||
mlir_dialect: str
|
||||
The dialect in which the given mlir_module is in.
|
||||
Refer to {https://mlir.llvm.org/docs/Dialects/}
|
||||
|
||||
Methods
|
||||
-------
|
||||
run(inputs=None):
|
||||
Runs the mlir_module with the given inputs, if the inputs are not
|
||||
given it autogenerates the inputs. Also, the inputs should be a
|
||||
numpy array.
|
||||
input_info():
|
||||
Gives the information about the inputs required by the `function_name`.
|
||||
This can be expensive as it does string matching to do so.
|
||||
|
||||
TODO(Stanley) Add the benchmark APIs with is_benchmark = True argument.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
input: tuple,
|
||||
device: str = None,
|
||||
dynamic: bool = False,
|
||||
jit_trace: bool = False,
|
||||
benchmark_mode: bool = False,
|
||||
mlir_module: str,
|
||||
function_name: str = "forward",
|
||||
device: str = "cpu",
|
||||
mlir_dialect: str = "linalg",
|
||||
):
|
||||
self.model = model
|
||||
self.input = input
|
||||
self.dynamic = dynamic
|
||||
self.jit_trace = jit_trace
|
||||
self.benchmark_mode = benchmark_mode
|
||||
|
||||
# By default it's torch frontend.
|
||||
self.frontend = "pytorch"
|
||||
|
||||
# Sets the device.
|
||||
self.device = device if device is not None else shark_args.device
|
||||
|
||||
self.model_config_path = shark_args.model_config_path
|
||||
self.mlir_module = mlir_module
|
||||
self.function_name = function_name
|
||||
self.device = device
|
||||
self.mlir_dialect = mlir_dialect
|
||||
|
||||
self.shark_runner = None
|
||||
|
||||
# Sets the frontend i.e `pytorch` or `tensorflow`.
|
||||
def set_frontend(self, frontend: str):
|
||||
if frontend not in [
|
||||
"pytorch",
|
||||
"torch",
|
||||
"tensorflow",
|
||||
"tf",
|
||||
"mhlo",
|
||||
"linalg",
|
||||
"tosa",
|
||||
"tflite",
|
||||
"tflite-tosa",
|
||||
]:
|
||||
print_err("frontend not supported.")
|
||||
else:
|
||||
self.frontend = frontend
|
||||
|
||||
def compile(self):
|
||||
# Inference do not use AOT. # TODO: Remove the from_aot arg as it's not
|
||||
# needed.
|
||||
from_aot = False
|
||||
if self.benchmark_mode == True:
|
||||
# Only import shark_benchmark runner when needed.
|
||||
from shark.shark_benchmark_runner import SharkBenchmarkRunner
|
||||
|
||||
self.shark_runner = SharkBenchmarkRunner(
|
||||
self.model,
|
||||
self.input,
|
||||
self.dynamic,
|
||||
self.device,
|
||||
self.jit_trace,
|
||||
from_aot,
|
||||
self.frontend,
|
||||
)
|
||||
else:
|
||||
self.shark_runner = SharkRunner(
|
||||
self.model,
|
||||
self.input,
|
||||
self.dynamic,
|
||||
self.device,
|
||||
self.jit_trace,
|
||||
from_aot,
|
||||
self.frontend,
|
||||
self.model_config_path,
|
||||
)
|
||||
|
||||
# inputs are considered to be np.array.
|
||||
def forward(self, inputs):
|
||||
input_list = inputs
|
||||
# converts the inputs to numpy.
|
||||
if self.frontend in ["pytorch", "torch"]:
|
||||
input_list = [x.detach().numpy() for x in inputs]
|
||||
elif self.frontend in ["tensorflow", "tf"]:
|
||||
input_list = [x.numpy() for x in inputs]
|
||||
return self.shark_runner.forward(input_list, self.frontend)
|
||||
|
||||
def import_mlir(self, model_name="model", dir=os.getcwd()):
|
||||
self.shark_runner.import_mlir(model_name, dir)
|
||||
|
||||
# Saves the .vmfb module.
|
||||
def save_module(self, dir=None):
|
||||
if dir is None:
|
||||
return self.shark_runner.save_module()
|
||||
return self.shark_runner.save_module(dir)
|
||||
|
||||
######### Benchmark Related Functions #########
|
||||
def benchmark_mode(func):
|
||||
def inner(self, *args, **kwargs):
|
||||
assert (
|
||||
self.benchmark_mode
|
||||
), "SharkRunner needs to be in benchmark mode to run benchmark methods."
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
@benchmark_mode
|
||||
def benchmark_all(self, inputs):
|
||||
self.shark_runner.benchmark_all(inputs)
|
||||
|
||||
@benchmark_mode
|
||||
def benchmark_frontend(self, inputs):
|
||||
self.shark_runner.benchmark_frontend(inputs)
|
||||
|
||||
@benchmark_mode
|
||||
def benchmark_python(self, inputs):
|
||||
self.shark_runner.benchmark_python(inputs)
|
||||
|
||||
@benchmark_mode
|
||||
def benchmark_c(self):
|
||||
self.shark_runner.benchmark_c()
|
||||
|
||||
@benchmark_mode
|
||||
def benchmark_all_csv(self, inputs, modelname, dynamic, device_str):
|
||||
self.shark_runner.benchmark_all_csv(
|
||||
inputs, modelname, dynamic, device_str
|
||||
# TODO: (Stanley) Update the shark_benchmark APIs.
|
||||
self.shark_runner = SharkRunner(
|
||||
self.mlir_module,
|
||||
self.function_name,
|
||||
self.device,
|
||||
self.mlir_dialect,
|
||||
)
|
||||
|
||||
# inputs are considered to be tuple of np.array.
|
||||
def forward(self, inputs: tuple):
|
||||
return self.shark_runner.run(inputs)
|
||||
|
||||
@@ -11,137 +11,94 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from iree.compiler import tf as tfc
|
||||
import iree.compiler.tflite as ireec_tflite
|
||||
from torch.utils._python_dispatch import enable_torch_dispatch_mode
|
||||
from torch_mlir.eager_mode import torch_mlir_tensor
|
||||
from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor
|
||||
from torch_mlir_e2e_test.eager_backends.refbackend import EagerModeRefBackend
|
||||
|
||||
from shark.iree_eager_backend import EagerModeIREELinalgOnTensorsBackend
|
||||
from shark.torch_mlir_utils import get_torch_mlir_module, run_on_refbackend
|
||||
from shark.parser import shark_args
|
||||
from shark.iree_utils.compile_utils import (
|
||||
get_iree_compiled_module,
|
||||
export_iree_module_to_vmfb,
|
||||
export_module_to_mlir_file,
|
||||
get_results,
|
||||
export_iree_module_to_vmfb,
|
||||
)
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
# supported dialects by the shark-runtime.
|
||||
supported_dialects = {"linalg", "mhlo", "tosa", "tf-lite"}
|
||||
|
||||
|
||||
class SharkRunner:
|
||||
"""Base class for Shark Inference and Shark Runner."""
|
||||
"""
|
||||
Base class for SharkInference and SharkTrainer
|
||||
used to execute an mlir_module.
|
||||
|
||||
...
|
||||
|
||||
Attributes
|
||||
----------
|
||||
mlir_module : str
|
||||
mlir_module represented in string.
|
||||
function_name : str
|
||||
function to execute in the given mlir_module.
|
||||
device : str
|
||||
device to execute the mlir_module on.
|
||||
currently supports cpu, cuda, vulkan, and metal backends.
|
||||
mlir_dialect: str
|
||||
The dialect in which the given mlir_module is in.
|
||||
Refer to {https://mlir.llvm.org/docs/Dialects/}
|
||||
|
||||
Methods
|
||||
-------
|
||||
run(inputs=None):
|
||||
Runs the mlir_module with the given inputs, if the inputs are not
|
||||
given it autogenerates the inputs. Also, the inputs should be a
|
||||
numpy array.
|
||||
input_info():
|
||||
Gives the information about the inputs required by the `function_name`.
|
||||
This can be expensive as it does string matching to do so.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
input: tuple,
|
||||
dynamic: bool = False,
|
||||
device: str = None,
|
||||
jit_trace: bool = False,
|
||||
from_aot: bool = False,
|
||||
frontend: str = "torch",
|
||||
model_config_path: str = None,
|
||||
mlir_module: str,
|
||||
function_name: str = "forward",
|
||||
device: str = "cpu",
|
||||
mlir_dialect: str = "linalg",
|
||||
):
|
||||
self.model = model
|
||||
self.frontend_model = model
|
||||
self.from_aot = from_aot
|
||||
self.input = input
|
||||
self.frontend = frontend
|
||||
self.vmfb_file = None
|
||||
func_name = "forward"
|
||||
self.device = device if device is not None else shark_args.device
|
||||
self.mlir_module = mlir_module
|
||||
self.function_name = function_name
|
||||
self.device = device
|
||||
self.mlir_dialect = mlir_dialect
|
||||
|
||||
if self.frontend in ["tflite-tosa"]:
|
||||
func_name = "main"
|
||||
elif self.frontend in ["pytorch", "torch"]:
|
||||
# get torch-mlir dialect
|
||||
# self.model = torch.Module
|
||||
# Lowers in linalg dialect.
|
||||
# TODO assert
|
||||
# TODO tosa dialect from torch_module.
|
||||
self.model = get_torch_mlir_module(
|
||||
self.model, input, dynamic, jit_trace, from_aot
|
||||
)
|
||||
elif self.frontend in ["tensorflow", "tf"]:
|
||||
# get mhlo dialect
|
||||
# self.model = tf.Module
|
||||
# TODO assert
|
||||
self.model = tfc.compile_module(
|
||||
self.model, exported_names=[func_name], import_only=True
|
||||
)
|
||||
elif self.frontend in ["tflite"]:
|
||||
print("Setting up for IREE compiler tflite")
|
||||
# get tosa dialect
|
||||
# self.model = model.tflite
|
||||
# TODO assert
|
||||
self.model = ireec_tflite.compile_file(
|
||||
self.model, input_type="tosa", import_only=True
|
||||
)
|
||||
func_name = "main"
|
||||
if check_device_drivers(self.device):
|
||||
device_driver_info(self.device)
|
||||
sys.exit(1)
|
||||
|
||||
# TODO: We can capture the .vmfb module here and later use it for saving
|
||||
# rather than recompiling it again, if used for saving.
|
||||
# Compile the module to get the .vmfb.
|
||||
(
|
||||
self.iree_compilation_module,
|
||||
self.iree_config,
|
||||
) = get_iree_compiled_module(
|
||||
self.model,
|
||||
self.mlir_module,
|
||||
self.device,
|
||||
self.frontend,
|
||||
func_name=func_name,
|
||||
model_config_path=model_config_path,
|
||||
self.mlir_dialect,
|
||||
func_name=self.function_name,
|
||||
)
|
||||
|
||||
# Debugging Options:
|
||||
if shark_args.save_mlir:
|
||||
export_module_to_mlir_file(
|
||||
self.model, self.frontend, shark_args.repro_dir
|
||||
)
|
||||
if shark_args.save_vmfb:
|
||||
self.vmfb_file = self.save_module(shark_args.repro_dir)
|
||||
|
||||
# All the timings and benchmarking can be done here.
|
||||
def forward(self, input, frontend):
|
||||
def run(self, inputs: tuple):
|
||||
return get_results(
|
||||
self.iree_compilation_module, input, self.iree_config, frontend
|
||||
self.iree_compilation_module,
|
||||
inputs,
|
||||
self.iree_config,
|
||||
self.mlir_dialect,
|
||||
)
|
||||
|
||||
# Saves the .mlir file, can be in tosa, linalg or mhlo dialect.
|
||||
# torch-mlir can export tosa or linalg dialects.
|
||||
# tensorflow models get exported to mhlo dialect.
|
||||
def import_mlir(self, model_name, dir):
|
||||
filename = os.path.join(dir, f"{model_name}_{self.frontend}.mlir")
|
||||
with open(filename, "w") as f:
|
||||
f.write(self.model)
|
||||
print(f"Saved mlir in {filename}.")
|
||||
return filename
|
||||
|
||||
# TODO: Instead of passing directory and having names decided by the module
|
||||
# , user may want to save the module with manual names.
|
||||
def save_module(self, dir=os.getcwd()):
|
||||
return export_iree_module_to_vmfb(
|
||||
self.model, self.device, dir, self.frontend
|
||||
self.model, self.device, dir, self.mlir_dialect
|
||||
)
|
||||
|
||||
# TODO: Load a module and directly use it, we will need to set the frontend
|
||||
# in this case.
|
||||
def load_module(self, name):
|
||||
# TODO: Get the input information from the mlir_module.
|
||||
def input_info(self):
|
||||
pass
|
||||
|
||||
|
||||
# TODO: Document shark_eager mode.
|
||||
class SharkEagerMode:
|
||||
def __init__(self, device="cpu"):
|
||||
if device == "refbackend":
|
||||
torch_mlir_tensor.backend = EagerModeRefBackend()
|
||||
else:
|
||||
torch_mlir_tensor.backend = EagerModeIREELinalgOnTensorsBackend(
|
||||
device
|
||||
)
|
||||
self.guard = enable_torch_dispatch_mode(TorchMLIRTensor)
|
||||
self.guard.__enter__()
|
||||
|
||||
def __del__(self):
|
||||
self.guard.__exit__(None, None, None)
|
||||
|
||||
@@ -110,14 +110,14 @@ def get_torch_mlir_module(
|
||||
module,
|
||||
input: tuple,
|
||||
dynamic: bool,
|
||||
tracing_required: bool,
|
||||
from_aot: bool = False,
|
||||
jit_trace: bool,
|
||||
from_torchscript: bool = False,
|
||||
):
|
||||
"""TODO: Include necessary documentation."""
|
||||
|
||||
# Tracing is not required from the aot_module.
|
||||
if not from_aot:
|
||||
module = shark_jit_trace(module, input, dynamic, tracing_required)
|
||||
if not from_torchscript:
|
||||
module = shark_jit_trace(module, input, dynamic, jit_trace)
|
||||
|
||||
mb = ModuleBuilder()
|
||||
class_annotator = ClassAnnotator()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_hf_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
@@ -30,12 +31,16 @@ class MiniLMModuleTester:
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=True,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=True
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_hf_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
@@ -28,12 +29,16 @@ class AlbertModuleTester:
|
||||
model, input, act_out = get_hf_model("albert-base-v2")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=True,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=True
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
@@ -55,9 +60,6 @@ class AlbertModuleTest(unittest.TestCase):
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
@@ -79,7 +81,6 @@ class AlbertModuleTest(unittest.TestCase):
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
@@ -89,7 +90,6 @@ class AlbertModuleTest(unittest.TestCase):
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
@@ -31,11 +32,16 @@ class AlexnetModuleTester:
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_hf_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
@@ -28,12 +29,16 @@ class BertModuleTester:
|
||||
model, input, act_out = get_hf_model("bert-base-uncased")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=True,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=True
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_hf_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
@@ -28,12 +29,16 @@ class DistilBertModuleTester:
|
||||
model, input, act_out = get_hf_model("distilbert-base-uncased")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=True,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=True
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
@@ -31,11 +32,16 @@ class Resnet101ModuleTester:
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
@@ -31,11 +32,16 @@ class Resnet18ModuleTester:
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
@@ -31,11 +32,16 @@ class Resnet50ModuleTester:
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
@@ -31,11 +32,16 @@ class SqueezenetModuleTester:
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
@@ -31,11 +32,16 @@ class WideResnet50ModuleTester:
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
|
||||
Reference in New Issue
Block a user