nv correct error messages with ptx (#5341)

* nv correct error messages with ptx

* return compile error
This commit is contained in:
nimlgen
2024-07-09 10:39:39 +03:00
committed by GitHub
parent c13da83f12
commit a2a9bfd2ec

View File

@@ -78,14 +78,17 @@ class NVCompiler(Compiler):
raise CompileError(f"compile failed: {_get_bytes(prog, nvrtc.nvrtcGetProgramLog, nvrtc.nvrtcGetProgramLogSize, cuda_check).decode()}")
return _get_bytes(prog, nvrtc.nvrtcGetCUBIN, nvrtc.nvrtcGetCUBINSize, cuda_check)
def jitlink_check(status):
if status != 0: raise CompileError(f"NvJitLink Error {status}, {nvrtc.nvJitLinkResult__enumvalues.get(status, 'Unknown')}")
class NVPTXCompiler(NVCompiler):
def compile(self, src:str) -> bytes:
ptxsrc = src.replace("TARGET", self.arch).replace("VERSION", "7.8" if self.arch >= "sm_89" else "7.5")
cuda_check(nvrtc.nvJitLinkCreate(handle := nvrtc.nvJitLinkHandle(), 1, to_char_p_p([f'-arch={self.arch}'.encode()])))
cuda_check(nvrtc.nvJitLinkAddData(handle, nvrtc.NVJITLINK_INPUT_PTX, ptxsrc.encode(), len(ptxsrc), "<null>".encode()))
jitlink_check(nvrtc.nvJitLinkCreate(handle := nvrtc.nvJitLinkHandle(), 1, to_char_p_p([f'-arch={self.arch}'.encode()])))
jitlink_check(nvrtc.nvJitLinkAddData(handle, nvrtc.NVJITLINK_INPUT_PTX, ptxsrc.encode(), len(ptxsrc), "<null>".encode()))
if nvrtc.nvJitLinkComplete(handle) != 0:
raise CompileError(f"compile failed: {_get_bytes(handle, nvrtc.nvJitLinkGetErrorLog, nvrtc.nvJitLinkGetErrorLogSize, cuda_check).decode()}")
return _get_bytes(handle, nvrtc.nvJitLinkGetLinkedCubin, nvrtc.nvJitLinkGetLinkedCubinSize, cuda_check)
raise CompileError(f"compile failed: {_get_bytes(handle, nvrtc.nvJitLinkGetErrorLog, nvrtc.nvJitLinkGetErrorLogSize, jitlink_check).decode()}")
return _get_bytes(handle, nvrtc.nvJitLinkGetLinkedCubin, nvrtc.nvJitLinkGetLinkedCubinSize, jitlink_check)
class HWQueue:
def __init__(self): self.q, self.binded_device, self.cmd_offsets = [], None, [0]