mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Update the resnet50 example to use the shark_downloader.
The resnet50 example is updated to use the shark_downloader instead of shark_importer and inference.
This commit is contained in:
@@ -5,6 +5,7 @@ import torchvision.models as models
|
||||
from torchvision import transforms
|
||||
import sys
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_torch_model
|
||||
|
||||
|
||||
################################## Preprocessing inputs and model ############
|
||||
@@ -63,18 +64,18 @@ labels = load_labels()
|
||||
|
||||
##############################################################################
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
print(input.shape)
|
||||
|
||||
## The img is passed to determine the input shape.
|
||||
shark_module = SharkInference(Resnet50Module(), (img,))
|
||||
shark_module.compile()
|
||||
|
||||
## Can pass any img or input to the forward module.
|
||||
results = shark_module.forward((img,))
|
||||
mlir_model, func_name, inputs, golden_out = download_torch_model("resnet50")
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((img.detach().numpy(),))
|
||||
|
||||
print("The top 3 results obtained via shark_runner is:")
|
||||
print(top3_possibilities(torch.from_numpy(results)))
|
||||
print(top3_possibilities(torch.from_numpy(result)))
|
||||
|
||||
print()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user