Update the resnet50 example to use the shark_downloader.

The resnet50 example is updated to use the shark_downloader instead
of shark_importer and inference.
This commit is contained in:
Prashant Kumar
2022-07-18 10:57:53 +05:30
parent 1191f53c9d
commit 54a642e76a

View File

@@ -5,6 +5,7 @@ import torchvision.models as models
from torchvision import transforms
import sys
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_torch_model
################################## Preprocessing inputs and model ############
@@ -63,18 +64,18 @@ labels = load_labels()
##############################################################################
input = torch.randn(1, 3, 224, 224)
print(input.shape)
## The img is passed to determine the input shape.
shark_module = SharkInference(Resnet50Module(), (img,))
shark_module.compile()
## Can pass any img or input to the forward module.
results = shark_module.forward((img,))
mlir_model, func_name, inputs, golden_out = download_torch_model("resnet50")
shark_module = SharkInference(
mlir_model, func_name, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
result = shark_module.forward((img.detach().numpy(),))
print("The top 3 results obtained via shark_runner is:")
print(top3_possibilities(torch.from_numpy(results)))
print(top3_possibilities(torch.from_numpy(result)))
print()