mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-08 21:38:04 -05:00
44 lines
1.3 KiB
Python
44 lines
1.3 KiB
Python
import torch
|
|
import torchvision.models as models
|
|
from amdshark.amdshark_inference import AMDSharkInference
|
|
|
|
|
|
class VisionModule(torch.nn.Module):
|
|
def __init__(self, model):
|
|
super().__init__()
|
|
self.model = model
|
|
self.train(False)
|
|
|
|
def forward(self, input):
|
|
return self.model.forward(input)
|
|
|
|
|
|
input = torch.randn(1, 3, 224, 224)
|
|
|
|
## The vision models present here: https://pytorch.org/vision/stable/models.html
|
|
vision_models_list = [
|
|
models.resnet18(pretrained=True),
|
|
models.alexnet(pretrained=True),
|
|
models.vgg16(pretrained=True),
|
|
models.squeezenet1_0(pretrained=True),
|
|
models.densenet161(pretrained=True),
|
|
models.inception_v3(pretrained=True),
|
|
models.shufflenet_v2_x1_0(pretrained=True),
|
|
models.mobilenet_v2(pretrained=True),
|
|
models.mobilenet_v3_small(pretrained=True),
|
|
models.resnext50_32x4d(pretrained=True),
|
|
models.wide_resnet50_2(pretrained=True),
|
|
models.mnasnet1_0(pretrained=True),
|
|
models.efficientnet_b0(pretrained=True),
|
|
models.regnet_y_400mf(pretrained=True),
|
|
models.regnet_x_400mf(pretrained=True),
|
|
]
|
|
|
|
for i, vision_model in enumerate(vision_models_list):
|
|
amdshark_module = AMDSharkInference(
|
|
VisionModule(vision_model),
|
|
(input,),
|
|
)
|
|
amdshark_module.compile()
|
|
amdshark_module.forward((input,))
|