[SharkInference] Make SharkInference compile the entire module

-- Previously SharkInference was compiling and providing run APIs
   for a harcoded function with function name "forward".
-- This commit makes the compiling functionality generic and now
   any function being defined within the module can be run.
-- It also creates an API to fetch all the function names defined
   within the compiled module.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
This commit is contained in:
Abhishek-Varma
2022-12-22 09:23:50 +00:00
parent 136021424c
commit ce00c1c5e1
6 changed files with 43 additions and 54 deletions

View File

@@ -21,7 +21,6 @@ import torch
from iree.runtime import DeviceArray
from torch_mlir._mlir_libs._mlir.ir import Module
from torch_mlir.compiler_utils import (
get_module_name_for_debug_dump,
run_pipeline_with_repro_report,
)
from torch_mlir.eager_mode.torch_mlir_eager_backend import (
@@ -64,14 +63,13 @@ class EagerModeIREELinalgOnTensorsBackend(TorchMLIREagerBackend):
)
def compile(self, imported_module: Module):
fn_name = get_module_name_for_debug_dump(imported_module)
run_pipeline_with_repro_report(
imported_module,
"torch-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline",
"EagerMode",
)
callable, _ = get_iree_compiled_module(
imported_module, self.raw_device_str, func_name=fn_name
imported_module, self.raw_device_str
)
return callable

View File

