mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Merge pull request #19 from NodLabs/ean-export-modules
Add shark-runner command line options for saving modules.
This commit is contained in:
@@ -15,6 +15,8 @@
|
||||
import iree.runtime as ireert
|
||||
import iree.compiler as ireec
|
||||
import numpy as np
|
||||
import os
|
||||
from shark.torch_mlir_utils import get_module_name_for_asm_dump
|
||||
|
||||
IREE_DEVICE_MAP = {"cpu": "dylib", "gpu": "cuda", "vulkan": "vulkan"}
|
||||
|
||||
@@ -32,6 +34,14 @@ def get_iree_compiled_module(module, device: str):
|
||||
ModuleCompiled = ctx.modules.module["forward"]
|
||||
return ModuleCompiled, config
|
||||
|
||||
def export_iree_module_to_vmfb(module, device: str, directory: str):
|
||||
module_name = get_module_name_for_asm_dump(module)
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
str(module), target_backends=[IREE_DEVICE_MAP[device]]
|
||||
)
|
||||
filename = os.path.join(directory, module_name + ".vmfb")
|
||||
with open(filename, 'wb') as f:
|
||||
f.write(flatbuffer_blob)
|
||||
|
||||
def get_results(compiled_vm, input, config):
|
||||
"""TODO: Documentation"""
|
||||
|
||||
@@ -12,11 +12,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from shark.torch_mlir_utils import get_torch_mlir_module
|
||||
from shark.iree_utils import get_results, get_iree_compiled_module
|
||||
from shark.torch_mlir_utils import get_torch_mlir_module, export_module_to_mlir_file
|
||||
from shark.iree_utils import get_results, get_iree_compiled_module, export_iree_module_to_vmfb
|
||||
import argparse
|
||||
import os
|
||||
# from functorch_utils import AOTModule
|
||||
|
||||
|
||||
def dir_path(path):
|
||||
if os.path.isdir(path):
|
||||
return path
|
||||
else:
|
||||
raise argparse.ArgumentTypeError(f"readable_dir:{path} is not a valid path")
|
||||
|
||||
class SharkRunner:
|
||||
"""TODO: Write the description"""
|
||||
|
||||
@@ -29,16 +36,25 @@ class SharkRunner:
|
||||
tracing_required: bool,
|
||||
from_aot: bool,
|
||||
):
|
||||
self.parser = argparse.ArgumentParser(description='SHARK runner.')
|
||||
self.parser.add_argument("--repro_dir", help="Directory to which module files will be saved for reproduction or debugging.", type=dir_path, default="/tmp/")
|
||||
self.parser.add_argument("--save_mlir", default=False, action="store_true", help="Saves input MLIR module to /tmp/ directory.")
|
||||
self.parser.add_argument("--save_vmfb", default=False, action="store_true", help="Saves iree .vmfb module to /tmp/ directory.")
|
||||
self.parser.parse_args(namespace=self)
|
||||
self.torch_module = model
|
||||
self.input = input
|
||||
self.torch_mlir_module = get_torch_mlir_module(
|
||||
model, input, dynamic, tracing_required, from_aot
|
||||
)
|
||||
if self.save_mlir:
|
||||
export_module_to_mlir_file(self.torch_mlir_module, self.repro_dir)
|
||||
if self.save_vmfb:
|
||||
export_iree_module_to_vmfb(self.torch_mlir_module, device, self.repro_dir)
|
||||
(
|
||||
self.iree_compilation_module,
|
||||
self.iree_config,
|
||||
) = get_iree_compiled_module(self.torch_mlir_module, device)
|
||||
|
||||
|
||||
def forward(self, input):
|
||||
return get_results(
|
||||
self.iree_compilation_module, input, self.iree_config
|
||||
|
||||
@@ -15,7 +15,10 @@
|
||||
import torch
|
||||
import io
|
||||
import pickle
|
||||
import sys
|
||||
import os
|
||||
|
||||
from io import StringIO
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import (
|
||||
ClassAnnotator,
|
||||
ModuleBuilder,
|
||||
@@ -28,7 +31,23 @@ from torch_mlir_e2e_test.torchscript.serialization import (
|
||||
|
||||
from torch_mlir.passmanager import PassManager
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
from torch_mlir.ir import StringAttr
|
||||
|
||||
def get_module_name_for_asm_dump(module):
|
||||
"""Gets a name suitable for an assembly dump.
|
||||
The name is not guaranteed to be unique.
|
||||
"""
|
||||
if not "torch.debug_module_name" in module.operation.attributes:
|
||||
return "UnnammedModule"
|
||||
return StringAttr(module.operation.attributes["torch.debug_module_name"]).value
|
||||
|
||||
def export_module_to_mlir_file(module, directory: str):
|
||||
"""Writes MLIR module to /tmp/module.mlir for debugging or performance use."""
|
||||
module_name = get_module_name_for_asm_dump(module)
|
||||
asm = module.operation.get_asm()
|
||||
filename = os.path.join(directory, module_name + ".mlir")
|
||||
with open(filename, 'w') as f:
|
||||
f.write(asm)
|
||||
|
||||
def get_input_annotations(inputs: tuple, dynamic: bool) -> list:
|
||||
"""TODO: Include necessary documentation"""
|
||||
|
||||
Reference in New Issue
Block a user