Add benchmark for TF (#87)

-Refactor SharkBenchmarker to run TF
-Add example and test to benchmark TF
This commit is contained in:
Stanley Winata
2022-06-03 22:10:27 -07:00
committed by GitHub
parent 7fac03a023
commit 276dcf1441
5 changed files with 167 additions and 16 deletions

View File

@@ -43,7 +43,9 @@ Yellow=`tput setaf 3`
torch_mlir_bin=false
if [[ $(uname -s) = 'Darwin' ]]; then
echo "${Yellow}Apple macOS detected"
install_tensorflow_mac=true
if [[ $(uname -m) == 'arm64' ]]; then
install_tensorflow_metal_extension=true
echo "${Yellow}Apple M1 Detected"
hash rustc 2>/dev/null
if [ $? -eq 0 ];then
@@ -72,6 +74,17 @@ fi
# Upgrade pip and install requirements.
$PYTHON -m pip install --upgrade pip || die "Could not upgrade pip"
$PYTHON -m pip install --upgrade -r "$TD/requirements.txt" --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://github.com/llvm/torch-mlir/releases
if [ "$install_tensorflow_mac" = true ]; then
$PYTHON -m pip install tensorflow-macos
if [ $? -eq 0 ];then
echo "Successfully Installed Tensorflow tools"
else
echo "Could not install Tensorflow tools" >&2
fi
if [ "$install_tensorflow_metal_extension" = true ]; then
$PYTHON -m pip install tensorflow-metal
fi
fi
if [ "$torch_mlir_bin" = true ]; then
$PYTHON -m pip install --find-links https://github.com/llvm/torch-mlir/releases torch-mlir
if [ $? -eq 0 ];then

View File

@@ -0,0 +1,54 @@
import tensorflow as tf
from transformers import BertModel, BertTokenizer, TFBertModel
from shark.shark_inference import SharkInference
MAX_SEQUENCE_LENGTH = 512
BATCH_SIZE = 1
# Create a set of 2-dimensional inputs
bert_input = [
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32)
]
class BertModule(tf.Module):
def __init__(self):
super(BertModule, self).__init__()
# Create a BERT trainer with the created network.
self.m = TFBertModel.from_pretrained(
"microsoft/MiniLM-L12-H384-uncased", from_pt=True)
# Invoke the trainer model on the inputs. This causes the layer to be built.
self.m.predict = lambda x, y, z: self.m.call(
input_ids=x, attention_mask=y, token_type_ids=z, training=False)
@tf.function(input_signature=bert_input)
def forward(self, input_ids, attention_mask, token_type_ids):
return self.m.predict(input_ids, attention_mask, token_type_ids)
if __name__ == "__main__":
# Prepping Data
tokenizer = BertTokenizer.from_pretrained(
"microsoft/MiniLM-L12-H384-uncased")
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text,
padding='max_length',
truncation=True,
max_length=MAX_SEQUENCE_LENGTH)
for key in encoded_input:
encoded_input[key] = tf.expand_dims(
tf.convert_to_tensor(encoded_input[key]), 0)
test_input = (encoded_input["input_ids"], encoded_input["attention_mask"],
encoded_input["token_type_ids"])
shark_module = SharkInference(
BertModule(),
test_input,
benchmark_mode=True)
shark_module.set_frontend("tensorflow")
shark_module.compile()
shark_module.benchmark_all(test_input)

View File

