diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index 86f576eead..c2c95eaf85 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -220,7 +220,7 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str: return lang.render_kernel(kernel, function_name, bufs, c.items()) class PTXLanguage(AssemblyLanguage): - kernel_prefix = """.version 7.5 + kernel_prefix = """.version VERSION .target TARGET .address_size 64 .visible .entry""" diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 3a1895a285..ba85c81311 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -56,9 +56,10 @@ class PTXCompiler(Compiler): linearizer_opts = LinearizerOptions("CUDA", suffix="PTX", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024], shared_max=49152) def __init__(self, arch:str): self.arch = arch + self.version = "7.8" if int(arch[3:]) >= 89 else "7.5" PTXCompiler.linearizer_opts = PTXCompiler.linearizer_opts._replace(has_tensor_cores=int(arch[3:]) >= 80) super().__init__(f"compile_ptx_{self.arch}") - def render(self, name:str, uops) -> str: return PTXRenderer(name, uops).replace("TARGET", self.arch) + def render(self, name:str, uops) -> str: return PTXRenderer(name, uops).replace("TARGET", self.arch).replace("VERSION", self.version) def compile(self, src:str) -> bytes: return src.encode() class CUDACompiler(Compiler):