cleanup nv/cuda compilers (#5767)

* cleanup nv/cuda compilers

* destroy prog

* small test

* fix test

* nv ptx rewrite key

* jitlink free

* ptx is part of cuda
This commit is contained in:
nimlgen
2024-07-29 13:50:03 +03:00
committed by GitHub
parent 76840fd65a
commit ab3839a80a
2 changed files with 32 additions and 31 deletions

View File

@@ -2,6 +2,7 @@
import unittest
from unittest.mock import patch
import os
from tinygrad import Tensor
from tinygrad.device import Device, Compiler
from tinygrad.helpers import diskcache_get, diskcache_put, getenv
@@ -37,5 +38,11 @@ class TestCompiler(unittest.TestCase):
assert MockCompiler("disabled_key").compile_cached("123") == str.encode("123")
assert diskcache_get("disabled_key", "123") is None
def test_device_compile(self):
getenv.cache_clear()
with patch.dict(os.environ, {"DISABLE_COMPILER_CACHE": "1"}):
a = Tensor([0.,1.], device=Device.DEFAULT).realize()
(a + 1).realize()
if __name__ == "__main__":
unittest.main()

View File

@@ -1,5 +1,5 @@
import subprocess, hashlib, tempfile, ctypes, ctypes.util, re
from pathlib import Path
import subprocess, hashlib, tempfile, ctypes, ctypes.util, re, pathlib
from typing import Callable
from tinygrad.helpers import to_char_p_p, colored, init_c_var, getenv
import tinygrad.runtime.autogen.nvrtc as nvrtc
from tinygrad.device import Compiler, CompileError
@@ -32,7 +32,7 @@ def pretty_ptx(s):
def cuda_disassemble(lib, arch):
try:
fn = (Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
fn = (pathlib.Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
with open(fn + ".ptx", "wb") as f: f.write(lib)
subprocess.run(["ptxas", f"-arch={arch}", "-o", fn, fn+".ptx"], check=True)
print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8'))
@@ -40,45 +40,39 @@ def cuda_disassemble(lib, arch):
def nv_disassemble(lib):
try:
fn = (Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
fn = (pathlib.Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
with open(fn + ".cubin", "wb") as f: f.write(lib)
print(subprocess.check_output(["nvdisasm", fn+".cubin"]).decode('utf-8'))
except Exception as e: print("Failed to disasm cubin:", str(e), "Make sure your PATH contains nvdisasm binary of compatible version.")
class PTXCompiler(Compiler):
def __init__(self, arch:str):
self.arch = arch
self.version = "7.8" if arch >= "sm_89" else "7.5"
super().__init__(f"compile_ptx_{self.arch}")
def compile(self, src:str) -> bytes: return src.replace("TARGET", self.arch).replace("VERSION", self.version).encode()
class CUDACompiler(Compiler):
def __init__(self, arch:str):
self.arch = arch
nvrtc_check(nvrtc.nvrtcVersion((nvrtcMajor := ctypes.c_int()), (nvrtcMinor := ctypes.c_int())))
self.compile_options = [f'--gpu-architecture={arch}', "-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include/"]
if (nvrtcMajor.value, nvrtcMinor.value) >= (12, 4): self.compile_options.append("--minimal")
super().__init__(f"compile_cuda_{self.arch}")
def compile(self, src:str) -> bytes:
nvrtc_check(nvrtc.nvrtcCreateProgram(ctypes.byref(prog := nvrtc.nvrtcProgram()), src.encode(), "<null>".encode(), 0, None, None))
nvrtc_check(nvrtc.nvrtcCompileProgram(prog, len(self.compile_options), to_char_p_p([o.encode() for o in self.compile_options])), prog)
return _get_bytes(prog, nvrtc.nvrtcGetPTX, nvrtc.nvrtcGetPTXSize, nvrtc_check)
class NVCompiler(Compiler):
def __init__(self, arch:str):
def __init__(self, arch:str, cache_key:str="cuda"):
self.arch, self.compile_options = arch, [f'--gpu-architecture={arch}', "-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include/"]
nvrtc_check(nvrtc.nvrtcVersion((nvrtcMajor := ctypes.c_int()), (nvrtcMinor := ctypes.c_int())))
if (nvrtcMajor.value, nvrtcMinor.value) >= (12, 4): self.compile_options.append("--minimal")
super().__init__(f"compile_nv_{self.arch}")
def compile(self, src:str) -> bytes:
super().__init__(f"compile_{cache_key}_{self.arch}")
def _compile_program(self, src:str, nvrtc_get_content:Callable, nvrtc_get_size:Callable) -> bytes:
nvrtc_check(nvrtc.nvrtcCreateProgram(ctypes.byref(prog := nvrtc.nvrtcProgram()), src.encode(), "<null>".encode(), 0, None, None))
nvrtc_check(nvrtc.nvrtcCompileProgram(prog, len(self.compile_options), to_char_p_p([o.encode() for o in self.compile_options])), prog)
return _get_bytes(prog, nvrtc.nvrtcGetCUBIN, nvrtc.nvrtcGetCUBINSize, nvrtc_check)
data = _get_bytes(prog, nvrtc_get_content, nvrtc_get_size, nvrtc_check)
nvrtc_check(nvrtc.nvrtcDestroyProgram(ctypes.byref(prog)))
return data
def compile(self, src:str) -> bytes: return self._compile_program(src, nvrtc.nvrtcGetPTX, nvrtc.nvrtcGetPTXSize)
class NVPTXCompiler(NVCompiler):
class NVCompiler(CUDACompiler):
def __init__(self, arch:str): super().__init__(arch, cache_key="nv")
def compile(self, src:str) -> bytes: return self._compile_program(src, nvrtc.nvrtcGetCUBIN, nvrtc.nvrtcGetCUBINSize)
class PTXCompiler(CUDACompiler):
def __init__(self, arch:str, cache_key="ptx"): super().__init__(arch, cache_key=cache_key)
def compile(self, src:str) -> bytes: return src.replace("TARGET", self.arch).replace("VERSION", "7.8" if self.arch >= "sm_89" else "7.5").encode()
class NVPTXCompiler(PTXCompiler):
def __init__(self, arch:str): super().__init__(arch, cache_key="nv_ptx")
def compile(self, src:str) -> bytes:
ptxsrc = src.replace("TARGET", self.arch).replace("VERSION", "7.8" if self.arch >= "sm_89" else "7.5")
jitlink_check(nvrtc.nvJitLinkCreate(handle := nvrtc.nvJitLinkHandle(), 1, to_char_p_p([f'-arch={self.arch}'.encode()])), handle)
jitlink_check(nvrtc.nvJitLinkAddData(handle, nvrtc.NVJITLINK_INPUT_PTX, ptxsrc.encode(), len(ptxsrc), "<null>".encode()), handle)
jitlink_check(nvrtc.nvJitLinkAddData(handle, nvrtc.NVJITLINK_INPUT_PTX, ptxsrc:=super().compile(src), len(ptxsrc), "<null>".encode()), handle)
jitlink_check(nvrtc.nvJitLinkComplete(handle), handle)
return _get_bytes(handle, nvrtc.nvJitLinkGetLinkedCubin, nvrtc.nvJitLinkGetLinkedCubinSize, jitlink_check)
data = _get_bytes(handle, nvrtc.nvJitLinkGetLinkedCubin, nvrtc.nvJitLinkGetLinkedCubinSize, jitlink_check)
jitlink_check(nvrtc.nvJitLinkDestroy(handle))
return data