Add waves_per_eu as kernel parameter (#319)

* Add waves_per_eu as kernel parameter

* Fix failing tests

* Add default value for waves_per_eu for ttgir_to_llir function

* Remove aot.py
This commit is contained in:
oplavsic
2023-10-06 19:08:34 +02:00
committed by GitHub
parent be95edc63f
commit e801638b40
7 changed files with 36 additions and 26 deletions

View File

@@ -59,7 +59,8 @@ struct NVVMMetadata {
// Add the nvvm related metadata to LLVM IR.
static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata,
Target target, const int threadsPerCTA) {
Target target, const int threadsPerCTA,
const int wavesPerEU) {
auto *module = func->getParent();
auto &ctx = func->getContext();
@@ -102,6 +103,8 @@ static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata,
func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
func->addFnAttr("amdgpu-flat-work-group-size",
"1, " + std::to_string(threadsPerCTA));
if (wavesPerEU > 0)
func->addFnAttr("amdgpu-waves-per-eu", std::to_string(wavesPerEU));
func->addFnAttr("denormal-fp-math-f32", "preserve-sign");
func->addFnAttr("amdgpu-unsafe-fp-atomics", "true");
} break;
@@ -283,7 +286,7 @@ bool linkExternLib(llvm::Module &module, llvm::StringRef name,
std::unique_ptr<llvm::Module>
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
Target target) {
Target target, int wavesPerEU) {
DialectRegistry registry;
mlir::registerBuiltinDialectTranslation(registry);
mlir::registerLLVMDialectTranslation(registry);
@@ -331,7 +334,7 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
for (auto &func : llvmModule->functions()) {
auto it = nvvmMetadata.find(func.getName());
if (it != nvvmMetadata.end())
amendLLVMFunc(&func, it->second, target, threadsPerCTA);
amendLLVMFunc(&func, it->second, target, threadsPerCTA, wavesPerEU);
}
return llvmModule;
@@ -341,7 +344,7 @@ std::unique_ptr<llvm::Module>
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
mlir::ModuleOp module, int computeCapability,
mlir::triton::gpu::TMAMetadataTy &tmaInfos,
Target target) {
Target target, int wavesPerEU) {
mlir::PassManager pm(module->getContext());
mlir::registerPassManagerCLOptions();
if (failed(applyPassManagerCLOptions(pm))) {
@@ -385,7 +388,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
return nullptr;
}
auto llvmIR = translateLLVMToLLVMIR(llvmContext, module, target);
auto llvmIR = translateLLVMToLLVMIR(llvmContext, module, target, wavesPerEU);
if (!llvmIR) {
llvm::errs() << "Translate to LLVM IR failed";
return nullptr;