Add shark-runner command line options for saving modules as .mlir or .vmfb

This commit is contained in:
monorimet
2022-04-15 01:10:30 -05:00
parent c586564356
commit 692cd180f6
3 changed files with 44 additions and 5 deletions

View File

@@ -15,6 +15,9 @@
import iree.runtime as ireert
import iree.compiler as ireec
import numpy as np
import tempfile
import os
from shark.torch_mlir_utils import get_module_name_for_asm_dump
IREE_DEVICE_MAP = {"cpu": "dylib", "gpu": "cuda", "vulkan": "vulkan"}
@@ -32,6 +35,14 @@ def get_iree_compiled_module(module, device: str):
ModuleCompiled = ctx.modules.module["forward"]
return ModuleCompiled, config
def export_iree_module_to_vmfb(module, device: str):
module_name = get_module_name_for_asm_dump(module)
flatbuffer_blob = ireec.compile_str(
str(module), target_backends=[IREE_DEVICE_MAP[device]]
)
filename = os.path.join(tempfile.gettempdir(), module_name + ".vmfb")
with open(filename, 'wb') as f:
f.write(flatbuffer_blob)
def get_results(compiled_vm, input, config):
"""TODO: Documentation"""

View File

@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from shark.torch_mlir_utils import get_torch_mlir_module
from shark.iree_utils import get_results, get_iree_compiled_module
from shark.torch_mlir_utils import get_torch_mlir_module, export_module_to_mlir_file
from shark.iree_utils import get_results, get_iree_compiled_module, export_iree_module_to_vmfb
import argparse
# from functorch_utils import AOTModule
class SharkRunner:
"""TODO: Write the description"""
@@ -29,16 +29,24 @@ class SharkRunner:
tracing_required: bool,
from_aot: bool,
):
self.parser = argparse.ArgumentParser(description='SHARK runner.')
self.parser.add_argument("--save_mlir", default=False, action="store_true", help="Saves input MLIR module to /tmp/ directory.")
self.parser.add_argument("--save_vmfb", default=False, action="store_true", help="Saves iree .vmfb module to /tmp/ directory.")
self.parser.parse_args(namespace=self)
self.torch_module = model
self.input = input
self.torch_mlir_module = get_torch_mlir_module(
model, input, dynamic, tracing_required, from_aot
)
if self.save_mlir:
export_module_to_mlir_file(self.torch_mlir_module)
if self.save_vmfb:
export_iree_module_to_vmfb(self.torch_mlir_module, device)
(
self.iree_compilation_module,
self.iree_config,
) = get_iree_compiled_module(self.torch_mlir_module, device)
def forward(self, input):
return get_results(
self.iree_compilation_module, input, self.iree_config

View File

@@ -15,7 +15,11 @@
import torch
import io
import pickle
import sys
import os
import tempfile
from io import StringIO
from torch_mlir.dialects.torch.importer.jit_ir import (
ClassAnnotator,
ModuleBuilder,
@@ -28,7 +32,23 @@ from torch_mlir_e2e_test.torchscript.serialization import (
from torch_mlir.passmanager import PassManager
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
from torch_mlir.ir import StringAttr
def get_module_name_for_asm_dump(module):
"""Gets a name suitable for an assembly dump.
The name is not guaranteed to be unique.
"""
if not "torch.debug_module_name" in module.operation.attributes:
return "UnnammedModule"
return StringAttr(module.operation.attributes["torch.debug_module_name"]).value
def export_module_to_mlir_file(module):
"""Writes MLIR module to /tmp/module.mlir for debugging or performance use."""
module_name = get_module_name_for_asm_dump(module)
asm = module.operation.get_asm()
filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir")
with open(filename, 'w') as f:
f.write(asm)
def get_input_annotations(inputs: tuple, dynamic: bool) -> list:
"""TODO: Include necessary documentation"""