[FRONTEND] use enum instead of bool to select target (#2118)

Before this PR, the determination of `TritonGPUToLLVMIRPass` to generate
NVVM-compatible LLVM or ROCDL-compatible LLVM is controlled by a boolean
`isROCM`. This method is hard to scale.
This PR changes it to use an enum instead, where new target can be added
easily when needed.

---------

Signed-off-by: Tsang, Whitney <whitney.tsang@intel.com>
Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
Whitney Tsang
2023-08-17 21:37:09 -04:00
committed by GitHub
parent b33f97a682
commit 100cabd0e4
10 changed files with 67 additions and 43 deletions

View File

@@ -125,7 +125,7 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
llvm::LLVMContext llvmContext;
mlir::triton::gpu::TMAMetadataTy tmaInfos;
auto llvmir = translateTritonGPUToLLVMIR(
&llvmContext, *module, SMArch.getValue(), tmaInfos, false /*isRocm*/);
&llvmContext, *module, SMArch.getValue(), tmaInfos, Target::Default);
if (!llvmir) {
llvm::errs() << "Translate to LLVM IR failed";

View File

@@ -30,9 +30,13 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
Option<"tmaMetadata", "tma-metadata",
"mlir::triton::gpu::TMAMetadataTy*", /*default*/"nullptr",
"tma metadata to the runtime">,
Option<"isROCM", "is-rocm",
"bool", /*default*/"false",
"compile for ROCM-compatible LLVM">,
Option<"target", "target", "enum Target", "mlir::triton::Target::Default",
"compile for target compatible LLVM",
"llvm::cl::values("
"clEnumValN(mlir::triton::Target::NVVM, \"nvvm\", \"compile for "
"NVVM-compatible LLVM\"), "
"clEnumValN(mlir::triton::Target::ROCDL, \"rocdl\", \"compile for "
"ROCDL-compatible LLVM\"))">,
];
}

View File

@@ -14,6 +14,8 @@ template <typename T> class OperationPass;
namespace triton {
enum Target { NVVM, ROCDL, Default = NVVM };
#define GEN_PASS_DECL
#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc"

View File

@@ -1,5 +1,6 @@
#ifndef TRITON_TARGET_LLVM_IR_LLVM_IR_TRANSLATION_H
#define TRITON_TARGET_LLVM_IR_LLVM_IR_TRANSLATION_H
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
#include "triton/Target/PTX/TmaMetadata.h"
#include "llvm/ADT/StringRef.h"
#include <memory>
@@ -28,15 +29,15 @@ std::unique_ptr<llvm::Module>
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
mlir::ModuleOp module, int computeCapability,
mlir::triton::gpu::TMAMetadataTy &tmaInfos,
bool isROCM);
Target target);
// Translate mlir LLVM dialect to LLVMIR, return null if failed.
std::unique_ptr<llvm::Module>
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
bool isROCM);
Target target);
bool linkExternLib(llvm::Module &module, llvm::StringRef name,
llvm::StringRef path, bool isROCM);
llvm::StringRef path, Target target);
} // namespace triton
} // namespace mlir

View File

