[FRONTEND] Override prototype (#2214)

Low tech but very useful way to override kernels on the fly. This can be
use for debugging functionality or performance problems this lets user
dump modify and feed back IR into the jit compiler.
This commit is contained in:
Thomas Raoux
2023-09-13 10:05:47 -07:00
committed by GitHub
parent cf7f8c5ea4
commit b63e8f87fc
2 changed files with 44 additions and 6 deletions

View File

@@ -20,7 +20,7 @@ from ..common.build import is_hip
# from ..runtime import driver, jit, JITFunction
# TODO: runtime.errors
from ..runtime.autotuner import OutOfResources
from ..runtime.cache import get_cache_manager
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
from ..runtime.driver import driver
from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device,
get_device_capability, version_key)
@@ -229,6 +229,9 @@ def make_hash(fn, arch, env_vars, **kwargs):
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{arch}-{env_vars_list}"
return hashlib.md5(key.encode("utf-8")).hexdigest()
assert isinstance(fn, str)
ignore_version = kwargs.get('ignore_version', False)
if (ignore_version):
return hashlib.md5((Path(fn).read_text()).encode("utf-8")).hexdigest()
return hashlib.md5((Path(fn).read_text() + version_key()).encode("utf-8")).hexdigest()
@@ -433,6 +436,11 @@ def compile(fn, **kwargs):
# create cache manager
fn_cache_manager = get_cache_manager(make_hash(fn, arch, get_env_vars(), **kwargs))
# managers used to dump and override IR for debugging
enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1"
fn_override_manager = get_override_manager(make_hash(fn, arch, get_env_vars(), **kwargs, ignore_version=True))
fn_dump_manager = get_dump_manager(make_hash(fn, arch, get_env_vars(), **kwargs, ignore_version=True))
# determine name and extension type of provided function
if isinstance(fn, JITFunction):
name, ext = fn.__name__, "ast"
@@ -493,6 +501,11 @@ def compile(fn, **kwargs):
else:
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
fn_cache_manager.put(next_module, ir_filename)
fn_dump_manager.put(next_module, ir_filename)
if (enable_override and fn_override_manager.has_file(ir_filename)):
print(f"\nOverriding kernel with file {ir_filename}")
full_name = fn_override_manager.get_file(ir_filename)
next_module = parse(full_name)
else:
if ir_name == "amdgcn":
extra_file_name = f"{name}.hsaco_path"

View File

@@ -10,6 +10,14 @@ def default_cache_dir():
return os.path.join(Path.home(), ".triton", "cache")
def default_override_dir():
return os.path.join(Path.home(), ".triton", "override")
def default_dump_dir():
return os.path.join(Path.home(), ".triton", "dump")
class CacheManager(ABC):
def __init__(self, key):
pass
@@ -36,17 +44,26 @@ class CacheManager(ABC):
class FileCacheManager(CacheManager):
def __init__(self, key):
def __init__(self, key, override=False, dump=False):
self.key = key
self.lock_path = None
# create cache directory if it doesn't exist
self.cache_dir = os.getenv('TRITON_CACHE_DIR', "").strip() or default_cache_dir()
if self.cache_dir:
if (dump):
self.cache_dir = default_dump_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
elif (override):
self.cache_dir = default_override_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
else:
raise RuntimeError("Could not create or locate cache dir")
# create cache directory if it doesn't exist
self.cache_dir = os.getenv('TRITON_CACHE_DIR', "").strip() or default_cache_dir()
if self.cache_dir:
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
else:
raise RuntimeError("Could not create or locate cache dir")
def _make_path(self, filename) -> str:
return os.path.join(self.cache_dir, filename)
@@ -131,3 +148,11 @@ def get_cache_manager(key) -> CacheManager:
__cache_cls_nme = user_cache_manager
return __cache_cls(key)
def get_override_manager(key) -> CacheManager:
return __cache_cls(key, override=True)
def get_dump_manager(key) -> CacheManager:
return __cache_cls(key, dump=True)