mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Override prototype (#2214)
Low tech but very useful way to override kernels on the fly. This can be use for debugging functionality or performance problems this lets user dump modify and feed back IR into the jit compiler.
This commit is contained in:
@@ -20,7 +20,7 @@ from ..common.build import is_hip
|
||||
# from ..runtime import driver, jit, JITFunction
|
||||
# TODO: runtime.errors
|
||||
from ..runtime.autotuner import OutOfResources
|
||||
from ..runtime.cache import get_cache_manager
|
||||
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
|
||||
from ..runtime.driver import driver
|
||||
from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device,
|
||||
get_device_capability, version_key)
|
||||
@@ -229,6 +229,9 @@ def make_hash(fn, arch, env_vars, **kwargs):
|
||||
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{arch}-{env_vars_list}"
|
||||
return hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
assert isinstance(fn, str)
|
||||
ignore_version = kwargs.get('ignore_version', False)
|
||||
if (ignore_version):
|
||||
return hashlib.md5((Path(fn).read_text()).encode("utf-8")).hexdigest()
|
||||
return hashlib.md5((Path(fn).read_text() + version_key()).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
@@ -433,6 +436,11 @@ def compile(fn, **kwargs):
|
||||
|
||||
# create cache manager
|
||||
fn_cache_manager = get_cache_manager(make_hash(fn, arch, get_env_vars(), **kwargs))
|
||||
# managers used to dump and override IR for debugging
|
||||
enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1"
|
||||
fn_override_manager = get_override_manager(make_hash(fn, arch, get_env_vars(), **kwargs, ignore_version=True))
|
||||
fn_dump_manager = get_dump_manager(make_hash(fn, arch, get_env_vars(), **kwargs, ignore_version=True))
|
||||
|
||||
# determine name and extension type of provided function
|
||||
if isinstance(fn, JITFunction):
|
||||
name, ext = fn.__name__, "ast"
|
||||
@@ -493,6 +501,11 @@ def compile(fn, **kwargs):
|
||||
else:
|
||||
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
|
||||
fn_cache_manager.put(next_module, ir_filename)
|
||||
fn_dump_manager.put(next_module, ir_filename)
|
||||
if (enable_override and fn_override_manager.has_file(ir_filename)):
|
||||
print(f"\nOverriding kernel with file {ir_filename}")
|
||||
full_name = fn_override_manager.get_file(ir_filename)
|
||||
next_module = parse(full_name)
|
||||
else:
|
||||
if ir_name == "amdgcn":
|
||||
extra_file_name = f"{name}.hsaco_path"
|
||||
|
||||
@@ -10,6 +10,14 @@ def default_cache_dir():
|
||||
return os.path.join(Path.home(), ".triton", "cache")
|
||||
|
||||
|
||||
def default_override_dir():
|
||||
return os.path.join(Path.home(), ".triton", "override")
|
||||
|
||||
|
||||
def default_dump_dir():
|
||||
return os.path.join(Path.home(), ".triton", "dump")
|
||||
|
||||
|
||||
class CacheManager(ABC):
|
||||
def __init__(self, key):
|
||||
pass
|
||||
@@ -36,17 +44,26 @@ class CacheManager(ABC):
|
||||
|
||||
|
||||
class FileCacheManager(CacheManager):
|
||||
def __init__(self, key):
|
||||
def __init__(self, key, override=False, dump=False):
|
||||
self.key = key
|
||||
self.lock_path = None
|
||||
# create cache directory if it doesn't exist
|
||||
self.cache_dir = os.getenv('TRITON_CACHE_DIR', "").strip() or default_cache_dir()
|
||||
if self.cache_dir:
|
||||
if (dump):
|
||||
self.cache_dir = default_dump_dir()
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
elif (override):
|
||||
self.cache_dir = default_override_dir()
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
else:
|
||||
raise RuntimeError("Could not create or locate cache dir")
|
||||
# create cache directory if it doesn't exist
|
||||
self.cache_dir = os.getenv('TRITON_CACHE_DIR', "").strip() or default_cache_dir()
|
||||
if self.cache_dir:
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
else:
|
||||
raise RuntimeError("Could not create or locate cache dir")
|
||||
|
||||
def _make_path(self, filename) -> str:
|
||||
return os.path.join(self.cache_dir, filename)
|
||||
@@ -131,3 +148,11 @@ def get_cache_manager(key) -> CacheManager:
|
||||
__cache_cls_nme = user_cache_manager
|
||||
|
||||
return __cache_cls(key)
|
||||
|
||||
|
||||
def get_override_manager(key) -> CacheManager:
|
||||
return __cache_cls(key, override=True)
|
||||
|
||||
|
||||
def get_dump_manager(key) -> CacheManager:
|
||||
return __cache_cls(key, dump=True)
|
||||
|
||||
Reference in New Issue
Block a user