@@ -15,7 +15,6 @@
import iree.runtime as ireert
import iree.runtime.scripts.iree_benchmark_module as benchmark_module
import iree.compiler as ireec
from iree.compiler import tf as tfc
from shark.torch_mlir_utils import get_module_name_for_asm_dump
from shark.cuda_utils import get_cuda_sm_cc
from shark.model_annotation import *
@@ -143,13 +142,10 @@ def compile_module_to_flatbuffer(module, device, frontend, func_name, model_conf
args += get_iree_device_args(device)
if frontend in ["tensorflow", "tf"]:
module = tfc.compile_module(module,
exported_names=[func_name],
import_only=True)
input_type = "mhlo"
elif frontend in ["mhlo", "tosa"]:
input_type = frontend
# Annotate the input module with the configs
if model_config_path != None:
# Currently tuned model only works on tf frontend
@@ -162,7 +158,7 @@ def compile_module_to_flatbuffer(module, device, frontend, func_name, model_conf
input_contents=input_module,
config_path=model_config_path)
module = str(module)
# Compile according to the input type, else just try compiling.
if input_type != "mhlo":
module = str(module)
@@ -207,7 +203,7 @@ def export_iree_module_to_vmfb(module,
func_name: str = "forward",
model_config_path: str = None):
flatbuffer_blob = compile_module_to_flatbuffer(module, device, frontend, func_name, model_config_path)
module_name = get_module_name_for_asm_dump(module)
module_name = f"{frontend}_{func_name}_{device}"
filename = os.path.join(directory, module_name + ".vmfb")
with open(filename, 'wb') as f:
f.write(flatbuffer_blob)
@@ -243,16 +239,21 @@ def get_results(compiled_vm, input, config, frontend="torch"):
######### Benchmark Related Tools ###########
def tensor_to_type_str(input_tensors: tuple):
def tensor_to_type_str(input_tensors: tuple, frontend: str):
"""
Input: A tuple of input tensors i.e tuple(torch.tensor)
Output: list of string that represent mlir types (i.e 1x24xf64)
# TODO: Support more than floats, and ints
"""
print("front:",frontend)
list_of_type = []
for input_tensor in input_tensors:
type_string = "x".join([str(dim) for dim in input_tensor.shape])
dtype_string = str(input_tensor.dtype).replace("torch.", "")
if frontend in ["torch", "pytorch"]:
dtype_string = str(input_tensor.dtype).replace("torch.", "")
elif frontend in ["tensorflow","tf"]:
dtype = input_tensor.dtype
dtype_string = re.findall('\'[^"]*\'',str(dtype))[0].replace("\'","")
regex_split = re.compile("([a-zA-Z]+)([0-9]+)")
match = regex_split.match(dtype_string)
mlir_type_string = str(match.group(1)[0]) + str(match.group(2))
@@ -264,6 +265,7 @@ def tensor_to_type_str(input_tensors: tuple):
def build_benchmark_args(input_file: str,
device: str,
input_tensors: tuple,
frontend: str,
training=False):
"""
Inputs: input_file leading to vmfb, input_tensor to function, target device, and whether it is training or not.
@@ -278,7 +280,7 @@ def build_benchmark_args(input_file: str,
fn_name = "train"
benchmark_cl.append(f"--entry_function={fn_name}")
benchmark_cl.append(f"--driver={IREE_DEVICE_MAP[device]}")
mlir_input_types = tensor_to_type_str(input_tensors)
mlir_input_types = tensor_to_type_str(input_tensors, frontend)
for mlir_input in mlir_input_types:
benchmark_cl.append(f"--function_input={mlir_input}")
time_extractor = "| awk \'END{{print $2 $3}}\'"

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from iree.compiler import tf as tfc
from torch.utils._python_dispatch import enable_torch_dispatch_mode
from torch_mlir.eager_mode import torch_mlir_tensor
from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor
@@ -45,10 +46,15 @@ class SharkRunner:
self.input = input
self.frontend = frontend
self.vmfb_file = None
func_name = "forward"
device = device if device is not None else shark_args.device
if self.frontend in ["pytorch", "torch"]:
self.model = get_torch_mlir_module(self.model, input, dynamic,
jit_trace, from_aot)
elif frontend in ["tensorflow", "tf"]:
self.model = tfc.compile_module(self.model,
exported_names=[func_name],
import_only=True)
(
self.iree_compilation_module,
self.iree_config,
@@ -103,7 +109,7 @@ class SharkBenchmarkRunner(SharkRunner):
shark_args.repro_dir,
frontend)
self.benchmark_cl = build_benchmark_args(self.vmfb_file, device, input,
from_aot)
frontend, from_aot)
def benchmark_frontend(self, inputs):
if self.frontend in ["pytorch", "torch"]:
@@ -128,7 +134,18 @@ class SharkBenchmarkRunner(SharkRunner):
)
def benchmark_tf(self, inputs):
print(f"TF benchmark not implemented yet!")
for i in range(shark_args.num_warmup_iterations):
self.frontend_model.forward(*inputs)
begin = time.time()
for i in range(shark_args.num_iterations):
out = self.frontend_model.forward(*inputs)
if i == shark_args.num_iterations - 1:
end = time.time()
break
print(
f"TF benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
)
return
def benchmark_c(self):
@@ -137,7 +154,7 @@ class SharkBenchmarkRunner(SharkRunner):
def benchmark_python(self, inputs):
inputs = self.input if self.from_aot else inputs
input_list = [x.detach().numpy() for x in inputs]
input_list = [x for x in inputs]
for i in range(shark_args.num_warmup_iterations):
self.forward(input_list, self.frontend)

View File

@@ -2,13 +2,60 @@ from shark.shark_inference import SharkInference
from shark.iree_utils import check_device_drivers
import torch
import tensorflow as tf
import numpy as np
import torchvision.models as models
from transformers import AutoModelForSequenceClassification
from transformers import AutoModelForSequenceClassification, BertTokenizer, TFBertModel
import importlib
import pytest
torch.manual_seed(0)
##################### Tensorflow Hugging Face LM Models ###################################
MAX_SEQUENCE_LENGTH = 512
BATCH_SIZE = 1
# Create a set of 2-dimensional inputs
tf_bert_input = [
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32)
]
class TFHuggingFaceLanguage(tf.Module):
def __init__(self, hf_model_name):
super(TFHuggingFaceLanguage, self).__init__()
# Create a BERT trainer with the created network.
self.m = TFBertModel.from_pretrained(
hf_model_name, from_pt=True)
# Invoke the trainer model on the inputs. This causes the layer to be built.
self.m.predict = lambda x, y, z: self.m.call(
input_ids=x, attention_mask=y, token_type_ids=z, training=False)
@tf.function(input_signature=tf_bert_input)
def forward(self, input_ids, attention_mask, token_type_ids):
return self.m.predict(input_ids, attention_mask, token_type_ids)
def get_TFhf_model(name):
model = TFHuggingFaceLanguage(name)
tokenizer = BertTokenizer.from_pretrained(
"microsoft/MiniLM-L12-H384-uncased")
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text,
padding='max_length',
truncation=True,
max_length=MAX_SEQUENCE_LENGTH)
for key in encoded_input:
encoded_input[key] = tf.expand_dims(
tf.convert_to_tensor(encoded_input[key]), 0)
test_input = (encoded_input["input_ids"], encoded_input["attention_mask"],
encoded_input["token_type_ids"])
actual_out = model.forward(*test_input)
return model, test_input, actual_out
##################### Hugging Face LM Models ###################################
@@ -74,9 +121,8 @@ pytest_benchmark_param = pytest.mark.parametrize(
pytest.param(True, 'cpu', marks=pytest.mark.skip),
])
@pytest_benchmark_param
def test_minilm_torch(dynamic, device):
def test_bench_minilm_torch(dynamic, device):
model, test_input, act_out = get_hf_model("microsoft/MiniLM-L12-H384-uncased")
shark_module = SharkInference(model, (test_input,),
device=device,
@@ -91,3 +137,22 @@ def test_minilm_torch(dynamic, device):
except Exception as e:
# If anything happen during benchmarking, assert False/failure.
assert False
@pytest.mark.skipif(importlib.util.find_spec("iree.tools") is None, reason = "Cannot find tools to import TF")
@pytest_benchmark_param
def test_bench_minilm_tf(dynamic, device):
model, test_input, act_out = get_TFhf_model("microsoft/MiniLM-L12-H384-uncased")
shark_module = SharkInference(model, test_input,
device=device,
dynamic=dynamic,
jit_trace=True,
benchmark_mode=True)
try:
# If becnhmarking succesful, assert success/True.
shark_module.set_frontend("tensorflow")
shark_module.compile()
shark_module.benchmark_all(test_input)
assert True
except Exception as e:
# If anything happen during benchmarking, assert False/failure.
assert False