nv: check if jitlink is avail (#12808)

* nv: check if jitlink is avail

* why

* fix

* fix
This commit is contained in:
nimlgen
2025-10-20 18:13:16 +08:00
committed by GitHub
parent b8a9cce783
commit b5e36e3c6c

View File

@@ -69,7 +69,9 @@ class PTXCompiler(Compiler):
def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch)
class NVPTXCompiler(PTXCompiler):
def __init__(self, arch:str): super().__init__(arch, cache_key="nv_ptx")
def __init__(self, arch:str):
nvrtc_check(nvrtc.nvJitLinkVersion(ctypes.byref(ctypes.c_uint()), ctypes.byref(ctypes.c_uint())))
super().__init__(arch, cache_key="nv_ptx")
def compile(self, src:str) -> bytes:
jitlink_check(nvrtc.nvJitLinkCreate(handle := nvrtc.nvJitLinkHandle(), 1, to_char_p_p([f'-arch={self.arch}'.encode()])), handle)
jitlink_check(nvrtc.nvJitLinkAddData(handle, nvrtc.NVJITLINK_INPUT_PTX, ptxsrc:=super().compile(src), len(ptxsrc), "<null>".encode()), handle)