From efac5b9ef64fff4cb40f96b7b8c61365dbb4e289 Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Sun, 8 Feb 2026 19:58:48 -0800 Subject: [PATCH] new style NV/CUDA renderers, try 2 (#14634) * new style NV/CUDA renderers, try 2 * fix diskcache --- tinygrad/device.py | 6 +++-- tinygrad/renderer/cstyle.py | 13 +++++----- tinygrad/renderer/ptx.py | 8 ++++--- tinygrad/runtime/ops_cuda.py | 8 +++---- tinygrad/runtime/ops_nv.py | 10 ++++---- tinygrad/runtime/support/compiler_cuda.py | 29 ++++++++++------------- 6 files changed, 36 insertions(+), 38 deletions(-) diff --git a/tinygrad/device.py b/tinygrad/device.py index fd34e25caa..7fc2dda00b 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -278,7 +278,9 @@ class Compiler: def disassemble(self, lib:bytes): pass @dataclass(frozen=True) -class CompilerPair: renderer:type[Renderer]|functools.partial; compiler:type[Compiler]|functools.partial|None; ctrl_var:ContextVar|None = None # noqa: E702 +class CompilerPair: + renderer:type[Renderer]|functools.partial; compiler:type[Compiler]|functools.partial|None = None; ctrl_var:ContextVar|None = None # noqa: E702 + name:str|None = None @dataclass(frozen=True) class CompilerSet: cset:list[CompilerPair]; ctrl_var:ContextVar|None = None # noqa: E702 @@ -293,7 +295,7 @@ class Compiled: self.comp_sets:dict[Any, tuple[ContextVar|None, tuple[type[Renderer]|functools.partial, type[Compiler]|functools.partial|None]]] = {} self.cached_pair:dict[Any, tuple[Renderer, Compiler|None]] = {} for cpair in (compilers.cset if compilers is not None else [CompilerPair(Renderer, Compiler)]): - self.comp_sets[self._compiler_name(cpair.renderer, cpair.compiler)] = (cpair.ctrl_var, (cpair.renderer, cpair.compiler)) + self.comp_sets[cpair.name or self._compiler_name(cpair.renderer, cpair.compiler)] = (cpair.ctrl_var, (cpair.renderer, cpair.compiler)) @property def renderer(self) -> Renderer: return self._select_compiler_pair()[0] diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 8f44d32108..ae0e407e95 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -386,15 +386,17 @@ class MetalRenderer(CStyleLanguage): _nms = list("xyzwabcdefghijkl") + [f'v{i}' for i in range(16, 32)] class CUDARenderer(CStyleLanguage): - device = "CUDA" global_max = (2147483647, 65535, 65535) local_max = (1024, 1024, 64) shared_max = 49152 - def __init__(self, arch:str): - self.arch, arch_ver = arch, int(arch[3:]) - self.tensor_cores = tc.cuda_sm89 if arch_ver >= 89 else tc.cuda_sm80 if arch_ver >= 80 else tc.cuda_sm75 if arch_ver >= 75 else [] - def __reduce__(self): return self.__class__, (self.arch,) + def __init__(self, arch:str, device:str="NV", use_nvcc=False): + from tinygrad.runtime.support.compiler_cuda import NVRTCCompiler, NVCCCompiler + from tinygrad.runtime.support.hcq import MOCKGPU + self.device, self.arch, self.use_nvcc = device, arch, use_nvcc + self.compiler = (NVCCCompiler if use_nvcc else NVRTCCompiler)(arch, ptx=bool(MOCKGPU) or device == "CUDA", cache_key=device.lower()) + self.tensor_cores = tc.cuda_sm89 if (ver:=int(arch[3:])) >= 89 else tc.cuda_sm80 if ver >= 80 else tc.cuda_sm75 if ver >= 75 else [] + def __reduce__(self): return self.__class__, (self.arch, self.device, self.use_nvcc) # language options # https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html @@ -551,7 +553,6 @@ class AMDHIPRenderer(CStyleLanguage): for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""") return super().render_kernel(function_name, kernel, bufs, uops, prefix) -class NVRenderer(CUDARenderer): device = "NV" class HIPRenderer(AMDHIPRenderer): device = "HIP" class AMDHIPCCRenderer(AMDHIPRenderer): def __init__(self, arch:str): diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index 6f90df2962..91c519dac9 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -144,9 +144,11 @@ class PTXRenderer(Renderer): tc_sm80 = [x for x in tc.cuda_sm80 if x.dtype_in in [dtypes.half, dtypes.float]] code_for_op = asm_for_op extra_matcher = ptx_matcher - def __init__(self, arch:str, device="CUDA"): - self.device, self.arch, arch_ver = device, arch, int(arch[3:]) - self.tensor_cores = PTXRenderer.tc_sm80 if arch_ver >= 80 else tc.cuda_sm75 if arch_ver >= 75 else [] + def __init__(self, arch:str, device="NV"): + from tinygrad.runtime.support.compiler_cuda import NVPTXCompiler, PTXCompiler + from tinygrad.runtime.support.hcq import MOCKGPU + self.compiler, self.device, self.arch = (PTXCompiler if bool(MOCKGPU) or device == "CUDA" else NVPTXCompiler)(arch), device, arch + self.tensor_cores = PTXRenderer.tc_sm80 if (ver:=int(arch[3:])) >= 80 else tc.cuda_sm75 if ver >= 75 else [] def __reduce__(self): return self.__class__, (self.arch, self.device) # language options diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 8b7a077e81..573833b933 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -5,7 +5,7 @@ from tinygrad.device import Compiled, BufferSpec, LRUAllocator, CompilerPair, Co from tinygrad.renderer.cstyle import CUDARenderer from tinygrad.renderer.ptx import PTXRenderer from tinygrad.runtime.autogen import cuda -from tinygrad.runtime.support.compiler_cuda import pretty_ptx, CUDACompiler, PTXCompiler, NVCCCompiler +from tinygrad.runtime.support.compiler_cuda import pretty_ptx from tinygrad.runtime.support.c import init_c_struct_t, init_c_var if getenv("IOCTL"): import extra.nv_gpu_driver.nv_ioctl # noqa: F401 # pylint: disable=unused-import if MOCKGPU:=getenv("MOCKGPU"): from test.mockgpu.cuda import cuda # type: ignore # pylint: disable=reimported @@ -118,9 +118,9 @@ class CUDADevice(Compiled): CUDADevice.devices.append(self) from tinygrad.runtime.graph.cuda import CUDAGraph - compilers = CompilerSet([CompilerPair(functools.partial(CUDARenderer, self.arch), functools.partial(CUDACompiler, self.arch)), - CompilerPair(functools.partial(PTXRenderer, self.arch), functools.partial(PTXCompiler, self.arch), CUDA_PTX), - CompilerPair(functools.partial(CUDARenderer, self.arch), functools.partial(NVCCCompiler, self.arch))], ctrl_var=CUDA_CC) + compilers = CompilerSet([CompilerPair(functools.partial(CUDARenderer, self.arch, device="CUDA")), + CompilerPair(functools.partial(PTXRenderer, self.arch, device="CUDA"), ctrl_var=CUDA_PTX), + CompilerPair(functools.partial(CUDARenderer, self.arch, device="CUDA", use_nvcc=True), name="NVCC")], ctrl_var=CUDA_CC) super().__init__(device, CUDAAllocator(self), compilers, functools.partial(CUDAProgram, self), None if MOCKGPU else CUDAGraph) def synchronize(self): diff --git a/tinygrad/runtime/ops_nv.py b/tinygrad/runtime/ops_nv.py index bdfef9bf18..21804d6dce 100644 --- a/tinygrad/runtime/ops_nv.py +++ b/tinygrad/runtime/ops_nv.py @@ -10,8 +10,7 @@ from tinygrad.device import Compiled, BufferSpec, CompilerPair, CompilerSet from tinygrad.helpers import getenv, mv_address, round_up, data64, data64_le, prod, OSX, to_mv, hi32, lo32, NV_CC, NV_PTX, NV_NAK, PROFILE from tinygrad.helpers import ContextVar, VIZ, ProfileEvent from tinygrad.renderer.ptx import PTXRenderer -from tinygrad.renderer.cstyle import NVRenderer -from tinygrad.runtime.support.compiler_cuda import CUDACompiler, PTXCompiler, NVPTXCompiler, NVCompiler +from tinygrad.renderer.cstyle import CUDARenderer from tinygrad.runtime.autogen import nv_570, nv_580, pci, mesa from tinygrad.runtime.support.elf import elf_loader from tinygrad.runtime.support.nv.nvdev import NVDev, NVMemoryManager @@ -620,10 +619,9 @@ class NVDevice(HCQCompiled[NVSignal]): self.arch: str = "sm_120" if self.sm_version==0xa04 else f"sm_{(self.sm_version>>8)&0xff}{(val>>4) if (val:=self.sm_version&0xff) > 0xf else val}" self.sass_version = ((self.sm_version & 0xf00) >> 4) | (self.sm_version & 0xf) - cucc, ptxcc = (CUDACompiler, PTXCompiler) if MOCKGPU else (NVCompiler, NVPTXCompiler) - compilers = CompilerSet(ctrl_var=NV_CC, cset=[CompilerPair(functools.partial(NVRenderer, self.arch),functools.partial(cucc, self.arch)), - CompilerPair(functools.partial(PTXRenderer, self.arch, device="NV"), functools.partial(ptxcc, self.arch), NV_PTX), - CompilerPair(functools.partial(NAKRenderer, self.arch, self.max_warps_per_sm), None, NV_NAK)]) + compilers = CompilerSet(ctrl_var=NV_CC, cset=[CompilerPair(functools.partial(CUDARenderer, self.arch)), + CompilerPair(functools.partial(PTXRenderer, self.arch, device="NV"), ctrl_var=NV_PTX), + CompilerPair(functools.partial(NAKRenderer, self.arch, self.max_warps_per_sm), ctrl_var=NV_NAK)]) super().__init__(device, NVAllocator(self), compilers, functools.partial(NVProgram, self), NVSignal, NVComputeQueue, NVCopyQueue) self.pma_enabled = PMA.value > 0 and PROFILE >= 1 diff --git a/tinygrad/runtime/support/compiler_cuda.py b/tinygrad/runtime/support/compiler_cuda.py index bbf4470826..8c4aa4234a 100644 --- a/tinygrad/runtime/support/compiler_cuda.py +++ b/tinygrad/runtime/support/compiler_cuda.py @@ -1,5 +1,4 @@ -import subprocess, hashlib, tempfile, ctypes, re, pathlib -from typing import Callable +import hashlib, tempfile, ctypes, re, pathlib from tinygrad.helpers import to_char_p_p, colored, getenv, system from tinygrad.runtime.support.c import init_c_var from tinygrad.runtime.autogen import nvrtc, nvjitlink as jitlink @@ -42,36 +41,32 @@ def cuda_disassemble(lib:bytes, arch:str, ptx=False): print(system(f'nvdisasm {fn}')) except Exception as e: print("Failed to generate SASS", str(e), "Make sure your PATH contains ptxas/nvdisasm binary of compatible version.") -class CUDACompiler(Compiler): - def __init__(self, arch:str, cache_key:str="cuda"): - self.arch, self.compile_options = arch, [f'--gpu-architecture={arch}'] +class NVRTCCompiler(Compiler): + def __init__(self, arch:str, ptx=True, cache_key:str="cuda"): + self.ptx, self.arch, self.compile_options = ptx, arch, [f'--gpu-architecture={arch}'] self.compile_options += [f"-I{CUDA_PATH}/include"] if CUDA_PATH else ["-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include"] nvrtc_check(nvrtc.nvrtcVersion((nvrtcMajor := ctypes.c_int()), (nvrtcMinor := ctypes.c_int()))) if (nvrtcMajor.value, nvrtcMinor.value) >= (12, 4): self.compile_options.append("--minimal") super().__init__(f"compile_{cache_key}_{self.arch}") - def _compile_program(self, src:str, nvrtc_get_content:Callable, nvrtc_get_size:Callable) -> bytes: + def compile(self, src:str) -> bytes: nvrtc_check(nvrtc.nvrtcCreateProgram(ctypes.byref(prog := nvrtc.nvrtcProgram()), src.encode(), "".encode(), 0, None, None)) nvrtc_check(nvrtc.nvrtcCompileProgram(prog, len(self.compile_options), to_char_p_p([o.encode() for o in self.compile_options])), prog) - data = _get_bytes(prog, nvrtc_get_content, nvrtc_get_size, nvrtc_check) + data = _get_bytes(prog, nvrtc.nvrtcGetPTX if self.ptx else nvrtc.nvrtcGetCUBIN, + nvrtc.nvrtcGetPTXSize if self.ptx else nvrtc.nvrtcGetCUBINSize, nvrtc_check) nvrtc_check(nvrtc.nvrtcDestroyProgram(ctypes.byref(prog))) return data - def compile(self, src:str) -> bytes: return self._compile_program(src, nvrtc.nvrtcGetPTX, nvrtc.nvrtcGetPTXSize) - def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch, ptx=True) - -class NVCompiler(CUDACompiler): - def __init__(self, arch:str): super().__init__(arch, cache_key="nv") - def compile(self, src:str) -> bytes: return self._compile_program(src, nvrtc.nvrtcGetCUBIN, nvrtc.nvrtcGetCUBINSize) - def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch) + def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch, ptx=self.ptx) class NVCCCompiler(Compiler): - def __init__(self, arch:str, extra_options:list[str]=[]): + def __init__(self, arch:str, ptx:bool=True, cache_key:str="cuda", extra_options:list[str]=[]): + assert ptx, "NVCCCompiler cubin support unimplemented" self.arch, self.extra_options = arch, extra_options - super().__init__(f"compile_nvcc_{self.arch}_{hashlib.sha256(' '.join(extra_options).encode()).hexdigest()[:8]}") + super().__init__(f"compile_nvcc_{cache_key}_{self.arch}_{hashlib.sha256(' '.join(extra_options).encode()).hexdigest()[:8]}") def compile(self, src:str) -> bytes: with tempfile.NamedTemporaryFile(suffix=".cu") as srcf, tempfile.NamedTemporaryFile(suffix=".ptx") as libf: srcf.write(src.encode()) srcf.flush() - subprocess.run(["nvcc", f"-arch={self.arch}", "-ptx", "-o", libf.name, srcf.name] + self.extra_options, check=True) + system(f"nvcc -arch={self.arch} -ptx -o {libf.name} {srcf.name}" + ' '.join(self.extra_options)) return libf.read() def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch, ptx=True)