Update pytorch tests to support vulkan and cuda.

All the model validation pass except distilbert which is failing
in torch-mlir lowering. Also, added the mobilebert-uncased model to
the torch test suite.
This commit is contained in:
Prashant Kumar
2022-07-08 13:12:04 +05:30
parent 9cc92d0e7d
commit fa7ee7e099
12 changed files with 139 additions and 103 deletions

View File

@@ -198,6 +198,7 @@ result = shark_module.forward((arg0, arg1))
| BigBird | :green_heart: (AOT) | | | |
| DistilBERT | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
| GPT2 | :broken_heart: (AOT) | | | |
| MobileBert | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
### Torchvision Models

View File

@@ -80,10 +80,8 @@ def check_device_drivers(device):
# Installation info for the missing device drivers.
def device_driver_info(device):
if device in ["gpu", "cuda"]:
print(
"nvidia-smi not found, please install the required drivers from https://www.nvidia.in/Download/index.aspx?lang=en-in"
)
return "nvidia-smi not found, please install the required drivers from https://www.nvidia.in/Download/index.aspx?lang=en-in"
elif device in ["metal", "vulkan"]:
print("vulkaninfo not found, Install from https://vulkan.lunarg.com/sdk/home or your distribution")
return "vulkaninfo not found, Install from https://vulkan.lunarg.com/sdk/home or your distribution"
else:
print(f"{device} is not supported.")
return f"{device} is not supported."

View File

@@ -1,6 +1,6 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from shark.iree_utils._common import check_device_drivers, device_driver_info
from tank.model_utils import get_hf_model, compare_tensors
from shark.parser import shark_args
@@ -55,39 +55,30 @@ class MiniLMModuleTest(unittest.TestCase):
self.module_tester.device = "cpu"
self.module_tester.create_and_check_module()
@pytest.mark.xfail(reason="language models failing for dynamic case")
def test_module_dynamic_cpu(self):
self.module_tester.dynamic = True
self.module_tester.device = "cpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
def test_module_static_gpu(self):
self.module_tester.dynamic = False
self.module_tester.device = "gpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
def test_module_dynamic_gpu(self):
self.module_tester.dynamic = True
self.module_tester.device = "gpu"
self.module_tester.create_and_check_module()
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
@pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
)
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
def test_module_static_vulkan(self):
self.module_tester.dynamic = False
self.module_tester.device = "vulkan"
self.module_tester.create_and_check_module()
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
@pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
)
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
def test_module_dynamic_vulkan(self):
self.module_tester.dynamic = True
self.module_tester.device = "vulkan"

View File

@@ -1,6 +1,6 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from shark.iree_utils._common import check_device_drivers, device_driver_info
from tank.model_utils import get_hf_model, compare_tensors
from shark.parser import shark_args
@@ -61,31 +61,25 @@ class AlbertModuleTest(unittest.TestCase):
self.module_tester.device = "cpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
def test_module_static_gpu(self):
self.module_tester.dynamic = False
self.module_tester.device = "gpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
def test_module_dynamic_gpu(self):
self.module_tester.dynamic = True
self.module_tester.device = "gpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
)
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
def test_module_static_vulkan(self):
self.module_tester.dynamic = False
self.module_tester.device = "vulkan"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
)
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
def test_module_dynamic_vulkan(self):
self.module_tester.dynamic = True
self.module_tester.device = "vulkan"

View File

@@ -1,6 +1,6 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from shark.iree_utils._common import check_device_drivers, device_driver_info
from tank.model_utils import get_hf_model, compare_tensors
from shark.parser import shark_args
@@ -55,39 +55,30 @@ class BertModuleTest(unittest.TestCase):
self.module_tester.device = "cpu"
self.module_tester.create_and_check_module()
@pytest.mark.xfail(reason="Language models currently failing for dynamic case")
def test_module_dynamic_cpu(self):
self.module_tester.dynamic = True
self.module_tester.device = "cpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
def test_module_static_gpu(self):
self.module_tester.dynamic = False
self.module_tester.device = "gpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
def test_module_dynamic_gpu(self):
self.module_tester.dynamic = True
self.module_tester.device = "gpu"
self.module_tester.create_and_check_module()
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
@pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
)
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
def test_module_static_vulkan(self):
self.module_tester.dynamic = False
self.module_tester.device = "vulkan"
self.module_tester.create_and_check_module()
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
@pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
)
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
def test_module_dynamic_vulkan(self):
self.module_tester.dynamic = True
self.module_tester.device = "vulkan"

View File

@@ -50,31 +50,33 @@ class DistilBertModuleTest(unittest.TestCase):
def setUp(self):
self.module_tester = DistilBertModuleTester(self)
@pytest.mark.xfail(reason="torch_mlir lowering issues.")
def test_module_static_cpu(self):
self.module_tester.dynamic = False
self.module_tester.device = "cpu"
self.module_tester.create_and_check_module()
@pytest.mark.xfail(reason="Language models currently failing for dynamic case")
@pytest.mark.xfail(reason="torch_mlir lowering issues.")
def test_module_dynamic_cpu(self):
self.module_tester.dynamic = True
self.module_tester.device = "cpu"
self.module_tester.create_and_check_module()
@pytest.mark.xfail(reason="torch_mlir lowering issues.")
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
def test_module_static_gpu(self):
self.module_tester.dynamic = False
self.module_tester.device = "gpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
@pytest.mark.xfail(reason="torch_mlir lowering issues.")
@pytest.mark.xfail(reason="Language models currently failing for dynamic case")
def test_module_dynamic_gpu(self):
self.module_tester.dynamic = True
self.module_tester.device = "gpu"
self.module_tester.create_and_check_module()
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
@pytest.mark.xfail(reason="torch_mlir lowering issues.")
@pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
@@ -84,7 +86,7 @@ class DistilBertModuleTest(unittest.TestCase):
self.module_tester.device = "vulkan"
self.module_tester.create_and_check_module()
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
@pytest.mark.xfail(reason="torch_mlir lowering issues.")
@pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",

View File

@@ -0,0 +1,89 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers, device_driver_info
from tank.model_utils import get_hf_model, compare_tensors
from shark.parser import shark_args
import torch
import unittest
import numpy as np
import pytest
torch.manual_seed(0)
class MobileBertUncasedModuleTester:
def __init__(
self,
dynamic=False,
device="cpu",
save_mlir=False,
save_vmfb=False,
):
self.dynamic = dynamic
self.device = device
self.save_mlir = save_mlir
self.save_vmfb = save_vmfb
def create_and_check_module(self):
model, input, act_out = get_hf_model("google/mobilebert-uncased")
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
mlir_importer = SharkImporter(
model,
(input,),
frontend="torch",
)
minilm_mlir, func_name = mlir_importer.import_mlir(is_dynamic=self.dynamic, tracing_required=True)
shark_module = SharkInference(minilm_mlir, func_name, device="cpu", mlir_dialect="linalg")
shark_module.compile()
results = shark_module.forward((input,))
assert True == compare_tensors(act_out, results)
class MobileBertModuleTest(unittest.TestCase):
@pytest.fixture(autouse=True)
def configure(self, pytestconfig):
self.save_mlir = pytestconfig.getoption("save_mlir")
self.save_vmfb = pytestconfig.getoption("save_vmfb")
def setUp(self):
self.module_tester = MobileBertUncasedModuleTester(self)
def test_module_static_cpu(self):
self.module_tester.dynamic = False
self.module_tester.device = "cpu"
self.module_tester.create_and_check_module()
def test_module_dynamic_cpu(self):
self.module_tester.dynamic = True
self.module_tester.device = "cpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
def test_module_static_gpu(self):
self.module_tester.dynamic = False
self.module_tester.device = "gpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
def test_module_dynamic_gpu(self):
self.module_tester.dynamic = True
self.module_tester.device = "gpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
def test_module_static_vulkan(self):
self.module_tester.dynamic = False
self.module_tester.device = "vulkan"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
def test_module_dynamic_vulkan(self):
self.module_tester.dynamic = True
self.module_tester.device = "vulkan"
self.module_tester.create_and_check_module()
if __name__ == "__main__":
unittest.main()

View File

@@ -1,6 +1,6 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from shark.iree_utils._common import check_device_drivers, device_driver_info
from tank.model_utils import get_vision_model, compare_tensors
from shark.parser import shark_args
@@ -61,31 +61,25 @@ class Resnet101ModuleTest(unittest.TestCase):
self.module_tester.device = "cpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
def test_module_static_gpu(self):
self.module_tester.dynamic = False
self.module_tester.device = "gpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
def test_module_dynamic_gpu(self):
self.module_tester.dynamic = True
self.module_tester.device = "gpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
)
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
def test_module_static_vulkan(self):
self.module_tester.dynamic = False
self.module_tester.device = "vulkan"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
)
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
def test_module_dynamic_vulkan(self):
self.module_tester.dynamic = True
self.module_tester.device = "vulkan"

View File

@@ -1,6 +1,6 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from shark.iree_utils._common import check_device_drivers, device_driver_info
from tank.model_utils import get_vision_model, compare_tensors
from shark.parser import shark_args
@@ -61,31 +61,25 @@ class Resnet18ModuleTest(unittest.TestCase):
self.module_tester.device = "cpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
def test_module_static_gpu(self):
self.module_tester.dynamic = False
self.module_tester.device = "gpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
def test_module_dynamic_gpu(self):
self.module_tester.dynamic = True
self.module_tester.device = "gpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
)
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
def test_module_static_vulkan(self):
self.module_tester.dynamic = False
self.module_tester.device = "vulkan"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
)
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
def test_module_dynamic_vulkan(self):
self.module_tester.dynamic = True
self.module_tester.device = "vulkan"

View File

@@ -1,6 +1,6 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from shark.iree_utils._common import check_device_drivers, device_driver_info
from tank.model_utils import get_vision_model, compare_tensors
from shark.parser import shark_args
@@ -61,31 +61,25 @@ class Resnet50ModuleTest(unittest.TestCase):
self.module_tester.device = "cpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
def test_module_static_gpu(self):
self.module_tester.dynamic = False
self.module_tester.device = "gpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
def test_module_dynamic_gpu(self):
self.module_tester.dynamic = True
self.module_tester.device = "gpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
)
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
def test_module_static_vulkan(self):
self.module_tester.dynamic = False
self.module_tester.device = "vulkan"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
)
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
def test_module_dynamic_vulkan(self):
self.module_tester.dynamic = True
self.module_tester.device = "vulkan"

View File

@@ -1,6 +1,6 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from shark.iree_utils._common import check_device_drivers, device_driver_info
from tank.model_utils import get_vision_model, compare_tensors
from shark.parser import shark_args
@@ -61,31 +61,25 @@ class SqueezenetModuleTest(unittest.TestCase):
self.module_tester.device = "cpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
def test_module_static_gpu(self):
self.module_tester.dynamic = False
self.module_tester.device = "gpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
def test_module_dynamic_gpu(self):
self.module_tester.dynamic = True
self.module_tester.device = "gpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
)
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
def test_module_static_vulkan(self):
self.module_tester.dynamic = False
self.module_tester.device = "vulkan"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
)
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
def test_module_dynamic_vulkan(self):
self.module_tester.dynamic = True
self.module_tester.device = "vulkan"

View File

@@ -1,6 +1,6 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from shark.iree_utils._common import check_device_drivers
from shark.iree_utils._common import check_device_drivers, device_driver_info
from tank.model_utils import get_vision_model, compare_tensors
from shark.parser import shark_args
@@ -61,31 +61,25 @@ class WideResnet50ModuleTest(unittest.TestCase):
self.module_tester.device = "cpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
def test_module_static_gpu(self):
self.module_tester.dynamic = False
self.module_tester.device = "gpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
def test_module_dynamic_gpu(self):
self.module_tester.dynamic = True
self.module_tester.device = "gpu"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
)
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
def test_module_static_vulkan(self):
self.module_tester.dynamic = False
self.module_tester.device = "vulkan"
self.module_tester.create_and_check_module()
@pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
)
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
def test_module_dynamic_vulkan(self):
self.module_tester.dynamic = True
self.module_tester.device = "vulkan"