From 715850aef9969fbf1a0a11261c7010e02b954086 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Sun, 24 Mar 2024 20:32:29 -0700 Subject: [PATCH] Fix sm89 PTX=1 compilation (#3915) * Fix sm89 PTX=1 compilation The minimum PTX version that supports sm89 is 7.8 (same version also supports sm90); without this ptxas fails when running tinygrad with PTX=1 on RTX 4090. * Use int(arch[3:]) for forward compat with SM10.0 if that happens --- tinygrad/renderer/assembly.py | 2 +- tinygrad/runtime/ops_cuda.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index 86f576eead..c2c95eaf85 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -220,7 +220,7 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str: return lang.render_kernel(kernel, function_name, bufs, c.items()) class PTXLanguage(AssemblyLanguage): - kernel_prefix = """.version 7.5 + kernel_prefix = """.version VERSION .target TARGET .address_size 64 .visible .entry""" diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 3a1895a285..ba85c81311 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -56,9 +56,10 @@ class PTXCompiler(Compiler): linearizer_opts = LinearizerOptions("CUDA", suffix="PTX", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024], shared_max=49152) def __init__(self, arch:str): self.arch = arch + self.version = "7.8" if int(arch[3:]) >= 89 else "7.5" PTXCompiler.linearizer_opts = PTXCompiler.linearizer_opts._replace(has_tensor_cores=int(arch[3:]) >= 80) super().__init__(f"compile_ptx_{self.arch}") - def render(self, name:str, uops) -> str: return PTXRenderer(name, uops).replace("TARGET", self.arch) + def render(self, name:str, uops) -> str: return PTXRenderer(name, uops).replace("TARGET", self.arch).replace("VERSION", self.version) def compile(self, src:str) -> bytes: return src.encode() class CUDACompiler(Compiler):