cuda compilers disassemble properly (#14166)

* cuda compilers disassemble properly

* this can use system
This commit is contained in:
Christopher Milan
2026-01-15 16:02:40 -08:00
committed by GitHub
parent 14e9a71a41
commit f9ca072b61

View File

@@ -34,11 +34,11 @@ def pretty_ptx(s):
s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives
return s
def cuda_disassemble(lib:bytes, arch:str):
def cuda_disassemble(lib:bytes, arch:str, ptx=False):
try:
fn = (pathlib.Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
with open(fn, "wb") as f: f.write(lib.rstrip(b'\x00'))
subprocess.run(["ptxas", f"-arch={arch}", "-o", fn, fn], check=False, stderr=subprocess.DEVNULL) # optional ptx -> sass step for CUDA=1
with open(fn, "wb") as f: f.write(lib.rstrip(b'\x00') if ptx else lib)
if ptx: system(f"ptxas -arch={arch} -o {fn} {fn}")
print(system(f'nvdisasm {fn}'))
except Exception as e: print("Failed to generate SASS", str(e), "Make sure your PATH contains ptxas/nvdisasm binary of compatible version.")
@@ -56,11 +56,12 @@ class CUDACompiler(Compiler):
nvrtc_check(nvrtc.nvrtcDestroyProgram(ctypes.byref(prog)))
return data
def compile(self, src:str) -> bytes: return self._compile_program(src, nvrtc.nvrtcGetPTX, nvrtc.nvrtcGetPTXSize)
def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch)
def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch, ptx=True)
class NVCompiler(CUDACompiler):
def __init__(self, arch:str): super().__init__(arch, cache_key="nv")
def compile(self, src:str) -> bytes: return self._compile_program(src, nvrtc.nvrtcGetCUBIN, nvrtc.nvrtcGetCUBINSize)
def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch)
class NVCCCompiler(Compiler):
def __init__(self, arch:str, extra_options:list[str]=[]):
@@ -72,7 +73,7 @@ class NVCCCompiler(Compiler):
srcf.flush()
subprocess.run(["nvcc", f"-arch={self.arch}", "-ptx", "-o", libf.name, srcf.name] + self.extra_options, check=True)
return libf.read()
def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch)
def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch, ptx=True)
class PTXCompiler(Compiler):
def __init__(self, arch:str, cache_key="ptx"):
@@ -80,7 +81,7 @@ class PTXCompiler(Compiler):
super().__init__(f"compile_{cache_key}_{self.arch}")
def compile(self, src:str) -> bytes:
return src.replace("TARGET", self.arch).replace("VERSION", "8.7" if (ver:=int(self.arch[3:]))>=120 else ("7.8" if ver>=89 else "7.5")).encode()
def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch)
def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch, ptx=True)
class NVPTXCompiler(PTXCompiler):
def __init__(self, arch:str):
@@ -93,3 +94,4 @@ class NVPTXCompiler(PTXCompiler):
data = _get_bytes(handle, jitlink.nvJitLinkGetLinkedCubin, jitlink.nvJitLinkGetLinkedCubinSize, jitlink_check)
jitlink_check(jitlink.nvJitLinkDestroy(handle))
return data
def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch)