Merge pull request #19 from NodLabs/ean-export-modules

Add shark-runner command line options for saving modules.
This commit is contained in:
powderluv
2022-04-15 16:17:04 -07:00
committed by GitHub
3 changed files with 49 additions and 4 deletions

View File

@@ -15,6 +15,8 @@
import iree.runtime as ireert
import iree.compiler as ireec
import numpy as np
import os
from shark.torch_mlir_utils import get_module_name_for_asm_dump
IREE_DEVICE_MAP = {"cpu": "dylib", "gpu": "cuda", "vulkan": "vulkan"}
@@ -32,6 +34,14 @@ def get_iree_compiled_module(module, device: str):
ModuleCompiled = ctx.modules.module["forward"]
return ModuleCompiled, config
def export_iree_module_to_vmfb(module, device: str, directory: str):
module_name = get_module_name_for_asm_dump(module)
flatbuffer_blob = ireec.compile_str(
str(module), target_backends=[IREE_DEVICE_MAP[device]]
)
filename = os.path.join(directory, module_name + ".vmfb")
with open(filename, 'wb') as f:
f.write(flatbuffer_blob)
def get_results(compiled_vm, input, config):
"""TODO: Documentation"""

View File

@@ -12,11 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from shark.torch_mlir_utils import get_torch_mlir_module
from shark.iree_utils import get_results, get_iree_compiled_module
from shark.torch_mlir_utils import get_torch_mlir_module, export_module_to_mlir_file
from shark.iree_utils import get_results, get_iree_compiled_module, export_iree_module_to_vmfb
import argparse
import os
# from functorch_utils import AOTModule
def dir_path(path):
if os.path.isdir(path):
return path
else:
raise argparse.ArgumentTypeError(f"readable_dir:{path} is not a valid path")
class SharkRunner:
"""TODO: Write the description"""
@@ -29,16 +36,25 @@ class SharkRunner:
tracing_required: bool,
from_aot: bool,
):
self.parser = argparse.ArgumentParser(description='SHARK runner.')
self.parser.add_argument("--repro_dir", help="Directory to which module files will be saved for reproduction or debugging.", type=dir_path, default="/tmp/")
self.parser.add_argument("--save_mlir", default=False, action="store_true", help="Saves input MLIR module to /tmp/ directory.")
self.parser.add_argument("--save_vmfb", default=False, action="store_true", help="Saves iree .vmfb module to /tmp/ directory.")
self.parser.parse_args(namespace=self)
self.torch_module = model
self.input = input
self.torch_mlir_module = get_torch_mlir_module(
model, input, dynamic, tracing_required, from_aot
)
if self.save_mlir:
export_module_to_mlir_file(self.torch_mlir_module, self.repro_dir)
if self.save_vmfb:
export_iree_module_to_vmfb(self.torch_mlir_module, device, self.repro_dir)
(
self.iree_compilation_module,
self.iree_config,
) = get_iree_compiled_module(self.torch_mlir_module, device)
def forward(self, input):
return get_results(
self.iree_compilation_module, input, self.iree_config

View File

@@ -15,7 +15,10 @@
import torch
import io
import pickle
import sys
import os
from io import StringIO
from torch_mlir.dialects.torch.importer.jit_ir import (
ClassAnnotator,
ModuleBuilder,
@@ -28,7 +31,23 @@ from torch_mlir_e2e_test.torchscript.serialization import (
from torch_mlir.passmanager import PassManager
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
from torch_mlir.ir import StringAttr
def get_module_name_for_asm_dump(module):
"""Gets a name suitable for an assembly dump.
The name is not guaranteed to be unique.
"""
if not "torch.debug_module_name" in module.operation.attributes:
return "UnnammedModule"
return StringAttr(module.operation.attributes["torch.debug_module_name"]).value
def export_module_to_mlir_file(module, directory: str):
"""Writes MLIR module to /tmp/module.mlir for debugging or performance use."""
module_name = get_module_name_for_asm_dump(module)
asm = module.operation.get_asm()
filename = os.path.join(directory, module_name + ".mlir")
with open(filename, 'w') as f:
f.write(asm)
def get_input_annotations(inputs: tuple, dynamic: bool) -> list:
"""TODO: Include necessary documentation"""