mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Use optimal number of VGPRs (#281)
* Use optimal number of VGPRs * Fix tritongpu_to_hsaco test
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Tools/Sys/GetEnv.hpp"
|
||||
#include "llvm/IR/CallingConv.h"
|
||||
#include "llvm/ADT/APInt.h"
|
||||
@@ -51,7 +52,7 @@ struct NVVMMetadata {
|
||||
|
||||
// Add the nvvm related metadata to LLVM IR.
|
||||
static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata,
|
||||
bool isROCM) {
|
||||
bool isROCM, const int threadsPerCTA) {
|
||||
auto *module = func->getParent();
|
||||
auto &ctx = func->getContext();
|
||||
|
||||
@@ -83,7 +84,8 @@ static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata,
|
||||
if (metadata.isKernel) {
|
||||
if (isROCM) {
|
||||
func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
|
||||
func->addFnAttr("amdgpu-flat-work-group-size", "1, 1024");
|
||||
func->addFnAttr("amdgpu-flat-work-group-size",
|
||||
"1, " + std::to_string(threadsPerCTA));
|
||||
func->addFnAttr("denormal-fp-math-f32", "preserve-sign");
|
||||
func->addFnAttr("amdgpu-unsafe-fp-atomics", "true");
|
||||
} else {
|
||||
@@ -312,10 +314,14 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module);
|
||||
const int warpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(module);
|
||||
const int threadsPerCTA = numWarps * warpSize;
|
||||
|
||||
for (auto &func : llvmModule->functions()) {
|
||||
auto it = nvvmMetadata.find(func.getName());
|
||||
if (it != nvvmMetadata.end())
|
||||
amendLLVMFunc(&func, it->second, isROCM);
|
||||
amendLLVMFunc(&func, it->second, isROCM, threadsPerCTA);
|
||||
}
|
||||
|
||||
return llvmModule;
|
||||
|
||||
@@ -29,7 +29,7 @@
|
||||
// CHECK: .group_segment_fixed_size: 0
|
||||
// CHECK-NEXT: .kernarg_segment_align: 8
|
||||
// CHECK-NEXT: .kernarg_segment_size: 16
|
||||
// CHECK-NEXT: .max_flat_workgroup_size: 1024
|
||||
// CHECK-NEXT: .max_flat_workgroup_size: 256
|
||||
// CHECK-NEXT: .name: test_empty_kernel
|
||||
// CHECK-NEXT: .private_segment_fixed_size: 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user