mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge commit 'ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33' into ifu-rebase-again
Conflicts: .gitignore .gitmodules README.md bin/triton-translate.cpp include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td include/triton/Target/AMDGCN/AMDGCNTranslation.h include/triton/Target/HSACO/HSACOTranslation.h lib/Analysis/Allocation.cpp lib/Analysis/Utility.cpp lib/Conversion/TritonGPUToLLVM/CMakeLists.txt lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/Utility.cpp lib/Conversion/TritonGPUToLLVM/Utility.h lib/Dialect/TritonGPU/IR/Dialect.cpp lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp lib/Target/HSACO/CMakeLists.txt lib/Target/HSACO/HSACOTranslation.cpp lib/Target/LLVMIR/LLVMIRTranslation.cpp python/src/triton.cc python/test/unit/language/test_core.py python/test/unit/operators/test_flash_attention.py python/triton/compiler/compiler.py python/triton/compiler/make_launcher.py python/triton/language/semantic.py python/triton/runtime/jit.py python/tutorials/06-fused-attention.py python/tutorials/11-grouped-gemm.py test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
@@ -53,7 +53,6 @@ llvm_update_compile_flags(triton-translate)
|
||||
TritonNvidiaGPUTransforms
|
||||
TritonLLVMIR
|
||||
TritonPTX
|
||||
TritonHSACO
|
||||
${dialect_libs}
|
||||
${conversion_libs}
|
||||
# tests
|
||||
@@ -80,3 +79,20 @@ llvm_update_compile_flags(triton-translate)
|
||||
MLIRROCDLToLLVMIRTranslation
|
||||
)
|
||||
mlir_check_all_link_libraries(triton-translate)
|
||||
|
||||
add_llvm_executable(triton-llvm-opt
|
||||
triton-llvm-opt.cpp
|
||||
|
||||
DEPENDS
|
||||
intrinsics_gen
|
||||
SUPPORT_PLUGINS
|
||||
)
|
||||
target_link_libraries(triton-llvm-opt PRIVATE
|
||||
TritonLLVMIR
|
||||
|
||||
LLVMCore
|
||||
LLVMSupport
|
||||
LLVMOption
|
||||
LLVMCodeGen
|
||||
)
|
||||
export_executable_symbols_for_plugins(triton-llvm-opt)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#pragma once
|
||||
#include "triton/Dialect/NVGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
|
||||
@@ -11,6 +12,7 @@
|
||||
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/Passes.h"
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "mlir/InitAllPasses.h"
|
||||
|
||||
namespace mlir {
|
||||
@@ -40,5 +42,6 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) {
|
||||
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,
|
||||
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
|
||||
mlir::arith::ArithDialect, mlir::scf::SCFDialect,
|
||||
mlir::gpu::GPUDialect>();
|
||||
mlir::gpu::GPUDialect, mlir::LLVM::LLVMDialect,
|
||||
mlir::NVVM::NVVMDialect, mlir::triton::nvgpu::NVGPUDialect>();
|
||||
}
|
||||
|
||||
114
bin/triton-llvm-opt.cpp
Normal file
114
bin/triton-llvm-opt.cpp
Normal file
@@ -0,0 +1,114 @@
|
||||
/// Trimmed down clone of llvm opt to be able to test triton custom llvm ir
|
||||
/// passes.
|
||||
#include "lib/Target/LLVMIR/LLVMPasses.h"
|
||||
#include "llvm/CodeGen/CommandFlags.h"
|
||||
#include "llvm/IR/Constants.h"
|
||||
#include "llvm/IR/DataLayout.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/IRReader/IRReader.h"
|
||||
#include "llvm/Passes/PassBuilder.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/Error.h"
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/SystemUtils.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "llvm/TargetParser/Triple.h"
|
||||
#include <optional>
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
static cl::opt<std::string> InputFilename(cl::Positional,
|
||||
cl::desc("<input bitcode file>"),
|
||||
cl::init("-"),
|
||||
cl::value_desc("filename"));
|
||||
|
||||
static cl::opt<std::string> ClDataLayout("data-layout",
|
||||
cl::desc("data layout string to use"),
|
||||
cl::value_desc("layout-string"),
|
||||
cl::init(""));
|
||||
static cl::opt<std::string>
|
||||
TargetTriple("mtriple", cl::desc("Override target triple for module"));
|
||||
|
||||
static cl::opt<bool>
|
||||
BreakStructPhiNodes("break-struct-phi-nodes",
|
||||
llvm::cl::desc("run pass to break phi struct"),
|
||||
cl::init(false));
|
||||
|
||||
namespace {
|
||||
static std::function<Error(Module *)> makeOptimizingPipeline() {
|
||||
return [](Module *m) -> Error {
|
||||
PipelineTuningOptions tuningOptions;
|
||||
PassBuilder pb(nullptr, tuningOptions);
|
||||
|
||||
LoopAnalysisManager lam;
|
||||
FunctionAnalysisManager fam;
|
||||
CGSCCAnalysisManager cgam;
|
||||
ModuleAnalysisManager mam;
|
||||
pb.registerModuleAnalyses(mam);
|
||||
pb.registerCGSCCAnalyses(cgam);
|
||||
pb.registerFunctionAnalyses(fam);
|
||||
pb.registerLoopAnalyses(lam);
|
||||
pb.crossRegisterProxies(lam, fam, cgam, mam);
|
||||
|
||||
ModulePassManager mpm;
|
||||
llvm::FunctionPassManager fpm;
|
||||
if (BreakStructPhiNodes)
|
||||
fpm.addPass(BreakStructPhiNodesPass());
|
||||
mpm.addPass(createModuleToFunctionPassAdaptor(std::move(fpm)));
|
||||
mpm.run(*m, mam);
|
||||
return Error::success();
|
||||
};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
InitLLVM X(argc, argv);
|
||||
cl::ParseCommandLineOptions(
|
||||
argc, argv, "llvm .bc -> .bc modular optimizer and analysis printer\n");
|
||||
|
||||
LLVMContext Context;
|
||||
SMDiagnostic Err;
|
||||
|
||||
// Load the input module...
|
||||
auto SetDataLayout = [](StringRef, StringRef) -> std::optional<std::string> {
|
||||
if (ClDataLayout.empty())
|
||||
return std::nullopt;
|
||||
return ClDataLayout;
|
||||
};
|
||||
std::unique_ptr<Module> M;
|
||||
M = parseIRFile(InputFilename, Err, Context, ParserCallbacks(SetDataLayout));
|
||||
if (!M) {
|
||||
Err.print(argv[0], errs());
|
||||
return 1;
|
||||
}
|
||||
// If we are supposed to override the target triple or data layout, do so now.
|
||||
if (!TargetTriple.empty())
|
||||
M->setTargetTriple(Triple::normalize(TargetTriple));
|
||||
auto optPipeline = makeOptimizingPipeline();
|
||||
if (auto err = optPipeline(M.get())) {
|
||||
llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
|
||||
}
|
||||
|
||||
if (verifyModule(*M, &errs())) {
|
||||
errs() << argv[0] << ": " << InputFilename
|
||||
<< ": error: input module is broken!\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Write to standard output.
|
||||
std::unique_ptr<ToolOutputFile> Out;
|
||||
std::string OutputFilename = "-";
|
||||
std::error_code EC;
|
||||
sys::fs::OpenFlags Flags = sys::fs::OF_TextWithCRLF;
|
||||
Out.reset(new ToolOutputFile(OutputFilename, EC, Flags));
|
||||
if (EC) {
|
||||
errs() << EC.message() << '\n';
|
||||
return 1;
|
||||
}
|
||||
Out->os() << *M << "\n";
|
||||
return 0;
|
||||
}
|
||||
@@ -15,7 +15,6 @@
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
|
||||
#include "triton/Target/HSACO/HSACOTranslation.h"
|
||||
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
||||
#include "triton/Target/PTX/PTXTranslation.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
@@ -143,11 +142,14 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
|
||||
} else if (targetKind == "ptx") {
|
||||
llvm::outs() << ::triton::translateLLVMIRToPTX(*llvmir, SMArch.getValue(),
|
||||
ptxVersion.getValue());
|
||||
<<<<<<< HEAD
|
||||
} else if (targetKind == "hsaco") {
|
||||
auto [module, hsaco] = mlir::triton::translateLLVMIRToHSACO(
|
||||
*llvmir, GCNArch.getValue(), GCNTriple.getValue(),
|
||||
GCNFeatures.getValue());
|
||||
llvm::outs() << hsaco;
|
||||
=======
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
} else {
|
||||
llvm::errs() << "Error: Unknown target specified: " << targetKind << "\n";
|
||||
return failure();
|
||||
|
||||
Reference in New Issue
Block a user