mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
cuda compilers disassemble properly (#14166)
* cuda compilers disassemble properly * this can use system
This commit is contained in:
committed by
GitHub
parent
14e9a71a41
commit
f9ca072b61
@@ -34,11 +34,11 @@ def pretty_ptx(s):
|
||||
s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives
|
||||
return s
|
||||
|
||||
def cuda_disassemble(lib:bytes, arch:str):
|
||||
def cuda_disassemble(lib:bytes, arch:str, ptx=False):
|
||||
try:
|
||||
fn = (pathlib.Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
|
||||
with open(fn, "wb") as f: f.write(lib.rstrip(b'\x00'))
|
||||
subprocess.run(["ptxas", f"-arch={arch}", "-o", fn, fn], check=False, stderr=subprocess.DEVNULL) # optional ptx -> sass step for CUDA=1
|
||||
with open(fn, "wb") as f: f.write(lib.rstrip(b'\x00') if ptx else lib)
|
||||
if ptx: system(f"ptxas -arch={arch} -o {fn} {fn}")
|
||||
print(system(f'nvdisasm {fn}'))
|
||||
except Exception as e: print("Failed to generate SASS", str(e), "Make sure your PATH contains ptxas/nvdisasm binary of compatible version.")
|
||||
|
||||
@@ -56,11 +56,12 @@ class CUDACompiler(Compiler):
|
||||
nvrtc_check(nvrtc.nvrtcDestroyProgram(ctypes.byref(prog)))
|
||||
return data
|
||||
def compile(self, src:str) -> bytes: return self._compile_program(src, nvrtc.nvrtcGetPTX, nvrtc.nvrtcGetPTXSize)
|
||||
def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch)
|
||||
def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch, ptx=True)
|
||||
|
||||
class NVCompiler(CUDACompiler):
|
||||
def __init__(self, arch:str): super().__init__(arch, cache_key="nv")
|
||||
def compile(self, src:str) -> bytes: return self._compile_program(src, nvrtc.nvrtcGetCUBIN, nvrtc.nvrtcGetCUBINSize)
|
||||
def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch)
|
||||
|
||||
class NVCCCompiler(Compiler):
|
||||
def __init__(self, arch:str, extra_options:list[str]=[]):
|
||||
@@ -72,7 +73,7 @@ class NVCCCompiler(Compiler):
|
||||
srcf.flush()
|
||||
subprocess.run(["nvcc", f"-arch={self.arch}", "-ptx", "-o", libf.name, srcf.name] + self.extra_options, check=True)
|
||||
return libf.read()
|
||||
def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch)
|
||||
def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch, ptx=True)
|
||||
|
||||
class PTXCompiler(Compiler):
|
||||
def __init__(self, arch:str, cache_key="ptx"):
|
||||
@@ -80,7 +81,7 @@ class PTXCompiler(Compiler):
|
||||
super().__init__(f"compile_{cache_key}_{self.arch}")
|
||||
def compile(self, src:str) -> bytes:
|
||||
return src.replace("TARGET", self.arch).replace("VERSION", "8.7" if (ver:=int(self.arch[3:]))>=120 else ("7.8" if ver>=89 else "7.5")).encode()
|
||||
def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch)
|
||||
def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch, ptx=True)
|
||||
|
||||
class NVPTXCompiler(PTXCompiler):
|
||||
def __init__(self, arch:str):
|
||||
@@ -93,3 +94,4 @@ class NVPTXCompiler(PTXCompiler):
|
||||
data = _get_bytes(handle, jitlink.nvJitLinkGetLinkedCubin, jitlink.nvJitLinkGetLinkedCubinSize, jitlink_check)
|
||||
jitlink_check(jitlink.nvJitLinkDestroy(handle))
|
||||
return data
|
||||
def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch)
|
||||
|
||||
Reference in New Issue
Block a user