mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Add waves_per_eu as kernel parameter (#319)
* Add waves_per_eu as kernel parameter * Fix failing tests * Add default value for waves_per_eu for ttgir_to_llir function * Remove aot.py
This commit is contained in:
@@ -59,7 +59,8 @@ struct NVVMMetadata {
|
||||
|
||||
// Add the nvvm related metadata to LLVM IR.
|
||||
static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata,
|
||||
Target target, const int threadsPerCTA) {
|
||||
Target target, const int threadsPerCTA,
|
||||
const int wavesPerEU) {
|
||||
auto *module = func->getParent();
|
||||
auto &ctx = func->getContext();
|
||||
|
||||
@@ -102,6 +103,8 @@ static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata,
|
||||
func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
|
||||
func->addFnAttr("amdgpu-flat-work-group-size",
|
||||
"1, " + std::to_string(threadsPerCTA));
|
||||
if (wavesPerEU > 0)
|
||||
func->addFnAttr("amdgpu-waves-per-eu", std::to_string(wavesPerEU));
|
||||
func->addFnAttr("denormal-fp-math-f32", "preserve-sign");
|
||||
func->addFnAttr("amdgpu-unsafe-fp-atomics", "true");
|
||||
} break;
|
||||
@@ -283,7 +286,7 @@ bool linkExternLib(llvm::Module &module, llvm::StringRef name,
|
||||
|
||||
std::unique_ptr<llvm::Module>
|
||||
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
|
||||
Target target) {
|
||||
Target target, int wavesPerEU) {
|
||||
DialectRegistry registry;
|
||||
mlir::registerBuiltinDialectTranslation(registry);
|
||||
mlir::registerLLVMDialectTranslation(registry);
|
||||
@@ -331,7 +334,7 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
|
||||
for (auto &func : llvmModule->functions()) {
|
||||
auto it = nvvmMetadata.find(func.getName());
|
||||
if (it != nvvmMetadata.end())
|
||||
amendLLVMFunc(&func, it->second, target, threadsPerCTA);
|
||||
amendLLVMFunc(&func, it->second, target, threadsPerCTA, wavesPerEU);
|
||||
}
|
||||
|
||||
return llvmModule;
|
||||
@@ -341,7 +344,7 @@ std::unique_ptr<llvm::Module>
|
||||
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
mlir::ModuleOp module, int computeCapability,
|
||||
mlir::triton::gpu::TMAMetadataTy &tmaInfos,
|
||||
Target target) {
|
||||
Target target, int wavesPerEU) {
|
||||
mlir::PassManager pm(module->getContext());
|
||||
mlir::registerPassManagerCLOptions();
|
||||
if (failed(applyPassManagerCLOptions(pm))) {
|
||||
@@ -385,7 +388,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto llvmIR = translateLLVMToLLVMIR(llvmContext, module, target);
|
||||
auto llvmIR = translateLLVMToLLVMIR(llvmContext, module, target, wavesPerEU);
|
||||
if (!llvmIR) {
|
||||
llvm::errs() << "Translate to LLVM IR failed";
|
||||
return nullptr;
|
||||
|
||||
Reference in New Issue
Block a user