ROCM IFU: Add get_version_key for ROCM backend

This commit is contained in:
Jason Furmanek
2023-11-28 00:11:44 +00:00
parent 71547e4fdb
commit f5f6b3c0a3

View File

@@ -8,8 +8,8 @@ from typing import Any, Tuple
from triton.common import _build
from triton.common.backend import BaseBackend, register_backend
from triton.compiler.make_launcher import get_cache_manager, version_key, make_so_cache_key
from triton.common.backend import BaseBackend, register_backend, compute_core_version_key
from triton.compiler.make_launcher import get_cache_manager, make_so_cache_key
from triton.compiler.utils import generate_cu_signature
from triton.runtime import jit
from triton.runtime.driver import HIPDriver
@@ -25,7 +25,7 @@ else:
def make_stub(name, signature, constants, ids, **kwargs):
# name of files that are cached
so_cache_key = make_so_cache_key(version_key(), signature, constants, ids, **kwargs)
so_cache_key = make_so_cache_key(compute_core_version_key(), signature, constants, ids, **kwargs)
so_cache_manager = get_cache_manager(so_cache_key)
so_name = f"{name}.so"
# retrieve stub from cache if it exists
@@ -414,11 +414,21 @@ def llir_to_amdgcn_and_hsaco(mod: Any, gfx_arch: str, gfx_triple: str, gfx_featu
class HIPBackend(BaseBackend):
_cached_rocm_version_key = None
def __init__(self, device_type: str) -> None:
super(HIPBackend, self).__init__(device_type)
self.driver = HIPDriver()
self.stub_so_path = ""
def get_version_key(self):
if self._cached_rocm_version_key is None:
key = compute_core_version_key()
### TODO: Append ROCM version here if needed
self._cached_rocm_version_key = key
return self._cached_rocm_version_key
def is_standalone(self):
return not HIP_BACKEND_MODE
@@ -500,7 +510,6 @@ class HIPBackend(BaseBackend):
return arch
def make_launcher_stub(self, name, signature, constants, ids):
# print("HIPBackend.make_launcher_stub")
self.stub_so_path = make_stub(name, signature, constants, ids)
return self.stub_so_path
@@ -517,4 +526,4 @@ class HIPBackend(BaseBackend):
return _triton.get_num_warps(module)
def get_matrix_core_version(self):
return gpu_matrix_core_version()
return gpu_matrix_core_version()