mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 22:07:55 -05:00
Added device map and runs test according to the driver.
Automatically runs test according to the device driver present.
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
[pytest]
|
[pytest]
|
||||||
addopts = --workers auto -m cpu_static --verbose -p no:warnings
|
addopts = --workers auto --verbose -p no:warnings
|
||||||
norecursedirs = inference
|
norecursedirs = inference
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from shark.shark_runner import SharkInference
|
from shark.shark_runner import SharkInference
|
||||||
|
from shark.iree_utils import check_device_drivers
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -78,15 +79,36 @@ def compare_tensors(torch_tensor, numpy_tensor):
|
|||||||
|
|
||||||
############################# Model Tests ####################################
|
############################# Model Tests ####################################
|
||||||
|
|
||||||
# A specific case can be run with -m parameter. For eg., to run gpu - static
|
# A specific case can be run by commenting different cases. Runs all the test
|
||||||
# case for all models it can be run by `pytest -m gpu_static` command.
|
# across cpu, gpu and vulkan according to available drivers.
|
||||||
pytest_param = pytest.mark.parametrize(('dynamic', 'device'), [
|
pytest_param = pytest.mark.parametrize(('dynamic', 'device'), [
|
||||||
pytest.param(False, 'cpu', marks=pytest.mark.cpu_static),
|
pytest.param(False, 'cpu'),
|
||||||
pytest.param(True, 'cpu', marks=pytest.mark.cpu_dynamic),
|
# TODO: Language models are failing for dynamic case..
|
||||||
pytest.param(False, 'gpu', marks=pytest.mark.gpu_static),
|
pytest.param(True, 'cpu', marks=pytest.mark.skip),
|
||||||
pytest.param(True, 'gpu', marks=pytest.mark.gpu_dynamic),
|
pytest.param(False,
|
||||||
pytest.param(False, 'vulkan', marks=pytest.mark.vulkan_static),
|
'gpu',
|
||||||
pytest.param(True, 'vulkan', marks=pytest.mark.vulkan_dynamic),
|
marks=pytest.mark.skipif(check_device_drivers("gpu"),
|
||||||
|
reason="nvidia-smi not found")),
|
||||||
|
pytest.param(True,
|
||||||
|
'gpu',
|
||||||
|
marks=pytest.mark.skipif(check_device_drivers("gpu"),
|
||||||
|
reason="nvidia-smi not found")),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
'vulkan',
|
||||||
|
marks=pytest.mark.skipif(
|
||||||
|
check_device_drivers("vulkan"),
|
||||||
|
reason=
|
||||||
|
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||||
|
)),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
'vulkan',
|
||||||
|
marks=pytest.mark.skipif(
|
||||||
|
check_device_drivers("vulkan"),
|
||||||
|
reason=
|
||||||
|
"vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases"
|
||||||
|
)),
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -19,11 +19,38 @@ import numpy as np
|
|||||||
import os
|
import os
|
||||||
from shark.torch_mlir_utils import get_module_name_for_asm_dump
|
from shark.torch_mlir_utils import get_module_name_for_asm_dump
|
||||||
|
|
||||||
IREE_DEVICE_MAP = {"cpu": "dylib", "gpu": "cuda", "vulkan": "vulkan"}
|
IREE_DEVICE_MAP = {
|
||||||
|
"cpu": "dylib",
|
||||||
|
"gpu": "cuda",
|
||||||
|
"cuda": "cuda",
|
||||||
|
"vulkan": "vulkan",
|
||||||
|
"metal": "vulkan"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def check_device_drivers(device):
|
||||||
|
"""Checks necessary drivers present for gpu and vulkan devices"""
|
||||||
|
if (device in ["gpu", "cuda"]):
|
||||||
|
try:
|
||||||
|
subprocess.check_output('nvidia-smi')
|
||||||
|
except Exception:
|
||||||
|
return True
|
||||||
|
elif (device in ["metal", "vulkan"]):
|
||||||
|
try:
|
||||||
|
subprocess.check_output('vulkaninfo')
|
||||||
|
except Exception:
|
||||||
|
return True
|
||||||
|
elif (device == "cpu"):
|
||||||
|
return False
|
||||||
|
# Unknown device.
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_iree_compiled_module(module, device: str):
|
def get_iree_compiled_module(module, device: str):
|
||||||
"""TODO: Documentation"""
|
"""Given an mlir module returns the compiled .vmfb"""
|
||||||
args = ["--iree-llvm-target-cpu-features=host"]
|
args = ["--iree-llvm-target-cpu-features=host"]
|
||||||
if (device == "cpu"):
|
if (device == "cpu"):
|
||||||
find_triple_cmd = "uname -s -m"
|
find_triple_cmd = "uname -s -m"
|
||||||
@@ -65,7 +92,7 @@ def export_iree_module_to_vmfb(module, device: str, directory: str):
|
|||||||
|
|
||||||
|
|
||||||
def get_results(compiled_vm, input, config):
|
def get_results(compiled_vm, input, config):
|
||||||
"""TODO: Documentation"""
|
"""Runs a .vmfb file given inputs and config and returns output."""
|
||||||
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
|
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
|
||||||
result = compiled_vm(*device_inputs)
|
result = compiled_vm(*device_inputs)
|
||||||
result_tensors = []
|
result_tensors = []
|
||||||
|
|||||||
Reference in New Issue
Block a user