@@ -234,7 +234,6 @@ def compile_module_to_flatbuffer(
module,
device,
frontend,
func_name,
model_config_path,
extra_args,
model_name="None",
@@ -277,7 +276,7 @@ def compile_module_to_flatbuffer(
return flatbuffer_blob
def get_iree_module(flatbuffer_blob, device, func_name):
def get_iree_module(flatbuffer_blob, device):
# Returns the compiled module and the configs.
config = get_iree_runtime_config(device)
vm_module = ireert.VmModule.from_flatbuffer(
@@ -285,7 +284,7 @@ def get_iree_module(flatbuffer_blob, device, func_name):
)
ctx = ireert.SystemContext(config=config)
ctx.add_vm_module(vm_module)
ModuleCompiled = ctx.modules.module[func_name]
ModuleCompiled = ctx.modules.module
return ModuleCompiled, config
@@ -293,25 +292,22 @@ def get_iree_compiled_module(
module,
device: str,
frontend: str = "torch",
func_name: str = "forward",
model_config_path: str = None,
extra_args: list = [],
):
"""Given a module returns the compiled .vmfb and configs"""
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, frontend, func_name, model_config_path, extra_args
module, device, frontend, model_config_path, extra_args
)
return get_iree_module(flatbuffer_blob, device, func_name)
return get_iree_module(flatbuffer_blob, device)
def load_flatbuffer(
flatbuffer_path: str, device: str, func_name: str = "forward"
):
def load_flatbuffer(flatbuffer_path: str, device: str):
with open(os.path.join(flatbuffer_path), "rb") as f:
flatbuffer_blob = f.read()
return get_iree_module(flatbuffer_blob, device, func_name)
return get_iree_module(flatbuffer_blob, device)
def export_iree_module_to_vmfb(
@@ -319,20 +315,19 @@ def export_iree_module_to_vmfb(
device: str,
directory: str,
mlir_dialect: str = "linalg",
func_name: str = "forward",
model_config_path: str = None,
module_name: str = None,
extra_args: list = [],
):
# Compiles the module given specs and saves it as .vmfb file.
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, mlir_dialect, func_name, model_config_path, extra_args
module, device, mlir_dialect, model_config_path, extra_args
)
if module_name is None:
device_name = (
device if "://" not in device else "-".join(device.split("://"))
)
module_name = f"{mlir_dialect}_{func_name}_{device_name}"
module_name = f"{mlir_dialect}_{device_name}"
filename = os.path.join(directory, module_name + ".vmfb")
print(f"Saved vmfb in {filename}.")
with open(filename, "wb") as f:
@@ -355,11 +350,16 @@ def export_module_to_mlir_file(module, frontend, directory: str):
def get_results(
compiled_vm, input, config, frontend="torch", send_to_host=True
compiled_vm,
function_name,
input,
config,
frontend="torch",
send_to_host=True,
):
"""Runs a .vmfb file given inputs and config and returns output."""
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
result = compiled_vm(*device_inputs)
result = compiled_vm[function_name](*device_inputs)
result_tensors = []
if isinstance(result, tuple):
if send_to_host:
@@ -376,7 +376,7 @@ def get_results(
return np.copy(res)
return data
else:
if send_to_host:
if send_to_host and result is not None:
return result.to_host()
return result

View File

@@ -60,7 +60,6 @@ class SharkBenchmarkRunner(SharkRunner):
def __init__(
self,
mlir_module: bytes,
function_name: str = "forward",
device: str = "none",
mlir_dialect: str = "linalg",
extra_args: list = [],
@@ -73,7 +72,6 @@ class SharkBenchmarkRunner(SharkRunner):
SharkRunner.__init__(
self,
mlir_module,
function_name,
device,
self.mlir_dialect,
self.extra_args,
@@ -85,7 +83,6 @@ class SharkBenchmarkRunner(SharkRunner):
device,
shark_args.repro_dir,
self.mlir_dialect,
function_name,
extra_args=self.extra_args,
)
@@ -185,11 +182,11 @@ class SharkBenchmarkRunner(SharkRunner):
def benchmark_python(self, inputs):
input_list = [x for x in inputs]
for i in range(shark_args.num_warmup_iterations):
self.run(input_list)
self.run("forward", input_list)
begin = time.time()
for i in range(shark_args.num_iterations):
out = self.run(input_list)
out = self.run("forward", input_list)
if i == shark_args.num_iterations - 1:
end = time.time()
print(

View File

@@ -40,8 +40,6 @@ class SharkInference:
----------
mlir_module : str
mlir_module represented in string; modules from torch-mlir are serialized in bytecode format.
function_name : str
function to execute in the given mlir_module.
device : str
device to execute the mlir_module on.
currently supports cpu, cuda, vulkan, and metal backends.
@@ -53,10 +51,10 @@ class SharkInference:
Methods
-------
run(inputs=None):
Runs the mlir_module with the given inputs, if the inputs are not
given it autogenerates the inputs. Also, the inputs should be a
numpy array.
__call__(function_name, inputs=None):
Runs the function with `function_name` within the mlir_module along
with the given inputs, if the inputs are not given it autogenerates the
inputs. Also, the inputs should be a numpy array.
input_info():
Gives the information about the inputs required by the `function_name`.
This can be expensive as it does string matching to do so.
@@ -66,7 +64,6 @@ class SharkInference:
def __init__(
self,
mlir_module: bytes,
function_name: str = "forward",
device: str = "none",
mlir_dialect: str = "linalg",
is_benchmark: bool = False,
@@ -74,7 +71,6 @@ class SharkInference:
dispatch_benchmark_dir: str = "temp_dispatch_benchmarks",
):
self.mlir_module = mlir_module
self.function_name = function_name
self.device = shark_args.device if device == "none" else device
self.mlir_dialect = mlir_dialect
self.is_benchmark = is_benchmark
@@ -113,7 +109,6 @@ class SharkInference:
self.shark_runner = SharkBenchmarkRunner(
self.mlir_module,
self.function_name,
self.device,
self.mlir_dialect,
extra_args=extra_args,
@@ -122,7 +117,6 @@ class SharkInference:
else:
self.shark_runner = SharkRunner(
self.mlir_module,
self.function_name,
self.device,
self.mlir_dialect,
extra_args=extra_args,
@@ -138,21 +132,25 @@ class SharkInference:
os.system(f"rm -rf {self.temp_dispatch_benchmarks_dir}")
# inputs are considered to be tuple of np.array.
def forward(self, inputs: tuple, send_to_host=True):
return self.shark_runner.run(inputs, send_to_host)
def __call__(self, function_name: str, inputs: tuple, send_to_host=True):
return self.shark_runner.run(function_name, inputs, send_to_host)
# Get all function names defined within the compiled module.
def get_functions_in_module(self):
return self.shark_runner.get_functions_in_module()
# Captures the static input information from the mlir_module.
# TODO(pashu123): Generate the input information for dynamic shapes.
def _input_info(self):
def _input_info(self, function_name):
# func_key to get the line which contains the function.
func_key = "func.func @" + self.function_name
func_key = "func.func @" + function_name
func_header = None
for line in str(self.mlir_module).splitlines():
if func_key in line:
func_header = line
break
if func_header is None:
print(f"Function: {self.function_name} not found")
print(f"Function: {function_name} not found")
import re
@@ -190,7 +188,6 @@ class SharkInference:
self.device,
dir,
self.mlir_dialect,
self.function_name,
module_name=module_name,
extra_args=extra_args,
)
@@ -198,7 +195,6 @@ class SharkInference:
# load and return the module.
def load_module(self, path, extra_args=[]):
self.shark_runner = SharkRunner(
function_name=self.function_name,
device=self.device,
compile_vmfb=False,
extra_args=extra_args,
@@ -209,6 +205,5 @@ class SharkInference:
) = load_flatbuffer(
path,
self.device,
self.function_name,
)
return

View File

@@ -39,8 +39,6 @@ class SharkRunner:
----------
mlir_module : str
mlir_module represented in string.
function_name : str
function to execute in the given mlir_module.
device : str
device to execute the mlir_module on.
currently supports cpu, cuda, vulkan, and metal backends.
@@ -50,10 +48,10 @@ class SharkRunner:
Methods
-------
run(inputs=None):
Runs the mlir_module with the given inputs, if the inputs are not
given it autogenerates the inputs. Also, the inputs should be a
numpy array.
run(function_name, inputs=None):
Runs the function with `function_name` within the mlir_module along
with the given inputs, if the inputs are not given it autogenerates the
inputs. Also, the inputs should be a numpy array.
input_info():
Gives the information about the inputs required by the `function_name`.
This can be expensive as it does string matching to do so.
@@ -62,14 +60,12 @@ class SharkRunner:
def __init__(
self,
mlir_module: bytes = None,
function_name: str = "forward",
device: str = "none",
mlir_dialect: str = "linalg",
extra_args: list = [],
compile_vmfb: bool = True,
):
self.mlir_module = mlir_module
self.function_name = function_name
self.device = shark_args.device if device == "none" else device
self.mlir_dialect = mlir_dialect
self.extra_args = extra_args
@@ -87,15 +83,19 @@ class SharkRunner:
self.mlir_module,
self.device,
self.mlir_dialect,
func_name=self.function_name,
extra_args=self.extra_args,
)
def run(self, inputs: tuple, send_to_host=False):
def run(self, function_name, inputs: tuple, send_to_host=False):
return get_results(
self.iree_compilation_module,
function_name,
inputs,
self.iree_config,
self.mlir_dialect,
send_to_host,
)
# Get all function names defined within the compiled module.
def get_functions_in_module(self):
return self.iree_compilation_module._vm_module.function_names

View File

@@ -148,7 +148,6 @@ class SharkModuleTester:
shark_module = SharkInference(
model,
func_name,
device=device,
mlir_dialect=self.config["dialect"],
is_benchmark=self.benchmark,
@@ -163,7 +162,7 @@ class SharkModuleTester:
self.upload_repro()
raise
result = shark_module.forward(inputs)
result = shark_module(func_name, inputs)
golden_out, result = self.postprocess_outputs(golden_out, result)
try:
np.testing.assert_allclose(