mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
65 lines
1.9 KiB
Python
65 lines
1.9 KiB
Python
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
|
from shark.shark_inference import SharkInference
|
|
from shark.shark_downloader import download_model
|
|
|
|
import iree.compiler as ireec
|
|
import unittest
|
|
import pytest
|
|
import numpy as np
|
|
|
|
|
|
class TapasBaseModuleTester:
|
|
def __init__(
|
|
self,
|
|
benchmark=False,
|
|
):
|
|
self.benchmark = benchmark
|
|
|
|
def create_and_check_module(self, dynamic, device):
|
|
model, func_name, inputs, golden_out = download_model(
|
|
"google/tapas-base",
|
|
frontend="tf",
|
|
)
|
|
|
|
shark_module = SharkInference(
|
|
model, func_name, device=device, mlir_dialect="mhlo"
|
|
)
|
|
shark_module.compile()
|
|
result = shark_module.forward(inputs)
|
|
np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03)
|
|
|
|
|
|
class TapasBaseModuleTest(unittest.TestCase):
|
|
@pytest.skip(
|
|
reason="Input must be a pandas dataframe.", allow_module_level=True
|
|
)
|
|
@pytest.fixture(autouse=True)
|
|
def configure(self, pytestconfig):
|
|
self.module_tester = TapasBaseModuleTester(self)
|
|
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
|
|
|
def test_module_static_cpu(self):
|
|
dynamic = False
|
|
device = "cpu"
|
|
self.module_tester.create_and_check_module(dynamic, device)
|
|
|
|
@pytest.mark.skipif(
|
|
check_device_drivers("cuda"), reason=device_driver_info("cuda")
|
|
)
|
|
def test_module_static_cuda(self):
|
|
dynamic = False
|
|
device = "cuda"
|
|
self.module_tester.create_and_check_module(dynamic, device)
|
|
|
|
@pytest.mark.skipif(
|
|
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
|
)
|
|
def test_module_static_vulkan(self):
|
|
dynamic = False
|
|
device = "vulkan"
|
|
self.module_tester.create_and_check_module(dynamic, device)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|