Merge pull request #122 from nod-ai/ean-tests

Enable passing of --save-mlir cmd-line option through pytest.
This commit is contained in:
Ean Garvey
2022-06-13 20:27:20 -05:00
committed by GitHub
10 changed files with 78 additions and 2 deletions

3
tank/conftest.py Normal file
View File

@@ -0,0 +1,3 @@
def pytest_addoption(parser):
# Attaches SHARK command-line arguments to the pytest machinery.
parser.addoption("--save_mlir", action="store_true", default="False", help="Pass option to save input MLIR module to /tmp/ directory.")

View File

@@ -1,6 +1,7 @@
from shark.shark_inference import SharkInference
from shark.iree_utils import check_device_drivers
from tank.model_utils import get_hf_model, compare_tensors
from shark.parser import shark_args
import torch
import unittest
@@ -15,12 +16,15 @@ class AlbertModuleTester:
self,
dynamic=False,
device="cpu",
save_mlir=False,
):
self.dynamic = dynamic
self.device = device
self.save_mlir = save_mlir
def create_and_check_module(self):
model, input, act_out = get_hf_model("albert-base-v2")
shark_args.save_mlir = self.save_mlir
shark_module = SharkInference(model, (input,),
device=self.device,
dynamic=self.dynamic,
@@ -30,10 +34,15 @@ class AlbertModuleTester:
assert True == compare_tensors(act_out, results)
class AlbertModuleTest(unittest.TestCase):
@pytest.fixture(autouse=True)
def configure(self, pytestconfig):
self.save_mlir = pytestconfig.getoption("save_mlir")
def setUp(self):
self.module_tester = AlbertModuleTester(self)
self.module_tester.save_mlir = self.save_mlir
def test_module_static_cpu(self):
self.module_tester.dynamic = False
self.module_tester.device = "cpu"

View File

@@ -1,6 +1,7 @@
from shark.shark_inference import SharkInference
from shark.iree_utils import check_device_drivers
from tank.model_utils import get_vision_model, compare_tensors
from shark.parser import shark_args
import torch
import unittest
@@ -16,12 +17,15 @@ class AlexnetModuleTester:
self,
dynamic=False,
device="cpu",
save_mlir=False,
):
self.dynamic = dynamic
self.device = device
self.save_mlir = save_mlir
def create_and_check_module(self):
model, input, act_out = get_vision_model(models.alexnet(pretrained=True))
shark_args.save_mlir = self.save_mlir
shark_module = SharkInference(
model,
(input,),
@@ -34,6 +38,10 @@ class AlexnetModuleTester:
class AlexnetModuleTest(unittest.TestCase):
@pytest.fixture(autouse=True)
def configure(self, pytestconfig):
self.save_mlir = pytestconfig.getoption("save_mlir")
def setUp(self):
self.module_tester = AlexnetModuleTester(self)

View File

@@ -1,6 +1,7 @@
from shark.shark_inference import SharkInference
from shark.iree_utils import check_device_drivers
from tank.model_utils import get_hf_model, compare_tensors
from shark.parser import shark_args
import torch
import unittest
@@ -15,12 +16,15 @@ class BertModuleTester:
self,
dynamic=False,
device="cpu",
save_mlir=False,
):
self.dynamic = dynamic
self.device = device
self.save_mlir = save_mlir
def create_and_check_module(self):
model, input, act_out = get_hf_model("bert-base-uncased")
shark_args.save_mlir = self.save_mlir
shark_module = SharkInference(model, (input,),
device=self.device,
dynamic=self.dynamic,
@@ -31,6 +35,10 @@ class BertModuleTester:
class BertModuleTest(unittest.TestCase):
@pytest.fixture(autouse=True)
def configure(self, pytestconfig):
self.save_mlir = pytestconfig.getoption("save_mlir")
def setUp(self):
self.module_tester = BertModuleTester(self)

View File

@@ -1,6 +1,7 @@
from shark.shark_inference import SharkInference
from shark.iree_utils import check_device_drivers
from tank.model_utils import get_hf_model, compare_tensors
from shark.parser import shark_args
import torch
import unittest
@@ -15,12 +16,15 @@ class MiniLMModuleTester:
self,
dynamic=False,
device="cpu",
save_mlir=False,
):
self.dynamic = dynamic
self.device = device
self.save_mlir = save_mlir
def create_and_check_module(self):
model, input, act_out = get_hf_model("microsoft/MiniLM-L12-H384-uncased")
shark_args.save_mlir = self.save_mlir
shark_module = SharkInference(model, (input,),
device=self.device,
dynamic=self.dynamic,
@@ -31,6 +35,10 @@ class MiniLMModuleTester:
class MiniLMModuleTest(unittest.TestCase):
@pytest.fixture(autouse=True)
def configure(self, pytestconfig):
self.save_mlir = pytestconfig.getoption("save_mlir")
def setUp(self):
self.module_tester = MiniLMModuleTester(self)

View File

@@ -1,6 +1,7 @@
from shark.shark_inference import SharkInference
from shark.iree_utils import check_device_drivers
from tank.model_utils import get_vision_model, compare_tensors
from shark.parser import shark_args
import torch
import unittest
@@ -16,12 +17,15 @@ class Resnet101ModuleTester:
self,
dynamic=False,
device="cpu",
save_mlir=False,
):
self.dynamic = dynamic
self.device = device
self.save_mlir = save_mlir
def create_and_check_module(self):
model, input, act_out = get_vision_model(models.resnet101(pretrained=True))
shark_args.save_mlir = self.save_mlir
shark_module = SharkInference(
model,
(input,),
@@ -34,6 +38,10 @@ class Resnet101ModuleTester:
class Resnet101ModuleTest(unittest.TestCase):
@pytest.fixture(autouse=True)
def configure(self, pytestconfig):
self.save_mlir = pytestconfig.getoption("save_mlir")
def setUp(self):
self.module_tester = Resnet101ModuleTester(self)

View File

@@ -1,6 +1,7 @@
from shark.shark_inference import SharkInference
from shark.iree_utils import check_device_drivers
from tank.model_utils import get_vision_model, compare_tensors
from shark.parser import shark_args
import torch
import unittest
@@ -16,12 +17,15 @@ class Resnet18ModuleTester:
self,
dynamic=False,
device="cpu",
save_mlir=False,
):
self.dynamic = dynamic
self.device = device
self.save_mlir = save_mlir
def create_and_check_module(self):
model, input, act_out = get_vision_model(models.resnet18(pretrained=True))
shark_args.save_mlir = self.save_mlir
shark_module = SharkInference(
model,
(input,),
@@ -34,6 +38,10 @@ class Resnet18ModuleTester:
class Resnet18ModuleTest(unittest.TestCase):
@pytest.fixture(autouse=True)
def configure(self, pytestconfig):
self.save_mlir = pytestconfig.getoption("save_mlir")
def setUp(self):
self.module_tester = Resnet18ModuleTester(self)

View File

@@ -1,6 +1,7 @@
from shark.shark_inference import SharkInference
from shark.iree_utils import check_device_drivers
from tank.model_utils import get_vision_model, compare_tensors
from shark.parser import shark_args
import torch
import unittest
@@ -16,12 +17,15 @@ class Resnet50ModuleTester:
self,
dynamic=False,
device="cpu",
save_mlir=False,
):
self.dynamic = dynamic
self.device = device
self.save_mlir = save_mlir
def create_and_check_module(self):
model, input, act_out = get_vision_model(models.resnet50(pretrained=True))
shark_args.save_mlir = self.save_mlir
shark_module = SharkInference(
model,
(input,),
@@ -34,6 +38,10 @@ class Resnet50ModuleTester:
class Resnet50ModuleTest(unittest.TestCase):
@pytest.fixture(autouse=True)
def configure(self, pytestconfig):
self.save_mlir = pytestconfig.getoption("save_mlir")
def setUp(self):
self.module_tester = Resnet50ModuleTester(self)

View File

@@ -1,6 +1,7 @@
from shark.shark_inference import SharkInference
from shark.iree_utils import check_device_drivers
from tank.model_utils import get_vision_model, compare_tensors
from shark.parser import shark_args
import torch
import unittest
@@ -16,12 +17,15 @@ class SqueezenetModuleTester:
self,
dynamic=False,
device="cpu",
save_mlir=False
):
self.dynamic = dynamic
self.device = device
self.save_mlir = save_mlir
def create_and_check_module(self):
model, input, act_out = get_vision_model(models.squeezenet1_0(pretrained=True))
shark_args.save_mlir = self.save_mlir
shark_module = SharkInference(
model,
(input,),
@@ -34,6 +38,10 @@ class SqueezenetModuleTester:
class SqueezenetModuleTest(unittest.TestCase):
@pytest.fixture(autouse=True)
def configure(self, pytestconfig):
self.save_mlir = pytestconfig.getoption("save_mlir")
def setUp(self):
self.module_tester = SqueezenetModuleTester(self)

View File

@@ -1,6 +1,7 @@
from shark.shark_inference import SharkInference
from shark.iree_utils import check_device_drivers
from tank.model_utils import get_vision_model, compare_tensors
from shark.parser import shark_args
import torch
import unittest
@@ -16,12 +17,15 @@ class WideResnet50ModuleTester:
self,
dynamic=False,
device="cpu",
save_mlir=False,
):
self.dynamic = dynamic
self.device = device
self.save_mlir = save_mlir
def create_and_check_module(self):
model, input, act_out = get_vision_model(models.wide_resnet50_2(pretrained=True))
shark_args.save_mlir = self.save_mlir
shark_module = SharkInference(
model,
(input,),
@@ -34,6 +38,10 @@ class WideResnet50ModuleTester:
class WideResnet50ModuleTest(unittest.TestCase):
@pytest.fixture(autouse=True)
def configure(self, pytestconfig):
self.save_mlir = pytestconfig.getoption("save_mlir")
def setUp(self):
self.module_tester = WideResnet50ModuleTester(self)