mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add examples of shark_inference via torch.script, torch.jit_trace, and
aot.
This commit is contained in:
40
shark_runner/examples/fullyconnected.py
Normal file
40
shark_runner/examples/fullyconnected.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from shark_runner import SharkInference, SharkTrainer
|
||||
|
||||
|
||||
class NeuralNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(NeuralNet, self).__init__()
|
||||
self.l1 = nn.Linear(10, 16)
|
||||
self.relu = nn.ReLU()
|
||||
self.l2 = nn.Linear(16, 2)
|
||||
self.train(False)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.l1(x)
|
||||
out = self.relu(out)
|
||||
out = self.l2(out)
|
||||
return out
|
||||
|
||||
|
||||
model = NeuralNet()
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
|
||||
|
||||
input = torch.randn(10, 10)
|
||||
labels = torch.randn(1, 2)
|
||||
|
||||
shark_module = SharkInference(NeuralNet(), (input,))
|
||||
results = shark_module.forward((input,))
|
||||
|
||||
# TODO: Currently errors out in torch-mlir lowering pass.
|
||||
# shark_trainer_module = SharkTrainer(
|
||||
# NeuralNet(), (input,), (labels,), dynamic=True, from_aot=True
|
||||
# )
|
||||
|
||||
# results = shark_trainer_module.train(input)
|
||||
|
||||
# print(results)
|
||||
@@ -24,16 +24,10 @@ criterion = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
|
||||
|
||||
input = torch.randn(1, 10)
|
||||
input = torch.randn(10, 10)
|
||||
labels = torch.randn(1, 2)
|
||||
|
||||
shark_module = SharkInference(NeuralNet(), input, from_aot = True)
|
||||
shark_module = SharkInference(NeuralNet(), (input,), from_aot=True)
|
||||
results = shark_module.forward((input,))
|
||||
|
||||
results = shark_module.forward(input)
|
||||
|
||||
#TODO: Currently errors out in torch-mlir lowering pass.
|
||||
shark_trainer_module = SharkTrainer(
|
||||
NeuralNet(), (input,), (labels,), dynamic=True, from_aot=True
|
||||
)
|
||||
|
||||
shark_trainer_module.train(input)
|
||||
print(results)
|
||||
40
shark_runner/examples/fullyconnected_script.py
Normal file
40
shark_runner/examples/fullyconnected_script.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from shark_runner import SharkInference, SharkTrainer
|
||||
|
||||
|
||||
class NeuralNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(NeuralNet, self).__init__()
|
||||
self.l1 = nn.Linear(10, 16)
|
||||
self.relu = nn.ReLU()
|
||||
self.l2 = nn.Linear(16, 2)
|
||||
self.train(False)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.l1(x)
|
||||
out = self.relu(out)
|
||||
out = self.l2(out)
|
||||
return out
|
||||
|
||||
|
||||
model = NeuralNet()
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
|
||||
|
||||
input = torch.randn(10, 10)
|
||||
labels = torch.randn(1, 2)
|
||||
|
||||
shark_module = SharkInference(NeuralNet(), (input,))
|
||||
results = shark_module.forward((input,))
|
||||
|
||||
# TODO: Currently errors out in torch-mlir lowering pass.
|
||||
# shark_trainer_module = SharkTrainer(
|
||||
# NeuralNet(), (input,), (labels,), dynamic=True, from_aot=True
|
||||
# )
|
||||
|
||||
# results = shark_trainer_module.train(input)
|
||||
|
||||
# print(results)
|
||||
41
shark_runner/examples/hugging_face_lm_jit.py
Normal file
41
shark_runner/examples/hugging_face_lm_jit.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from shark_runner import SharkInference
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
# "distilbert-base-uncased"
|
||||
models = ["albert-base-v2", "bert-base-uncased"]
|
||||
|
||||
|
||||
def prepare_sentence_tokens(tokenizer, sentence):
|
||||
return torch.tensor([tokenizer.encode(sentence)])
|
||||
|
||||
|
||||
class HuggingFaceLanguage(torch.nn.Module):
|
||||
def __init__(self, hf_model_name):
|
||||
super().__init__()
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
hf_model_name, # The pretrained model.
|
||||
num_labels=2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=False, # Whether the model returns all hidden-states.
|
||||
torchscript=True,
|
||||
)
|
||||
|
||||
def forward(self, tokens):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
|
||||
for hf_model in models:
|
||||
tokenizer = AutoTokenizer.from_pretrained(hf_model)
|
||||
test_input = prepare_sentence_tokens(
|
||||
tokenizer, "this project is very interesting"
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
HuggingFaceLanguage(hf_model),
|
||||
(test_input,),
|
||||
jit_trace=True,
|
||||
)
|
||||
shark_module.forward((test_input,))
|
||||
35
shark_runner/examples/minilm_jit.py
Normal file
35
shark_runner/examples/minilm_jit.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import transforms
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from shark_runner import SharkInference, SharkTrainer
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
|
||||
|
||||
|
||||
def _prepare_sentence_tokens(sentence: str):
|
||||
return torch.tensor([tokenizer.encode(sentence)])
|
||||
|
||||
|
||||
class MiniLMSequenceClassification(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased", # The pretrained model.
|
||||
num_labels=2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=False, # Whether the model returns all hidden-states.
|
||||
torchscript=True,
|
||||
)
|
||||
|
||||
def forward(self, tokens):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
|
||||
test_input = _prepare_sentence_tokens("this project is very interesting")
|
||||
|
||||
shark_module = SharkInference(
|
||||
MiniLMSequenceClassification(), (test_input,), device="cpu", jit_trace=True
|
||||
)
|
||||
results = shark_module.forward((test_input,))
|
||||
@@ -68,8 +68,10 @@ labels = load_labels()
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
print(input.shape)
|
||||
|
||||
## The img is passed to determine the input shape.
|
||||
shark_module = SharkInference(Resnet50Module(), (img,))
|
||||
|
||||
## Can pass any img or input to the forward module.
|
||||
results = shark_module.forward((img,))
|
||||
|
||||
print("The top 3 results obtained via torch-mlir via iree-backend is:")
|
||||
45
shark_runner/examples/torch_vision_models_script.py
Normal file
45
shark_runner/examples/torch_vision_models_script.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
from shark_runner import SharkInference
|
||||
|
||||
|
||||
class VisionModule(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.train(False)
|
||||
|
||||
def forward(self, input):
|
||||
return self.model.forward(input)
|
||||
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
## The vision models present here: https://pytorch.org/vision/stable/models.html
|
||||
vision_models_list = [
|
||||
models.resnet18(pretrained=True),
|
||||
models.alexnet(pretrained=True),
|
||||
models.vgg16(pretrained=True),
|
||||
models.squeezenet1_0(pretrained=True),
|
||||
models.densenet161(pretrained=True),
|
||||
models.inception_v3(pretrained=True),
|
||||
models.shufflenet_v2_x1_0(pretrained=True),
|
||||
models.mobilenet_v2(pretrained=True),
|
||||
models.mobilenet_v3_small(pretrained=True),
|
||||
models.resnext50_32x4d(pretrained=True),
|
||||
models.wide_resnet50_2(pretrained=True),
|
||||
models.mnasnet1_0(pretrained=True),
|
||||
models.efficientnet_b0(pretrained=True),
|
||||
models.regnet_y_400mf(pretrained=True),
|
||||
models.regnet_x_400mf(pretrained=True),
|
||||
]
|
||||
|
||||
for i, vision_model in enumerate(vision_models_list):
|
||||
shark_module = SharkInference(
|
||||
VisionModule(vision_model),
|
||||
(input,),
|
||||
)
|
||||
shark_module.forward((input,))
|
||||
33
shark_runner/examples/unet_script.py
Normal file
33
shark_runner/examples/unet_script.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
from shark_runner import SharkInference
|
||||
|
||||
# Currently not supported aten.transpose_conv2d missing.
|
||||
class UnetModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = torch.hub.load(
|
||||
"mateuszbuda/brain-segmentation-pytorch",
|
||||
"unet",
|
||||
in_channels=3,
|
||||
out_channels=1,
|
||||
init_features=32,
|
||||
pretrained=True,
|
||||
)
|
||||
self.train(False)
|
||||
|
||||
def forward(self, input):
|
||||
return self.model(input)
|
||||
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
print(input)
|
||||
shark_module = SharkInference(
|
||||
UnetModule(),
|
||||
(input,),
|
||||
)
|
||||
shark_module.forward((input,))
|
||||
print(input)
|
||||
@@ -23,7 +23,7 @@ from typing import List
|
||||
|
||||
|
||||
class AOTModule:
|
||||
def __init__(self, model, inputs, labels = None):
|
||||
def __init__(self, model, inputs, labels=None):
|
||||
self.model = model
|
||||
self.inputs = inputs
|
||||
self.labels = labels
|
||||
@@ -36,7 +36,7 @@ class AOTModule:
|
||||
iters = 1
|
||||
with torch.no_grad():
|
||||
for _ in range(iters):
|
||||
out = model(inputs)
|
||||
out = model(*inputs)
|
||||
|
||||
def train(self, model, inputs, labels):
|
||||
# TODO: Pass the criterion and optimizer.
|
||||
@@ -52,11 +52,13 @@ class AOTModule:
|
||||
|
||||
def change_fx_graph_return_to_tuple(self, fx_g: fx.GraphModule):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == 'output':
|
||||
if node.op == "output":
|
||||
# output nodes always have one argument
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, list):
|
||||
node.args = (tuple(node_arg),)
|
||||
# TODO: Check why return of tuple is not working.
|
||||
# node.args = (tuple(node_arg),)
|
||||
node.args = node_arg
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return fx_g
|
||||
|
||||
@@ -37,11 +37,9 @@ def get_iree_compiled_module(module, device: str):
|
||||
def get_results(compiled_vm, input):
|
||||
"""TODO: Documentation"""
|
||||
|
||||
# TODO: Currently only one output and input is supported.
|
||||
# Extend it to support multiple inputs and outputs.
|
||||
# TODO: Support returning multiple outputs.
|
||||
result = compiled_vm(*input)
|
||||
result_numpy = np.asarray(result, dtype=result.dtype)
|
||||
|
||||
# TODO: Segfault if the copy of numpy array is not returned.
|
||||
result_copy = np.copy(result_numpy)
|
||||
return result_copy
|
||||
|
||||
@@ -57,6 +57,7 @@ class SharkInference:
|
||||
):
|
||||
self.model = model
|
||||
self.input = input
|
||||
self.from_aot = from_aot
|
||||
|
||||
if from_aot:
|
||||
aot_module = AOTModule(model, input)
|
||||
@@ -68,14 +69,11 @@ class SharkInference:
|
||||
self.model, self.input, dynamic, device, jit_trace, from_aot
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
input_list = []
|
||||
def forward(self, inputs):
|
||||
# TODO Capture weights and inputs in case of AOT, Also rework the
|
||||
# forward pass.
|
||||
if True:
|
||||
for input in self.input:
|
||||
input_list.append(input.detach().numpy())
|
||||
|
||||
inputs = self.input if self.from_aot else inputs
|
||||
input_list = [x.detach().numpy() for x in inputs]
|
||||
return self.shark_runner.forward(input_list)
|
||||
|
||||
|
||||
@@ -111,27 +109,27 @@ class SharkTrainer:
|
||||
jit_trace,
|
||||
from_aot,
|
||||
)
|
||||
self.shark_backward = SharkRunner(
|
||||
self.backward_graph,
|
||||
self.backward_inputs,
|
||||
dynamic,
|
||||
device,
|
||||
jit_trace,
|
||||
from_aot,
|
||||
)
|
||||
# self.shark_backward = SharkRunner(
|
||||
# self.backward_graph,
|
||||
# self.backward_inputs,
|
||||
# dynamic,
|
||||
# device,
|
||||
# jit_trace,
|
||||
# from_aot,
|
||||
# )
|
||||
|
||||
def train(self, input):
|
||||
forward_inputs = []
|
||||
backward_inputs = []
|
||||
for input in self.forward_inputs:
|
||||
forward_inputs.append(input.detach().numpy())
|
||||
for input in self.backward_inputs:
|
||||
backward_inputs.append(input.detach().numpy())
|
||||
def train(self, input):
|
||||
forward_inputs = []
|
||||
backward_inputs = []
|
||||
for input in self.forward_inputs:
|
||||
forward_inputs.append(input.detach().numpy())
|
||||
for input in self.backward_inputs:
|
||||
backward_inputs.append(input.detach().numpy())
|
||||
|
||||
# TODO: Pass the iter variable, and optimizer.
|
||||
iters = 1
|
||||
# TODO: Pass the iter variable, and optimizer.
|
||||
iters = 1
|
||||
|
||||
for _ in range(iters):
|
||||
self.shark_runner.forward(forward_inputs)
|
||||
self.shark_runner.forward(backward_inputs)
|
||||
return
|
||||
for _ in range(iters):
|
||||
self.shark_forward.forward(forward_inputs)
|
||||
# self.shark_backward.forward(backward_inputs)
|
||||
return
|
||||
|
||||
@@ -62,18 +62,14 @@ def shark_jit_trace(
|
||||
if not tracing_required:
|
||||
return torch.jit.script(module)
|
||||
|
||||
# TODO: Currently, the jit trace accepts only one input.
|
||||
if len(input) != 1:
|
||||
sys.exit("Currently, the jit_trace accepts only one input")
|
||||
|
||||
traced_module = torch.jit.trace_module(module, {"forward": input[0]})
|
||||
traced_module = torch.jit.trace_module(module, {"forward": input})
|
||||
actual_script = traced_module._actual_script_module
|
||||
export(script_module.forward)
|
||||
export(actual_script.forward)
|
||||
annotate_args_decorator = annotate_args(
|
||||
get_input_annotations(input, dynamic)
|
||||
)
|
||||
annotate_args_decorator(script_module.forward)
|
||||
module = torch.jit.script(script_module)
|
||||
annotate_args_decorator(actual_script.forward)
|
||||
module = torch.jit.script(actual_script)
|
||||
|
||||
# TODO: remove saved annotations.pickle
|
||||
torchscript_module_bytes = module.save_to_buffer(
|
||||
@@ -120,11 +116,11 @@ def get_torch_mlir_module(
|
||||
)
|
||||
mb.import_module(module._c, class_annotator)
|
||||
|
||||
mb.module.dump()
|
||||
with mb.module.context:
|
||||
pm = PassManager.parse(
|
||||
"torchscript-module-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline"
|
||||
)
|
||||
pm.run(mb.module)
|
||||
|
||||
|
||||
return mb.module
|
||||
|
||||
Reference in New Issue
Block a user