mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
[FRONTEND] change hash to not require ptxas (#2476)
I noticed that Triton is using the `ptxas` version as part of the version hash even for non-CUDA targets. This is an attempt at fixing this. Moving the version calculation to the back-end makes sense to me from an architectural standpoint, so that's my approach here. I'm not as confident in the implementation, so please if folks have any feedback let me know.
This commit is contained in:
@@ -13,12 +13,12 @@ import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.common.backend import BaseBackend, register_backend
|
||||
from triton.common.backend import (BaseBackend, compute_core_version_key,
|
||||
register_backend)
|
||||
from triton.common.build import quiet
|
||||
from triton.compiler.make_launcher import make_so_cache_key
|
||||
from triton.runtime.cache import get_cache_manager
|
||||
from triton.runtime.driver import DriverBase
|
||||
from triton.runtime.jit import version_key
|
||||
|
||||
|
||||
def build_for_backend(name, src, srcdir):
|
||||
@@ -125,6 +125,7 @@ class ExtensionBackend(BaseBackend):
|
||||
def __init__(self, device_type: str) -> None:
|
||||
super(ExtensionBackend, self).__init__(device_type)
|
||||
self.driver = ExtensionDriver()
|
||||
self.version_key = None
|
||||
|
||||
def add_stages(self, arch, extern_libs, stages):
|
||||
filter_in_stages = ["ast", "ttir", "ttgir"]
|
||||
@@ -163,9 +164,14 @@ class ExtensionBackend(BaseBackend):
|
||||
def get_architecture_descriptor(self, **kwargs):
|
||||
return ""
|
||||
|
||||
def get_version_key(self):
|
||||
if self.version_key is None:
|
||||
self.version_key = compute_core_version_key()
|
||||
return self.version_key
|
||||
|
||||
def make_launcher_stub(self, name, signature, constants):
|
||||
# name of files that are cached
|
||||
so_cache_key = make_so_cache_key(version_key(), signature, constants)
|
||||
so_cache_key = make_so_cache_key(self.get_version_key(), signature, constants)
|
||||
so_cache_manager = get_cache_manager(so_cache_key)
|
||||
so_name = f"{name}.so"
|
||||
# retrieve stub from cache if it exists
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
@@ -10,6 +11,9 @@ from typing import Dict
|
||||
|
||||
from ..runtime.driver import DriverBase
|
||||
|
||||
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
TRITON_VERSION = "2.1.0"
|
||||
|
||||
|
||||
class BaseBackend:
|
||||
def __init__(self, device_type: str) -> None:
|
||||
@@ -132,3 +136,44 @@ def path_to_cuobjdump():
|
||||
@functools.lru_cache()
|
||||
def path_to_nvdisasm():
|
||||
return _path_to_binary("nvdisasm")
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def compute_core_version_key():
|
||||
import pkgutil
|
||||
contents = []
|
||||
# frontend
|
||||
with open(__file__, "rb") as f:
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
# compiler
|
||||
compiler_path = os.path.join(TRITON_PATH, 'compiler')
|
||||
for lib in pkgutil.iter_modules([compiler_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
# backend
|
||||
libtriton_hash = hashlib.sha1()
|
||||
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(1024 ** 2)
|
||||
if not chunk:
|
||||
break
|
||||
libtriton_hash.update(chunk)
|
||||
contents.append(libtriton_hash.hexdigest())
|
||||
# language
|
||||
language_path = os.path.join(TRITON_PATH, 'language')
|
||||
for lib in pkgutil.iter_modules([language_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
return '-'.join(TRITON_VERSION) + '-'.join(contents)
|
||||
|
||||
|
||||
_cached_cuda_version_key = None
|
||||
|
||||
|
||||
def get_cuda_version_key():
|
||||
global _cached_cuda_version_key
|
||||
if _cached_cuda_version_key is None:
|
||||
key = compute_core_version_key()
|
||||
ptxas = path_to_ptxas()[0]
|
||||
_cached_cuda_version_key = key + '-' + hashlib.sha1(subprocess.check_output([ptxas, "--version"])).hexdigest()
|
||||
return _cached_cuda_version_key
|
||||
|
||||
@@ -16,7 +16,7 @@ from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs,
|
||||
get_shared_memory_size, ir, runtime,
|
||||
translate_llvmir_to_ptx,
|
||||
translate_triton_gpu_to_llvmir)
|
||||
from ..common.backend import get_backend, path_to_ptxas
|
||||
from ..common.backend import get_backend, get_cuda_version_key, path_to_ptxas
|
||||
from ..common.build import is_hip
|
||||
# from ..runtime import driver, jit, JITFunction
|
||||
# TODO: runtime.errors
|
||||
@@ -24,7 +24,7 @@ from ..runtime.autotuner import OutOfResources
|
||||
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
|
||||
from ..runtime.driver import driver
|
||||
from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device,
|
||||
get_device_capability, version_key)
|
||||
get_device_capability)
|
||||
from ..tools.disasm import get_sass
|
||||
from .code_generator import ast_to_ttir
|
||||
from .make_launcher import make_stub
|
||||
@@ -235,7 +235,11 @@ def convert_type_repr(x):
|
||||
return x
|
||||
|
||||
|
||||
def make_hash(fn, target, env_vars, **kwargs):
|
||||
def make_hash(fn, target, env_vars, device_backend, **kwargs):
|
||||
if device_backend is None:
|
||||
version_key = get_cuda_version_key()
|
||||
else:
|
||||
version_key = device_backend.get_version_key()
|
||||
if isinstance(fn, JITFunction):
|
||||
configs = kwargs["configs"]
|
||||
signature = kwargs["signature"]
|
||||
@@ -250,13 +254,13 @@ def make_hash(fn, target, env_vars, **kwargs):
|
||||
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8))
|
||||
configs_key = [get_conf_key(conf) for conf in configs]
|
||||
env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())]
|
||||
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}"
|
||||
key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}"
|
||||
return hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
assert isinstance(fn, str)
|
||||
ignore_version = kwargs.get('ignore_version', False)
|
||||
if (ignore_version):
|
||||
return hashlib.md5((Path(fn).read_text()).encode("utf-8")).hexdigest()
|
||||
return hashlib.md5((Path(fn).read_text() + version_key()).encode("utf-8")).hexdigest()
|
||||
return hashlib.md5((Path(fn).read_text() + version_key).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
|
||||
@@ -452,11 +456,11 @@ def compile(fn, **kwargs):
|
||||
first_stage = list(stages.keys()).index(ir_name)
|
||||
|
||||
# create cache manager
|
||||
fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), **kwargs))
|
||||
fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), _device_backend, **kwargs))
|
||||
# managers used to dump and override IR for debugging
|
||||
enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1"
|
||||
fn_override_manager = get_override_manager(make_hash(fn, target, get_env_vars(), **kwargs, ignore_version=True))
|
||||
fn_dump_manager = get_dump_manager(make_hash(fn, target, get_env_vars(), **kwargs, ignore_version=True))
|
||||
fn_override_manager = get_override_manager(make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True))
|
||||
fn_dump_manager = get_dump_manager(make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True))
|
||||
|
||||
# determine name and extension type of provided function
|
||||
if isinstance(fn, JITFunction):
|
||||
|
||||
@@ -3,9 +3,9 @@ import os
|
||||
import tempfile
|
||||
|
||||
from ..common import _build
|
||||
from ..common.backend import get_cuda_version_key
|
||||
from ..common.build import is_hip
|
||||
from ..runtime.cache import get_cache_manager
|
||||
from ..runtime.jit import version_key
|
||||
from .utils import generate_cu_signature
|
||||
|
||||
# ----- stub --------
|
||||
@@ -23,7 +23,7 @@ def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
|
||||
|
||||
def make_stub(name, signature, constants, ids, **kwargs):
|
||||
# name of files that are cached
|
||||
so_cache_key = make_so_cache_key(version_key(), signature, constants, ids, **kwargs)
|
||||
so_cache_key = make_so_cache_key(get_cuda_version_key(), signature, constants, ids, **kwargs)
|
||||
so_cache_manager = get_cache_manager(so_cache_key)
|
||||
so_name = f"{name}.so"
|
||||
# retrieve stub from cache if it exists
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune,
|
||||
heuristics)
|
||||
from .driver import driver
|
||||
from .jit import (JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret,
|
||||
version_key)
|
||||
from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret
|
||||
|
||||
__all__ = [
|
||||
"driver",
|
||||
@@ -12,7 +11,6 @@ __all__ = [
|
||||
"heuristics",
|
||||
"JITFunction",
|
||||
"KernelInterface",
|
||||
"version_key",
|
||||
"reinterpret",
|
||||
"TensorWrapper",
|
||||
"OutOfResources",
|
||||
|
||||
@@ -5,20 +5,16 @@ import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import subprocess
|
||||
import textwrap
|
||||
from collections import defaultdict, namedtuple
|
||||
from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast,
|
||||
overload)
|
||||
|
||||
from .._C.libtriton.triton import TMAInfos
|
||||
from ..common.backend import get_backend, path_to_ptxas
|
||||
from ..common.backend import get_backend, get_cuda_version_key
|
||||
from ..language.core import dtype
|
||||
from .interpreter import InterpretedFunction
|
||||
|
||||
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
TRITON_VERSION = "2.1.0"
|
||||
|
||||
|
||||
def get_cuda_stream(idx=None):
|
||||
if idx is None:
|
||||
@@ -99,38 +95,6 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def version_key():
|
||||
import pkgutil
|
||||
contents = []
|
||||
# frontend
|
||||
with open(__file__, "rb") as f:
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
# compiler
|
||||
compiler_path = os.path.join(TRITON_PATH, 'compiler')
|
||||
for lib in pkgutil.iter_modules([compiler_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
# backend
|
||||
libtriton_hash = hashlib.sha1()
|
||||
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(1024 ** 2)
|
||||
if not chunk:
|
||||
break
|
||||
libtriton_hash.update(chunk)
|
||||
contents.append(libtriton_hash.hexdigest())
|
||||
# language
|
||||
language_path = os.path.join(TRITON_PATH, 'language')
|
||||
for lib in pkgutil.iter_modules([language_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.sha1(f.read()).hexdigest()]
|
||||
# ptxas version
|
||||
ptxas = path_to_ptxas()[0]
|
||||
ptxas_version = hashlib.sha1(subprocess.check_output([ptxas, "--version"])).hexdigest()
|
||||
return '-'.join(TRITON_VERSION) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||
|
||||
|
||||
def _normalize_ty(ty) -> str:
|
||||
if isinstance(ty, type):
|
||||
return ty.__name__
|
||||
@@ -396,7 +360,11 @@ class JITFunction(KernelInterface[T]):
|
||||
if num_stages is None:
|
||||
num_stages = get_arch_default_num_stages(device_type)
|
||||
|
||||
key = (version_key(), sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, self.debug)
|
||||
if device_type in ['cuda']:
|
||||
version_key = get_cuda_version_key()
|
||||
else:
|
||||
version_key = device_backend.get_version_key()
|
||||
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, self.debug)
|
||||
if extern_libs is not None:
|
||||
key = (key, tuple(extern_libs.items()))
|
||||
|
||||
@@ -492,7 +460,7 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
|
||||
if self.hash is None:
|
||||
dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src)
|
||||
dependencies_finder.visit(self.parse())
|
||||
self.hash = dependencies_finder.ret + version_key()
|
||||
self.hash = dependencies_finder.ret
|
||||
return self.hash
|
||||
|
||||
def warmup(self, *args, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user