nv ptx print log (#5691)

This commit is contained in:
nimlgen
2024-07-24 21:40:58 +03:00
committed by GitHub
parent bf24be4c8c
commit b026312a31

View File

@@ -78,16 +78,17 @@ class NVCompiler(Compiler):
raise CompileError(f"compile failed: {_get_bytes(prog, nvrtc.nvrtcGetProgramLog, nvrtc.nvrtcGetProgramLogSize, cuda_check).decode()}")
return _get_bytes(prog, nvrtc.nvrtcGetCUBIN, nvrtc.nvrtcGetCUBINSize, cuda_check)
def jitlink_check(status):
if status != 0: raise CompileError(f"NvJitLink Error {status}, {nvrtc.nvJitLinkResult__enumvalues.get(status, 'Unknown')}")
def jitlink_check(status, ctx=None):
if status != 0:
err_log = _get_bytes(ctx, nvrtc.nvJitLinkGetErrorLog, nvrtc.nvJitLinkGetErrorLogSize, cuda_check).decode() if ctx else ""
raise CompileError(f"NvJitLink Error {status}, {nvrtc.nvJitLinkResult__enumvalues.get(status, 'Unknown')}\n{err_log}")
class NVPTXCompiler(NVCompiler):
def compile(self, src:str) -> bytes:
ptxsrc = src.replace("TARGET", self.arch).replace("VERSION", "7.8" if self.arch >= "sm_89" else "7.5")
jitlink_check(nvrtc.nvJitLinkCreate(handle := nvrtc.nvJitLinkHandle(), 1, to_char_p_p([f'-arch={self.arch}'.encode()])))
jitlink_check(nvrtc.nvJitLinkAddData(handle, nvrtc.NVJITLINK_INPUT_PTX, ptxsrc.encode(), len(ptxsrc), "<null>".encode()))
if nvrtc.nvJitLinkComplete(handle) != 0:
raise CompileError(f"compile failed: {_get_bytes(handle, nvrtc.nvJitLinkGetErrorLog, nvrtc.nvJitLinkGetErrorLogSize, jitlink_check).decode()}")
jitlink_check(nvrtc.nvJitLinkCreate(handle := nvrtc.nvJitLinkHandle(), 1, to_char_p_p([f'-arch={self.arch}'.encode()])), handle)
jitlink_check(nvrtc.nvJitLinkAddData(handle, nvrtc.NVJITLINK_INPUT_PTX, ptxsrc.encode(), len(ptxsrc), "<null>".encode()), handle)
jitlink_check(nvrtc.nvJitLinkComplete(handle), handle)
return _get_bytes(handle, nvrtc.nvJitLinkGetLinkedCubin, nvrtc.nvJitLinkGetLinkedCubinSize, jitlink_check)
class NVSignal(HCQSignal):