Files
ROCm/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
2024-01-03 12:13:36 -06:00

1096 lines
43 KiB
C++

#include "TritonGPUToLLVM.h"
#include "Utility.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
namespace {
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getSRegValue;
using ::mlir::triton::gpu::getTotalElemsPerThread;
using ::mlir::triton::gpu::SharedEncodingAttr;
Value llGetPid(int axis, Location loc, ModuleOp moduleOp,
ConversionPatternRewriter &rewriter) {
assert(axis >= 0);
assert(axis < 3);
assert(moduleOp);
#ifdef USE_ROCM
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
mlir::gpu::Dimension::y,
mlir::gpu::Dimension::z};
Value blockId = rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[axis]);
return rewriter.create<arith::IndexCastOp>(loc, i32_ty, blockId);
#else
// It is not easy to get the compute capability here, so we use numCTAs to
// decide the semantic of GetProgramIdOp. If numCTAs = 1, then
// GetProgramIdOp is converted to "%ctaid", otherwise it is converted to
// "%clusterid".
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp);
std::string sreg = numCTAs == 1 ? "%ctaid." : "%clusterid.";
sreg.append(1, 'x' + axis); // 0 -> 'x', 1 -> 'y', 2 -> 'z'
return getSRegValue(rewriter, loc, sreg);
#endif
}
struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
using ConvertOpToLLVMPattern<triton::ReturnOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
unsigned numArguments = op.getNumOperands();
// Currently, Triton kernel function always return nothing.
// TODO(Superjomn) add support for non-inline device function
if (numArguments > 0) {
return rewriter.notifyMatchFailure(
op, "Only kernel function with nothing returned is supported.");
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
op->getAttrs());
return success();
}
};
struct BroadcastOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::BroadcastOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::BroadcastOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Following the order of indices in the legacy code, a broadcast of:
// [s(0), s(1) ... s(k-1), 1, s(k+1), s(k+2) ... s(n-1)]
// =>
// [s(0), s(1) ... s(k-1), s(k), s(k+1), s(k+2) ... s(n-1)]
//
// logically maps to a broadcast within a thread's scope:
// [cta(0)..cta(k-1), 1,cta(k+1)..cta(n-1),spt(0)..spt(k-1),
// 1,spt(k+1)..spt(n-1)]
// =>
// [cta(0)..cta(k-1),cta(k),cta(k+1)..cta(n-1),spt(0)..spt(k-1),spt(k),spt(k+1)..spt(n-1)]
//
// regardless of the order of the layout
//
Location loc = op->getLoc();
Value src = adaptor.getSrc();
Value result = op.getResult();
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
auto resultTy = result.getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding();
auto resultLayout = resultTy.getEncoding();
auto srcShape = srcTy.getShape();
auto resultShape = resultTy.getShape();
unsigned rank = srcTy.getRank();
assert(rank == resultTy.getRank());
auto order = triton::gpu::getOrder(srcLayout);
auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy);
auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy);
SmallVector<Value> srcVals =
getTypeConverter()->unpackLLElements(loc, src, rewriter, srcTy);
DenseMap<SmallVector<unsigned>, Value, SmallVectorKeyInfo> srcValues;
for (size_t i = 0; i < srcOffsets.size(); i++) {
srcValues[srcOffsets[i]] = srcVals[i];
}
SmallVector<Value> resultVals;
for (size_t i = 0; i < resultOffsets.size(); i++) {
auto offset = resultOffsets[i];
for (size_t j = 0; j < srcShape.size(); j++)
if (srcShape[j] == 1)
offset[j] = 0;
resultVals.push_back(srcValues.lookup(offset));
}
Value resultStruct =
getTypeConverter()->packLLElements(loc, resultVals, rewriter, resultTy);
rewriter.replaceOp(op, {resultStruct});
return success();
}
};
// The input print op contains:
// - a "prefix" (string) specified by the user, and
// - one or more "operands" (tensors).
//
// For each operand, we print all of the values contained in this GPU thread,
// one per line, along with the index of the value in its tensor.
struct PrintOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::PrintOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::PrintOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
Value prefixStr =
LLVM::addStringToModule(loc, rewriter, "printfPrefix_", op.getPrefix());
auto getPid = [&](int axis) {
return llGetPid(axis, loc, op->getParentOfType<ModuleOp>(), rewriter);
};
std::array<Value, 3> pid = {getPid(0), getPid(1), getPid(2)};
// Simple printf of a string without any tensors.
if (op.getNumOperands() == 0) {
std::string formatStr;
llvm::raw_string_ostream os(formatStr);
#ifdef USE_ROCM
os << "pid (" << getFormatSubstr(pid[0]) << ", "
<< getFormatSubstr(pid[1]) << ", " << getFormatSubstr(pid[2]) << ")" << op.getPrefix().str();
llPrintfHIP(loc, op->getParentOfType<mlir::ModuleOp>(), formatStr,
{pid[0], pid[1], pid[2]}, rewriter);
#else
os << "pid (" << getFormatSubstr(pid[0]) << ", "
<< getFormatSubstr(pid[1]) << ", " << getFormatSubstr(pid[2]) << ")%s";
llPrintf(formatStr, {pid[0], pid[1], pid[2], prefixStr}, rewriter);
#endif
} else {
for (size_t i = 0; i < op.getNumOperands(); i++) {
// Elements of the tensor that are resident in this GPU thread.
auto elems = getTypeConverter()->unpackLLElements(
loc, adaptor.getOperands()[i], rewriter,
op.getOperand(i).getType());
// Get the indices of `elems` within the tensor. Note that if `elems`
// has an "interesting" layout, then these will not be in any
// particularly nice order.
// Extract the shape of the tensor being printed and use it to figure
// out how many digits we need for each of the dimensions.
SmallVector<int, 8> dimWidths;
SmallVector<SmallVector<Value>> indices;
if (auto rankedTy =
op.getOperand(i).getType().dyn_cast<RankedTensorType>()) {
indices =
emitIndices(loc, rewriter, rankedTy.getEncoding(), rankedTy);
for (int64_t dim : rankedTy.getShape()) {
if (dim > 0) {
dimWidths.push_back(static_cast<int>(std::ceil(std::log10(dim))));
} else {
dimWidths.push_back(0);
}
}
} else {
// We're printing a scalar.
assert(elems.size() == 1);
indices.push_back({});
}
if (!elems.empty()) {
printTensor(op, prefixStr, /*operand=*/i,
/*numOperands=*/op.getNumOperands(), elems, pid, indices,
dimWidths, rewriter);
}
}
}
rewriter.eraseOp(op);
return success();
}
void printTensor(triton::PrintOp op, Value prefixStr, size_t operand, size_t numOperands,
ArrayRef<Value> elems, std::array<Value, 3> pid,
ArrayRef<SmallVector<Value>> indices,
ArrayRef<int> dimWidths,
ConversionPatternRewriter &rewriter) const {
assert(!elems.empty());
assert(elems.size() == indices.size());
assert(dimWidths.size() == indices.front().size());
size_t rank = dimWidths.size();
// Format is:
// pid (<x>, <y>, <z>) idx (<i1>, <i2>, ...)<prefix> (operand <n>) <elem>
// where we leave off "(operand <n>)" if there's only one operand.
//
// The Python wrapper munges `prefix` so that it prints nicely (e.g. starts
// with " " and ends with ": ").
Value formatStrValue;
ConvertTritonGPUOpToLLVMPatternBase::PrintFormatting formatting;
for (int i = 0; i < elems.size(); i++) {
std::string formatStr;
llvm::raw_string_ostream os(formatStr);
// nvptx printf can only accept 32 args; if we pass more than that, it
// will print garbage for the trailing args.
constexpr int kMaxPrintfOperands = 32;
SmallVector<Value, kMaxPrintfOperands> printfOperands;
// TODO(jlebar): We really should pad the pid, but because the max pid is
// not known at compile-time, this would require nontrivial device-side
// work.
os << "pid (";
for (int j = 0; j < pid.size(); j++) {
if (j != 0) {
os << ", ";
}
os << getFormatSubstr(pid[j]);
printfOperands.push_back(pid[j]);
}
os << ") ";
// If `rank` is large enough, we could end up exceeding
// kMaxPrintfOperands. In that case, just truncate the index.
// (Subtract 2 because we're going to add two operands after the index.)
int maxAllowedRank = kMaxPrintfOperands - printfOperands.size() - 2;
os << "idx (";
const auto &index = indices[i];
for (size_t dim = 0; dim < index.size(); dim++) {
if (dim != 0) {
os << ", ";
}
if (dim == maxAllowedRank) {
os << "... (truncated)";
break;
}
os << getFormatSubstr(index[dim], /*width=*/dimWidths[dim]);
printfOperands.push_back(index[dim]);
}
os << ")";
#if USE_ROCM
os << op.getPrefix().str();
#else
os << "%s";
printfOperands.push_back(prefixStr);
#endif
if (numOperands > 1) {
os << "(operand " << operand << ") ";
}
auto elem = elems[i];
os << getFormatSubstr(elem);
printfOperands.push_back(elem);
// It's the same format string each iteration, but it's a lot easier if we
// construct the format string at the same time as we populate
// printfOperands. But we don't want to create BLOCK_SIZE duplicate
// strings, so we cache the Value.
if (i == 0) {
#if USE_ROCM
formatting = llPrintfHIP(op->getLoc(), op->getParentOfType<mlir::ModuleOp>(), formatStr,
printfOperands, rewriter);
#else
formatStrValue = llPrintf(formatStr, printfOperands, rewriter);
#endif
} else {
#ifdef USE_ROCM
llPrintfHIP(op->getLoc(), op->getParentOfType<mlir::ModuleOp>(), formatting,
printfOperands, rewriter);
#else
llPrintf(formatStrValue, printfOperands, rewriter);
#endif
}
}
}
std::string getFormatSubstr(Value value,
std::optional<int> width = std::nullopt) const {
std::string prefix = "%";
if (width.has_value()) {
prefix += std::to_string(*width);
}
Type type = value.getType();
if (type.isa<LLVM::LLVMPointerType>()) {
return prefix + "p";
} else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
return prefix + "f";
} else if (type.isSignedInteger()) {
if (type.getIntOrFloatBitWidth() == 64)
return prefix + "lli";
else
return prefix + "i";
} else if (type.isUnsignedInteger() || type.isSignlessInteger()) {
if (type.getIntOrFloatBitWidth() == 64)
return prefix + "llu";
else
return prefix + "u";
}
assert(false && "not supported type");
return "";
}
// declare vprintf(i8*, i8*) as external function
static LLVM::LLVMFuncOp
getVprintfDeclaration(ConversionPatternRewriter &rewriter) {
auto moduleOp =
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
StringRef funcName("vprintf");
Operation *funcOp = moduleOp.lookupSymbol(funcName);
if (funcOp)
return cast<LLVM::LLVMFuncOp>(*funcOp);
auto *context = rewriter.getContext();
SmallVector<Type> argsType{ptr_ty(IntegerType::get(context, 8)),
ptr_ty(IntegerType::get(context, 8))};
auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType);
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
return rewriter.create<LLVM::LLVMFuncOp>(UnknownLoc::get(context), funcName,
funcType);
}
// extend integer to int32, extend float to float64
// this comes from vprintf alignment requirements.
static std::pair<Type, Value>
promoteValue(ConversionPatternRewriter &rewriter, Value value) {
auto *context = rewriter.getContext();
auto type = value.getType();
Value newOp = value;
Type newType = type;
auto loc = UnknownLoc::get(context);
bool bUnsigned = type.isUnsignedInteger();
if (type.isIntOrIndex() && type.getIntOrFloatBitWidth() < 32) {
if (bUnsigned) {
newType = ui32_ty;
newOp = zext(newType, value);
} else {
newType = i32_ty;
newOp = sext(newType, value);
}
} else if (type.isBF16() || type.isF16() || type.isF32()) {
newType = f64_ty;
newOp = fpext(newType, value);
}
return {newType, newOp};
}
// Returns a Value for the format string, which you can reuse.
static Value llPrintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter) {
assert(!msg.empty() && "printf with empty string not supported");
llvm::SmallString<64> msgNewline(msg);
msgNewline.push_back('\n');
msgNewline.push_back('\0');
Value msgValue =
LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()),
rewriter, "printfFormat_", msgNewline);
llPrintf(msgValue, args, rewriter);
return msgValue;
}
static void llPrintf(Value msg, ValueRange args,
ConversionPatternRewriter &rewriter) {
Type int8Ptr = ptr_ty(i8_ty);
auto *ctx = rewriter.getContext();
auto moduleOp =
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
auto funcOp = getVprintfDeclaration(rewriter);
auto loc = UnknownLoc::get(ctx);
Value one = i32_val(1);
Value zero = i32_val(0);
Value bufferPtr = null(int8Ptr);
SmallVector<Value, 16> newArgs;
if (args.size() >= 1) {
SmallVector<Type> argTypes;
for (auto arg : args) {
Type newType;
Value newArg;
std::tie(newType, newArg) = promoteValue(rewriter, arg);
argTypes.push_back(newType);
newArgs.push_back(newArg);
}
Type structTy = LLVM::LLVMStructType::getLiteral(ctx, argTypes);
auto allocated =
rewriter.create<LLVM::AllocaOp>(loc, ptr_ty(structTy), one,
/*alignment=*/0);
for (const auto &entry : llvm::enumerate(newArgs)) {
auto index = i32_val(entry.index());
auto fieldPtr = gep(ptr_ty(argTypes[entry.index()]), allocated,
ArrayRef<Value>{zero, index});
store(entry.value(), fieldPtr);
}
bufferPtr = bitcast(allocated, int8Ptr);
}
SmallVector<Value> operands{msg, bufferPtr};
call(funcOp, operands);
}
};
struct AssertOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::AssertOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::AssertOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto ctx = rewriter.getContext();
auto elems = getTypeConverter()->unpackLLElements(
loc, adaptor.getCondition(), rewriter, op.getCondition().getType());
auto elemTy = elems[0].getType();
Value condition = int_val(elemTy.getIntOrFloatBitWidth(), 0);
for (auto elem : elems) {
if (elemTy.isSignedInteger() || elemTy.isSignlessInteger()) {
condition =
or_(condition,
icmp_eq(elem, rewriter.create<LLVM::ConstantOp>(
loc, elemTy, rewriter.getZeroAttr(elemTy))));
} else {
assert(false && "Unsupported type for assert");
return failure();
}
}
#ifdef USE_ROCM
llAssertHIP(op, condition, adaptor.getMessage(), adaptor.getFile(),
adaptor.getFunc(), adaptor.getLine(), rewriter);
#else
llAssert(op, condition, adaptor.getMessage(), adaptor.getFile(),
adaptor.getFunc(), adaptor.getLine(), rewriter);
#endif
rewriter.eraseOp(op);
return success();
}
// op: the op at which the assert is inserted. Unlike printf, we need to
// know about the op to split the block.
#ifdef USE_ROCM
void llAssertHIP(Operation *op, Value condition, StringRef message,
StringRef file, StringRef func, int line,
ConversionPatternRewriter &rewriter) const {
ConversionPatternRewriter::InsertionGuard guard(rewriter);
auto ctx = rewriter.getContext();
auto loc = op->getLoc();
// #prevBlock
// if (condition) {
// #ifBlock
// print(message);
// halt;
// }
// #endBlock
Block *prevBlock = op->getBlock();
Block *ifBlock = rewriter.splitBlock(prevBlock, op->getIterator());
rewriter.setInsertionPointToStart(ifBlock);
SmallString<256> tmpBuf;
message =
llvm::Twine("Assertion failed: " + message + ", File: " + file +
", Function: " + func + ", Line: " + llvm::Twine(line))
.toStringRef(tmpBuf);
// Print assert message.
llPrintfHIP(loc, op->getParentOfType<mlir::ModuleOp>(), message,
ValueRange(), rewriter, /*stderr*/ true);
// Perform the trap.
GCNBuilder BuilderTrap;
// TODO: LLVM::Trap LLVM::DebugTrap instructions don't work here.
BuilderTrap.create<>("s_endpgm")->operator()();
BuilderTrap.launch(rewriter, loc, void_ty(ctx));
// Split a block after the call.
Block *endBlock = rewriter.splitBlock(ifBlock, op->getIterator());
rewriter.setInsertionPointToEnd(ifBlock);
rewriter.create<cf::BranchOp>(loc, endBlock);
rewriter.setInsertionPointToEnd(prevBlock);
rewriter.create<cf::CondBranchOp>(loc, condition, ifBlock, endBlock);
}
#else // USE_ROCM
static void llAssert(Operation *op, Value condition, StringRef message,
StringRef file, StringRef func, int line,
ConversionPatternRewriter &rewriter) {
ConversionPatternRewriter::InsertionGuard guard(rewriter);
auto ctx = rewriter.getContext();
auto loc = op->getLoc();
// #block1
// if (condition) {
// #block2
// __assertfail(message);
// }
// #block3
Block *prevBlock = op->getBlock();
Block *ifBlock = rewriter.splitBlock(prevBlock, op->getIterator());
rewriter.setInsertionPointToStart(ifBlock);
auto funcOp = getAssertfailDeclaration(rewriter);
auto moduleOp =
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
Value messageString =
LLVM::addStringToModule(loc, rewriter, "assertMessage_", message);
Value fileString =
LLVM::addStringToModule(loc, rewriter, "assertFile_", file);
Value funcString =
LLVM::addStringToModule(loc, rewriter, "assertFunc_", func);
Value lineNumber = i32_val(line);
Value charSize = int_val(sizeof(size_t) * 8, sizeof(char));
SmallVector<Value> operands = {messageString, fileString, lineNumber,
funcString, charSize};
auto ret = call(funcOp, operands);
// Split a block after the call.
Block *thenBlock = rewriter.splitBlock(ifBlock, op->getIterator());
rewriter.setInsertionPointToEnd(ifBlock);
rewriter.create<cf::BranchOp>(loc, thenBlock);
rewriter.setInsertionPointToEnd(prevBlock);
rewriter.create<cf::CondBranchOp>(loc, condition, ifBlock, thenBlock);
}
static LLVM::LLVMFuncOp
getAssertfailDeclaration(ConversionPatternRewriter &rewriter) {
auto moduleOp =
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
StringRef funcName("__assertfail");
Operation *funcOp = moduleOp.lookupSymbol(funcName);
if (funcOp)
return cast<LLVM::LLVMFuncOp>(*funcOp);
// void __assert_fail(const char * assertion, const char * file, unsigned
// int line, const char * function);
auto *ctx = rewriter.getContext();
SmallVector<Type> argsType{ptr_ty(i8_ty), ptr_ty(i8_ty), i32_ty,
ptr_ty(i8_ty),
rewriter.getIntegerType(sizeof(size_t) * 8)};
auto funcType = LLVM::LLVMFunctionType::get(void_ty(ctx), argsType);
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
return rewriter.create<LLVM::LLVMFuncOp>(UnknownLoc::get(ctx), funcName,
funcType);
}
#endif // USE_ROCM
};
struct MakeRangeOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
MakeRangeOpConversion(
TritonGPUToLLVMTypeConverter &converter,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp>(
converter, indexCacheInfo, benefit) {}
LogicalResult
matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto rankedTy = op.getResult().getType().cast<RankedTensorType>();
auto shape = rankedTy.getShape();
auto layout = rankedTy.getEncoding();
auto elemTy = rankedTy.getElementType();
assert(elemTy.isInteger(32));
Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.getStart());
auto idxs = emitIndices(loc, rewriter, layout, rankedTy);
unsigned elems = idxs.size();
SmallVector<Value> retVals(elems);
// TODO: slice layout has more elements than expected.
// Unexpected behavior for make range, but generally OK when followed by
// expand dims + broadcast. very weird behavior otherwise potentially.
for (const auto &multiDim : llvm::enumerate(idxs)) {
assert(multiDim.value().size() == 1);
retVals[multiDim.index()] = add(multiDim.value()[0], start);
}
Value result =
getTypeConverter()->packLLElements(loc, retVals, rewriter, rankedTy);
rewriter.replaceOp(op, result);
return success();
}
};
struct GetProgramIdOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::GetProgramIdOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::GetProgramIdOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value programId = llGetPid(op.getAxisAsInt(), op->getLoc(),
op->getParentOfType<ModuleOp>(), rewriter);
#ifdef USE_ROCM
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, i32_ty, programId);
#else
rewriter.replaceOp(op, programId);
#endif
return success();
}
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
mlir::gpu::Dimension::y,
mlir::gpu::Dimension::z};
};
struct GetNumProgramsOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::GetNumProgramsOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::GetNumProgramsOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
#ifdef USE_ROCM
Location loc = op->getLoc();
assert(op.getAxis() < 3);
Value blockId =
rewriter.create<::mlir::gpu::GridDimOp>(loc, dims[op.getAxis()]);
rewriter.replaceOpWithNewOp<arith::TruncIOp>(op, i32_ty, blockId);
return success();
#else
// It is not easy to get the compute capability here, so we use numCTAs to
// decide the semantic of GetNumProgramsOp. If numCTAs = 1, then
// GetNumProgramsOp is converted to "%nctaid", otherwise it is converted to
// "%nclusterid".
auto moduleOp = op->getParentOfType<ModuleOp>();
assert(moduleOp && "Parent ModuleOp not found for GetProgramIdOp");
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp);
Location loc = op->getLoc();
assert(op.getAxis() < 3);
std::string sreg = numCTAs == 1 ? "%nctaid." : "%nclusterid.";
sreg.append(1, 'x' + op.getAxis()); // 0 -> 'x', 1 -> 'y', 2 -> 'z'
Value numPrograms = getSRegValue(rewriter, loc, sreg);
rewriter.replaceOp(op, numPrograms);
return success();
#endif
}
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
mlir::gpu::Dimension::y,
mlir::gpu::Dimension::z};
};
// TODO[goostavz]: GetThreadIdOp/GetClusterCTAIdOp is a temporary solution
// before async dialect is done. These concepts should appear in ttgpu
// level, and they are planned to be deprecated along with ttgpu.mbarrier_xxx
// ops.
struct GetThreadIdOpConversion : public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::GetThreadIdOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::GetThreadIdOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::GetThreadIdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(op, getThreadId(rewriter, op->getLoc()));
return success();
}
};
struct GetClusterCTAIdOpConversion
: public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::GetClusterCTAIdOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::GetClusterCTAIdOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::GetClusterCTAIdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(op, getClusterCTAId(rewriter, op->getLoc()));
return success();
}
};
struct AddPtrOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::AddPtrOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::AddPtrOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = op.getType();
auto offsetTy = op.getOffset().getType();
auto ptrTy = op.getPtr().getType();
auto resultTensorTy = resultTy.dyn_cast<RankedTensorType>();
if (resultTensorTy) {
unsigned elems = getTotalElemsPerThread(resultTy);
Type elemTy =
getTypeConverter()->convertType(resultTensorTy.getElementType());
auto ptrs = getTypeConverter()->unpackLLElements(loc, adaptor.getPtr(),
rewriter, ptrTy);
auto offsets = getTypeConverter()->unpackLLElements(
loc, adaptor.getOffset(), rewriter, offsetTy);
SmallVector<Value> resultVals(elems);
for (unsigned i = 0; i < elems; ++i) {
resultVals[i] = gep(elemTy, ptrs[i], offsets[i]);
}
Value view = getTypeConverter()->packLLElements(loc, resultVals, rewriter,
resultTy);
rewriter.replaceOp(op, view);
} else {
assert(resultTy.isa<triton::PointerType>());
Type llResultTy = getTypeConverter()->convertType(resultTy);
Value result = gep(llResultTy, adaptor.getPtr(), adaptor.getOffset());
rewriter.replaceOp(op, result);
}
return success();
}
};
struct AllocTensorOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::AllocTensorOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::AllocTensorOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::gpu::AllocTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getResult());
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
auto llvmElemTy =
getTypeConverter()->convertType(resultTy.getElementType());
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
smemBase = bitcast(smemBase, elemPtrTy);
auto sharedLayout = resultTy.getEncoding().cast<SharedEncodingAttr>();
auto order = sharedLayout.getOrder();
// Workaround for 3D tensors
// TODO: we need to modify the pipeline pass to give a proper shared
// encoding to 3D tensors
SmallVector<unsigned> newOrder;
if (resultTy.getShape().size() != order.size()) {
for (auto i = 0; i < order.size(); ++i)
newOrder.push_back(order[i] + 1);
newOrder.push_back(0);
} else {
newOrder = SmallVector<unsigned>(order.begin(), order.end());
}
auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape());
auto smemObj =
SharedMemoryObject(smemBase, shapePerCTA, newOrder, loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
}
};
struct ExtractSliceOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ExtractSliceOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::ExtractSliceOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::gpu::ExtractSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// %dst = extract_slice %src[%offsets]
Location loc = op->getLoc();
auto srcTy = op.getSource().getType().dyn_cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().dyn_cast<SharedEncodingAttr>();
assert(srcLayout && "Unexpected resultLayout in ExtractSliceOpConversion");
assert(op.hasUnitStride() &&
"Only unit stride supported by ExtractSliceOpConversion");
// newBase = base + offset
// Triton supports either static and dynamic offsets
auto smemObj =
getSharedMemoryObjectFromStruct(loc, adaptor.getSource(), rewriter);
SmallVector<Value, 4> opOffsetVals;
SmallVector<Value, 4> offsetVals;
auto mixedOffsets = op.getMixedOffsets();
for (auto i = 0, j = 0; i < mixedOffsets.size(); ++i) {
if (op.isDynamicOffset(i)) {
// adaptor.getOffsets() returns list of variable offsets. the size of
// the list may not be the same as mixedOffsets
opOffsetVals.emplace_back(adaptor.getOffsets()[j]);
++j;
} else
opOffsetVals.emplace_back(i32_val(op.getStaticOffset(i)));
offsetVals.emplace_back(add(smemObj.offsets[i], opOffsetVals[i]));
}
// Compute the offset based on the original strides of the shared memory
// object
auto offset = dot(rewriter, loc, opOffsetVals, smemObj.strides);
// newShape = rank_reduce(shape)
// Triton only supports static tensor sizes
SmallVector<Value, 4> strideVals;
for (auto i = 0; i < op.getStaticSizes().size(); ++i) {
if (op.getStaticSize(i) == 1) {
offsetVals.erase(offsetVals.begin() + i);
} else {
strideVals.emplace_back(smemObj.strides[i]);
}
}
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
smemObj = SharedMemoryObject(gep(elemPtrTy, smemObj.base, offset),
strideVals, offsetVals);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
}
};
// clang-format off
/***
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# WO # W1 # | #
# # # | #
# # # # # | #
# W2 # W3 # .... | #
# # # | SkipElems #
# # # # # | #
# | #
# Slice | #
# . / \ | #
# . / \ | #
# . / \| #
# # # # # # #
# # W0 # W1 # #
# # # # #
# # # # # # tensorStride #
# # W2 # W3 # --------------------------------#
# # # # #
# # # # # # #
# tensorStride # W0 # W1 # #
# ---------------------------------- # # # #
# # # # # # #
# # W2 # W3 # #
# # # # #
# # # # # # ---> lastIdx #
# . #
# . #
# . #
# #
# #
# #
# #
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
***/
// clang-format on
struct ViewSliceOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ViewSliceOp> {
using OpAdaptor = typename triton::gpu::ViewSliceOp::Adaptor;
explicit ViewSliceOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::ViewSliceOp>(typeConverter,
benefit) {}
LogicalResult processLayout(triton::gpu::ViewSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
auto srcTy = op.getSource().getType().dyn_cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto resultTy = op.getType().template cast<RankedTensorType>();
auto vals = this->getTypeConverter()->unpackLLElements(
loc, adaptor.getSource(), rewriter, srcTy);
auto elemsPerThread = mlir::triton::gpu::getElemsPerThread(srcTy);
auto sizePerThread = mlir::triton::gpu::getSizePerThread(srcLayout);
auto totalSizePerThread = sizePerThread[0] * sizePerThread[1];
auto order = mlir::triton::gpu::getOrder(srcLayout);
auto shapePerCTA =
mlir::triton::gpu::getShapePerCTATile(srcLayout, srcShape);
shapePerCTA[0] = std::min(srcShape[0], (long)shapePerCTA[0]);
shapePerCTA[1] = std::min(srcShape[1], (long)shapePerCTA[1]);
auto offsets = op.getStaticOffsets();
auto sizes = op.getStaticSizes();
// ViewSlice only supports slicing where offsets and sizes are multiples of
// shapePerCTA. This condition ensures that slice has the same layout as the
// original tensor.
assert(offsets[0] % shapePerCTA[0] == 0);
assert(offsets[1] % shapePerCTA[1] == 0);
assert(sizes[0] % shapePerCTA[0] == 0);
assert(sizes[1] % shapePerCTA[1] == 0);
assert(op.hasUnitStride() &&
"Only unit stride supported by ViewSliceOpConversion");
// Calculate offsets and sizes in terms of CTA units.
std::vector<long int> CTAOffsets{offsets[0] / shapePerCTA[0],
offsets[1] / shapePerCTA[1]};
std::vector<long int> CTASizes{sizes[0] / shapePerCTA[0],
sizes[1] / shapePerCTA[1]};
std::vector<long int> CTAPerShape{srcShape[0] / shapePerCTA[0],
srcShape[1] / shapePerCTA[1]};
SmallVector<Value> resultVals;
// The diagram above illustrates the graphical representation of the
// skipElems, tensorStride, and lastIdx variables.
auto skipElems = CTAOffsets[order[1]] *
(elemsPerThread[order[0]] * sizePerThread[order[1]]) +
CTAOffsets[order[0]] * totalSizePerThread;
auto tensorStride =
(CTAPerShape[order[0]] - CTASizes[order[0]]) * totalSizePerThread;
auto lastIdx =
(CTAOffsets[order[1]] + CTASizes[order[1]] - 1) *
elemsPerThread[order[0]] * sizePerThread[order[1]] +
(CTAOffsets[order[0]] + CTASizes[order[0]]) * totalSizePerThread;
assert(lastIdx <= vals.size());
for (int i = skipElems; i < lastIdx; i += tensorStride) {
for (int j = 0; j < totalSizePerThread * CTASizes[order[0]]; ++j, ++i) {
assert(i < lastIdx);
resultVals.push_back(vals[i]);
}
}
Value ret = this->getTypeConverter()->packLLElements(loc, resultVals,
rewriter, resultTy);
rewriter.replaceOp(op, ret);
return success();
}
LogicalResult
matchAndRewrite(triton::gpu::ViewSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcTy = op.getSource().getType().dyn_cast<RankedTensorType>();
if (srcTy.getEncoding().isa<BlockedEncodingAttr>() ||
srcTy.getEncoding().isa<MfmaEncodingAttr>()) {
return processLayout(op, adaptor, rewriter);
} else {
assert(false && "Unsupported layout in viewSlice.");
return failure();
}
}
};
struct AsyncWaitOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::AsyncWaitOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::AsyncWaitOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::gpu::AsyncWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
PTXBuilder ptxBuilder;
auto &asyncWaitOp = *ptxBuilder.create<>("cp.async.wait_group");
auto num = op->getAttrOfType<IntegerAttr>("num").getInt();
asyncWaitOp(ptxBuilder.newConstantOperand(num));
auto ctx = op.getContext();
auto loc = op.getLoc();
auto voidTy = void_ty(ctx);
ptxBuilder.launch(rewriter, loc, voidTy);
// Safe to remove the op since it doesn't have any return value.
rewriter.eraseOp(op);
return success();
}
};
struct AsyncCommitGroupOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::AsyncCommitGroupOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::AsyncCommitGroupOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::gpu::AsyncCommitGroupOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
PTXBuilder ptxBuilder;
ptxBuilder.create<>("cp.async.commit_group")->operator()();
ptxBuilder.launch(rewriter, op.getLoc(), void_ty(op.getContext()));
// Safe to remove the op since it doesn't have any return value.
rewriter.eraseOp(op);
return success();
}
};
struct AsyncBulkWaitOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::AsyncBulkWaitOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::AsyncBulkWaitOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::gpu::AsyncBulkWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
PTXBuilder ptxBuilder;
auto &asyncBulkWaitOp = *ptxBuilder.create<>("cp.async.bulk.wait_group");
auto num = op->getAttrOfType<IntegerAttr>("num").getInt();
asyncBulkWaitOp(ptxBuilder.newConstantOperand(num));
auto ctx = op.getContext();
auto loc = op.getLoc();
auto voidTy = void_ty(ctx);
ptxBuilder.launch(rewriter, loc, voidTy);
// Safe to remove the op since it doesn't have any return value.
rewriter.eraseOp(op);
return success();
}
};
struct AsyncBulkCommitGroupOpConversion
: public ConvertTritonGPUOpToLLVMPattern<
triton::gpu::AsyncBulkCommitGroupOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::AsyncBulkCommitGroupOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::gpu::AsyncBulkCommitGroupOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
PTXBuilder ptxBuilder;
ptxBuilder.create<>("cp.async.bulk.commit_group")->operator()();
ptxBuilder.launch(rewriter, op.getLoc(), void_ty(op.getContext()));
// Safe to remove the op since it doesn't have any return value.
rewriter.eraseOp(op);
return success();
}
};
} // namespace
namespace mlir::triton {
void populateTritonGPUToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &moduleAllocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) {
patterns.add<AddPtrOpConversion>(typeConverter, benefit);
patterns.add<AllocTensorOpConversion>(typeConverter, moduleAllocation,
benefit);
patterns.add<AsyncCommitGroupOpConversion>(typeConverter, benefit);
patterns.add<AsyncWaitOpConversion>(typeConverter, benefit);
patterns.add<AsyncBulkCommitGroupOpConversion>(typeConverter, benefit);
patterns.add<AsyncBulkWaitOpConversion>(typeConverter, benefit);
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
patterns.add<ExtractSliceOpConversion>(typeConverter, moduleAllocation,
benefit);
patterns.add<ViewSliceOpConversion>(typeConverter, benefit);
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
patterns.add<GetThreadIdOpConversion>(typeConverter, benefit);
patterns.add<GetClusterCTAIdOpConversion>(typeConverter, benefit);
patterns.add<MakeRangeOpConversion>(typeConverter, indexCacheInfo, benefit);
patterns.add<ReturnOpConversion>(typeConverter, benefit);
patterns.add<PrintOpConversion>(typeConverter, benefit);
patterns.add<AssertOpConversion>(typeConverter, benefit);
}
} // namespace mlir::triton