mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Add waves_per_eu as kernel parameter (#319)
* Add waves_per_eu as kernel parameter * Fix failing tests * Add default value for waves_per_eu for ttgir_to_llir function * Remove aot.py
This commit is contained in:
@@ -126,11 +126,13 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
|
||||
llvm::LLVMContext llvmContext;
|
||||
mlir::triton::gpu::TMAMetadataTy tmaInfos;
|
||||
#ifdef USE_ROCM
|
||||
auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module,
|
||||
SMArch.getValue(), tmaInfos, Target::ROCDL);
|
||||
auto llvmir =
|
||||
translateTritonGPUToLLVMIR(&llvmContext, *module, SMArch.getValue(),
|
||||
tmaInfos, Target::ROCDL, 0 /*wavesPerEU*/);
|
||||
#else
|
||||
auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module,
|
||||
SMArch.getValue(), tmaInfos, Target::Default);
|
||||
auto llvmir =
|
||||
translateTritonGPUToLLVMIR(&llvmContext, *module, SMArch.getValue(),
|
||||
tmaInfos, Target::Default, 0 /*wavesPerEU*/);
|
||||
#endif
|
||||
if (!llvmir) {
|
||||
llvm::errs() << "Translate to LLVM IR failed";
|
||||
|
||||
@@ -29,7 +29,7 @@ std::unique_ptr<llvm::Module>
|
||||
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
mlir::ModuleOp module, int computeCapability,
|
||||
mlir::triton::gpu::TMAMetadataTy &tmaInfos,
|
||||
Target target);
|
||||
Target target, int wavesPerEU);
|
||||
|
||||
// Translate mlir LLVM dialect to LLVMIR, return null if failed.
|
||||
std::unique_ptr<llvm::Module>
|
||||
|
||||
@@ -59,7 +59,8 @@ struct NVVMMetadata {
|
||||
|
||||
// Add the nvvm related metadata to LLVM IR.
|
||||
static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata,
|
||||
Target target, const int threadsPerCTA) {
|
||||
Target target, const int threadsPerCTA,
|
||||
const int wavesPerEU) {
|
||||
auto *module = func->getParent();
|
||||
auto &ctx = func->getContext();
|
||||
|
||||
@@ -102,6 +103,8 @@ static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata,
|
||||
func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
|
||||
func->addFnAttr("amdgpu-flat-work-group-size",
|
||||
"1, " + std::to_string(threadsPerCTA));
|
||||
if (wavesPerEU > 0)
|
||||
func->addFnAttr("amdgpu-waves-per-eu", std::to_string(wavesPerEU));
|
||||
func->addFnAttr("denormal-fp-math-f32", "preserve-sign");
|
||||
func->addFnAttr("amdgpu-unsafe-fp-atomics", "true");
|
||||
} break;
|
||||
@@ -283,7 +286,7 @@ bool linkExternLib(llvm::Module &module, llvm::StringRef name,
|
||||
|
||||
std::unique_ptr<llvm::Module>
|
||||
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
|
||||
Target target) {
|
||||
Target target, int wavesPerEU) {
|
||||
DialectRegistry registry;
|
||||
mlir::registerBuiltinDialectTranslation(registry);
|
||||
mlir::registerLLVMDialectTranslation(registry);
|
||||
@@ -331,7 +334,7 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
|
||||
for (auto &func : llvmModule->functions()) {
|
||||
auto it = nvvmMetadata.find(func.getName());
|
||||
if (it != nvvmMetadata.end())
|
||||
amendLLVMFunc(&func, it->second, target, threadsPerCTA);
|
||||
amendLLVMFunc(&func, it->second, target, threadsPerCTA, wavesPerEU);
|
||||
}
|
||||
|
||||
return llvmModule;
|
||||
@@ -341,7 +344,7 @@ std::unique_ptr<llvm::Module>
|
||||
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
mlir::ModuleOp module, int computeCapability,
|
||||
mlir::triton::gpu::TMAMetadataTy &tmaInfos,
|
||||
Target target) {
|
||||
Target target, int wavesPerEU) {
|
||||
mlir::PassManager pm(module->getContext());
|
||||
mlir::registerPassManagerCLOptions();
|
||||
if (failed(applyPassManagerCLOptions(pm))) {
|
||||
@@ -385,7 +388,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto llvmIR = translateLLVMToLLVMIR(llvmContext, module, target);
|
||||
auto llvmIR = translateLLVMToLLVMIR(llvmContext, module, target, wavesPerEU);
|
||||
if (!llvmIR) {
|
||||
llvm::errs() << "Translate to LLVM IR failed";
|
||||
return nullptr;
|
||||
|
||||
@@ -1970,11 +1970,11 @@ void init_triton_translation(py::module &m) {
|
||||
"translate_triton_gpu_to_llvmir",
|
||||
[](mlir::ModuleOp op, int computeCapability,
|
||||
mlir::triton::gpu::TMAMetadataTy &tmaInfos,
|
||||
mlir::triton::Target target) {
|
||||
mlir::triton::Target target, int wavesPerEU) {
|
||||
py::gil_scoped_release allow_threads;
|
||||
llvm::LLVMContext llvmContext;
|
||||
auto llvmModule = ::mlir::triton::translateTritonGPUToLLVMIR(
|
||||
&llvmContext, op, computeCapability, tmaInfos, target);
|
||||
&llvmContext, op, computeCapability, tmaInfos, target, wavesPerEU);
|
||||
if (!llvmModule)
|
||||
llvm::report_fatal_error("Failed to translate TritonGPU to LLVM IR.");
|
||||
|
||||
|
||||
@@ -162,14 +162,14 @@ def _add_external_libs(mod, libs):
|
||||
add_external_libs(mod, list(libs.keys()), list(libs.values()))
|
||||
|
||||
|
||||
def ttgir_to_llir(mod, extern_libs, arch, tma_infos):
|
||||
def ttgir_to_llir(mod, extern_libs, arch, tma_infos, waves_per_eu=0):
|
||||
if extern_libs:
|
||||
_add_external_libs(mod, extern_libs)
|
||||
# TODO: separate tritongpu_to_llvmir for different backends
|
||||
if _is_cuda(arch):
|
||||
return translate_triton_gpu_to_llvmir(mod, arch, tma_infos, runtime.TARGET.NVVM)
|
||||
return translate_triton_gpu_to_llvmir(mod, arch, tma_infos, runtime.TARGET.NVVM, waves_per_eu)
|
||||
else:
|
||||
return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL)
|
||||
return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL, waves_per_eu)
|
||||
|
||||
|
||||
# PTX translation
|
||||
@@ -308,6 +308,7 @@ def make_hash(fn, arch, env_vars, **kwargs):
|
||||
num_warps = kwargs.get("num_warps", 4)
|
||||
num_ctas = kwargs.get("num_ctas", 1)
|
||||
num_stages = kwargs.get("num_stages", 3)
|
||||
waves_per_eu = kwargs.get("waves_per_eu", 0)
|
||||
enable_warp_specialization = kwargs.get("enable_warp_specialization", False)
|
||||
enable_persistent = kwargs.get("enable_persistent", False)
|
||||
debug = kwargs.get("debug", False)
|
||||
@@ -315,7 +316,7 @@ def make_hash(fn, arch, env_vars, **kwargs):
|
||||
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8))
|
||||
configs_key = [get_conf_key(conf) for conf in configs]
|
||||
env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())]
|
||||
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{arch}-{env_vars_list}"
|
||||
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{arch}-{env_vars_list}"
|
||||
return hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
assert isinstance(fn, str)
|
||||
return hashlib.md5((Path(fn).read_text() + version_key()).encode("utf-8")).hexdigest()
|
||||
@@ -472,6 +473,7 @@ def compile(fn, **kwargs):
|
||||
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
|
||||
num_ctas = kwargs.get("num_ctas", 1)
|
||||
num_stages = kwargs.get("num_stages", get_arch_default_num_stages(device_type, capability=capability))
|
||||
waves_per_eu = kwargs.get("waves_per_eu", 0)
|
||||
# TODO[shuhaoj]: Default should be to enable warp specialization once possible
|
||||
enable_warp_specialization = kwargs.get("enable_warp_specialization", False)
|
||||
# TODO[shuhaoj]: persistent can be decoupled with warp specialization
|
||||
@@ -499,7 +501,7 @@ def compile(fn, **kwargs):
|
||||
stages["ttgir"] = (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue))
|
||||
stages["llir"] = (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos))
|
||||
lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos, waves_per_eu))
|
||||
if is_cuda:
|
||||
add_cuda_stages(arch, extern_libs, stages)
|
||||
elif is_hip:
|
||||
@@ -571,6 +573,7 @@ def compile(fn, **kwargs):
|
||||
"warp_size": warp_size,
|
||||
"num_ctas": num_ctas,
|
||||
"num_stages": num_stages,
|
||||
"waves_per_eu": waves_per_eu,
|
||||
"enable_warp_specialization": enable_warp_specialization,
|
||||
"enable_persistent": enable_persistent,
|
||||
"constants": _get_jsonable_constants(constants),
|
||||
@@ -689,6 +692,7 @@ class CompiledKernel:
|
||||
self.warp_size = metadata["warp_size"]
|
||||
self.num_ctas = metadata["num_ctas"]
|
||||
self.num_stages = metadata["num_stages"]
|
||||
self.waves_per_eu = metadata["waves_per_eu"]
|
||||
self.clusterDims = metadata["clusterDims"]
|
||||
if "tensormaps_info" in metadata:
|
||||
self.tensormaps_info = metadata["tensormaps_info"]
|
||||
|
||||
@@ -276,13 +276,13 @@ class JITFunction(KernelInterface[T]):
|
||||
constants = dict(zip(self.constexprs, constexpr_key))
|
||||
return constants
|
||||
|
||||
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs):
|
||||
def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, extern_libs, configs):
|
||||
if JITFunction.cache_hook is None:
|
||||
return False
|
||||
name = self.fn.__name__
|
||||
module = self.fn.__module__
|
||||
arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])])
|
||||
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs})"
|
||||
repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs})"
|
||||
key = str(key)
|
||||
|
||||
class LegacyCompiler:
|
||||
@@ -292,7 +292,7 @@ class JITFunction(KernelInterface[T]):
|
||||
pass
|
||||
|
||||
kwargs = dict(signature=signature, device=device, constants=constants,
|
||||
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs,
|
||||
num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs,
|
||||
configs=configs)
|
||||
|
||||
return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={
|
||||
@@ -364,7 +364,7 @@ class JITFunction(KernelInterface[T]):
|
||||
|
||||
src = f"""
|
||||
import triton
|
||||
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, waves_per_eu=0, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
from ..compiler import compile, CompiledKernel, get_arch_default_num_warps, get_arch_default_num_stages
|
||||
sig_key = {f'{sig_keys},' if len(sig_keys) > 0 else ()}
|
||||
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
|
||||
@@ -406,7 +406,7 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
|
||||
if num_stages is None:
|
||||
num_stages = get_arch_default_num_stages(device_type)
|
||||
|
||||
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, self.debug)
|
||||
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, self.debug)
|
||||
if not extern_libs is None:
|
||||
key = (key, tuple(extern_libs.items()))
|
||||
|
||||
@@ -434,8 +434,8 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu
|
||||
for i, arg in constants.items():
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs):
|
||||
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, extern_libs, configs):
|
||||
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
|
||||
# Create tensormaps and append to args
|
||||
args = bin.assemble_tensormap_to_arg(args)
|
||||
if not warmup:
|
||||
|
||||
@@ -464,6 +464,7 @@ class _attention(torch.autograd.Function):
|
||||
BLOCK_N = 64
|
||||
num_warps = 4
|
||||
num_stages = 1
|
||||
waves_per_eu = 2 if causal else 3
|
||||
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
@@ -481,7 +482,7 @@ class _attention(torch.autograd.Function):
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
|
||||
IS_CAUSAL=causal,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages)
|
||||
num_stages=num_stages, waves_per_eu=waves_per_eu)
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, L)
|
||||
ctx.grid = grid
|
||||
@@ -560,7 +561,7 @@ class _attention(torch.autograd.Function):
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=2*BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4, waves_per_eu=1,
|
||||
num_stages=1,
|
||||
)
|
||||
# print(h.asm["ttgir"])
|
||||
|
||||
Reference in New Issue
Block a user