Files
AMD-SHARK-Studio/benchmarks/tests/test_hf_benchmark.py

46 lines
1.3 KiB
Python

import torch
from benchmarks.hf_transformer import SharkHFBenchmarkRunner
import importlib
import pytest
torch.manual_seed(0)
############################# HF Benchmark Tests ####################################
# Test running benchmark module without failing.
pytest_benchmark_param = pytest.mark.parametrize(
("dynamic", "device"),
[
pytest.param(False, "cpu"),
# TODO: Language models are failing for dynamic case..
pytest.param(True, "cpu", marks=pytest.mark.skip),
],
)
@pytest.mark.skipif(
importlib.util.find_spec("onnxruntime") is None,
reason="Cannot find ONNXRUNTIME.",
)
@pytest_benchmark_param
def test_HFbench_minilm_torch(dynamic, device):
model_name = "bert-base-uncased"
test_input = torch.randint(2, (1, 128))
try:
shark_module = SharkHFBenchmarkRunner(
model_name,
(test_input,),
jit_trace=True,
dynamic=dynamic,
device=device,
)
shark_module.benchmark_c()
shark_module.benchmark_python((test_input,))
shark_module.benchmark_torch(test_input)
shark_module.benchmark_onnx(test_input)
# If becnhmarking succesful, assert success/True.
assert True
except Exception as e:
# If anything happen during benchmarking, assert False/failure.
assert False