This commit is contained in:
unknown
2025-12-29 17:33:13 -08:00
parent 70b1994b10
commit fd2521e0ef

View File

@@ -455,10 +455,32 @@ class DeviceInstaller():
# ============================================================
elif has_working_cuda() and (has_nvidia_gpu_pci() or is_wsl2()):
version_out = ''
msg = ''
# 1) CUDA RUNTIME detection
try:
import ctypes
libcudart = ctypes.CDLL("libcudart.so")
libcudart = None
if os.name == "nt":
# Native Windows: CUDA runtime is a DLL
for dll in (
"cudart64_130.dll",
"cudart64_121.dll",
"cudart64_120.dll",
"cudart64_118.dll",
"cudart64_117.dll",
"cudart64_116.dll",
"cudart64_115.dll",
):
try:
libcudart = ctypes.CDLL(dll)
break
except OSError:
pass
else:
# Linux + WSL2
libcudart = ctypes.CDLL("libcudart.so")
if not libcudart:
raise OSError
version = ctypes.c_int()
if libcudart.cudaRuntimeGetVersion(ctypes.byref(version)) == 0:
device_count = ctypes.c_int()
@@ -468,7 +490,9 @@ class DeviceInstaller():
major = v // 1000
minor = (v % 1000) // 10
version_out = f"{major}.{minor}"
except Exception:
else:
msg = f'Runtime present ({version.value}) but no devices'
except (OSError, AttributeError):
pass
# CUDA TOOLKIT detection (fallback only)
if not version_out:
@@ -493,11 +517,11 @@ class DeviceInstaller():
version_out = f.read()
break
if not version_out:
msg = 'CUDA runtime detected but NVIDIA CUDA Toolkit not installed.'
if msg == '':
msg = 'CUDA runtime detected but NVIDIA CUDA Toolkit or Runtime not installed.'
else:
version_str = toolkit_version_parse(version_out)
cmp = toolkit_version_compare(version_str, cuda_version_range)
if cmp == -1:
msg = f'CUDA {version_str} < min {cuda_version_range["min"]}. Please upgrade.'
elif cmp == 1: