[FRONTEND] change hash to not require ptxas (#2476)

I noticed that Triton is using the `ptxas` version as part of the
version hash even for non-CUDA targets. This is an attempt at fixing
this. Moving the version calculation to the back-end makes sense to me
from an architectural standpoint, so that's my approach here. I'm not as
confident in the implementation, so please if folks have any feedback
let me know.
This commit is contained in:
ian Bearman
2023-10-17 10:28:51 -07:00
committed by GitHub
parent 376acb610b
commit 768fc1fcd9
6 changed files with 76 additions and 55 deletions

View File

@@ -13,12 +13,12 @@ import torch
import triton
import triton.language as tl
from triton.common.backend import BaseBackend, register_backend
from triton.common.backend import (BaseBackend, compute_core_version_key,
register_backend)
from triton.common.build import quiet
from triton.compiler.make_launcher import make_so_cache_key
from triton.runtime.cache import get_cache_manager
from triton.runtime.driver import DriverBase
from triton.runtime.jit import version_key
def build_for_backend(name, src, srcdir):
@@ -125,6 +125,7 @@ class ExtensionBackend(BaseBackend):
def __init__(self, device_type: str) -> None:
super(ExtensionBackend, self).__init__(device_type)
self.driver = ExtensionDriver()
self.version_key = None
def add_stages(self, arch, extern_libs, stages):
filter_in_stages = ["ast", "ttir", "ttgir"]
@@ -163,9 +164,14 @@ class ExtensionBackend(BaseBackend):
def get_architecture_descriptor(self, **kwargs):
return ""
def get_version_key(self):
if self.version_key is None:
self.version_key = compute_core_version_key()
return self.version_key
def make_launcher_stub(self, name, signature, constants):
# name of files that are cached
so_cache_key = make_so_cache_key(version_key(), signature, constants)
so_cache_key = make_so_cache_key(self.get_version_key(), signature, constants)
so_cache_manager = get_cache_manager(so_cache_key)
so_name = f"{name}.so"
# retrieve stub from cache if it exists