diff --git a/lib/classes/device_installer.py b/lib/classes/device_installer.py index be5b14b6..1ebe5527 100644 --- a/lib/classes/device_installer.py +++ b/lib/classes/device_installer.py @@ -455,10 +455,32 @@ class DeviceInstaller(): # ============================================================ elif has_working_cuda() and (has_nvidia_gpu_pci() or is_wsl2()): version_out = '' + msg = '' # 1) CUDA RUNTIME detection try: import ctypes - libcudart = ctypes.CDLL("libcudart.so") + libcudart = None + if os.name == "nt": + # Native Windows: CUDA runtime is a DLL + for dll in ( + "cudart64_130.dll", + "cudart64_121.dll", + "cudart64_120.dll", + "cudart64_118.dll", + "cudart64_117.dll", + "cudart64_116.dll", + "cudart64_115.dll", + ): + try: + libcudart = ctypes.CDLL(dll) + break + except OSError: + pass + else: + # Linux + WSL2 + libcudart = ctypes.CDLL("libcudart.so") + if not libcudart: + raise OSError version = ctypes.c_int() if libcudart.cudaRuntimeGetVersion(ctypes.byref(version)) == 0: device_count = ctypes.c_int() @@ -468,7 +490,9 @@ class DeviceInstaller(): major = v // 1000 minor = (v % 1000) // 10 version_out = f"{major}.{minor}" - except Exception: + else: + msg = f'Runtime present ({version.value}) but no devices' + except (OSError, AttributeError): pass # CUDA TOOLKIT detection (fallback only) if not version_out: @@ -493,11 +517,11 @@ class DeviceInstaller(): version_out = f.read() break if not version_out: - msg = 'CUDA runtime detected but NVIDIA CUDA Toolkit not installed.' + if msg == '': + msg = 'CUDA runtime detected but NVIDIA CUDA Toolkit or Runtime not installed.' else: version_str = toolkit_version_parse(version_out) cmp = toolkit_version_compare(version_str, cuda_version_range) - if cmp == -1: msg = f'CUDA {version_str} < min {cuda_version_range["min"]}. Please upgrade.' elif cmp == 1: