From a2a9bfd2ece02628b70e2a45e9775a100ad1557b Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Tue, 9 Jul 2024 10:39:39 +0300 Subject: [PATCH] nv correct error messages with ptx (#5341) * nv correct error messages with ptx * return compile error --- tinygrad/runtime/ops_nv.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tinygrad/runtime/ops_nv.py b/tinygrad/runtime/ops_nv.py index c9bab072b6..9f95b9c44d 100644 --- a/tinygrad/runtime/ops_nv.py +++ b/tinygrad/runtime/ops_nv.py @@ -78,14 +78,17 @@ class NVCompiler(Compiler): raise CompileError(f"compile failed: {_get_bytes(prog, nvrtc.nvrtcGetProgramLog, nvrtc.nvrtcGetProgramLogSize, cuda_check).decode()}") return _get_bytes(prog, nvrtc.nvrtcGetCUBIN, nvrtc.nvrtcGetCUBINSize, cuda_check) +def jitlink_check(status): + if status != 0: raise CompileError(f"NvJitLink Error {status}, {nvrtc.nvJitLinkResult__enumvalues.get(status, 'Unknown')}") + class NVPTXCompiler(NVCompiler): def compile(self, src:str) -> bytes: ptxsrc = src.replace("TARGET", self.arch).replace("VERSION", "7.8" if self.arch >= "sm_89" else "7.5") - cuda_check(nvrtc.nvJitLinkCreate(handle := nvrtc.nvJitLinkHandle(), 1, to_char_p_p([f'-arch={self.arch}'.encode()]))) - cuda_check(nvrtc.nvJitLinkAddData(handle, nvrtc.NVJITLINK_INPUT_PTX, ptxsrc.encode(), len(ptxsrc), "".encode())) + jitlink_check(nvrtc.nvJitLinkCreate(handle := nvrtc.nvJitLinkHandle(), 1, to_char_p_p([f'-arch={self.arch}'.encode()]))) + jitlink_check(nvrtc.nvJitLinkAddData(handle, nvrtc.NVJITLINK_INPUT_PTX, ptxsrc.encode(), len(ptxsrc), "".encode())) if nvrtc.nvJitLinkComplete(handle) != 0: - raise CompileError(f"compile failed: {_get_bytes(handle, nvrtc.nvJitLinkGetErrorLog, nvrtc.nvJitLinkGetErrorLogSize, cuda_check).decode()}") - return _get_bytes(handle, nvrtc.nvJitLinkGetLinkedCubin, nvrtc.nvJitLinkGetLinkedCubinSize, cuda_check) + raise CompileError(f"compile failed: {_get_bytes(handle, nvrtc.nvJitLinkGetErrorLog, nvrtc.nvJitLinkGetErrorLogSize, jitlink_check).decode()}") + return _get_bytes(handle, nvrtc.nvJitLinkGetLinkedCubin, nvrtc.nvJitLinkGetLinkedCubinSize, jitlink_check) class HWQueue: def __init__(self): self.q, self.binded_device, self.cmd_offsets = [], None, [0]