[FRONTEND] change hash to not require ptxas (#2476)

I noticed that Triton is using the `ptxas` version as part of the
version hash even for non-CUDA targets. This is an attempt at fixing
this. Moving the version calculation to the back-end makes sense to me
from an architectural standpoint, so that's my approach here. I'm not as
confident in the implementation, so please if folks have any feedback
let me know.
This commit is contained in:
ian Bearman
2023-10-17 10:28:51 -07:00
committed by GitHub
parent 376acb610b
commit 768fc1fcd9
6 changed files with 76 additions and 55 deletions

View File

@@ -13,12 +13,12 @@ import torch
import triton
import triton.language as tl
from triton.common.backend import BaseBackend, register_backend
from triton.common.backend import (BaseBackend, compute_core_version_key,
register_backend)
from triton.common.build import quiet
from triton.compiler.make_launcher import make_so_cache_key
from triton.runtime.cache import get_cache_manager
from triton.runtime.driver import DriverBase
from triton.runtime.jit import version_key
def build_for_backend(name, src, srcdir):
@@ -125,6 +125,7 @@ class ExtensionBackend(BaseBackend):
def __init__(self, device_type: str) -> None:
super(ExtensionBackend, self).__init__(device_type)
self.driver = ExtensionDriver()
self.version_key = None
def add_stages(self, arch, extern_libs, stages):
filter_in_stages = ["ast", "ttir", "ttgir"]
@@ -163,9 +164,14 @@ class ExtensionBackend(BaseBackend):
def get_architecture_descriptor(self, **kwargs):
return ""
def get_version_key(self):
if self.version_key is None:
self.version_key = compute_core_version_key()
return self.version_key
def make_launcher_stub(self, name, signature, constants):
# name of files that are cached
so_cache_key = make_so_cache_key(version_key(), signature, constants)
so_cache_key = make_so_cache_key(self.get_version_key(), signature, constants)
so_cache_manager = get_cache_manager(so_cache_key)
so_name = f"{name}.so"
# retrieve stub from cache if it exists

View File

@@ -1,5 +1,6 @@
import functools
import hashlib
import importlib
import importlib.util
import os
@@ -10,6 +11,9 @@ from typing import Dict
from ..runtime.driver import DriverBase
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TRITON_VERSION = "2.1.0"
class BaseBackend:
def __init__(self, device_type: str) -> None:
@@ -132,3 +136,44 @@ def path_to_cuobjdump():
@functools.lru_cache()
def path_to_nvdisasm():
return _path_to_binary("nvdisasm")
@functools.lru_cache()
def compute_core_version_key():
import pkgutil
contents = []
# frontend
with open(__file__, "rb") as f:
contents += [hashlib.sha1(f.read()).hexdigest()]
# compiler
compiler_path = os.path.join(TRITON_PATH, 'compiler')
for lib in pkgutil.iter_modules([compiler_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.sha1(f.read()).hexdigest()]
# backend
libtriton_hash = hashlib.sha1()
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
while True:
chunk = f.read(1024 ** 2)
if not chunk:
break
libtriton_hash.update(chunk)
contents.append(libtriton_hash.hexdigest())
# language
language_path = os.path.join(TRITON_PATH, 'language')
for lib in pkgutil.iter_modules([language_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.sha1(f.read()).hexdigest()]
return '-'.join(TRITON_VERSION) + '-'.join(contents)
_cached_cuda_version_key = None
def get_cuda_version_key():
global _cached_cuda_version_key
if _cached_cuda_version_key is None:
key = compute_core_version_key()
ptxas = path_to_ptxas()[0]
_cached_cuda_version_key = key + '-' + hashlib.sha1(subprocess.check_output([ptxas, "--version"])).hexdigest()
return _cached_cuda_version_key

View File

@@ -16,7 +16,7 @@ from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs,
get_shared_memory_size, ir, runtime,
translate_llvmir_to_ptx,
translate_triton_gpu_to_llvmir)
from ..common.backend import get_backend, path_to_ptxas
from ..common.backend import get_backend, get_cuda_version_key, path_to_ptxas
from ..common.build import is_hip
# from ..runtime import driver, jit, JITFunction
# TODO: runtime.errors
@@ -24,7 +24,7 @@ from ..runtime.autotuner import OutOfResources
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
from ..runtime.driver import driver
from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device,
get_device_capability, version_key)
get_device_capability)
from ..tools.disasm import get_sass
from .code_generator import ast_to_ttir
from .make_launcher import make_stub
@@ -235,7 +235,11 @@ def convert_type_repr(x):
return x
def make_hash(fn, target, env_vars, **kwargs):
def make_hash(fn, target, env_vars, device_backend, **kwargs):
if device_backend is None:
version_key = get_cuda_version_key()
else:
version_key = device_backend.get_version_key()
if isinstance(fn, JITFunction):
configs = kwargs["configs"]
signature = kwargs["signature"]
@@ -250,13 +254,13 @@ def make_hash(fn, target, env_vars, **kwargs):
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8))
configs_key = [get_conf_key(conf) for conf in configs]
env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())]
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}"
key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}"
return hashlib.md5(key.encode("utf-8")).hexdigest()
assert isinstance(fn, str)
ignore_version = kwargs.get('ignore_version', False)
if (ignore_version):
return hashlib.md5((Path(fn).read_text()).encode("utf-8")).hexdigest()
return hashlib.md5((Path(fn).read_text() + version_key()).encode("utf-8")).hexdigest()
return hashlib.md5((Path(fn).read_text() + version_key).encode("utf-8")).hexdigest()
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
@@ -452,11 +456,11 @@ def compile(fn, **kwargs):
first_stage = list(stages.keys()).index(ir_name)
# create cache manager
fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), **kwargs))
fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), _device_backend, **kwargs))
# managers used to dump and override IR for debugging
enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1"
fn_override_manager = get_override_manager(make_hash(fn, target, get_env_vars(), **kwargs, ignore_version=True))
fn_dump_manager = get_dump_manager(make_hash(fn, target, get_env_vars(), **kwargs, ignore_version=True))
fn_override_manager = get_override_manager(make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True))
fn_dump_manager = get_dump_manager(make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True))
# determine name and extension type of provided function
if isinstance(fn, JITFunction):

