mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
114 lines
3.4 KiB
Python
114 lines
3.4 KiB
Python
from shark.shark_inference import SharkInference
|
|
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
|
from tank.model_utils import compare_tensors
|
|
from shark.shark_downloader import download_torch_model
|
|
|
|
import torch
|
|
import unittest
|
|
import numpy as np
|
|
import pytest
|
|
|
|
|
|
class MobileBertModuleTester:
|
|
def __init__(
|
|
self,
|
|
benchmark=False,
|
|
):
|
|
self.benchmark = benchmark
|
|
|
|
def create_and_check_module(self, dynamic, device):
|
|
model_mlir, func_name, input, act_out = download_torch_model(
|
|
"google/mobilebert-uncased", dynamic
|
|
)
|
|
|
|
# from shark.shark_importer import SharkImporter
|
|
# mlir_importer = SharkImporter(
|
|
# model,
|
|
# (input,),
|
|
# frontend="torch",
|
|
# )
|
|
# minilm_mlir, func_name = mlir_importer.import_mlir(
|
|
# is_dynamic=dynamic, tracing_required=True
|
|
# )
|
|
|
|
shark_module = SharkInference(
|
|
model_mlir,
|
|
func_name,
|
|
device=device,
|
|
mlir_dialect="linalg",
|
|
is_benchmark=self.benchmark,
|
|
)
|
|
shark_module.compile()
|
|
results = shark_module.forward(input)
|
|
assert True == compare_tensors(act_out, results)
|
|
|
|
if self.benchmark == True:
|
|
shark_module.shark_runner.benchmark_all_csv(
|
|
(input),
|
|
"google/mobilebert-uncased",
|
|
dynamic,
|
|
device,
|
|
"torch",
|
|
)
|
|
|
|
|
|
class MobileBertModuleTest(unittest.TestCase):
|
|
@pytest.fixture(autouse=True)
|
|
def configure(self, pytestconfig):
|
|
self.module_tester = MobileBertModuleTester(self)
|
|
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
|
|
|
def test_module_static_cpu(self):
|
|
dynamic = False
|
|
device = "cpu"
|
|
self.module_tester.create_and_check_module(dynamic, device)
|
|
|
|
def test_module_dynamic_cpu(self):
|
|
dynamic = True
|
|
device = "cpu"
|
|
self.module_tester.create_and_check_module(dynamic, device)
|
|
|
|
@pytest.mark.skipif(
|
|
check_device_drivers("intel-gpu"), reason=device_driver_info("intel-gpu")
|
|
)
|
|
def test_module_static_intelgpu(self):
|
|
dynamic = False
|
|
device = "intel-gpu"
|
|
self.module_tester.create_and_check_module(dynamic, device)
|
|
|
|
@pytest.mark.skipif(
|
|
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
|
)
|
|
def test_module_static_gpu(self):
|
|
dynamic = False
|
|
device = "gpu"
|
|
self.module_tester.create_and_check_module(dynamic, device)
|
|
|
|
@pytest.mark.skipif(
|
|
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
|
)
|
|
def test_module_dynamic_gpu(self):
|
|
dynamic = True
|
|
device = "gpu"
|
|
self.module_tester.create_and_check_module(dynamic, device)
|
|
|
|
@pytest.mark.skipif(
|
|
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
|
)
|
|
def test_module_static_vulkan(self):
|
|
dynamic = False
|
|
device = "vulkan"
|
|
self.module_tester.create_and_check_module(dynamic, device)
|
|
|
|
@pytest.mark.skipif(
|
|
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
|
)
|
|
def test_module_dynamic_vulkan(self):
|
|
dynamic = True
|
|
device = "vulkan"
|
|
self.module_tester.create_and_check_module(dynamic, device)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|