mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
fix exception in cuda bindings code on windows (#13823)
* fix cuda on windows * fix linter errors * test github action install cuda-toolkit * Revert "test github action install cuda-toolkit" This reverts commitc18ad6f937. * Revert "fix linter errors" This reverts commit00aa943e91. * Revert "fix cuda on windows" This reverts commit7aea5256b1. * fix windows sysconfig.get_config_var("MULTIARCH") is None
This commit is contained in:
@@ -12,7 +12,7 @@ llvm_lib = (r"'C:\\Program Files\\LLVM\\bin\\LLVM-C.dll' if WIN else '/opt/homeb
|
||||
repr(['LLVM'] + [f'LLVM-{i}' for i in reversed(range(14, 21+1))]))
|
||||
|
||||
webgpu_lib = "os.path.join(sysconfig.get_paths()['purelib'], 'pydawn', 'lib', 'libwebgpu_dawn.dll') if WIN else 'webgpu_dawn'"
|
||||
nv_lib_path = "f'/usr/local/cuda/targets/{sysconfig.get_config_var(\"MULTIARCH\").rsplit(\"-\", 1)[0]}/lib'"
|
||||
nv_lib_path = "f'/usr/local/cuda/targets/{sysconfig.get_config_vars().get(\"MULTIARCH\", \"\").rsplit(\"-\", 1)[0]}/lib'"
|
||||
|
||||
def load(name, dll, files, **kwargs):
|
||||
if not (f:=(root/(path:=kwargs.pop("path", __name__)).replace('.','/')/f"{name}.py")).exists() or getenv('REGEN'):
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import ctypes
|
||||
from tinygrad.runtime.support.c import DLL, Struct, CEnum, _IO, _IOW, _IOR, _IOWR
|
||||
import sysconfig
|
||||
dll = DLL('nvjitlink', 'nvJitLink', f'/usr/local/cuda/targets/{sysconfig.get_config_var("MULTIARCH").rsplit("-", 1)[0]}/lib')
|
||||
dll = DLL('nvjitlink', 'nvJitLink', f'/usr/local/cuda/targets/{sysconfig.get_config_vars().get("MULTIARCH", "").rsplit("-", 1)[0]}/lib')
|
||||
nvJitLinkResult = CEnum(ctypes.c_uint32)
|
||||
NVJITLINK_SUCCESS = nvJitLinkResult.define('NVJITLINK_SUCCESS', 0)
|
||||
NVJITLINK_ERROR_UNRECOGNIZED_OPTION = nvJitLinkResult.define('NVJITLINK_ERROR_UNRECOGNIZED_OPTION', 1)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import ctypes
|
||||
from tinygrad.runtime.support.c import DLL, Struct, CEnum, _IO, _IOW, _IOR, _IOWR
|
||||
import sysconfig
|
||||
dll = DLL('nvrtc', 'nvrtc', f'/usr/local/cuda/targets/{sysconfig.get_config_var("MULTIARCH").rsplit("-", 1)[0]}/lib')
|
||||
dll = DLL('nvrtc', 'nvrtc', f'/usr/local/cuda/targets/{sysconfig.get_config_vars().get("MULTIARCH", "").rsplit("-", 1)[0]}/lib')
|
||||
nvrtcResult = CEnum(ctypes.c_uint32)
|
||||
NVRTC_SUCCESS = nvrtcResult.define('NVRTC_SUCCESS', 0)
|
||||
NVRTC_ERROR_OUT_OF_MEMORY = nvrtcResult.define('NVRTC_ERROR_OUT_OF_MEMORY', 1)
|
||||
|
||||
Reference in New Issue
Block a user