Add vgg and alexnet models to passing tests.

This commit is contained in:
Prashant Kumar
2022-05-27 15:39:14 +00:00
parent cee02f6a61
commit 04291fdfcb
3 changed files with 31 additions and 3 deletions

View File

@@ -41,4 +41,4 @@ jobs:
cd $GITHUB_WORKSPACE
./setup_venv.sh
source shark.venv/bin/activate
pytest --workers auto
pytest

View File

@@ -93,7 +93,7 @@ print(shark_module.forward((arg0, arg1)))
| TORCHVISION Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|--------------------|----------------------|----------|----------|-------------|
| AlexNet | :heavy_check_mark: (Script) | | | |
| AlexNet | :heavy_check_mark: (Script) | :heavy_check_mark: | :heavy_check_mark: | |
| DenseNet121 | :heavy_check_mark: (Script) | | | |
| MNasNet1_0 | :heavy_check_mark: (Script) | | | |
| MobileNetV2 | :heavy_check_mark: (Script) | | | |
@@ -108,7 +108,7 @@ print(shark_module.forward((arg0, arg1)))
| Regnet | :heavy_check_mark: (Script) | | | |
| Resnest | :x: (Script) | | | |
| Vision Transformer | :heavy_check_mark: (Script) | | | |
| VGG 16 | :heavy_check_mark: (Script) | | | |
| VGG 16 | :heavy_check_mark: (Script) | :heavy_check_mark: | :heavy_check_mark: | |
| Wide Resnet | :heavy_check_mark: (Script) | :heavy_check_mark: | :heavy_check_mark: | |
| RAFT | :x: (JIT) | | | |

View File

@@ -206,3 +206,31 @@ def test_squeezenet(dynamic, device):
shark_module.compile()
results = shark_module.forward((input,))
assert True == compare_tensors(act_out, results)
@pytest_param
def test_vgg16(dynamic, device):
model, input, act_out = get_vision_model(
models.vgg16(pretrained=True))
shark_module = SharkInference(
model,
(input,),
device=device,
dynamic=dynamic,
)
shark_module.compile()
results = shark_module.forward((input,))
assert True == compare_tensors(act_out, results)
@pytest_param
def test_alexnet(dynamic, device):
model, input, act_out = get_vision_model(
models.alexnet(pretrained=True))
shark_module = SharkInference(
model,
(input,),
device=device,
dynamic=dynamic,
)
shark_module.compile()
results = shark_module.forward((input,))
assert True == compare_tensors(act_out, results)