Files
AMD-SHARK-Studio/tank/mpnet-base_tf/mpnet-base_tf_test.py
2022-08-18 21:21:43 -07:00

72 lines
2.3 KiB
Python

from shark.iree_utils._common import check_device_drivers, device_driver_info
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_tf_model
import iree.compiler as ireec
import unittest
import pytest
import numpy as np
class MpNetModuleTester:
def __init__(
self,
benchmark=False,
):
self.benchmark = benchmark
def create_and_check_module(self, dynamic, device):
model, func_name, inputs, golden_out = download_tf_model(
"microsoft/mpnet-base"
)
shark_module = SharkInference(
model, func_name, device=device, mlir_dialect="mhlo"
)
shark_module.compile()
result = shark_module.forward(inputs)
np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03)
class MpNetModuleTest(unittest.TestCase):
@pytest.fixture(autouse=True)
def configure(self, pytestconfig):
self.module_tester = MpNetModuleTester(self)
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
@pytest.mark.xfail(reason="https://github.com/nod-ai/SHARK/issues/203")
def test_module_static_cpu(self):
dynamic = False
device = "cpu"
self.module_tester.create_and_check_module(dynamic, device)
@pytest.mark.xfail(reason="https://github.com/nod-ai/SHARK/issues/203")
@pytest.mark.skipif(
check_device_drivers("gpu"), reason=device_driver_info("gpu")
)
def test_module_static_gpu(self):
dynamic = False
device = "gpu"
self.module_tester.create_and_check_module(dynamic, device)
@pytest.mark.xfail(reason="https://github.com/nod-ai/SHARK/issues/203")
@pytest.mark.skipif(
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
)
def test_module_static_vulkan(self):
dynamic = False
device = "vulkan"
self.module_tester.create_and_check_module(dynamic, device)
# @pytest.mark.skipif(
# check_device_drivers("intel-gpu"),
# reason=device_driver_info("intel-gpu"),
# )
# def test_module_static_intel_gpu(self):
# dynamic = False
# device = "intel-gpu"
# self.module_tester.create_and_check_module(dynamic, device)
if __name__ == "__main__":
unittest.main()