Add resnest model to the shark_inference examples list.

This commit is contained in:
Prashant Kumar
2022-07-07 18:48:01 +05:30
parent f49a2c3df4
commit 2e5cb4ba76

View File

@@ -0,0 +1,35 @@
import torch
import torchvision.models as models
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
torch.hub.list("zhanghang1989/ResNeSt", force_reload=True)
class ResnestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = torch.hub.load("zhanghang1989/ResNeSt", "resnest50", pretrained=True)
self.model.eval()
def forward(self, input):
return self.model.forward(input)
input = torch.randn(1, 3, 224, 224)
mlir_importer = SharkImporter(
ResnestModule(),
(input),
frontend="torch",
)
(vision_mlir, func_name), inputs, golden_out = mlir_importer.import_debug(tracing_required=True)
print(golden_out)
shark_module = SharkInference(vision_mlir, func_name, device="cpu", mlir_dialect="linalg")
shark_module.compile()
result = shark_module.forward((input))
print("Obtained result", result)