View File

@@ -3,9 +3,9 @@ import os
import tempfile
from ..common import _build
from ..common.backend import get_cuda_version_key
from ..common.build import is_hip
from ..runtime.cache import get_cache_manager
from ..runtime.jit import version_key
from .utils import generate_cu_signature
# ----- stub --------
@@ -23,7 +23,7 @@ def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
def make_stub(name, signature, constants, ids, **kwargs):
# name of files that are cached
so_cache_key = make_so_cache_key(version_key(), signature, constants, ids, **kwargs)
so_cache_key = make_so_cache_key(get_cuda_version_key(), signature, constants, ids, **kwargs)
so_cache_manager = get_cache_manager(so_cache_key)
so_name = f"{name}.so"
# retrieve stub from cache if it exists

View File

@@ -1,8 +1,7 @@
from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune,
heuristics)
from .driver import driver
from .jit import (JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret,
version_key)
from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret
__all__ = [
"driver",
@@ -12,7 +11,6 @@ __all__ = [
"heuristics",
"JITFunction",
"KernelInterface",
"version_key",
"reinterpret",
"TensorWrapper",
"OutOfResources",

View File

@@ -5,20 +5,16 @@ import functools
import hashlib
import inspect
import os
import subprocess
import textwrap
from collections import defaultdict, namedtuple
from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast,
overload)
from .._C.libtriton.triton import TMAInfos
from ..common.backend import get_backend, path_to_ptxas
from ..common.backend import get_backend, get_cuda_version_key
from ..language.core import dtype
from .interpreter import InterpretedFunction
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TRITON_VERSION = "2.1.0"
def get_cuda_stream(idx=None):
if idx is None:
@@ -99,38 +95,6 @@ class DependenciesFinder(ast.NodeVisitor):
# -----------------------------------------------------------------------------
@functools.lru_cache()
def version_key():
import pkgutil
contents = []
# frontend
with open(__file__, "rb") as f:
contents += [hashlib.sha1(f.read()).hexdigest()]
# compiler
compiler_path = os.path.join(TRITON_PATH, 'compiler')
for lib in pkgutil.iter_modules([compiler_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.sha1(f.read()).hexdigest()]
# backend
libtriton_hash = hashlib.sha1()
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
while True:
chunk = f.read(1024 ** 2)
if not chunk:
break
libtriton_hash.update(chunk)
contents.append(libtriton_hash.hexdigest())
# language
language_path = os.path.join(TRITON_PATH, 'language')
for lib in pkgutil.iter_modules([language_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.sha1(f.read()).hexdigest()]
# ptxas version
ptxas = path_to_ptxas()[0]
ptxas_version = hashlib.sha1(subprocess.check_output([ptxas, "--version"])).hexdigest()
return '-'.join(TRITON_VERSION) + '-' + ptxas_version + '-' + '-'.join(contents)
def _normalize_ty(ty) -> str:
if isinstance(ty, type):
return ty.__name__
@@ -396,7 +360,11 @@ class JITFunction(KernelInterface[T]):
if num_stages is None:
num_stages = get_arch_default_num_stages(device_type)
key = (version_key(), sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, self.debug)
if device_type in ['cuda']:
version_key = get_cuda_version_key()
else:
version_key = device_backend.get_version_key()
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, self.debug)
if extern_libs is not None:
key = (key, tuple(extern_libs.items()))
@@ -492,7 +460,7 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
if self.hash is None:
dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src)
dependencies_finder.visit(self.parse())
self.hash = dependencies_finder.ret + version_key()
self.hash = dependencies_finder.ret
return self.hash
def warmup(self, *args, **kwargs):