Enable Shark Trainer APIs

This commit is contained in:
Prashant Kumar
2022-04-21 14:49:57 +00:00
parent d9f47b59f2
commit 37468f7bb8
6 changed files with 132 additions and 89 deletions

3
.style.yapf Normal file
View File

@@ -0,0 +1,3 @@
[style]
based_on_style = google
column_limit = 80

View File

@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import torchvision.models as models
from shark_runner import SharkTrainer
from shark.shark_runner import SharkTrainer
class NeuralNet(nn.Module):
@@ -36,17 +36,17 @@ results = shark_module.train((input,))
# print(results)
input = torch.randn(1, 3, 224, 224)
labels = torch.randn(1, 1000)
# input = torch.randn(1, 3, 224, 224)
# labels = torch.randn(1, 1000)
class Resnet50Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.resnet = models.resnet50(pretrained=True)
self.train(True)
# class Resnet50Module(torch.nn.Module):
# def __init__(self):
# super().__init__()
# self.resnet = models.resnet50(pretrained=True)
# self.train(True)
def forward(self, img):
return self.resnet.forward(img)
# def forward(self, img):
# return self.resnet.forward(img)
shark_module = SharkTrainer(Resnet50Module(), (input,), (labels,), from_aot=True)
results = shark_module.train((input,))
# shark_module = SharkTrainer(Resnet50Module(), (input,), (labels,), from_aot=True)
# results = shark_module.train((input,))

View File

@@ -23,6 +23,7 @@ import copy
class AOTModule:
def __init__(self, model, inputs, labels=None, custom_inference_fn=None):
self.model = model
self.inputs = inputs
@@ -56,41 +57,71 @@ class AOTModule:
loss.backward()
optimizer.step()
# Replaces None types with zeros.
# def change_fx_graph_return_to_tuple(self, fx_g: fx.GraphModule) -> fx.GraphModule:
# for node in fx_g.graph.nodes:
# if node.op == 'output':
# # output nodes always have one argument
# node_arg = node.args[0]
# out_nodes = []
# if isinstance(node_arg, list):
# for out_node in node_arg:
# if isinstance(out_node, type(None)):
# print("None node found replacing with zeros")
# with fx_g.graph.inserting_before(out_node):
# new_node = fx_g.graph.call_function(torch.ops.aten.zeros, (-1,))
# out_nodes.append(new_node)
# else:
# out_nodes.append(out_node)
# node.args = (tuple(out_nodes),)
# fx_g.graph.lint()
# fx_g.recompile()
# return fx_g
# Doesn't replace the None type.
def change_fx_graph_return_to_tuple(self, fx_g: fx.GraphModule):
for node in fx_g.graph.nodes:
if node.op == "output":
# output nodes always have one argument
node_arg = node.args[0]
out_nodes = []
if isinstance(node_arg, list):
# Don't return NoneType elements.
for out_node in node_arg:
if not isinstance(out_node, type(None)):
out_nodes.append(out_node)
# If there is a single tensor/element to be returned don't
# a tuple for it.
if len(node_arg) == 1:
node.args = node_arg
if len(out_nodes) == 1:
node.args = out_nodes
else:
node.args = (tuple(node_arg),)
node.args = (tuple(out_nodes),)
fx_g.graph.lint()
fx_g.recompile()
return fx_g
def get_forward_graph(self, fx_g: fx.GraphModule, inps):
fx_g = self.change_fx_graph_return_to_tuple(fx_g)
return_fx = copy.deepcopy(fx_g)
f = self.change_fx_graph_return_to_tuple(fx_g)
f = torch.jit.script(fx_g)
f = torch.jit.freeze(f.eval())
torch.jit.save(f, "forw.pt")
f = torch.jit.load("forw.pt")
self.forward_graph = f
self.forward_inputs = copy.deepcopy(inps)
return f
return return_fx
def get_backward_graph(self, fx_g: fx.GraphModule, inps):
fx_g = self.change_fx_graph_return_to_tuple(fx_g)
return_fx = copy.deepcopy(fx_g)
f = self.change_fx_graph_return_to_tuple(fx_g)
f = torch.jit.script(fx_g)
f = torch.jit.freeze(f.eval())
torch.jit.save(f, "back.pt")
f = torch.jit.load("back.pt")
self.backward_graph = f
self.backward_inputs = copy.deepcopy(inps)
return f
return return_fx
def generate_inference_graph(self):
aot_model = memory_efficient_fusion(

View File

@@ -24,8 +24,7 @@ IREE_DEVICE_MAP = {"cpu": "dylib", "gpu": "cuda", "vulkan": "vulkan"}
def get_iree_compiled_module(module, device: str):
"""TODO: Documentation"""
flatbuffer_blob = ireec.compile_str(
str(module), target_backends=[IREE_DEVICE_MAP[device]]
)
str(module), target_backends=[IREE_DEVICE_MAP[device]])
vm_module = ireert.VmModule.from_flatbuffer(flatbuffer_blob)
config = ireert.Config(IREE_DEVICE_MAP[device])
ctx = ireert.SystemContext(config=config)
@@ -34,24 +33,24 @@ def get_iree_compiled_module(module, device: str):
ModuleCompiled = ctx.modules.module["forward"]
return ModuleCompiled, config
def export_iree_module_to_vmfb(module, device: str, directory: str):
module_name = get_module_name_for_asm_dump(module)
flatbuffer_blob = ireec.compile_str(
str(module), target_backends=[IREE_DEVICE_MAP[device]]
)
str(module), target_backends=[IREE_DEVICE_MAP[device]])
filename = os.path.join(directory, module_name + ".vmfb")
with open(filename, 'wb') as f:
f.write(flatbuffer_blob)
def get_results(compiled_vm, input, config):
"""TODO: Documentation"""
# TODO: Support returning multiple outputs.
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
result = compiled_vm(*device_inputs)
result_numpy = np.asarray(result, dtype=result.dtype)
# TODO: Segfault if the copy of numpy array is not returned.
result_copy = np.copy(result_numpy)
# {k:v.to_host() for k, v in device_outputs.items(
return result_copy
result_tensors = []
if (isinstance(result, tuple)):
for val in result:
result_tensors.append(np.copy(np.asarray(val, val.dtype)))
return result_tensors
else:
return np.copy(np.asarray(result, dtype=result.dtype))

