From f9ca072b618d4601a1784171aede748e32e2e789 Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Thu, 15 Jan 2026 16:02:40 -0800 Subject: [PATCH] cuda compilers disassemble properly (#14166) * cuda compilers disassemble properly * this can use system --- tinygrad/runtime/support/compiler_cuda.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tinygrad/runtime/support/compiler_cuda.py b/tinygrad/runtime/support/compiler_cuda.py index 2f846f5bba..bbf4470826 100644 --- a/tinygrad/runtime/support/compiler_cuda.py +++ b/tinygrad/runtime/support/compiler_cuda.py @@ -34,11 +34,11 @@ def pretty_ptx(s): s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives return s -def cuda_disassemble(lib:bytes, arch:str): +def cuda_disassemble(lib:bytes, arch:str, ptx=False): try: fn = (pathlib.Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix() - with open(fn, "wb") as f: f.write(lib.rstrip(b'\x00')) - subprocess.run(["ptxas", f"-arch={arch}", "-o", fn, fn], check=False, stderr=subprocess.DEVNULL) # optional ptx -> sass step for CUDA=1 + with open(fn, "wb") as f: f.write(lib.rstrip(b'\x00') if ptx else lib) + if ptx: system(f"ptxas -arch={arch} -o {fn} {fn}") print(system(f'nvdisasm {fn}')) except Exception as e: print("Failed to generate SASS", str(e), "Make sure your PATH contains ptxas/nvdisasm binary of compatible version.") @@ -56,11 +56,12 @@ class CUDACompiler(Compiler): nvrtc_check(nvrtc.nvrtcDestroyProgram(ctypes.byref(prog))) return data def compile(self, src:str) -> bytes: return self._compile_program(src, nvrtc.nvrtcGetPTX, nvrtc.nvrtcGetPTXSize) - def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch) + def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch, ptx=True) class NVCompiler(CUDACompiler): def __init__(self, arch:str): super().__init__(arch, cache_key="nv") def compile(self, src:str) -> bytes: return self._compile_program(src, nvrtc.nvrtcGetCUBIN, nvrtc.nvrtcGetCUBINSize) + def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch) class NVCCCompiler(Compiler): def __init__(self, arch:str, extra_options:list[str]=[]): @@ -72,7 +73,7 @@ class NVCCCompiler(Compiler): srcf.flush() subprocess.run(["nvcc", f"-arch={self.arch}", "-ptx", "-o", libf.name, srcf.name] + self.extra_options, check=True) return libf.read() - def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch) + def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch, ptx=True) class PTXCompiler(Compiler): def __init__(self, arch:str, cache_key="ptx"): @@ -80,7 +81,7 @@ class PTXCompiler(Compiler): super().__init__(f"compile_{cache_key}_{self.arch}") def compile(self, src:str) -> bytes: return src.replace("TARGET", self.arch).replace("VERSION", "8.7" if (ver:=int(self.arch[3:]))>=120 else ("7.8" if ver>=89 else "7.5")).encode() - def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch) + def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch, ptx=True) class NVPTXCompiler(PTXCompiler): def __init__(self, arch:str): @@ -93,3 +94,4 @@ class NVPTXCompiler(PTXCompiler): data = _get_bytes(handle, jitlink.nvJitLinkGetLinkedCubin, jitlink.nvJitLinkGetLinkedCubinSize, jitlink_check) jitlink_check(jitlink.nvJitLinkDestroy(handle)) return data + def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch)