mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
772 lines
32 KiB
Python
772 lines
32 KiB
Python
from __future__ import annotations
|
|
|
|
import functools
|
|
import hashlib
|
|
import json
|
|
import os
|
|
import re
|
|
from collections import namedtuple
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs, compile_ptx_to_cubin, get_env_vars,
|
|
get_num_warps, get_shared_memory_size, ir, runtime, translate_llvmir_to_ptx,
|
|
translate_triton_gpu_to_llvmir)
|
|
from ..common.backend import get_backend, get_cuda_version_key, path_to_ptxas
|
|
from ..common.build import is_hip
|
|
# from ..runtime import driver, jit, JITFunction
|
|
# TODO: runtime.errors
|
|
from ..runtime.autotuner import OutOfResources
|
|
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
|
|
from ..runtime.driver import driver
|
|
from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device, get_device_capability)
|
|
from ..tools.disasm import get_sass
|
|
from .code_generator import ast_to_ttir
|
|
from .make_launcher import make_stub
|
|
from .utils import (InfoFromBackendForTensorMap, TensorMapManager, get_ids_of_tensormaps, parse_tma_info)
|
|
|
|
CUDA_DEFAULT_WARP_SIZE = 32
|
|
|
|
@dataclass
|
|
class CudaTargetDescriptor:
|
|
capability: int
|
|
num_warps: int
|
|
enable_fp_fusion: bool
|
|
|
|
|
|
def _is_cuda(target):
|
|
return isinstance(target, CudaTargetDescriptor)
|
|
|
|
|
|
class LazyDict(dict):
|
|
|
|
def __getitem__(self, key):
|
|
val = dict.__getitem__(self, key)
|
|
if callable(val):
|
|
return val()
|
|
return val
|
|
|
|
|
|
def inline_triton_ir(mod):
|
|
pm = ir.pass_manager(mod.context)
|
|
pm.enable_debug()
|
|
pm.add_inliner_pass()
|
|
pm.run(mod)
|
|
return mod
|
|
|
|
|
|
def ttir_compute_capability_rewrite(mod, target):
|
|
# For hardware without support, we must rewrite all load/store
|
|
# with block (tensor) pointers into tensors of pointers
|
|
pm = ir.pass_manager(mod.context)
|
|
pm.enable_debug()
|
|
if _is_cuda(target):
|
|
pm.add_rewrite_tensor_pointer_pass(target.capability, False)
|
|
elif is_hip():
|
|
pm.add_rewrite_tensor_pointer_pass(target["capability"], True)
|
|
else:
|
|
assert(False, "unsupported target")
|
|
pm.run(mod)
|
|
return mod
|
|
|
|
|
|
def optimize_ttir(mod, target):
|
|
mod = inline_triton_ir(mod)
|
|
mod = ttir_compute_capability_rewrite(mod, target)
|
|
pm = ir.pass_manager(mod.context)
|
|
pm.enable_debug()
|
|
pm.add_inliner_pass()
|
|
pm.add_triton_combine_pass()
|
|
pm.add_canonicalizer_pass()
|
|
pm.add_reorder_broadcast_pass()
|
|
pm.add_cse_pass()
|
|
pm.add_licm_pass()
|
|
pm.add_symbol_dce_pass()
|
|
pm.run(mod)
|
|
return mod
|
|
|
|
|
|
def ttir_to_ttgir(mod, num_warps, warpsize, num_ctas, target):
|
|
pm = ir.pass_manager(mod.context)
|
|
pm.enable_debug()
|
|
if is_hip():
|
|
pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize, num_ctas, 0)
|
|
else:
|
|
pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize, num_ctas, target.capability)
|
|
pm.run(mod)
|
|
return mod
|
|
|
|
|
|
def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization,
|
|
enable_persistent, optimize_epilogue, matrix_inst_type, slice_k_tile):
|
|
is_cuda = _is_cuda(target)
|
|
if is_cuda:
|
|
capability = target.capability
|
|
pm = ir.pass_manager(mod.context)
|
|
pm.enable_debug()
|
|
pm.add_tritongpu_coalesce_pass()
|
|
# TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass
|
|
pm.add_plan_cta_pass(cluster_info)
|
|
if is_cuda:
|
|
pm.add_tritongpu_rewrite_tensor_pointer_pass(capability)
|
|
pm.add_plan_cta_pass(cluster_info)
|
|
pm.add_tritongpu_remove_layout_conversions_pass()
|
|
if is_cuda:
|
|
pm.add_tritongpu_accelerate_matmul_pass(capability)
|
|
# TODO change interface of accelerate_matmul_pass
|
|
if is_hip():
|
|
gfx_arch = target["gfx_arch"]
|
|
matrix_inst_size = matrix_inst_type
|
|
pm.add_tritonamdgpu_accelerate_matmul_pass(gfx_arch, matrix_inst_size)
|
|
pm.add_tritongpu_remove_layout_conversions_pass()
|
|
# if optimize_epilogue:
|
|
pm.add_tritongpu_optimize_epilogue_pass()
|
|
pm.add_tritonamdgpu_dot_slicing_pass(slice_k_tile)
|
|
pm.add_tritongpu_optimize_dot_operands_pass()
|
|
if num_stages == 0 and is_hip() and target["matrix_core_version"] != 0:
|
|
pm.add_tritongpu_stream_pipeline_pass()
|
|
pm.add_canonicalizer_pass()
|
|
ws_enabled = False
|
|
# `num_warps` does not mean the total number of warps of a CTA when
|
|
# warp specialization is enabled.
|
|
# it's the responsibility of the compiler to figure out the exact
|
|
# `num_warps` to use.
|
|
# TODO: support the case where `num_warps` from user is not 4.
|
|
if is_cuda and capability // 10 >= 9 and enable_warp_specialization and num_warps == 4:
|
|
pm.add_tritongpu_ws_feasibility_checking_pass(capability)
|
|
pm.run(mod)
|
|
ws_enabled = ir.is_ws_supported(mod)
|
|
pm = ir.pass_manager(mod.context)
|
|
pm.enable_debug()
|
|
if ws_enabled:
|
|
pm.add_tritongpu_wsdecomposing_pass(capability)
|
|
pm.add_tritongpu_wspipeline_pass(num_stages, num_warps, capability)
|
|
pm.add_tritongpu_wsmutex_pass(capability)
|
|
pm.add_tritongpu_wsmaterialization_pass(capability)
|
|
pm.add_licm_pass()
|
|
pm.add_cse_pass()
|
|
else:
|
|
if is_hip():
|
|
pm.add_tritongpu_pipeline_pass(
|
|
num_stages, num_warps, num_ctas, 0)
|
|
else:
|
|
pm.add_tritongpu_pipeline_pass(
|
|
num_stages, num_warps, num_ctas, capability)
|
|
if is_hip():
|
|
pm.add_tritongpu_materialize_load_store_pass(num_warps, 0)
|
|
else:
|
|
pm.add_tritongpu_materialize_load_store_pass(num_warps, capability)
|
|
if is_cuda and capability // 10 <= 8:
|
|
pm.add_tritongpu_prefetch_pass()
|
|
pm.add_tritongpu_optimize_dot_operands_pass()
|
|
pm.add_tritongpu_remove_layout_conversions_pass()
|
|
pm.add_tritongpu_decompose_conversions_pass()
|
|
pm.add_tritongpu_ws_fixup_missing_attrs_pass()
|
|
if is_hip() and num_stages != 0:
|
|
pm.add_tritonamdgpu_reorder_instructions_pass()
|
|
else:
|
|
pm.add_tritongpu_reorder_instructions_pass()
|
|
|
|
pm.add_cse_pass()
|
|
pm.add_symbol_dce_pass()
|
|
if is_cuda and capability // 10 >= 9:
|
|
pm.add_tritongpu_fence_insertion_pass()
|
|
pm.add_tritongpu_ws_fixup_missing_attrs_pass()
|
|
pm.add_tritongpu_optimize_thread_locality_pass()
|
|
pm.add_canonicalizer_pass()
|
|
pm.run(mod)
|
|
return mod
|
|
|
|
|
|
def _add_external_libs(mod, libs):
|
|
for name, path in libs.items():
|
|
if len(name) == 0 or len(path) == 0:
|
|
return
|
|
add_external_libs(mod, list(libs.keys()), list(libs.values()))
|
|
|
|
|
|
def ttgir_to_llir(mod, extern_libs, target, tma_infos, waves_per_eu=0):
|
|
if extern_libs:
|
|
_add_external_libs(mod, extern_libs)
|
|
# TODO: separate tritongpu_to_llvmir for different backends
|
|
if _is_cuda(target):
|
|
return translate_triton_gpu_to_llvmir(mod, target.capability, tma_infos, runtime.TARGET.NVVM, waves_per_eu)
|
|
else:
|
|
return translate_triton_gpu_to_llvmir(mod, target["capability"], TMAInfos(), runtime.TARGET.ROCDL, waves_per_eu)
|
|
|
|
|
|
# PTX translation
|
|
|
|
|
|
@functools.lru_cache()
|
|
def ptx_get_version(cuda_version) -> int:
|
|
'''
|
|
Get the highest PTX version supported by the current CUDA driver.
|
|
'''
|
|
assert isinstance(cuda_version, str)
|
|
major, minor = map(int, cuda_version.split('.'))
|
|
if major == 12:
|
|
return 80 + minor
|
|
if major == 11:
|
|
return 70 + minor
|
|
if major == 10:
|
|
return 63 + minor
|
|
raise RuntimeError("Triton only support CUDA 10.0 or higher")
|
|
|
|
|
|
def llir_to_ptx(mod: Any, target: CudaTargetDescriptor, ptx_version: int = None) -> str:
|
|
'''
|
|
Translate TritonGPU module to PTX code.
|
|
:param mod: a TritonGPU dialect module
|
|
:return: PTX code
|
|
'''
|
|
if ptx_version is None:
|
|
_, cuda_version = path_to_ptxas()
|
|
ptx_version = ptx_get_version(cuda_version)
|
|
return translate_llvmir_to_ptx(mod, target.capability, ptx_version, target.enable_fp_fusion)
|
|
|
|
|
|
def ptx_to_cubin(ptx: str, target: CudaTargetDescriptor):
|
|
'''
|
|
Compile TritonGPU module to cubin.
|
|
:param ptx: ptx code
|
|
:param compute_capability: compute capability
|
|
:return: str
|
|
'''
|
|
ptxas, _ = path_to_ptxas()
|
|
return compile_ptx_to_cubin(ptx, ptxas, target.capability, target.enable_fp_fusion)
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# compiler
|
|
# ------------------------------------------------------------------------------
|
|
def get_kernel_name(src: str, pattern: str) -> str:
|
|
'''
|
|
Get kernel name from PTX code.
|
|
This Kernel name is required when launching the kernel.
|
|
'''
|
|
# There is a name mangling in PTX codegen, so the original kernel names in Triton IR are not available in PTX/cubin.
|
|
assert src
|
|
for line in src.split('\n'):
|
|
line = line.strip()
|
|
if line.startswith(pattern):
|
|
return line.split()[-1]
|
|
|
|
|
|
def convert_type_repr(x):
|
|
# Currently we only capture the pointer type and assume the pointer is on global memory.
|
|
# TODO: Capture and support shared memory space
|
|
match = re.search(r'!tt\.ptr<([^,]+)', x)
|
|
if match is not None:
|
|
return '*' + convert_type_repr(match.group(1))
|
|
return x
|
|
|
|
|
|
def make_hash(fn, target, env_vars, device_backend, **kwargs):
|
|
if device_backend is None:
|
|
version_key = get_cuda_version_key()
|
|
else:
|
|
version_key = device_backend.get_version_key()
|
|
if isinstance(fn, JITFunction):
|
|
configs = kwargs["configs"]
|
|
signature = kwargs["signature"]
|
|
constants = kwargs.get("constants", dict())
|
|
num_warps = kwargs.get("num_warps", 4)
|
|
num_ctas = kwargs.get("num_ctas", 1)
|
|
num_stages = kwargs.get("num_stages", 3)
|
|
waves_per_eu = kwargs.get("waves_per_eu", 0)
|
|
slice_k_tile = kwargs.get("slice_k_tile", 0)
|
|
matrix_instr_nonkdim = kwargs.get("matrix_instr_nonkdim", 0);
|
|
enable_warp_specialization = kwargs.get("enable_warp_specialization", False)
|
|
enable_persistent = kwargs.get("enable_persistent", False)
|
|
debug = kwargs.get("debug", False)
|
|
# Get unique key for the compiled code
|
|
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1),
|
|
sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8))
|
|
configs_key = [get_conf_key(conf) for conf in configs]
|
|
env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())]
|
|
key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{slice_k_tile}-{matrix_instr_nonkdim}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}"
|
|
return hashlib.md5(key.encode("utf-8")).hexdigest()
|
|
assert isinstance(fn, str)
|
|
ignore_version = kwargs.get('ignore_version', False)
|
|
if (ignore_version):
|
|
return hashlib.md5((Path(fn).read_text()).encode("utf-8")).hexdigest()
|
|
return hashlib.md5((Path(fn).read_text() + version_key).encode("utf-8")).hexdigest()
|
|
|
|
|
|
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
|
|
# and any following whitespace
|
|
# - (public\s+)? : optionally match the keyword public and any following whitespace
|
|
# - (@\w+) : match an @ symbol followed by one or more word characters
|
|
# (letters, digits, or underscores), and capture it as group 1 (the function name)
|
|
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
|
|
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
|
|
# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3
|
|
mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$"
|
|
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
|
|
prototype_pattern = {
|
|
"ttir": mlir_prototype_pattern,
|
|
"ttgir": mlir_prototype_pattern,
|
|
"ptx": ptx_prototype_pattern,
|
|
}
|
|
|
|
# - ((?:[^,\s<]+|<[^>]+>)+): Capturing group that matches one or more of either:
|
|
# [^,\s<]+: One or more characters that are not a comma, whitespace, or the < symbol.
|
|
# |: OR
|
|
# <[^>]+>: A string that starts with < and ends with >, containing any characters except > in between.
|
|
mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<]+|<[^>]+>)+),?'
|
|
ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
|
|
arg_type_pattern = {
|
|
"ttir": mlir_arg_type_pattern,
|
|
"ttgir": mlir_arg_type_pattern,
|
|
"ptx": ptx_arg_type_pattern,
|
|
}
|
|
if is_hip():
|
|
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
|
|
else:
|
|
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
|
|
|
|
|
|
def _get_jsonable_constants(constants):
|
|
|
|
def _is_jsonable(x):
|
|
try:
|
|
json.dumps(x)
|
|
return True
|
|
except (TypeError, OverflowError):
|
|
return False
|
|
|
|
serialized_constants = {}
|
|
for constant in constants:
|
|
if _is_jsonable(constants[constant]):
|
|
serialized_constants[constant] = constants[constant]
|
|
return serialized_constants
|
|
|
|
|
|
def parse_mlir_module(path, context):
|
|
module = ir.parse_mlir_module(path, context)
|
|
# module takes ownership of the context
|
|
module.context = context
|
|
return module
|
|
|
|
|
|
instance_descriptor = namedtuple("instance_descriptor",
|
|
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"],
|
|
defaults=[set(), set(), set(), set()])
|
|
|
|
|
|
def is_hip():
|
|
try:
|
|
import torch
|
|
except ImportError:
|
|
raise ImportError("Triton requires PyTorch to be installed")
|
|
return torch.version.hip is not None
|
|
|
|
def get_cuda_capability(capability):
|
|
if capability is None:
|
|
device = get_current_device()
|
|
capability = get_device_capability(device)
|
|
capability = capability[0] * 10 + capability[1]
|
|
return capability
|
|
|
|
@functools.lru_cache
|
|
def get_arch_default_num_warps(device_type):
|
|
if device_type in ["cuda", "hip"]:
|
|
num_warps = 4
|
|
else:
|
|
_device_backend = get_backend(device_type)
|
|
assert _device_backend
|
|
arch = _device_backend.get_architecture_descriptor()
|
|
num_warps = arch["num_warps"]
|
|
return num_warps
|
|
|
|
@functools.lru_cache
|
|
def get_arch_default_num_stages(device_type, capability=None):
|
|
if device_type == "cuda":
|
|
num_stages = 3 if get_cuda_capability(capability) >= 75 else 2
|
|
else:
|
|
_device_backend = get_backend(device_type)
|
|
assert _device_backend
|
|
arch = _device_backend.get_architecture_descriptor()
|
|
num_stages = arch["num_stages"]
|
|
return num_stages
|
|
|
|
|
|
def add_cuda_stages(target, extern_libs, stages):
|
|
|
|
stages["ptx"] = (lambda path: Path(path).read_text(), lambda src: llir_to_ptx(src, target))
|
|
stages["cubin"] = (lambda path: Path(path).read_bytes(), lambda src: ptx_to_cubin(src, target))
|
|
|
|
|
|
def compile(fn, **kwargs):
|
|
# Get device type to decide which backend should be used
|
|
device_type = kwargs.get("device_type", "cuda")
|
|
capability = kwargs.get("cc", None)
|
|
|
|
if is_hip():
|
|
device_type = "hip"
|
|
capability = None
|
|
|
|
is_cuda = device_type == "cuda"
|
|
if is_hip():
|
|
is_cuda = False
|
|
context = ir.context()
|
|
constants = kwargs.get("constants", dict())
|
|
num_warps = kwargs.get("num_warps", get_arch_default_num_warps(device_type))
|
|
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
|
|
num_ctas = kwargs.get("num_ctas", 1)
|
|
num_stages = kwargs.get("num_stages", get_arch_default_num_stages(device_type, capability=capability))
|
|
waves_per_eu = kwargs.get("waves_per_eu", 0)
|
|
slice_k_tile = kwargs.get("slice_k_tile", 0)
|
|
matrix_instr_nonkdim = kwargs.get("matrix_instr_nonkdim", 0)
|
|
enable_fp_fusion = kwargs.get("enable_fp_fusion", True)
|
|
# TODO[shuhaoj]: Default should be to enable warp specialization once possible
|
|
enable_warp_specialization = kwargs.get("enable_warp_specialization", False)
|
|
# TODO[shuhaoj]: persistent can be decoupled with warp specialization
|
|
enable_persistent = kwargs.get("enable_persistent", enable_warp_specialization)
|
|
extern_libs = kwargs.get("extern_libs", dict())
|
|
if extern_libs is None:
|
|
extern_libs = dict()
|
|
debug = kwargs.get("debug", False)
|
|
# Flag to control whether to store mma layout directly
|
|
optimize_epilogue = False
|
|
if os.environ.get('OPTIMIZE_EPILOGUE', '') == '1':
|
|
optimize_epilogue = True
|
|
#
|
|
cluster_info = ClusterInfo()
|
|
if "clusterDims" in kwargs:
|
|
cluster_info.clusterDimX = kwargs["clusterDims"][0]
|
|
cluster_info.clusterDimY = kwargs["clusterDims"][1]
|
|
cluster_info.clusterDimZ = kwargs["clusterDims"][2]
|
|
tma_infos = TMAInfos()
|
|
# build architecture descriptor
|
|
if device_type == "cuda":
|
|
_device_backend = get_backend(device_type)
|
|
target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps,
|
|
enable_fp_fusion=enable_fp_fusion)
|
|
else:
|
|
_device_backend = get_backend(device_type)
|
|
assert _device_backend
|
|
target = _device_backend.get_architecture_descriptor(**kwargs)
|
|
warp_size = CUDA_DEFAULT_WARP_SIZE if is_cuda else target["warp_size"]
|
|
# build compilation stages
|
|
stages = dict()
|
|
stages["ast"] = (lambda path: fn, None)
|
|
stages["ttir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttir(
|
|
ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))
|
|
if is_cuda:
|
|
stages["ttgir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttgir(
|
|
ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info,
|
|
enable_warp_specialization, enable_persistent, optimize_epilogue, slice_k_tile))
|
|
stages["llir"] = (lambda path: Path(path).read_text(),
|
|
lambda src: ttgir_to_llir(src, extern_libs, target, tma_infos))
|
|
add_cuda_stages(target, extern_libs, stages)
|
|
elif device_type == "hip":
|
|
# pass the user's configuration to the backend device.
|
|
target["num_warps"] = num_warps
|
|
target["num_stages"] = num_stages
|
|
target["num_ctas"] = num_ctas
|
|
|
|
other = {}
|
|
other["context"] = context
|
|
other["warp_size"] = warp_size
|
|
other["cluster_info"] = cluster_info
|
|
other["enable_warp_specialization"] = enable_warp_specialization
|
|
other["enable_persistent"] = enable_persistent
|
|
other["optimize_epilogue"] = optimize_epilogue
|
|
other["tma_infos"] = tma_infos
|
|
other["waves_per_eu"] = waves_per_eu
|
|
other["slice_k_tile"] = slice_k_tile
|
|
other["matrix_instr_nonkdim"] = matrix_instr_nonkdim
|
|
|
|
_device_backend.add_stages(target, extern_libs, stages, other)
|
|
elif device_type == "xpu":
|
|
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
|
|
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, slice_k_tile))
|
|
stages["llir"] = (lambda path: Path(path).read_text(),
|
|
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos))
|
|
_device_backend.add_stages(arch, extern_libs, stages)
|
|
else:
|
|
# pass the user's configuration to the backend device.
|
|
target["num_warps"] = num_warps
|
|
target["num_stages"] = num_stages
|
|
target["num_ctas"] = num_ctas
|
|
_device_backend.add_stages(target, extern_libs, stages)
|
|
|
|
# find out the signature of the function
|
|
if isinstance(fn, JITFunction):
|
|
configs = kwargs.get("configs", None)
|
|
signature = kwargs["signature"]
|
|
if configs is None:
|
|
configs = [instance_descriptor()]
|
|
assert len(configs) == 1
|
|
kwargs["configs"] = configs
|
|
name = fn.__name__
|
|
first_stage = 0
|
|
if isinstance(signature, str):
|
|
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
|
|
kwargs["signature"] = signature
|
|
else:
|
|
assert isinstance(fn, str)
|
|
_, ir_name = os.path.basename(fn).split(".")
|
|
src = Path(fn).read_text()
|
|
import re
|
|
match = re.search(prototype_pattern[ir_name], src, re.MULTILINE)
|
|
# TODO: support function attributes at group 3 (e.g., device function)
|
|
name, signature = match.group(1), match.group(2)
|
|
types = re.findall(arg_type_pattern[ir_name], signature)
|
|
if ir_name == 'ttgir':
|
|
num_warps_matches = re.findall(ttgir_num_warps_pattern, src)
|
|
assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps"
|
|
assert "num_warps" not in kwargs or int(
|
|
num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile"
|
|
num_warps = int(num_warps_matches[0])
|
|
param_tys = [convert_type_repr(ty) for ty in types]
|
|
signature = {k: v for k, v in enumerate(param_tys)}
|
|
first_stage = list(stages.keys()).index(ir_name)
|
|
|
|
# create cache manager
|
|
fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), _device_backend, **kwargs))
|
|
# managers used to dump and override IR for debugging
|
|
enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1"
|
|
fn_override_manager = get_override_manager(
|
|
make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True))
|
|
fn_dump_manager = get_dump_manager(
|
|
make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True))
|
|
|
|
# determine name and extension type of provided function
|
|
if isinstance(fn, JITFunction):
|
|
name, ext = fn.__name__, "ast"
|
|
else:
|
|
name, ext = os.path.basename(fn).split(".")
|
|
|
|
# load metadata if any
|
|
metadata = None
|
|
metadata_filename = f"{name}.json"
|
|
|
|
# The group is addressed by the metadata
|
|
metadata_group = fn_cache_manager.get_group(metadata_filename) or {}
|
|
|
|
metadata_path = metadata_group.get(metadata_filename)
|
|
|
|
if metadata_path is not None:
|
|
with open(metadata_path) as f:
|
|
metadata = json.load(f)
|
|
if 'tensormaps_info' in metadata:
|
|
metadata['tensormaps_info'] = [InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']]
|
|
else:
|
|
metadata = {
|
|
"num_warps": num_warps,
|
|
"warp_size": warp_size,
|
|
"num_ctas": num_ctas,
|
|
"num_stages": num_stages,
|
|
"waves_per_eu": waves_per_eu,
|
|
"slice_k_tile": slice_k_tile,
|
|
"matrix_instr_nonkdim": matrix_instr_nonkdim,
|
|
"enable_warp_specialization": enable_warp_specialization,
|
|
"enable_persistent": enable_persistent,
|
|
"constants": _get_jsonable_constants(constants),
|
|
"debug": debug,
|
|
"target": target,
|
|
}
|
|
metadata.update(get_env_vars())
|
|
if ext == "ptx":
|
|
assert "shared" in kwargs, "ptx compilation must provide shared memory size"
|
|
metadata["shared"] = kwargs["shared"]
|
|
|
|
# Add device type to meta information
|
|
metadata["device_type"] = device_type
|
|
|
|
first_stage = list(stages.keys()).index(ext)
|
|
asm = LazyDict()
|
|
module = fn
|
|
# run compilation pipeline and populate metadata
|
|
for ir_name, (parse, compile_kernel) in list(stages.items())[first_stage:]:
|
|
ir_filename = f"{name}.{ir_name}"
|
|
|
|
if ir_name == ext:
|
|
next_module = parse(fn)
|
|
else:
|
|
path = metadata_group.get(ir_filename)
|
|
if path is None:
|
|
next_module = compile_kernel(module)
|
|
if ir_name == "amdgcn":
|
|
extra_file_name = f"{name}.hsaco_path"
|
|
metadata_group[ir_filename] = fn_cache_manager.put(next_module[0], ir_filename)
|
|
metadata_group[extra_file_name] = fn_cache_manager.put(next_module[1], extra_file_name)
|
|
else:
|
|
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
|
|
fn_dump_manager.put(next_module, ir_filename)
|
|
if (enable_override and fn_override_manager.has_file(ir_filename)):
|
|
print(f"\nOverriding kernel with file {ir_filename}")
|
|
full_name = fn_override_manager.get_file(ir_filename)
|
|
next_module = parse(full_name)
|
|
else:
|
|
if ir_name == "amdgcn":
|
|
extra_file_name = f"{name}.hsaco_path"
|
|
hasco_path = metadata_group.get(extra_file_name)
|
|
assert hasco_path is not None, "Expected to have hsaco_path in metadata when we have the amdgcn"
|
|
next_module = (parse(path), parse(hasco_path))
|
|
else:
|
|
next_module = parse(path)
|
|
|
|
if ir_name == "cubin":
|
|
asm[ir_name] = next_module
|
|
asm["sass"] = lambda: get_sass(next_module)
|
|
elif ir_name == "amdgcn":
|
|
asm[ir_name] = str(next_module[0])
|
|
else:
|
|
asm[ir_name] = str(next_module)
|
|
if ir_name == "llir" and "shared" not in metadata:
|
|
if is_hip():
|
|
metadata["shared"] = _device_backend.get_shared_memory_size(module)
|
|
else:
|
|
metadata["shared"] = get_shared_memory_size(module)
|
|
if ir_name == "ttgir":
|
|
metadata["enable_warp_specialization"] = ir.is_ws_supported(next_module)
|
|
if metadata["enable_warp_specialization"]:
|
|
if is_hip():
|
|
metadata["num_warps"] = _device_backend.get_num_warps(next_module)
|
|
else:
|
|
metadata["num_warps"] = get_num_warps(next_module)
|
|
if ir_name == "ptx":
|
|
metadata["name"] = get_kernel_name(next_module, pattern='// .globl')
|
|
if ir_name == "amdgcn":
|
|
metadata["name"] = get_kernel_name(next_module[0], pattern='.globl')
|
|
asm["hsaco_path"] = next_module[1]
|
|
if not is_cuda and not is_hip():
|
|
_device_backend.add_meta_info(ir_name, module, next_module, metadata, asm)
|
|
module = next_module
|
|
|
|
ids_of_folded_args = tuple([int(k) for k in configs[0].ids_of_folded_args]) if isinstance(fn, JITFunction) else ()
|
|
if "clusterDims" not in metadata:
|
|
metadata["clusterDims"] = [cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ]
|
|
|
|
if len(tma_infos) > 0:
|
|
metadata["tensormaps_info"] = parse_tma_info(tma_infos, ids_of_folded_args)
|
|
# set constant
|
|
if "tensormaps_info" in metadata:
|
|
for i, _ in enumerate(metadata["tensormaps_info"]):
|
|
metadata["tensormaps_info"][i].ids_of_folded_args = ids_of_folded_args
|
|
|
|
ids_of_tensormaps = get_ids_of_tensormaps(metadata.get("tensormaps_info", None))
|
|
if isinstance(fn, JITFunction) and "tensormaps_info" in metadata:
|
|
fn.tensormaps_info = metadata["tensormaps_info"]
|
|
|
|
ids_of_const_exprs = tuple(fn.constexprs) if isinstance(fn, JITFunction) else ()
|
|
ids = {
|
|
"ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs":
|
|
ids_of_const_exprs
|
|
}
|
|
# cache manager
|
|
if is_cuda:
|
|
so_path = make_stub(name, signature, constants, ids, enable_warp_specialization=enable_warp_specialization)
|
|
else:
|
|
so_path = _device_backend.make_launcher_stub(name, signature, constants, ids)
|
|
# write-back metadata, if it didn't come from the cache
|
|
if metadata_path is None:
|
|
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename,
|
|
binary=False)
|
|
fn_cache_manager.put_group(metadata_filename, metadata_group)
|
|
|
|
# return handle to compiled kernel
|
|
return CompiledKernel(fn, so_path, metadata, asm)
|
|
|
|
|
|
class CompiledKernel:
|
|
|
|
# Hooks for external tools to monitor the execution of triton kernels
|
|
launch_enter_hook = None
|
|
launch_exit_hook = None
|
|
tensormap_manager = TensorMapManager()
|
|
|
|
def __init__(self, fn, so_path, metadata, asm):
|
|
# initialize launcher
|
|
import importlib.util
|
|
spec = importlib.util.spec_from_file_location("__triton_launcher", so_path)
|
|
mod = importlib.util.module_from_spec(spec)
|
|
self.fn = fn
|
|
spec.loader.exec_module(mod)
|
|
self.c_wrapper = getattr(mod, "launch")
|
|
# initialize metadata
|
|
self.shared = metadata["shared"]
|
|
self.num_warps = metadata["num_warps"]
|
|
self.warp_size = metadata["warp_size"]
|
|
self.num_ctas = metadata["num_ctas"]
|
|
self.num_stages = metadata["num_stages"]
|
|
self.waves_per_eu = metadata["waves_per_eu"]
|
|
self.slice_k_tile = metadata["slice_k_tile"]
|
|
self.clusterDims = metadata["clusterDims"]
|
|
if "tensormaps_info" in metadata:
|
|
self.tensormaps_info = metadata["tensormaps_info"]
|
|
self.constants = metadata["constants"]
|
|
self.device_type = metadata["device_type"]
|
|
self.device_backend = get_backend(self.device_type) if self.device_type not in ["cuda"] else None
|
|
# initialize asm dict
|
|
self.asm = asm
|
|
# binaries are lazily initialized
|
|
# because it involves doing runtime things
|
|
# (e.g., checking amount of shared memory on current device)
|
|
self.metadata = metadata
|
|
self.cu_module = None
|
|
self.cu_function = None
|
|
|
|
def _init_handles(self):
|
|
if self.cu_module is not None:
|
|
return
|
|
|
|
if self.device_type in ["cuda"]:
|
|
device = get_current_device()
|
|
bin_path = {driver.HIP: "hsaco_path", driver.CUDA: "cubin"}[driver.backend]
|
|
max_shared = driver.utils.get_device_properties(device)["max_shared_mem"]
|
|
fn_load_binary = driver.utils.load_binary
|
|
else:
|
|
assert self.device_backend
|
|
device = self.device_backend.get_current_device()
|
|
bin_path = self.device_backend.get_kernel_bin()
|
|
max_shared = self.device_backend.get_device_properties(device)["max_shared_mem"]
|
|
fn_load_binary = self.device_backend.get_load_binary_fn()
|
|
|
|
if self.shared > max_shared:
|
|
raise OutOfResources(self.shared, max_shared, "shared memory")
|
|
|
|
mod, func, n_regs, n_spills = fn_load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device)
|
|
|
|
self.n_spills = n_spills
|
|
self.n_regs = n_regs
|
|
self.cu_module = mod
|
|
self.cu_function = func
|
|
|
|
def __getattribute__(self, name):
|
|
if name == 'c_wrapper':
|
|
self._init_handles()
|
|
return super().__getattribute__(name)
|
|
|
|
# capture args and expand args with cutensormap*
|
|
def assemble_tensormap_to_arg(self, args):
|
|
args_with_tma = list(args)
|
|
if hasattr(self, 'tensormaps_info'):
|
|
# tuple for hashable
|
|
args_ptr = tuple([arg.data_ptr() if hasattr(arg, 'data_ptr') else arg for arg in args])
|
|
for i, e in enumerate(self.tensormaps_info):
|
|
args_with_tma.append(CompiledKernel.tensormap_manager[(e, args_ptr)])
|
|
return args_with_tma
|
|
|
|
def __getitem__(self, grid):
|
|
self._init_handles()
|
|
|
|
def runner(*args, stream=None):
|
|
args_expand = self.assemble_tensormap_to_arg(args)
|
|
if stream is None:
|
|
if self.device_type in ["cuda"]:
|
|
stream = get_cuda_stream()
|
|
else:
|
|
stream = get_backend(self.device_type).get_stream(None)
|
|
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.clusterDims[0],
|
|
self.clusterDims[1], self.clusterDims[2], self.shared, stream, self.cu_function,
|
|
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand)
|
|
|
|
return runner
|