Add run_on_refbackend and gpu configs. (#52)

`run_on_refbackend` is added to run linalg_on_tensors_backend compiled
code on torch-mlir's refbackend. Also, added gpu configs and flags.
This commit is contained in:
Prashant Kumar
2022-05-24 02:48:11 +05:30
committed by GitHub
parent 2df9128749
commit 4b63e0e04a
3 changed files with 24 additions and 3 deletions

View File

@@ -52,7 +52,7 @@ def check_device_drivers(device):
def get_iree_compiled_module(module, device: str):
"""Given an mlir module returns the compiled .vmfb"""
args = ["--iree-llvm-target-cpu-features=host"]
if (device == "cpu"):
if device == "cpu":
find_triple_cmd = "uname -s -m"
os_name, proc_name = subprocess.run(
find_triple_cmd, shell=True, stdout=subprocess.PIPE,
@@ -71,6 +71,12 @@ def get_iree_compiled_module(module, device: str):
raise Exception(error_message)
print(f"Target triple found:{target_triple}")
args.append(f"-iree-llvm-target-triple={target_triple}")
if device == "gpu":
args += ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa"]
ireert.flags.FUNCTION_INPUT_VALIDATION = False
ireert.flags.parse_flags("--cuda_allow_inline_execution")
flatbuffer_blob = ireec.compile_str(
str(module), target_backends=[IREE_DEVICE_MAP[device]], extra_args=args)
vm_module = ireert.VmModule.from_flatbuffer(flatbuffer_blob)

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from shark.torch_mlir_utils import get_torch_mlir_module, export_module_to_mlir_file
from shark.torch_mlir_utils import get_torch_mlir_module, export_module_to_mlir_file, run_on_refbackend
from shark.iree_utils import get_results, get_iree_compiled_module, export_iree_module_to_vmfb
import os
from shark.functorch_utils import AOTModule
@@ -23,7 +23,7 @@ import time
class SharkRunner:
"""TODO: Write the description"""
"""Base class for Shark Inference and Shark Runner."""
def __init__(
self,
@@ -55,6 +55,10 @@ class SharkRunner:
return get_results(self.iree_compilation_module, input,
self.iree_config)
# Refbackend is used only for debugging purpose. It can be quite slow.
def run_on_refbackend(self, input):
return run_on_refbackend(self.torch_mlir_module, input)
class SharkInference:
"""TODO: Write the description"""
@@ -107,6 +111,9 @@ class SharkInference:
input_list = [x.detach().numpy() for x in inputs]
return self.shark_runner.forward(input_list)
def run_on_refbackend(self, inputs):
self.shark_runner.run_on_refbackend(inputs)
class SharkTrainer:
"""TODO: Write the description"""

View File

@@ -27,6 +27,8 @@ from torch_mlir_e2e_test.torchscript.serialization import (
extract_serializable_annotations, apply_serializable_annotations,
SerializableTest)
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
from torch_mlir.passmanager import PassManager
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
from torch_mlir.ir import StringAttr
@@ -66,6 +68,12 @@ def get_input_annotations(inputs: tuple, dynamic: bool) -> list:
annotations_list.append(tuple(temp_list))
return annotations_list
def run_on_refbackend(torch_module, inputs):
backend = refbackend.RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(torch_module)
jit_module = backend.load(compiled)
np_inputs = [x.numpy() for x in inputs]
return jit_module.forward(np_inputs[0])
def shark_jit_trace(module, input: tuple, dynamic: bool,
tracing_required: bool):