mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
fix cuda on windows
This commit is contained in:
@@ -28,9 +28,9 @@ def __getattr__(nm):
|
||||
[i for i in system("dpkg -L libc6-dev").split() if 'sys/mman.h' in i or 'sys/syscall.h' in i] +
|
||||
["/usr/include/string.h", "/usr/include/elf.h", "/usr/include/unistd.h", "/usr/include/asm-generic/mman-common.h"]), use_errno=True)
|
||||
case "opencl": return load("opencl", ["find_library('OpenCL')"], ["/usr/include/CL/cl.h"])
|
||||
case "cuda": return load("cuda", ["find_library('cuda')"], ["/usr/include/cuda.h"], args=["-D__CUDA_API_VERSION_INTERNAL"], parse_macros=False)
|
||||
case "nvrtc": return load("nvrtc", ["find_library('nvrtc')"], ["/usr/include/nvrtc.h"])
|
||||
case "nvjitlink": load("nvjitlink", ["find_library('nvJitLink')"], [root/"extra/nvJitLink.h"])
|
||||
case "cuda": return load("cuda", ["CUDA_PATH"], ["/usr/include/cuda.h"], args=["-D__CUDA_API_VERSION_INTERNAL"], parse_macros=False, prolog=["from tinygrad.runtime.support.cuda import CUDA_PATH"])
|
||||
case "nvrtc": return load("nvrtc", ["NVRTC_PATH"], ["/usr/include/nvrtc.h"], prolog=["from tinygrad.runtime.support.cuda import NVRTC_PATH"])
|
||||
case "nvjitlink": load("nvjitlink", ["NVJITLINK_PATH"], [root/"extra/nvJitLink.h"], prolog=["from tinygrad.runtime.support.cuda import NVJITLINK_PATH"])
|
||||
case "kfd": return load("kfd", [], ["/usr/include/linux/kfd_ioctl.h"])
|
||||
case "nv_570" | "nv_580":
|
||||
return load(nm, [], [
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
import ctypes
|
||||
from tinygrad.helpers import unwrap
|
||||
from tinygrad.runtime.support.c import Struct, CEnum, _IO, _IOW, _IOR, _IOWR
|
||||
from ctypes.util import find_library
|
||||
from tinygrad.runtime.support.cuda import CUDA_PATH
|
||||
def dll():
|
||||
try: return ctypes.CDLL(unwrap(find_library('cuda')))
|
||||
try: return ctypes.CDLL(unwrap(CUDA_PATH))
|
||||
except: pass
|
||||
return None
|
||||
dll = dll()
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
import ctypes
|
||||
from tinygrad.helpers import unwrap
|
||||
from tinygrad.runtime.support.c import Struct, CEnum, _IO, _IOW, _IOR, _IOWR
|
||||
from ctypes.util import find_library
|
||||
from tinygrad.runtime.support.cuda import NVJITLINK_PATH
|
||||
def dll():
|
||||
try: return ctypes.CDLL(unwrap(find_library('nvJitLink')))
|
||||
try: return ctypes.CDLL(unwrap(NVJITLINK_PATH))
|
||||
except: pass
|
||||
return None
|
||||
dll = dll()
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
import ctypes
|
||||
from tinygrad.helpers import unwrap
|
||||
from tinygrad.runtime.support.c import Struct, CEnum, _IO, _IOW, _IOR, _IOWR
|
||||
from ctypes.util import find_library
|
||||
from tinygrad.runtime.support.cuda import NVRTC_PATH
|
||||
def dll():
|
||||
try: return ctypes.CDLL(unwrap(find_library('nvrtc')))
|
||||
try: return ctypes.CDLL(unwrap(NVRTC_PATH))
|
||||
except: pass
|
||||
return None
|
||||
dll = dll()
|
||||
|
||||
21
tinygrad/runtime/support/cuda.py
Normal file
21
tinygrad/runtime/support/cuda.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import ctypes.util, os, platform
|
||||
from ctypes.util import find_library
|
||||
|
||||
CUDA_PATH: str | None
|
||||
NVRTC_PATH: str | None
|
||||
NVJITLINK_PATH: str | None
|
||||
|
||||
if platform.system() == "Windows":
|
||||
import glob, os.path
|
||||
def find_nv_dll(glob_pattern):
|
||||
cuda_bin = os.path.join(os.environ.get("CUDA_PATH", ""), "bin")
|
||||
matches = glob.glob(os.path.join(cuda_bin, glob_pattern))
|
||||
return matches[0] if matches else None
|
||||
|
||||
CUDA_PATH = find_library('nvcuda')
|
||||
NVRTC_PATH = find_nv_dll("nvrtc64_*.dll")
|
||||
NVJITLINK_PATH = find_nv_dll("nvJitLink_*.dll")
|
||||
else:
|
||||
CUDA_PATH = find_library('cuda')
|
||||
NVRTC_PATH = find_library('nvrtc')
|
||||
NVJITLINK_PATH = find_library('nvJitLink')
|
||||
Reference in New Issue
Block a user