fix calls and remove unused imports for check_device_drivers

This commit is contained in:
PhaneeshB
2023-09-04 19:51:40 +05:30
committed by Phaneesh Barwaria
parent 72c0a8abc8
commit 679a452139
5 changed files with 4 additions and 10 deletions

View File

@@ -129,12 +129,12 @@ pytest_benchmark_param = pytest.mark.parametrize(
pytest.param(True, "cpu", marks=pytest.mark.skip),
pytest.param(
False,
"gpu",
"cuda",
marks=pytest.mark.skipif(
check_device_drivers("gpu"), reason="nvidia-smi not found"
check_device_drivers("cuda"), reason="nvidia-smi not found"
),
),
pytest.param(True, "gpu", marks=pytest.mark.skip),
pytest.param(True, "cuda", marks=pytest.mark.skip),
pytest.param(
False,
"vulkan",

View File

@@ -1,4 +1,3 @@
from shark.iree_utils._common import check_device_drivers, device_driver_info
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_model
from shark.parser import shark_args

View File

@@ -3,10 +3,6 @@ import os
import torch
import numpy as np
from shark_opt_wrapper import OPTForCausalLMModel
from shark.iree_utils._common import (
check_device_drivers,
device_driver_info,
)
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from transformers import AutoTokenizer, OPTForCausalLM

View File

@@ -1,4 +1,3 @@
from shark.iree_utils._common import check_device_drivers, device_driver_info
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_model
from tank.test_utils import get_valid_test_params, shark_test_name_func

View File

@@ -44,7 +44,7 @@ class TapasBaseModuleTest(unittest.TestCase):
self.module_tester.create_and_check_module(dynamic, device)
@pytest.mark.skipif(
check_device_drivers("cuda"), reason=device_driver_info("gpu")
check_device_drivers("cuda"), reason=device_driver_info("cuda")
)
def test_module_static_cuda(self):
dynamic = False