This commit is contained in:
unknown
2025-12-29 15:30:33 -08:00
parent ddd5b6f974
commit 70b1994b10

View File

@@ -285,17 +285,23 @@ class DeviceInstaller():
return False
return False
def is_wsl2():
if os.name != "posix":
return False
try:
with open("/proc/version", "r", encoding="utf-8", errors="ignore") as f:
return "microsoft" in f.read().lower()
except Exception:
return False
def has_working_cuda():
# CUDA does not exist on macOS
if sys.platform == "darwin":
return False
# nvidia-smi is the only reliable cross-platform signal
if not has_cmd("nvidia-smi"):
return False
out = try_cmd("nvidia-smi -L").lower()
if not out:
return False
# Guard against common failure states
if "failed" in out or "error" in out or "no devices were found" in out:
return False
return "gpu" in out
@@ -447,8 +453,25 @@ class DeviceInstaller():
# ============================================================
# CUDA
# ============================================================
elif has_working_cuda() and has_nvidia_gpu_pci():
elif has_working_cuda() and (has_nvidia_gpu_pci() or is_wsl2()):
version_out = ''
# 1) CUDA RUNTIME detection
try:
import ctypes
libcudart = ctypes.CDLL("libcudart.so")
version = ctypes.c_int()
if libcudart.cudaRuntimeGetVersion(ctypes.byref(version)) == 0:
device_count = ctypes.c_int()
if libcudart.cudaGetDeviceCount(ctypes.byref(device_count)) == 0:
if device_count.value > 0:
v = version.value
major = v // 1000
minor = (v % 1000) // 10
version_out = f"{major}.{minor}"
except Exception:
pass
# CUDA TOOLKIT detection (fallback only)
if not version_out:
if os.name == 'posix':
for p in (
'/usr/local/cuda/version.json',
@@ -470,7 +493,7 @@ class DeviceInstaller():
version_out = f.read()
break
if not version_out:
msg = 'CUDA hardware detected but NVIDIA CUDA Toolkit not installed.'
msg = 'CUDA runtime detected but NVIDIA CUDA Toolkit not installed.'
else:
version_str = toolkit_version_parse(version_out)
cmp = toolkit_version_compare(version_str, cuda_version_range)