mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Update pytorch tests to support vulkan and cuda.
All the model validation pass except distilbert which is failing in torch-mlir lowering. Also, added the mobilebert-uncased model to the torch test suite.
This commit is contained in:
@@ -198,6 +198,7 @@ result = shark_module.forward((arg0, arg1))
|
||||
| BigBird | :green_heart: (AOT) | | | |
|
||||
| DistilBERT | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| GPT2 | :broken_heart: (AOT) | | | |
|
||||
| MobileBert | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
|
||||
|
||||
### Torchvision Models
|
||||
|
||||
|
||||
@@ -80,10 +80,8 @@ def check_device_drivers(device):
|
||||
# Installation info for the missing device drivers.
|
||||
def device_driver_info(device):
|
||||
if device in ["gpu", "cuda"]:
|
||||
print(
|
||||
"nvidia-smi not found, please install the required drivers from https://www.nvidia.in/Download/index.aspx?lang=en-in"
|
||||
)
|
||||
return "nvidia-smi not found, please install the required drivers from https://www.nvidia.in/Download/index.aspx?lang=en-in"
|
||||
elif device in ["metal", "vulkan"]:
|
||||
print("vulkaninfo not found, Install from https://vulkan.lunarg.com/sdk/home or your distribution")
|
||||
return "vulkaninfo not found, Install from https://vulkan.lunarg.com/sdk/home or your distribution"
|
||||
else:
|
||||
print(f"{device} is not supported.")
|
||||
return f"{device} is not supported."
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from tank.model_utils import get_hf_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -55,39 +55,30 @@ class MiniLMModuleTest(unittest.TestCase):
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.xfail(reason="language models failing for dynamic case")
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from tank.model_utils import get_hf_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -61,31 +61,25 @@ class AlbertModuleTest(unittest.TestCase):
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from tank.model_utils import get_hf_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -55,39 +55,30 @@ class BertModuleTest(unittest.TestCase):
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.xfail(reason="Language models currently failing for dynamic case")
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
|
||||
@@ -50,31 +50,33 @@ class DistilBertModuleTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.module_tester = DistilBertModuleTester(self)
|
||||
|
||||
@pytest.mark.xfail(reason="torch_mlir lowering issues.")
|
||||
def test_module_static_cpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.xfail(reason="Language models currently failing for dynamic case")
|
||||
@pytest.mark.xfail(reason="torch_mlir lowering issues.")
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.xfail(reason="torch_mlir lowering issues.")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.xfail(reason="torch_mlir lowering issues.")
|
||||
@pytest.mark.xfail(reason="Language models currently failing for dynamic case")
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
|
||||
@pytest.mark.xfail(reason="torch_mlir lowering issues.")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
@@ -84,7 +86,7 @@ class DistilBertModuleTest(unittest.TestCase):
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
|
||||
@pytest.mark.xfail(reason="torch_mlir lowering issues.")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from tank.model_utils import get_hf_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
import torch
|
||||
import unittest
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class MobileBertUncasedModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
device="cpu",
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
):
|
||||
self.dynamic = dynamic
|
||||
self.device = device
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
def create_and_check_module(self):
|
||||
model, input, act_out = get_hf_model("google/mobilebert-uncased")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(is_dynamic=self.dynamic, tracing_required=True)
|
||||
shark_module = SharkInference(minilm_mlir, func_name, device="cpu", mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
assert True == compare_tensors(act_out, results)
|
||||
|
||||
|
||||
class MobileBertModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = MobileBertUncasedModuleTester(self)
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,6 +1,6 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -61,31 +61,25 @@ class Resnet101ModuleTest(unittest.TestCase):
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -61,31 +61,25 @@ class Resnet18ModuleTest(unittest.TestCase):
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -61,31 +61,25 @@ class Resnet50ModuleTest(unittest.TestCase):
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -61,31 +61,25 @@ class SqueezenetModuleTest(unittest.TestCase):
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -61,31 +61,25 @@ class WideResnet50ModuleTest(unittest.TestCase):
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason=device_driver_info("gpu"))
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
@pytest.mark.skipif(check_device_drivers("vulkan"), reason=device_driver_info("vulkan"))
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
|
||||
Reference in New Issue
Block a user