diff --git a/tinygrad/runtime/autogen/__init__.py b/tinygrad/runtime/autogen/__init__.py index 1d7e9be2f7..9e0bd00145 100644 --- a/tinygrad/runtime/autogen/__init__.py +++ b/tinygrad/runtime/autogen/__init__.py @@ -28,9 +28,9 @@ def __getattr__(nm): [i for i in system("dpkg -L libc6-dev").split() if 'sys/mman.h' in i or 'sys/syscall.h' in i] + ["/usr/include/string.h", "/usr/include/elf.h", "/usr/include/unistd.h", "/usr/include/asm-generic/mman-common.h"]), use_errno=True) case "opencl": return load("opencl", ["find_library('OpenCL')"], ["/usr/include/CL/cl.h"]) - case "cuda": return load("cuda", ["find_library('cuda')"], ["/usr/include/cuda.h"], args=["-D__CUDA_API_VERSION_INTERNAL"], parse_macros=False) - case "nvrtc": return load("nvrtc", ["find_library('nvrtc')"], ["/usr/include/nvrtc.h"]) - case "nvjitlink": load("nvjitlink", ["find_library('nvJitLink')"], [root/"extra/nvJitLink.h"]) + case "cuda": return load("cuda", ["CUDA_PATH"], ["/usr/include/cuda.h"], args=["-D__CUDA_API_VERSION_INTERNAL"], parse_macros=False, prolog=["from tinygrad.runtime.support.cuda import CUDA_PATH"]) + case "nvrtc": return load("nvrtc", ["NVRTC_PATH"], ["/usr/include/nvrtc.h"], prolog=["from tinygrad.runtime.support.cuda import NVRTC_PATH"]) + case "nvjitlink": load("nvjitlink", ["NVJITLINK_PATH"], [root/"extra/nvJitLink.h"], prolog=["from tinygrad.runtime.support.cuda import NVJITLINK_PATH"]) case "kfd": return load("kfd", [], ["/usr/include/linux/kfd_ioctl.h"]) case "nv_570" | "nv_580": return load(nm, [], [ diff --git a/tinygrad/runtime/autogen/cuda.py b/tinygrad/runtime/autogen/cuda.py index e5a944cdbb..3ed11047d5 100644 --- a/tinygrad/runtime/autogen/cuda.py +++ b/tinygrad/runtime/autogen/cuda.py @@ -2,9 +2,9 @@ import ctypes from tinygrad.helpers import unwrap from tinygrad.runtime.support.c import Struct, CEnum, _IO, _IOW, _IOR, _IOWR -from ctypes.util import find_library +from tinygrad.runtime.support.cuda import CUDA_PATH def dll(): - try: return ctypes.CDLL(unwrap(find_library('cuda'))) + try: return ctypes.CDLL(unwrap(CUDA_PATH)) except: pass return None dll = dll() diff --git a/tinygrad/runtime/autogen/nvjitlink.py b/tinygrad/runtime/autogen/nvjitlink.py index 7e07a40ae8..6cd4124682 100644 --- a/tinygrad/runtime/autogen/nvjitlink.py +++ b/tinygrad/runtime/autogen/nvjitlink.py @@ -2,9 +2,9 @@ import ctypes from tinygrad.helpers import unwrap from tinygrad.runtime.support.c import Struct, CEnum, _IO, _IOW, _IOR, _IOWR -from ctypes.util import find_library +from tinygrad.runtime.support.cuda import NVJITLINK_PATH def dll(): - try: return ctypes.CDLL(unwrap(find_library('nvJitLink'))) + try: return ctypes.CDLL(unwrap(NVJITLINK_PATH)) except: pass return None dll = dll() diff --git a/tinygrad/runtime/autogen/nvrtc.py b/tinygrad/runtime/autogen/nvrtc.py index ce7083f553..26f4fd92ca 100644 --- a/tinygrad/runtime/autogen/nvrtc.py +++ b/tinygrad/runtime/autogen/nvrtc.py @@ -2,9 +2,9 @@ import ctypes from tinygrad.helpers import unwrap from tinygrad.runtime.support.c import Struct, CEnum, _IO, _IOW, _IOR, _IOWR -from ctypes.util import find_library +from tinygrad.runtime.support.cuda import NVRTC_PATH def dll(): - try: return ctypes.CDLL(unwrap(find_library('nvrtc'))) + try: return ctypes.CDLL(unwrap(NVRTC_PATH)) except: pass return None dll = dll() diff --git a/tinygrad/runtime/support/cuda.py b/tinygrad/runtime/support/cuda.py new file mode 100644 index 0000000000..ed2e112610 --- /dev/null +++ b/tinygrad/runtime/support/cuda.py @@ -0,0 +1,21 @@ +import ctypes.util, os, platform +from ctypes.util import find_library + +CUDA_PATH: str | None +NVRTC_PATH: str | None +NVJITLINK_PATH: str | None + +if platform.system() == "Windows": + import glob, os.path + def find_nv_dll(glob_pattern): + cuda_bin = os.path.join(os.environ.get("CUDA_PATH", ""), "bin") + matches = glob.glob(os.path.join(cuda_bin, glob_pattern)) + return matches[0] if matches else None + + CUDA_PATH = find_library('nvcuda') + NVRTC_PATH = find_nv_dll("nvrtc64_*.dll") + NVJITLINK_PATH = find_nv_dll("nvJitLink_*.dll") +else: + CUDA_PATH = find_library('cuda') + NVRTC_PATH = find_library('nvrtc') + NVJITLINK_PATH = find_library('nvJitLink') \ No newline at end of file