mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
[FRONTEND] change hash to not require ptxas (#2476)
I noticed that Triton is using the `ptxas` version as part of the version hash even for non-CUDA targets. This is an attempt at fixing this. Moving the version calculation to the back-end makes sense to me from an architectural standpoint, so that's my approach here. I'm not as confident in the implementation, so please if folks have any feedback let me know.
This commit is contained in:
@@ -13,12 +13,12 @@ import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.common.backend import BaseBackend, register_backend
|
||||
from triton.common.backend import (BaseBackend, compute_core_version_key,
|
||||
register_backend)
|
||||
from triton.common.build import quiet
|
||||
from triton.compiler.make_launcher import make_so_cache_key
|
||||
from triton.runtime.cache import get_cache_manager
|
||||
from triton.runtime.driver import DriverBase
|
||||
from triton.runtime.jit import version_key
|
||||
|
||||
|
||||
def build_for_backend(name, src, srcdir):
|
||||
@@ -125,6 +125,7 @@ class ExtensionBackend(BaseBackend):
|
||||
def __init__(self, device_type: str) -> None:
|
||||
super(ExtensionBackend, self).__init__(device_type)
|
||||
self.driver = ExtensionDriver()
|
||||
self.version_key = None
|
||||
|
||||
def add_stages(self, arch, extern_libs, stages):
|
||||
filter_in_stages = ["ast", "ttir", "ttgir"]
|
||||
@@ -163,9 +164,14 @@ class ExtensionBackend(BaseBackend):
|
||||
def get_architecture_descriptor(self, **kwargs):
|
||||
return ""
|
||||
|
||||
def get_version_key(self):
|
||||
if self.version_key is None:
|
||||
self.version_key = compute_core_version_key()
|
||||
return self.version_key
|
||||
|
||||
def make_launcher_stub(self, name, signature, constants):
|
||||
# name of files that are cached
|
||||
so_cache_key = make_so_cache_key(version_key(), signature, constants)
|
||||
so_cache_key = make_so_cache_key(self.get_version_key(), signature, constants)
|
||||
so_cache_manager = get_cache_manager(so_cache_key)
|
||||
so_name = f"{name}.so"
|
||||
# retrieve stub from cache if it exists
|
||||
|
||||
Reference in New Issue
Block a user