Merge commit '721897fcc4f942aa97d2e9ba3787a5e213758177' into ifu-231108

Conflicts:
	bin/triton-translate.cpp
	lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
	lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
	python/triton/compiler/compiler.py
	python/triton/runtime/jit.py
	python/tutorials/06-fused-attention.py
	test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
Jason Furmanek
2023-11-08 18:51:23 +00:00
72 changed files with 1623 additions and 838 deletions

View File

@@ -14,7 +14,6 @@ add_mlir_translation_library(TritonLLVMIR
PUBLIC
MLIRArithToLLVM
MLIRBuiltinToLLVMIRTranslation
MLIRExecutionEngineUtils
MLIRIndexToLLVM
MLIRIR
MLIRLLVMDialect

View File

@@ -44,7 +44,8 @@ static bool findAndReplace(std::string &str, const std::string &begin,
return true;
}
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version,
bool enable_fp_fusion) {
// LLVM version in use may not officially support target hardware.
// Supported versions for LLVM 14 are here:
// https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/clang/include/clang/Basic/BuiltinsNVPTX.def
@@ -84,13 +85,15 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
auto target =
llvm::TargetRegistry::lookupTarget(module.getTargetTriple(), error);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
if (enable_fp_fusion)
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
opt.TrapUnreachable = true;
std::unique_ptr<llvm::TargetMachine> machine{target->createTargetMachine(
module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
std::nullopt, llvm::CodeGenOpt::Aggressive)};
std::nullopt, llvm::CodeGenOptLevel::Aggressive)};
// set data layout
if (layout.empty())
module.setDataLayout(machine->createDataLayout());
@@ -106,7 +109,7 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
llvm::legacy::PassManager pass;
// emit
machine->addPassesToEmitFile(pass, pstream, nullptr,
llvm::CodeGenFileType::CGFT_AssemblyFile);
llvm::CodeGenFileType::AssemblyFile);
pass.run(module);
}
// post-process