mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Merge pull request #122 from nod-ai/ean-tests
Enable passing of --save-mlir cmd-line option through pytest.
This commit is contained in:
3
tank/conftest.py
Normal file
3
tank/conftest.py
Normal file
@@ -0,0 +1,3 @@
|
||||
def pytest_addoption(parser):
|
||||
# Attaches SHARK command-line arguments to the pytest machinery.
|
||||
parser.addoption("--save_mlir", action="store_true", default="False", help="Pass option to save input MLIR module to /tmp/ directory.")
|
||||
@@ -1,6 +1,7 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from tank.model_utils import get_hf_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
import torch
|
||||
import unittest
|
||||
@@ -15,12 +16,15 @@ class AlbertModuleTester:
|
||||
self,
|
||||
dynamic=False,
|
||||
device="cpu",
|
||||
save_mlir=False,
|
||||
):
|
||||
self.dynamic = dynamic
|
||||
self.device = device
|
||||
self.save_mlir = save_mlir
|
||||
|
||||
def create_and_check_module(self):
|
||||
model, input, act_out = get_hf_model("albert-base-v2")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
@@ -30,10 +34,15 @@ class AlbertModuleTester:
|
||||
assert True == compare_tensors(act_out, results)
|
||||
|
||||
class AlbertModuleTest(unittest.TestCase):
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = AlbertModuleTester(self)
|
||||
|
||||
self.module_tester.save_mlir = self.save_mlir
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
import torch
|
||||
import unittest
|
||||
@@ -16,12 +17,15 @@ class AlexnetModuleTester:
|
||||
self,
|
||||
dynamic=False,
|
||||
device="cpu",
|
||||
save_mlir=False,
|
||||
):
|
||||
self.dynamic = dynamic
|
||||
self.device = device
|
||||
self.save_mlir = save_mlir
|
||||
|
||||
def create_and_check_module(self):
|
||||
model, input, act_out = get_vision_model(models.alexnet(pretrained=True))
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
@@ -34,6 +38,10 @@ class AlexnetModuleTester:
|
||||
|
||||
class AlexnetModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = AlexnetModuleTester(self)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from tank.model_utils import get_hf_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
import torch
|
||||
import unittest
|
||||
@@ -15,12 +16,15 @@ class BertModuleTester:
|
||||
self,
|
||||
dynamic=False,
|
||||
device="cpu",
|
||||
save_mlir=False,
|
||||
):
|
||||
self.dynamic = dynamic
|
||||
self.device = device
|
||||
self.save_mlir = save_mlir
|
||||
|
||||
def create_and_check_module(self):
|
||||
model, input, act_out = get_hf_model("bert-base-uncased")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
@@ -31,6 +35,10 @@ class BertModuleTester:
|
||||
|
||||
class BertModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = BertModuleTester(self)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from tank.model_utils import get_hf_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
import torch
|
||||
import unittest
|
||||
@@ -15,12 +16,15 @@ class MiniLMModuleTester:
|
||||
self,
|
||||
dynamic=False,
|
||||
device="cpu",
|
||||
save_mlir=False,
|
||||
):
|
||||
self.dynamic = dynamic
|
||||
self.device = device
|
||||
self.save_mlir = save_mlir
|
||||
|
||||
def create_and_check_module(self):
|
||||
model, input, act_out = get_hf_model("microsoft/MiniLM-L12-H384-uncased")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
@@ -31,6 +35,10 @@ class MiniLMModuleTester:
|
||||
|
||||
class MiniLMModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = MiniLMModuleTester(self)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
import torch
|
||||
import unittest
|
||||
@@ -16,12 +17,15 @@ class Resnet101ModuleTester:
|
||||
self,
|
||||
dynamic=False,
|
||||
device="cpu",
|
||||
save_mlir=False,
|
||||
):
|
||||
self.dynamic = dynamic
|
||||
self.device = device
|
||||
self.save_mlir = save_mlir
|
||||
|
||||
def create_and_check_module(self):
|
||||
model, input, act_out = get_vision_model(models.resnet101(pretrained=True))
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
@@ -34,6 +38,10 @@ class Resnet101ModuleTester:
|
||||
|
||||
class Resnet101ModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = Resnet101ModuleTester(self)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
import torch
|
||||
import unittest
|
||||
@@ -16,12 +17,15 @@ class Resnet18ModuleTester:
|
||||
self,
|
||||
dynamic=False,
|
||||
device="cpu",
|
||||
save_mlir=False,
|
||||
):
|
||||
self.dynamic = dynamic
|
||||
self.device = device
|
||||
self.save_mlir = save_mlir
|
||||
|
||||
def create_and_check_module(self):
|
||||
model, input, act_out = get_vision_model(models.resnet18(pretrained=True))
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
@@ -34,6 +38,10 @@ class Resnet18ModuleTester:
|
||||
|
||||
class Resnet18ModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = Resnet18ModuleTester(self)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
import torch
|
||||
import unittest
|
||||
@@ -16,12 +17,15 @@ class Resnet50ModuleTester:
|
||||
self,
|
||||
dynamic=False,
|
||||
device="cpu",
|
||||
save_mlir=False,
|
||||
):
|
||||
self.dynamic = dynamic
|
||||
self.device = device
|
||||
self.save_mlir = save_mlir
|
||||
|
||||
def create_and_check_module(self):
|
||||
model, input, act_out = get_vision_model(models.resnet50(pretrained=True))
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
@@ -34,6 +38,10 @@ class Resnet50ModuleTester:
|
||||
|
||||
class Resnet50ModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = Resnet50ModuleTester(self)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
import torch
|
||||
import unittest
|
||||
@@ -16,12 +17,15 @@ class SqueezenetModuleTester:
|
||||
self,
|
||||
dynamic=False,
|
||||
device="cpu",
|
||||
save_mlir=False
|
||||
):
|
||||
self.dynamic = dynamic
|
||||
self.device = device
|
||||
self.save_mlir = save_mlir
|
||||
|
||||
def create_and_check_module(self):
|
||||
model, input, act_out = get_vision_model(models.squeezenet1_0(pretrained=True))
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
@@ -34,6 +38,10 @@ class SqueezenetModuleTester:
|
||||
|
||||
class SqueezenetModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = SqueezenetModuleTester(self)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
import torch
|
||||
import unittest
|
||||
@@ -16,12 +17,15 @@ class WideResnet50ModuleTester:
|
||||
self,
|
||||
dynamic=False,
|
||||
device="cpu",
|
||||
save_mlir=False,
|
||||
):
|
||||
self.dynamic = dynamic
|
||||
self.device = device
|
||||
self.save_mlir = save_mlir
|
||||
|
||||
def create_and_check_module(self):
|
||||
model, input, act_out = get_vision_model(models.wide_resnet50_2(pretrained=True))
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
@@ -34,6 +38,10 @@ class WideResnet50ModuleTester:
|
||||
|
||||
class WideResnet50ModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = WideResnet50ModuleTester(self)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user