mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 22:07:55 -05:00
Added device map and runs test according to the driver.
Automatically runs test according to the device driver present.
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
[pytest]
|
||||
addopts = --workers auto -m cpu_static --verbose -p no:warnings
|
||||
addopts = --workers auto --verbose -p no:warnings
|
||||
norecursedirs = inference
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from shark.shark_runner import SharkInference
|
||||
from shark.iree_utils import check_device_drivers
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
@@ -78,15 +79,36 @@ def compare_tensors(torch_tensor, numpy_tensor):
|
||||
|
||||
############################# Model Tests ####################################
|
||||
|
||||
# A specific case can be run with -m parameter. For eg., to run gpu - static
|
||||
# case for all models it can be run by `pytest -m gpu_static` command.
|
||||
# A specific case can be run by commenting different cases. Runs all the test
|
||||
# across cpu, gpu and vulkan according to available drivers.
|
||||
pytest_param = pytest.mark.parametrize(('dynamic', 'device'), [
|
||||
pytest.param(False, 'cpu', marks=pytest.mark.cpu_static),
|
||||
pytest.param(True, 'cpu', marks=pytest.mark.cpu_dynamic),
|
||||
pytest.param(False, 'gpu', marks=pytest.mark.gpu_static),
|
||||
pytest.param(True, 'gpu', marks=pytest.mark.gpu_dynamic),
|
||||
pytest.param(False, 'vulkan', marks=pytest.mark.vulkan_static),
|
||||
pytest.param(True, 'vulkan', marks=pytest.mark.vulkan_dynamic),
|
||||
pytest.param(False, 'cpu'),
|
||||
# TODO: Language models are failing for dynamic case..
|
||||
pytest.param(True, 'cpu', marks=pytest.mark.skip),
|
||||
pytest.param(False,
|
||||
'gpu',
|
||||
marks=pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")),
|
||||
pytest.param(True,
|
||||
'gpu',
|
||||
marks=pytest.mark.skipif(check_device_drivers("gpu"),
|
||||
reason="nvidia-smi not found")),
|
||||
pytest.param(
|
||||
False,
|
||||
'vulkan',
|
||||
marks=pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
)),
|
||||
pytest.param(
|
||||
True,
|
||||
'vulkan',
|
||||
marks=pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason=
|
||||
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||
)),
|
||||
])
|
||||
|
||||
|
||||
|
||||
@@ -19,11 +19,38 @@ import numpy as np
|
||||
import os
|
||||
from shark.torch_mlir_utils import get_module_name_for_asm_dump
|
||||
|
||||
IREE_DEVICE_MAP = {"cpu": "dylib", "gpu": "cuda", "vulkan": "vulkan"}
|
||||
IREE_DEVICE_MAP = {
|
||||
"cpu": "dylib",
|
||||
"gpu": "cuda",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
"metal": "vulkan"
|
||||
}
|
||||
|
||||
|
||||
def check_device_drivers(device):
|
||||
"""Checks necessary drivers present for gpu and vulkan devices"""
|
||||
if (device in ["gpu", "cuda"]):
|
||||
try:
|
||||
subprocess.check_output('nvidia-smi')
|
||||
except Exception:
|
||||
return True
|
||||
elif (device in ["metal", "vulkan"]):
|
||||
try:
|
||||
subprocess.check_output('vulkaninfo')
|
||||
except Exception:
|
||||
return True
|
||||
elif (device == "cpu"):
|
||||
return False
|
||||
# Unknown device.
|
||||
else:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_iree_compiled_module(module, device: str):
|
||||
"""TODO: Documentation"""
|
||||
"""Given an mlir module returns the compiled .vmfb"""
|
||||
args = ["--iree-llvm-target-cpu-features=host"]
|
||||
if (device == "cpu"):
|
||||
find_triple_cmd = "uname -s -m"
|
||||
@@ -65,7 +92,7 @@ def export_iree_module_to_vmfb(module, device: str, directory: str):
|
||||
|
||||
|
||||
def get_results(compiled_vm, input, config):
|
||||
"""TODO: Documentation"""
|
||||
"""Runs a .vmfb file given inputs and config and returns output."""
|
||||
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
|
||||
result = compiled_vm(*device_inputs)
|
||||
result_tensors = []
|
||||
|
||||
Reference in New Issue
Block a user