Use optimal number of VGPRs (#281)

* Use optimal number of VGPRs

* Fix tritongpu_to_hsaco test
This commit is contained in:
oplavsic
2023-08-04 17:46:53 +02:00
committed by GitHub
parent e1de24cd5c
commit 138844568d
2 changed files with 10 additions and 4 deletions

View File

@@ -15,6 +15,7 @@
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/IR/CallingConv.h"
#include "llvm/ADT/APInt.h"
@@ -51,7 +52,7 @@ struct NVVMMetadata {
// Add the nvvm related metadata to LLVM IR.
static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata,
bool isROCM) {
bool isROCM, const int threadsPerCTA) {
auto *module = func->getParent();
auto &ctx = func->getContext();
@@ -83,7 +84,8 @@ static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata,
if (metadata.isKernel) {
if (isROCM) {
func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
func->addFnAttr("amdgpu-flat-work-group-size", "1, 1024");
func->addFnAttr("amdgpu-flat-work-group-size",
"1, " + std::to_string(threadsPerCTA));
func->addFnAttr("denormal-fp-math-f32", "preserve-sign");
func->addFnAttr("amdgpu-unsafe-fp-atomics", "true");
} else {
@@ -312,10 +314,14 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
return nullptr;
}
const int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module);
const int warpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(module);
const int threadsPerCTA = numWarps * warpSize;
for (auto &func : llvmModule->functions()) {
auto it = nvvmMetadata.find(func.getName());
if (it != nvvmMetadata.end())
amendLLVMFunc(&func, it->second, isROCM);
amendLLVMFunc(&func, it->second, isROCM, threadsPerCTA);
}
return llvmModule;

View File

@@ -29,7 +29,7 @@
// CHECK: .group_segment_fixed_size: 0
// CHECK-NEXT: .kernarg_segment_align: 8
// CHECK-NEXT: .kernarg_segment_size: 16
// CHECK-NEXT: .max_flat_workgroup_size: 1024
// CHECK-NEXT: .max_flat_workgroup_size: 256
// CHECK-NEXT: .name: test_empty_kernel
// CHECK-NEXT: .private_segment_fixed_size: 0