This commit is contained in:
unknown
2025-12-29 15:30:33 -08:00
parent ddd5b6f974
commit 70b1994b10

View File

@@ -285,17 +285,23 @@ class DeviceInstaller():
return False
return False
def is_wsl2():
if os.name != "posix":
return False
try:
with open("/proc/version", "r", encoding="utf-8", errors="ignore") as f:
return "microsoft" in f.read().lower()
except Exception:
return False
def has_working_cuda():
# CUDA does not exist on macOS
if sys.platform == "darwin":
return False
# nvidia-smi is the only reliable cross-platform signal
if not has_cmd("nvidia-smi"):
return False
out = try_cmd("nvidia-smi -L").lower()
if not out:
return False
# Guard against common failure states
if "failed" in out or "error" in out or "no devices were found" in out:
return False
return "gpu" in out
@@ -447,30 +453,47 @@ class DeviceInstaller():
# ============================================================
# CUDA
# ============================================================
elif has_working_cuda() and has_nvidia_gpu_pci():
elif has_working_cuda() and (has_nvidia_gpu_pci() or is_wsl2()):
version_out = ''
if os.name == 'posix':
for p in (
'/usr/local/cuda/version.json',
'/usr/local/cuda/version.txt',
):
if os.path.exists(p):
with open(p, 'r', encoding='utf-8', errors='ignore') as f:
version_out = f.read()
break
elif os.name == 'nt':
cuda_path = os.environ.get('CUDA_PATH')
if cuda_path:
# 1) CUDA RUNTIME detection
try:
import ctypes
libcudart = ctypes.CDLL("libcudart.so")
version = ctypes.c_int()
if libcudart.cudaRuntimeGetVersion(ctypes.byref(version)) == 0:
device_count = ctypes.c_int()
if libcudart.cudaGetDeviceCount(ctypes.byref(device_count)) == 0:
if device_count.value > 0:
v = version.value
major = v // 1000
minor = (v % 1000) // 10
version_out = f"{major}.{minor}"
except Exception:
pass
# CUDA TOOLKIT detection (fallback only)
if not version_out:
if os.name == 'posix':
for p in (
os.path.join(cuda_path, 'version.json'),
os.path.join(cuda_path, 'version.txt'),
'/usr/local/cuda/version.json',
'/usr/local/cuda/version.txt',
):
if os.path.exists(p):
with open(p, 'r', encoding='utf-8', errors='ignore') as f:
version_out = f.read()
break
elif os.name == 'nt':
cuda_path = os.environ.get('CUDA_PATH')
if cuda_path:
for p in (
os.path.join(cuda_path, 'version.json'),
os.path.join(cuda_path, 'version.txt'),
):
if os.path.exists(p):
with open(p, 'r', encoding='utf-8', errors='ignore') as f:
version_out = f.read()
break
if not version_out:
msg = 'CUDA hardware detected but NVIDIA CUDA Toolkit not installed.'
msg = 'CUDA runtime detected but NVIDIA CUDA Toolkit not installed.'
else:
version_str = toolkit_version_parse(version_out)
cmp = toolkit_version_compare(version_str, cuda_version_range)