mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Refactor TF tests for importer split, update pytorch tests.
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
setuptools
|
||||
wheel
|
||||
|
||||
#SHARK Runner
|
||||
# SHARK Runner
|
||||
tqdm
|
||||
|
||||
#Testing
|
||||
# Testing
|
||||
pytest
|
||||
pytest-xdist
|
||||
pytest - xdist
|
||||
|
||||
@@ -43,6 +43,9 @@ class MaskedLM(tf.Module):
|
||||
|
||||
|
||||
def get_causal_lm_model(hf_name, text="Hello, this is the default text."):
|
||||
gpus = tf.config.experimental.list_physical_devices("GPU")
|
||||
for gpu in gpus:
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
model = MaskedLM(hf_name)
|
||||
encoded_input = preprocess_input(hf_name, text)
|
||||
test_input = (encoded_input["input_ids"], encoded_input["attention_mask"])
|
||||
@@ -37,6 +37,9 @@ class TFHuggingFaceLanguage(tf.Module):
|
||||
|
||||
|
||||
def get_TFhf_model(name):
|
||||
gpus = tf.config.experimental.list_physical_devices("GPU")
|
||||
for gpu in gpus:
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
model = TFHuggingFaceLanguage(name)
|
||||
tokenizer = BertTokenizer.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
|
||||
@@ -15,17 +15,13 @@ torch.manual_seed(0)
|
||||
class MiniLMModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
device="cpu",
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
):
|
||||
self.dynamic = dynamic
|
||||
self.device = device
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
def create_and_check_module(self):
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_hf_model(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
@@ -37,10 +33,10 @@ class MiniLMModuleTester:
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=True
|
||||
is_dynamic=dynamic, tracing_required=True
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device=self.device, mlir_dialect="linalg"
|
||||
minilm_mlir, func_name, device=device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
@@ -50,53 +46,51 @@ class MiniLMModuleTester:
|
||||
class MiniLMModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = MiniLMModuleTester(self)
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -15,17 +15,13 @@ import pytest
|
||||
class AlbertModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
device="cpu",
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
):
|
||||
self.dynamic = dynamic
|
||||
self.device = device
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
def create_and_check_module(self):
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_hf_model("albert-base-v2")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
@@ -34,11 +30,12 @@ class AlbertModuleTester:
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=True
|
||||
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=True
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device=self.device, mlir_dialect="linalg"
|
||||
mlir_module, func_name, device=device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
@@ -48,54 +45,51 @@ class AlbertModuleTester:
|
||||
class AlbertModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = AlbertModuleTester(self)
|
||||
self.module_tester.save_mlir = self.save_mlir
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -1,106 +0,0 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
import torch
|
||||
import unittest
|
||||
import numpy as np
|
||||
import torchvision.models as models
|
||||
import pytest
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class AlexnetModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
device="cpu",
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
):
|
||||
self.dynamic = dynamic
|
||||
self.device = device
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
def create_and_check_module(self):
|
||||
model, input, act_out = get_vision_model(
|
||||
models.alexnet(pretrained=True)
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
assert True == compare_tensors(act_out, results)
|
||||
|
||||
|
||||
class AlexnetModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = AlexnetModuleTester(self)
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
98
tank/pytorch/alexnet/alexnet_test.py
Normal file
98
tank/pytorch/alexnet/alexnet_test.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
import torch
|
||||
import unittest
|
||||
import numpy as np
|
||||
import torchvision.models as models
|
||||
import pytest
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class AlexnetModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
):
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_vision_model(
|
||||
models.alexnet(pretrained=True)
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
assert True == compare_tensors(act_out, results)
|
||||
|
||||
|
||||
class AlexnetModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = AlexnetModuleTester(self)
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -12,20 +12,16 @@ import pytest
|
||||
# torch.manual_seed(0)
|
||||
|
||||
|
||||
class BertModuleTester:
|
||||
class BertBaseUncasedModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
device="cpu",
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
):
|
||||
self.dynamic = dynamic
|
||||
self.device = device
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
def create_and_check_module(self):
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_hf_model("bert-base-uncased")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
@@ -34,67 +30,66 @@ class BertModuleTester:
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=True
|
||||
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=True
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device=self.device, mlir_dialect="linalg"
|
||||
mlir_module, func_name, device=device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
assert True == compare_tensors(act_out, results)
|
||||
|
||||
|
||||
class BertModuleTest(unittest.TestCase):
|
||||
class BertBaseUncasedModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = BertModuleTester(self)
|
||||
self.module_tester = BertBaseUncasedModuleTester(self)
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -15,17 +15,13 @@ import pytest
|
||||
class DistilBertModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
device="cpu",
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
):
|
||||
self.dynamic = dynamic
|
||||
self.device = device
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
def create_and_check_module(self):
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_hf_model("distilbert-base-uncased")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
@@ -34,11 +30,11 @@ class DistilBertModuleTester:
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=True
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=True
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device=self.device, mlir_dialect="linalg"
|
||||
mlir_module, func_name, device=device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
@@ -48,56 +44,57 @@ class DistilBertModuleTester:
|
||||
class DistilBertModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = DistilBertModuleTester(self)
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
@pytest.mark.skip(reason="torch-mlir lowering issues")
|
||||
def test_module_static_cpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(reason="torch_mlir lowering issues.")
|
||||
@pytest.mark.skip(reason="torch-mlir lowering issues")
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(reason="torch-mlir lowering issues")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(reason="torch_mlir lowering issues.")
|
||||
@pytest.mark.skip(reason="torch-mlir lowering issues")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(reason="torch-mlir lowering issues")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(reason="torch_mlir lowering issues.")
|
||||
@pytest.mark.skip(reason="torch-mlir lowering issues")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -9,23 +9,19 @@ import unittest
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
torch.manual_seed(0)
|
||||
# torch.manual_seed(0)
|
||||
|
||||
|
||||
class MobileBertUncasedModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
device="cpu",
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
):
|
||||
self.dynamic = dynamic
|
||||
self.device = device
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
def create_and_check_module(self):
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_hf_model("google/mobilebert-uncased")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
@@ -34,68 +30,66 @@ class MobileBertUncasedModuleTester:
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=True
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=True
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device=self.device, mlir_dialect="linalg"
|
||||
mlir_module, func_name, device=device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
assert True == compare_tensors(act_out, results)
|
||||
|
||||
|
||||
class MobileBertModuleTest(unittest.TestCase):
|
||||
class MobileBertUncasedModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = MobileBertUncasedModuleTester(self)
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(reason="golden and original results mismatch")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -16,17 +16,13 @@ torch.manual_seed(0)
|
||||
class Resnet101ModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
device="cpu",
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
):
|
||||
self.dynamic = dynamic
|
||||
self.device = device
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
def create_and_check_module(self):
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_vision_model(
|
||||
models.resnet101(pretrained=True)
|
||||
)
|
||||
@@ -37,11 +33,11 @@ class Resnet101ModuleTester:
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=False
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device=self.device, mlir_dialect="linalg"
|
||||
mlir_module, func_name, device=device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
@@ -51,53 +47,51 @@ class Resnet101ModuleTester:
|
||||
class Resnet101ModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = Resnet101ModuleTester(self)
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -16,17 +16,13 @@ torch.manual_seed(0)
|
||||
class Resnet18ModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
device="cpu",
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
):
|
||||
self.dynamic = dynamic
|
||||
self.device = device
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
def create_and_check_module(self):
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_vision_model(
|
||||
models.resnet18(pretrained=True)
|
||||
)
|
||||
@@ -37,11 +33,12 @@ class Resnet18ModuleTester:
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=False
|
||||
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device=self.device, mlir_dialect="linalg"
|
||||
mlir_module, func_name, device=device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
@@ -51,53 +48,51 @@ class Resnet18ModuleTester:
|
||||
class Resnet18ModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = Resnet18ModuleTester(self)
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -16,17 +16,13 @@ torch.manual_seed(0)
|
||||
class Resnet50ModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
device="cpu",
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
):
|
||||
self.dynamic = dynamic
|
||||
self.device = device
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
def create_and_check_module(self):
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_vision_model(
|
||||
models.resnet50(pretrained=True)
|
||||
)
|
||||
@@ -37,11 +33,11 @@ class Resnet50ModuleTester:
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=False
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device=self.device, mlir_dialect="linalg"
|
||||
mlir_module, func_name, device=device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
@@ -51,53 +47,51 @@ class Resnet50ModuleTester:
|
||||
class Resnet50ModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = Resnet50ModuleTester(self)
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -16,17 +16,13 @@ torch.manual_seed(0)
|
||||
class SqueezenetModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
device="cpu",
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
):
|
||||
self.dynamic = dynamic
|
||||
self.device = device
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
def create_and_check_module(self):
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_vision_model(
|
||||
models.squeezenet1_0(pretrained=True)
|
||||
)
|
||||
@@ -37,11 +33,11 @@ class SqueezenetModuleTester:
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic, tracing_required=False
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device=self.device, mlir_dialect="linalg"
|
||||
mlir_module, func_name, device=device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
@@ -51,53 +47,51 @@ class SqueezenetModuleTester:
|
||||
class SqueezenetModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = SqueezenetModuleTester(self)
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -16,17 +16,13 @@ torch.manual_seed(0)
|
||||
class WideResnet50ModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
device="cpu",
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
):
|
||||
self.dynamic = dynamic
|
||||
self.device = device
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
def create_and_check_module(self):
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_vision_model(
|
||||
models.wide_resnet50_2(pretrained=True)
|
||||
)
|
||||
@@ -37,11 +33,11 @@ class WideResnet50ModuleTester:
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=self.dynamic
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, device=self.device, mlir_dialect="linalg"
|
||||
mlir_module, func_name, device=device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
@@ -51,53 +47,51 @@ class WideResnet50ModuleTester:
|
||||
class WideResnet50ModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = WideResnet50ModuleTester(self)
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
dynamic = True
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
142
tank/tf/hf_masked_lm/MiniLM-L12-H384-uncased_tf_test.py
Normal file
142
tank/tf/hf_masked_lm/MiniLM-L12-H384-uncased_tf_test.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from tank.model_utils_tf import compare_tensors_tf, get_TFhf_model
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.parser import shark_args
|
||||
|
||||
import iree.compiler as ireec
|
||||
import unittest
|
||||
import pytest
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
|
||||
class MiniLMModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
# benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
# self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_TFhf_model(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
|
||||
if (
|
||||
shark_args.save_mlir == True
|
||||
or shark_args.save_vmfb == True
|
||||
or self.save_temps == True
|
||||
):
|
||||
repro_path = f"./shark_tmp/minilm_tf_{dynamic}_{device}"
|
||||
if not os.path.isdir(repro_path):
|
||||
os.mkdir(repro_path)
|
||||
shark_args.repro_dir = repro_path
|
||||
|
||||
if self.save_temps == True:
|
||||
temp_dir = tempfile.mkdtemp(
|
||||
prefix="iree_tfs", dir=shark_args.repro_dir
|
||||
)
|
||||
np.set_printoptions(threshold=np.inf)
|
||||
np.save(f"{temp_dir}/input1.npy", input[0])
|
||||
np.save(f"{temp_dir}/input2.npy", input[1])
|
||||
exp_out = act_out.numpy()
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
mlir_importer = SharkImporter(
|
||||
model, (input,), frontend="tensorflow"
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
else:
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
frontend="tensorflow",
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
outputs = results[0]
|
||||
assert True == compare_tensors_tf(act_out[0], outputs[1].to_host())
|
||||
|
||||
|
||||
class MiniLMModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = MiniLMModuleTester(self)
|
||||
self.module_tester.save_temps = pytestconfig.getoption("save_temps")
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
# self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,7 +1,8 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.masked_lm_tf import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.parser import shark_args
|
||||
|
||||
import iree.compiler as ireec
|
||||
@@ -15,15 +16,16 @@ import os
|
||||
class AlbertBaseModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False,
|
||||
save_temps=False,
|
||||
# benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
self.benchmark = benchmark
|
||||
self.save_temps = save_temps
|
||||
|
||||
# self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model("albert-base-v2")
|
||||
@@ -35,7 +37,7 @@ class AlbertBaseModuleTester:
|
||||
or shark_args.save_vmfb == True
|
||||
or self.save_temps == True
|
||||
):
|
||||
repro_path = f"shark_tmp/albert_base_tf_{dynamic}_{device}"
|
||||
repro_path = f"./shark_tmp/albert-base-v2_tf_{dynamic}_{device}"
|
||||
if not os.path.isdir(repro_path):
|
||||
os.mkdir(repro_path)
|
||||
shark_args.repro_dir = repro_path
|
||||
@@ -51,37 +53,33 @@ class AlbertBaseModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
mlir_importer = SharkImporter(
|
||||
model, (input,), frontend="tensorflow"
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
frontend="tensorflow",
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "albert-base-v2", dynamic, device
|
||||
)
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
|
||||
class AlbertBaseModuleTest(unittest.TestCase):
|
||||
@@ -91,61 +89,51 @@ class AlbertBaseModuleTest(unittest.TestCase):
|
||||
self.module_tester.save_temps = pytestconfig.getoption("save_temps")
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
# @pytest.mark.xfail(
|
||||
# reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536"
|
||||
# )
|
||||
# self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -1,7 +1,8 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.masked_lm_tf import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.parser import shark_args
|
||||
|
||||
import iree.compiler as ireec
|
||||
@@ -15,15 +16,16 @@ import os
|
||||
class BertBaseUncasedModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False,
|
||||
save_temps=False,
|
||||
# benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
self.benchmark = benchmark
|
||||
self.save_temps = save_temps
|
||||
|
||||
# self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model("bert-base-uncased")
|
||||
@@ -35,7 +37,7 @@ class BertBaseUncasedModuleTester:
|
||||
or shark_args.save_vmfb == True
|
||||
or self.save_temps == True
|
||||
):
|
||||
repro_path = f"./shark_tmp/bert_base_uncased_tf_{dynamic}_{device}"
|
||||
repro_path = f"./shark_tmp/bert-base-uncased_tf_{dynamic}_{device}"
|
||||
if not os.path.isdir(repro_path):
|
||||
os.mkdir(repro_path)
|
||||
shark_args.repro_dir = repro_path
|
||||
@@ -51,37 +53,33 @@ class BertBaseUncasedModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
mlir_importer = SharkImporter(
|
||||
model, (input,), frontend="tensorflow"
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
frontend="tensorflow",
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "bert_base_uncased", dynamic, device
|
||||
)
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
|
||||
class BertBaseUncasedModuleTest(unittest.TestCase):
|
||||
@@ -91,61 +89,51 @@ class BertBaseUncasedModuleTest(unittest.TestCase):
|
||||
self.module_tester.save_temps = pytestconfig.getoption("save_temps")
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536"
|
||||
)
|
||||
# self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -1,7 +1,8 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.masked_lm_tf import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.parser import shark_args
|
||||
|
||||
import iree.compiler as ireec
|
||||
@@ -15,15 +16,16 @@ import os
|
||||
class CamemBertModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False,
|
||||
save_temps=False,
|
||||
# benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
self.benchmark = benchmark
|
||||
self.save_temps = save_temps
|
||||
|
||||
# self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model("camembert-base")
|
||||
@@ -35,7 +37,7 @@ class CamemBertModuleTester:
|
||||
or shark_args.save_vmfb == True
|
||||
or self.save_temps == True
|
||||
):
|
||||
repro_path = f"./shark_tmp/camembert_base_tf_{dynamic}_{device}"
|
||||
repro_path = f"./shark_tmp/camembert-base_tf_{dynamic}_{device}"
|
||||
if not os.path.isdir(repro_path):
|
||||
os.mkdir(repro_path)
|
||||
shark_args.repro_dir = repro_path
|
||||
@@ -51,37 +53,33 @@ class CamemBertModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
mlir_importer = SharkImporter(
|
||||
model, (input,), frontend="tensorflow"
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
frontend="tensorflow",
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "camembert-base", dynamic, device
|
||||
)
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
|
||||
class CamemBertModuleTest(unittest.TestCase):
|
||||
@@ -89,63 +87,53 @@ class CamemBertModuleTest(unittest.TestCase):
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = CamemBertModuleTester(self)
|
||||
self.module_tester.save_temps = pytestconfig.getoption("save_temps")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536"
|
||||
)
|
||||
# self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -1,7 +1,8 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.masked_lm_tf import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.parser import shark_args
|
||||
|
||||
import iree.compiler as ireec
|
||||
@@ -15,15 +16,16 @@ import os
|
||||
class ConvBertModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False,
|
||||
save_temps=False,
|
||||
# benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
self.benchmark = benchmark
|
||||
self.save_temps = save_temps
|
||||
|
||||
# self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model(
|
||||
@@ -53,37 +55,32 @@ class ConvBertModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
mlir_importer = SharkImporter(
|
||||
model, (input,), frontend="tensorflow"
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
frontend="tensorflow",
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "convbert-base-turkish-cased", dynamic, device
|
||||
)
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
|
||||
class ConvBertModuleTest(unittest.TestCase):
|
||||
@@ -93,61 +90,51 @@ class ConvBertModuleTest(unittest.TestCase):
|
||||
self.module_tester.save_temps = pytestconfig.getoption("save_temps")
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536."
|
||||
)
|
||||
# self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -1,167 +0,0 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
|
||||
import iree.compiler as ireec
|
||||
import unittest
|
||||
import pytest
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
|
||||
class DebertaModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model("microsoft/deberta-base")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
|
||||
if (
|
||||
shark_args.save_mlir == True
|
||||
or shark_args.save_vmfb == True
|
||||
or self.save_temps == True
|
||||
):
|
||||
repro_path = f"./shark_tmp/deberta_tf_{dynamic}_{device}"
|
||||
if not os.path.isdir(repro_path):
|
||||
os.mkdir(repro_path)
|
||||
shark_args.repro_dir = repro_path
|
||||
|
||||
if self.save_temps == True:
|
||||
temp_dir = tempfile.mkdtemp(
|
||||
prefix="iree_tfs", dir=shark_args.repro_dir
|
||||
)
|
||||
np.set_printoptions(threshold=np.inf)
|
||||
np.save(f"{temp_dir}/input1.npy", input[0])
|
||||
np.save(f"{temp_dir}/input2.npy", input[1])
|
||||
exp_out = act_out.numpy()
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "deberta-base", dynamic, device
|
||||
)
|
||||
|
||||
|
||||
class DebertaModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = DebertaModuleTester(self)
|
||||
self.module_tester.save_temps = pytestconfig.getoption("save_temps")
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skip(
|
||||
reason="deberta currently failing in the lowering passes."
|
||||
)
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="deberta currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="deberta currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="deberta currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="deberta currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,7 +1,8 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.masked_lm_tf import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.parser import shark_args
|
||||
|
||||
import iree.compiler as ireec
|
||||
@@ -12,20 +13,22 @@ import tempfile
|
||||
import os
|
||||
|
||||
|
||||
class XLMRobertaModuleTester:
|
||||
class DebertaBaseModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False,
|
||||
save_temps=False,
|
||||
# benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
self.save_temps = save_temps
|
||||
|
||||
# self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model("xlm-roberta-base")
|
||||
model, input, act_out = get_causal_lm_model("microsoft/deberta-base")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
|
||||
@@ -34,7 +37,7 @@ class XLMRobertaModuleTester:
|
||||
or shark_args.save_vmfb == True
|
||||
or self.save_temps == True
|
||||
):
|
||||
repro_path = f"./shark_tmp/xlm_roberta_tf_{dynamic}_{device}"
|
||||
repro_path = f"./shark_tmp/deberta-base_tf_{dynamic}_{device}"
|
||||
if not os.path.isdir(repro_path):
|
||||
os.mkdir(repro_path)
|
||||
shark_args.repro_dir = repro_path
|
||||
@@ -50,98 +53,86 @@ class XLMRobertaModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
mlir_importer = SharkImporter(
|
||||
model, (input,), frontend="tensorflow"
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
frontend="tensorflow",
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "xlm-roberta-base", dynamic, device
|
||||
)
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
|
||||
class XLMRobertaModuleTest(unittest.TestCase):
|
||||
class DebertaBaseModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = XLMRobertaModuleTester(self)
|
||||
self.module_tester = DebertaBaseModuleTester(self)
|
||||
self.module_tester.save_temps = pytestconfig.getoption("save_temps")
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.skip(reason="Test currently hangs.")
|
||||
# self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -1,7 +1,8 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.masked_lm_tf import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.parser import shark_args
|
||||
|
||||
import iree.compiler as ireec
|
||||
@@ -15,14 +16,16 @@ import os
|
||||
class DistilBertModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False,
|
||||
save_temps=False,
|
||||
# benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
self.save_temps = save_temps
|
||||
|
||||
# self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model("distilbert-base-uncased")
|
||||
@@ -34,7 +37,9 @@ class DistilBertModuleTester:
|
||||
or shark_args.save_vmfb == True
|
||||
or self.save_temps == True
|
||||
):
|
||||
repro_path = f"./shark_tmp/distilbert_tf_{dynamic}_{device}"
|
||||
repro_path = (
|
||||
f"./shark_tmp/distilbert_base_uncased_tf_{dynamic}_{device}"
|
||||
)
|
||||
if not os.path.isdir(repro_path):
|
||||
os.mkdir(repro_path)
|
||||
shark_args.repro_dir = repro_path
|
||||
@@ -50,37 +55,33 @@ class DistilBertModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
mlir_importer = SharkImporter(
|
||||
model, (input,), frontend="tensorflow"
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
frontend="tensorflow",
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "distilbert-base-uncased", dynamic, device
|
||||
)
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
|
||||
class DistilBertModuleTest(unittest.TestCase):
|
||||
@@ -90,61 +91,51 @@ class DistilBertModuleTest(unittest.TestCase):
|
||||
self.module_tester.save_temps = pytestconfig.getoption("save_temps")
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536"
|
||||
)
|
||||
# self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -1,7 +1,8 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.masked_lm_tf import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.parser import shark_args
|
||||
|
||||
import iree.compiler as ireec
|
||||
@@ -15,14 +16,16 @@ import os
|
||||
class ElectraModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False,
|
||||
save_temps=False,
|
||||
# benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
self.save_temps = save_temps
|
||||
|
||||
# self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model(
|
||||
@@ -36,7 +39,7 @@ class ElectraModuleTester:
|
||||
or shark_args.save_vmfb == True
|
||||
or self.save_temps == True
|
||||
):
|
||||
repro_path = f"./shark_tmp/electra_tf_{dynamic}_{device}"
|
||||
repro_path = f"./shark_tmp/electra_small_discriminator_tf_{dynamic}_{device}"
|
||||
if not os.path.isdir(repro_path):
|
||||
os.mkdir(repro_path)
|
||||
shark_args.repro_dir = repro_path
|
||||
@@ -52,37 +55,33 @@ class ElectraModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
mlir_importer = SharkImporter(
|
||||
model, (input,), frontend="tensorflow"
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
frontend="tensorflow",
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "electra-small-discriminator", dynamic, device
|
||||
)
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
|
||||
class ElectraModuleTest(unittest.TestCase):
|
||||
@@ -92,60 +91,51 @@ class ElectraModuleTest(unittest.TestCase):
|
||||
self.module_tester.save_temps = pytestconfig.getoption("save_temps")
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536"
|
||||
)
|
||||
# self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -1,7 +1,8 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.masked_lm_tf import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.parser import shark_args
|
||||
|
||||
import iree.compiler as ireec
|
||||
@@ -15,14 +16,16 @@ import os
|
||||
class FunnelModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False,
|
||||
save_temps=False,
|
||||
# benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
self.save_temps = save_temps
|
||||
|
||||
# self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model("funnel-transformer/small")
|
||||
@@ -34,7 +37,7 @@ class FunnelModuleTester:
|
||||
or shark_args.save_vmfb == True
|
||||
or self.save_temps == True
|
||||
):
|
||||
repro_path = f"./shark_tmp/funnel_tf_{dynamic}_{device}"
|
||||
repro_path = f"./shark_tmp/funnel_small_tf_{dynamic}_{device}"
|
||||
if not os.path.isdir(repro_path):
|
||||
os.mkdir(repro_path)
|
||||
shark_args.repro_dir = repro_path
|
||||
@@ -50,37 +53,33 @@ class FunnelModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
mlir_importer = SharkImporter(
|
||||
model, (input,), frontend="tensorflow"
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
frontend="tensorflow",
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "funnel-transformer-small", dynamic, device
|
||||
)
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
|
||||
class FunnelModuleTest(unittest.TestCase):
|
||||
@@ -90,70 +89,51 @@ class FunnelModuleTest(unittest.TestCase):
|
||||
self.module_tester.save_temps = pytestconfig.getoption("save_temps")
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="funnel currently failing in the lowering passes."
|
||||
)
|
||||
# self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="funnel currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="funnel currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="funnel currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="funnel currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -1,7 +1,8 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.masked_lm_tf import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.parser import shark_args
|
||||
|
||||
import iree.compiler as ireec
|
||||
@@ -12,17 +13,19 @@ import tempfile
|
||||
import os
|
||||
|
||||
|
||||
class LayoutLmModuleTester:
|
||||
class LayoutLMModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False,
|
||||
save_temps=False,
|
||||
# benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
self.save_temps = save_temps
|
||||
|
||||
# self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model(
|
||||
@@ -36,7 +39,9 @@ class LayoutLmModuleTester:
|
||||
or shark_args.save_vmfb == True
|
||||
or self.save_temps == True
|
||||
):
|
||||
repro_path = f"./shark_tmp/layoutlm_tf_{dynamic}_{device}"
|
||||
repro_path = (
|
||||
f"./shark_tmp/layoutlm_base_uncased_tf_{dynamic}_{device}"
|
||||
)
|
||||
if not os.path.isdir(repro_path):
|
||||
os.mkdir(repro_path)
|
||||
shark_args.repro_dir = repro_path
|
||||
@@ -52,100 +57,87 @@ class LayoutLmModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
mlir_importer = SharkImporter(
|
||||
model, (input,), frontend="tensorflow"
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
frontend="tensorflow",
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "layoutlm-base-uncased", dynamic, device
|
||||
)
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
|
||||
class LayoutLmModuleTest(unittest.TestCase):
|
||||
class LayoutLMModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = LayoutLmModuleTester(self)
|
||||
self.module_tester = LayoutLMModuleTester(self)
|
||||
self.module_tester.save_temps = pytestconfig.getoption("save_temps")
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536"
|
||||
)
|
||||
# self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
146
tank/tf/hf_masked_lm/longformer-base-4096_tf_test.py
Normal file
146
tank/tf/hf_masked_lm/longformer-base-4096_tf_test.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from tank.masked_lm_tf import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.parser import shark_args
|
||||
|
||||
import iree.compiler as ireec
|
||||
import unittest
|
||||
import pytest
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
|
||||
class LongformerModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
save_temps=False,
|
||||
# benchmark=False,
|
||||
):
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
self.save_temps = save_temps
|
||||
|
||||
# self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model(
|
||||
"allenai/longformer-base-4096"
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
|
||||
if (
|
||||
shark_args.save_mlir == True
|
||||
or shark_args.save_vmfb == True
|
||||
or self.save_temps == True
|
||||
):
|
||||
repro_path = f"./shark_tmp/longformer-base_tf_{dynamic}_{device}"
|
||||
if not os.path.isdir(repro_path):
|
||||
os.mkdir(repro_path)
|
||||
shark_args.repro_dir = repro_path
|
||||
|
||||
if self.save_temps == True:
|
||||
temp_dir = tempfile.mkdtemp(
|
||||
prefix="iree_tfs", dir=shark_args.repro_dir
|
||||
)
|
||||
np.set_printoptions(threshold=np.inf)
|
||||
np.save(f"{temp_dir}/input1.npy", input[0])
|
||||
np.save(f"{temp_dir}/input2.npy", input[1])
|
||||
exp_out = act_out.numpy()
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
mlir_importer = SharkImporter(
|
||||
model, (input,), frontend="tensorflow"
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
else:
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
frontend="tensorflow",
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
|
||||
class LongformerModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = LongformerModuleTester(self)
|
||||
self.module_tester.save_temps = pytestconfig.getoption("save_temps")
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
# self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,7 +1,8 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.masked_lm_tf import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.parser import shark_args
|
||||
|
||||
import iree.compiler as ireec
|
||||
@@ -15,14 +16,16 @@ import os
|
||||
class MobileBertModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False,
|
||||
save_temps=False,
|
||||
# benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
self.save_temps = save_temps
|
||||
|
||||
# self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model(
|
||||
@@ -36,7 +39,9 @@ class MobileBertModuleTester:
|
||||
or shark_args.save_vmfb == True
|
||||
or self.save_temps == True
|
||||
):
|
||||
repro_path = f"./shark_tmp/mobilebert_tf_{dynamic}_{device}"
|
||||
repro_path = (
|
||||
f"./shark_tmp/mobilebert-uncased_tf_{dynamic}_{device}"
|
||||
)
|
||||
if not os.path.isdir(repro_path):
|
||||
os.mkdir(repro_path)
|
||||
shark_args.repro_dir = repro_path
|
||||
@@ -52,37 +57,33 @@ class MobileBertModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
mlir_importer = SharkImporter(
|
||||
model, (input,), frontend="tensorflow"
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
frontend="tensorflow",
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "mobilebert-uncased", dynamic, device
|
||||
)
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
|
||||
class MobileBertModuleTest(unittest.TestCase):
|
||||
@@ -92,60 +93,51 @@ class MobileBertModuleTest(unittest.TestCase):
|
||||
self.module_tester.save_temps = pytestconfig.getoption("save_temps")
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536"
|
||||
)
|
||||
# self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -1,7 +1,8 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.masked_lm_tf import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.parser import shark_args
|
||||
|
||||
import iree.compiler as ireec
|
||||
@@ -15,14 +16,16 @@ import os
|
||||
class MpNetModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False,
|
||||
save_temps=False,
|
||||
# benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
self.save_temps = save_temps
|
||||
|
||||
# self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model("microsoft/mpnet-base")
|
||||
@@ -34,7 +37,7 @@ class MpNetModuleTester:
|
||||
or shark_args.save_vmfb == True
|
||||
or self.save_temps == True
|
||||
):
|
||||
repro_path = f"./shark_tmp/mpnet_tf_{dynamic}_{device}"
|
||||
repro_path = f"./shark_tmp/mpnet-base_tf_{dynamic}_{device}"
|
||||
if not os.path.isdir(repro_path):
|
||||
os.mkdir(repro_path)
|
||||
shark_args.repro_dir = repro_path
|
||||
@@ -50,37 +53,33 @@ class MpNetModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
mlir_importer = SharkImporter(
|
||||
model, (input,), frontend="tensorflow"
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
frontend="tensorflow",
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "mpnet-base", dynamic, device
|
||||
)
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
|
||||
class MpNetModuleTest(unittest.TestCase):
|
||||
@@ -90,60 +89,51 @@ class MpNetModuleTest(unittest.TestCase):
|
||||
self.module_tester.save_temps = pytestconfig.getoption("save_temps")
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536"
|
||||
)
|
||||
# self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -1,7 +1,8 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.masked_lm_tf import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.parser import shark_args
|
||||
|
||||
import iree.compiler as ireec
|
||||
@@ -15,14 +16,16 @@ import os
|
||||
class RemBertModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False,
|
||||
save_temps=False,
|
||||
# benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
self.save_temps = save_temps
|
||||
|
||||
# self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model("google/rembert")
|
||||
@@ -50,35 +53,32 @@ class RemBertModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
mlir_importer = SharkImporter(
|
||||
model, (input,), frontend="tensorflow"
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
frontend="tensorflow",
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv((input), "rembert", dynamic, device)
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
|
||||
class RemBertModuleTest(unittest.TestCase):
|
||||
@@ -88,70 +88,51 @@ class RemBertModuleTest(unittest.TestCase):
|
||||
self.module_tester.save_temps = pytestconfig.getoption("save_temps")
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="rembert currently failing in the lowering passes."
|
||||
)
|
||||
# self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="rembert currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="rembert currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="rembert currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="rembert currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -1,7 +1,8 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.masked_lm_tf import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.parser import shark_args
|
||||
|
||||
import iree.compiler as ireec
|
||||
@@ -12,17 +13,19 @@ import tempfile
|
||||
import os
|
||||
|
||||
|
||||
class RobertaModuleTester:
|
||||
class RobertaBaseModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False,
|
||||
save_temps=False,
|
||||
# benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
self.save_temps = save_temps
|
||||
|
||||
# self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model("roberta-base")
|
||||
@@ -34,7 +37,7 @@ class RobertaModuleTester:
|
||||
or shark_args.save_vmfb == True
|
||||
or self.save_temps == True
|
||||
):
|
||||
repro_path = f"./shark_tmp/roberta_tf_{dynamic}_{device}"
|
||||
repro_path = f"./shark_tmp/roberta-base_tf_{dynamic}_{device}"
|
||||
if not os.path.isdir(repro_path):
|
||||
os.mkdir(repro_path)
|
||||
shark_args.repro_dir = repro_path
|
||||
@@ -50,100 +53,87 @@ class RobertaModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
mlir_importer = SharkImporter(
|
||||
model, (input,), frontend="tensorflow"
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
frontend="tensorflow",
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "roberta-base", dynamic, device
|
||||
)
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
|
||||
class RobertaModuleTest(unittest.TestCase):
|
||||
class RobertaBaseModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = RobertaModuleTester(self)
|
||||
self.module_tester = RobertaBaseModuleTester(self)
|
||||
self.module_tester.save_temps = pytestconfig.getoption("save_temps")
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536"
|
||||
)
|
||||
# self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -1,7 +1,8 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.masked_lm_tf import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.parser import shark_args
|
||||
|
||||
import iree.compiler as ireec
|
||||
@@ -15,14 +16,16 @@ import os
|
||||
class TapasBaseModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False,
|
||||
save_temps=False,
|
||||
# benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
self.save_temps = save_temps
|
||||
|
||||
# self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model("google/tapas-base")
|
||||
@@ -34,7 +37,7 @@ class TapasBaseModuleTester:
|
||||
or shark_args.save_vmfb == True
|
||||
or self.save_temps == True
|
||||
):
|
||||
repro_path = f"./shark_tmp/tapas_base_tf_{dynamic}_{device}"
|
||||
repro_path = f"./shark_tmp/tapas-base_tf_{dynamic}_{device}"
|
||||
if not os.path.isdir(repro_path):
|
||||
os.mkdir(repro_path)
|
||||
shark_args.repro_dir = repro_path
|
||||
@@ -50,37 +53,33 @@ class TapasBaseModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
mlir_importer = SharkImporter(
|
||||
model, (input,), frontend="tensorflow"
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
frontend="tensorflow",
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "tapas-base", dynamic, device
|
||||
)
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
|
||||
class TapasBaseModuleTest(unittest.TestCase):
|
||||
@@ -90,60 +89,51 @@ class TapasBaseModuleTest(unittest.TestCase):
|
||||
self.module_tester.save_temps = pytestconfig.getoption("save_temps")
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.skip(reason="tapas currently failing in the lowering passes.")
|
||||
# self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(reason="tapas currently failing in the lowering passes.")
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(reason="tapas currently failing in the lowering passes.")
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skip(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(reason="tapas currently failing in the lowering passes.")
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(reason="tapas currently failing in the lowering passes.")
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -1,7 +1,8 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.masked_lm_tf import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.parser import shark_args
|
||||
|
||||
import iree.compiler as ireec
|
||||
@@ -15,14 +16,16 @@ import os
|
||||
class FlauBertModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False,
|
||||
save_temps=False,
|
||||
# benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
self.save_temps = save_temps
|
||||
|
||||
# self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model(
|
||||
@@ -36,7 +39,9 @@ class FlauBertModuleTester:
|
||||
or shark_args.save_vmfb == True
|
||||
or self.save_temps == True
|
||||
):
|
||||
repro_path = f"./shark_tmp/tiny_flaubert_tf_{dynamic}_{device}"
|
||||
repro_path = (
|
||||
f"./shark_tmp/tiny-random-flaubert_tf_{dynamic}_{device}"
|
||||
)
|
||||
if not os.path.isdir(repro_path):
|
||||
os.mkdir(repro_path)
|
||||
shark_args.repro_dir = repro_path
|
||||
@@ -52,37 +57,33 @@ class FlauBertModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
mlir_importer = SharkImporter(
|
||||
model, (input,), frontend="tensorflow"
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
frontend="tensorflow",
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "tiny-random-flaubert", dynamic, device
|
||||
)
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
|
||||
class FlauBertModuleTest(unittest.TestCase):
|
||||
@@ -92,57 +93,49 @@ class FlauBertModuleTest(unittest.TestCase):
|
||||
self.module_tester.save_temps = pytestconfig.getoption("save_temps")
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
# self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skip(reason="https://github.com/nod-ai/SHARK/issues/154")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skip(reason="https://github.com/nod-ai/SHARK/issues/154")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -1,7 +1,8 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.masked_lm_tf import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from shark.parser import shark_args
|
||||
|
||||
import iree.compiler as ireec
|
||||
@@ -12,22 +13,22 @@ import tempfile
|
||||
import os
|
||||
|
||||
|
||||
class LongFormerModuleTester:
|
||||
class XLMRobertaModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False,
|
||||
save_temps=False,
|
||||
# benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
self.save_temps = save_temps
|
||||
|
||||
# self.benchmark = benchmark
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model(
|
||||
"allenai/longformer-base-4096"
|
||||
)
|
||||
model, input, act_out = get_causal_lm_model("xlm-roberta-base")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
|
||||
@@ -36,7 +37,7 @@ class LongFormerModuleTester:
|
||||
or shark_args.save_vmfb == True
|
||||
or self.save_temps == True
|
||||
):
|
||||
repro_path = f"./shark_tmp/longformer_tf_{dynamic}_{device}"
|
||||
repro_path = f"./shark_tmp/xlm-roberta-base_tf_{dynamic}_{device}"
|
||||
if not os.path.isdir(repro_path):
|
||||
os.mkdir(repro_path)
|
||||
shark_args.repro_dir = repro_path
|
||||
@@ -52,110 +53,93 @@ class LongFormerModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
mlir_importer = SharkImporter(
|
||||
model, (input,), frontend="tensorflow"
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
mlir_importer = SharkImporter(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
frontend="tensorflow",
|
||||
)
|
||||
mlir_module, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=dynamic, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module, func_name, device=device, mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "longformer-base-4096", dynamic, device
|
||||
)
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
|
||||
class LongFormerModuleTest(unittest.TestCase):
|
||||
class XLMRobertaModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = LongFormerModuleTester(self)
|
||||
self.module_tester = XLMRobertaModuleTester(self)
|
||||
self.module_tester.save_temps = pytestconfig.getoption("save_temps")
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="longformer currently failing in the lowering passes."
|
||||
)
|
||||
# self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.skip(reason="https://github.com/nod-ai/SHARK/issues/141")
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skip(reason="https://github.com/nod-ai/SHARK/issues/141")
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9536")
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="longformer currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.skip(reason="https://github.com/nod-ai/SHARK/issues/141")
|
||||
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="longformer currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skip(reason="https://github.com/nod-ai/SHARK/issues/141")
|
||||
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("gpu"), reason=device_driver_info("gpu")
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="longformer currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.skip(reason="https://github.com/nod-ai/SHARK/issues/141")
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="longformer currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skip(reason="https://github.com/nod-ai/SHARK/issues/141")
|
||||
@pytest.mark.xfail(reason="https://github.com/iree-org/iree/issues/9524")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
Reference in New Issue
Block a user