mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
nv: check if jitlink is avail (#12808)
* nv: check if jitlink is avail * why * fix * fix
This commit is contained in:
@@ -69,7 +69,9 @@ class PTXCompiler(Compiler):
|
||||
def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch)
|
||||
|
||||
class NVPTXCompiler(PTXCompiler):
|
||||
def __init__(self, arch:str): super().__init__(arch, cache_key="nv_ptx")
|
||||
def __init__(self, arch:str):
|
||||
nvrtc_check(nvrtc.nvJitLinkVersion(ctypes.byref(ctypes.c_uint()), ctypes.byref(ctypes.c_uint())))
|
||||
super().__init__(arch, cache_key="nv_ptx")
|
||||
def compile(self, src:str) -> bytes:
|
||||
jitlink_check(nvrtc.nvJitLinkCreate(handle := nvrtc.nvJitLinkHandle(), 1, to_char_p_p([f'-arch={self.arch}'.encode()])), handle)
|
||||
jitlink_check(nvrtc.nvJitLinkAddData(handle, nvrtc.NVJITLINK_INPUT_PTX, ptxsrc:=super().compile(src), len(ptxsrc), "<null>".encode()), handle)
|
||||
|
||||
Reference in New Issue
Block a user