@@ -63,14 +63,17 @@ static void addWSNamedAttrs(Operation *op,
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx, bool isROCM)
explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx, Target target)
: ConversionTarget(ctx) {
addLegalDialect<index::IndexDialect>();
addLegalDialect<LLVM::LLVMDialect>();
if (isROCM) {
addLegalDialect<ROCDL::ROCDLDialect>();
} else {
switch (target) {
case Target::NVVM:
addLegalDialect<NVVM::NVVMDialect>();
break;
case Target::ROCDL:
addLegalDialect<ROCDL::ROCDLDialect>();
break;
}
addLegalOp<mlir::UnrealizedConversionCastOp>();
}
@@ -359,13 +362,16 @@ private:
class TritonLLVMConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMConversionTarget(MLIRContext &ctx, bool isROCM)
explicit TritonLLVMConversionTarget(MLIRContext &ctx, Target target)
: ConversionTarget(ctx) {
addLegalDialect<LLVM::LLVMDialect>();
if (isROCM) {
addLegalDialect<ROCDL::ROCDLDialect>();
} else {
switch (target) {
case Target::NVVM:
addLegalDialect<NVVM::NVVMDialect>();
break;
case Target::ROCDL:
addLegalDialect<ROCDL::ROCDLDialect>();
break;
}
addLegalDialect<mlir::triton::nvgpu::NVGPUDialect>();
addIllegalDialect<triton::TritonDialect>();
@@ -387,7 +393,7 @@ struct ConvertTritonGPUToLLVM
mlir::LowerToLLVMOptions option(context);
option.overrideIndexBitwidth(32);
TritonGPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMConversionTarget target(*context, isROCM);
TritonLLVMConversionTarget convTarget(*context, target);
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
@@ -441,7 +447,7 @@ struct ConvertTritonGPUToLLVM
{
mlir::LowerToLLVMOptions option(context);
TritonGPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMFunctionConversionTarget funcTarget(*context, isROCM);
TritonLLVMFunctionConversionTarget funcTarget(*context, target);
RewritePatternSet funcPatterns(context);
funcPatterns.add<FuncOpConversion>(typeConverter, numWarps, allocation,
/*benefit=*/1);
@@ -461,7 +467,7 @@ struct ConvertTritonGPUToLLVM
{
mlir::LowerToLLVMOptions option(context);
TritonGPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMFunctionConversionTarget funcTarget(*context, isROCM);
TritonLLVMFunctionConversionTarget funcTarget(*context, target);
RewritePatternSet funcPatterns(context);
funcPatterns.add<CallOpConversion>(typeConverter, numWarps, allocation,
/*benefit=*/1);
@@ -539,16 +545,19 @@ struct ConvertTritonGPUToLLVM
mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns);
// Native lowering patterns
if (isROCM) {
switch (target) {
case Target::NVVM:
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
break;
case Target::ROCDL:
mlir::populateGpuToROCDLConversionPatterns(typeConverter, patterns,
mlir::gpu::amd::HIP);
} else {
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
break;
}
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
patterns);
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
// Fold CTAId when there is only 1 CTA.

View File

@@ -55,7 +55,7 @@ struct NVVMMetadata {
// Add the nvvm related metadata to LLVM IR.
static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata,
bool isROCM) {
Target target) {
auto *module = func->getParent();
auto &ctx = func->getContext();
@@ -85,16 +85,19 @@ static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata,
}
if (metadata.isKernel) {
if (isROCM) {
func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
func->addFnAttr("amdgpu-flat-work-group-size", "1, 1024");
} else {
switch (target) {
case Target::NVVM: {
llvm::Metadata *mdArgs[] = {
llvm::ValueAsMetadata::get(func), llvm::MDString::get(ctx, "kernel"),
llvm::ValueAsMetadata::get(
llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1))};
module->getOrInsertNamedMetadata("nvvm.annotations")
->addOperand(llvm::MDNode::get(ctx, mdArgs));
} break;
case Target::ROCDL: {
func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
func->addFnAttr("amdgpu-flat-work-group-size", "1, 1024");
} break;
}
}
}
@@ -240,7 +243,7 @@ static void linkLibdevice(llvm::Module &module) {
}
bool linkExternLib(llvm::Module &module, llvm::StringRef name,
llvm::StringRef path, bool isROCM) {
llvm::StringRef path, Target target) {
llvm::SMDiagnostic err;
auto &ctx = module.getContext();
@@ -259,8 +262,7 @@ bool linkExternLib(llvm::Module &module, llvm::StringRef name,
return true;
}
// check if ROCM
if (!isROCM) {
if (target == Target::NVVM) {
if (name == "libdevice") {
linkLibdevice(module);
}
@@ -274,7 +276,7 @@ bool linkExternLib(llvm::Module &module, llvm::StringRef name,
std::unique_ptr<llvm::Module>
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
bool isROCM) {
Target target) {
DialectRegistry registry;
mlir::registerBuiltinDialectTranslation(registry);
mlir::registerLLVMDialectTranslation(registry);
@@ -302,7 +304,7 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
// dead code.
auto externLibs = getExternLibs(module);
for (auto &lib : externLibs) {
if (linkExternLib(*llvmModule, lib.first, lib.second, isROCM))
if (linkExternLib(*llvmModule, lib.first, lib.second, target))
return nullptr;
}
@@ -318,7 +320,7 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
for (auto &func : llvmModule->functions()) {
auto it = nvvmMetadata.find(func.getName());
if (it != nvvmMetadata.end())
amendLLVMFunc(&func, it->second, isROCM);
amendLLVMFunc(&func, it->second, target);
}
return llvmModule;
@@ -328,7 +330,7 @@ std::unique_ptr<llvm::Module>
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
mlir::ModuleOp module, int computeCapability,
mlir::triton::gpu::TMAMetadataTy &tmaInfos,
bool isROCM) {
Target target) {
mlir::PassManager pm(module->getContext());
mlir::registerPassManagerCLOptions();
if (failed(applyPassManagerCLOptions(pm))) {
@@ -351,7 +353,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(mlir::createConvertIndexToLLVMPass());
pm.addPass(
createConvertTritonGPUToLLVMPass({computeCapability, &tmaInfos, isROCM}));
createConvertTritonGPUToLLVMPass({computeCapability, &tmaInfos, target}));
pm.addPass(createConvertNVGPUToLLVMPass());
pm.addPass(mlir::createArithToLLVMConversionPass());
pm.addPass(mlir::createCanonicalizerPass());
@@ -366,7 +368,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
return nullptr;
}
auto llvmIR = translateLLVMToLLVMIR(llvmContext, module, isROCM);
auto llvmIR = translateLLVMToLLVMIR(llvmContext, module, target);
if (!llvmIR) {
llvm::errs() << "Translate to LLVM IR failed";
return nullptr;

View File

@@ -56,7 +56,7 @@ static void linkExternal(llvm::Module &module) {
// std::filesystem::path(__BUILD_DIR__) / "lib" / "Hopper" /
// "libhopper_helpers.bc";
if (mlir::triton::linkExternLib(module, "libhopper_helpers", path.string(),
/*isROCM*/ false))
mlir::triton::Target::NVVM))
llvm::errs() << "Link failed for: libhopper_helpers.bc";
}

View File

@@ -81,6 +81,11 @@ void init_triton_runtime(py::module &&m) {
.value("CUDA", CUDA)
.value("ROCM", ROCM)
.export_values();
py::enum_<mlir::triton::Target>(m, "TARGET")
.value("NVVM", mlir::triton::NVVM)
.value("ROCDL", mlir::triton::ROCDL)
.export_values();
}
// A custom op builder that keeps track of the last location
@@ -1804,11 +1809,12 @@ void init_triton_translation(py::module &m) {
m.def(
"translate_triton_gpu_to_llvmir",
[](mlir::ModuleOp op, int computeCapability,
mlir::triton::gpu::TMAMetadataTy &tmaInfos, bool isROCM) {
mlir::triton::gpu::TMAMetadataTy &tmaInfos,
mlir::triton::Target target) {
py::gil_scoped_release allow_threads;
llvm::LLVMContext llvmContext;
auto llvmModule = ::mlir::triton::translateTritonGPUToLLVMIR(
&llvmContext, op, computeCapability, tmaInfos, isROCM);
&llvmContext, op, computeCapability, tmaInfos, target);
if (!llvmModule)
llvm::report_fatal_error("Failed to translate TritonGPU to LLVM IR.");

View File

@@ -13,7 +13,7 @@ from typing import Any, Tuple
from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs,
compile_ptx_to_cubin, get_env_vars, get_num_warps,
get_shared_memory_size, ir,
get_shared_memory_size, ir, runtime,
translate_llvmir_to_hsaco, translate_llvmir_to_ptx,
translate_triton_gpu_to_llvmir)
from ..common.backend import get_backend, path_to_ptxas
@@ -142,9 +142,9 @@ def ttgir_to_llir(mod, extern_libs, arch, tma_infos):
_add_external_libs(mod, extern_libs)
# TODO: separate tritongpu_to_llvmir for different backends
if _is_cuda(arch):
return translate_triton_gpu_to_llvmir(mod, arch, tma_infos, False)
return translate_triton_gpu_to_llvmir(mod, arch, tma_infos, runtime.TARGET.NVVM)
else:
return translate_triton_gpu_to_llvmir(mod, 0, True)
return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL)
# PTX translation

View File

@@ -1,4 +1,4 @@
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm | FileCheck %s
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm="target=nvvm" | FileCheck %s
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<f16, 1>)