mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add support of AOT_Module in shark_runner.
Added support of AOT_Module for inference on simple example.
This commit is contained in:
33
shark_runner/fullyconnected.py
Normal file
33
shark_runner/fullyconnected.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from shark_runner import SharkInference
|
||||
|
||||
class NeuralNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(NeuralNet, self).__init__()
|
||||
self.l1 = nn.Linear(10, 16)
|
||||
self.relu = nn.ReLU()
|
||||
self.l2 = nn.Linear(16, 2)
|
||||
self.train(False)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.l1(x)
|
||||
out = self.relu(out)
|
||||
out = self.l2(out)
|
||||
return out
|
||||
|
||||
model = NeuralNet()
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
|
||||
|
||||
input = torch.randn(1, 10)
|
||||
labels = torch.randn(1, 2)
|
||||
|
||||
shark_module = SharkInference(NeuralNet(), input, from_aot = True)
|
||||
|
||||
results = shark_module.forward(input)
|
||||
|
||||
print(results)
|
||||
|
||||
@@ -11,3 +11,83 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from functorch.compile import (
|
||||
aot_module,
|
||||
min_cut_rematerialization_partition,
|
||||
)
|
||||
from torch_mlir_utils import get_torch_mlir_module
|
||||
from torch import optim, fx
|
||||
from typing import List
|
||||
|
||||
|
||||
class AOTModule:
|
||||
def __init__(self, model, inputs):
|
||||
self.model = model
|
||||
self.inputs = inputs
|
||||
self.forward_graph = None
|
||||
self.backward_graph = None
|
||||
self.forward_inputs = None
|
||||
self.backward_inputs = None
|
||||
|
||||
def inference(self, model, inputs):
|
||||
iters = 1
|
||||
with torch.no_grad():
|
||||
for _ in range(iters):
|
||||
out = model(inputs)
|
||||
|
||||
def train(self, model, inputs):
|
||||
iters = 1
|
||||
for _ in range(iters):
|
||||
model(**inputs).loss.sum().backward()
|
||||
|
||||
def change_fx_graph_return_to_tuple(self, fx_g: fx.GraphModule):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == 'output':
|
||||
# output nodes always have one argument
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, list):
|
||||
node.args = node_arg
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return fx_g
|
||||
|
||||
|
||||
def get_forward_graph(self, fx_g: fx.GraphModule, inps):
|
||||
fx_g = self.change_fx_graph_return_to_tuple(fx_g)
|
||||
f = torch.jit.script(fx_g)
|
||||
f = torch.jit.freeze(f.eval())
|
||||
torch.jit.save(f, "forw.pt")
|
||||
f = torch.jit.load("forw.pt")
|
||||
self.forward_graph = f
|
||||
self.forward_inputs = inps
|
||||
return f
|
||||
|
||||
def get_backward_graph(self, fx_g: fx.GraphModule, inps):
|
||||
fx_g = self.change_fx_graph_return_to_tuple(fx_g)
|
||||
f = torch.jit.script(fx_g)
|
||||
f = torch.jit.freeze(f.eval())
|
||||
torch.jit.save(f, "back.pt")
|
||||
f = torch.jit.load("back.pt")
|
||||
self.backward_graph = f
|
||||
self.backward_inputs = inps
|
||||
return f
|
||||
|
||||
def generate_inference_graph(self):
|
||||
aot_model = aot_module(
|
||||
self.model,
|
||||
fw_compiler=self.get_forward_graph,
|
||||
bw_compiler=self.get_backward_graph,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
self.inference(aot_model, self.inputs)
|
||||
|
||||
def generate_training_graph(self):
|
||||
aot_model = aot_module(
|
||||
self.model,
|
||||
fw_compiler=self.get_forward_graph,
|
||||
bw_compiler=self.get_backward_graph,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
self.train(aot_model, self.inputs)
|
||||
|
||||
@@ -39,7 +39,7 @@ def get_results(compiled_vm, input):
|
||||
|
||||
# TODO: Currently only one output and input is supported.
|
||||
# Extend it to support multiple inputs and outputs.
|
||||
result = compiled_vm(input)
|
||||
result = compiled_vm(*input)
|
||||
result_numpy = np.asarray(result, dtype=result.dtype)
|
||||
|
||||
# TODO: Segfault if the copy of numpy array is not returned.
|
||||
|
||||
@@ -68,9 +68,9 @@ labels = load_labels()
|
||||
input = torch.randn(1,3,224,224)
|
||||
print(input.shape)
|
||||
|
||||
shark_module = SharkInference(Resnet50Module(), (input,))
|
||||
shark_module = SharkInference(Resnet50Module(), (img,))
|
||||
|
||||
results = shark_module.forward(img)
|
||||
results = shark_module.forward((img,))
|
||||
|
||||
print("The top 3 results obtained via torch-mlir via iree-backend is:")
|
||||
print(top3_possibilities(torch.from_numpy(results)))
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
|
||||
from torch_mlir_utils import get_torch_mlir_module
|
||||
from iree_utils import get_results, get_iree_compiled_module
|
||||
from functorch_utils import AOTModule
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SharkRunner:
|
||||
@@ -53,12 +55,28 @@ class SharkInference:
|
||||
jit_trace: bool = False,
|
||||
from_aot: bool = False,
|
||||
):
|
||||
self.model = model
|
||||
self.input = input
|
||||
|
||||
if(from_aot):
|
||||
aot_module = AOTModule(model, input)
|
||||
aot_module.generate_inference_graph()
|
||||
self.model = aot_module.forward_graph
|
||||
self.input = aot_module.forward_inputs
|
||||
|
||||
self.shark_runner = SharkRunner(
|
||||
model, input, dynamic, device, jit_trace, from_aot
|
||||
self.model, self.input, dynamic, device, jit_trace, from_aot
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.shark_runner.forward(input)
|
||||
input_list = []
|
||||
# TODO Capture weights and inputs in case of AOT, Also rework the
|
||||
# forward pass.
|
||||
if(True):
|
||||
for input in self.input:
|
||||
input_list.append(input.detach().numpy())
|
||||
|
||||
return self.shark_runner.forward(input_list)
|
||||
|
||||
|
||||
class SharkTrainer:
|
||||
|
||||
Reference in New Issue
Block a user