Fix sm89 PTX=1 compilation (#3915)

* Fix sm89 PTX=1 compilation

The minimum PTX version that supports sm89 is 7.8 (same version also
supports sm90); without this ptxas fails when running tinygrad with
PTX=1 on RTX 4090.

* Use int(arch[3:]) for forward compat with SM10.0 if that happens
This commit is contained in:
Arseny Kapoulkine
2024-03-24 20:32:29 -07:00
committed by GitHub
parent 83f39a8ceb
commit 715850aef9
2 changed files with 3 additions and 2 deletions

View File

@@ -56,9 +56,10 @@ class PTXCompiler(Compiler):
linearizer_opts = LinearizerOptions("CUDA", suffix="PTX", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024], shared_max=49152)
def __init__(self, arch:str):
self.arch = arch
self.version = "7.8" if int(arch[3:]) >= 89 else "7.5"
PTXCompiler.linearizer_opts = PTXCompiler.linearizer_opts._replace(has_tensor_cores=int(arch[3:]) >= 80)
super().__init__(f"compile_ptx_{self.arch}")
def render(self, name:str, uops) -> str: return PTXRenderer(name, uops).replace("TARGET", self.arch)
def render(self, name:str, uops) -> str: return PTXRenderer(name, uops).replace("TARGET", self.arch).replace("VERSION", self.version)
def compile(self, src:str) -> bytes: return src.encode()
class CUDACompiler(Compiler):