Add support for training in shark_runner.

Added support for training via AOT_module in shark_runner.
This commit is contained in:
Prashant Kumar
2022-03-13 16:01:04 +00:00
parent fba169f456
commit e6115da192
5 changed files with 87 additions and 18 deletions

View File

@@ -1,11 +1,12 @@
import torch
import torch.nn as nn
from shark_runner import SharkInference
from shark_runner import SharkInference, SharkTrainer
class NeuralNet(nn.Module):
def __init__(self):
super(NeuralNet, self).__init__()
self.l1 = nn.Linear(10, 16)
self.l1 = nn.Linear(10, 16)
self.relu = nn.ReLU()
self.l2 = nn.Linear(16, 2)
self.train(False)
@@ -16,6 +17,7 @@ class NeuralNet(nn.Module):
out = self.l2(out)
return out
model = NeuralNet()
criterion = nn.CrossEntropyLoss()
@@ -29,5 +31,9 @@ shark_module = SharkInference(NeuralNet(), input, from_aot = True)
results = shark_module.forward(input)
print(results)
#TODO: Currently errors out in torch-mlir lowering pass.
shark_trainer_module = SharkTrainer(
NeuralNet(), (input,), (labels,), dynamic=True, from_aot=True
)
shark_trainer_module.train(input)

View File

@@ -23,9 +23,10 @@ from typing import List
class AOTModule:
def __init__(self, model, inputs):
def __init__(self, model, inputs, labels = None):
self.model = model
self.inputs = inputs
self.labels = labels
self.forward_graph = None
self.backward_graph = None
self.forward_inputs = None
@@ -37,10 +38,17 @@ class AOTModule:
for _ in range(iters):
out = model(inputs)
def train(self, model, inputs):
def train(self, model, inputs, labels):
# TODO: Pass the criterion and optimizer.
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
iters = 1
for _ in range(iters):
model(**inputs).loss.sum().backward()
optimizer.zero_grad()
output = model(*inputs)
loss = criterion(output, *labels)
loss.backward()
optimizer.step()
def change_fx_graph_return_to_tuple(self, fx_g: fx.GraphModule):
for node in fx_g.graph.nodes:
@@ -48,12 +56,11 @@ class AOTModule:
# output nodes always have one argument
node_arg = node.args[0]
if isinstance(node_arg, list):
node.args = node_arg
node.args = (tuple(node_arg),)
fx_g.graph.lint()
fx_g.recompile()
return fx_g
def get_forward_graph(self, fx_g: fx.GraphModule, inps):
fx_g = self.change_fx_graph_return_to_tuple(fx_g)
f = torch.jit.script(fx_g)
@@ -90,4 +97,4 @@ class AOTModule:
bw_compiler=self.get_backward_graph,
partition_fn=min_cut_rematerialization_partition,
)
self.train(aot_model, self.inputs)
self.train(aot_model, self.inputs, self.labels)

View File

@@ -65,7 +65,7 @@ labels = load_labels()
##############################################################################
input = torch.randn(1,3,224,224)
input = torch.randn(1, 3, 224, 224)
print(input.shape)
shark_module = SharkInference(Resnet50Module(), (img,))

View File

@@ -58,7 +58,7 @@ class SharkInference:
self.model = model
self.input = input
if(from_aot):
if from_aot:
aot_module = AOTModule(model, input)
aot_module.generate_inference_graph()
self.model = aot_module.forward_graph
@@ -70,17 +70,68 @@ class SharkInference:
def forward(self, input):
input_list = []
# TODO Capture weights and inputs in case of AOT, Also rework the
# TODO Capture weights and inputs in case of AOT, Also rework the
# forward pass.
if(True):
if True:
for input in self.input:
input_list.append(input.detach().numpy())
return self.shark_runner.forward(input_list)
class SharkTrainer:
"""TODO: Write the description"""
def __init__(self):
pass
def __init__(
self,
model,
input: tuple,
label: tuple,
dynamic: bool = False,
device: str = "cpu",
jit_trace: bool = False,
from_aot: bool = True,
):
self.model = model
self.input = input
self.label = label
aot_module = AOTModule(model, input, label)
aot_module.generate_training_graph()
self.forward_graph = aot_module.forward_graph
self.forward_inputs = aot_module.forward_inputs
self.backward_graph = aot_module.backward_graph
self.backward_inputs = aot_module.backward_inputs
self.shark_forward = SharkRunner(
self.forward_graph,
self.forward_inputs,
dynamic,
device,
jit_trace,
from_aot,
)
self.shark_backward = SharkRunner(
self.backward_graph,
self.backward_inputs,
dynamic,
device,
jit_trace,
from_aot,
)
def train(self, input):
forward_inputs = []
backward_inputs = []
for input in self.forward_inputs:
forward_inputs.append(input.detach().numpy())
for input in self.backward_inputs:
backward_inputs.append(input.detach().numpy())
# TODO: Pass the iter variable, and optimizer.
iters = 1
for _ in range(iters):
self.shark_runner.forward(forward_inputs)
self.shark_runner.forward(backward_inputs)
return

View File

@@ -69,7 +69,9 @@ def shark_jit_trace(
traced_module = torch.jit.trace_module(module, {"forward": input[0]})
actual_script = traced_module._actual_script_module
export(script_module.forward)
annotate_args_decorator = annotate_args(get_input_annotations(input, dynamic))
annotate_args_decorator = annotate_args(
get_input_annotations(input, dynamic)
)
annotate_args_decorator(script_module.forward)
module = torch.jit.script(script_module)
@@ -112,7 +114,9 @@ def get_torch_mlir_module(
class_annotator.exportNone(module._c._type())
class_annotator.exportPath(module._c._type(), ["forward"])
class_annotator.annotateArgs(
module._c._type(), ["forward"], get_input_annotations(input, dynamic),
module._c._type(),
["forward"],
get_input_annotations(input, dynamic),
)
mb.import_module(module._c, class_annotator)
@@ -121,5 +125,6 @@ def get_torch_mlir_module(
"torchscript-module-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline"
)
pm.run(mb.module)
return mb.module