diff --git a/.style.yapf b/.style.yapf new file mode 100644 index 00000000..7da00642 --- /dev/null +++ b/.style.yapf @@ -0,0 +1,3 @@ +[style] + based_on_style = google + column_limit = 80 diff --git a/shark/examples/fullyconnected_aot.py b/shark/examples/fullyconnected_aot.py index b8245199..91267a80 100644 --- a/shark/examples/fullyconnected_aot.py +++ b/shark/examples/fullyconnected_aot.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torchvision.models as models -from shark_runner import SharkTrainer +from shark.shark_runner import SharkTrainer class NeuralNet(nn.Module): @@ -36,17 +36,17 @@ results = shark_module.train((input,)) # print(results) -input = torch.randn(1, 3, 224, 224) -labels = torch.randn(1, 1000) +# input = torch.randn(1, 3, 224, 224) +# labels = torch.randn(1, 1000) -class Resnet50Module(torch.nn.Module): - def __init__(self): - super().__init__() - self.resnet = models.resnet50(pretrained=True) - self.train(True) +# class Resnet50Module(torch.nn.Module): + # def __init__(self): + # super().__init__() + # self.resnet = models.resnet50(pretrained=True) + # self.train(True) - def forward(self, img): - return self.resnet.forward(img) + # def forward(self, img): + # return self.resnet.forward(img) -shark_module = SharkTrainer(Resnet50Module(), (input,), (labels,), from_aot=True) -results = shark_module.train((input,)) +# shark_module = SharkTrainer(Resnet50Module(), (input,), (labels,), from_aot=True) +# results = shark_module.train((input,)) diff --git a/shark/functorch_utils.py b/shark/functorch_utils.py index 16dab8e0..6495484c 100644 --- a/shark/functorch_utils.py +++ b/shark/functorch_utils.py @@ -23,6 +23,7 @@ import copy class AOTModule: + def __init__(self, model, inputs, labels=None, custom_inference_fn=None): self.model = model self.inputs = inputs @@ -56,41 +57,71 @@ class AOTModule: loss.backward() optimizer.step() + # Replaces None types with zeros. + # def change_fx_graph_return_to_tuple(self, fx_g: fx.GraphModule) -> fx.GraphModule: + # for node in fx_g.graph.nodes: + # if node.op == 'output': + # # output nodes always have one argument + # node_arg = node.args[0] + # out_nodes = [] + # if isinstance(node_arg, list): + # for out_node in node_arg: + # if isinstance(out_node, type(None)): + # print("None node found replacing with zeros") + # with fx_g.graph.inserting_before(out_node): + # new_node = fx_g.graph.call_function(torch.ops.aten.zeros, (-1,)) + # out_nodes.append(new_node) + # else: + # out_nodes.append(out_node) + + # node.args = (tuple(out_nodes),) + # fx_g.graph.lint() + # fx_g.recompile() + # return fx_g + + # Doesn't replace the None type. def change_fx_graph_return_to_tuple(self, fx_g: fx.GraphModule): for node in fx_g.graph.nodes: if node.op == "output": # output nodes always have one argument node_arg = node.args[0] + out_nodes = [] if isinstance(node_arg, list): + # Don't return NoneType elements. + for out_node in node_arg: + if not isinstance(out_node, type(None)): + out_nodes.append(out_node) # If there is a single tensor/element to be returned don't # a tuple for it. - if len(node_arg) == 1: - node.args = node_arg + if len(out_nodes) == 1: + node.args = out_nodes else: - node.args = (tuple(node_arg),) + node.args = (tuple(out_nodes),) fx_g.graph.lint() fx_g.recompile() return fx_g def get_forward_graph(self, fx_g: fx.GraphModule, inps): - fx_g = self.change_fx_graph_return_to_tuple(fx_g) + return_fx = copy.deepcopy(fx_g) + f = self.change_fx_graph_return_to_tuple(fx_g) f = torch.jit.script(fx_g) f = torch.jit.freeze(f.eval()) torch.jit.save(f, "forw.pt") f = torch.jit.load("forw.pt") self.forward_graph = f self.forward_inputs = copy.deepcopy(inps) - return f + return return_fx def get_backward_graph(self, fx_g: fx.GraphModule, inps): - fx_g = self.change_fx_graph_return_to_tuple(fx_g) + return_fx = copy.deepcopy(fx_g) + f = self.change_fx_graph_return_to_tuple(fx_g) f = torch.jit.script(fx_g) f = torch.jit.freeze(f.eval()) torch.jit.save(f, "back.pt") f = torch.jit.load("back.pt") self.backward_graph = f self.backward_inputs = copy.deepcopy(inps) - return f + return return_fx def generate_inference_graph(self): aot_model = memory_efficient_fusion( diff --git a/shark/iree_utils.py b/shark/iree_utils.py index 0524c389..428d6929 100644 --- a/shark/iree_utils.py +++ b/shark/iree_utils.py @@ -24,8 +24,7 @@ IREE_DEVICE_MAP = {"cpu": "dylib", "gpu": "cuda", "vulkan": "vulkan"} def get_iree_compiled_module(module, device: str): """TODO: Documentation""" flatbuffer_blob = ireec.compile_str( - str(module), target_backends=[IREE_DEVICE_MAP[device]] - ) + str(module), target_backends=[IREE_DEVICE_MAP[device]]) vm_module = ireert.VmModule.from_flatbuffer(flatbuffer_blob) config = ireert.Config(IREE_DEVICE_MAP[device]) ctx = ireert.SystemContext(config=config) @@ -34,24 +33,24 @@ def get_iree_compiled_module(module, device: str): ModuleCompiled = ctx.modules.module["forward"] return ModuleCompiled, config + def export_iree_module_to_vmfb(module, device: str, directory: str): module_name = get_module_name_for_asm_dump(module) flatbuffer_blob = ireec.compile_str( - str(module), target_backends=[IREE_DEVICE_MAP[device]] - ) + str(module), target_backends=[IREE_DEVICE_MAP[device]]) filename = os.path.join(directory, module_name + ".vmfb") with open(filename, 'wb') as f: f.write(flatbuffer_blob) + def get_results(compiled_vm, input, config): """TODO: Documentation""" - - # TODO: Support returning multiple outputs. device_inputs = [ireert.asdevicearray(config.device, a) for a in input] result = compiled_vm(*device_inputs) - result_numpy = np.asarray(result, dtype=result.dtype) - # TODO: Segfault if the copy of numpy array is not returned. - result_copy = np.copy(result_numpy) - # {k:v.to_host() for k, v in device_outputs.items( - return result_copy - + result_tensors = [] + if (isinstance(result, tuple)): + for val in result: + result_tensors.append(np.copy(np.asarray(val, val.dtype))) + return result_tensors + else: + return np.copy(np.asarray(result, dtype=result.dtype)) diff --git a/shark/shark_runner.py b/shark/shark_runner.py index ae8247eb..da9ace55 100644 --- a/shark/shark_runner.py +++ b/shark/shark_runner.py @@ -16,14 +16,17 @@ from shark.torch_mlir_utils import get_torch_mlir_module, export_module_to_mlir_ from shark.iree_utils import get_results, get_iree_compiled_module, export_iree_module_to_vmfb import argparse import os -# from functorch_utils import AOTModule +from shark.functorch_utils import AOTModule + def dir_path(path): if os.path.isdir(path): return path else: - raise argparse.ArgumentTypeError(f"readable_dir:{path} is not a valid path") - + raise argparse.ArgumentTypeError( + f"readable_dir:{path} is not a valid path") + + class SharkRunner: """TODO: Write the description""" @@ -37,28 +40,41 @@ class SharkRunner: from_aot: bool, ): self.parser = argparse.ArgumentParser(description='SHARK runner.') - self.parser.add_argument("--repro_dir", help="Directory to which module files will be saved for reproduction or debugging.", type=dir_path, default="/tmp/") - self.parser.add_argument("--save_mlir", default=False, action="store_true", help="Saves input MLIR module to /tmp/ directory.") - self.parser.add_argument("--save_vmfb", default=False, action="store_true", help="Saves iree .vmfb module to /tmp/ directory.") + self.parser.add_argument( + "--repro_dir", + help= + "Directory to which module files will be saved for reproduction or debugging.", + type=dir_path, + default="/tmp/") + self.parser.add_argument( + "--save_mlir", + default=False, + action="store_true", + help="Saves input MLIR module to /tmp/ directory.") + self.parser.add_argument( + "--save_vmfb", + default=False, + action="store_true", + help="Saves iree .vmfb module to /tmp/ directory.") self.parser.parse_args(namespace=self) self.torch_module = model self.input = input - self.torch_mlir_module = get_torch_mlir_module( - model, input, dynamic, tracing_required, from_aot - ) + self.torch_mlir_module = get_torch_mlir_module(model, input, dynamic, + tracing_required, + from_aot) if self.save_mlir: export_module_to_mlir_file(self.torch_mlir_module, self.repro_dir) if self.save_vmfb: - export_iree_module_to_vmfb(self.torch_mlir_module, device, self.repro_dir) + export_iree_module_to_vmfb(self.torch_mlir_module, device, + self.repro_dir) ( self.iree_compilation_module, self.iree_config, ) = get_iree_compiled_module(self.torch_mlir_module, device) - + def forward(self, input): - return get_results( - self.iree_compilation_module, input, self.iree_config - ) + return get_results(self.iree_compilation_module, input, + self.iree_config) class SharkInference: @@ -78,17 +94,16 @@ class SharkInference: self.input = input self.from_aot = from_aot - # if from_aot: - # aot_module = AOTModule( - # model, input, custom_inference_fn=custom_inference_fn - # ) - # aot_module.generate_inference_graph() - # self.model = aot_module.forward_graph - # self.input = aot_module.forward_inputs + if from_aot: + aot_module = AOTModule(model, + input, + custom_inference_fn=custom_inference_fn) + aot_module.generate_inference_graph() + self.model = aot_module.forward_graph + self.input = aot_module.forward_inputs - self.shark_runner = SharkRunner( - self.model, self.input, dynamic, device, jit_trace, from_aot - ) + self.shark_runner = SharkRunner(self.model, self.input, dynamic, device, + jit_trace, from_aot) def forward(self, inputs): # TODO Capture weights and inputs in case of AOT, Also rework the @@ -120,17 +135,16 @@ class SharkTrainer: self.forward_graph = aot_module.forward_graph self.forward_inputs = aot_module.forward_inputs self.backward_graph = aot_module.backward_graph - print(self.backward_graph.graph) self.backward_inputs = aot_module.backward_inputs - # self.shark_forward = SharkRunner( - # self.forward_graph, - # self.forward_inputs, - # dynamic, - # device, - # jit_trace, - # from_aot, - # ) + self.shark_forward = SharkRunner( + self.forward_graph, + self.forward_inputs, + dynamic, + device, + jit_trace, + from_aot, + ) self.shark_backward = SharkRunner( self.backward_graph, self.backward_inputs, @@ -153,5 +167,5 @@ class SharkTrainer: for _ in range(iters): self.shark_forward.forward(forward_inputs) - # self.shark_backward.forward(backward_inputs) + self.shark_backward.forward(backward_inputs) return diff --git a/shark/torch_mlir_utils.py b/shark/torch_mlir_utils.py index 0e876923..f1878fda 100644 --- a/shark/torch_mlir_utils.py +++ b/shark/torch_mlir_utils.py @@ -24,23 +24,24 @@ from torch_mlir.dialects.torch.importer.jit_ir import ( ModuleBuilder, ) from torch_mlir_e2e_test.torchscript.serialization import ( - extract_serializable_annotations, - apply_serializable_annotations, - SerializableTest -) + extract_serializable_annotations, apply_serializable_annotations, + SerializableTest) from torch_mlir.passmanager import PassManager from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export from torch_mlir.ir import StringAttr + def get_module_name_for_asm_dump(module): """Gets a name suitable for an assembly dump. The name is not guaranteed to be unique. """ if not "torch.debug_module_name" in module.operation.attributes: return "UnnammedModule" - return StringAttr(module.operation.attributes["torch.debug_module_name"]).value - + return StringAttr( + module.operation.attributes["torch.debug_module_name"]).value + + def export_module_to_mlir_file(module, directory: str): """Writes MLIR module to /tmp/module.mlir for debugging or performance use.""" module_name = get_module_name_for_asm_dump(module) @@ -49,6 +50,7 @@ def export_module_to_mlir_file(module, directory: str): with open(filename, 'w') as f: f.write(asm) + def get_input_annotations(inputs: tuple, dynamic: bool) -> list: """TODO: Include necessary documentation""" @@ -65,9 +67,8 @@ def get_input_annotations(inputs: tuple, dynamic: bool) -> list: return annotations_list -def shark_jit_trace( - module, input: tuple, dynamic: bool, tracing_required: bool -): +def shark_jit_trace(module, input: tuple, dynamic: bool, + tracing_required: bool): """TODO: Include necessary documentation.""" if not tracing_required: @@ -77,26 +78,21 @@ def shark_jit_trace( actual_script = traced_module._actual_script_module export(actual_script.forward) annotate_args_decorator = annotate_args( - get_input_annotations(input, dynamic) - ) + get_input_annotations(input, dynamic)) annotate_args_decorator(actual_script.forward) module = torch.jit.script(actual_script) # TODO: remove saved annotations.pickle - torchscript_module_bytes = module.save_to_buffer( - { - "annotations.pkl": pickle.dumps( - extract_serializable_annotations(module) - ) - } - ) - serializable_test = SerializableTest( - unique_name="", program=torchscript_module_bytes, trace=None - ) + torchscript_module_bytes = module.save_to_buffer({ + "annotations.pkl": + pickle.dumps(extract_serializable_annotations(module)) + }) + serializable_test = SerializableTest(unique_name="", + program=torchscript_module_bytes, + trace=None) _extra_files = {"annotations.pkl": ""} - module = torch.jit.load( - io.BytesIO(serializable_test.program), _extra_files=_extra_files - ) + module = torch.jit.load(io.BytesIO(serializable_test.program), + _extra_files=_extra_files) # Load the pickled annotations. annotations = pickle.loads(_extra_files["annotations.pkl"]) apply_serializable_annotations(module, annotations)