Files
AMD-SHARK-Studio/tank/examples/tapas-base_tf/tapas-base_tf_test.py

65 lines
1.9 KiB
Python

from shark.iree_utils._common import check_device_drivers, device_driver_info
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_model
import iree.compiler as ireec
import unittest
import pytest
import numpy as np
class TapasBaseModuleTester:
def __init__(
self,
benchmark=False,
):
self.benchmark = benchmark
def create_and_check_module(self, dynamic, device):
model, func_name, inputs, golden_out = download_model(
"google/tapas-base",
frontend="tf",
)
shark_module = SharkInference(
model, func_name, device=device, mlir_dialect="mhlo"
)
shark_module.compile()
result = shark_module.forward(inputs)
np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03)
class TapasBaseModuleTest(unittest.TestCase):
@pytest.skip(
reason="Input must be a pandas dataframe.", allow_module_level=True
)
@pytest.fixture(autouse=True)
def configure(self, pytestconfig):
self.module_tester = TapasBaseModuleTester(self)
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
def test_module_static_cpu(self):
dynamic = False
device = "cpu"
self.module_tester.create_and_check_module(dynamic, device)
@pytest.mark.skipif(
check_device_drivers("cuda"), reason=device_driver_info("cuda")
)
def test_module_static_cuda(self):
dynamic = False
device = "cuda"
self.module_tester.create_and_check_module(dynamic, device)
@pytest.mark.skipif(
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
)
def test_module_static_vulkan(self):
dynamic = False
device = "vulkan"
self.module_tester.create_and_check_module(dynamic, device)
if __name__ == "__main__":
unittest.main()