mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 13:57:54 -05:00
Enable Shark Trainer APIs
This commit is contained in:
3
.style.yapf
Normal file
3
.style.yapf
Normal file
@@ -0,0 +1,3 @@
|
||||
[style]
|
||||
based_on_style = google
|
||||
column_limit = 80
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
from shark_runner import SharkTrainer
|
||||
from shark.shark_runner import SharkTrainer
|
||||
|
||||
|
||||
class NeuralNet(nn.Module):
|
||||
@@ -36,17 +36,17 @@ results = shark_module.train((input,))
|
||||
|
||||
# print(results)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
labels = torch.randn(1, 1000)
|
||||
# input = torch.randn(1, 3, 224, 224)
|
||||
# labels = torch.randn(1, 1000)
|
||||
|
||||
class Resnet50Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.resnet = models.resnet50(pretrained=True)
|
||||
self.train(True)
|
||||
# class Resnet50Module(torch.nn.Module):
|
||||
# def __init__(self):
|
||||
# super().__init__()
|
||||
# self.resnet = models.resnet50(pretrained=True)
|
||||
# self.train(True)
|
||||
|
||||
def forward(self, img):
|
||||
return self.resnet.forward(img)
|
||||
# def forward(self, img):
|
||||
# return self.resnet.forward(img)
|
||||
|
||||
shark_module = SharkTrainer(Resnet50Module(), (input,), (labels,), from_aot=True)
|
||||
results = shark_module.train((input,))
|
||||
# shark_module = SharkTrainer(Resnet50Module(), (input,), (labels,), from_aot=True)
|
||||
# results = shark_module.train((input,))
|
||||
|
||||
@@ -23,6 +23,7 @@ import copy
|
||||
|
||||
|
||||
class AOTModule:
|
||||
|
||||
def __init__(self, model, inputs, labels=None, custom_inference_fn=None):
|
||||
self.model = model
|
||||
self.inputs = inputs
|
||||
@@ -56,41 +57,71 @@ class AOTModule:
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Replaces None types with zeros.
|
||||
# def change_fx_graph_return_to_tuple(self, fx_g: fx.GraphModule) -> fx.GraphModule:
|
||||
# for node in fx_g.graph.nodes:
|
||||
# if node.op == 'output':
|
||||
# # output nodes always have one argument
|
||||
# node_arg = node.args[0]
|
||||
# out_nodes = []
|
||||
# if isinstance(node_arg, list):
|
||||
# for out_node in node_arg:
|
||||
# if isinstance(out_node, type(None)):
|
||||
# print("None node found replacing with zeros")
|
||||
# with fx_g.graph.inserting_before(out_node):
|
||||
# new_node = fx_g.graph.call_function(torch.ops.aten.zeros, (-1,))
|
||||
# out_nodes.append(new_node)
|
||||
# else:
|
||||
# out_nodes.append(out_node)
|
||||
|
||||
# node.args = (tuple(out_nodes),)
|
||||
# fx_g.graph.lint()
|
||||
# fx_g.recompile()
|
||||
# return fx_g
|
||||
|
||||
# Doesn't replace the None type.
|
||||
def change_fx_graph_return_to_tuple(self, fx_g: fx.GraphModule):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
# output nodes always have one argument
|
||||
node_arg = node.args[0]
|
||||
out_nodes = []
|
||||
if isinstance(node_arg, list):
|
||||
# Don't return NoneType elements.
|
||||
for out_node in node_arg:
|
||||
if not isinstance(out_node, type(None)):
|
||||
out_nodes.append(out_node)
|
||||
# If there is a single tensor/element to be returned don't
|
||||
# a tuple for it.
|
||||
if len(node_arg) == 1:
|
||||
node.args = node_arg
|
||||
if len(out_nodes) == 1:
|
||||
node.args = out_nodes
|
||||
else:
|
||||
node.args = (tuple(node_arg),)
|
||||
node.args = (tuple(out_nodes),)
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return fx_g
|
||||
|
||||
def get_forward_graph(self, fx_g: fx.GraphModule, inps):
|
||||
fx_g = self.change_fx_graph_return_to_tuple(fx_g)
|
||||
return_fx = copy.deepcopy(fx_g)
|
||||
f = self.change_fx_graph_return_to_tuple(fx_g)
|
||||
f = torch.jit.script(fx_g)
|
||||
f = torch.jit.freeze(f.eval())
|
||||
torch.jit.save(f, "forw.pt")
|
||||
f = torch.jit.load("forw.pt")
|
||||
self.forward_graph = f
|
||||
self.forward_inputs = copy.deepcopy(inps)
|
||||
return f
|
||||
return return_fx
|
||||
|
||||
def get_backward_graph(self, fx_g: fx.GraphModule, inps):
|
||||
fx_g = self.change_fx_graph_return_to_tuple(fx_g)
|
||||
return_fx = copy.deepcopy(fx_g)
|
||||
f = self.change_fx_graph_return_to_tuple(fx_g)
|
||||
f = torch.jit.script(fx_g)
|
||||
f = torch.jit.freeze(f.eval())
|
||||
torch.jit.save(f, "back.pt")
|
||||
f = torch.jit.load("back.pt")
|
||||
self.backward_graph = f
|
||||
self.backward_inputs = copy.deepcopy(inps)
|
||||
return f
|
||||
return return_fx
|
||||
|
||||
def generate_inference_graph(self):
|
||||
aot_model = memory_efficient_fusion(
|
||||
|
||||
@@ -24,8 +24,7 @@ IREE_DEVICE_MAP = {"cpu": "dylib", "gpu": "cuda", "vulkan": "vulkan"}
|
||||
def get_iree_compiled_module(module, device: str):
|
||||
"""TODO: Documentation"""
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
str(module), target_backends=[IREE_DEVICE_MAP[device]]
|
||||
)
|
||||
str(module), target_backends=[IREE_DEVICE_MAP[device]])
|
||||
vm_module = ireert.VmModule.from_flatbuffer(flatbuffer_blob)
|
||||
config = ireert.Config(IREE_DEVICE_MAP[device])
|
||||
ctx = ireert.SystemContext(config=config)
|
||||
@@ -34,24 +33,24 @@ def get_iree_compiled_module(module, device: str):
|
||||
ModuleCompiled = ctx.modules.module["forward"]
|
||||
return ModuleCompiled, config
|
||||
|
||||
|
||||
def export_iree_module_to_vmfb(module, device: str, directory: str):
|
||||
module_name = get_module_name_for_asm_dump(module)
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
str(module), target_backends=[IREE_DEVICE_MAP[device]]
|
||||
)
|
||||
str(module), target_backends=[IREE_DEVICE_MAP[device]])
|
||||
filename = os.path.join(directory, module_name + ".vmfb")
|
||||
with open(filename, 'wb') as f:
|
||||
f.write(flatbuffer_blob)
|
||||
|
||||
|
||||
def get_results(compiled_vm, input, config):
|
||||
"""TODO: Documentation"""
|
||||
|
||||
# TODO: Support returning multiple outputs.
|
||||
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
|
||||
result = compiled_vm(*device_inputs)
|
||||
result_numpy = np.asarray(result, dtype=result.dtype)
|
||||
# TODO: Segfault if the copy of numpy array is not returned.
|
||||
result_copy = np.copy(result_numpy)
|
||||
# {k:v.to_host() for k, v in device_outputs.items(
|
||||
return result_copy
|
||||
|
||||
result_tensors = []
|
||||
if (isinstance(result, tuple)):
|
||||
for val in result:
|
||||
result_tensors.append(np.copy(np.asarray(val, val.dtype)))
|
||||
return result_tensors
|
||||
else:
|
||||
return np.copy(np.asarray(result, dtype=result.dtype))
|
||||
|
||||
@@ -16,14 +16,17 @@ from shark.torch_mlir_utils import get_torch_mlir_module, export_module_to_mlir_
|
||||
from shark.iree_utils import get_results, get_iree_compiled_module, export_iree_module_to_vmfb
|
||||
import argparse
|
||||
import os
|
||||
# from functorch_utils import AOTModule
|
||||
from shark.functorch_utils import AOTModule
|
||||
|
||||
|
||||
def dir_path(path):
|
||||
if os.path.isdir(path):
|
||||
return path
|
||||
else:
|
||||
raise argparse.ArgumentTypeError(f"readable_dir:{path} is not a valid path")
|
||||
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"readable_dir:{path} is not a valid path")
|
||||
|
||||
|
||||
class SharkRunner:
|
||||
"""TODO: Write the description"""
|
||||
|
||||
@@ -37,28 +40,41 @@ class SharkRunner:
|
||||
from_aot: bool,
|
||||
):
|
||||
self.parser = argparse.ArgumentParser(description='SHARK runner.')
|
||||
self.parser.add_argument("--repro_dir", help="Directory to which module files will be saved for reproduction or debugging.", type=dir_path, default="/tmp/")
|
||||
self.parser.add_argument("--save_mlir", default=False, action="store_true", help="Saves input MLIR module to /tmp/ directory.")
|
||||
self.parser.add_argument("--save_vmfb", default=False, action="store_true", help="Saves iree .vmfb module to /tmp/ directory.")
|
||||
self.parser.add_argument(
|
||||
"--repro_dir",
|
||||
help=
|
||||
"Directory to which module files will be saved for reproduction or debugging.",
|
||||
type=dir_path,
|
||||
default="/tmp/")
|
||||
self.parser.add_argument(
|
||||
"--save_mlir",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Saves input MLIR module to /tmp/ directory.")
|
||||
self.parser.add_argument(
|
||||
"--save_vmfb",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Saves iree .vmfb module to /tmp/ directory.")
|
||||
self.parser.parse_args(namespace=self)
|
||||
self.torch_module = model
|
||||
self.input = input
|
||||
self.torch_mlir_module = get_torch_mlir_module(
|
||||
model, input, dynamic, tracing_required, from_aot
|
||||
)
|
||||
self.torch_mlir_module = get_torch_mlir_module(model, input, dynamic,
|
||||
tracing_required,
|
||||
from_aot)
|
||||
if self.save_mlir:
|
||||
export_module_to_mlir_file(self.torch_mlir_module, self.repro_dir)
|
||||
if self.save_vmfb:
|
||||
export_iree_module_to_vmfb(self.torch_mlir_module, device, self.repro_dir)
|
||||
export_iree_module_to_vmfb(self.torch_mlir_module, device,
|
||||
self.repro_dir)
|
||||
(
|
||||
self.iree_compilation_module,
|
||||
self.iree_config,
|
||||
) = get_iree_compiled_module(self.torch_mlir_module, device)
|
||||
|
||||
|
||||
def forward(self, input):
|
||||
return get_results(
|
||||
self.iree_compilation_module, input, self.iree_config
|
||||
)
|
||||
return get_results(self.iree_compilation_module, input,
|
||||
self.iree_config)
|
||||
|
||||
|
||||
class SharkInference:
|
||||
@@ -78,17 +94,16 @@ class SharkInference:
|
||||
self.input = input
|
||||
self.from_aot = from_aot
|
||||
|
||||
# if from_aot:
|
||||
# aot_module = AOTModule(
|
||||
# model, input, custom_inference_fn=custom_inference_fn
|
||||
# )
|
||||
# aot_module.generate_inference_graph()
|
||||
# self.model = aot_module.forward_graph
|
||||
# self.input = aot_module.forward_inputs
|
||||
if from_aot:
|
||||
aot_module = AOTModule(model,
|
||||
input,
|
||||
custom_inference_fn=custom_inference_fn)
|
||||
aot_module.generate_inference_graph()
|
||||
self.model = aot_module.forward_graph
|
||||
self.input = aot_module.forward_inputs
|
||||
|
||||
self.shark_runner = SharkRunner(
|
||||
self.model, self.input, dynamic, device, jit_trace, from_aot
|
||||
)
|
||||
self.shark_runner = SharkRunner(self.model, self.input, dynamic, device,
|
||||
jit_trace, from_aot)
|
||||
|
||||
def forward(self, inputs):
|
||||
# TODO Capture weights and inputs in case of AOT, Also rework the
|
||||
@@ -120,17 +135,16 @@ class SharkTrainer:
|
||||
self.forward_graph = aot_module.forward_graph
|
||||
self.forward_inputs = aot_module.forward_inputs
|
||||
self.backward_graph = aot_module.backward_graph
|
||||
print(self.backward_graph.graph)
|
||||
self.backward_inputs = aot_module.backward_inputs
|
||||
|
||||
# self.shark_forward = SharkRunner(
|
||||
# self.forward_graph,
|
||||
# self.forward_inputs,
|
||||
# dynamic,
|
||||
# device,
|
||||
# jit_trace,
|
||||
# from_aot,
|
||||
# )
|
||||
self.shark_forward = SharkRunner(
|
||||
self.forward_graph,
|
||||
self.forward_inputs,
|
||||
dynamic,
|
||||
device,
|
||||
jit_trace,
|
||||
from_aot,
|
||||
)
|
||||
self.shark_backward = SharkRunner(
|
||||
self.backward_graph,
|
||||
self.backward_inputs,
|
||||
@@ -153,5 +167,5 @@ class SharkTrainer:
|
||||
|
||||
for _ in range(iters):
|
||||
self.shark_forward.forward(forward_inputs)
|
||||
# self.shark_backward.forward(backward_inputs)
|
||||
self.shark_backward.forward(backward_inputs)
|
||||
return
|
||||
|
||||
@@ -24,23 +24,24 @@ from torch_mlir.dialects.torch.importer.jit_ir import (
|
||||
ModuleBuilder,
|
||||
)
|
||||
from torch_mlir_e2e_test.torchscript.serialization import (
|
||||
extract_serializable_annotations,
|
||||
apply_serializable_annotations,
|
||||
SerializableTest
|
||||
)
|
||||
extract_serializable_annotations, apply_serializable_annotations,
|
||||
SerializableTest)
|
||||
|
||||
from torch_mlir.passmanager import PassManager
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
from torch_mlir.ir import StringAttr
|
||||
|
||||
|
||||
def get_module_name_for_asm_dump(module):
|
||||
"""Gets a name suitable for an assembly dump.
|
||||
The name is not guaranteed to be unique.
|
||||
"""
|
||||
if not "torch.debug_module_name" in module.operation.attributes:
|
||||
return "UnnammedModule"
|
||||
return StringAttr(module.operation.attributes["torch.debug_module_name"]).value
|
||||
|
||||
return StringAttr(
|
||||
module.operation.attributes["torch.debug_module_name"]).value
|
||||
|
||||
|
||||
def export_module_to_mlir_file(module, directory: str):
|
||||
"""Writes MLIR module to /tmp/module.mlir for debugging or performance use."""
|
||||
module_name = get_module_name_for_asm_dump(module)
|
||||
@@ -49,6 +50,7 @@ def export_module_to_mlir_file(module, directory: str):
|
||||
with open(filename, 'w') as f:
|
||||
f.write(asm)
|
||||
|
||||
|
||||
def get_input_annotations(inputs: tuple, dynamic: bool) -> list:
|
||||
"""TODO: Include necessary documentation"""
|
||||
|
||||
@@ -65,9 +67,8 @@ def get_input_annotations(inputs: tuple, dynamic: bool) -> list:
|
||||
return annotations_list
|
||||
|
||||
|
||||
def shark_jit_trace(
|
||||
module, input: tuple, dynamic: bool, tracing_required: bool
|
||||
):
|
||||
def shark_jit_trace(module, input: tuple, dynamic: bool,
|
||||
tracing_required: bool):
|
||||
"""TODO: Include necessary documentation."""
|
||||
|
||||
if not tracing_required:
|
||||
@@ -77,26 +78,21 @@ def shark_jit_trace(
|
||||
actual_script = traced_module._actual_script_module
|
||||
export(actual_script.forward)
|
||||
annotate_args_decorator = annotate_args(
|
||||
get_input_annotations(input, dynamic)
|
||||
)
|
||||
get_input_annotations(input, dynamic))
|
||||
annotate_args_decorator(actual_script.forward)
|
||||
module = torch.jit.script(actual_script)
|
||||
|
||||
# TODO: remove saved annotations.pickle
|
||||
torchscript_module_bytes = module.save_to_buffer(
|
||||
{
|
||||
"annotations.pkl": pickle.dumps(
|
||||
extract_serializable_annotations(module)
|
||||
)
|
||||
}
|
||||
)
|
||||
serializable_test = SerializableTest(
|
||||
unique_name="", program=torchscript_module_bytes, trace=None
|
||||
)
|
||||
torchscript_module_bytes = module.save_to_buffer({
|
||||
"annotations.pkl":
|
||||
pickle.dumps(extract_serializable_annotations(module))
|
||||
})
|
||||
serializable_test = SerializableTest(unique_name="",
|
||||
program=torchscript_module_bytes,
|
||||
trace=None)
|
||||
_extra_files = {"annotations.pkl": ""}
|
||||
module = torch.jit.load(
|
||||
io.BytesIO(serializable_test.program), _extra_files=_extra_files
|
||||
)
|
||||
module = torch.jit.load(io.BytesIO(serializable_test.program),
|
||||
_extra_files=_extra_files)
|
||||
# Load the pickled annotations.
|
||||
annotations = pickle.loads(_extra_files["annotations.pkl"])
|
||||
apply_serializable_annotations(module, annotations)
|
||||
|
||||
Reference in New Issue
Block a user