Add support of AOT_Module in shark_runner.

Added support of AOT_Module for inference on simple example.
This commit is contained in:
Prashant Kumar
2022-03-12 16:35:55 +00:00
parent 7a0296f359
commit fba169f456
5 changed files with 136 additions and 5 deletions

View File

@@ -0,0 +1,33 @@
import torch
import torch.nn as nn
from shark_runner import SharkInference
class NeuralNet(nn.Module):
def __init__(self):
super(NeuralNet, self).__init__()
self.l1 = nn.Linear(10, 16)
self.relu = nn.ReLU()
self.l2 = nn.Linear(16, 2)
self.train(False)
def forward(self, x):
out = self.l1(x)
out = self.relu(out)
out = self.l2(out)
return out
model = NeuralNet()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
input = torch.randn(1, 10)
labels = torch.randn(1, 2)
shark_module = SharkInference(NeuralNet(), input, from_aot = True)
results = shark_module.forward(input)
print(results)

View File

@@ -11,3 +11,83 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from functorch.compile import (
aot_module,
min_cut_rematerialization_partition,
)
from torch_mlir_utils import get_torch_mlir_module
from torch import optim, fx
from typing import List
class AOTModule:
def __init__(self, model, inputs):
self.model = model
self.inputs = inputs
self.forward_graph = None
self.backward_graph = None
self.forward_inputs = None
self.backward_inputs = None
def inference(self, model, inputs):
iters = 1
with torch.no_grad():
for _ in range(iters):
out = model(inputs)
def train(self, model, inputs):
iters = 1
for _ in range(iters):
model(**inputs).loss.sum().backward()
def change_fx_graph_return_to_tuple(self, fx_g: fx.GraphModule):
for node in fx_g.graph.nodes:
if node.op == 'output':
# output nodes always have one argument
node_arg = node.args[0]
if isinstance(node_arg, list):
node.args = node_arg
fx_g.graph.lint()
fx_g.recompile()
return fx_g
def get_forward_graph(self, fx_g: fx.GraphModule, inps):
fx_g = self.change_fx_graph_return_to_tuple(fx_g)
f = torch.jit.script(fx_g)
f = torch.jit.freeze(f.eval())
torch.jit.save(f, "forw.pt")
f = torch.jit.load("forw.pt")
self.forward_graph = f
self.forward_inputs = inps
return f
def get_backward_graph(self, fx_g: fx.GraphModule, inps):
fx_g = self.change_fx_graph_return_to_tuple(fx_g)
f = torch.jit.script(fx_g)
f = torch.jit.freeze(f.eval())
torch.jit.save(f, "back.pt")
f = torch.jit.load("back.pt")
self.backward_graph = f
self.backward_inputs = inps
return f
def generate_inference_graph(self):
aot_model = aot_module(
self.model,
fw_compiler=self.get_forward_graph,
bw_compiler=self.get_backward_graph,
partition_fn=min_cut_rematerialization_partition,
)
self.inference(aot_model, self.inputs)
def generate_training_graph(self):
aot_model = aot_module(
self.model,
fw_compiler=self.get_forward_graph,
bw_compiler=self.get_backward_graph,
partition_fn=min_cut_rematerialization_partition,
)
self.train(aot_model, self.inputs)

View File

@@ -39,7 +39,7 @@ def get_results(compiled_vm, input):
# TODO: Currently only one output and input is supported.
# Extend it to support multiple inputs and outputs.
result = compiled_vm(input)
result = compiled_vm(*input)
result_numpy = np.asarray(result, dtype=result.dtype)
# TODO: Segfault if the copy of numpy array is not returned.

View File

@@ -68,9 +68,9 @@ labels = load_labels()
input = torch.randn(1,3,224,224)
print(input.shape)
shark_module = SharkInference(Resnet50Module(), (input,))
shark_module = SharkInference(Resnet50Module(), (img,))
results = shark_module.forward(img)
results = shark_module.forward((img,))
print("The top 3 results obtained via torch-mlir via iree-backend is:")
print(top3_possibilities(torch.from_numpy(results)))

View File

@@ -14,6 +14,8 @@
from torch_mlir_utils import get_torch_mlir_module
from iree_utils import get_results, get_iree_compiled_module
from functorch_utils import AOTModule
import numpy as np
class SharkRunner:
@@ -53,12 +55,28 @@ class SharkInference:
jit_trace: bool = False,
from_aot: bool = False,
):
self.model = model
self.input = input
if(from_aot):
aot_module = AOTModule(model, input)
aot_module.generate_inference_graph()
self.model = aot_module.forward_graph
self.input = aot_module.forward_inputs
self.shark_runner = SharkRunner(
model, input, dynamic, device, jit_trace, from_aot
self.model, self.input, dynamic, device, jit_trace, from_aot
)
def forward(self, input):
return self.shark_runner.forward(input)
input_list = []
# TODO Capture weights and inputs in case of AOT, Also rework the
# forward pass.
if(True):
for input in self.input:
input_list.append(input.detach().numpy())
return self.shark_runner.forward(input_list)
class SharkTrainer: