mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add support for training in shark_runner.
Added support for training via AOT_module in shark_runner.
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from shark_runner import SharkInference
|
||||
from shark_runner import SharkInference, SharkTrainer
|
||||
|
||||
|
||||
class NeuralNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(NeuralNet, self).__init__()
|
||||
self.l1 = nn.Linear(10, 16)
|
||||
self.l1 = nn.Linear(10, 16)
|
||||
self.relu = nn.ReLU()
|
||||
self.l2 = nn.Linear(16, 2)
|
||||
self.train(False)
|
||||
@@ -16,6 +17,7 @@ class NeuralNet(nn.Module):
|
||||
out = self.l2(out)
|
||||
return out
|
||||
|
||||
|
||||
model = NeuralNet()
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
@@ -29,5 +31,9 @@ shark_module = SharkInference(NeuralNet(), input, from_aot = True)
|
||||
|
||||
results = shark_module.forward(input)
|
||||
|
||||
print(results)
|
||||
#TODO: Currently errors out in torch-mlir lowering pass.
|
||||
shark_trainer_module = SharkTrainer(
|
||||
NeuralNet(), (input,), (labels,), dynamic=True, from_aot=True
|
||||
)
|
||||
|
||||
shark_trainer_module.train(input)
|
||||
|
||||
@@ -23,9 +23,10 @@ from typing import List
|
||||
|
||||
|
||||
class AOTModule:
|
||||
def __init__(self, model, inputs):
|
||||
def __init__(self, model, inputs, labels = None):
|
||||
self.model = model
|
||||
self.inputs = inputs
|
||||
self.labels = labels
|
||||
self.forward_graph = None
|
||||
self.backward_graph = None
|
||||
self.forward_inputs = None
|
||||
@@ -37,10 +38,17 @@ class AOTModule:
|
||||
for _ in range(iters):
|
||||
out = model(inputs)
|
||||
|
||||
def train(self, model, inputs):
|
||||
def train(self, model, inputs, labels):
|
||||
# TODO: Pass the criterion and optimizer.
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
iters = 1
|
||||
for _ in range(iters):
|
||||
model(**inputs).loss.sum().backward()
|
||||
optimizer.zero_grad()
|
||||
output = model(*inputs)
|
||||
loss = criterion(output, *labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
def change_fx_graph_return_to_tuple(self, fx_g: fx.GraphModule):
|
||||
for node in fx_g.graph.nodes:
|
||||
@@ -48,12 +56,11 @@ class AOTModule:
|
||||
# output nodes always have one argument
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, list):
|
||||
node.args = node_arg
|
||||
node.args = (tuple(node_arg),)
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return fx_g
|
||||
|
||||
|
||||
def get_forward_graph(self, fx_g: fx.GraphModule, inps):
|
||||
fx_g = self.change_fx_graph_return_to_tuple(fx_g)
|
||||
f = torch.jit.script(fx_g)
|
||||
@@ -90,4 +97,4 @@ class AOTModule:
|
||||
bw_compiler=self.get_backward_graph,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
self.train(aot_model, self.inputs)
|
||||
self.train(aot_model, self.inputs, self.labels)
|
||||
|
||||
@@ -65,7 +65,7 @@ labels = load_labels()
|
||||
|
||||
##############################################################################
|
||||
|
||||
input = torch.randn(1,3,224,224)
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
print(input.shape)
|
||||
|
||||
shark_module = SharkInference(Resnet50Module(), (img,))
|
||||
|
||||
@@ -58,7 +58,7 @@ class SharkInference:
|
||||
self.model = model
|
||||
self.input = input
|
||||
|
||||
if(from_aot):
|
||||
if from_aot:
|
||||
aot_module = AOTModule(model, input)
|
||||
aot_module.generate_inference_graph()
|
||||
self.model = aot_module.forward_graph
|
||||
@@ -70,17 +70,68 @@ class SharkInference:
|
||||
|
||||
def forward(self, input):
|
||||
input_list = []
|
||||
# TODO Capture weights and inputs in case of AOT, Also rework the
|
||||
# TODO Capture weights and inputs in case of AOT, Also rework the
|
||||
# forward pass.
|
||||
if(True):
|
||||
if True:
|
||||
for input in self.input:
|
||||
input_list.append(input.detach().numpy())
|
||||
|
||||
|
||||
return self.shark_runner.forward(input_list)
|
||||
|
||||
|
||||
class SharkTrainer:
|
||||
"""TODO: Write the description"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
input: tuple,
|
||||
label: tuple,
|
||||
dynamic: bool = False,
|
||||
device: str = "cpu",
|
||||
jit_trace: bool = False,
|
||||
from_aot: bool = True,
|
||||
):
|
||||
|
||||
self.model = model
|
||||
self.input = input
|
||||
self.label = label
|
||||
aot_module = AOTModule(model, input, label)
|
||||
aot_module.generate_training_graph()
|
||||
self.forward_graph = aot_module.forward_graph
|
||||
self.forward_inputs = aot_module.forward_inputs
|
||||
self.backward_graph = aot_module.backward_graph
|
||||
self.backward_inputs = aot_module.backward_inputs
|
||||
|
||||
self.shark_forward = SharkRunner(
|
||||
self.forward_graph,
|
||||
self.forward_inputs,
|
||||
dynamic,
|
||||
device,
|
||||
jit_trace,
|
||||
from_aot,
|
||||
)
|
||||
self.shark_backward = SharkRunner(
|
||||
self.backward_graph,
|
||||
self.backward_inputs,
|
||||
dynamic,
|
||||
device,
|
||||
jit_trace,
|
||||
from_aot,
|
||||
)
|
||||
|
||||
def train(self, input):
|
||||
forward_inputs = []
|
||||
backward_inputs = []
|
||||
for input in self.forward_inputs:
|
||||
forward_inputs.append(input.detach().numpy())
|
||||
for input in self.backward_inputs:
|
||||
backward_inputs.append(input.detach().numpy())
|
||||
|
||||
# TODO: Pass the iter variable, and optimizer.
|
||||
iters = 1
|
||||
|
||||
for _ in range(iters):
|
||||
self.shark_runner.forward(forward_inputs)
|
||||
self.shark_runner.forward(backward_inputs)
|
||||
return
|
||||
|
||||
@@ -69,7 +69,9 @@ def shark_jit_trace(
|
||||
traced_module = torch.jit.trace_module(module, {"forward": input[0]})
|
||||
actual_script = traced_module._actual_script_module
|
||||
export(script_module.forward)
|
||||
annotate_args_decorator = annotate_args(get_input_annotations(input, dynamic))
|
||||
annotate_args_decorator = annotate_args(
|
||||
get_input_annotations(input, dynamic)
|
||||
)
|
||||
annotate_args_decorator(script_module.forward)
|
||||
module = torch.jit.script(script_module)
|
||||
|
||||
@@ -112,7 +114,9 @@ def get_torch_mlir_module(
|
||||
class_annotator.exportNone(module._c._type())
|
||||
class_annotator.exportPath(module._c._type(), ["forward"])
|
||||
class_annotator.annotateArgs(
|
||||
module._c._type(), ["forward"], get_input_annotations(input, dynamic),
|
||||
module._c._type(),
|
||||
["forward"],
|
||||
get_input_annotations(input, dynamic),
|
||||
)
|
||||
mb.import_module(module._c, class_annotator)
|
||||
|
||||
@@ -121,5 +125,6 @@ def get_torch_mlir_module(
|
||||
"torchscript-module-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline"
|
||||
)
|
||||
pm.run(mb.module)
|
||||
|
||||
|
||||
return mb.module
|
||||
|
||||
Reference in New Issue
Block a user