View File

@@ -16,14 +16,17 @@ from shark.torch_mlir_utils import get_torch_mlir_module, export_module_to_mlir_
from shark.iree_utils import get_results, get_iree_compiled_module, export_iree_module_to_vmfb
import argparse
import os
# from functorch_utils import AOTModule
from shark.functorch_utils import AOTModule
def dir_path(path):
if os.path.isdir(path):
return path
else:
raise argparse.ArgumentTypeError(f"readable_dir:{path} is not a valid path")
raise argparse.ArgumentTypeError(
f"readable_dir:{path} is not a valid path")
class SharkRunner:
"""TODO: Write the description"""
@@ -37,28 +40,41 @@ class SharkRunner:
from_aot: bool,
):
self.parser = argparse.ArgumentParser(description='SHARK runner.')
self.parser.add_argument("--repro_dir", help="Directory to which module files will be saved for reproduction or debugging.", type=dir_path, default="/tmp/")
self.parser.add_argument("--save_mlir", default=False, action="store_true", help="Saves input MLIR module to /tmp/ directory.")
self.parser.add_argument("--save_vmfb", default=False, action="store_true", help="Saves iree .vmfb module to /tmp/ directory.")
self.parser.add_argument(
"--repro_dir",
help=
"Directory to which module files will be saved for reproduction or debugging.",
type=dir_path,
default="/tmp/")
self.parser.add_argument(
"--save_mlir",
default=False,
action="store_true",
help="Saves input MLIR module to /tmp/ directory.")
self.parser.add_argument(
"--save_vmfb",
default=False,
action="store_true",
help="Saves iree .vmfb module to /tmp/ directory.")
self.parser.parse_args(namespace=self)
self.torch_module = model
self.input = input
self.torch_mlir_module = get_torch_mlir_module(
model, input, dynamic, tracing_required, from_aot
)
self.torch_mlir_module = get_torch_mlir_module(model, input, dynamic,
tracing_required,
from_aot)
if self.save_mlir:
export_module_to_mlir_file(self.torch_mlir_module, self.repro_dir)
if self.save_vmfb:
export_iree_module_to_vmfb(self.torch_mlir_module, device, self.repro_dir)
export_iree_module_to_vmfb(self.torch_mlir_module, device,
self.repro_dir)
(
self.iree_compilation_module,
self.iree_config,
) = get_iree_compiled_module(self.torch_mlir_module, device)
def forward(self, input):
return get_results(
self.iree_compilation_module, input, self.iree_config
)
return get_results(self.iree_compilation_module, input,
self.iree_config)
class SharkInference:
@@ -78,17 +94,16 @@ class SharkInference:
self.input = input
self.from_aot = from_aot
# if from_aot:
# aot_module = AOTModule(
# model, input, custom_inference_fn=custom_inference_fn
# )
# aot_module.generate_inference_graph()
# self.model = aot_module.forward_graph
# self.input = aot_module.forward_inputs
if from_aot:
aot_module = AOTModule(model,
input,
custom_inference_fn=custom_inference_fn)
aot_module.generate_inference_graph()
self.model = aot_module.forward_graph
self.input = aot_module.forward_inputs
self.shark_runner = SharkRunner(
self.model, self.input, dynamic, device, jit_trace, from_aot
)
self.shark_runner = SharkRunner(self.model, self.input, dynamic, device,
jit_trace, from_aot)
def forward(self, inputs):
# TODO Capture weights and inputs in case of AOT, Also rework the
@@ -120,17 +135,16 @@ class SharkTrainer:
self.forward_graph = aot_module.forward_graph
self.forward_inputs = aot_module.forward_inputs
self.backward_graph = aot_module.backward_graph
print(self.backward_graph.graph)
self.backward_inputs = aot_module.backward_inputs
# self.shark_forward = SharkRunner(
# self.forward_graph,
# self.forward_inputs,
# dynamic,
# device,
# jit_trace,
# from_aot,
# )
self.shark_forward = SharkRunner(
self.forward_graph,
self.forward_inputs,
dynamic,
device,
jit_trace,
from_aot,
)
self.shark_backward = SharkRunner(
self.backward_graph,
self.backward_inputs,
@@ -153,5 +167,5 @@ class SharkTrainer:
for _ in range(iters):
self.shark_forward.forward(forward_inputs)
# self.shark_backward.forward(backward_inputs)
self.shark_backward.forward(backward_inputs)
return

View File

@@ -24,23 +24,24 @@ from torch_mlir.dialects.torch.importer.jit_ir import (
ModuleBuilder,
)
from torch_mlir_e2e_test.torchscript.serialization import (
extract_serializable_annotations,
apply_serializable_annotations,
SerializableTest
)
extract_serializable_annotations, apply_serializable_annotations,
SerializableTest)
from torch_mlir.passmanager import PassManager
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
from torch_mlir.ir import StringAttr
def get_module_name_for_asm_dump(module):
"""Gets a name suitable for an assembly dump.
The name is not guaranteed to be unique.
"""
if not "torch.debug_module_name" in module.operation.attributes:
return "UnnammedModule"
return StringAttr(module.operation.attributes["torch.debug_module_name"]).value
return StringAttr(
module.operation.attributes["torch.debug_module_name"]).value
def export_module_to_mlir_file(module, directory: str):
"""Writes MLIR module to /tmp/module.mlir for debugging or performance use."""
module_name = get_module_name_for_asm_dump(module)
@@ -49,6 +50,7 @@ def export_module_to_mlir_file(module, directory: str):
with open(filename, 'w') as f:
f.write(asm)
def get_input_annotations(inputs: tuple, dynamic: bool) -> list:
"""TODO: Include necessary documentation"""
@@ -65,9 +67,8 @@ def get_input_annotations(inputs: tuple, dynamic: bool) -> list:
return annotations_list
def shark_jit_trace(
module, input: tuple, dynamic: bool, tracing_required: bool
):
def shark_jit_trace(module, input: tuple, dynamic: bool,
tracing_required: bool):
"""TODO: Include necessary documentation."""
if not tracing_required:
@@ -77,26 +78,21 @@ def shark_jit_trace(
actual_script = traced_module._actual_script_module
export(actual_script.forward)
annotate_args_decorator = annotate_args(
get_input_annotations(input, dynamic)
)
get_input_annotations(input, dynamic))
annotate_args_decorator(actual_script.forward)
module = torch.jit.script(actual_script)
# TODO: remove saved annotations.pickle
torchscript_module_bytes = module.save_to_buffer(
{
"annotations.pkl": pickle.dumps(
extract_serializable_annotations(module)
)
}
)
serializable_test = SerializableTest(
unique_name="", program=torchscript_module_bytes, trace=None
)
torchscript_module_bytes = module.save_to_buffer({
"annotations.pkl":
pickle.dumps(extract_serializable_annotations(module))
})
serializable_test = SerializableTest(unique_name="",
program=torchscript_module_bytes,
trace=None)
_extra_files = {"annotations.pkl": ""}
module = torch.jit.load(
io.BytesIO(serializable_test.program), _extra_files=_extra_files
)
module = torch.jit.load(io.BytesIO(serializable_test.program),
_extra_files=_extra_files)
# Load the pickled annotations.
annotations = pickle.loads(_extra_files["annotations.pkl"])
apply_serializable_annotations(module, annotations)