fix cuda on windows

This commit is contained in:
Irwin1138
2025-11-23 20:59:14 +02:00
parent 63a931ff76
commit 7aea5256b1
5 changed files with 30 additions and 9 deletions

View File

@@ -28,9 +28,9 @@ def __getattr__(nm):
[i for i in system("dpkg -L libc6-dev").split() if 'sys/mman.h' in i or 'sys/syscall.h' in i] +
["/usr/include/string.h", "/usr/include/elf.h", "/usr/include/unistd.h", "/usr/include/asm-generic/mman-common.h"]), use_errno=True)
case "opencl": return load("opencl", ["find_library('OpenCL')"], ["/usr/include/CL/cl.h"])
case "cuda": return load("cuda", ["find_library('cuda')"], ["/usr/include/cuda.h"], args=["-D__CUDA_API_VERSION_INTERNAL"], parse_macros=False)
case "nvrtc": return load("nvrtc", ["find_library('nvrtc')"], ["/usr/include/nvrtc.h"])
case "nvjitlink": load("nvjitlink", ["find_library('nvJitLink')"], [root/"extra/nvJitLink.h"])
case "cuda": return load("cuda", ["CUDA_PATH"], ["/usr/include/cuda.h"], args=["-D__CUDA_API_VERSION_INTERNAL"], parse_macros=False, prolog=["from tinygrad.runtime.support.cuda import CUDA_PATH"])
case "nvrtc": return load("nvrtc", ["NVRTC_PATH"], ["/usr/include/nvrtc.h"], prolog=["from tinygrad.runtime.support.cuda import NVRTC_PATH"])
case "nvjitlink": load("nvjitlink", ["NVJITLINK_PATH"], [root/"extra/nvJitLink.h"], prolog=["from tinygrad.runtime.support.cuda import NVJITLINK_PATH"])
case "kfd": return load("kfd", [], ["/usr/include/linux/kfd_ioctl.h"])
case "nv_570" | "nv_580":
return load(nm, [], [

View File

@@ -2,9 +2,9 @@
import ctypes
from tinygrad.helpers import unwrap
from tinygrad.runtime.support.c import Struct, CEnum, _IO, _IOW, _IOR, _IOWR
from ctypes.util import find_library
from tinygrad.runtime.support.cuda import CUDA_PATH
def dll():
try: return ctypes.CDLL(unwrap(find_library('cuda')))
try: return ctypes.CDLL(unwrap(CUDA_PATH))
except: pass
return None
dll = dll()

View File

@@ -2,9 +2,9 @@
import ctypes
from tinygrad.helpers import unwrap
from tinygrad.runtime.support.c import Struct, CEnum, _IO, _IOW, _IOR, _IOWR
from ctypes.util import find_library
from tinygrad.runtime.support.cuda import NVJITLINK_PATH
def dll():
try: return ctypes.CDLL(unwrap(find_library('nvJitLink')))
try: return ctypes.CDLL(unwrap(NVJITLINK_PATH))
except: pass
return None
dll = dll()

View File

@@ -2,9 +2,9 @@
import ctypes
from tinygrad.helpers import unwrap
from tinygrad.runtime.support.c import Struct, CEnum, _IO, _IOW, _IOR, _IOWR
from ctypes.util import find_library
from tinygrad.runtime.support.cuda import NVRTC_PATH
def dll():
try: return ctypes.CDLL(unwrap(find_library('nvrtc')))
try: return ctypes.CDLL(unwrap(NVRTC_PATH))
except: pass
return None
dll = dll()

View File

@@ -0,0 +1,21 @@
import ctypes.util, os, platform
from ctypes.util import find_library
CUDA_PATH: str | None
NVRTC_PATH: str | None
NVJITLINK_PATH: str | None
if platform.system() == "Windows":
import glob, os.path
def find_nv_dll(glob_pattern):
cuda_bin = os.path.join(os.environ.get("CUDA_PATH", ""), "bin")
matches = glob.glob(os.path.join(cuda_bin, glob_pattern))
return matches[0] if matches else None
CUDA_PATH = find_library('nvcuda')
NVRTC_PATH = find_nv_dll("nvrtc64_*.dll")
NVJITLINK_PATH = find_nv_dll("nvJitLink_*.dll")
else:
CUDA_PATH = find_library('cuda')
NVRTC_PATH = find_library('nvrtc')
NVJITLINK_PATH = find_library('nvJitLink')