mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Conflicts: .gitignore bin/triton-translate.cpp include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td lib/Analysis/Utility.cpp lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp lib/Conversion/TritonGPUToLLVM/Utility.h lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp lib/Dialect/TritonGPU/IR/Dialect.cpp lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp lib/Target/LLVMIR/LLVMIRTranslation.cpp python/src/triton.cc python/test/unit/runtime/test_subproc.py python/triton/compiler/compiler.py python/triton/compiler/make_launcher.py python/triton/language/semantic.py python/triton/runtime/jit.py python/tutorials/06-fused-attention.py test/Conversion/triton_to_tritongpu.mlir test/Conversion/tritongpu_to_llvm.mlir test/TritonGPU/coalesce.mlir unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt
235 lines
6.6 KiB
C++
235 lines
6.6 KiB
C++
#include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h"
|
|
#include "triton/Conversion/TritonGPUToLLVM/AsmFormat.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "triton/Conversion/TritonGPUToLLVM/AsmFormat.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
// TODO(Superjomn): unify to llvm::raw_string_ostream
|
|
#include <sstream>
|
|
|
|
namespace mlir {
|
|
namespace triton {
|
|
|
|
PTXInstr::Operand *
|
|
PTXBuilder::newOperand(mlir::Value value, StringRef constraint,
|
|
std::function<std::string(int)> formatter) {
|
|
argArchive.emplace_back(std::make_unique<Operand>(value, constraint));
|
|
auto *opr = argArchive.back().get();
|
|
opr->repr = formatter;
|
|
opr->idx = oprCounter++;
|
|
return opr;
|
|
}
|
|
|
|
void PTXBuilder::initOperand(Operand *opr) {
|
|
auto numBits = 0;
|
|
// Derive numBits from the constraint.
|
|
if (opr->constraint[1] == 'c' || opr->constraint[1] == 'h')
|
|
numBits = 16;
|
|
else if (opr->constraint[1] == 'r')
|
|
numBits = 32;
|
|
else if (opr->constraint[1] == 'l')
|
|
numBits = 64;
|
|
else
|
|
llvm_unreachable(("Unknown constraint: " + opr->constraint).c_str());
|
|
// If numBits is less than 16, we use 16 as default because PTX does not
|
|
// support 8-bit mov.
|
|
numBits = numBits < 16 ? 16 : numBits;
|
|
auto *zero = newConstantOperand(0);
|
|
auto &init = create<>("mov")->o("u" + std::to_string(numBits));
|
|
init(opr, zero);
|
|
}
|
|
|
|
PTXBuilder::Operand *PTXBuilder::newOperand(StringRef constraint, bool init) {
|
|
// Constraint should be something like "=r"
|
|
assert(constraint.size() == 2 && constraint[0] == '=');
|
|
auto *opr = newOperand();
|
|
opr->idx = oprCounter++;
|
|
opr->constraint = constraint;
|
|
if (init) {
|
|
initOperand(opr);
|
|
}
|
|
return opr;
|
|
}
|
|
|
|
PTXBuilder::Operand *PTXBuilder::newOperand(unsigned operandIndex) {
|
|
assert(operandIndex < oprCounter && "operand index out of range");
|
|
auto *opr = newOperand();
|
|
opr->idx = oprCounter++;
|
|
opr->constraint = std::to_string(operandIndex);
|
|
return opr;
|
|
}
|
|
|
|
PTXBuilder::Operand *PTXBuilder::newConstantOperand(const std::string &v) {
|
|
argArchive.emplace_back(std::make_unique<Operand>());
|
|
argArchive.back()->repr = [v](int idx) { return v; };
|
|
return argArchive.back().get();
|
|
}
|
|
|
|
PTXBuilder::Operand *PTXBuilder::newConstantOperand(int64_t v) {
|
|
std::stringstream ss;
|
|
ss << "0x" << std::hex << v;
|
|
return newConstantOperand(ss.str());
|
|
}
|
|
|
|
std::string PTXBuilder::getConstraints() const {
|
|
auto args = getAllArgs();
|
|
llvm::SmallVector<std::string, 4> argReprs;
|
|
for (auto arg : args)
|
|
argReprs.push_back(arg->constraint);
|
|
return strJoin(argReprs, ",");
|
|
}
|
|
|
|
llvm::SmallVector<Value, 4> PTXBuilder::getAllMLIRArgs() const {
|
|
llvm::SmallVector<Value, 4> res;
|
|
for (auto &arg : argArchive) {
|
|
if (!arg->isList() && arg->value)
|
|
res.push_back(arg->value);
|
|
}
|
|
return res;
|
|
}
|
|
|
|
SmallVector<PTXBuilder::Operand *, 4> PTXBuilder::getAllArgs() const {
|
|
llvm::SmallVector<Operand *, 4> res;
|
|
for (auto &x : argArchive)
|
|
if (!x->isList())
|
|
res.push_back(x.get());
|
|
return res;
|
|
}
|
|
|
|
mlir::Value PTXBuilder::launch(OpBuilder &rewriter, Location loc, Type resTy,
|
|
bool hasSideEffect, bool isAlignStack,
|
|
ArrayRef<Attribute> attrs) const {
|
|
auto *ctx = rewriter.getContext();
|
|
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>(
|
|
loc, resTy, getAllMLIRArgs(), // operands
|
|
dump(), // asm_string
|
|
getConstraints(), // constraints
|
|
hasSideEffect, // has_side_effects
|
|
isAlignStack, // is_align_stack
|
|
LLVM::AsmDialectAttr::get(ctx,
|
|
LLVM::AsmDialect::AD_ATT), // asm_dialect
|
|
ArrayAttr::get(ctx, attrs) // operand_attrs
|
|
);
|
|
|
|
return inlineAsm.getRes();
|
|
}
|
|
|
|
std::string PTXInstr::Operand::dump() const {
|
|
if (repr)
|
|
return repr(idx);
|
|
if (!isList())
|
|
return "$" + std::to_string(idx);
|
|
|
|
llvm::SmallVector<std::string> oprs;
|
|
for (auto *opr : list)
|
|
oprs.push_back(opr->dump());
|
|
return "{ " + strJoin(oprs, ", ") + " }";
|
|
}
|
|
|
|
PTXInstr::Operand *PTXBuilder::newAddrOperand(mlir::Value addr,
|
|
StringRef constraint, int off) {
|
|
auto *opr = newOperand(addr, constraint);
|
|
opr->repr = [off](int idx) -> std::string {
|
|
std::stringstream ss;
|
|
ss << "[ $" << idx << " + " << off << " ]";
|
|
return ss.str();
|
|
};
|
|
|
|
return opr;
|
|
}
|
|
|
|
std::string PTXBuilder::dump() const {
|
|
llvm::SmallVector<std::string> lines;
|
|
for (auto &exec : executions) {
|
|
lines.push_back(exec->dump());
|
|
}
|
|
|
|
return strJoin(lines, "\n\t");
|
|
}
|
|
|
|
PTXInstrExecution &PTXInstrCommon::call(ArrayRef<Operand *> oprs,
|
|
bool onlyAttachMLIRArgs) {
|
|
if (onlyAttachMLIRArgs) {
|
|
// Nearly impossible to make the $0,$1 in two PTX code snippets to point to
|
|
// the same MLIR values in onlyAttachMLIRArgs mode.
|
|
assert(builder->executions.empty() &&
|
|
"builder can only hold a single execution when onlyAttachMIIRArgs "
|
|
"is true.");
|
|
builder->reorderArgArchive(oprs);
|
|
}
|
|
|
|
builder->executions.emplace_back(
|
|
std::make_unique<PTXInstrExecution>(this, oprs, onlyAttachMLIRArgs));
|
|
|
|
return *builder->executions.back();
|
|
}
|
|
|
|
PTXInstrExecution &PTXInstrCommon::operator()(ArrayRef<Operand *> oprs,
|
|
bool onlyAttachMLIRArgs) {
|
|
return call(oprs, onlyAttachMLIRArgs);
|
|
}
|
|
|
|
std::string PTXInstrExecution::dump() const {
|
|
std::string osStr;
|
|
llvm::raw_string_ostream os(osStr);
|
|
|
|
std::string instrRepr = strJoin(instr->instrParts, ".");
|
|
if (onlyAttachMLIRArgs)
|
|
return instrRepr;
|
|
|
|
if (pred) {
|
|
if (!pred->repr)
|
|
os << "@" << pred->dump() << " ";
|
|
else
|
|
os << pred->repr(pred->idx) << " ";
|
|
}
|
|
|
|
llvm::SmallVector<std::string, 4> argReprs;
|
|
for (auto *arg : argsInOrder) {
|
|
argReprs.push_back(arg->dump());
|
|
}
|
|
|
|
std::string argsRepr = strJoin(argReprs, ", ");
|
|
|
|
os << instrRepr << " " << argsRepr << ";";
|
|
os.flush();
|
|
return osStr;
|
|
}
|
|
|
|
SmallVector<PTXInstrExecution::Operand *>
|
|
PTXInstrExecution::getArgList() const {
|
|
SmallVector<Operand *> args;
|
|
for (auto *arg : argsInOrder) {
|
|
if (arg->isList())
|
|
args.insert(args.end(), arg->list.begin(), arg->list.end());
|
|
else
|
|
args.push_back(arg);
|
|
}
|
|
return args;
|
|
}
|
|
|
|
PTXInstr &PTXInstr::global() {
|
|
o("global");
|
|
return *this;
|
|
}
|
|
|
|
PTXInstr &PTXInstr::shared() {
|
|
o("shared");
|
|
return *this;
|
|
}
|
|
|
|
PTXInstr &PTXInstr::v(int vecWidth, bool predicate) {
|
|
if (vecWidth > 1) {
|
|
o("v" + std::to_string(vecWidth), predicate);
|
|
}
|
|
return *this;
|
|
}
|
|
|
|
PTXInstr &PTXInstr::b(int width) {
|
|
o("b" + std::to_string(width));
|
|
return *this;
|
|
}
|
|
|
|
} // namespace triton
|
|
} // namespace mlir
|