mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
46 lines
1.4 KiB
Python
46 lines
1.4 KiB
Python
import torch
|
|
from benchmarks.hf_transformer import AMDSharkHFBenchmarkRunner
|
|
import importlib
|
|
import pytest
|
|
|
|
torch.manual_seed(0)
|
|
|
|
############################# HF Benchmark Tests ####################################
|
|
|
|
# Test running benchmark module without failing.
|
|
pytest_benchmark_param = pytest.mark.parametrize(
|
|
("dynamic", "device"),
|
|
[
|
|
pytest.param(False, "cpu"),
|
|
# TODO: Language models are failing for dynamic case..
|
|
pytest.param(True, "cpu", marks=pytest.mark.skip),
|
|
],
|
|
)
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
importlib.util.find_spec("onnxruntime") is None,
|
|
reason="Cannot find ONNXRUNTIME.",
|
|
)
|
|
@pytest_benchmark_param
|
|
def test_HFbench_minilm_torch(dynamic, device):
|
|
model_name = "bert-base-uncased"
|
|
test_input = torch.randint(2, (1, 128))
|
|
try:
|
|
amdshark_module = AMDSharkHFBenchmarkRunner(
|
|
model_name,
|
|
(test_input,),
|
|
jit_trace=True,
|
|
dynamic=dynamic,
|
|
device=device,
|
|
)
|
|
amdshark_module.benchmark_c()
|
|
amdshark_module.benchmark_python((test_input,))
|
|
amdshark_module.benchmark_torch(test_input)
|
|
amdshark_module.benchmark_onnx(test_input)
|
|
# If becnhmarking succesful, assert success/True.
|
|
assert True
|
|
except Exception as e:
|
|
# If anything happen during benchmarking, assert False/failure.
|
|
assert False
|