Files
AMD-SHARK-Studio/shark/iree_utils/benchmark_utils.py

95 lines
3.5 KiB
Python

# Copyright 2020 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import iree.runtime.scripts.iree_benchmark_module as benchmark_module
from shark.iree_utils._common import run_cmd, IREE_DEVICE_MAP
import numpy as np
import os
import re
UNIT_TO_SECOND_MAP = {"ms": 0.001, "s": 1}
def tensor_to_type_str(input_tensors: tuple, mlir_dialect: str):
"""
Input: A tuple of input tensors i.e tuple(torch.tensor)
Output: list of string that represent mlir types (i.e 1x24xf64)
# TODO: Support more than floats, and ints
"""
list_of_type = []
for input_tensor in input_tensors:
type_string = "x".join([str(dim) for dim in input_tensor.shape])
if mlir_dialect in ["linalg", "tosa"]:
dtype_string = str(input_tensor.dtype).replace("torch.", "")
elif mlir_dialect in ["mhlo", "tflite"]:
dtype = input_tensor.dtype
dtype_string = re.findall("'[^\"]*'", str(dtype))[0].replace(
"'", ""
)
regex_split = re.compile("([a-zA-Z]+)([0-9]+)")
match = regex_split.match(dtype_string)
mlir_type_string = str(match.group(1)[0]) + str(match.group(2))
type_string += f"x{mlir_type_string}"
list_of_type.append(type_string)
return list_of_type
def build_benchmark_args(
input_file: str,
device: str,
input_tensors: tuple,
mlir_dialect: str,
training=False,
):
"""
Inputs: input_file leading to vmfb, input_tensor to function, target device,
and whether it is training or not.
Outputs: string that execute benchmark-module on target model.
"""
path = benchmark_module.__path__[0]
benchmarker_path = os.path.join(path, "..", "..", "iree-benchmark-module")
benchmark_cl = [benchmarker_path, f"--module_file={input_file}"]
# TODO: The function named can be passed as one of the args.
fn_name = "forward"
if training == True:
# TODO: Replace name of train with actual train fn name.
fn_name = "train"
benchmark_cl.append(f"--entry_function={fn_name}")
benchmark_cl.append(f"--device={IREE_DEVICE_MAP[device]}")
mlir_input_types = tensor_to_type_str(input_tensors, mlir_dialect)
for mlir_input in mlir_input_types:
benchmark_cl.append(f"--function_input={mlir_input}")
time_extractor = "| awk 'END{{print $2 $3}}'"
benchmark_cl.append(time_extractor)
return benchmark_cl
def run_benchmark_module(benchmark_cl):
"""
Run benchmark command, extract result and return iteration/seconds.
# TODO: Add an example of the benchmark command.
Input: benchmark command.
"""
benchmark_path = benchmark_cl[0]
assert os.path.exists(
benchmark_path
), "Cannot find benchmark_module, Please contact SHARK maintainer on discord."
bench_result = run_cmd(" ".join(benchmark_cl))
regex_split = re.compile("([0-9]+[.]*[0-9]*)([a-zA-Z]+)")
match = regex_split.match(bench_result)
time = float(match.group(1))
unit = match.group(2)
return 1.0 / (time * UNIT_TO_SECOND_MAP[unit])