mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] use enum instead of bool to select target (#2118)
Before this PR, the determination of `TritonGPUToLLVMIRPass` to generate NVVM-compatible LLVM or ROCDL-compatible LLVM is controlled by a boolean `isROCM`. This method is hard to scale. This PR changes it to use an enum instead, where new target can be added easily when needed. --------- Signed-off-by: Tsang, Whitney <whitney.tsang@intel.com> Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -125,7 +125,7 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
|
||||
llvm::LLVMContext llvmContext;
|
||||
mlir::triton::gpu::TMAMetadataTy tmaInfos;
|
||||
auto llvmir = translateTritonGPUToLLVMIR(
|
||||
&llvmContext, *module, SMArch.getValue(), tmaInfos, false /*isRocm*/);
|
||||
&llvmContext, *module, SMArch.getValue(), tmaInfos, Target::Default);
|
||||
|
||||
if (!llvmir) {
|
||||
llvm::errs() << "Translate to LLVM IR failed";
|
||||
|
||||
@@ -30,9 +30,13 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
|
||||
Option<"tmaMetadata", "tma-metadata",
|
||||
"mlir::triton::gpu::TMAMetadataTy*", /*default*/"nullptr",
|
||||
"tma metadata to the runtime">,
|
||||
Option<"isROCM", "is-rocm",
|
||||
"bool", /*default*/"false",
|
||||
"compile for ROCM-compatible LLVM">,
|
||||
Option<"target", "target", "enum Target", "mlir::triton::Target::Default",
|
||||
"compile for target compatible LLVM",
|
||||
"llvm::cl::values("
|
||||
"clEnumValN(mlir::triton::Target::NVVM, \"nvvm\", \"compile for "
|
||||
"NVVM-compatible LLVM\"), "
|
||||
"clEnumValN(mlir::triton::Target::ROCDL, \"rocdl\", \"compile for "
|
||||
"ROCDL-compatible LLVM\"))">,
|
||||
];
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,8 @@ template <typename T> class OperationPass;
|
||||
|
||||
namespace triton {
|
||||
|
||||
enum Target { NVVM, ROCDL, Default = NVVM };
|
||||
|
||||
#define GEN_PASS_DECL
|
||||
#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc"
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#ifndef TRITON_TARGET_LLVM_IR_LLVM_IR_TRANSLATION_H
|
||||
#define TRITON_TARGET_LLVM_IR_LLVM_IR_TRANSLATION_H
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
|
||||
#include "triton/Target/PTX/TmaMetadata.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include <memory>
|
||||
@@ -28,15 +29,15 @@ std::unique_ptr<llvm::Module>
|
||||
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
mlir::ModuleOp module, int computeCapability,
|
||||
mlir::triton::gpu::TMAMetadataTy &tmaInfos,
|
||||
bool isROCM);
|
||||
Target target);
|
||||
|
||||
// Translate mlir LLVM dialect to LLVMIR, return null if failed.
|
||||
std::unique_ptr<llvm::Module>
|
||||
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
|
||||
bool isROCM);
|
||||
Target target);
|
||||
|
||||
bool linkExternLib(llvm::Module &module, llvm::StringRef name,
|
||||
llvm::StringRef path, bool isROCM);
|
||||
llvm::StringRef path, Target target);
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
@@ -63,14 +63,17 @@ static void addWSNamedAttrs(Operation *op,
|
||||
|
||||
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
|
||||
public:
|
||||
explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx, bool isROCM)
|
||||
explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx, Target target)
|
||||
: ConversionTarget(ctx) {
|
||||
addLegalDialect<index::IndexDialect>();
|
||||
addLegalDialect<LLVM::LLVMDialect>();
|
||||
if (isROCM) {
|
||||
addLegalDialect<ROCDL::ROCDLDialect>();
|
||||
} else {
|
||||
switch (target) {
|
||||
case Target::NVVM:
|
||||
addLegalDialect<NVVM::NVVMDialect>();
|
||||
break;
|
||||
case Target::ROCDL:
|
||||
addLegalDialect<ROCDL::ROCDLDialect>();
|
||||
break;
|
||||
}
|
||||
addLegalOp<mlir::UnrealizedConversionCastOp>();
|
||||
}
|
||||
@@ -359,13 +362,16 @@ private:
|
||||
|
||||
class TritonLLVMConversionTarget : public ConversionTarget {
|
||||
public:
|
||||
explicit TritonLLVMConversionTarget(MLIRContext &ctx, bool isROCM)
|
||||
explicit TritonLLVMConversionTarget(MLIRContext &ctx, Target target)
|
||||
: ConversionTarget(ctx) {
|
||||
addLegalDialect<LLVM::LLVMDialect>();
|
||||
if (isROCM) {
|
||||
addLegalDialect<ROCDL::ROCDLDialect>();
|
||||
} else {
|
||||
switch (target) {
|
||||
case Target::NVVM:
|
||||
addLegalDialect<NVVM::NVVMDialect>();
|
||||
break;
|
||||
case Target::ROCDL:
|
||||
addLegalDialect<ROCDL::ROCDLDialect>();
|
||||
break;
|
||||
}
|
||||
addLegalDialect<mlir::triton::nvgpu::NVGPUDialect>();
|
||||
addIllegalDialect<triton::TritonDialect>();
|
||||
@@ -387,7 +393,7 @@ struct ConvertTritonGPUToLLVM
|
||||
mlir::LowerToLLVMOptions option(context);
|
||||
option.overrideIndexBitwidth(32);
|
||||
TritonGPUToLLVMTypeConverter typeConverter(context, option);
|
||||
TritonLLVMConversionTarget target(*context, isROCM);
|
||||
TritonLLVMConversionTarget convTarget(*context, target);
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
|
||||
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
|
||||
@@ -441,7 +447,7 @@ struct ConvertTritonGPUToLLVM
|
||||
{
|
||||
mlir::LowerToLLVMOptions option(context);
|
||||
TritonGPUToLLVMTypeConverter typeConverter(context, option);
|
||||
TritonLLVMFunctionConversionTarget funcTarget(*context, isROCM);
|
||||
TritonLLVMFunctionConversionTarget funcTarget(*context, target);
|
||||
RewritePatternSet funcPatterns(context);
|
||||
funcPatterns.add<FuncOpConversion>(typeConverter, numWarps, allocation,
|
||||
/*benefit=*/1);
|
||||
@@ -461,7 +467,7 @@ struct ConvertTritonGPUToLLVM
|
||||
{
|
||||
mlir::LowerToLLVMOptions option(context);
|
||||
TritonGPUToLLVMTypeConverter typeConverter(context, option);
|
||||
TritonLLVMFunctionConversionTarget funcTarget(*context, isROCM);
|
||||
TritonLLVMFunctionConversionTarget funcTarget(*context, target);
|
||||
RewritePatternSet funcPatterns(context);
|
||||
funcPatterns.add<CallOpConversion>(typeConverter, numWarps, allocation,
|
||||
/*benefit=*/1);
|
||||
@@ -539,16 +545,19 @@ struct ConvertTritonGPUToLLVM
|
||||
mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns);
|
||||
|
||||
// Native lowering patterns
|
||||
if (isROCM) {
|
||||
switch (target) {
|
||||
case Target::NVVM:
|
||||
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
|
||||
break;
|
||||
case Target::ROCDL:
|
||||
mlir::populateGpuToROCDLConversionPatterns(typeConverter, patterns,
|
||||
mlir::gpu::amd::HIP);
|
||||
} else {
|
||||
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
|
||||
break;
|
||||
}
|
||||
|
||||
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
||||
if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
// Fold CTAId when there is only 1 CTA.
|
||||
|
||||
@@ -55,7 +55,7 @@ struct NVVMMetadata {
|
||||
|
||||
// Add the nvvm related metadata to LLVM IR.
|
||||
static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata,
|
||||
bool isROCM) {
|
||||
Target target) {
|
||||
auto *module = func->getParent();
|
||||
auto &ctx = func->getContext();
|
||||
|
||||
@@ -85,16 +85,19 @@ static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata,
|
||||
}
|
||||
|
||||
if (metadata.isKernel) {
|
||||
if (isROCM) {
|
||||
func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
|
||||
func->addFnAttr("amdgpu-flat-work-group-size", "1, 1024");
|
||||
} else {
|
||||
switch (target) {
|
||||
case Target::NVVM: {
|
||||
llvm::Metadata *mdArgs[] = {
|
||||
llvm::ValueAsMetadata::get(func), llvm::MDString::get(ctx, "kernel"),
|
||||
llvm::ValueAsMetadata::get(
|
||||
llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1))};
|
||||
module->getOrInsertNamedMetadata("nvvm.annotations")
|
||||
->addOperand(llvm::MDNode::get(ctx, mdArgs));
|
||||
} break;
|
||||
case Target::ROCDL: {
|
||||
func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
|
||||
func->addFnAttr("amdgpu-flat-work-group-size", "1, 1024");
|
||||
} break;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -240,7 +243,7 @@ static void linkLibdevice(llvm::Module &module) {
|
||||
}
|
||||
|
||||
bool linkExternLib(llvm::Module &module, llvm::StringRef name,
|
||||
llvm::StringRef path, bool isROCM) {
|
||||
llvm::StringRef path, Target target) {
|
||||
llvm::SMDiagnostic err;
|
||||
auto &ctx = module.getContext();
|
||||
|
||||
@@ -259,8 +262,7 @@ bool linkExternLib(llvm::Module &module, llvm::StringRef name,
|
||||
return true;
|
||||
}
|
||||
|
||||
// check if ROCM
|
||||
if (!isROCM) {
|
||||
if (target == Target::NVVM) {
|
||||
if (name == "libdevice") {
|
||||
linkLibdevice(module);
|
||||
}
|
||||
@@ -274,7 +276,7 @@ bool linkExternLib(llvm::Module &module, llvm::StringRef name,
|
||||
|
||||
std::unique_ptr<llvm::Module>
|
||||
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
|
||||
bool isROCM) {
|
||||
Target target) {
|
||||
DialectRegistry registry;
|
||||
mlir::registerBuiltinDialectTranslation(registry);
|
||||
mlir::registerLLVMDialectTranslation(registry);
|
||||
@@ -302,7 +304,7 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
|
||||
// dead code.
|
||||
auto externLibs = getExternLibs(module);
|
||||
for (auto &lib : externLibs) {
|
||||
if (linkExternLib(*llvmModule, lib.first, lib.second, isROCM))
|
||||
if (linkExternLib(*llvmModule, lib.first, lib.second, target))
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -318,7 +320,7 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
|
||||
for (auto &func : llvmModule->functions()) {
|
||||
auto it = nvvmMetadata.find(func.getName());
|
||||
if (it != nvvmMetadata.end())
|
||||
amendLLVMFunc(&func, it->second, isROCM);
|
||||
amendLLVMFunc(&func, it->second, target);
|
||||
}
|
||||
|
||||
return llvmModule;
|
||||
@@ -328,7 +330,7 @@ std::unique_ptr<llvm::Module>
|
||||
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
mlir::ModuleOp module, int computeCapability,
|
||||
mlir::triton::gpu::TMAMetadataTy &tmaInfos,
|
||||
bool isROCM) {
|
||||
Target target) {
|
||||
mlir::PassManager pm(module->getContext());
|
||||
mlir::registerPassManagerCLOptions();
|
||||
if (failed(applyPassManagerCLOptions(pm))) {
|
||||
@@ -351,7 +353,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
pm.addPass(mlir::createConvertSCFToCFPass());
|
||||
pm.addPass(mlir::createConvertIndexToLLVMPass());
|
||||
pm.addPass(
|
||||
createConvertTritonGPUToLLVMPass({computeCapability, &tmaInfos, isROCM}));
|
||||
createConvertTritonGPUToLLVMPass({computeCapability, &tmaInfos, target}));
|
||||
pm.addPass(createConvertNVGPUToLLVMPass());
|
||||
pm.addPass(mlir::createArithToLLVMConversionPass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
@@ -366,7 +368,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto llvmIR = translateLLVMToLLVMIR(llvmContext, module, isROCM);
|
||||
auto llvmIR = translateLLVMToLLVMIR(llvmContext, module, target);
|
||||
if (!llvmIR) {
|
||||
llvm::errs() << "Translate to LLVM IR failed";
|
||||
return nullptr;
|
||||
|
||||
@@ -56,7 +56,7 @@ static void linkExternal(llvm::Module &module) {
|
||||
// std::filesystem::path(__BUILD_DIR__) / "lib" / "Hopper" /
|
||||
// "libhopper_helpers.bc";
|
||||
if (mlir::triton::linkExternLib(module, "libhopper_helpers", path.string(),
|
||||
/*isROCM*/ false))
|
||||
mlir::triton::Target::NVVM))
|
||||
llvm::errs() << "Link failed for: libhopper_helpers.bc";
|
||||
}
|
||||
|
||||
|
||||
@@ -81,6 +81,11 @@ void init_triton_runtime(py::module &&m) {
|
||||
.value("CUDA", CUDA)
|
||||
.value("ROCM", ROCM)
|
||||
.export_values();
|
||||
|
||||
py::enum_<mlir::triton::Target>(m, "TARGET")
|
||||
.value("NVVM", mlir::triton::NVVM)
|
||||
.value("ROCDL", mlir::triton::ROCDL)
|
||||
.export_values();
|
||||
}
|
||||
|
||||
// A custom op builder that keeps track of the last location
|
||||
@@ -1804,11 +1809,12 @@ void init_triton_translation(py::module &m) {
|
||||
m.def(
|
||||
"translate_triton_gpu_to_llvmir",
|
||||
[](mlir::ModuleOp op, int computeCapability,
|
||||
mlir::triton::gpu::TMAMetadataTy &tmaInfos, bool isROCM) {
|
||||
mlir::triton::gpu::TMAMetadataTy &tmaInfos,
|
||||
mlir::triton::Target target) {
|
||||
py::gil_scoped_release allow_threads;
|
||||
llvm::LLVMContext llvmContext;
|
||||
auto llvmModule = ::mlir::triton::translateTritonGPUToLLVMIR(
|
||||
&llvmContext, op, computeCapability, tmaInfos, isROCM);
|
||||
&llvmContext, op, computeCapability, tmaInfos, target);
|
||||
if (!llvmModule)
|
||||
llvm::report_fatal_error("Failed to translate TritonGPU to LLVM IR.");
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from typing import Any, Tuple
|
||||
|
||||
from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs,
|
||||
compile_ptx_to_cubin, get_env_vars, get_num_warps,
|
||||
get_shared_memory_size, ir,
|
||||
get_shared_memory_size, ir, runtime,
|
||||
translate_llvmir_to_hsaco, translate_llvmir_to_ptx,
|
||||
translate_triton_gpu_to_llvmir)
|
||||
from ..common.backend import get_backend, path_to_ptxas
|
||||
@@ -142,9 +142,9 @@ def ttgir_to_llir(mod, extern_libs, arch, tma_infos):
|
||||
_add_external_libs(mod, extern_libs)
|
||||
# TODO: separate tritongpu_to_llvmir for different backends
|
||||
if _is_cuda(arch):
|
||||
return translate_triton_gpu_to_llvmir(mod, arch, tma_infos, False)
|
||||
return translate_triton_gpu_to_llvmir(mod, arch, tma_infos, runtime.TARGET.NVVM)
|
||||
else:
|
||||
return translate_triton_gpu_to_llvmir(mod, 0, True)
|
||||
return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL)
|
||||
|
||||
|
||||
# PTX translation
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm | FileCheck %s
|
||||
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm="target=nvvm" | FileCheck %s
|
||||
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<f16, 1>)
|
||||
|
||||
Reference in New Issue
Block a user