Add waves_per_eu as kernel parameter (#319)

* Add waves_per_eu as kernel parameter

* Fix failing tests

* Add default value for waves_per_eu for ttgir_to_llir function

* Remove aot.py
This commit is contained in:
oplavsic
2023-10-06 19:08:34 +02:00
committed by GitHub
parent be95edc63f
commit e801638b40
7 changed files with 36 additions and 26 deletions

View File

@@ -1970,11 +1970,11 @@ void init_triton_translation(py::module &m) {
"translate_triton_gpu_to_llvmir",
[](mlir::ModuleOp op, int computeCapability,
mlir::triton::gpu::TMAMetadataTy &tmaInfos,
mlir::triton::Target target) {
mlir::triton::Target target, int wavesPerEU) {
py::gil_scoped_release allow_threads;
llvm::LLVMContext llvmContext;
auto llvmModule = ::mlir::triton::translateTritonGPUToLLVMIR(
&llvmContext, op, computeCapability, tmaInfos, target);
&llvmContext, op, computeCapability, tmaInfos, target, wavesPerEU);
if (!llvmModule)
llvm::report_fatal_error("Failed to translate TritonGPU to LLVM IR.");

View File

@@ -162,14 +162,14 @@ def _add_external_libs(mod, libs):
add_external_libs(mod, list(libs.keys()), list(libs.values()))
def ttgir_to_llir(mod, extern_libs, arch, tma_infos):
def ttgir_to_llir(mod, extern_libs, arch, tma_infos, waves_per_eu=0):
if extern_libs:
_add_external_libs(mod, extern_libs)
# TODO: separate tritongpu_to_llvmir for different backends
if _is_cuda(arch):
return translate_triton_gpu_to_llvmir(mod, arch, tma_infos, runtime.TARGET.NVVM)
return translate_triton_gpu_to_llvmir(mod, arch, tma_infos, runtime.TARGET.NVVM, waves_per_eu)
else:
return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL)
return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL, waves_per_eu)
# PTX translation
@@ -308,6 +308,7 @@ def make_hash(fn, arch, env_vars, **kwargs):
num_warps = kwargs.get("num_warps", 4)
num_ctas = kwargs.get("num_ctas", 1)
num_stages = kwargs.get("num_stages", 3)
waves_per_eu = kwargs.get("waves_per_eu", 0)
enable_warp_specialization = kwargs.get("enable_warp_specialization", False)
enable_persistent = kwargs.get("enable_persistent", False)
debug = kwargs.get("debug", False)
@@ -315,7 +316,7 @@ def make_hash(fn, arch, env_vars, **kwargs):
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8))
configs_key = [get_conf_key(conf) for conf in configs]
env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())]
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{arch}-{env_vars_list}"
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{arch}-{env_vars_list}"
return hashlib.md5(key.encode("utf-8")).hexdigest()
assert isinstance(fn, str)
return hashlib.md5((Path(fn).read_text() + version_key()).encode("utf-8")).hexdigest()
@@ -472,6 +473,7 @@ def compile(fn, **kwargs):
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
num_ctas = kwargs.get("num_ctas", 1)
num_stages = kwargs.get("num_stages", get_arch_default_num_stages(device_type, capability=capability))
waves_per_eu = kwargs.get("waves_per_eu", 0)
# TODO[shuhaoj]: Default should be to enable warp specialization once possible
enable_warp_specialization = kwargs.get("enable_warp_specialization", False)
# TODO[shuhaoj]: persistent can be decoupled with warp specialization
@@ -499,7 +501,7 @@ def compile(fn, **kwargs):
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
stages["llir"] = (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos))
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos, waves_per_eu))
if is_cuda:
add_cuda_stages(arch, extern_libs, stages)
elif is_hip:
@@ -571,6 +573,7 @@ def compile(fn, **kwargs):
"warp_size": warp_size,
"num_ctas": num_ctas,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu,
"enable_warp_specialization": enable_warp_specialization,
"enable_persistent": enable_persistent,
"constants": _get_jsonable_constants(constants),
@@ -689,6 +692,7 @@ class CompiledKernel:
self.warp_size = metadata["warp_size"]
self.num_ctas = metadata["num_ctas"]
self.num_stages = metadata["num_stages"]
self.waves_per_eu = metadata["waves_per_eu"]
self.clusterDims = metadata["clusterDims"]
if "tensormaps_info" in metadata:
self.tensormaps_info = metadata["tensormaps_info"]

View File

@@ -276,13 +276,13 @@ class JITFunction(KernelInterface[T]):
constants = dict(zip(self.constexprs, constexpr_key))
return constants
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs):
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, extern_libs, configs):
if JITFunction.cache_hook is None:
return False
name = self.fn.__name__
module = self.fn.__module__
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs})"
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs})"
key = str(key)
class LegacyCompiler:
@@ -292,7 +292,7 @@ class JITFunction(KernelInterface[T]):
pass
kwargs = dict(signature=signature, device=device, constants=constants,
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs,
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs,
configs=configs)
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={
@@ -364,7 +364,7 @@ class JITFunction(KernelInterface[T]):
src = f"""
import triton
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, waves_per_eu=0, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
from ..compiler import compile, CompiledKernel, get_arch_default_num_warps, get_arch_default_num_stages
sig_key = {f'{sig_keys},' if len(sig_keys) > 0 else ()}
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
@@ -406,7 +406,7 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
if num_stages is None:
num_stages = get_arch_default_num_stages(device_type)
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, self.debug)
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, self.debug)
if not extern_libs is None:
key = (key, tuple(extern_libs.items()))
@@ -434,8 +434,8 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs):
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, extern_libs, configs):
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
# Create tensormaps and append to args
args = bin.assemble_tensormap_to_arg(args)
if not warmup:

View File

@@ -464,6 +464,7 @@ class _attention(torch.autograd.Function):
BLOCK_N = 64
num_warps = 4
num_stages = 1
waves_per_eu = 2 if causal else 3
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
@@ -481,7 +482,7 @@ class _attention(torch.autograd.Function):
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
IS_CAUSAL=causal,
num_warps=num_warps,
num_stages=num_stages)
num_stages=num_stages, waves_per_eu=waves_per_eu)
ctx.save_for_backward(q, k, v, o, L)
ctx.grid = grid
@@ -560,7 +561,7 @@ class _attention(torch.autograd.Function):
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=2*BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4, waves_per_eu=1,
num_stages=1,
)
# print(h.asm["ttgir"])