mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 06:17:55 -05:00
Update benchmark command to ToM and Add to test (#125)
-Update benchmark_cl to latest benchmark_module API -tensorflow OOM issue.
This commit is contained in:
2
.github/workflows/test-models.yml
vendored
2
.github/workflows/test-models.yml
vendored
@@ -52,7 +52,7 @@ jobs:
|
||||
cd $GITHUB_WORKSPACE
|
||||
IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest -k 'not benchmark' --ignore=benchmarks/tests/test_hf_benchmark.py --ignore=benchmarks/tests/test_benchmark.py --ignore=tank/tf/ --ignore=shark/tests/test_shark_importer.py
|
||||
pytest -k 'not benchmark' --ignore=tank/tf/ --ignore=shark/tests/test_shark_importer.py
|
||||
|
||||
perf-macOS:
|
||||
runs-on: MacStudio
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -161,7 +161,7 @@ cython_debug/
|
||||
#.idea/
|
||||
|
||||
# Shark related artefacts
|
||||
shark.venv/
|
||||
*venv/
|
||||
|
||||
# ORT related artefacts
|
||||
cache_models/
|
||||
|
||||
@@ -11,6 +11,9 @@ import pytest
|
||||
import unittest
|
||||
|
||||
torch.manual_seed(0)
|
||||
gpus = tf.config.experimental.list_physical_devices('GPU')
|
||||
for gpu in gpus:
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
|
||||
##################### Tensorflow Hugging Face LM Models ###################################
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
@@ -42,8 +45,7 @@ class TFHuggingFaceLanguage(tf.Module):
|
||||
|
||||
def get_TFhf_model(name):
|
||||
model = TFHuggingFaceLanguage(name)
|
||||
tokenizer = BertTokenizer.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased")
|
||||
tokenizer = BertTokenizer.from_pretrained(name)
|
||||
text = "Replace me by any text you'd like."
|
||||
encoded_input = tokenizer(text,
|
||||
padding='max_length',
|
||||
@@ -124,20 +126,19 @@ pytest_benchmark_param = pytest.mark.parametrize(
|
||||
'gpu',
|
||||
marks=pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")),
|
||||
pytest.param(True,
|
||||
'gpu',
|
||||
marks=pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")),
|
||||
pytest.param(True,
|
||||
'gpu',
|
||||
marks=pytest.mark.skip),
|
||||
pytest.param(
|
||||
False,
|
||||
'vulkan',
|
||||
False,
|
||||
'vulkan',
|
||||
marks=pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
)),
|
||||
pytest.param(
|
||||
True,
|
||||
'vulkan',
|
||||
True,
|
||||
'vulkan',
|
||||
marks=pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
@@ -188,8 +189,7 @@ def test_bench_distilbert(dynamic, device):
|
||||
assert False
|
||||
|
||||
|
||||
@pytest.mark.skipif(importlib.util.find_spec("iree.tools") is None,
|
||||
reason="Cannot find tools to import TF")
|
||||
@pytest.mark.skip(reason="XLM Roberta too large to test.")
|
||||
@pytest_benchmark_param
|
||||
def test_bench_xlm_roberta(dynamic, device):
|
||||
model, test_input, act_out = get_TFhf_model("xlm-roberta-base")
|
||||
|
||||
@@ -4,7 +4,7 @@ requires = [
|
||||
"wheel",
|
||||
"packaging",
|
||||
|
||||
"numpy",
|
||||
"numpy==1.22.4",
|
||||
"torch-mlir>=20220428.420",
|
||||
"iree-compiler>=20220427.13",
|
||||
"iree-runtime>=20220427.13",
|
||||
|
||||
@@ -72,7 +72,7 @@ fi
|
||||
|
||||
# Upgrade pip and install requirements.
|
||||
$PYTHON -m pip install --upgrade pip || die "Could not upgrade pip"
|
||||
$PYTHON -m pip install --upgrade -r "$TD/requirements.txt" --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://github.com/llvm/torch-mlir/releases
|
||||
$PYTHON -m pip install --upgrade -r "$TD/requirements.txt"
|
||||
if [ "$torch_mlir_bin" = true ]; then
|
||||
$PYTHON -m pip install --find-links https://github.com/llvm/torch-mlir/releases torch-mlir --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
if [ $? -eq 0 ];then
|
||||
|
||||
@@ -42,7 +42,6 @@ IREE_TARGET_MAP = {
|
||||
"rocm": "rocm"
|
||||
}
|
||||
|
||||
|
||||
UNIT_TO_SECOND_MAP = {"ms": 0.001, "s": 1}
|
||||
|
||||
|
||||
@@ -112,7 +111,7 @@ def get_vulkan_triple_flag():
|
||||
return "-iree-vulkan-target-triple=ampere-rtx3080-linux"
|
||||
else:
|
||||
print(
|
||||
"Optimized kernel for your target device is not added yet. Contact SHARK Admin on discord or pull up an issue."
|
||||
"Optimized kernel for your target device is not added yet. Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u] or pull up an issue."
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -178,8 +177,7 @@ def compile_module_to_flatbuffer(module, device, frontend, func_name,
|
||||
module = str(module)
|
||||
|
||||
# Compile according to the input type, else just try compiling.
|
||||
print(type(module))
|
||||
if input_type not in ["mhlo","tosa"]:
|
||||
if input_type not in ["mhlo", "tosa"]:
|
||||
module = str(module)
|
||||
if input_type != "":
|
||||
# Currently for MHLO/TOSA.
|
||||
@@ -318,7 +316,7 @@ def build_benchmark_args(input_file: str,
|
||||
# TODO: Replace name of train with actual train fn name.
|
||||
fn_name = "train"
|
||||
benchmark_cl.append(f"--entry_function={fn_name}")
|
||||
benchmark_cl.append(f"--driver={IREE_DEVICE_MAP[device]}")
|
||||
benchmark_cl.append(f"--device={IREE_DEVICE_MAP[device]}")
|
||||
mlir_input_types = tensor_to_type_str(input_tensors, frontend)
|
||||
for mlir_input in mlir_input_types:
|
||||
benchmark_cl.append(f"--function_input={mlir_input}")
|
||||
|
||||
@@ -12,6 +12,7 @@ from shark.shark_inference import SharkInference
|
||||
|
||||
|
||||
class SharkImporter:
|
||||
|
||||
def __init__(self,
|
||||
model_path,
|
||||
model_type: str = "tflite",
|
||||
@@ -41,22 +42,28 @@ class SharkImporter:
|
||||
if self.model_type == "tflite":
|
||||
print("Setting up for TMP_DIR")
|
||||
exe_basename = os.path.basename(sys.argv[0])
|
||||
self.workdir = os.path.join(os.path.dirname(__file__), "tmp", exe_basename)
|
||||
self.workdir = os.path.join(os.path.dirname(__file__), "tmp",
|
||||
exe_basename)
|
||||
print(f"TMP_DIR = {self.workdir}")
|
||||
os.makedirs(self.workdir, exist_ok=True)
|
||||
self.tflite_file = '/'.join([self.workdir, 'model.tflite'])
|
||||
print("Setting up local address for tflite model file: ", self.tflite_file)
|
||||
print("Setting up local address for tflite model file: ",
|
||||
self.tflite_file)
|
||||
if os.path.exists(self.model_path):
|
||||
self.tflite_file = self.model_path
|
||||
else:
|
||||
print("Download tflite model")
|
||||
urllib.request.urlretrieve(self.model_path, self.tflite_file)
|
||||
urllib.request.urlretrieve(self.model_path,
|
||||
self.tflite_file)
|
||||
print("Setting up tflite interpreter")
|
||||
self.tflite_interpreter = tf.lite.Interpreter(model_path=self.tflite_file)
|
||||
self.tflite_interpreter = tf.lite.Interpreter(
|
||||
model_path=self.tflite_file)
|
||||
self.tflite_interpreter.allocate_tensors()
|
||||
# default input initialization
|
||||
self.input_details, self.output_details = self.get_model_details()
|
||||
inputs = self.generate_inputs(self.input_details) # device_inputs
|
||||
self.input_details, self.output_details = self.get_model_details(
|
||||
)
|
||||
inputs = self.generate_inputs(
|
||||
self.input_details) # device_inputs
|
||||
self.setup_inputs(inputs)
|
||||
|
||||
def generate_inputs(self, input_details):
|
||||
@@ -85,7 +92,8 @@ class SharkImporter:
|
||||
if self.model_source_hub == "tfhub":
|
||||
# compile and run tfhub tflite
|
||||
print("Inference tfhub model")
|
||||
self.shark_module = SharkInference(self.tflite_file, self.inputs,
|
||||
self.shark_module = SharkInference(self.tflite_file,
|
||||
self.inputs,
|
||||
device=self.device,
|
||||
dynamic=self.dynamic,
|
||||
jit_trace=self.jit_trace)
|
||||
|
||||
@@ -51,7 +51,8 @@ class SharkInference:
|
||||
# Sets the frontend i.e `pytorch` or `tensorflow`.
|
||||
def set_frontend(self, frontend: str):
|
||||
if frontend not in [
|
||||
"pytorch", "torch", "tensorflow", "tf", "mhlo", "linalg", "tosa", "tflite"
|
||||
"pytorch", "torch", "tensorflow", "tf", "mhlo", "linalg",
|
||||
"tosa", "tflite"
|
||||
]:
|
||||
print_err("frontend not supported.")
|
||||
else:
|
||||
@@ -94,7 +95,7 @@ class SharkInference:
|
||||
@benchmark_mode
|
||||
def benchmark_all(self, inputs):
|
||||
self.shark_runner.benchmark_all(inputs)
|
||||
|
||||
|
||||
@benchmark_mode
|
||||
def benchmark_frontend(self, inputs):
|
||||
self.shark_runner.benchmark_frontend(inputs)
|
||||
|
||||
Reference in New Issue
Block a user