Files
AMD-SHARK-Studio/tank/deberta-base_tf/deberta-base_tf_test.py
2022-07-29 09:34:50 -07:00

65 lines
1.9 KiB
Python

from shark.iree_utils._common import check_device_drivers, device_driver_info
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_tf_model
from shark.parser import shark_args
import iree.compiler as ireec
import unittest
import pytest
import numpy as np
import tempfile
import os
class DebertaBaseModuleTester:
def __init__(
self,
benchmark=False,
):
self.benchmark = benchmark
def create_and_check_module(self, dynamic, device):
model, func_name, inputs, golden_out = download_tf_model(
"microsoft/deberta-base"
)
shark_module = SharkInference(
model, func_name, device=device, mlir_dialect="mhlo"
)
shark_module.compile()
result = shark_module.forward(inputs)
np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03)
class DebertaBaseModuleTest(unittest.TestCase):
@pytest.skip(reason="Model can't be imported.", allow_module_level=True)
@pytest.fixture(autouse=True)
def configure(self, pytestconfig):
self.module_tester = DebertaBaseModuleTester(self)
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
def test_module_static_cpu(self):
dynamic = False
device = "cpu"
self.module_tester.create_and_check_module(dynamic, device)
@pytest.mark.skipif(
check_device_drivers("gpu"), reason=device_driver_info("gpu")
)
def test_module_static_gpu(self):
dynamic = False
device = "gpu"
self.module_tester.create_and_check_module(dynamic, device)
@pytest.mark.skipif(
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
)
def test_module_static_vulkan(self):
dynamic = False
device = "vulkan"
self.module_tester.create_and_check_module(dynamic, device)
if __name__ == "__main__":
unittest.main()