mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-08 21:38:04 -05:00
[SharkInference] Make SharkInference compile the entire module
-- Previously SharkInference was compiling and providing run APIs for a harcoded function with function name "forward". -- This commit makes the compiling functionality generic and now any function being defined within the module can be run. -- It also creates an API to fetch all the function names defined within the compiled module. Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
This commit is contained in:
@@ -21,7 +21,6 @@ import torch
|
||||
from iree.runtime import DeviceArray
|
||||
from torch_mlir._mlir_libs._mlir.ir import Module
|
||||
from torch_mlir.compiler_utils import (
|
||||
get_module_name_for_debug_dump,
|
||||
run_pipeline_with_repro_report,
|
||||
)
|
||||
from torch_mlir.eager_mode.torch_mlir_eager_backend import (
|
||||
@@ -64,14 +63,13 @@ class EagerModeIREELinalgOnTensorsBackend(TorchMLIREagerBackend):
|
||||
)
|
||||
|
||||
def compile(self, imported_module: Module):
|
||||
fn_name = get_module_name_for_debug_dump(imported_module)
|
||||
run_pipeline_with_repro_report(
|
||||
imported_module,
|
||||
"torch-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline",
|
||||
"EagerMode",
|
||||
)
|
||||
callable, _ = get_iree_compiled_module(
|
||||
imported_module, self.raw_device_str, func_name=fn_name
|
||||
imported_module, self.raw_device_str
|
||||
)
|
||||
return callable
|
||||
|
||||
|
||||
@@ -234,7 +234,6 @@ def compile_module_to_flatbuffer(
|
||||
module,
|
||||
device,
|
||||
frontend,
|
||||
func_name,
|
||||
model_config_path,
|
||||
extra_args,
|
||||
model_name="None",
|
||||
@@ -277,7 +276,7 @@ def compile_module_to_flatbuffer(
|
||||
return flatbuffer_blob
|
||||
|
||||
|
||||
def get_iree_module(flatbuffer_blob, device, func_name):
|
||||
def get_iree_module(flatbuffer_blob, device):
|
||||
# Returns the compiled module and the configs.
|
||||
config = get_iree_runtime_config(device)
|
||||
vm_module = ireert.VmModule.from_flatbuffer(
|
||||
@@ -285,7 +284,7 @@ def get_iree_module(flatbuffer_blob, device, func_name):
|
||||
)
|
||||
ctx = ireert.SystemContext(config=config)
|
||||
ctx.add_vm_module(vm_module)
|
||||
ModuleCompiled = ctx.modules.module[func_name]
|
||||
ModuleCompiled = ctx.modules.module
|
||||
return ModuleCompiled, config
|
||||
|
||||
|
||||
@@ -293,25 +292,22 @@ def get_iree_compiled_module(
|
||||
module,
|
||||
device: str,
|
||||
frontend: str = "torch",
|
||||
func_name: str = "forward",
|
||||
model_config_path: str = None,
|
||||
extra_args: list = [],
|
||||
):
|
||||
"""Given a module returns the compiled .vmfb and configs"""
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, frontend, func_name, model_config_path, extra_args
|
||||
module, device, frontend, model_config_path, extra_args
|
||||
)
|
||||
return get_iree_module(flatbuffer_blob, device, func_name)
|
||||
return get_iree_module(flatbuffer_blob, device)
|
||||
|
||||
|
||||
def load_flatbuffer(
|
||||
flatbuffer_path: str, device: str, func_name: str = "forward"
|
||||
):
|
||||
def load_flatbuffer(flatbuffer_path: str, device: str):
|
||||
|
||||
with open(os.path.join(flatbuffer_path), "rb") as f:
|
||||
flatbuffer_blob = f.read()
|
||||
|
||||
return get_iree_module(flatbuffer_blob, device, func_name)
|
||||
return get_iree_module(flatbuffer_blob, device)
|
||||
|
||||
|
||||
def export_iree_module_to_vmfb(
|
||||
@@ -319,20 +315,19 @@ def export_iree_module_to_vmfb(
|
||||
device: str,
|
||||
directory: str,
|
||||
mlir_dialect: str = "linalg",
|
||||
func_name: str = "forward",
|
||||
model_config_path: str = None,
|
||||
module_name: str = None,
|
||||
extra_args: list = [],
|
||||
):
|
||||
# Compiles the module given specs and saves it as .vmfb file.
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, mlir_dialect, func_name, model_config_path, extra_args
|
||||
module, device, mlir_dialect, model_config_path, extra_args
|
||||
)
|
||||
if module_name is None:
|
||||
device_name = (
|
||||
device if "://" not in device else "-".join(device.split("://"))
|
||||
)
|
||||
module_name = f"{mlir_dialect}_{func_name}_{device_name}"
|
||||
module_name = f"{mlir_dialect}_{device_name}"
|
||||
filename = os.path.join(directory, module_name + ".vmfb")
|
||||
print(f"Saved vmfb in {filename}.")
|
||||
with open(filename, "wb") as f:
|
||||
@@ -355,11 +350,16 @@ def export_module_to_mlir_file(module, frontend, directory: str):
|
||||
|
||||
|
||||
def get_results(
|
||||
compiled_vm, input, config, frontend="torch", send_to_host=True
|
||||
compiled_vm,
|
||||
function_name,
|
||||
input,
|
||||
config,
|
||||
frontend="torch",
|
||||
send_to_host=True,
|
||||
):
|
||||
"""Runs a .vmfb file given inputs and config and returns output."""
|
||||
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
|
||||
result = compiled_vm(*device_inputs)
|
||||
result = compiled_vm[function_name](*device_inputs)
|
||||
result_tensors = []
|
||||
if isinstance(result, tuple):
|
||||
if send_to_host:
|
||||
@@ -376,7 +376,7 @@ def get_results(
|
||||
return np.copy(res)
|
||||
return data
|
||||
else:
|
||||
if send_to_host:
|
||||
if send_to_host and result is not None:
|
||||
return result.to_host()
|
||||
return result
|
||||
|
||||
|
||||
@@ -60,7 +60,6 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
def __init__(
|
||||
self,
|
||||
mlir_module: bytes,
|
||||
function_name: str = "forward",
|
||||
device: str = "none",
|
||||
mlir_dialect: str = "linalg",
|
||||
extra_args: list = [],
|
||||
@@ -73,7 +72,6 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
SharkRunner.__init__(
|
||||
self,
|
||||
mlir_module,
|
||||
function_name,
|
||||
device,
|
||||
self.mlir_dialect,
|
||||
self.extra_args,
|
||||
@@ -85,7 +83,6 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
device,
|
||||
shark_args.repro_dir,
|
||||
self.mlir_dialect,
|
||||
function_name,
|
||||
extra_args=self.extra_args,
|
||||
)
|
||||
|
||||
@@ -185,11 +182,11 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
def benchmark_python(self, inputs):
|
||||
input_list = [x for x in inputs]
|
||||
for i in range(shark_args.num_warmup_iterations):
|
||||
self.run(input_list)
|
||||
self.run("forward", input_list)
|
||||
|
||||
begin = time.time()
|
||||
for i in range(shark_args.num_iterations):
|
||||
out = self.run(input_list)
|
||||
out = self.run("forward", input_list)
|
||||
if i == shark_args.num_iterations - 1:
|
||||
end = time.time()
|
||||
print(
|
||||
|
||||
@@ -40,8 +40,6 @@ class SharkInference:
|
||||
----------
|
||||
mlir_module : str
|
||||
mlir_module represented in string; modules from torch-mlir are serialized in bytecode format.
|
||||
function_name : str
|
||||
function to execute in the given mlir_module.
|
||||
device : str
|
||||
device to execute the mlir_module on.
|
||||
currently supports cpu, cuda, vulkan, and metal backends.
|
||||
@@ -53,10 +51,10 @@ class SharkInference:
|
||||
|
||||
Methods
|
||||
-------
|
||||
run(inputs=None):
|
||||
Runs the mlir_module with the given inputs, if the inputs are not
|
||||
given it autogenerates the inputs. Also, the inputs should be a
|
||||
numpy array.
|
||||
__call__(function_name, inputs=None):
|
||||
Runs the function with `function_name` within the mlir_module along
|
||||
with the given inputs, if the inputs are not given it autogenerates the
|
||||
inputs. Also, the inputs should be a numpy array.
|
||||
input_info():
|
||||
Gives the information about the inputs required by the `function_name`.
|
||||
This can be expensive as it does string matching to do so.
|
||||
@@ -66,7 +64,6 @@ class SharkInference:
|
||||
def __init__(
|
||||
self,
|
||||
mlir_module: bytes,
|
||||
function_name: str = "forward",
|
||||
device: str = "none",
|
||||
mlir_dialect: str = "linalg",
|
||||
is_benchmark: bool = False,
|
||||
@@ -74,7 +71,6 @@ class SharkInference:
|
||||
dispatch_benchmark_dir: str = "temp_dispatch_benchmarks",
|
||||
):
|
||||
self.mlir_module = mlir_module
|
||||
self.function_name = function_name
|
||||
self.device = shark_args.device if device == "none" else device
|
||||
self.mlir_dialect = mlir_dialect
|
||||
self.is_benchmark = is_benchmark
|
||||
@@ -113,7 +109,6 @@ class SharkInference:
|
||||
|
||||
self.shark_runner = SharkBenchmarkRunner(
|
||||
self.mlir_module,
|
||||
self.function_name,
|
||||
self.device,
|
||||
self.mlir_dialect,
|
||||
extra_args=extra_args,
|
||||
@@ -122,7 +117,6 @@ class SharkInference:
|
||||
else:
|
||||
self.shark_runner = SharkRunner(
|
||||
self.mlir_module,
|
||||
self.function_name,
|
||||
self.device,
|
||||
self.mlir_dialect,
|
||||
extra_args=extra_args,
|
||||
@@ -138,21 +132,25 @@ class SharkInference:
|
||||
os.system(f"rm -rf {self.temp_dispatch_benchmarks_dir}")
|
||||
|
||||
# inputs are considered to be tuple of np.array.
|
||||
def forward(self, inputs: tuple, send_to_host=True):
|
||||
return self.shark_runner.run(inputs, send_to_host)
|
||||
def __call__(self, function_name: str, inputs: tuple, send_to_host=True):
|
||||
return self.shark_runner.run(function_name, inputs, send_to_host)
|
||||
|
||||
# Get all function names defined within the compiled module.
|
||||
def get_functions_in_module(self):
|
||||
return self.shark_runner.get_functions_in_module()
|
||||
|
||||
# Captures the static input information from the mlir_module.
|
||||
# TODO(pashu123): Generate the input information for dynamic shapes.
|
||||
def _input_info(self):
|
||||
def _input_info(self, function_name):
|
||||
# func_key to get the line which contains the function.
|
||||
func_key = "func.func @" + self.function_name
|
||||
func_key = "func.func @" + function_name
|
||||
func_header = None
|
||||
for line in str(self.mlir_module).splitlines():
|
||||
if func_key in line:
|
||||
func_header = line
|
||||
break
|
||||
if func_header is None:
|
||||
print(f"Function: {self.function_name} not found")
|
||||
print(f"Function: {function_name} not found")
|
||||
|
||||
import re
|
||||
|
||||
@@ -190,7 +188,6 @@ class SharkInference:
|
||||
self.device,
|
||||
dir,
|
||||
self.mlir_dialect,
|
||||
self.function_name,
|
||||
module_name=module_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
@@ -198,7 +195,6 @@ class SharkInference:
|
||||
# load and return the module.
|
||||
def load_module(self, path, extra_args=[]):
|
||||
self.shark_runner = SharkRunner(
|
||||
function_name=self.function_name,
|
||||
device=self.device,
|
||||
compile_vmfb=False,
|
||||
extra_args=extra_args,
|
||||
@@ -209,6 +205,5 @@ class SharkInference:
|
||||
) = load_flatbuffer(
|
||||
path,
|
||||
self.device,
|
||||
self.function_name,
|
||||
)
|
||||
return
|
||||
|
||||
@@ -39,8 +39,6 @@ class SharkRunner:
|
||||
----------
|
||||
mlir_module : str
|
||||
mlir_module represented in string.
|
||||
function_name : str
|
||||
function to execute in the given mlir_module.
|
||||
device : str
|
||||
device to execute the mlir_module on.
|
||||
currently supports cpu, cuda, vulkan, and metal backends.
|
||||
@@ -50,10 +48,10 @@ class SharkRunner:
|
||||
|
||||
Methods
|
||||
-------
|
||||
run(inputs=None):
|
||||
Runs the mlir_module with the given inputs, if the inputs are not
|
||||
given it autogenerates the inputs. Also, the inputs should be a
|
||||
numpy array.
|
||||
run(function_name, inputs=None):
|
||||
Runs the function with `function_name` within the mlir_module along
|
||||
with the given inputs, if the inputs are not given it autogenerates the
|
||||
inputs. Also, the inputs should be a numpy array.
|
||||
input_info():
|
||||
Gives the information about the inputs required by the `function_name`.
|
||||
This can be expensive as it does string matching to do so.
|
||||
@@ -62,14 +60,12 @@ class SharkRunner:
|
||||
def __init__(
|
||||
self,
|
||||
mlir_module: bytes = None,
|
||||
function_name: str = "forward",
|
||||
device: str = "none",
|
||||
mlir_dialect: str = "linalg",
|
||||
extra_args: list = [],
|
||||
compile_vmfb: bool = True,
|
||||
):
|
||||
self.mlir_module = mlir_module
|
||||
self.function_name = function_name
|
||||
self.device = shark_args.device if device == "none" else device
|
||||
self.mlir_dialect = mlir_dialect
|
||||
self.extra_args = extra_args
|
||||
@@ -87,15 +83,19 @@ class SharkRunner:
|
||||
self.mlir_module,
|
||||
self.device,
|
||||
self.mlir_dialect,
|
||||
func_name=self.function_name,
|
||||
extra_args=self.extra_args,
|
||||
)
|
||||
|
||||
def run(self, inputs: tuple, send_to_host=False):
|
||||
def run(self, function_name, inputs: tuple, send_to_host=False):
|
||||
return get_results(
|
||||
self.iree_compilation_module,
|
||||
function_name,
|
||||
inputs,
|
||||
self.iree_config,
|
||||
self.mlir_dialect,
|
||||
send_to_host,
|
||||
)
|
||||
|
||||
# Get all function names defined within the compiled module.
|
||||
def get_functions_in_module(self):
|
||||
return self.iree_compilation_module._vm_module.function_names
|
||||
|
||||
@@ -148,7 +148,6 @@ class SharkModuleTester:
|
||||
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
func_name,
|
||||
device=device,
|
||||
mlir_dialect=self.config["dialect"],
|
||||
is_benchmark=self.benchmark,
|
||||
@@ -163,7 +162,7 @@ class SharkModuleTester:
|
||||
self.upload_repro()
|
||||
raise
|
||||
|
||||
result = shark_module.forward(inputs)
|
||||
result = shark_module(func_name, inputs)
|
||||
golden_out, result = self.postprocess_outputs(golden_out, result)
|
||||
try:
|
||||
np.testing.assert_allclose(
|
||||
|
||||
Reference in New Issue
Block a user