mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Divide iree_utils and do module imports on function calls.
This commit is contained in:
@@ -1,18 +1,26 @@
|
||||
def pytest_addoption(parser):
|
||||
# Attaches SHARK command-line arguments to the pytest machinery.
|
||||
parser.addoption("--save_mlir",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Pass option to save input MLIR")
|
||||
parser.addoption("--save_vmfb",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Pass option to save IREE output .vmfb")
|
||||
parser.addoption("--benchmark",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Pass option to benchmark and write results.csv")
|
||||
parser.addoption("--save_temps",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Saves IREE reproduction artifacts for filing upstream issues.")
|
||||
parser.addoption(
|
||||
"--save_mlir",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Pass option to save input MLIR",
|
||||
)
|
||||
parser.addoption(
|
||||
"--save_vmfb",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Pass option to save IREE output .vmfb",
|
||||
)
|
||||
parser.addoption(
|
||||
"--benchmark",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Pass option to benchmark and write results.csv",
|
||||
)
|
||||
parser.addoption(
|
||||
"--save_temps",
|
||||
action="store_true",
|
||||
default="False",
|
||||
help="Saves IREE reproduction artifacts for filing upstream issues.",
|
||||
)
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils import check_device_drivers
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import torchvision.models as models
|
||||
from transformers import AutoModelForSequenceClassification, BertTokenizer, TFBertModel
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
BertTokenizer,
|
||||
TFBertModel,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -12,17 +15,13 @@ torch.manual_seed(0)
|
||||
|
||||
|
||||
class HuggingFaceLanguage(torch.nn.Module):
|
||||
|
||||
def __init__(self, hf_model_name):
|
||||
super().__init__()
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
hf_model_name, # The pretrained model.
|
||||
num_labels=
|
||||
2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=
|
||||
False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=
|
||||
False, # Whether the model returns all hidden-states.
|
||||
num_labels=2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=False, # Whether the model returns all hidden-states.
|
||||
torchscript=True,
|
||||
)
|
||||
|
||||
@@ -44,7 +43,6 @@ def get_hf_model(name):
|
||||
|
||||
|
||||
class VisionModule(torch.nn.Module):
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
@@ -61,6 +59,7 @@ def get_vision_model(torch_model):
|
||||
actual_out = model(test_input)
|
||||
return model, test_input, actual_out
|
||||
|
||||
|
||||
################################################################################
|
||||
|
||||
# Utility function for comparing two tensors (torch).
|
||||
@@ -70,4 +69,3 @@ def compare_tensors(torch_tensor, numpy_tensor):
|
||||
atol = 1e-03
|
||||
torch_to_numpy = torch_tensor.detach().numpy()
|
||||
return np.allclose(torch_to_numpy, numpy_tensor, rtol, atol)
|
||||
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils import check_device_drivers
|
||||
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from transformers import AutoModelForSequenceClassification, BertTokenizer, TFBertModel
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
BertTokenizer,
|
||||
TFBertModel,
|
||||
)
|
||||
|
||||
##################### Tensorflow Hugging Face LM Models ###################################
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
@@ -13,20 +16,20 @@ BATCH_SIZE = 1
|
||||
tf_bert_input = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32)
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
]
|
||||
|
||||
class TFHuggingFaceLanguage(tf.Module):
|
||||
|
||||
class TFHuggingFaceLanguage(tf.Module):
|
||||
def __init__(self, hf_model_name):
|
||||
super(TFHuggingFaceLanguage, self).__init__()
|
||||
# Create a BERT trainer with the created network.
|
||||
self.m = TFBertModel.from_pretrained(
|
||||
hf_model_name, from_pt=True)
|
||||
self.m = TFBertModel.from_pretrained(hf_model_name, from_pt=True)
|
||||
|
||||
# Invoke the trainer model on the inputs. This causes the layer to be built.
|
||||
self.m.predict = lambda x, y, z: self.m.call(
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False)
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
|
||||
@tf.function(input_signature=tf_bert_input)
|
||||
def forward(self, input_ids, attention_mask, token_type_ids):
|
||||
@@ -36,17 +39,24 @@ class TFHuggingFaceLanguage(tf.Module):
|
||||
def get_TFhf_model(name):
|
||||
model = TFHuggingFaceLanguage(name)
|
||||
tokenizer = BertTokenizer.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased")
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
text = "Replace me by any text you'd like."
|
||||
encoded_input = tokenizer(text,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH)
|
||||
encoded_input = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
)
|
||||
for key in encoded_input:
|
||||
encoded_input[key] = tf.expand_dims(
|
||||
tf.convert_to_tensor(encoded_input[key]), 0)
|
||||
test_input = (encoded_input["input_ids"], encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"])
|
||||
tf.convert_to_tensor(encoded_input[key]), 0
|
||||
)
|
||||
test_input = (
|
||||
encoded_input["input_ids"],
|
||||
encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"],
|
||||
)
|
||||
actual_out = model.forward(*test_input)
|
||||
return model, test_input, actual_out
|
||||
|
||||
@@ -58,5 +68,3 @@ def compare_tensors_tf(tf_tensor, numpy_tensor):
|
||||
atol = 1e-03
|
||||
tf_to_numpy = tf_tensor.numpy()
|
||||
return np.allclose(tf_to_numpy, numpy_tensor, rtol, atol)
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_hf_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -10,8 +10,8 @@ import pytest
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
class MiniLMModuleTester:
|
||||
|
||||
class MiniLMModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
@@ -25,24 +25,29 @@ class MiniLMModuleTester:
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
def create_and_check_module(self):
|
||||
model, input, act_out = get_hf_model("microsoft/MiniLM-L12-H384-uncased")
|
||||
model, input, act_out = get_hf_model(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=True)
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=True,
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
assert True == compare_tensors(act_out, results)
|
||||
|
||||
class MiniLMModuleTest(unittest.TestCase):
|
||||
|
||||
class MiniLMModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = MiniLMModuleTester(self)
|
||||
|
||||
@@ -50,20 +55,24 @@ class MiniLMModuleTest(unittest.TestCase):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="language models failing for dynamic case")
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
@@ -71,9 +80,9 @@ class MiniLMModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
)
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
@@ -81,14 +90,14 @@ class MiniLMModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
)
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_hf_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -8,10 +8,10 @@ import unittest
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
#torch.manual_seed(0)
|
||||
# torch.manual_seed(0)
|
||||
|
||||
|
||||
class AlbertModuleTester:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
@@ -28,21 +28,24 @@ class AlbertModuleTester:
|
||||
model, input, act_out = get_hf_model("albert-base-v2")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=True)
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=True,
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
assert True == compare_tensors(act_out, results)
|
||||
|
||||
|
||||
class AlbertModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = AlbertModuleTester(self)
|
||||
self.module_tester.save_mlir = self.save_mlir
|
||||
@@ -51,20 +54,26 @@ class AlbertModuleTest(unittest.TestCase):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.xfail(reason="Language models currently failing for dynamic case")
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
@@ -72,9 +81,9 @@ class AlbertModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
)
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
@@ -82,13 +91,14 @@ class AlbertModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
)
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -11,8 +11,8 @@ import pytest
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
class AlexnetModuleTester:
|
||||
|
||||
class AlexnetModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
@@ -26,69 +26,75 @@ class AlexnetModuleTester:
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
def create_and_check_module(self):
|
||||
model, input, act_out = get_vision_model(models.alexnet(pretrained=True))
|
||||
model, input, act_out = get_vision_model(
|
||||
models.alexnet(pretrained=True)
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
assert True == compare_tensors(act_out, results)
|
||||
|
||||
class AlexnetModuleTest(unittest.TestCase):
|
||||
|
||||
class AlexnetModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = AlexnetModuleTester(self)
|
||||
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
)
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
)
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_hf_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -8,10 +8,10 @@ import unittest
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
#torch.manual_seed(0)
|
||||
# torch.manual_seed(0)
|
||||
|
||||
|
||||
class BertModuleTester:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
@@ -28,42 +28,51 @@ class BertModuleTester:
|
||||
model, input, act_out = get_hf_model("bert-base-uncased")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=True)
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=True,
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
assert True == compare_tensors(act_out, results)
|
||||
|
||||
class BertModuleTest(unittest.TestCase):
|
||||
|
||||
class BertModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = BertModuleTester(self)
|
||||
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.xfail(reason="Language models currently failing for dynamic case")
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
@@ -71,9 +80,9 @@ class BertModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
)
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
@@ -81,14 +90,14 @@ class BertModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
)
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_hf_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -8,10 +8,10 @@ import unittest
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
#torch.manual_seed(0)
|
||||
# torch.manual_seed(0)
|
||||
|
||||
|
||||
class DistilBertModuleTester:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
@@ -28,43 +28,54 @@ class DistilBertModuleTester:
|
||||
model, input, act_out = get_hf_model("distilbert-base-uncased")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=True)
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=True,
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
assert True == compare_tensors(act_out, results)
|
||||
|
||||
class DistilBertModuleTest(unittest.TestCase):
|
||||
|
||||
class DistilBertModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = DistilBertModuleTester(self)
|
||||
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.xfail(reason="Language models currently failing for dynamic case")
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.xfail(reason="Language models currently failing for dynamic case")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
@@ -72,9 +83,9 @@ class DistilBertModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
)
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
@@ -82,14 +93,14 @@ class DistilBertModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9554")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
)
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -11,8 +11,8 @@ import pytest
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
class Resnet101ModuleTester:
|
||||
|
||||
class Resnet101ModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
@@ -26,68 +26,75 @@ class Resnet101ModuleTester:
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
def create_and_check_module(self):
|
||||
model, input, act_out = get_vision_model(models.resnet101(pretrained=True))
|
||||
model, input, act_out = get_vision_model(
|
||||
models.resnet101(pretrained=True)
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
assert True == compare_tensors(act_out, results)
|
||||
|
||||
class Resnet101ModuleTest(unittest.TestCase):
|
||||
|
||||
class Resnet101ModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = Resnet101ModuleTester(self)
|
||||
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
)
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
)
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -11,8 +11,8 @@ import pytest
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
class Resnet18ModuleTester:
|
||||
|
||||
class Resnet18ModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
@@ -26,69 +26,75 @@ class Resnet18ModuleTester:
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
def create_and_check_module(self):
|
||||
model, input, act_out = get_vision_model(models.resnet18(pretrained=True))
|
||||
model, input, act_out = get_vision_model(
|
||||
models.resnet18(pretrained=True)
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
assert True == compare_tensors(act_out, results)
|
||||
|
||||
class Resnet18ModuleTest(unittest.TestCase):
|
||||
|
||||
class Resnet18ModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = Resnet18ModuleTester(self)
|
||||
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
)
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
)
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -11,8 +11,8 @@ import pytest
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
class Resnet50ModuleTester:
|
||||
|
||||
class Resnet50ModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
@@ -26,68 +26,75 @@ class Resnet50ModuleTester:
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
def create_and_check_module(self):
|
||||
model, input, act_out = get_vision_model(models.resnet50(pretrained=True))
|
||||
model, input, act_out = get_vision_model(
|
||||
models.resnet50(pretrained=True)
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
assert True == compare_tensors(act_out, results)
|
||||
|
||||
class Resnet50ModuleTest(unittest.TestCase):
|
||||
|
||||
class Resnet50ModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = Resnet50ModuleTester(self)
|
||||
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
)
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
)
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -11,8 +11,8 @@ import pytest
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
class SqueezenetModuleTester:
|
||||
|
||||
class SqueezenetModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
@@ -26,69 +26,75 @@ class SqueezenetModuleTester:
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
def create_and_check_module(self):
|
||||
model, input, act_out = get_vision_model(models.squeezenet1_0(pretrained=True))
|
||||
model, input, act_out = get_vision_model(
|
||||
models.squeezenet1_0(pretrained=True)
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
assert True == compare_tensors(act_out, results)
|
||||
|
||||
class SqueezenetModuleTest(unittest.TestCase):
|
||||
|
||||
class SqueezenetModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = SqueezenetModuleTester(self)
|
||||
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
)
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
)
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -47,18 +47,23 @@ if os.path.exists(checkpoint):
|
||||
model.load_state_dict(torch.load(checkpoint, map_location="cpu"))
|
||||
|
||||
model = model.to(device).eval().requires_grad_(False)
|
||||
clip_model_name = model.clip_model if hasattr(model, "clip_model") else "ViT-B/16"
|
||||
clip_model_name = (
|
||||
model.clip_model if hasattr(model, "clip_model") else "ViT-B/16"
|
||||
)
|
||||
clip_model = clip.load(clip_model_name, jit=False, device=device)[0]
|
||||
clip_model.eval().requires_grad_(False)
|
||||
normalize = transforms.Normalize(
|
||||
mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]
|
||||
mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
std=[0.26862954, 0.26130258, 0.27577711],
|
||||
)
|
||||
|
||||
zero_embed = torch.zeros([1, clip_model.visual.output_dim], device=device)
|
||||
target_embeds, weights = [zero_embed], []
|
||||
|
||||
txt, weight = parse_prompt(args.prompts[0])
|
||||
target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
|
||||
target_embeds.append(
|
||||
clip_model.encode_text(clip.tokenize(txt).to(device)).float()
|
||||
)
|
||||
weights.append(weight)
|
||||
|
||||
weights = torch.tensor([1 - sum(weights), *weights], device=device)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils import get_vision_model, compare_tensors
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -11,8 +11,8 @@ import pytest
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
class WideResnet50ModuleTester:
|
||||
|
||||
class WideResnet50ModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
@@ -26,69 +26,75 @@ class WideResnet50ModuleTester:
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
def create_and_check_module(self):
|
||||
model, input, act_out = get_vision_model(models.wide_resnet50_2(pretrained=True))
|
||||
model, input, act_out = get_vision_model(
|
||||
models.wide_resnet50_2(pretrained=True)
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
model,
|
||||
(input,),
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
)
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input,))
|
||||
assert True == compare_tensors(act_out, results)
|
||||
|
||||
class WideResnet50ModuleTest(unittest.TestCase):
|
||||
|
||||
class WideResnet50ModuleTest(unittest.TestCase):
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = WideResnet50ModuleTester(self)
|
||||
|
||||
|
||||
def test_module_static_cpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
|
||||
def test_module_dynamic_cpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"), reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "gpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
)
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
self.module_tester.dynamic = False
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
)
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
self.module_tester.dynamic = True
|
||||
self.module_tester.device = "vulkan"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -9,11 +9,11 @@ inputs_signature = [
|
||||
|
||||
|
||||
class AutoModelMaskedLM(tf.Module):
|
||||
|
||||
def __init__(self, model_name):
|
||||
super(AutoModelMaskedLM, self).__init__()
|
||||
self.m = TFAutoModelForMaskedLM.from_pretrained(model_name,
|
||||
output_attentions=False)
|
||||
self.m = TFAutoModelForMaskedLM.from_pretrained(
|
||||
model_name, output_attentions=False
|
||||
)
|
||||
self.m.predict = lambda x: self.m(input_ids=x)
|
||||
|
||||
@tf.function(input_signature=inputs_signature)
|
||||
@@ -24,20 +24,26 @@ class AutoModelMaskedLM(tf.Module):
|
||||
fail_models = ["microsoft/deberta-base", "google/rembert", "google/tapas-base"]
|
||||
|
||||
supported_models = [
|
||||
"albert-base-v2", "bert-base-uncased", "camembert-base",
|
||||
"dbmdz/convbert-base-turkish-cased", "distilbert-base-uncased",
|
||||
"albert-base-v2",
|
||||
"bert-base-uncased",
|
||||
"camembert-base",
|
||||
"dbmdz/convbert-base-turkish-cased",
|
||||
"distilbert-base-uncased",
|
||||
"google/electra-small-discriminator",
|
||||
"hf-internal-testing/tiny-random-flaubert", "funnel-transformer/small",
|
||||
"microsoft/layoutlm-base-uncased", "allenai/longformer-base-4096",
|
||||
"google/mobilebert-uncased", "microsoft/mpnet-base", "roberta-base",
|
||||
"xlm-roberta-base"
|
||||
"hf-internal-testing/tiny-random-flaubert",
|
||||
"funnel-transformer/small",
|
||||
"microsoft/layoutlm-base-uncased",
|
||||
"allenai/longformer-base-4096",
|
||||
"google/mobilebert-uncased",
|
||||
"microsoft/mpnet-base",
|
||||
"roberta-base",
|
||||
"xlm-roberta-base",
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
inputs = tf.random.uniform(shape=[1, 512],
|
||||
maxval=3,
|
||||
dtype=tf.int32,
|
||||
seed=10)
|
||||
inputs = tf.random.uniform(
|
||||
shape=[1, 512], maxval=3, dtype=tf.int32, seed=10
|
||||
)
|
||||
|
||||
for model_name in supported_models:
|
||||
print(f"Running model: {model_name}")
|
||||
|
||||
@@ -19,24 +19,26 @@ BATCH_SIZE = 1
|
||||
bert_input = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32)
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
]
|
||||
|
||||
|
||||
class BertModule(tf.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(BertModule, self).__init__()
|
||||
dict_outputs = False
|
||||
test_network = networks.BertEncoder(vocab_size=vocab_size,
|
||||
num_layers=24,
|
||||
hidden_size=1024,
|
||||
num_attention_heads=16,
|
||||
dict_outputs=dict_outputs)
|
||||
test_network = networks.BertEncoder(
|
||||
vocab_size=vocab_size,
|
||||
num_layers=24,
|
||||
hidden_size=1024,
|
||||
num_attention_heads=16,
|
||||
dict_outputs=dict_outputs,
|
||||
)
|
||||
|
||||
# Create a BERT trainer with the created network.
|
||||
bert_trainer_model = bert_classifier.BertClassifier(
|
||||
test_network, num_classes=NUM_CLASSES)
|
||||
test_network, num_classes=NUM_CLASSES
|
||||
)
|
||||
bert_trainer_model.summary()
|
||||
|
||||
# Invoke the trainer model on the inputs. This causes the layer to be built.
|
||||
@@ -46,15 +48,20 @@ class BertModule(tf.Module):
|
||||
self.loss = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||
self.optimizer = tf.keras.optimizers.SGD(learning_rate=1e-2)
|
||||
|
||||
@tf.function(input_signature=[
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH],
|
||||
dtype=tf.int32), #input0: input_word_ids
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH],
|
||||
dtype=tf.int32), #input1: input_mask
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH],
|
||||
dtype=tf.int32), #input2: segment_ids
|
||||
tf.TensorSpec([BATCH_SIZE], tf.int32) # input3: labels
|
||||
])
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32
|
||||
), # input0: input_word_ids
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32
|
||||
), # input1: input_mask
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32
|
||||
), # input2: segment_ids
|
||||
tf.TensorSpec([BATCH_SIZE], tf.int32), # input3: labels
|
||||
]
|
||||
)
|
||||
def learn(self, input_word_ids, input_mask, segment_ids, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
# Capture the gradients from forward prop...
|
||||
@@ -77,12 +84,12 @@ class BertModule(tf.Module):
|
||||
if __name__ == "__main__":
|
||||
# BertModule()
|
||||
# Compile the model using IREE
|
||||
compiler_module = tfc.compile_module(BertModule(),
|
||||
exported_names=["learn"],
|
||||
import_only=True)
|
||||
compiler_module = tfc.compile_module(
|
||||
BertModule(), exported_names=["learn"], import_only=True
|
||||
)
|
||||
# Save module as MLIR file in a directory
|
||||
ARITFACTS_DIR = os.getcwd()
|
||||
mlir_path = os.path.join(ARITFACTS_DIR, "model.mlir")
|
||||
with open(mlir_path, "wt") as output_file:
|
||||
output_file.write(compiler_module.decode('utf-8'))
|
||||
output_file.write(compiler_module.decode("utf-8"))
|
||||
print(f"Wrote MLIR to path '{mlir_path}'")
|
||||
|
||||
@@ -21,24 +21,26 @@ BATCH_SIZE = 1
|
||||
bert_input = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32)
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
]
|
||||
|
||||
|
||||
class BertModule(tf.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(BertModule, self).__init__()
|
||||
dict_outputs = False
|
||||
test_network = networks.BertEncoder(vocab_size=vocab_size,
|
||||
num_layers=24,
|
||||
hidden_size=1024,
|
||||
num_attention_heads=16,
|
||||
dict_outputs=dict_outputs)
|
||||
test_network = networks.BertEncoder(
|
||||
vocab_size=vocab_size,
|
||||
num_layers=24,
|
||||
hidden_size=1024,
|
||||
num_attention_heads=16,
|
||||
dict_outputs=dict_outputs,
|
||||
)
|
||||
|
||||
# Create a BERT trainer with the created network.
|
||||
bert_trainer_model = bert_classifier.BertClassifier(
|
||||
test_network, num_classes=NUM_CLASSES)
|
||||
test_network, num_classes=NUM_CLASSES
|
||||
)
|
||||
bert_trainer_model.summary()
|
||||
|
||||
# Invoke the trainer model on the inputs. This causes the layer to be built.
|
||||
@@ -49,10 +51,12 @@ class BertModule(tf.Module):
|
||||
self.loss = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||
self.optimizer = tf.keras.optimizers.SGD(learning_rate=1e-2)
|
||||
|
||||
@tf.function(input_signature=[
|
||||
bert_input, # inputs
|
||||
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32) # labels
|
||||
])
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
bert_input, # inputs
|
||||
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels
|
||||
]
|
||||
)
|
||||
def learn(self, inputs, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
# Capture the gradients from forward prop...
|
||||
@@ -69,26 +73,29 @@ class BertModule(tf.Module):
|
||||
if __name__ == "__main__":
|
||||
# BertModule()
|
||||
# Compile the model using IREE
|
||||
compiler_module = tfc.compile_module(BertModule(),
|
||||
exported_names=["learn"],
|
||||
import_only=True)
|
||||
compiler_module = tfc.compile_module(
|
||||
BertModule(), exported_names=["learn"], import_only=True
|
||||
)
|
||||
|
||||
# Compile the model using IREE
|
||||
backend = "dylib-llvm-aot"
|
||||
args = [
|
||||
"--iree-llvm-target-cpu-features=host",
|
||||
"--iree-mhlo-demote-i64-to-i32=false",
|
||||
"--iree-stream-resource-index-bits=64", "--iree-vm-target-index-bits=64"
|
||||
"--iree-stream-resource-index-bits=64",
|
||||
"--iree-vm-target-index-bits=64",
|
||||
]
|
||||
backend_config = "dylib"
|
||||
#backend = "cuda"
|
||||
#backend_config = "cuda"
|
||||
#args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
|
||||
flatbuffer_blob = compile_str(compiler_module,
|
||||
target_backends=[backend],
|
||||
extra_args=args,
|
||||
input_type="mhlo")
|
||||
#flatbuffer_blob = compile_str(compiler_module, target_backends=["dylib-llvm-aot"])
|
||||
# backend = "cuda"
|
||||
# backend_config = "cuda"
|
||||
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
|
||||
flatbuffer_blob = compile_str(
|
||||
compiler_module,
|
||||
target_backends=[backend],
|
||||
extra_args=args,
|
||||
input_type="mhlo",
|
||||
)
|
||||
# flatbuffer_blob = compile_str(compiler_module, target_backends=["dylib-llvm-aot"])
|
||||
|
||||
# Save module as MLIR file in a directory
|
||||
vm_module = ireert.VmModule.from_flatbuffer(flatbuffer_blob)
|
||||
@@ -100,21 +107,23 @@ if __name__ == "__main__":
|
||||
predict_sample_input = [
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH))
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
]
|
||||
learn_sample_input = [
|
||||
predict_sample_input,
|
||||
np.random.randint(5, size=(BATCH_SIZE))
|
||||
np.random.randint(5, size=(BATCH_SIZE)),
|
||||
]
|
||||
warmup = 5
|
||||
total_iter = 10
|
||||
num_iter = total_iter - warmup
|
||||
for i in range(10):
|
||||
if (i == warmup - 1):
|
||||
if i == warmup - 1:
|
||||
start = time.time()
|
||||
print(
|
||||
BertCompiled.learn(predict_sample_input,
|
||||
np.random.randint(5, size=(BATCH_SIZE))))
|
||||
BertCompiled.learn(
|
||||
predict_sample_input, np.random.randint(5, size=(BATCH_SIZE))
|
||||
)
|
||||
)
|
||||
end = time.time()
|
||||
total_time = end - start
|
||||
print("time: " + str(total_time))
|
||||
|
||||
@@ -14,24 +14,26 @@ BATCH_SIZE = 1
|
||||
bert_input = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32)
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
]
|
||||
|
||||
|
||||
class BertModule(tf.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(BertModule, self).__init__()
|
||||
dict_outputs = False
|
||||
test_network = networks.BertEncoder(vocab_size=vocab_size,
|
||||
num_layers=24,
|
||||
hidden_size=1024,
|
||||
num_attention_heads=16,
|
||||
dict_outputs=dict_outputs)
|
||||
test_network = networks.BertEncoder(
|
||||
vocab_size=vocab_size,
|
||||
num_layers=24,
|
||||
hidden_size=1024,
|
||||
num_attention_heads=16,
|
||||
dict_outputs=dict_outputs,
|
||||
)
|
||||
|
||||
# Create a BERT trainer with the created network.
|
||||
bert_trainer_model = bert_classifier.BertClassifier(
|
||||
test_network, num_classes=NUM_CLASSES)
|
||||
test_network, num_classes=NUM_CLASSES
|
||||
)
|
||||
bert_trainer_model.summary()
|
||||
|
||||
# Invoke the trainer model on the inputs. This causes the layer to be built.
|
||||
@@ -42,10 +44,12 @@ class BertModule(tf.Module):
|
||||
self.loss = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||
self.optimizer = tf.keras.optimizers.SGD(learning_rate=1e-2)
|
||||
|
||||
@tf.function(input_signature=[
|
||||
bert_input, # inputs
|
||||
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32) # labels
|
||||
])
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
bert_input, # inputs
|
||||
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels
|
||||
]
|
||||
)
|
||||
def learn(self, inputs, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
# Capture the gradients from forward prop...
|
||||
@@ -64,7 +68,7 @@ if __name__ == "__main__":
|
||||
predict_sample_input = [
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH))
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
]
|
||||
bert_model = BertModule()
|
||||
warmup = 1
|
||||
@@ -72,9 +76,11 @@ if __name__ == "__main__":
|
||||
num_iter = total_iter - warmup
|
||||
for i in range(total_iter):
|
||||
print(
|
||||
bert_model.learn(predict_sample_input,
|
||||
np.random.randint(5, size=(BATCH_SIZE))))
|
||||
if (i == warmup - 1):
|
||||
bert_model.learn(
|
||||
predict_sample_input, np.random.randint(5, size=(BATCH_SIZE))
|
||||
)
|
||||
)
|
||||
if i == warmup - 1:
|
||||
start = time.time()
|
||||
|
||||
end = time.time()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from iree import runtime as ireert
|
||||
#from iree.tf.support import module_utils
|
||||
|
||||
# from iree.tf.support import module_utils
|
||||
from iree.compiler import tf as tfc
|
||||
from absl import app
|
||||
|
||||
@@ -19,22 +20,22 @@ BATCH_SIZE = 1
|
||||
bert_input = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32)
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
]
|
||||
|
||||
|
||||
class BertModule(tf.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(BertModule, self).__init__()
|
||||
dict_outputs = False
|
||||
test_network = networks.BertEncoder(vocab_size=vocab_size,
|
||||
num_layers=2,
|
||||
dict_outputs=dict_outputs)
|
||||
test_network = networks.BertEncoder(
|
||||
vocab_size=vocab_size, num_layers=2, dict_outputs=dict_outputs
|
||||
)
|
||||
|
||||
# Create a BERT trainer with the created network.
|
||||
bert_trainer_model = bert_classifier.BertClassifier(
|
||||
test_network, num_classes=NUM_CLASSES)
|
||||
test_network, num_classes=NUM_CLASSES
|
||||
)
|
||||
bert_trainer_model.summary()
|
||||
|
||||
# Invoke the trainer model on the inputs. This causes the layer to be built.
|
||||
@@ -44,15 +45,20 @@ class BertModule(tf.Module):
|
||||
self.loss = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||
self.optimizer = tf.keras.optimizers.SGD(learning_rate=1e-2)
|
||||
|
||||
@tf.function(input_signature=[
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH],
|
||||
dtype=tf.int32), #input0: input_word_ids
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH],
|
||||
dtype=tf.int32), #input1: input_mask
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH],
|
||||
dtype=tf.int32), #input2: segment_ids
|
||||
tf.TensorSpec([BATCH_SIZE], tf.int32) # input3: labels
|
||||
])
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32
|
||||
), # input0: input_word_ids
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32
|
||||
), # input1: input_mask
|
||||
tf.TensorSpec(
|
||||
shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32
|
||||
), # input2: segment_ids
|
||||
tf.TensorSpec([BATCH_SIZE], tf.int32), # input3: labels
|
||||
]
|
||||
)
|
||||
def learn(self, input_word_ids, input_mask, segment_ids, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
# Capture the gradients from forward prop...
|
||||
@@ -75,13 +81,13 @@ class BertModule(tf.Module):
|
||||
if __name__ == "__main__":
|
||||
# BertModule()
|
||||
# Compile the model using IREE
|
||||
compiler_module = tfc.compile_module(BertModule(),
|
||||
exported_names=["learn"],
|
||||
import_only=True)
|
||||
compiler_module = tfc.compile_module(
|
||||
BertModule(), exported_names=["learn"], import_only=True
|
||||
)
|
||||
print(type(compiler_module))
|
||||
# Save module as MLIR file in a directory
|
||||
ARITFACTS_DIR = os.getcwd()
|
||||
mlir_path = os.path.join(ARITFACTS_DIR, "model.mlir")
|
||||
with open(mlir_path, "wt") as output_file:
|
||||
output_file.write(compiler_module.decode('utf-8'))
|
||||
output_file.write(compiler_module.decode("utf-8"))
|
||||
print(f"Wrote MLIR to path '{mlir_path}'")
|
||||
|
||||
@@ -21,22 +21,22 @@ BATCH_SIZE = 1
|
||||
bert_input = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32)
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
]
|
||||
|
||||
|
||||
class BertModule(tf.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(BertModule, self).__init__()
|
||||
dict_outputs = False
|
||||
test_network = networks.BertEncoder(vocab_size=vocab_size,
|
||||
num_layers=2,
|
||||
dict_outputs=dict_outputs)
|
||||
test_network = networks.BertEncoder(
|
||||
vocab_size=vocab_size, num_layers=2, dict_outputs=dict_outputs
|
||||
)
|
||||
|
||||
# Create a BERT trainer with the created network.
|
||||
bert_trainer_model = bert_classifier.BertClassifier(
|
||||
test_network, num_classes=NUM_CLASSES)
|
||||
test_network, num_classes=NUM_CLASSES
|
||||
)
|
||||
bert_trainer_model.summary()
|
||||
|
||||
# Invoke the trainer model on the inputs. This causes the layer to be built.
|
||||
@@ -47,10 +47,12 @@ class BertModule(tf.Module):
|
||||
self.loss = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||
self.optimizer = tf.keras.optimizers.SGD(learning_rate=1e-2)
|
||||
|
||||
@tf.function(input_signature=[
|
||||
bert_input, # inputs
|
||||
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32) # labels
|
||||
])
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
bert_input, # inputs
|
||||
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels
|
||||
]
|
||||
)
|
||||
def learn(self, inputs, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
# Capture the gradients from forward prop...
|
||||
@@ -67,25 +69,28 @@ class BertModule(tf.Module):
|
||||
if __name__ == "__main__":
|
||||
# BertModule()
|
||||
# Compile the model using IREE
|
||||
compiler_module = tfc.compile_module(BertModule(),
|
||||
exported_names=["learn"],
|
||||
import_only=True)
|
||||
compiler_module = tfc.compile_module(
|
||||
BertModule(), exported_names=["learn"], import_only=True
|
||||
)
|
||||
|
||||
# Compile the model using IREE
|
||||
backend = "dylib-llvm-aot"
|
||||
args = [
|
||||
"--iree-llvm-target-cpu-features=host",
|
||||
"--iree-mhlo-demote-i64-to-i32=false", "--iree-flow-demote-i64-to-i32"
|
||||
"--iree-mhlo-demote-i64-to-i32=false",
|
||||
"--iree-flow-demote-i64-to-i32",
|
||||
]
|
||||
backend_config = "dylib"
|
||||
#backend = "cuda"
|
||||
#backend_config = "cuda"
|
||||
#args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
|
||||
flatbuffer_blob = compile_str(compiler_module,
|
||||
target_backends=[backend],
|
||||
extra_args=args,
|
||||
input_type="mhlo")
|
||||
#flatbuffer_blob = compile_str(compiler_module, target_backends=["dylib-llvm-aot"])
|
||||
# backend = "cuda"
|
||||
# backend_config = "cuda"
|
||||
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
|
||||
flatbuffer_blob = compile_str(
|
||||
compiler_module,
|
||||
target_backends=[backend],
|
||||
extra_args=args,
|
||||
input_type="mhlo",
|
||||
)
|
||||
# flatbuffer_blob = compile_str(compiler_module, target_backends=["dylib-llvm-aot"])
|
||||
|
||||
# Save module as MLIR file in a directory
|
||||
vm_module = ireert.VmModule.from_flatbuffer(flatbuffer_blob)
|
||||
@@ -97,21 +102,23 @@ if __name__ == "__main__":
|
||||
predict_sample_input = [
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH))
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
]
|
||||
learn_sample_input = [
|
||||
predict_sample_input,
|
||||
np.random.randint(5, size=(BATCH_SIZE))
|
||||
np.random.randint(5, size=(BATCH_SIZE)),
|
||||
]
|
||||
warmup = 5
|
||||
total_iter = 10
|
||||
num_iter = total_iter - warmup
|
||||
for i in range(10):
|
||||
if (i == warmup - 1):
|
||||
if i == warmup - 1:
|
||||
start = time.time()
|
||||
print(
|
||||
BertCompiled.learn(predict_sample_input,
|
||||
np.random.randint(5, size=(BATCH_SIZE))))
|
||||
BertCompiled.learn(
|
||||
predict_sample_input, np.random.randint(5, size=(BATCH_SIZE))
|
||||
)
|
||||
)
|
||||
end = time.time()
|
||||
total_time = end - start
|
||||
print("time: " + str(total_time))
|
||||
|
||||
@@ -14,22 +14,22 @@ BATCH_SIZE = 1
|
||||
bert_input = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32)
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
]
|
||||
|
||||
|
||||
class BertModule(tf.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(BertModule, self).__init__()
|
||||
dict_outputs = False
|
||||
test_network = networks.BertEncoder(vocab_size=vocab_size,
|
||||
num_layers=2,
|
||||
dict_outputs=dict_outputs)
|
||||
test_network = networks.BertEncoder(
|
||||
vocab_size=vocab_size, num_layers=2, dict_outputs=dict_outputs
|
||||
)
|
||||
|
||||
# Create a BERT trainer with the created network.
|
||||
bert_trainer_model = bert_classifier.BertClassifier(
|
||||
test_network, num_classes=NUM_CLASSES)
|
||||
test_network, num_classes=NUM_CLASSES
|
||||
)
|
||||
bert_trainer_model.summary()
|
||||
|
||||
# Invoke the trainer model on the inputs. This causes the layer to be built.
|
||||
@@ -40,10 +40,12 @@ class BertModule(tf.Module):
|
||||
self.loss = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||
self.optimizer = tf.keras.optimizers.SGD(learning_rate=1e-2)
|
||||
|
||||
@tf.function(input_signature=[
|
||||
bert_input, # inputs
|
||||
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32) # labels
|
||||
])
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
bert_input, # inputs
|
||||
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels
|
||||
]
|
||||
)
|
||||
def learn(self, inputs, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
# Capture the gradients from forward prop...
|
||||
@@ -62,7 +64,7 @@ if __name__ == "__main__":
|
||||
predict_sample_input = [
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH))
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
]
|
||||
bert_model = BertModule()
|
||||
warmup = 1
|
||||
@@ -70,9 +72,11 @@ if __name__ == "__main__":
|
||||
num_iter = total_iter - warmup
|
||||
for i in range(total_iter):
|
||||
print(
|
||||
bert_model.learn(predict_sample_input,
|
||||
np.random.randint(5, size=(BATCH_SIZE))))
|
||||
if (i == warmup - 1):
|
||||
bert_model.learn(
|
||||
predict_sample_input, np.random.randint(5, size=(BATCH_SIZE))
|
||||
)
|
||||
)
|
||||
if i == warmup - 1:
|
||||
start = time.time()
|
||||
|
||||
end = time.time()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -12,13 +12,12 @@ import tempfile
|
||||
|
||||
|
||||
class AlbertBaseModuleTester:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False
|
||||
benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
@@ -42,71 +41,80 @@ class AlbertBaseModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv((input),
|
||||
"albert-base-v2",
|
||||
dynamic,
|
||||
device)
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "albert-base-v2", dynamic, device
|
||||
)
|
||||
|
||||
|
||||
class AlbertBaseModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = AlbertBaseModuleTester(self)
|
||||
self.module_tester.save_temps=pytestconfig.getoption("save_temps")
|
||||
self.module_tester.save_mlir=pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb=pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark=pytestconfig.getoption("benchmark")
|
||||
self.module_tester.save_temps = pytestconfig.getoption("save_temps")
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536")
|
||||
@pytest.mark.xfail(
|
||||
reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536"
|
||||
)
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
@pytest.mark.xfail(reason="https://github.com/google/iree/issues/9553")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
@@ -115,8 +123,7 @@ class AlbertBaseModuleTest(unittest.TestCase):
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
@@ -124,11 +131,11 @@ class AlbertBaseModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -136,5 +143,5 @@ class AlbertBaseModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -12,13 +12,12 @@ import tempfile
|
||||
|
||||
|
||||
class BertBaseUncasedModuleTester:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False
|
||||
benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
@@ -42,36 +41,40 @@ class BertBaseUncasedModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv((input),
|
||||
"bert_base_uncased",
|
||||
dynamic,
|
||||
device)
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "bert_base_uncased", dynamic, device
|
||||
)
|
||||
|
||||
|
||||
class BertBaseUncasedModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = BertBaseUncasedModuleTester(self)
|
||||
@@ -80,8 +83,9 @@ class BertBaseUncasedModuleTest(unittest.TestCase):
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536")
|
||||
@pytest.mark.xfail(
|
||||
reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536"
|
||||
)
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
@@ -89,24 +93,28 @@ class BertBaseUncasedModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
@@ -115,8 +123,7 @@ class BertBaseUncasedModuleTest(unittest.TestCase):
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
@@ -124,11 +131,11 @@ class BertBaseUncasedModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -136,5 +143,5 @@ class BertBaseUncasedModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -12,13 +12,12 @@ import tempfile
|
||||
|
||||
|
||||
class CamemBertModuleTester:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False
|
||||
benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
@@ -42,44 +41,51 @@ class CamemBertModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv((input),
|
||||
"camembert-base",
|
||||
dynamic,
|
||||
device)
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "camembert-base", dynamic, device
|
||||
)
|
||||
|
||||
|
||||
class CamemBertModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester=CamemBertModuleTester(self)
|
||||
self.module_tester = CamemBertModuleTester(self)
|
||||
self.module_tester.save_temps = pytestconfig.getoption("save_temps")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536")
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536"
|
||||
)
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
@@ -87,24 +93,28 @@ class CamemBertModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
@@ -113,8 +123,7 @@ class CamemBertModuleTest(unittest.TestCase):
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
@@ -122,11 +131,11 @@ class CamemBertModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -134,5 +143,5 @@ class CamemBertModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -11,24 +11,23 @@ import numpy as np
|
||||
import tempfile
|
||||
|
||||
|
||||
class ConvBertModuleTester:
|
||||
|
||||
class ConvBertModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False
|
||||
benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
self.benchmark = benchmark
|
||||
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model(
|
||||
"dbmdz/convbert-base-turkish-cased")
|
||||
"dbmdz/convbert-base-turkish-cased"
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
if self.save_temps == True:
|
||||
@@ -44,36 +43,40 @@ class ConvBertModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv((input),
|
||||
"convbert-base-turkish-cased",
|
||||
dynamic,
|
||||
device)
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "convbert-base-turkish-cased", dynamic, device
|
||||
)
|
||||
|
||||
|
||||
class ConvBertModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = ConvBertModuleTester(self)
|
||||
@@ -82,7 +85,9 @@ class ConvBertModuleTest(unittest.TestCase):
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536.")
|
||||
@pytest.mark.xfail(
|
||||
reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536."
|
||||
)
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
@@ -90,24 +95,28 @@ class ConvBertModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
@@ -116,8 +125,7 @@ class ConvBertModuleTest(unittest.TestCase):
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
@@ -125,11 +133,11 @@ class ConvBertModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -137,5 +145,5 @@ class ConvBertModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -12,18 +12,17 @@ import tempfile
|
||||
|
||||
|
||||
class DebertaModuleTester:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False
|
||||
benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model("microsoft/deberta-base")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
@@ -41,36 +40,40 @@ class DebertaModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv((input),
|
||||
"deberta-base",
|
||||
dynamic,
|
||||
device)
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "deberta-base", dynamic, device
|
||||
)
|
||||
|
||||
|
||||
class DebertaModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = DebertaModuleTester(self)
|
||||
@@ -78,10 +81,11 @@ class DebertaModuleTest(unittest.TestCase):
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skip(reason="deberta currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.skip(
|
||||
reason="deberta currently failing in the lowering passes."
|
||||
)
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
@@ -89,52 +93,59 @@ class DebertaModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(reason="deberta currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
@pytest.mark.skip(
|
||||
reason="deberta currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(reason="deberta currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.skip(
|
||||
reason="deberta currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(reason="deberta currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.skip(
|
||||
reason="deberta currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(reason="deberta currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.skip(
|
||||
reason="deberta currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -142,5 +153,5 @@ class DebertaModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -12,13 +12,12 @@ import tempfile
|
||||
|
||||
|
||||
class DistilBertModuleTester:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False
|
||||
benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
@@ -41,35 +40,40 @@ class DistilBertModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv((input),
|
||||
"distilbert-base-uncased",
|
||||
dynamic,
|
||||
device)
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "distilbert-base-uncased", dynamic, device
|
||||
)
|
||||
|
||||
|
||||
class DistilBertModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = DistilBertModuleTester(self)
|
||||
@@ -78,7 +82,9 @@ class DistilBertModuleTest(unittest.TestCase):
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536")
|
||||
@pytest.mark.xfail(
|
||||
reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536"
|
||||
)
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
@@ -86,24 +92,28 @@ class DistilBertModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
@@ -112,8 +122,7 @@ class DistilBertModuleTest(unittest.TestCase):
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
@@ -121,11 +130,11 @@ class DistilBertModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -133,5 +142,5 @@ class DistilBertModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -12,20 +12,21 @@ import tempfile
|
||||
|
||||
|
||||
class ElectraModuleTester:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False
|
||||
benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model("google/electra-small-discriminator")
|
||||
model, input, act_out = get_causal_lm_model(
|
||||
"google/electra-small-discriminator"
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
if self.save_temps == True:
|
||||
@@ -39,71 +40,81 @@ class ElectraModuleTester:
|
||||
np.save(f"{temp_dir}/input2.npy", input[1])
|
||||
exp_out = act_out.numpy()
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv((input),
|
||||
"electra-small-discriminator",
|
||||
dynamic,
|
||||
device)
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "electra-small-discriminator", dynamic, device
|
||||
)
|
||||
|
||||
|
||||
class ElectraModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = ElectraModuleTester(self)
|
||||
self.module_tester.save_temps = pytestconfig.getoption("save_temps")
|
||||
self.module_tester.save_temps = pytestconfig.getoption("save_temps")
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536")
|
||||
@pytest.mark.xfail(
|
||||
reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536"
|
||||
)
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
@@ -112,8 +123,7 @@ class ElectraModuleTest(unittest.TestCase):
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
@@ -121,11 +131,11 @@ class ElectraModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -133,5 +143,5 @@ class ElectraModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -12,13 +12,12 @@ import tempfile
|
||||
|
||||
|
||||
class FunnelModuleTester:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False
|
||||
benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
@@ -41,35 +40,40 @@ class FunnelModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv((input),
|
||||
"funnel-transformer-small",
|
||||
dynamic,
|
||||
device)
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "funnel-transformer-small", dynamic, device
|
||||
)
|
||||
|
||||
|
||||
class FunnelModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = FunnelModuleTester(self)
|
||||
@@ -85,15 +89,17 @@ class FunnelModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(reason="funnel currently failing in the lowering passes.")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
@@ -101,9 +107,11 @@ class FunnelModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.mark.skip(reason="funnel currently failing in the lowering passes.")
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
@@ -112,8 +120,7 @@ class FunnelModuleTest(unittest.TestCase):
|
||||
@pytest.mark.skip(reason="funnel currently failing in the lowering passes.")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
@@ -122,11 +129,11 @@ class FunnelModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.mark.skip(reason="funnel currently failing in the lowering passes.")
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -134,5 +141,5 @@ class FunnelModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -12,21 +12,21 @@ import tempfile
|
||||
|
||||
|
||||
class LayoutLmModuleTester:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False
|
||||
benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model(
|
||||
"microsoft/layoutlm-base-uncased")
|
||||
"microsoft/layoutlm-base-uncased"
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
if self.save_temps == True:
|
||||
@@ -42,35 +42,40 @@ class LayoutLmModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv((input),
|
||||
"layoutlm-base-uncased",
|
||||
dynamic,
|
||||
device)
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "layoutlm-base-uncased", dynamic, device
|
||||
)
|
||||
|
||||
|
||||
class LayoutLmModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = LayoutLmModuleTester(self)
|
||||
@@ -78,32 +83,38 @@ class LayoutLmModuleTest(unittest.TestCase):
|
||||
self.module_tester.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536")
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536"
|
||||
)
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
@@ -112,8 +123,7 @@ class LayoutLmModuleTest(unittest.TestCase):
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
@@ -121,11 +131,11 @@ class LayoutLmModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -133,5 +143,5 @@ class LayoutLmModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -12,20 +12,21 @@ import tempfile
|
||||
|
||||
|
||||
class LongFormerModuleTester:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False
|
||||
benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model("allenai/longformer-base-4096")
|
||||
model, input, act_out = get_causal_lm_model(
|
||||
"allenai/longformer-base-4096"
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
if self.save_temps == True:
|
||||
@@ -41,35 +42,40 @@ class LongFormerModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv((input),
|
||||
"longformer-base-4096",
|
||||
dynamic,
|
||||
device)
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "longformer-base-4096", dynamic, device
|
||||
)
|
||||
|
||||
|
||||
class LongFormerModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = LongFormerModuleTester(self)
|
||||
@@ -79,45 +85,52 @@ class LongFormerModuleTest(unittest.TestCase):
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="longformer currently failing in the lowering passes.")
|
||||
reason="longformer currently failing in the lowering passes."
|
||||
)
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="longformer currently failing in the lowering passes.")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
reason="longformer currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="longformer currently failing in the lowering passes.")
|
||||
reason="longformer currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="longformer currently failing in the lowering passes.")
|
||||
reason="longformer currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
@@ -125,13 +138,14 @@ class LongFormerModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="longformer currently failing in the lowering passes.")
|
||||
reason="longformer currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -139,5 +153,5 @@ class LongFormerModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -17,21 +17,22 @@ inputs_signature = [
|
||||
|
||||
def preprocess_input(model_name, text="This is just used to compile the model"):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
inputs = tokenizer(text,
|
||||
padding="max_length",
|
||||
return_tensors="tf",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH)
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
return_tensors="tf",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
)
|
||||
return inputs
|
||||
|
||||
|
||||
class MaskedLM(tf.Module):
|
||||
|
||||
def __init__(self, model_name):
|
||||
super(MaskedLM, self).__init__()
|
||||
self.m = TFAutoModelForMaskedLM.from_pretrained(model_name,
|
||||
output_attentions=False,
|
||||
num_labels=2)
|
||||
self.m = TFAutoModelForMaskedLM.from_pretrained(
|
||||
model_name, output_attentions=False, num_labels=2
|
||||
)
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)[0]
|
||||
|
||||
@tf.function(input_signature=inputs_signature)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -12,13 +12,12 @@ import tempfile
|
||||
|
||||
|
||||
class MobileBertModuleTester:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False
|
||||
benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
@@ -41,35 +40,40 @@ class MobileBertModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv((input),
|
||||
"mobilebert-uncased",
|
||||
dynamic,
|
||||
device)
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "mobilebert-uncased", dynamic, device
|
||||
)
|
||||
|
||||
|
||||
class MobileBertModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = MobileBertModuleTester(self)
|
||||
@@ -78,31 +82,37 @@ class MobileBertModuleTest(unittest.TestCase):
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536")
|
||||
@pytest.mark.xfail(
|
||||
reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536"
|
||||
)
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
@@ -111,8 +121,7 @@ class MobileBertModuleTest(unittest.TestCase):
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
@@ -120,11 +129,11 @@ class MobileBertModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -132,5 +141,5 @@ class MobileBertModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -12,18 +12,17 @@ import tempfile
|
||||
|
||||
|
||||
class MpNetModuleTester:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False
|
||||
benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model("microsoft/mpnet-base")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
@@ -41,35 +40,40 @@ class MpNetModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv((input),
|
||||
"mpnet-base",
|
||||
dynamic,
|
||||
device)
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "mpnet-base", dynamic, device
|
||||
)
|
||||
|
||||
|
||||
class MpNetModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = MpNetModuleTester(self)
|
||||
@@ -78,31 +82,37 @@ class MpNetModuleTest(unittest.TestCase):
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536")
|
||||
@pytest.mark.xfail(
|
||||
reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536"
|
||||
)
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
@@ -111,8 +121,7 @@ class MpNetModuleTest(unittest.TestCase):
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
@@ -120,11 +129,11 @@ class MpNetModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -132,5 +141,5 @@ class MpNetModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -12,18 +12,17 @@ import tempfile
|
||||
|
||||
|
||||
class RemBertModuleTester:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False
|
||||
benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model("google/rembert")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
@@ -41,35 +40,38 @@ class RemBertModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv((input),
|
||||
"rembert",
|
||||
dynamic,
|
||||
device)
|
||||
shark_module.benchmark_all_csv((input), "rembert", dynamic, device)
|
||||
|
||||
|
||||
class RemBertModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = RemBertModuleTester(self)
|
||||
@@ -78,60 +80,68 @@ class RemBertModuleTest(unittest.TestCase):
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.skip(reason="rembert currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.skip(
|
||||
reason="rembert currently failing in the lowering passes."
|
||||
)
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(reason="rembert currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
@pytest.mark.skip(
|
||||
reason="rembert currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(reason="rembert currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.skip(
|
||||
reason="rembert currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(reason="rembert currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.skip(
|
||||
reason="rembert currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
device = "vulkan"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(reason="rembert currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.skip(
|
||||
reason="rembert currently failing in the lowering passes."
|
||||
)
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -139,5 +149,5 @@ class RemBertModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -12,18 +12,17 @@ import tempfile
|
||||
|
||||
|
||||
class RobertaModuleTester:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False
|
||||
benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model("roberta-base")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
@@ -41,35 +40,40 @@ class RobertaModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv((input),
|
||||
"roberta-base",
|
||||
dynamic,
|
||||
device)
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "roberta-base", dynamic, device
|
||||
)
|
||||
|
||||
|
||||
class RobertaModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = RobertaModuleTester(self)
|
||||
@@ -78,31 +82,37 @@ class RobertaModuleTest(unittest.TestCase):
|
||||
self.module_tester.save_vmfb = pytestconfig.getoption("save_vmfb")
|
||||
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
|
||||
|
||||
@pytest.mark.xfail(reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536")
|
||||
@pytest.mark.xfail(
|
||||
reason="Upstream IREE issue, see https://github.com/google/iree/issues/9536"
|
||||
)
|
||||
def test_module_static_cpu(self):
|
||||
dynamic = False
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
@@ -111,8 +121,7 @@ class RobertaModuleTest(unittest.TestCase):
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
@@ -120,11 +129,11 @@ class RobertaModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -132,5 +141,5 @@ class RobertaModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -12,18 +12,17 @@ import tempfile
|
||||
|
||||
|
||||
class TapasBaseModuleTester:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False
|
||||
benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model("google/tapas-base")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
@@ -41,35 +40,40 @@ class TapasBaseModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv((input),
|
||||
"tapas-base",
|
||||
dynamic,
|
||||
device)
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "tapas-base", dynamic, device
|
||||
)
|
||||
|
||||
|
||||
class TapasBaseModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = TapasBaseModuleTester(self)
|
||||
@@ -85,15 +89,17 @@ class TapasBaseModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(reason="tapas currently failing in the lowering passes.")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
@@ -101,9 +107,11 @@ class TapasBaseModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.mark.skip(reason="tapas currently failing in the lowering passes.")
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
@@ -112,8 +120,7 @@ class TapasBaseModuleTest(unittest.TestCase):
|
||||
@pytest.mark.skip(reason="tapas currently failing in the lowering passes.")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
@@ -122,11 +129,11 @@ class TapasBaseModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.mark.skip(reason="tapas currently failing in the lowering passes.")
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -134,5 +141,5 @@ class TapasBaseModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -12,20 +12,21 @@ import tempfile
|
||||
|
||||
|
||||
class FlauBertModuleTester:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False
|
||||
benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model("hf-internal-testing/tiny-random-flaubert")
|
||||
model, input, act_out = get_causal_lm_model(
|
||||
"hf-internal-testing/tiny-random-flaubert"
|
||||
)
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
if self.save_temps == True:
|
||||
@@ -41,35 +42,40 @@ class FlauBertModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv((input),
|
||||
"tiny-random-flaubert",
|
||||
dynamic,
|
||||
device)
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "tiny-random-flaubert", dynamic, device
|
||||
)
|
||||
|
||||
|
||||
class FlauBertModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = FlauBertModuleTester(self)
|
||||
@@ -84,24 +90,28 @@ class FlauBertModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
@@ -110,8 +120,7 @@ class FlauBertModuleTest(unittest.TestCase):
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
@@ -119,11 +128,11 @@ class FlauBertModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -131,5 +140,5 @@ class FlauBertModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from masked_lm import get_causal_lm_model
|
||||
from tank.model_utils_tf import compare_tensors_tf
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -12,18 +12,17 @@ import tempfile
|
||||
|
||||
|
||||
class XLMRobertaModuleTester:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_temps=False,
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
benchmark=False
|
||||
benchmark=False,
|
||||
):
|
||||
self.save_temps = save_temps
|
||||
self.save_mlir = save_mlir
|
||||
self.save_vmfb = save_vmfb
|
||||
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_causal_lm_model("xlm-roberta-base")
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
@@ -41,35 +40,40 @@ class XLMRobertaModuleTester:
|
||||
with open(f"{temp_dir}/expected_out.txt", "w") as out_file:
|
||||
out_file.write(np.array2string(exp_out))
|
||||
with ireec.tools.TempFileSaver(temp_dir):
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark)
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=self.benchmark,
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
assert True == compare_tensors_tf(act_out, results)
|
||||
|
||||
if self.benchmark == True:
|
||||
shark_module.benchmark_all_csv((input),
|
||||
"xlm-roberta-base",
|
||||
dynamic,
|
||||
device)
|
||||
shark_module.benchmark_all_csv(
|
||||
(input), "xlm-roberta-base", dynamic, device
|
||||
)
|
||||
|
||||
|
||||
class XLMRobertaModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.module_tester = XLMRobertaModuleTester(self)
|
||||
@@ -85,24 +89,28 @@ class XLMRobertaModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
@@ -111,8 +119,7 @@ class XLMRobertaModuleTest(unittest.TestCase):
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
@@ -120,11 +127,11 @@ class XLMRobertaModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -132,5 +139,5 @@ class XLMRobertaModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -15,21 +15,22 @@ BATCH_SIZE = 1
|
||||
bert_input = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32)
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
]
|
||||
|
||||
|
||||
class BertModule(tf.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(BertModule, self).__init__()
|
||||
# Create a BERT trainer with the created network.
|
||||
self.m = TFBertModel.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased", from_pt=True)
|
||||
"microsoft/MiniLM-L12-H384-uncased", from_pt=True
|
||||
)
|
||||
|
||||
# Invoke the trainer model on the inputs. This causes the layer to be built.
|
||||
self.m.predict = lambda x, y, z: self.m.call(
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False)
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
|
||||
@tf.function(input_signature=bert_input)
|
||||
def predict(self, input_word_ids, input_mask, segment_ids):
|
||||
@@ -39,12 +40,12 @@ class BertModule(tf.Module):
|
||||
if __name__ == "__main__":
|
||||
# BertModule()
|
||||
# Compile the model using IREE
|
||||
compiler_module = tfc.compile_module(BertModule(),
|
||||
exported_names=["predict"],
|
||||
import_only=True)
|
||||
compiler_module = tfc.compile_module(
|
||||
BertModule(), exported_names=["predict"], import_only=True
|
||||
)
|
||||
# Save module as MLIR file in a directory
|
||||
ARITFACTS_DIR = os.getcwd()
|
||||
mlir_path = os.path.join(ARITFACTS_DIR, "model.mlir")
|
||||
with open(mlir_path, "wt") as output_file:
|
||||
output_file.write(compiler_module.decode('utf-8'))
|
||||
output_file.write(compiler_module.decode("utf-8"))
|
||||
print(f"Wrote MLIR to path '{mlir_path}'")
|
||||
|
||||
@@ -16,21 +16,22 @@ BATCH_SIZE = 1
|
||||
bert_input = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32)
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
]
|
||||
|
||||
|
||||
class BertModule(tf.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(BertModule, self).__init__()
|
||||
# Create a BERT trainer with the created network.
|
||||
self.m = TFBertModel.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased", from_pt=True)
|
||||
"microsoft/MiniLM-L12-H384-uncased", from_pt=True
|
||||
)
|
||||
|
||||
# Invoke the trainer model on the inputs. This causes the layer to be built.
|
||||
self.m.predict = lambda x, y, z: self.m.call(
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False)
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
|
||||
@tf.function(input_signature=bert_input)
|
||||
def predict(self, input_ids, attention_mask, token_type_ids):
|
||||
@@ -40,36 +41,43 @@ class BertModule(tf.Module):
|
||||
if __name__ == "__main__":
|
||||
# Prepping Data
|
||||
tokenizer = BertTokenizer.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased")
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
text = "Replace me by any text you'd like."
|
||||
encoded_input = tokenizer(text,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH)
|
||||
encoded_input = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
)
|
||||
for key in encoded_input:
|
||||
encoded_input[key] = tf.expand_dims(
|
||||
tf.convert_to_tensor(encoded_input[key]), 0)
|
||||
tf.convert_to_tensor(encoded_input[key]), 0
|
||||
)
|
||||
|
||||
# Compile the model using IREE
|
||||
compiler_module = tfc.compile_module(BertModule(),
|
||||
exported_names=["predict"],
|
||||
import_only=True)
|
||||
compiler_module = tfc.compile_module(
|
||||
BertModule(), exported_names=["predict"], import_only=True
|
||||
)
|
||||
|
||||
# Compile the model using IREE
|
||||
backend = "dylib-llvm-aot"
|
||||
args = [
|
||||
"--iree-llvm-target-cpu-features=host",
|
||||
"--iree-mhlo-demote-i64-to-i32=false", "--iree-flow-demote-i64-to-i32"
|
||||
"--iree-mhlo-demote-i64-to-i32=false",
|
||||
"--iree-flow-demote-i64-to-i32",
|
||||
]
|
||||
backend_config = "dylib"
|
||||
#backend = "cuda"
|
||||
#backend_config = "cuda"
|
||||
#args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
|
||||
flatbuffer_blob = compile_str(compiler_module,
|
||||
target_backends=[backend],
|
||||
extra_args=args,
|
||||
input_type="mhlo")
|
||||
#flatbuffer_blob = compile_str(compiler_module, target_backends=["dylib-llvm-aot"])
|
||||
# backend = "cuda"
|
||||
# backend_config = "cuda"
|
||||
# args = ["--iree-cuda-llvm-target-arch=sm_80", "--iree-hal-cuda-disable-loop-nounroll-wa", "--iree-enable-fusion-with-reduction-ops"]
|
||||
flatbuffer_blob = compile_str(
|
||||
compiler_module,
|
||||
target_backends=[backend],
|
||||
extra_args=args,
|
||||
input_type="mhlo",
|
||||
)
|
||||
# flatbuffer_blob = compile_str(compiler_module, target_backends=["dylib-llvm-aot"])
|
||||
|
||||
# Save module as MLIR file in a directory
|
||||
vm_module = ireert.VmModule.from_flatbuffer(flatbuffer_blob)
|
||||
@@ -78,7 +86,9 @@ if __name__ == "__main__":
|
||||
ctx = ireert.SystemContext(config=config)
|
||||
ctx.add_vm_module(vm_module)
|
||||
BertCompiled = ctx.modules.module
|
||||
result = BertCompiled.predict(encoded_input["input_ids"],
|
||||
encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"])
|
||||
result = BertCompiled.predict(
|
||||
encoded_input["input_ids"],
|
||||
encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"],
|
||||
)
|
||||
print(result)
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
import tensorflow as tf
|
||||
from transformers import BertModel, BertTokenizer, TFBertModel
|
||||
|
||||
tf_model = TFBertModel.from_pretrained("microsoft/MiniLM-L12-H384-uncased",
|
||||
from_pt=True)
|
||||
tf_model = TFBertModel.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased", from_pt=True
|
||||
)
|
||||
tokenizer = BertTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
|
||||
|
||||
text = "Replace me by any text you'd like."
|
||||
encoded_input = tokenizer(text,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=512)
|
||||
encoded_input = tokenizer(
|
||||
text, padding="max_length", truncation=True, max_length=512
|
||||
)
|
||||
for key in encoded_input:
|
||||
encoded_input[key] = tf.expand_dims(
|
||||
tf.convert_to_tensor(encoded_input[key]), 0)
|
||||
tf.convert_to_tensor(encoded_input[key]), 0
|
||||
)
|
||||
output = tf_model(encoded_input)
|
||||
|
||||
print(output)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils import check_device_drivers
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
from tank.model_utils_tf import get_TFhf_model, compare_tensors_tf
|
||||
|
||||
import tensorflow as tf
|
||||
@@ -9,14 +9,13 @@ import pytest
|
||||
|
||||
|
||||
class MiniLMTFModuleTester:
|
||||
|
||||
def create_and_check_module(self, dynamic, device):
|
||||
model, input, act_out = get_TFhf_model(
|
||||
"microsoft/MiniLM-L12-H384-uncased")
|
||||
shark_module = SharkInference(model, (input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True)
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
model, (input,), device=device, dynamic=dynamic, jit_trace=True
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
results = shark_module.forward((input))
|
||||
@@ -24,7 +23,6 @@ class MiniLMTFModuleTester:
|
||||
|
||||
|
||||
class MiniLMTFModuleTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.module_tester = MiniLMTFModuleTester()
|
||||
|
||||
@@ -36,15 +34,17 @@ class MiniLMTFModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.mark.skip(reason="TF testing temporarily unavailable.")
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
def test_module_dynamic_cpu(self):
|
||||
dynamic = True
|
||||
device = "cpu"
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skip(reason="TF testing temporarily unavailable.")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_static_gpu(self):
|
||||
dynamic = False
|
||||
device = "gpu"
|
||||
@@ -52,9 +52,11 @@ class MiniLMTFModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.mark.skip(reason="TF testing temporarily unavailable.")
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
@pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
)
|
||||
def test_module_dynamic_gpu(self):
|
||||
dynamic = True
|
||||
device = "gpu"
|
||||
@@ -63,8 +65,7 @@ class MiniLMTFModuleTest(unittest.TestCase):
|
||||
@pytest.mark.skip(reason="TF testing temporarily unavailable.")
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_static_vulkan(self):
|
||||
dynamic = False
|
||||
@@ -73,11 +74,11 @@ class MiniLMTFModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.mark.skip(reason="TF testing temporarily unavailable.")
|
||||
@pytest.mark.xfail(
|
||||
reason="Language models currently failing for dynamic case")
|
||||
reason="Language models currently failing for dynamic case"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
)
|
||||
def test_module_dynamic_vulkan(self):
|
||||
dynamic = True
|
||||
@@ -85,5 +86,5 @@ class MiniLMTFModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -6,12 +6,15 @@ from shark.parser import shark_args
|
||||
import argparse
|
||||
|
||||
|
||||
seq_parser = argparse.ArgumentParser(description='Shark Sequence Classification.')
|
||||
seq_parser = argparse.ArgumentParser(
|
||||
description="Shark Sequence Classification."
|
||||
)
|
||||
seq_parser.add_argument(
|
||||
"--hf_model_name",
|
||||
type=str,
|
||||
default="bert-base-uncased",
|
||||
help="Hugging face model to run sequence classification.")
|
||||
help="Hugging face model to run sequence classification.",
|
||||
)
|
||||
|
||||
seq_args, unknown = seq_parser.parse_known_args()
|
||||
|
||||
@@ -25,45 +28,56 @@ inputs_signature = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
]
|
||||
|
||||
# For supported models please see here:
|
||||
# For supported models please see here:
|
||||
# https://huggingface.co/docs/transformers/model_doc/auto#transformers.TFAutoModelForSequenceClassification
|
||||
|
||||
def preprocess_input(text = "This is just used to compile the model"):
|
||||
|
||||
def preprocess_input(text="This is just used to compile the model"):
|
||||
tokenizer = AutoTokenizer.from_pretrained(seq_args.hf_model_name)
|
||||
inputs = tokenizer(text,
|
||||
padding="max_length",
|
||||
return_tensors="tf",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH)
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
return_tensors="tf",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
)
|
||||
return inputs
|
||||
|
||||
|
||||
class SeqClassification(tf.Module):
|
||||
|
||||
def __init__(self, model_name):
|
||||
super(SeqClassification, self).__init__()
|
||||
self.m = TFAutoModelForSequenceClassification.from_pretrained(
|
||||
model_name, output_attentions=False, num_labels=2)
|
||||
model_name, output_attentions=False, num_labels=2
|
||||
)
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)[0]
|
||||
|
||||
@tf.function(input_signature=inputs_signature)
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return tf.math.softmax(self.m.predict(input_ids, attention_mask),
|
||||
axis=-1)
|
||||
return tf.math.softmax(
|
||||
self.m.predict(input_ids, attention_mask), axis=-1
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
inputs = preprocess_input()
|
||||
shark_module = SharkInference(
|
||||
SeqClassification(seq_args.hf_model_name),
|
||||
(inputs["input_ids"], inputs["attention_mask"]))
|
||||
(inputs["input_ids"], inputs["attention_mask"]),
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
print(f"Model has been successfully compiled on {shark_args.device}")
|
||||
|
||||
while True:
|
||||
input_text = input("Enter the text to classify (press q or nothing to exit): ")
|
||||
input_text = input(
|
||||
"Enter the text to classify (press q or nothing to exit): "
|
||||
)
|
||||
if not input_text or input_text == "q":
|
||||
break
|
||||
inputs = preprocess_input(input_text)
|
||||
print(shark_module.forward((inputs["input_ids"], inputs["attention_mask"])))
|
||||
print(
|
||||
shark_module.forward(
|
||||
(inputs["input_ids"], inputs["attention_mask"])
|
||||
)
|
||||
)
|
||||
|
||||
@@ -13,26 +13,35 @@ def generate_inputs(input_details):
|
||||
|
||||
args = []
|
||||
args.append(
|
||||
np.random.randint(low=0,
|
||||
high=256,
|
||||
size=input_details[0]["shape"],
|
||||
dtype=input_details[0]["dtype"]))
|
||||
np.random.randint(
|
||||
low=0,
|
||||
high=256,
|
||||
size=input_details[0]["shape"],
|
||||
dtype=input_details[0]["dtype"],
|
||||
)
|
||||
)
|
||||
args.append(
|
||||
np.ones(shape=input_details[1]["shape"],
|
||||
dtype=input_details[1]["dtype"]))
|
||||
np.ones(
|
||||
shape=input_details[1]["shape"], dtype=input_details[1]["dtype"]
|
||||
)
|
||||
)
|
||||
args.append(
|
||||
np.zeros(shape=input_details[2]["shape"],
|
||||
dtype=input_details[2]["dtype"]))
|
||||
np.zeros(
|
||||
shape=input_details[2]["shape"], dtype=input_details[2]["dtype"]
|
||||
)
|
||||
)
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
my_shark_importer = SharkImporter(model_path=model_path,
|
||||
model_type="tflite",
|
||||
model_source_hub="tfhub",
|
||||
device="cpu",
|
||||
dynamic=False,
|
||||
jit_trace=True)
|
||||
if __name__ == "__main__":
|
||||
my_shark_importer = SharkImporter(
|
||||
model_path=model_path,
|
||||
model_type="tflite",
|
||||
model_source_hub="tfhub",
|
||||
device="cpu",
|
||||
dynamic=False,
|
||||
jit_trace=True,
|
||||
)
|
||||
# Case1: Use default inputs
|
||||
my_shark_importer.compile()
|
||||
shark_results = my_shark_importer.forward()
|
||||
|
||||
@@ -6,13 +6,12 @@ from shark.parser import shark_args
|
||||
|
||||
|
||||
class AlbertTfliteModuleTester:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
device="cpu",
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
self,
|
||||
dynamic=False,
|
||||
device="cpu",
|
||||
save_mlir=False,
|
||||
save_vmfb=False,
|
||||
):
|
||||
self.dynamic = dynamic
|
||||
self.device = device
|
||||
@@ -22,19 +21,23 @@ class AlbertTfliteModuleTester:
|
||||
def create_and_check_module(self):
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
self.shark_downloader = SharkDownloader(model_name="albert_lite_base",
|
||||
tank_url="https://storage.googleapis.com/shark_tank",
|
||||
local_tank_dir="./../gen_shark_tank/tflite",
|
||||
model_type="tflite-tosa",
|
||||
input_json="input.json",
|
||||
input_type="int32")
|
||||
self.shark_downloader = SharkDownloader(
|
||||
model_name="albert_lite_base",
|
||||
tank_url="https://storage.googleapis.com/shark_tank",
|
||||
local_tank_dir="./../gen_shark_tank/tflite",
|
||||
model_type="tflite-tosa",
|
||||
input_json="input.json",
|
||||
input_type="int32",
|
||||
)
|
||||
tflite_tosa_model = self.shark_downloader.get_mlir_file()
|
||||
inputs = self.shark_downloader.get_inputs()
|
||||
self.shark_module = SharkInference(tflite_tosa_model,
|
||||
inputs,
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=True)
|
||||
self.shark_module = SharkInference(
|
||||
tflite_tosa_model,
|
||||
inputs,
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=True,
|
||||
)
|
||||
self.shark_module.set_frontend("tflite-tosa")
|
||||
self.shark_module.compile()
|
||||
self.shark_module.forward(inputs)
|
||||
@@ -42,7 +45,6 @@ class AlbertTfliteModuleTester:
|
||||
|
||||
|
||||
class AlbertTfliteModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
@@ -58,7 +60,7 @@ class AlbertTfliteModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
# module_tester = AlbertTfliteModuleTester()
|
||||
# module_tester.create_and_check_module()
|
||||
|
||||
@@ -14,20 +14,27 @@ def generate_inputs(input_details):
|
||||
|
||||
args = []
|
||||
args.append(
|
||||
np.random.randint(low=0,
|
||||
high=256,
|
||||
size=input_details[0]["shape"],
|
||||
dtype=input_details[0]["dtype"]))
|
||||
np.random.randint(
|
||||
low=0,
|
||||
high=256,
|
||||
size=input_details[0]["shape"],
|
||||
dtype=input_details[0]["dtype"],
|
||||
)
|
||||
)
|
||||
args.append(
|
||||
np.ones(shape=input_details[1]["shape"],
|
||||
dtype=input_details[1]["dtype"]))
|
||||
np.ones(
|
||||
shape=input_details[1]["shape"], dtype=input_details[1]["dtype"]
|
||||
)
|
||||
)
|
||||
args.append(
|
||||
np.zeros(shape=input_details[2]["shape"],
|
||||
dtype=input_details[2]["dtype"]))
|
||||
np.zeros(
|
||||
shape=input_details[2]["shape"], dtype=input_details[2]["dtype"]
|
||||
)
|
||||
)
|
||||
return args
|
||||
|
||||
class AlbertTfliteModuleTester:
|
||||
|
||||
class AlbertTfliteModuleTester:
|
||||
def __init__(
|
||||
self,
|
||||
dynamic=False,
|
||||
@@ -43,14 +50,16 @@ class AlbertTfliteModuleTester:
|
||||
def create_and_check_module(self):
|
||||
shark_args.save_mlir = self.save_mlir
|
||||
shark_args.save_vmfb = self.save_vmfb
|
||||
my_shark_importer = SharkImporter(model_name="albert_lite_base",
|
||||
# model_path=model_path,
|
||||
model_type="tflite",
|
||||
model_source_hub="tfhub",
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=True,
|
||||
tank_url=None)
|
||||
my_shark_importer = SharkImporter(
|
||||
model_name="albert_lite_base",
|
||||
# model_path=model_path,
|
||||
model_type="tflite",
|
||||
model_source_hub="tfhub",
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=True,
|
||||
tank_url=None,
|
||||
)
|
||||
# Case1: Use default inputs
|
||||
my_shark_importer.compile()
|
||||
shark_results = my_shark_importer.forward()
|
||||
@@ -63,7 +72,6 @@ class AlbertTfliteModuleTester:
|
||||
|
||||
|
||||
class AlbertTfliteModuleTest(unittest.TestCase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def configure(self, pytestconfig):
|
||||
self.save_mlir = pytestconfig.getoption("save_mlir")
|
||||
@@ -78,7 +86,8 @@ class AlbertTfliteModuleTest(unittest.TestCase):
|
||||
self.module_tester.device = "cpu"
|
||||
self.module_tester.create_and_check_module()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
# module_tester = AlbertTfliteModuleTester()
|
||||
# module_tester.create_and_check_module()
|
||||
unittest.main()
|
||||
|
||||
@@ -10,7 +10,6 @@ model_path = "https://tfhub.dev/neso613/lite-model/ASR_TFLite/pre_trained_models
|
||||
# Failure is due to dynamic shapes:
|
||||
# - Some improvements to tfl.strided_slice lowering are next steps
|
||||
class AsrConformerTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AsrConformerTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
@@ -18,5 +17,5 @@ class AsrConformerTest(test_util.TFLiteModelTest):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -11,18 +11,21 @@ model_path = "https://tfhub.dev/google/lite-model/aiy/vision/classifier/birds_V1
|
||||
|
||||
|
||||
class BirdClassifierTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(BirdClassifierTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(BirdClassifierTest, self).compare_results(iree_results,
|
||||
tflite_results, details)
|
||||
super(BirdClassifierTest, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=1e-3).all())
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=1e-3).all()
|
||||
)
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
img_path = "https://github.com/google-coral/test_data/raw/master/bird.bmp"
|
||||
img_path = (
|
||||
"https://github.com/google-coral/test_data/raw/master/bird.bmp"
|
||||
)
|
||||
local_path = "/".join([self.workdir, "bird.bmp"])
|
||||
urllib.request.urlretrieve(img_path, local_path)
|
||||
|
||||
@@ -35,5 +38,5 @@ class BirdClassifierTest(test_util.TFLiteModelTest):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -4,11 +4,12 @@
|
||||
import absl.testing
|
||||
import test_util
|
||||
|
||||
model_path = "https://tfhub.dev/sayakpaul/lite-model/cartoongan/dr/1?lite-format=tflite"
|
||||
model_path = (
|
||||
"https://tfhub.dev/sayakpaul/lite-model/cartoongan/dr/1?lite-format=tflite"
|
||||
)
|
||||
|
||||
|
||||
class CartoonGanTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(CartoonGanTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
@@ -16,5 +17,5 @@ class CartoonGanTest(test_util.TFLiteModelTest):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -10,17 +10,17 @@ model_path = "https://tfhub.dev/tulasiram58827/lite-model/craft-text-detector/dr
|
||||
# Failure: Resize lowering does not handle inferred dynamic shapes. Furthermore, the entire model
|
||||
# requires dynamic shape support.
|
||||
class CraftTextTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(CraftTextTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(CraftTextTest, self).compare_results(iree_results, tflite_results,
|
||||
details)
|
||||
super(CraftTextTest, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -8,19 +8,20 @@ model_path = "https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2?lit
|
||||
|
||||
|
||||
class DeepLabV3Test(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DeepLabV3Test, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(DeepLabV3Test, self).compare_results(iree_results, tflite_results,
|
||||
details)
|
||||
super(DeepLabV3Test, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=1e-3).all())
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=1e-3).all()
|
||||
)
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -8,19 +8,20 @@ model_path = "https://tfhub.dev/tensorflow/lite-model/densenet/1/metadata/1?lite
|
||||
|
||||
|
||||
class DenseNetTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DenseNetTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(DenseNetTest, self).compare_results(iree_results, tflite_results,
|
||||
details)
|
||||
super(DenseNetTest, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=1e-5).all())
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=1e-5).all()
|
||||
)
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -8,28 +8,30 @@ model_path = "https://tfhub.dev/sayakpaul/lite-model/east-text-detector/dr/1?lit
|
||||
|
||||
|
||||
class EastTextDetectorTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(EastTextDetectorTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(EastTextDetectorTest,
|
||||
self).compare_results(iree_results, tflite_results, details)
|
||||
super(EastTextDetectorTest, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=1e-3).all())
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=1e-3).all()
|
||||
)
|
||||
|
||||
# The second return is extremely noisy as it is not a binary classification. To handle we
|
||||
# check normalized correlation with an expectation of "close enough".
|
||||
iree_norm = numpy.sqrt(iree_results[1] * iree_results[1])
|
||||
tflite_norm = numpy.sqrt(tflite_results[1] * tflite_results[1])
|
||||
|
||||
correlation = numpy.average(iree_results[1] * tflite_results[1] /
|
||||
iree_norm / tflite_norm)
|
||||
correlation = numpy.average(
|
||||
iree_results[1] * tflite_results[1] / iree_norm / tflite_norm
|
||||
)
|
||||
self.assertTrue(numpy.isclose(correlation, 1.0, atol=1e-2).all())
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -10,23 +10,25 @@ model_path = "https://storage.googleapis.com/iree-model-artifacts/efficientnet_l
|
||||
|
||||
|
||||
class EfficientnetLite0Int8Test(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(EfficientnetLite0Int8Test, self).__init__(model_path, *args,
|
||||
**kwargs)
|
||||
super(EfficientnetLite0Int8Test, self).__init__(
|
||||
model_path, *args, **kwargs
|
||||
)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(EfficientnetLite0Int8Test,
|
||||
self).compare_results(iree_results, tflite_results, details)
|
||||
super(EfficientnetLite0Int8Test, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
# Dequantize outputs.
|
||||
zero_point = details[0]['quantization_parameters']['zero_points'][0]
|
||||
scale = details[0]['quantization_parameters']['scales'][0]
|
||||
zero_point = details[0]["quantization_parameters"]["zero_points"][0]
|
||||
scale = details[0]["quantization_parameters"]["scales"][0]
|
||||
dequantized_iree_results = (iree_results - zero_point) * scale
|
||||
dequantized_tflite_results = (tflite_results - zero_point) * scale
|
||||
self.assertTrue(
|
||||
numpy.isclose(dequantized_iree_results,
|
||||
dequantized_tflite_results,
|
||||
atol=5e-3).all())
|
||||
numpy.isclose(
|
||||
dequantized_iree_results, dequantized_tflite_results, atol=5e-3
|
||||
).all()
|
||||
)
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
return [imagenet_test_data.generate_input(self.workdir, input_details)]
|
||||
@@ -35,5 +37,5 @@ class EfficientnetLite0Int8Test(test_util.TFLiteModelTest):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -10,25 +10,26 @@ model_path = "https://storage.googleapis.com/iree-model-artifacts/efficientnet_l
|
||||
|
||||
|
||||
class EfficientnetLite0Test(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(EfficientnetLite0Test, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(EfficientnetLite0Test,
|
||||
self).compare_results(iree_results, tflite_results, details)
|
||||
super(EfficientnetLite0Test, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results, tflite_results, atol=1e-4).all())
|
||||
numpy.isclose(iree_results, tflite_results, atol=1e-4).all()
|
||||
)
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
inputs = imagenet_test_data.generate_input(self.workdir, input_details)
|
||||
# Normalize inputs to [-1, 1].
|
||||
inputs = (inputs.astype('float32') / 127.5) - 1
|
||||
inputs = (inputs.astype("float32") / 127.5) - 1
|
||||
return [inputs]
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -11,25 +11,26 @@ model_path = "https://storage.googleapis.com/iree-model-artifacts/efficientnet_2
|
||||
|
||||
|
||||
class EfficientnetTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(EfficientnetTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(EfficientnetTest, self).compare_results(iree_results,
|
||||
tflite_results, details)
|
||||
super(EfficientnetTest, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results, tflite_results, atol=1e-4).all())
|
||||
numpy.isclose(iree_results, tflite_results, atol=1e-4).all()
|
||||
)
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
inputs = imagenet_test_data.generate_input(self.workdir, input_details)
|
||||
# Normalize inputs to [-1, 1].
|
||||
inputs = (inputs.astype('float32') / 127.5) - 1
|
||||
inputs = (inputs.astype("float32") / 127.5) - 1
|
||||
return [inputs]
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -5,12 +5,13 @@ import absl.testing
|
||||
import numpy
|
||||
import test_util
|
||||
|
||||
model_path = "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-64.tflite"
|
||||
model_path = (
|
||||
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-64.tflite"
|
||||
)
|
||||
|
||||
|
||||
# This test is a massive download and excluded due to causing timeouts.
|
||||
class GPT2Test(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(GPT2Test, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
@@ -18,23 +19,29 @@ class GPT2Test(test_util.TFLiteModelTest):
|
||||
def generate_inputs(self, input_details):
|
||||
args = []
|
||||
args.append(
|
||||
numpy.random.randint(low=0,
|
||||
high=256,
|
||||
size=input_details[0]["shape"],
|
||||
dtype=input_details[0]["dtype"]))
|
||||
numpy.random.randint(
|
||||
low=0,
|
||||
high=256,
|
||||
size=input_details[0]["shape"],
|
||||
dtype=input_details[0]["dtype"],
|
||||
)
|
||||
)
|
||||
return args
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(GPT2Test, self).compare_results(iree_results, tflite_results,
|
||||
details)
|
||||
super(GPT2Test, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
for i in range(len(iree_results)):
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[i], tflite_results[i],
|
||||
atol=5e-3).all())
|
||||
numpy.isclose(
|
||||
iree_results[i], tflite_results[i], atol=5e-3
|
||||
).all()
|
||||
)
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -9,23 +9,25 @@ model_path = "https://tfhub.dev/sayakpaul/lite-model/arbitrary-image-stylization
|
||||
|
||||
# Failure is due to avg_pool2d.
|
||||
class ImageStylizationTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ImageStylizationTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(ImageStylizationTest,
|
||||
self).compare_results(iree_results, tflite_results, details)
|
||||
super(ImageStylizationTest, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
iree = iree_results[0].flatten().astype(numpy.single)
|
||||
tflite = tflite_results[0].flatten().astype(numpy.single)
|
||||
# Error is not tiny but appears close.
|
||||
self.assertTrue(
|
||||
numpy.isclose(numpy.max(numpy.abs(iree - tflite)), 0.0,
|
||||
atol=5e-2).all())
|
||||
numpy.isclose(
|
||||
numpy.max(numpy.abs(iree - tflite)), 0.0, atol=5e-2
|
||||
).all()
|
||||
)
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -10,25 +10,26 @@ model_path = "https://storage.googleapis.com/iree-model-artifacts/inception_v4_2
|
||||
|
||||
|
||||
class InceptionV4Test(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(InceptionV4Test, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(InceptionV4Test, self).compare_results(iree_results,
|
||||
tflite_results, details)
|
||||
super(InceptionV4Test, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results, tflite_results, atol=1e-4).all())
|
||||
numpy.isclose(iree_results, tflite_results, atol=1e-4).all()
|
||||
)
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
inputs = imagenet_test_data.generate_input(self.workdir, input_details)
|
||||
# Normalize inputs to [-1, 1].
|
||||
inputs = (inputs.astype('float32') / 127.5) - 1
|
||||
inputs = (inputs.astype("float32") / 127.5) - 1
|
||||
return [inputs]
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -10,22 +10,23 @@ model_path = "https://storage.googleapis.com/iree-model-artifacts/inception_v4_2
|
||||
|
||||
|
||||
class InceptionV4Uint8Test(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(InceptionV4Uint8Test, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(InceptionV4Uint8Test,
|
||||
self).compare_results(iree_results, tflite_results, details)
|
||||
super(InceptionV4Uint8Test, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
# Dequantize outputs.
|
||||
zero_point = details[0]['quantization_parameters']['zero_points'][0]
|
||||
scale = details[0]['quantization_parameters']['scales'][0]
|
||||
zero_point = details[0]["quantization_parameters"]["zero_points"][0]
|
||||
scale = details[0]["quantization_parameters"]["scales"][0]
|
||||
dequantized_iree_results = (iree_results - zero_point) * scale
|
||||
dequantized_tflite_results = (tflite_results - zero_point) * scale
|
||||
self.assertTrue(
|
||||
numpy.isclose(dequantized_iree_results,
|
||||
dequantized_tflite_results,
|
||||
atol=5e-3).all())
|
||||
numpy.isclose(
|
||||
dequantized_iree_results, dequantized_tflite_results, atol=5e-3
|
||||
).all()
|
||||
)
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
return [imagenet_test_data.generate_input(self.workdir, input_details)]
|
||||
@@ -34,5 +35,5 @@ class InceptionV4Uint8Test(test_util.TFLiteModelTest):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -8,7 +8,6 @@ model_path = "https://tfhub.dev/tulasiram58827/lite-model/keras-ocr/dr/2?lite-fo
|
||||
|
||||
|
||||
class KerasOCRTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(KerasOCRTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
@@ -16,5 +15,5 @@ class KerasOCRTest(test_util.TFLiteModelTest):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -10,7 +10,6 @@ model_path = "https://tfhub.dev/google/lite-model/movenet/singlepose/lightning/t
|
||||
# Currently failing further in the linalg stack:
|
||||
# Bug related to linalg fusion. Collapsing dimension despite linalg index.
|
||||
class LightningFp16Test(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(LightningFp16Test, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
@@ -18,5 +17,5 @@ class LightningFp16Test(test_util.TFLiteModelTest):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -13,7 +13,6 @@ model_path = "https://tfhub.dev/google/lite-model/movenet/singlepose/lightning/t
|
||||
# Currently failing further in the linalg stack:
|
||||
# Invalid cast from ui8 to f32 TODO: make tfl.cast insert a rescale for ui8
|
||||
class LightningI8Test(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(LightningI8Test, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
@@ -22,6 +21,7 @@ class LightningI8Test(test_util.TFLiteModelTest):
|
||||
# debug the source of numerical differences.
|
||||
def plot_results(self, iree_results, tflite_results, details):
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
local_path = "/".join([self.workdir, "person.jpg"])
|
||||
im = numpy.array(Image.open(local_path))
|
||||
|
||||
@@ -38,24 +38,31 @@ class LightningI8Test(test_util.TFLiteModelTest):
|
||||
iree_y = iree_result[0, 0, :, 1] * height
|
||||
|
||||
plt.imshow(im)
|
||||
plt.scatter(tflite_y, tflite_x, label='tflite')
|
||||
plt.scatter(iree_y, iree_x, label='iree')
|
||||
plt.scatter(tflite_y, tflite_x, label="tflite")
|
||||
plt.scatter(iree_y, iree_x, label="iree")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(LightningI8Test, self).compare_results(iree_results,
|
||||
tflite_results, details)
|
||||
super(LightningI8Test, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
# This value is a discretized location of the persons joints. If we are
|
||||
# *close* to the expected position we can consider this good enough.
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[0][:, :, :, 0],
|
||||
tflite_results[0][:, :, :, 0],
|
||||
atol=25e-3).all())
|
||||
numpy.isclose(
|
||||
iree_results[0][:, :, :, 0],
|
||||
tflite_results[0][:, :, :, 0],
|
||||
atol=25e-3,
|
||||
).all()
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[0][:, :, :, 1],
|
||||
tflite_results[0][:, :, :, 1],
|
||||
atol=25e-3).all())
|
||||
numpy.isclose(
|
||||
iree_results[0][:, :, :, 1],
|
||||
tflite_results[0][:, :, :, 1],
|
||||
atol=25e-3,
|
||||
).all()
|
||||
)
|
||||
# self.plot_results(iree_results, tflite_results, details)
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
@@ -72,5 +79,5 @@ class LightningI8Test(test_util.TFLiteModelTest):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -10,7 +10,6 @@ model_path = "https://tfhub.dev/google/lite-model/movenet/singlepose/lightning/3
|
||||
# Currently failing further in the linalg stack:
|
||||
# Fusion appears to produce an invalid IR.
|
||||
class LightningTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(LightningTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
@@ -18,5 +17,5 @@ class LightningTest(test_util.TFLiteModelTest):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -10,37 +10,39 @@ import lit.llvm
|
||||
lit.llvm.initialize(lit_config, config)
|
||||
|
||||
# name: The name of this test suite.
|
||||
config.name = 'TFLITEHUB'
|
||||
config.name = "TFLITEHUB"
|
||||
|
||||
config.test_format = lit.formats.ShTest()
|
||||
|
||||
# suffixes: A list of file extensions to treat as test files.
|
||||
config.suffixes = ['.py']
|
||||
config.suffixes = [".py"]
|
||||
|
||||
# test_source_root: The root path where tests are located.
|
||||
config.test_source_root = os.path.dirname(__file__)
|
||||
|
||||
#config.use_default_substitutions()
|
||||
# config.use_default_substitutions()
|
||||
config.excludes = [
|
||||
'coco_test_data.py',
|
||||
'imagenet_test_data.py',
|
||||
'lit.cfg.py',
|
||||
'lit.site.cfg.py',
|
||||
'manual_test.py',
|
||||
'squad_test_data.py',
|
||||
'test_util.py',
|
||||
"coco_test_data.py",
|
||||
"imagenet_test_data.py",
|
||||
"lit.cfg.py",
|
||||
"lit.site.cfg.py",
|
||||
"manual_test.py",
|
||||
"squad_test_data.py",
|
||||
"test_util.py",
|
||||
]
|
||||
|
||||
config.substitutions.extend([
|
||||
('%PYTHON', sys.executable),
|
||||
])
|
||||
config.substitutions.extend(
|
||||
[
|
||||
("%PYTHON", sys.executable),
|
||||
]
|
||||
)
|
||||
|
||||
config.environment['PYTHONPATH'] = ":".join(sys.path)
|
||||
config.environment["PYTHONPATH"] = ":".join(sys.path)
|
||||
|
||||
project_root = os.path.dirname(os.path.dirname(__file__))
|
||||
|
||||
# Enable features based on -D FEATURES=hugetest,vulkan
|
||||
# syntax.
|
||||
features_param = lit_config.params.get('FEATURES')
|
||||
features_param = lit_config.params.get("FEATURES")
|
||||
if features_param:
|
||||
config.available_features.update(features_param.split(','))
|
||||
config.available_features.update(features_param.split(","))
|
||||
|
||||
@@ -8,19 +8,20 @@ model_path = "https://tfhub.dev/google/lite-model/magenta/arbitrary-image-styliz
|
||||
|
||||
|
||||
class MagentaTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MagentaTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(MagentaTest, self).compare_results(iree_results, tflite_results,
|
||||
details)
|
||||
super(MagentaTest, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=2e-1).all())
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=2e-1).all()
|
||||
)
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -2,23 +2,24 @@ import absl.flags
|
||||
import absl.testing
|
||||
import test_util
|
||||
|
||||
absl.flags.DEFINE_string('model', None, 'model path to execute')
|
||||
absl.flags.DEFINE_string("model", None, "model path to execute")
|
||||
|
||||
|
||||
class ManualTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ManualTest, self).__init__(absl.flags.FLAGS.model, *args,
|
||||
**kwargs)
|
||||
super(ManualTest, self).__init__(
|
||||
absl.flags.FLAGS.model, *args, **kwargs
|
||||
)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(ManualTest, self).compare_results(iree_results, tflite_results,
|
||||
details)
|
||||
super(ManualTest, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
|
||||
def test_compile_tflite(self):
|
||||
if self.model_path is not None:
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -8,19 +8,20 @@ model_path = "https://tfhub.dev/intel/lite-model/midas/v2_1_small/1/lite/1?lite-
|
||||
|
||||
|
||||
class MidasTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MidasTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(MidasTest, self).compare_results(iree_results, tflite_results,
|
||||
details)
|
||||
super(MidasTest, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=1e-3).all())
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=1e-3).all()
|
||||
)
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -10,19 +10,20 @@ model_path = "https://tfhub.dev/sayakpaul/lite-model/mirnet-fixed/dr/1?lite-form
|
||||
|
||||
# Note this one takes forever right now. Great for performance work!
|
||||
class MirnetTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MirnetTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(MirnetTest, self).compare_results(iree_results, tflite_results,
|
||||
details)
|
||||
super(MirnetTest, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=5e-3).all())
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=5e-3).all()
|
||||
)
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -8,7 +8,6 @@ model_path = "https://tfhub.dev/tensorflow/lite-model/mnasnet_1.0_224/1/metadata
|
||||
|
||||
|
||||
class MnasnetTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MnasnetTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
@@ -16,5 +15,5 @@ class MnasnetTest(test_util.TFLiteModelTest):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -8,41 +8,51 @@ model_path = "https://storage.googleapis.com/iree-model-artifacts/mobilebert-edg
|
||||
|
||||
|
||||
class MobileBertTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MobileBertTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
# Inputs modified to be useful mobilebert inputs.
|
||||
def generate_inputs(self, input_details):
|
||||
for input in input_details:
|
||||
absl.logging.info("\t%s, %s", str(input["shape"]),
|
||||
input["dtype"].__name__)
|
||||
absl.logging.info(
|
||||
"\t%s, %s", str(input["shape"]), input["dtype"].__name__
|
||||
)
|
||||
|
||||
args = []
|
||||
args.append(
|
||||
numpy.random.randint(low=0,
|
||||
high=256,
|
||||
size=input_details[0]["shape"],
|
||||
dtype=input_details[0]["dtype"]))
|
||||
numpy.random.randint(
|
||||
low=0,
|
||||
high=256,
|
||||
size=input_details[0]["shape"],
|
||||
dtype=input_details[0]["dtype"],
|
||||
)
|
||||
)
|
||||
args.append(
|
||||
numpy.ones(shape=input_details[1]["shape"],
|
||||
dtype=input_details[1]["dtype"]))
|
||||
numpy.ones(
|
||||
shape=input_details[1]["shape"], dtype=input_details[1]["dtype"]
|
||||
)
|
||||
)
|
||||
args.append(
|
||||
numpy.zeros(shape=input_details[2]["shape"],
|
||||
dtype=input_details[2]["dtype"]))
|
||||
numpy.zeros(
|
||||
shape=input_details[2]["shape"], dtype=input_details[2]["dtype"]
|
||||
)
|
||||
)
|
||||
return args
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(MobileBertTest, self).compare_results(iree_results,
|
||||
tflite_results, details)
|
||||
super(MobileBertTest, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=1e-4).all())
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=1e-4).all()
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[1], tflite_results[1], atol=1e-4).all())
|
||||
numpy.isclose(iree_results[1], tflite_results[1], atol=1e-4).all()
|
||||
)
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -9,41 +9,51 @@ model_path = "https://storage.googleapis.com/iree-model-artifacts/mobilebert-edg
|
||||
|
||||
|
||||
class MobileBertTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MobileBertTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
# Inputs modified to be useful mobilebert inputs.
|
||||
def generate_inputs(self, input_details):
|
||||
for input in input_details:
|
||||
absl.logging.info("\t%s, %s", str(input["shape"]),
|
||||
input["dtype"].__name__)
|
||||
absl.logging.info(
|
||||
"\t%s, %s", str(input["shape"]), input["dtype"].__name__
|
||||
)
|
||||
|
||||
args = []
|
||||
args.append(
|
||||
numpy.random.randint(low=0,
|
||||
high=256,
|
||||
size=input_details[0]["shape"],
|
||||
dtype=input_details[0]["dtype"]))
|
||||
numpy.random.randint(
|
||||
low=0,
|
||||
high=256,
|
||||
size=input_details[0]["shape"],
|
||||
dtype=input_details[0]["dtype"],
|
||||
)
|
||||
)
|
||||
args.append(
|
||||
numpy.ones(shape=input_details[1]["shape"],
|
||||
dtype=input_details[1]["dtype"]))
|
||||
numpy.ones(
|
||||
shape=input_details[1]["shape"], dtype=input_details[1]["dtype"]
|
||||
)
|
||||
)
|
||||
args.append(
|
||||
numpy.zeros(shape=input_details[2]["shape"],
|
||||
dtype=input_details[2]["dtype"]))
|
||||
numpy.zeros(
|
||||
shape=input_details[2]["shape"], dtype=input_details[2]["dtype"]
|
||||
)
|
||||
)
|
||||
return args
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(MobileBertTest, self).compare_results(iree_results,
|
||||
tflite_results, details)
|
||||
super(MobileBertTest, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=1.0).all())
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=1.0).all()
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[1], tflite_results[1], atol=1.0).all())
|
||||
numpy.isclose(iree_results[1], tflite_results[1], atol=1.0).all()
|
||||
)
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -9,39 +9,45 @@ model_path = "https://tfhub.dev/tensorflow/lite-model/mobilebert/1/metadata/1?li
|
||||
|
||||
|
||||
class MobileBertTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MobileBertTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
# Inputs modified to be useful mobilebert inputs.
|
||||
def generate_inputs(self, input_details):
|
||||
for input in input_details:
|
||||
absl.logging.info("\t%s, %s", str(input["shape"]),
|
||||
input["dtype"].__name__)
|
||||
absl.logging.info(
|
||||
"\t%s, %s", str(input["shape"]), input["dtype"].__name__
|
||||
)
|
||||
|
||||
input_0 = np.asarray(squad_test_data._INPUT_WORD_ID,
|
||||
dtype=input_details[0]["dtype"])
|
||||
input_1 = np.asarray(squad_test_data._INPUT_TYPE_ID,
|
||||
dtype=input_details[1]["dtype"])
|
||||
input_2 = np.asarray(squad_test_data._INPUT_MASK,
|
||||
dtype=input_details[2]["dtype"])
|
||||
input_0 = np.asarray(
|
||||
squad_test_data._INPUT_WORD_ID, dtype=input_details[0]["dtype"]
|
||||
)
|
||||
input_1 = np.asarray(
|
||||
squad_test_data._INPUT_TYPE_ID, dtype=input_details[1]["dtype"]
|
||||
)
|
||||
input_2 = np.asarray(
|
||||
squad_test_data._INPUT_MASK, dtype=input_details[2]["dtype"]
|
||||
)
|
||||
return [
|
||||
input_0.reshape(input_details[0]["shape"]),
|
||||
input_1.reshape(input_details[1]["shape"]),
|
||||
input_2.reshape(input_details[2]["shape"])
|
||||
input_2.reshape(input_details[2]["shape"]),
|
||||
]
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(MobileBertTest, self).compare_results(iree_results,
|
||||
tflite_results, details)
|
||||
super(MobileBertTest, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
self.assertTrue(
|
||||
np.isclose(iree_results[0], tflite_results[0], atol=1e-4).all())
|
||||
np.isclose(iree_results[0], tflite_results[0], atol=1e-4).all()
|
||||
)
|
||||
self.assertTrue(
|
||||
np.isclose(iree_results[1], tflite_results[1], atol=1e-4).all())
|
||||
np.isclose(iree_results[1], tflite_results[1], atol=1e-4).all()
|
||||
)
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -9,39 +9,45 @@ model_path = "https://storage.googleapis.com/iree-model-artifacts/mobilebert-bas
|
||||
|
||||
|
||||
class MobileBertTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MobileBertTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
# Inputs modified to be useful mobilebert inputs.
|
||||
def generate_inputs(self, input_details):
|
||||
for input in input_details:
|
||||
absl.logging.info("\t%s, %s", str(input["shape"]),
|
||||
input["dtype"].__name__)
|
||||
absl.logging.info(
|
||||
"\t%s, %s", str(input["shape"]), input["dtype"].__name__
|
||||
)
|
||||
|
||||
input_0 = np.asarray(squad_test_data._INPUT_WORD_ID,
|
||||
dtype=input_details[0]["dtype"])
|
||||
input_1 = np.asarray(squad_test_data._INPUT_TYPE_ID,
|
||||
dtype=input_details[1]["dtype"])
|
||||
input_2 = np.asarray(squad_test_data._INPUT_MASK,
|
||||
dtype=input_details[2]["dtype"])
|
||||
input_0 = np.asarray(
|
||||
squad_test_data._INPUT_WORD_ID, dtype=input_details[0]["dtype"]
|
||||
)
|
||||
input_1 = np.asarray(
|
||||
squad_test_data._INPUT_TYPE_ID, dtype=input_details[1]["dtype"]
|
||||
)
|
||||
input_2 = np.asarray(
|
||||
squad_test_data._INPUT_MASK, dtype=input_details[2]["dtype"]
|
||||
)
|
||||
return [
|
||||
input_0.reshape(input_details[0]["shape"]),
|
||||
input_1.reshape(input_details[1]["shape"]),
|
||||
input_2.reshape(input_details[2]["shape"])
|
||||
input_2.reshape(input_details[2]["shape"]),
|
||||
]
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(MobileBertTest, self).compare_results(iree_results,
|
||||
tflite_results, details)
|
||||
super(MobileBertTest, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
self.assertTrue(
|
||||
np.isclose(iree_results[0], tflite_results[0], atol=1e-4).all())
|
||||
np.isclose(iree_results[0], tflite_results[0], atol=1e-4).all()
|
||||
)
|
||||
self.assertTrue(
|
||||
np.isclose(iree_results[1], tflite_results[1], atol=1e-4).all())
|
||||
np.isclose(iree_results[1], tflite_results[1], atol=1e-4).all()
|
||||
)
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -9,40 +9,46 @@ model_path = "https://storage.googleapis.com/iree-model-artifacts/mobilebert-bas
|
||||
|
||||
|
||||
class MobileBertTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MobileBertTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
# Inputs modified to be useful mobilebert inputs.
|
||||
def generate_inputs(self, input_details):
|
||||
for input in input_details:
|
||||
absl.logging.info("\t%s, %s", str(input["shape"]),
|
||||
input["dtype"].__name__)
|
||||
absl.logging.info(
|
||||
"\t%s, %s", str(input["shape"]), input["dtype"].__name__
|
||||
)
|
||||
|
||||
input_0 = np.asarray(squad_test_data._INPUT_WORD_ID,
|
||||
dtype=input_details[0]["dtype"])
|
||||
input_1 = np.asarray(squad_test_data._INPUT_TYPE_ID,
|
||||
dtype=input_details[1]["dtype"])
|
||||
input_2 = np.asarray(squad_test_data._INPUT_MASK,
|
||||
dtype=input_details[2]["dtype"])
|
||||
input_0 = np.asarray(
|
||||
squad_test_data._INPUT_WORD_ID, dtype=input_details[0]["dtype"]
|
||||
)
|
||||
input_1 = np.asarray(
|
||||
squad_test_data._INPUT_TYPE_ID, dtype=input_details[1]["dtype"]
|
||||
)
|
||||
input_2 = np.asarray(
|
||||
squad_test_data._INPUT_MASK, dtype=input_details[2]["dtype"]
|
||||
)
|
||||
return [
|
||||
input_0.reshape(input_details[0]["shape"]),
|
||||
input_1.reshape(input_details[1]["shape"]),
|
||||
input_2.reshape(input_details[2]["shape"])
|
||||
input_2.reshape(input_details[2]["shape"]),
|
||||
]
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(MobileBertTest, self).compare_results(iree_results,
|
||||
tflite_results, details)
|
||||
super(MobileBertTest, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
# We have confirmed in large scale accuracy tests that differences this large is acceptable.
|
||||
self.assertTrue(
|
||||
np.isclose(iree_results[0], tflite_results[0], atol=5.0).all())
|
||||
np.isclose(iree_results[0], tflite_results[0], atol=5.0).all()
|
||||
)
|
||||
self.assertTrue(
|
||||
np.isclose(iree_results[1], tflite_results[1], atol=5.0).all())
|
||||
np.isclose(iree_results[1], tflite_results[1], atol=5.0).all()
|
||||
)
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -13,15 +13,16 @@ model_path = "https://storage.googleapis.com/iree-shared-files/models/ssd_mobile
|
||||
|
||||
|
||||
class MobilenetSsdQuantTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MobilenetSsdQuantTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(MobilenetSsdQuantTest,
|
||||
self).compare_results(iree_results, tflite_results, details)
|
||||
super(MobilenetSsdQuantTest, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=1.0).all())
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=1.0).all()
|
||||
)
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
img_path = "https://github.com/google-coral/test_data/raw/master/grace_hopper.bmp"
|
||||
@@ -37,5 +38,5 @@ class MobilenetSsdQuantTest(test_util.TFLiteModelTest):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -9,25 +9,26 @@ model_path = "https://storage.googleapis.com/iree-model-artifacts/mobilenet_v1_2
|
||||
|
||||
|
||||
class MobilenetV1Test(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MobilenetV1Test, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(MobilenetV1Test, self).compare_results(iree_results,
|
||||
tflite_results, details)
|
||||
super(MobilenetV1Test, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results, tflite_results, atol=1e-4).all())
|
||||
numpy.isclose(iree_results, tflite_results, atol=1e-4).all()
|
||||
)
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
inputs = imagenet_test_data.generate_input(self.workdir, input_details)
|
||||
# Normalize inputs to [-1, 1].
|
||||
inputs = (inputs.astype('float32') / 127.5) - 1
|
||||
inputs = (inputs.astype("float32") / 127.5) - 1
|
||||
return [inputs]
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -9,22 +9,23 @@ model_path = "https://storage.googleapis.com/iree-model-artifacts/mobilenet_v1_2
|
||||
|
||||
|
||||
class MobilenetV1Uint8Test(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MobilenetV1Uint8Test, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(MobilenetV1Uint8Test,
|
||||
self).compare_results(iree_results, tflite_results, details)
|
||||
super(MobilenetV1Uint8Test, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
# Dequantize outputs.
|
||||
zero_point = details[0]['quantization_parameters']['zero_points'][0]
|
||||
scale = details[0]['quantization_parameters']['scales'][0]
|
||||
zero_point = details[0]["quantization_parameters"]["zero_points"][0]
|
||||
scale = details[0]["quantization_parameters"]["scales"][0]
|
||||
dequantized_iree_results = (iree_results - zero_point) * scale
|
||||
dequantized_tflite_results = (tflite_results - zero_point) * scale
|
||||
self.assertTrue(
|
||||
numpy.isclose(dequantized_iree_results,
|
||||
dequantized_tflite_results,
|
||||
atol=5e-3).all())
|
||||
numpy.isclose(
|
||||
dequantized_iree_results, dequantized_tflite_results, atol=5e-3
|
||||
).all()
|
||||
)
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
return [imagenet_test_data.generate_input(self.workdir, input_details)]
|
||||
@@ -33,5 +34,5 @@ class MobilenetV1Uint8Test(test_util.TFLiteModelTest):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -9,17 +9,18 @@ model_path = "https://storage.googleapis.com/tf_model_garden/vision/mobilenet/v2
|
||||
|
||||
|
||||
class MobilenetV2Int8Test(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MobilenetV2Int8Test, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(MobilenetV2Int8Test,
|
||||
self).compare_results(iree_results, tflite_results, details)
|
||||
super(MobilenetV2Int8Test, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
# Although this a quantized model, inputs and outputs are in float.
|
||||
# The difference here is quite high for a dequantized output.
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results, tflite_results, atol=0.5).all())
|
||||
numpy.isclose(iree_results, tflite_results, atol=0.5).all()
|
||||
)
|
||||
|
||||
# Make sure the predicted class is the same.
|
||||
iree_predicted_class = numpy.argmax(iree_results[0][0])
|
||||
@@ -29,12 +30,12 @@ class MobilenetV2Int8Test(test_util.TFLiteModelTest):
|
||||
def generate_inputs(self, input_details):
|
||||
inputs = imagenet_test_data.generate_input(self.workdir, input_details)
|
||||
# Normalize inputs to [-1, 1].
|
||||
inputs = (inputs.astype('float32') / 127.5) - 1
|
||||
inputs = (inputs.astype("float32") / 127.5) - 1
|
||||
return [inputs]
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -9,25 +9,26 @@ model_path = "https://storage.googleapis.com/iree-model-artifacts/mobilenet_v2_1
|
||||
|
||||
|
||||
class MobilenetV2Test(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MobilenetV2Test, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(MobilenetV2Test, self).compare_results(iree_results,
|
||||
tflite_results, details)
|
||||
super(MobilenetV2Test, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results, tflite_results, atol=1e-4).all())
|
||||
numpy.isclose(iree_results, tflite_results, atol=1e-4).all()
|
||||
)
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
inputs = imagenet_test_data.generate_input(self.workdir, input_details)
|
||||
# Normalize inputs to [-1, 1].
|
||||
inputs = (inputs.astype('float32') / 127.5) - 1
|
||||
inputs = (inputs.astype("float32") / 127.5) - 1
|
||||
return [inputs]
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -9,22 +9,23 @@ model_path = "https://storage.googleapis.com/iree-model-artifacts/mobilenet_v2_2
|
||||
|
||||
|
||||
class MobilenetV2Uint8Test(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MobilenetV2Uint8Test, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(MobilenetV2Uint8Test,
|
||||
self).compare_results(iree_results, tflite_results, details)
|
||||
super(MobilenetV2Uint8Test, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
# Dequantize outputs.
|
||||
zero_point = details[0]['quantization_parameters']['zero_points'][0]
|
||||
scale = details[0]['quantization_parameters']['scales'][0]
|
||||
zero_point = details[0]["quantization_parameters"]["zero_points"][0]
|
||||
scale = details[0]["quantization_parameters"]["scales"][0]
|
||||
dequantized_iree_results = (iree_results - zero_point) * scale
|
||||
dequantized_tflite_results = (tflite_results - zero_point) * scale
|
||||
self.assertTrue(
|
||||
numpy.isclose(dequantized_iree_results,
|
||||
dequantized_tflite_results,
|
||||
atol=5e-3).all())
|
||||
numpy.isclose(
|
||||
dequantized_iree_results, dequantized_tflite_results, atol=5e-3
|
||||
).all()
|
||||
)
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
return [imagenet_test_data.generate_input(self.workdir, input_details)]
|
||||
@@ -33,5 +34,5 @@ class MobilenetV2Uint8Test(test_util.TFLiteModelTest):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -9,25 +9,26 @@ model_path = "https://storage.googleapis.com/iree-model-artifacts/mobilenet_v3-l
|
||||
|
||||
|
||||
class MobilenetV3LargeTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MobilenetV3LargeTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(MobilenetV3LargeTest,
|
||||
self).compare_results(iree_results, tflite_results, details)
|
||||
super(MobilenetV3LargeTest, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results, tflite_results, atol=1e-4).all())
|
||||
numpy.isclose(iree_results, tflite_results, atol=1e-4).all()
|
||||
)
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
inputs = imagenet_test_data.generate_input(self.workdir, input_details)
|
||||
# Normalize inputs to [-1, 1].
|
||||
inputs = (inputs.astype('float32') / 127.5) - 1
|
||||
inputs = (inputs.astype("float32") / 127.5) - 1
|
||||
return [inputs]
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -9,23 +9,25 @@ model_path = "https://storage.googleapis.com/iree-model-artifacts/mobilenet_v3-l
|
||||
|
||||
|
||||
class MobilenetV3LargeUint8Test(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MobilenetV3LargeUint8Test, self).__init__(model_path, *args,
|
||||
**kwargs)
|
||||
super(MobilenetV3LargeUint8Test, self).__init__(
|
||||
model_path, *args, **kwargs
|
||||
)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(MobilenetV3LargeUint8Test,
|
||||
self).compare_results(iree_results, tflite_results, details)
|
||||
super(MobilenetV3LargeUint8Test, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
# Dequantize outputs.
|
||||
zero_point = details[0]['quantization_parameters']['zero_points'][0]
|
||||
scale = details[0]['quantization_parameters']['scales'][0]
|
||||
zero_point = details[0]["quantization_parameters"]["zero_points"][0]
|
||||
scale = details[0]["quantization_parameters"]["scales"][0]
|
||||
dequantized_iree_results = (iree_results - zero_point) * scale
|
||||
dequantized_tflite_results = (tflite_results - zero_point) * scale
|
||||
self.assertTrue(
|
||||
numpy.isclose(dequantized_iree_results,
|
||||
dequantized_tflite_results,
|
||||
atol=5e-3).all())
|
||||
numpy.isclose(
|
||||
dequantized_iree_results, dequantized_tflite_results, atol=5e-3
|
||||
).all()
|
||||
)
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
return [imagenet_test_data.generate_input(self.workdir, input_details)]
|
||||
@@ -34,5 +36,5 @@ class MobilenetV3LargeUint8Test(test_util.TFLiteModelTest):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -9,16 +9,17 @@ model_path = "https://storage.googleapis.com/tf_model_garden/vision/mobilenet/v3
|
||||
|
||||
|
||||
class MobilenetV35Int8Test(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MobilenetV35Int8Test, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(MobilenetV35Int8Test,
|
||||
self).compare_results(iree_results, tflite_results, details)
|
||||
super(MobilenetV35Int8Test, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
# The difference here is quite high for a dequantized output.
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results, tflite_results, atol=0.5).all())
|
||||
numpy.isclose(iree_results, tflite_results, atol=0.5).all()
|
||||
)
|
||||
|
||||
# Make sure the predicted class is the same.
|
||||
iree_predicted_class = numpy.argmax(iree_results[0][0])
|
||||
@@ -28,12 +29,12 @@ class MobilenetV35Int8Test(test_util.TFLiteModelTest):
|
||||
def generate_inputs(self, input_details):
|
||||
inputs = imagenet_test_data.generate_input(self.workdir, input_details)
|
||||
# Normalize inputs to [-1, 1].
|
||||
inputs = (inputs.astype('float32') / 127.5) - 1
|
||||
inputs = (inputs.astype("float32") / 127.5) - 1
|
||||
return [inputs]
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -9,7 +9,6 @@ model_path = "https://tfhub.dev/tensorflow/lite-model/nasnet/large/1/default/1?l
|
||||
|
||||
|
||||
class MnasnetTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MnasnetTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
@@ -17,5 +16,5 @@ class MnasnetTest(test_util.TFLiteModelTest):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -11,27 +11,32 @@ model_path = "https://github.com/tensorflow/tflite-micro/raw/aeac6f39e5c7475cea2
|
||||
|
||||
|
||||
class PersonDetectTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(PersonDetectTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(PersonDetectTest, self).compare_results(iree_results,
|
||||
tflite_results, details)
|
||||
super(PersonDetectTest, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=1e-3).all())
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=1e-3).all()
|
||||
)
|
||||
|
||||
# TFLite is broken with this model so we hardcode the input/output details.
|
||||
def setup_tflite(self):
|
||||
self.input_details = [{
|
||||
"shape": [1, 96, 96, 1],
|
||||
"dtype": numpy.int8,
|
||||
"index": 0,
|
||||
}]
|
||||
self.output_details = [{
|
||||
"shape": [1, 2],
|
||||
"dtype": numpy.int8,
|
||||
}]
|
||||
self.input_details = [
|
||||
{
|
||||
"shape": [1, 96, 96, 1],
|
||||
"dtype": numpy.int8,
|
||||
"index": 0,
|
||||
}
|
||||
]
|
||||
self.output_details = [
|
||||
{
|
||||
"shape": [1, 2],
|
||||
"dtype": numpy.int8,
|
||||
}
|
||||
]
|
||||
|
||||
# The input has known expected values. We hardcode this value.
|
||||
def invoke_tflite(self, args):
|
||||
@@ -43,8 +48,9 @@ class PersonDetectTest(test_util.TFLiteModelTest):
|
||||
urllib.request.urlretrieve(img_path, local_path)
|
||||
|
||||
shape = input_details[0]["shape"]
|
||||
im = numpy.array(Image.open(local_path).resize(
|
||||
(shape[1], shape[2]))).astype(input_details[0]["dtype"])
|
||||
im = numpy.array(
|
||||
Image.open(local_path).resize((shape[1], shape[2]))
|
||||
).astype(input_details[0]["dtype"])
|
||||
args = [im.reshape(shape)]
|
||||
return args
|
||||
|
||||
@@ -52,5 +58,5 @@ class PersonDetectTest(test_util.TFLiteModelTest):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -8,25 +8,29 @@ model_path = "https://storage.googleapis.com/download.tensorflow.org/models/tfli
|
||||
|
||||
|
||||
class PoseTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(PoseTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(PoseTest, self).compare_results(iree_results, tflite_results,
|
||||
details)
|
||||
super(PoseTest, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=1e-3).all())
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=1e-3).all()
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[1], tflite_results[1], atol=1e-2).all())
|
||||
numpy.isclose(iree_results[1], tflite_results[1], atol=1e-2).all()
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[2], tflite_results[2], atol=1e-2).all())
|
||||
numpy.isclose(iree_results[2], tflite_results[2], atol=1e-2).all()
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[3], tflite_results[3], atol=1e-3).all())
|
||||
numpy.isclose(iree_results[3], tflite_results[3], atol=1e-3).all()
|
||||
)
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -8,19 +8,20 @@ model_path = "https://storage.googleapis.com/tf_model_garden/vision/resnet50_ima
|
||||
|
||||
|
||||
class ResNet50Int8Test(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ResNet50Int8Test, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(ResNet50Int8Test, self).compare_results(iree_results,
|
||||
tflite_results, details)
|
||||
super(ResNet50Int8Test, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=1.0).all())
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=1.0).all()
|
||||
)
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -10,19 +10,20 @@ model_path = "https://tfhub.dev/tulasiram58827/lite-model/rosetta/dr/1?lite-form
|
||||
|
||||
# tfl.padv2 cannot be lowered to tosa.pad. May be possible to switch tosa.concat
|
||||
class RosettaTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(RosettaTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(RosettaTest, self).compare_results(iree_results, tflite_results,
|
||||
details)
|
||||
super(RosettaTest, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=5e-3).all())
|
||||
numpy.isclose(iree_results[0], tflite_results[0], atol=5e-3).all()
|
||||
)
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -12,7 +12,6 @@ model_path = "https://tfhub.dev/google/lite-model/spice/1?lite-format=tflite"
|
||||
# 1. Multiple unsupported dynamic operations (tfl.stride, range, gather).
|
||||
# 2. Static version blocked by tfl.range not having a lowering for static fixed shapes.
|
||||
class SpiceTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SpiceTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
@@ -20,5 +19,5 @@ class SpiceTest(test_util.TFLiteModelTest):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,17 +7,17 @@ model_path = "https://tfhub.dev/tensorflow/lite-model/squeezenet/1/default/1?lit
|
||||
|
||||
|
||||
class SqueezeNetTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SqueezeNetTest, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(SqueezeNetTest, self).compare_results(iree_results,
|
||||
tflite_results, details)
|
||||
super(SqueezeNetTest, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -9,27 +9,29 @@ model_path = "https://storage.googleapis.com/iree-model-artifacts/ssd_mobilenet_
|
||||
|
||||
|
||||
class SsdMobilenetV1Test(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SsdMobilenetV1Test, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(SsdMobilenetV1Test, self).compare_results(iree_results,
|
||||
tflite_results, details)
|
||||
super(SsdMobilenetV1Test, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
for i in range(len(iree_results)):
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[i], tflite_results[i],
|
||||
atol=1e-4).all())
|
||||
numpy.isclose(
|
||||
iree_results[i], tflite_results[i], atol=1e-4
|
||||
).all()
|
||||
)
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
inputs = coco_test_data.generate_input(self.workdir, input_details)
|
||||
# Normalize inputs to [-1, 1].
|
||||
inputs = (inputs.astype('float32') / 127.5) - 1
|
||||
inputs = (inputs.astype("float32") / 127.5) - 1
|
||||
return [inputs]
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -9,25 +9,30 @@ model_path = "https://storage.googleapis.com/iree-model-artifacts/ssd_mobilenet_
|
||||
|
||||
|
||||
class SsdMobilenetV1Uint8Test(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SsdMobilenetV1Uint8Test, self).__init__(model_path, *args,
|
||||
**kwargs)
|
||||
super(SsdMobilenetV1Uint8Test, self).__init__(
|
||||
model_path, *args, **kwargs
|
||||
)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(SsdMobilenetV1Uint8Test,
|
||||
self).compare_results(iree_results, tflite_results, details)
|
||||
super(SsdMobilenetV1Uint8Test, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
for i in range(len(iree_results)):
|
||||
# Dequantize outputs.
|
||||
zero_point = details[i]['quantization_parameters']['zero_points'][0]
|
||||
scale = details[i]['quantization_parameters']['scales'][0]
|
||||
zero_point = details[i]["quantization_parameters"]["zero_points"][0]
|
||||
scale = details[i]["quantization_parameters"]["scales"][0]
|
||||
dequantized_iree_results = (iree_results[i] - zero_point) * scale
|
||||
dequantized_tflite_results = (tflite_results[i] -
|
||||
zero_point) * scale
|
||||
dequantized_tflite_results = (
|
||||
tflite_results[i] - zero_point
|
||||
) * scale
|
||||
self.assertTrue(
|
||||
numpy.isclose(dequantized_iree_results,
|
||||
dequantized_tflite_results,
|
||||
atol=0.1).all())
|
||||
numpy.isclose(
|
||||
dequantized_iree_results,
|
||||
dequantized_tflite_results,
|
||||
atol=0.1,
|
||||
).all()
|
||||
)
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
return [coco_test_data.generate_input(self.workdir, input_details)]
|
||||
@@ -36,5 +41,5 @@ class SsdMobilenetV1Uint8Test(test_util.TFLiteModelTest):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -10,28 +10,31 @@ model_path = "https://storage.googleapis.com/iree-model-artifacts/ssd_mobilenet_
|
||||
|
||||
|
||||
class SsdMobilenetV2FpnliteTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SsdMobilenetV2FpnliteTest, self).__init__(model_path, *args,
|
||||
**kwargs)
|
||||
super(SsdMobilenetV2FpnliteTest, self).__init__(
|
||||
model_path, *args, **kwargs
|
||||
)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(SsdMobilenetV2FpnliteTest,
|
||||
self).compare_results(iree_results, tflite_results, details)
|
||||
super(SsdMobilenetV2FpnliteTest, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
for i in range(len(iree_results)):
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[i], tflite_results[i],
|
||||
atol=1e-4).all())
|
||||
numpy.isclose(
|
||||
iree_results[i], tflite_results[i], atol=1e-4
|
||||
).all()
|
||||
)
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
inputs = coco_test_data.generate_input(self.workdir, input_details)
|
||||
# Normalize inputs to [-1, 1].
|
||||
inputs = (inputs.astype('float32') / 127.5) - 1
|
||||
inputs = (inputs.astype("float32") / 127.5) - 1
|
||||
return [inputs]
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -10,25 +10,30 @@ model_path = "https://storage.googleapis.com/iree-model-artifacts/ssd_mobilenet_
|
||||
|
||||
|
||||
class SsdMobilenetV2FpnliteUint8Test(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SsdMobilenetV2FpnliteUint8Test,
|
||||
self).__init__(model_path, *args, **kwargs)
|
||||
super(SsdMobilenetV2FpnliteUint8Test, self).__init__(
|
||||
model_path, *args, **kwargs
|
||||
)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(SsdMobilenetV2FpnliteUint8Test,
|
||||
self).compare_results(iree_results, tflite_results, details)
|
||||
super(SsdMobilenetV2FpnliteUint8Test, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
for i in range(len(iree_results)):
|
||||
# Dequantize outputs.
|
||||
zero_point = details[i]['quantization_parameters']['zero_points'][0]
|
||||
scale = details[i]['quantization_parameters']['scales'][0]
|
||||
zero_point = details[i]["quantization_parameters"]["zero_points"][0]
|
||||
scale = details[i]["quantization_parameters"]["scales"][0]
|
||||
dequantized_iree_results = (iree_results[i] - zero_point) * scale
|
||||
dequantized_tflite_results = (tflite_results[i] -
|
||||
zero_point) * scale
|
||||
dequantized_tflite_results = (
|
||||
tflite_results[i] - zero_point
|
||||
) * scale
|
||||
self.assertTrue(
|
||||
numpy.isclose(dequantized_iree_results,
|
||||
dequantized_tflite_results,
|
||||
atol=0.1).all())
|
||||
numpy.isclose(
|
||||
dequantized_iree_results,
|
||||
dequantized_tflite_results,
|
||||
atol=0.1,
|
||||
).all()
|
||||
)
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
return [coco_test_data.generate_input(self.workdir, input_details)]
|
||||
@@ -37,5 +42,5 @@ class SsdMobilenetV2FpnliteUint8Test(test_util.TFLiteModelTest):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -10,24 +10,28 @@ model_path = "https://storage.googleapis.com/iree-model-artifacts/ssd_mobilenet_
|
||||
|
||||
|
||||
class SsdMobilenetV2Test(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SsdMobilenetV2Test, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(SsdMobilenetV2Test, self).compare_results(iree_results,
|
||||
tflite_results, details)
|
||||
super(SsdMobilenetV2Test, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
for i in range(len(iree_results)):
|
||||
# Dequantize outputs.
|
||||
zero_point = details[i]['quantization_parameters']['zero_points'][0]
|
||||
scale = details[i]['quantization_parameters']['scales'][0]
|
||||
zero_point = details[i]["quantization_parameters"]["zero_points"][0]
|
||||
scale = details[i]["quantization_parameters"]["scales"][0]
|
||||
dequantized_iree_results = (iree_results[i] - zero_point) * scale
|
||||
dequantized_tflite_results = (tflite_results[i] -
|
||||
zero_point) * scale
|
||||
dequantized_tflite_results = (
|
||||
tflite_results[i] - zero_point
|
||||
) * scale
|
||||
self.assertTrue(
|
||||
numpy.isclose(dequantized_iree_results,
|
||||
dequantized_tflite_results,
|
||||
atol=0.1).all())
|
||||
numpy.isclose(
|
||||
dequantized_iree_results,
|
||||
dequantized_tflite_results,
|
||||
atol=0.1,
|
||||
).all()
|
||||
)
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
inputs = coco_test_data.generate_input(self.workdir, input_details)
|
||||
@@ -39,5 +43,5 @@ class SsdMobilenetV2Test(test_util.TFLiteModelTest):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -10,27 +10,29 @@ model_path = "https://storage.googleapis.com/iree-model-artifacts/ssd_mobilenet_
|
||||
|
||||
|
||||
class SsdMobilenetV2Test(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SsdMobilenetV2Test, self).__init__(model_path, *args, **kwargs)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(SsdMobilenetV2Test, self).compare_results(iree_results,
|
||||
tflite_results, details)
|
||||
super(SsdMobilenetV2Test, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
for i in range(len(iree_results)):
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[i], tflite_results[i],
|
||||
atol=1e-4).all())
|
||||
numpy.isclose(
|
||||
iree_results[i], tflite_results[i], atol=1e-4
|
||||
).all()
|
||||
)
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
inputs = coco_test_data.generate_input(self.workdir, input_details)
|
||||
# Normalize inputs to [-1, 1].
|
||||
inputs = (inputs.astype('float32') / 127.5) - 1
|
||||
inputs = (inputs.astype("float32") / 127.5) - 1
|
||||
return [inputs]
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
@@ -10,30 +10,33 @@ model_path = "https://storage.googleapis.com/iree-model-artifacts/ssd_spaghettin
|
||||
|
||||
|
||||
class SsdSpaghettinetLargeTest(test_util.TFLiteModelTest):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SsdSpaghettinetLargeTest, self).__init__(model_path, *args,
|
||||
**kwargs)
|
||||
super(SsdSpaghettinetLargeTest, self).__init__(
|
||||
model_path, *args, **kwargs
|
||||
)
|
||||
|
||||
def compare_results(self, iree_results, tflite_results, details):
|
||||
super(SsdSpaghettinetLargeTest,
|
||||
self).compare_results(iree_results, tflite_results, details)
|
||||
super(SsdSpaghettinetLargeTest, self).compare_results(
|
||||
iree_results, tflite_results, details
|
||||
)
|
||||
for i in range(len(iree_results)):
|
||||
print('iree_results: ' + str(iree_results[i]))
|
||||
print('tflite_results: ' + str(tflite_results[i]))
|
||||
print("iree_results: " + str(iree_results[i]))
|
||||
print("tflite_results: " + str(tflite_results[i]))
|
||||
self.assertTrue(
|
||||
numpy.isclose(iree_results[i], tflite_results[i],
|
||||
atol=1e-4).all())
|
||||
numpy.isclose(
|
||||
iree_results[i], tflite_results[i], atol=1e-4
|
||||
).all()
|
||||
)
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
inputs = coco_test_data.generate_input(self.workdir, input_details)
|
||||
# Normalize inputs to [-1, 1].
|
||||
inputs = (inputs.astype('float32') / 127.5) - 1
|
||||
inputs = (inputs.astype("float32") / 127.5) - 1
|
||||
return [inputs]
|
||||
|
||||
def test_compile_tflite(self):
|
||||
self.compile_and_execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
absl.testing.absltest.main()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user