Files
AMD-SHARK-Studio/tank/pytorch/mobilebert-uncased/mobilebert-uncased_test.py
2022-08-10 11:24:00 -07:00

114 lines
3.4 KiB
Python

from shark.shark_inference import SharkInference
from shark.iree_utils._common import check_device_drivers, device_driver_info
from tank.model_utils import compare_tensors
from shark.shark_downloader import download_torch_model
import torch
import unittest
import numpy as np
import pytest
class MobileBertModuleTester:
def __init__(
self,
benchmark=False,
):
self.benchmark = benchmark
def create_and_check_module(self, dynamic, device):
model_mlir, func_name, input, act_out = download_torch_model(
"google/mobilebert-uncased", dynamic
)
# from shark.shark_importer import SharkImporter
# mlir_importer = SharkImporter(
# model,
# (input,),
# frontend="torch",
# )
# minilm_mlir, func_name = mlir_importer.import_mlir(
# is_dynamic=dynamic, tracing_required=True
# )
shark_module = SharkInference(
model_mlir,
func_name,
device=device,
mlir_dialect="linalg",
is_benchmark=self.benchmark,
)
shark_module.compile()
results = shark_module.forward(input)
assert True == compare_tensors(act_out, results)
if self.benchmark == True:
shark_module.shark_runner.benchmark_all_csv(
(input),
"google/mobilebert-uncased",
dynamic,
device,
"torch",
)
class MobileBertModuleTest(unittest.TestCase):
@pytest.fixture(autouse=True)
def configure(self, pytestconfig):
self.module_tester = MobileBertModuleTester(self)
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
def test_module_static_cpu(self):
dynamic = False
device = "cpu"
self.module_tester.create_and_check_module(dynamic, device)
def test_module_dynamic_cpu(self):
dynamic = True
device = "cpu"
self.module_tester.create_and_check_module(dynamic, device)
@pytest.mark.skipif(
check_device_drivers("intel-gpu"), reason=device_driver_info("intel-gpu")
)
def test_module_static_intelgpu(self):
dynamic = False
device = "intel-gpu"
self.module_tester.create_and_check_module(dynamic, device)
@pytest.mark.skipif(
check_device_drivers("gpu"), reason=device_driver_info("gpu")
)
def test_module_static_gpu(self):
dynamic = False
device = "gpu"
self.module_tester.create_and_check_module(dynamic, device)
@pytest.mark.skipif(
check_device_drivers("gpu"), reason=device_driver_info("gpu")
)
def test_module_dynamic_gpu(self):
dynamic = True
device = "gpu"
self.module_tester.create_and_check_module(dynamic, device)
@pytest.mark.skipif(
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
)
def test_module_static_vulkan(self):
dynamic = False
device = "vulkan"
self.module_tester.create_and_check_module(dynamic, device)
@pytest.mark.skipif(
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
)
def test_module_dynamic_vulkan(self):
dynamic = True
device = "vulkan"
self.module_tester.create_and_check_module(dynamic, device)
if __name__ == "__main__":
unittest.main()