mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-08 21:38:04 -05:00
Add unet_torch reference. (#283)
* Add unet_torch reference. * Delete distilbert-base-uncased_torch_test.py
This commit is contained in:
91
reference_models/unet_torch/unet_torch_test.py
Normal file
91
reference_models/unet_torch/unet_torch_test.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_downloader import download_torch_model
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
class UnetModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
benchmark=False,
|
||||
):
|
||||
self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model_mlir, func_name, input, act_out = download_torch_model(
|
||||
"unet", dynamic
|
||||
)
|
||||
|
||||
# from shark.shark_importer import SharkImporter
|
||||
# mlir_importer = SharkImporter(
|
||||
# model,
|
||||
# (input,),
|
||||
# frontend="torch",
|
||||
# )
|
||||
# minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
# is_dynamic=dynamic, tracing_required=True
|
||||
# )
|
||||
|
||||
shark_module = SharkInference(
|
||||
model_mlir,
|
||||
func_name,
|
||||
device=device,
|
||||
mlir_dialect="linalg",
|
||||
is_benchmark=self.benchmark,
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward(input)
|
||||
np.testing.assert_allclose(act_out, results, rtol=1e-02, atol=1e-03)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.shark_runner.benchmark_all_csv(
|
||||
(input),
|
||||
"unet",
|
||||
dynamic,
|
||||
device,
|
||||
"torch",
|
||||
)
|
||||
|
||||
|
||||
class UnetModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = UnetModuleTester(self)
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("intel-gpu"),
|
||||
reason=device_driver_info("intel-gpu"),
|
||||
)
|
||||
def test_module_static_intel_gpu(self):
|
||||
dynamic = False
|
||||
device = "intel-gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user