mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
ROCM IFU: Add get_version_key for ROCM backend
This commit is contained in:
19
python/triton/third_party/hip/hip_backend.py
vendored
19
python/triton/third_party/hip/hip_backend.py
vendored
@@ -8,8 +8,8 @@ from typing import Any, Tuple
|
||||
|
||||
|
||||
from triton.common import _build
|
||||
from triton.common.backend import BaseBackend, register_backend
|
||||
from triton.compiler.make_launcher import get_cache_manager, version_key, make_so_cache_key
|
||||
from triton.common.backend import BaseBackend, register_backend, compute_core_version_key
|
||||
from triton.compiler.make_launcher import get_cache_manager, make_so_cache_key
|
||||
from triton.compiler.utils import generate_cu_signature
|
||||
from triton.runtime import jit
|
||||
from triton.runtime.driver import HIPDriver
|
||||
@@ -25,7 +25,7 @@ else:
|
||||
|
||||
def make_stub(name, signature, constants, ids, **kwargs):
|
||||
# name of files that are cached
|
||||
so_cache_key = make_so_cache_key(version_key(), signature, constants, ids, **kwargs)
|
||||
so_cache_key = make_so_cache_key(compute_core_version_key(), signature, constants, ids, **kwargs)
|
||||
so_cache_manager = get_cache_manager(so_cache_key)
|
||||
so_name = f"{name}.so"
|
||||
# retrieve stub from cache if it exists
|
||||
@@ -414,11 +414,21 @@ def llir_to_amdgcn_and_hsaco(mod: Any, gfx_arch: str, gfx_triple: str, gfx_featu
|
||||
|
||||
|
||||
class HIPBackend(BaseBackend):
|
||||
_cached_rocm_version_key = None
|
||||
|
||||
def __init__(self, device_type: str) -> None:
|
||||
super(HIPBackend, self).__init__(device_type)
|
||||
self.driver = HIPDriver()
|
||||
self.stub_so_path = ""
|
||||
|
||||
def get_version_key(self):
|
||||
if self._cached_rocm_version_key is None:
|
||||
key = compute_core_version_key()
|
||||
### TODO: Append ROCM version here if needed
|
||||
|
||||
self._cached_rocm_version_key = key
|
||||
return self._cached_rocm_version_key
|
||||
|
||||
def is_standalone(self):
|
||||
return not HIP_BACKEND_MODE
|
||||
|
||||
@@ -500,7 +510,6 @@ class HIPBackend(BaseBackend):
|
||||
return arch
|
||||
|
||||
def make_launcher_stub(self, name, signature, constants, ids):
|
||||
# print("HIPBackend.make_launcher_stub")
|
||||
self.stub_so_path = make_stub(name, signature, constants, ids)
|
||||
return self.stub_so_path
|
||||
|
||||
@@ -517,4 +526,4 @@ class HIPBackend(BaseBackend):
|
||||
return _triton.get_num_warps(module)
|
||||
|
||||
def get_matrix_core_version(self):
|
||||
return gpu_matrix_core_version()
|
||||
return gpu_matrix_core_version()
|
||||
|
||||
Reference in New Issue
Block a user