Add unet_torch reference. (#283)

* Add unet_torch reference.

* Delete distilbert-base-uncased_torch_test.py
This commit is contained in:
Prashant Kumar
2022-08-19 13:35:10 +05:30
committed by GitHub
parent 62f3573d43
commit d8c9225af8

View File

@@ -0,0 +1,91 @@
from shark.shark_inference import SharkInference
from shark.iree_utils._common import check_device_drivers, device_driver_info
from shark.shark_downloader import download_torch_model
import unittest
import numpy as np
import pytest
class UnetModuleTester:
def __init__(
self,
benchmark=False,
):
self.benchmark = benchmark
def create_and_check_module(self, dynamic, device):
model_mlir, func_name, input, act_out = download_torch_model(
"unet", dynamic
)
# from shark.shark_importer import SharkImporter
# mlir_importer = SharkImporter(
# model,
# (input,),
# frontend="torch",
# )
# minilm_mlir, func_name = mlir_importer.import_mlir(
# is_dynamic=dynamic, tracing_required=True
# )
shark_module = SharkInference(
model_mlir,
func_name,
device=device,
mlir_dialect="linalg",
is_benchmark=self.benchmark,
)
shark_module.compile()
results = shark_module.forward(input)
np.testing.assert_allclose(act_out, results, rtol=1e-02, atol=1e-03)
if self.benchmark == True:
shark_module.shark_runner.benchmark_all_csv(
(input),
"unet",
dynamic,
device,
"torch",
)
class UnetModuleTest(unittest.TestCase):
@pytest.fixture(autouse=True)
def configure(self, pytestconfig):
self.module_tester = UnetModuleTester(self)
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
def test_module_static_cpu(self):
dynamic = False
device = "cpu"
self.module_tester.create_and_check_module(dynamic, device)
@pytest.mark.skipif(
check_device_drivers("gpu"), reason=device_driver_info("gpu")
)
def test_module_static_gpu(self):
dynamic = False
device = "gpu"
self.module_tester.create_and_check_module(dynamic, device)
@pytest.mark.skipif(
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
)
def test_module_static_vulkan(self):
dynamic = False
device = "vulkan"
self.module_tester.create_and_check_module(dynamic, device)
@pytest.mark.skipif(
check_device_drivers("intel-gpu"),
reason=device_driver_info("intel-gpu"),
)
def test_module_static_intel_gpu(self):
dynamic = False
device = "intel-gpu"
self.module_tester.create_and_check_module(dynamic, device)
if __name__ == "__main__":
unittest.main()