Files
ROCm/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
Ognjen 38fbb7e472 ROCM IFU: Enable slice layout for insertSliceAsync AMD path
Fix basic_insert_slice_async_1d lit test

Remove code added for debugging

Return hopper test
2023-11-17 01:27:57 +00:00

1481 lines
62 KiB
C++

#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_BASE_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_BASE_H
// TODO: refactor so that it doesn't fail if Allocation.h
// is included after utility.h (due to conflict in `store` macro
// and <atomic>
#include "triton/Analysis/Allocation.h"
#include "TypeConverter.h"
//
#include "Utility.h"
#include "mlir/IR/TypeUtilities.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Dialect/NVGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Target/PTX/TmaMetadata.h"
#include <set>
#define DEBUG_TYPE "ttgpu_to_llvm"
constexpr ::llvm::StringLiteral kAttrNumTMALoadDescsName =
"triton_gpu.num-tma-load";
constexpr ::llvm::StringLiteral kAttrNumTMAStoreDescsName =
"triton_gpu.num-tma-store";
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::delinearize;
using ::mlir::LLVM::SharedMemoryObject;
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::CTALayoutAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::MfmaEncodingAttr;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
using ::mlir::triton::gpu::TMAMetadataTy;
namespace ttng = ::mlir::triton::nvidia_gpu;
typedef DenseMap<Operation *, triton::MakeTensorPtrOp> TensorPtrMapT;
namespace mlir {
namespace LLVM {
// Helper function for using printf in LLVM conversion.
void vprintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter);
void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
std::string elem_repr, ConversionPatternRewriter &builder);
} // namespace LLVM
} // namespace mlir
// FuncOpConversion/FuncOpConversionBase is borrowed from
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276
// since it is not exposed on header files in mlir v14
// TODO(Superjomn): remove the code when MLIR v15.0 is included.
// All the rights are reserved by the LLVM community.
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<triton::FuncOp> {
protected:
/// Only retain those attributes that are not constructed by
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
/// attributes.
static void filterFuncAttributes(triton::FuncOp op, bool filterArgAttrs,
SmallVectorImpl<NamedAttribute> &result) {
for (const auto &attr : op->getAttrs()) {
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
attr.getName() == op.getFunctionTypeAttrName() ||
attr.getName() == "std.varargs" ||
(filterArgAttrs && attr.getName() == op.getArgAttrsAttrName()))
continue;
result.push_back(attr);
}
}
/// Helper function for wrapping all attributes into a single DictionaryAttr
static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) {
return DictionaryAttr::get(b.getContext(),
b.getNamedAttr("llvm.struct_attrs", attrs));
}
protected:
using ConvertOpToLLVMPattern<triton::FuncOp>::ConvertOpToLLVMPattern;
// Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
// to this legalization pattern.
LLVM::LLVMFuncOp
convertFuncOpToLLVMFuncOp(triton::FuncOp funcOp,
ConversionPatternRewriter &rewriter) const {
// Convert the original function arguments. They are converted using the
// LLVMTypeConverter provided to this legalization pattern.
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("func.varargs");
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
auto llvmType = getTypeConverter()->convertFunctionSignature(
funcOp.getFunctionType(), varargsAttr && varargsAttr.getValue(), false,
result);
if (!llvmType)
return nullptr;
// Propagate argument/result attributes to all converted arguments/result
// obtained after converting a given original argument/result.
SmallVector<NamedAttribute, 4> attributes;
filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, attributes);
if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
assert(!resAttrDicts.empty() && "expected array to be non-empty");
auto newResAttrDicts =
(funcOp.getNumResults() == 1)
? resAttrDicts
: rewriter.getArrayAttr(
{wrapAsStructAttrs(rewriter, resAttrDicts)});
attributes.push_back(
rewriter.getNamedAttr(funcOp.getResAttrsAttrName(), newResAttrDicts));
}
if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
SmallVector<Attribute, 4> newArgAttrs(
llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());
for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
auto mapping = result.getInputMapping(i);
assert(mapping && "unexpected deletion of function argument");
for (size_t j = 0; j < mapping->size; ++j)
newArgAttrs[mapping->inputNo + j] = argAttrDicts[i];
}
attributes.push_back(rewriter.getNamedAttr(
funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(newArgAttrs)));
}
for (const auto &pair : llvm::enumerate(attributes)) {
if (pair.value().getName() == "llvm.linkage") {
attributes.erase(attributes.begin() + pair.index());
break;
}
}
// Create an LLVM function, use external linkage by default until MLIR
// functions have linkage.
LLVM::Linkage linkage = LLVM::Linkage::External;
if (auto linkageAttr = funcOp->getDiscardableAttr("llvm.linkage")) {
auto attr = linkageAttr.dyn_cast<mlir::LLVM::LinkageAttr>();
if (!attr) {
funcOp->emitError()
<< "Contains llvm.linkage attribute not of type LLVM::LinkageAttr";
return nullptr;
}
linkage = attr.getLinkage();
}
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
/*dsoLocal*/ false, LLVM::CConv::C, /*comdat=*/SymbolRefAttr{},
attributes);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
&result)))
return nullptr;
return newFuncOp;
}
};
struct IndexCacheKeyT {
Attribute layout;
RankedTensorType type;
bool withCTAOffset;
};
struct CacheKeyDenseMapInfo {
static IndexCacheKeyT getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return {mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
RankedTensorType{}, true};
}
static IndexCacheKeyT getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
auto tombstone = llvm::DenseMapInfo<RankedTensorType>::getTombstoneKey();
return {mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
tombstone, true};
}
static unsigned getHashValue(IndexCacheKeyT key) {
return llvm::hash_combine(mlir::hash_value(key.layout),
mlir::hash_value(key.type),
llvm::hash_value(key.withCTAOffset));
}
static bool isEqual(IndexCacheKeyT LHS, IndexCacheKeyT RHS) {
return LHS.layout == RHS.layout && LHS.type == RHS.type &&
LHS.withCTAOffset == RHS.withCTAOffset;
}
};
class ConvertTritonGPUOpToLLVMPatternBase {
public:
// Two levels of value cache in emitting indices calculation:
// Key: {layout, shape, withCTAOffset}
struct IndexCacheInfo {
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
*baseIndexCache;
DenseMap<IndexCacheKeyT, SmallVector<SmallVector<Value>>,
CacheKeyDenseMapInfo> *indexCache;
OpBuilder::InsertPoint *indexInsertPoint;
};
explicit ConvertTritonGPUOpToLLVMPatternBase(
TritonGPUToLLVMTypeConverter &typeConverter)
: converter(&typeConverter) {}
explicit ConvertTritonGPUOpToLLVMPatternBase(
TritonGPUToLLVMTypeConverter &typeConverter,
IndexCacheInfo indexCacheInfo)
: converter(&typeConverter), indexCacheInfo(indexCacheInfo) {}
explicit ConvertTritonGPUOpToLLVMPatternBase(
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation)
: converter(&typeConverter), allocation(&allocation) {}
explicit ConvertTritonGPUOpToLLVMPatternBase(
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation,
IndexCacheInfo indexCacheInfo)
: converter(&typeConverter), allocation(&allocation),
indexCacheInfo(indexCacheInfo) {}
explicit ConvertTritonGPUOpToLLVMPatternBase(
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation,
TMAMetadataTy *tmaMetadata)
: converter(&typeConverter), allocation(&allocation),
tmaMetadata(tmaMetadata) {}
TritonGPUToLLVMTypeConverter *getTypeConverter() const { return converter; }
static Value
getStructFromSharedMemoryObject(Location loc,
const SharedMemoryObject &smemObj,
ConversionPatternRewriter &rewriter) {
auto elems = smemObj.getElems();
auto types = smemObj.getTypes();
auto structTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
// pack into struct
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structTy);
for (const auto &v : llvm::enumerate(elems)) {
assert(v.value() && "can not insert null values");
llvmStruct = insert_val(structTy, llvmStruct, v.value(), v.index());
}
return llvmStruct;
}
// Returns CTA level thread idx
Value getThreadIdInCTA(ConversionPatternRewriter &rewriter,
Location loc) const {
Value tid = rewriter.create<::mlir::gpu::ThreadIdOp>(
loc, ::mlir::gpu::Dimension::x);
return rewriter.create<arith::IndexCastOp>(loc, i32_ty, tid);
}
// Returns CTA level thread idx for not ws mode.
// Returns agent level thread idx for ws mode.
Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const {
Value tid = getThreadIdInCTA(rewriter, loc);
auto mod = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
if (ttng::TritonNvidiaGPUDialect::getWSSupportedAttr(mod)) {
Value _128 = rewriter.create<arith::ConstantIntOp>(loc, 128, 32);
tid = rewriter.create<arith::RemSIOp>(loc, tid, _128);
}
return tid;
}
Value getClusterCTAId(ConversionPatternRewriter &rewriter,
Location loc) const {
return rewriter.create<triton::nvgpu::ClusterCTAIdOp>(
loc, rewriter.getI32Type());
}
// -----------------------------------------------------------------------
// Shared memory utilities
// -----------------------------------------------------------------------
template <typename T>
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
T value) const {
auto ptrTy = LLVM::LLVMPointerType::get(
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
FunctionOpInterface funcOp;
if constexpr (std::is_pointer_v<T>)
funcOp = value->template getParentOfType<FunctionOpInterface>();
else
funcOp = value.getParentRegion()
->template getParentOfType<FunctionOpInterface>();
auto *funcAllocation = allocation->getFuncData(funcOp);
auto smem = allocation->getFunctionSharedMemoryBase(funcOp);
auto bufferId = funcAllocation->getBufferId(value);
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
size_t offset = funcAllocation->getOffset(bufferId);
Value offVal = i32_val(offset);
Value base = gep(ptrTy, smem, offVal);
return base;
}
DenseMap<unsigned, Value>
getSwizzledSharedPtrs(Location loc, unsigned inVec, RankedTensorType srcTy,
triton::gpu::SharedEncodingAttr resSharedLayout,
Type resElemTy, SharedMemoryObject smemObj,
ConversionPatternRewriter &rewriter,
SmallVectorImpl<Value> &offsetVals,
SmallVectorImpl<Value> &srcStrides) const {
// This utililty computes the pointers for accessing the provided swizzled
// shared memory layout `resSharedLayout`. More specifically, it computes,
// for all indices (row, col) of `srcEncoding` such that idx % inVec = 0,
// the pointer: ptr[(row, col)] = base + (rowOff * strides[ord[1]] +
// colOff) where :
// phase = (row // perPhase) % maxPhase
// rowOff = row
// colOff = colOffSwizzled + colOffOrdered
// colOffSwizzled = ((col // outVec) ^ phase) * outVec
// colOffOrdered = (col % outVec) // minVec * minVec
//
// Note 1:
// -------
// Because swizzling happens at a granularity of outVec, we need to
// decompose the offset into a swizzled factor and a non-swizzled
// (ordered) factor
//
// Note 2:
// -------
// If we have x, y, z of the form:
// x = 0b00000xxxx
// y = 0byyyyy0000
// z = 0b00000zzzz
// then (x + y) XOR z = 0byyyyxxxx XOR 0b00000zzzz = (x XOR z) + y
// This means that we can use some immediate offsets for shared memory
// operations.
resElemTy = getTypeConverter()->convertType(resElemTy);
auto dstPtrTy = ptr_ty(getTypeConverter()->convertType(resElemTy), 3);
auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides);
Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset);
auto srcEncoding = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto srcShapePerCTA = triton::gpu::getShapePerCTA(srcTy);
unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy);
// swizzling params as described in TritonGPUAttrDefs.td
unsigned outVec = resSharedLayout.getVec();
unsigned perPhase = resSharedLayout.getPerPhase();
unsigned maxPhase = resSharedLayout.getMaxPhase();
// Order
auto inOrder = triton::gpu::getOrder(srcEncoding);
auto outOrder = triton::gpu::getOrder(resSharedLayout);
assert(maxPhase == 1 ||
outVec * maxPhase <= srcShape[outOrder[0]] &&
"Swizzling would generate out of bounds memory accesses");
// Tensor indices held by the current thread, as LLVM values
auto srcIndices = emitIndices(loc, rewriter, srcEncoding, srcTy, false);
// Swizzling with leading offsets (e.g. Hopper GMMA)
unsigned swizzlingByteWidth = 0;
if (resSharedLayout.getHasLeadingOffset()) {
if (perPhase == 4 && maxPhase == 2)
swizzlingByteWidth = 32;
else if (perPhase == 2 && maxPhase == 4)
swizzlingByteWidth = 64;
else if (perPhase == 1 && maxPhase == 8)
swizzlingByteWidth = 128;
else
llvm::report_fatal_error("Unsupported shared layout.");
}
unsigned numElemsPerSwizzlingRow =
swizzlingByteWidth * 8 / resElemTy.getIntOrFloatBitWidth();
Value numElemsPerSwizzlingRowVal = i32_val(numElemsPerSwizzlingRow);
unsigned leadingDimOffset;
if (outOrder.size() == 2) {
leadingDimOffset = numElemsPerSwizzlingRow * srcShapePerCTA[outOrder[1]];
} else {
leadingDimOffset = numElemsPerSwizzlingRow;
}
Value leadingDimOffsetVal = i32_val(leadingDimOffset);
// Return values
DenseMap<unsigned, Value> ret;
// cache for non-immediate offsets
DenseMap<unsigned, Value> cacheCol, cacheRow;
unsigned minVec = std::min(outVec, inVec);
for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) {
Value offset = i32_val(0);
// Extract multi dimensional index for current element
auto idx = srcIndices[elemIdx];
Value idxCol = idx[outOrder[0]]; // contiguous dimension
Value idxRow, strideRow;
if (outOrder.size() == 2) {
idxRow = idx[outOrder[1]]; // discontiguous dimension
strideRow = srcStrides[outOrder[1]];
} else {
idxRow = i32_val(0);
strideRow = i32_val(0);
}
Value strideCol = srcStrides[outOrder[0]];
// compute phase = (row // perPhase) % maxPhase
Value phase = urem(udiv(idxRow, i32_val(perPhase)), i32_val(maxPhase));
// extract dynamic/static offset for immediate offsetting
unsigned immedateOffCol = 0;
unsigned immedateOffRow = 0;
if (leadingDimOffset) {
// hopper
offset =
mul(udiv(idxCol, numElemsPerSwizzlingRowVal), leadingDimOffsetVal);
// Shrink by swizzling blocks
idxCol = urem(idxCol, numElemsPerSwizzlingRowVal);
strideRow = numElemsPerSwizzlingRowVal;
} else {
if (auto add = dyn_cast_or_null<LLVM::AddOp>(idxCol.getDefiningOp()))
if (auto _cst = dyn_cast_or_null<LLVM::ConstantOp>(
add.getRhs().getDefiningOp())) {
unsigned cst =
_cst.getValue().cast<IntegerAttr>().getValue().getSExtValue();
unsigned key = cst % (outVec * maxPhase);
cacheCol.insert({key, idxCol});
idxCol = cacheCol[key];
immedateOffCol = cst / (outVec * maxPhase) * (outVec * maxPhase);
}
if (auto add = dyn_cast_or_null<LLVM::AddOp>(idxRow.getDefiningOp()))
if (auto _cst = dyn_cast_or_null<LLVM::ConstantOp>(
add.getRhs().getDefiningOp())) {
unsigned cst =
_cst.getValue().cast<IntegerAttr>().getValue().getSExtValue();
unsigned key = cst % (perPhase * maxPhase);
cacheRow.insert({key, idxRow});
idxRow = cacheRow[key];
immedateOffRow =
cst / (perPhase * maxPhase) * (perPhase * maxPhase);
}
}
// row offset is simply row index
Value rowOff = mul(idxRow, strideRow);
// because swizzling happens at a granularity of outVec, we need to
// decompose the offset into a swizzled factor and a non-swizzled
// (ordered) factor: colOffSwizzled = ((col // outVec) ^ phase) * outVec
// colOffOrdered = (col % outVec) // minVec * minVec
Value colOffSwizzled = xor_(udiv(idxCol, i32_val(outVec)), phase);
colOffSwizzled = mul(colOffSwizzled, i32_val(outVec));
Value colOffOrdered = urem(idxCol, i32_val(outVec));
colOffOrdered = udiv(colOffOrdered, i32_val(minVec));
colOffOrdered = mul(colOffOrdered, i32_val(minVec));
Value colOff = add(colOffSwizzled, colOffOrdered);
// compute non-immediate offset
offset = add(offset, add(rowOff, mul(colOff, strideCol)));
Value currPtr = gep(dstPtrTy, dstPtrBase, offset);
// compute immediate offset
Value immediateOff;
if (outOrder.size() == 2) {
immediateOff =
add(mul(i32_val(immedateOffRow), srcStrides[outOrder[1]]),
i32_val(immedateOffCol));
} else {
immediateOff = i32_val(immedateOffCol);
}
ret[elemIdx] = gep(dstPtrTy, currPtr, immediateOff);
}
return ret;
}
SmallVector<Value>
loadSharedToDistributed(Value dst, ArrayRef<SmallVector<Value>> dstIndices,
Value src, SharedMemoryObject smemObj, Type elemTy,
Location loc,
ConversionPatternRewriter &rewriter) const {
auto dstTy = dst.getType().cast<RankedTensorType>();
auto dstShape = dstTy.getShape();
assert(dstShape.size() == 2 &&
"Unexpected rank of loadSharedToDistributed");
auto srcTy = src.getType().cast<RankedTensorType>();
auto dstDistributedLayout = dstTy.getEncoding();
if (auto mmaLayout = dstDistributedLayout.dyn_cast<MmaEncodingAttr>()) {
assert((!mmaLayout.isVolta()) &&
"ConvertLayout Shared->MMAv1 is not supported yet");
}
auto srcSharedLayout =
srcTy.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
auto srcElemTy = srcTy.getElementType();
auto dstElemTy = dstTy.getElementType();
auto inOrd = triton::gpu::getOrder(srcSharedLayout);
auto outOrd = triton::gpu::getOrder(dstDistributedLayout);
unsigned outVec = inOrd == outOrd
? triton::gpu::getUniqueContigPerThread(
dstDistributedLayout, dstShape)[outOrd[0]]
: 1;
unsigned inVec = srcSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
unsigned outElems = triton::gpu::getTotalElemsPerThread(dstTy);
SmallVector<Value> offsetVals = {i32_val(0), i32_val(0)};
assert(outElems == dstIndices.size());
DenseMap<unsigned, Value> sharedPtrs =
getSwizzledSharedPtrs(loc, outVec, dstTy, srcSharedLayout, srcElemTy,
smemObj, rewriter, offsetVals, smemObj.strides);
assert(outElems % minVec == 0 && "Unexpected number of elements");
unsigned numVecs = outElems / minVec;
auto wordTy = vec_ty(elemTy, minVec);
SmallVector<Value> outVals(outElems);
for (unsigned i = 0; i < numVecs; ++i) {
Value smemAddr = sharedPtrs[i * minVec];
smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3));
Value valVec = load(smemAddr);
for (unsigned v = 0; v < minVec; ++v) {
Value currVal = extract_element(dstElemTy, valVec, i32_val(v));
outVals[i * minVec + v] = currVal;
}
}
return outVals;
}
void storeDistributedToShared(Value src, Value llSrc,
ArrayRef<Value> dstStrides,
ArrayRef<SmallVector<Value>> srcIndices,
Value dst, Value smemBase, Type elemTy,
Location loc,
ConversionPatternRewriter &rewriter) const {
auto srcTy = src.getType().cast<RankedTensorType>();
auto srcShape = srcTy.getShape();
assert((srcShape.size() == 1 || srcShape.size() == 2) &&
"Unexpected rank of storeDistributedToShared");
auto dstTy = dst.getType().cast<RankedTensorType>();
auto srcDistributedLayout = srcTy.getEncoding();
if (auto mmaLayout = srcDistributedLayout.dyn_cast<MmaEncodingAttr>()) {
assert((!mmaLayout.isVolta()) &&
"ConvertLayout MMAv1->Shared is not supported yet");
}
auto dstSharedLayout =
dstTy.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
auto dstElemTy = dstTy.getElementType();
auto inOrd = triton::gpu::getOrder(srcDistributedLayout);
auto outOrd = dstSharedLayout.getOrder();
unsigned inVec = inOrd == outOrd
? triton::gpu::getUniqueContigPerThread(
srcDistributedLayout, srcShape)[inOrd[0]]
: 1;
unsigned outVec = dstSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy);
assert(numElems == srcIndices.size());
auto inVals =
getTypeConverter()->unpackLLElements(loc, llSrc, rewriter, srcTy);
auto wordTy = vec_ty(elemTy, minVec);
Value word;
SmallVector<Value> srcStrides;
SmallVector<Value> offsetVals;
for (int i = 0; i < srcShape.size(); i++) {
srcStrides.push_back(dstStrides[i]);
offsetVals.push_back(i32_val(0));
}
SharedMemoryObject smemObj(smemBase, srcStrides, offsetVals);
DenseMap<unsigned, Value> sharedPtrs =
getSwizzledSharedPtrs(loc, inVec, srcTy, dstSharedLayout, dstElemTy,
smemObj, rewriter, offsetVals, srcStrides);
for (unsigned i = 0; i < numElems; ++i) {
if (i % minVec == 0)
word = undef(wordTy);
word = insert_element(wordTy, word, inVals[i], i32_val(i % minVec));
if (i % minVec == minVec - 1) {
Value smemAddr = sharedPtrs[i / minVec * minVec];
smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3));
store(word, smemAddr);
}
}
}
// -----------------------------------------------------------------------
// Utilities
// -----------------------------------------------------------------------
Value getMask(Type valueTy, ConversionPatternRewriter &rewriter,
Location loc) const {
auto tensorTy = valueTy.dyn_cast<RankedTensorType>();
Value mask = int_val(1, 1);
auto tid = tid_val();
auto clusterCTAId = getClusterCTAId(rewriter, loc);
if (tensorTy) {
auto layout = tensorTy.getEncoding();
auto shape = tensorTy.getShape();
unsigned rank = shape.size();
auto sizePerThread = triton::gpu::getSizePerThread(layout);
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout);
auto order = triton::gpu::getOrder(layout);
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape);
Value warpSize = i32_val(triton::gpu::getWarpSize(layout));
Value laneId = urem(tid, warpSize);
Value warpId = udiv(tid, warpSize);
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
SmallVector<Value> multiDimThreadId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
for (unsigned dim = 0; dim < rank; ++dim) {
// if there is no data replication across threads on this dimension
if (shape[dim] >= shapePerCTATile[dim])
continue;
// Otherwise, we need to mask threads that will replicate data on this
// dimension. Calculate the thread index on this dimension for the CTA
Value threadDim =
add(mul(multiDimWarpId[dim], i32_val(threadsPerWarp[dim])),
multiDimThreadId[dim]);
mask = and_(mask, icmp_slt(mul(threadDim, i32_val(sizePerThread[dim])),
i32_val(shape[dim])));
}
// Do not write duplicated data when multicast is enabled
if (triton::gpu::getNumCTAs(layout) > 1) {
auto _0 = i32_val(0);
auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(layout);
auto CTASplitNum = triton::gpu::getCTASplitNum(layout);
auto CTAOrder = triton::gpu::getCTAOrder(layout);
auto multiDimClusterCTAId =
delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);
for (unsigned dim = 0; dim < rank; ++dim) {
// Skip when multicast is not enabled in this dimension
if (CTAsPerCGA[dim] == CTASplitNum[dim])
continue;
// This wrapping rule must be consistent with emitCTAOffsetForLayout
unsigned splitNum = std::min<unsigned>(shape[dim], CTASplitNum[dim]);
Value repId = udiv(multiDimClusterCTAId[dim], i32_val(splitNum));
// Consider the example where CTAsPerCGA = [4] and CTASplitNum = [2]:
// CTA0 and CTA2 holds data of block0,
// CTA1 and CTA3 holds data of block1.
// Only CTA0 and CTA1 are expected to write while CTA2 and CTA3 should
// be masked. We add the following mask:
// multiDimClusterCTAId[dim] / splitNum == 0
// Actually in all existing cases of multicast, splitNum is always 1.
// The mask is equivalent to:
// multiDimClusterCTAId[dim] == 0
mask = and_(mask, icmp_eq(repId, _0));
}
}
} else {
// If the tensor is not ranked, then it is a scalar and only thread 0 of
// CTA0 can write
mask = and_(mask, icmp_eq(clusterCTAId, i32_val(0)));
mask = and_(mask, icmp_eq(tid, i32_val(0)));
}
return mask;
}
Value dot(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> offsets, ArrayRef<Value> strides) const {
assert(offsets.size() == strides.size());
Value ret = i32_val(0);
for (auto [offset, stride] : llvm::zip(offsets, strides)) {
ret = add(ret, mul(offset, stride));
}
return ret;
}
struct SmallVectorKeyInfo {
static unsigned getHashValue(const SmallVector<unsigned> &key) {
return llvm::hash_combine_range(key.begin(), key.end());
}
static bool isEqual(const SmallVector<unsigned> &lhs,
const SmallVector<unsigned> &rhs) {
return lhs == rhs;
}
static SmallVector<unsigned> getEmptyKey() {
return SmallVector<unsigned>();
}
static SmallVector<unsigned> getTombstoneKey() {
return {std::numeric_limits<unsigned>::max()};
}
};
// -----------------------------------------------------------------------
// Get offsets / indices for any layout
// -----------------------------------------------------------------------
SmallVector<Value> emitCTAOffsetForLayout(Location loc,
ConversionPatternRewriter &rewriter,
Attribute layout,
ArrayRef<int64_t> shape) const {
unsigned rank = shape.size();
SmallVector<unsigned> CTAsPerCGA = triton::gpu::getCTAsPerCGA(layout);
SmallVector<unsigned> CTASplitNum = triton::gpu::getCTASplitNum(layout);
SmallVector<unsigned> CTAOrder = triton::gpu::getCTAOrder(layout);
SmallVector<int64_t> shapePerCTA =
triton::gpu::getShapePerCTA(CTASplitNum, shape);
// Delinearize clusterCTAId
Value clusterCTAId = getClusterCTAId(rewriter, loc);
SmallVector<Value> multiDimClusterCTAId =
delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);
// CTA Wrapping
for (unsigned i = 0; i < rank; ++i) {
// This wrapping rule must be consistent with getShapePerCTA
unsigned splitNum = std::min<unsigned>(shape[i], CTASplitNum[i]);
multiDimClusterCTAId[i] =
urem(multiDimClusterCTAId[i], i32_val(splitNum));
}
SmallVector<Value> CTAOffset(rank);
for (unsigned i = 0; i < rank; ++i)
CTAOffset[i] = mul(multiDimClusterCTAId[i], i32_val(shapePerCTA[i]));
return CTAOffset;
}
SmallVector<Value> emitBaseIndexForLayout(Location loc,
ConversionPatternRewriter &rewriter,
Attribute layout,
RankedTensorType type,
bool withCTAOffset) const {
auto shape = type.getShape();
IndexCacheKeyT key{layout, type, withCTAOffset};
auto cache = indexCacheInfo.baseIndexCache;
auto insertPt = indexCacheInfo.indexInsertPoint;
SmallVector<Value> baseIndex;
if (cache && cache->count(key) > 0) {
return cache->lookup(key);
} else {
ConversionPatternRewriter::InsertionGuard guard(rewriter);
if (cache)
restoreInsertionPointIfSet(insertPt, rewriter);
SmallVector<Value> result;
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
result = emitBaseIndexWithinCTAForBlockedLayout(loc, rewriter,
blockedLayout, type);
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isVolta())
result = emitBaseIndexWithinCTAForMmaLayoutV1(loc, rewriter,
mmaLayout, type);
if (mmaLayout.isAmpere() || mmaLayout.isHopper())
result = emitBaseIndexWithinCTAForMmaLayoutV2V3(loc, rewriter,
mmaLayout, type);
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
result = emitBaseIndexForMfmaLayout(loc, rewriter, mfmaLayout, type);
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
auto parentShape = sliceLayout.paddedShape(type.getShape());
RankedTensorType parentTy = RankedTensorType::get(
parentShape, type.getElementType(), parentLayout);
result = emitBaseIndexForLayout(loc, rewriter, parentLayout, parentTy,
withCTAOffset);
result.erase(result.begin() + sliceLayout.getDim());
// CTAOffset has been added in emitBaseIndexForLayout of parentLayout
return result;
} else {
llvm_unreachable("unsupported emitBaseIndexForLayout");
}
if (withCTAOffset) {
auto CTAOffset = emitCTAOffsetForLayout(loc, rewriter, layout, shape);
assert(CTAOffset.size() == result.size() && "Rank mismatch");
for (unsigned k = 0; k < result.size(); ++k)
result[k] = add(result[k], CTAOffset[k]);
}
if (cache) {
cache->insert(std::make_pair(key, result));
*insertPt = rewriter.saveInsertionPoint();
}
return result;
}
}
SmallVector<SmallVector<unsigned>>
emitOffsetForLayout(Attribute layout, RankedTensorType type) const {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
return emitOffsetForBlockedLayout(blockedLayout, type);
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isVolta())
return emitOffsetForMmaLayoutV1(mmaLayout, type);
if (mmaLayout.isAmpere())
return emitOffsetForMmaLayoutV2(mmaLayout, type);
if (mmaLayout.isHopper())
return emitOffsetForMmaLayoutV3(mmaLayout, type);
}
if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
return emitOffsetForMfmaLayout(mfmaLayout, type);
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>())
return emitOffsetForSliceLayout(sliceLayout, type);
llvm_unreachable("unsupported emitOffsetForLayout");
}
#ifdef USE_ROCM
void emitMfmaOffsetForCTA(const MfmaEncodingAttr &mfmaLayout,
SmallVector<SmallVector<unsigned>> &offsets,
unsigned ctaOffsetX, unsigned ctaOffsetY) const {
auto nonKDim = mfmaLayout.getNonKDim();
// MFMA output tile consists of repeated "dot operand B" layout groups along
// row axis. This variable defines number of these groups.
const unsigned numGroups = (nonKDim == 32 ? 4 : 1);
const unsigned elemsPerThreadPerGroup = 4;
auto warpSize = getWarpSize(mfmaLayout);
assert(warpSize == 64);
auto shapePerCta = getShapePerCTATile(mfmaLayout);
for (unsigned block = 0; block < numGroups; block++) {
unsigned rowOrColOffset =
block * elemsPerThreadPerGroup * warpSize / nonKDim;
for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) {
if (mfmaLayout.getIsTransposed()) {
offsets.push_back(
{ctaOffsetX * shapePerCta[0],
ctaOffsetY * shapePerCta[1] + elem + rowOrColOffset});
} else {
offsets.push_back(
{ctaOffsetX * shapePerCta[0] + elem + rowOrColOffset,
ctaOffsetY * shapePerCta[1]});
}
}
}
}
#endif
// -----------------------------------------------------------------------
// Emit indices
// -----------------------------------------------------------------------
SmallVector<SmallVector<Value>>
emitIndices(Location loc, ConversionPatternRewriter &b, Attribute layout,
RankedTensorType type, bool withCTAOffset = true) const {
IndexCacheKeyT key{layout, type, withCTAOffset};
auto cache = indexCacheInfo.indexCache;
auto insertPt = indexCacheInfo.indexInsertPoint;
if (cache && cache->count(key) > 0) {
return cache->lookup(key);
} else {
ConversionPatternRewriter::InsertionGuard guard(b);
if (cache)
restoreInsertionPointIfSet(insertPt, b);
SmallVector<SmallVector<Value>> result;
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
result = emitIndicesForDistributedLayout(loc, b, blocked, type,
withCTAOffset);
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
result =
emitIndicesForDistributedLayout(loc, b, mma, type, withCTAOffset);
} else if (auto mfma = layout.dyn_cast<MfmaEncodingAttr>()) {
result =
emitIndicesForDistributedLayout(loc, b, mfma, type, withCTAOffset);
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
result =
emitIndicesForDistributedLayout(loc, b, slice, type, withCTAOffset);
} else {
llvm_unreachable(
"emitIndices for layouts other than blocked & slice not "
"implemented yet");
}
if (cache) {
cache->insert(std::make_pair(key, result));
*insertPt = b.saveInsertionPoint();
}
return result;
}
}
private:
void restoreInsertionPointIfSet(OpBuilder::InsertPoint *insertPt,
ConversionPatternRewriter &rewriter) const {
if (insertPt->isSet()) {
rewriter.restoreInsertionPoint(*insertPt);
} else {
auto func =
rewriter.getInsertionPoint()->getParentOfType<LLVM::LLVMFuncOp>();
rewriter.setInsertionPointToStart(&func.getBody().front());
}
}
// -----------------------------------------------------------------------
// Blocked layout indices
// -----------------------------------------------------------------------
// Get an index-base for each dimension for a \param blockedLayout.
SmallVector<Value> emitBaseIndexWithinCTAForBlockedLayout(
Location loc, ConversionPatternRewriter &rewriter,
const BlockedEncodingAttr &blockedLayout, RankedTensorType type) const {
auto shape = type.getShape();
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(triton::gpu::getWarpSize(blockedLayout));
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
auto sizePerThread = blockedLayout.getSizePerThread();
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
auto order = blockedLayout.getOrder();
auto shapePerCTA = triton::gpu::getShapePerCTA(blockedLayout, shape);
unsigned rank = shape.size();
// delinearize threadId to get the base index
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
SmallVector<Value> multiDimThreadId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
SmallVector<Value> multiDimBase(rank);
for (unsigned k = 0; k < rank; ++k) {
// Wrap around multiDimWarpId/multiDimThreadId in case
// shapePerCTATile[k] > shapePerCTA[k]
auto maxWarps =
ceil<unsigned>(shapePerCTA[k], sizePerThread[k] * threadsPerWarp[k]);
auto maxThreads = ceil<unsigned>(shapePerCTA[k], sizePerThread[k]);
multiDimWarpId[k] = urem(multiDimWarpId[k], i32_val(maxWarps));
multiDimThreadId[k] = urem(multiDimThreadId[k], i32_val(maxThreads));
// multiDimBase[k] = (multiDimThreadId[k] +
// multiDimWarpId[k] * threadsPerWarp[k]) *
// sizePerThread[k];
Value threadsPerWarpK = i32_val(threadsPerWarp[k]);
Value sizePerThreadK = i32_val(sizePerThread[k]);
multiDimBase[k] =
mul(sizePerThreadK, add(multiDimThreadId[k],
mul(multiDimWarpId[k], threadsPerWarpK)));
}
return multiDimBase;
}
SmallVector<SmallVector<unsigned>>
emitOffsetForBlockedLayout(const BlockedEncodingAttr &blockedLayout,
RankedTensorType type) const {
auto shape = type.getShape();
auto sizePerThread = blockedLayout.getSizePerThread();
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
auto order = blockedLayout.getOrder();
auto shapePerCTATile = getShapePerCTATile(blockedLayout);
auto shapePerCTA = triton::gpu::getShapePerCTA(blockedLayout, shape);
unsigned rank = shape.size();
SmallVector<unsigned> tilesPerDim(rank);
for (unsigned k = 0; k < rank; ++k)
tilesPerDim[k] = ceil<unsigned>(shapePerCTA[k], shapePerCTATile[k]);
SmallVector<SmallVector<unsigned>> offset(rank);
for (unsigned k = 0; k < rank; ++k) {
// 1 CTA tile in minimum if shapePerCTA[k] is less than shapePerCTATile[k]
for (unsigned blockOffset = 0; blockOffset < tilesPerDim[k];
++blockOffset)
for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k]; ++warpOffset)
for (unsigned threadOffset = 0; threadOffset < threadsPerWarp[k];
++threadOffset)
for (unsigned elemOffset = 0; elemOffset < sizePerThread[k];
++elemOffset)
offset[k].push_back(blockOffset * sizePerThread[k] *
threadsPerWarp[k] * warpsPerCTA[k] +
warpOffset * sizePerThread[k] *
threadsPerWarp[k] +
threadOffset * sizePerThread[k] + elemOffset);
}
unsigned elemsPerThread = triton::gpu::getTotalElemsPerThread(type);
unsigned totalSizePerThread = product<unsigned>(sizePerThread);
SmallVector<SmallVector<unsigned>> reorderedOffset(elemsPerThread);
for (unsigned n = 0; n < elemsPerThread; ++n) {
unsigned linearNanoTileId = n / totalSizePerThread;
unsigned linearNanoTileElemId = n % totalSizePerThread;
SmallVector<unsigned> multiDimNanoTileId =
getMultiDimIndex<unsigned>(linearNanoTileId, tilesPerDim, order);
SmallVector<unsigned> multiDimNanoTileElemId = getMultiDimIndex<unsigned>(
linearNanoTileElemId, sizePerThread, order);
for (unsigned k = 0; k < rank; ++k) {
unsigned reorderedMultiDimId =
multiDimNanoTileId[k] *
(sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]) +
multiDimNanoTileElemId[k];
reorderedOffset[n].push_back(offset[k][reorderedMultiDimId]);
}
}
return reorderedOffset;
}
// -----------------------------------------------------------------------
// Mma layout indices
// -----------------------------------------------------------------------
SmallVector<Value> emitBaseIndexWithinCTAForMmaLayoutV1(
Location loc, ConversionPatternRewriter &rewriter,
const MmaEncodingAttr &mmaLayout, RankedTensorType type) const {
auto shape = type.getShape();
auto wpt = mmaLayout.getWarpsPerCTA();
static constexpr std::array<int, 3> fpw{{2, 2, 1}};
auto [isARow, isBRow, isAVec4, isBVec4, _] =
mmaLayout.decodeVoltaLayoutStates();
Value thread = getThreadId(rewriter, loc);
auto *ctx = thread.getContext();
Value _1 = i32_val(1);
Value _2 = i32_val(2);
Value _4 = i32_val(4);
Value _16 = i32_val(16);
Value warpSize = i32_val(triton::gpu::getWarpSize(mmaLayout));
Value _fpw0 = i32_val(fpw[0]);
Value _fpw1 = i32_val(fpw[1]);
// A info
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout, 0);
auto aRep = aEncoding.getMMAv1Rep();
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
// B info
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout, 0);
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
auto bRep = bEncoding.getMMAv1Rep();
SmallVector<int, 2> rep({aRep[0], bRep[1]});
SmallVector<int, 2> spw({aSpw[0], bSpw[1]});
SmallVector<unsigned, 2> shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]});
Value lane = urem(thread, warpSize);
Value warp = udiv(thread, warpSize);
Value warp0 = urem(warp, i32_val(wpt[0]));
Value warp12 = udiv(warp, i32_val(wpt[0]));
Value warp1 = urem(warp12, i32_val(wpt[1]));
// warp offset
Value offWarpM = mul(warp0, i32_val(spw[0]));
Value offWarpN = mul(warp1, i32_val(spw[1]));
// quad offset
Value offQuadM = mul(udiv(and_(lane, _16), _4), _fpw0);
Value offQuadN = mul(udiv(and_(lane, _16), _4), _fpw1);
// pair offset
Value offPairM = udiv(urem(lane, _16), _4);
offPairM = urem(offPairM, _fpw0);
offPairM = mul(offPairM, _4);
Value offPairN = udiv(urem(lane, _16), _4);
offPairN = udiv(offPairN, _fpw0);
offPairN = urem(offPairN, _fpw1);
offPairN = mul(offPairN, _4);
offPairM = mul(offPairM, i32_val(rep[0] / 2));
offQuadM = mul(offQuadM, i32_val(rep[0] / 2));
offPairN = mul(offPairN, i32_val(rep[1] / 2));
offQuadN = mul(offQuadN, i32_val(rep[1] / 2));
// quad pair offset
Value offLaneM = add(offPairM, offQuadM);
Value offLaneN = add(offPairN, offQuadN);
// a, b offset
Value offsetAM = add(offWarpM, offLaneM);
Value offsetBN = add(offWarpN, offLaneN);
// m indices
Value offsetCM = add(and_(lane, _1), offsetAM);
// n indices
Value offsetCN = add((and_(lane, _2)), (add(offWarpN, offPairN)));
return {offsetCM, offsetCN};
}
SmallVector<SmallVector<unsigned>>
emitOffsetForMmaLayoutV1(const MmaEncodingAttr &mmaLayout,
RankedTensorType type) const {
auto shape = type.getShape();
auto [isARow, isBRow, isAVec4, isBVec4, _] =
mmaLayout.decodeVoltaLayoutStates();
// TODO: seems like the apttern below to get `rep`/`spw` appears quite often
// A info
auto aEncoding =
DotOperandEncodingAttr::get(type.getContext(), 0, mmaLayout, 0);
auto aRep = aEncoding.getMMAv1Rep();
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
// B info
auto bEncoding =
DotOperandEncodingAttr::get(type.getContext(), 1, mmaLayout, 0);
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
auto bRep = bEncoding.getMMAv1Rep();
auto wpt = mmaLayout.getWarpsPerCTA();
static constexpr std::array<int, 3> fpw{{2, 2, 1}};
SmallVector<int, 2> rep({aRep[0], bRep[1]});
SmallVector<int, 2> spw({aSpw[0], bSpw[1]});
SmallVector<unsigned, 2> shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]});
SmallVector<unsigned> idxM;
for (unsigned m = 0; m < shape[0]; m += shapePerCTA[0])
for (unsigned mm = 0; mm < rep[0]; ++mm)
idxM.push_back(m + mm * 2);
SmallVector<unsigned> idxN;
for (int n = 0; n < shape[1]; n += shapePerCTA[1]) {
for (int nn = 0; nn < rep[1]; ++nn) {
idxN.push_back(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1]);
idxN.push_back(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1] + 1);
}
}
SmallVector<SmallVector<unsigned>> ret;
for (unsigned x1 : idxN) { // N
for (unsigned x0 : idxM) { // M
SmallVector<unsigned> idx(2);
idx[0] = x0; // M
idx[1] = x1; // N
ret.push_back(std::move(idx));
}
}
return ret;
}
SmallVector<SmallVector<unsigned>>
emitOffsetForMmaLayoutV2(const MmaEncodingAttr &mmaLayout,
RankedTensorType type) const {
auto shape = type.getShape();
auto shapePerCTA = getShapePerCTA(mmaLayout, shape);
SmallVector<SmallVector<unsigned>> ret;
for (unsigned i = 0; i < shapePerCTA[0];
i += getShapePerCTATile(mmaLayout)[0]) {
for (unsigned j = 0; j < shapePerCTA[1];
j += getShapePerCTATile(mmaLayout)[1]) {
ret.push_back({i, j});
ret.push_back({i, j + 1});
ret.push_back({i + 8, j});
ret.push_back({i + 8, j + 1});
}
}
return ret;
}
SmallVector<Value> emitBaseIndexWithinCTAForMmaLayoutV2V3(
Location loc, ConversionPatternRewriter &rewriter,
const MmaEncodingAttr &mmaLayout, RankedTensorType type) const {
auto shape = type.getShape();
auto _warpsPerCTA = mmaLayout.getWarpsPerCTA();
assert(_warpsPerCTA.size() == 2);
auto order = triton::gpu::getOrder(mmaLayout);
ArrayRef<unsigned int> instrShape = mmaLayout.getInstrShape();
SmallVector<Value> warpsPerCTA = {i32_val(_warpsPerCTA[0]),
i32_val(_warpsPerCTA[1])};
auto shapePerCTA = getShapePerCTA(mmaLayout, shape);
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(triton::gpu::getWarpSize(mmaLayout));
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
uint32_t repM = (_warpsPerCTA[0] * instrShape[0]) / shapePerCTA[0];
uint32_t repN = (_warpsPerCTA[1] * instrShape[1]) / shapePerCTA[1];
uint32_t warpsM;
if (repM > 1)
warpsM = _warpsPerCTA[0] / repM;
else
warpsM = shape[0] / instrShape[0];
uint32_t warpsN;
if (repN > 1)
warpsN = _warpsPerCTA[1] / repN;
else
warpsN = shape[1] / instrShape[1];
SmallVector<Value> multiDimWarpId(2);
if (mmaLayout.isHopper()) {
// TODO[goostavz]: the tiling order from CTA->warp level is different for
// MMAv2/3. This is a workaround since we don't explicitly have warpGrp
// level in the layout definition, and the tiling order of warpGrp->warp
// must be fixed to meet the HW's needs. We may need to consider to
// explicitly define warpGrpPerCTA for MMAv3 layout.
multiDimWarpId[0] = urem(warpId, warpsPerCTA[0]);
multiDimWarpId[1] = urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]);
} else {
multiDimWarpId = delinearize(rewriter, loc, warpId, _warpsPerCTA, order);
}
Value warpId0 = urem(multiDimWarpId[0], i32_val(warpsM));
Value warpId1 = urem(multiDimWarpId[1], i32_val(warpsN));
Value offWarp0 = mul(warpId0, i32_val(instrShape[0]));
Value offWarp1 = mul(warpId1, i32_val(instrShape[1]));
SmallVector<Value> multiDimBase(2);
multiDimBase[0] = add(udiv(laneId, i32_val(4)), offWarp0);
multiDimBase[1] = add(mul(i32_val(2), urem(laneId, i32_val(4))), offWarp1);
return multiDimBase;
}
SmallVector<SmallVector<unsigned>>
emitOffsetForMmaLayoutV3(const MmaEncodingAttr &mmaLayout,
RankedTensorType type) const {
auto shape = type.getShape();
auto shapePerCTA = getShapePerCTA(mmaLayout, shape);
SmallVector<SmallVector<unsigned>> ret;
ArrayRef<unsigned int> instrShape = mmaLayout.getInstrShape();
for (unsigned i = 0; i < shapePerCTA[0];
i += getShapePerCTATile(mmaLayout)[0]) {
for (unsigned j = 0; j < shapePerCTA[1];
j += getShapePerCTATile(mmaLayout)[1]) {
for (unsigned k = 0; k < instrShape[1]; k += 8) {
ret.push_back({i, j + k});
ret.push_back({i, j + k + 1});
ret.push_back({i + 8, j + k});
ret.push_back({i + 8, j + k + 1});
}
}
}
return ret;
}
// -----------------------------------------------------------------------
// Mfma layout indices
// -----------------------------------------------------------------------
SmallVector<Value>
emitBaseIndexForMfmaLayout(Location loc, ConversionPatternRewriter &rewriter,
const MfmaEncodingAttr &mfmaLayout,
RankedTensorType type) const {
auto shape = type.getShape();
auto _warpsPerCTA = mfmaLayout.getWarpsPerCTA();
assert(_warpsPerCTA.size() == 2);
SmallVector<Value> warpsPerCTA = {i32_val(_warpsPerCTA[0]),
i32_val(_warpsPerCTA[1])};
int nonKDim = mfmaLayout.getNonKDim();
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(triton::gpu::getWarpSize(mfmaLayout));
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
Value warpId0 =
urem(urem(warpId, warpsPerCTA[0]), i32_val(shape[0] / nonKDim));
Value warpId1 = urem(urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]),
i32_val(shape[1] / nonKDim));
Value offWarp0 = mul(warpId0, i32_val(nonKDim));
Value offWarp1 = mul(warpId1, i32_val(nonKDim));
SmallVector<Value> multiDimBase(2);
if (mfmaLayout.getIsTransposed()) {
multiDimBase[1] =
add(mul(i32_val(4), udiv(laneId, i32_val(nonKDim))), offWarp1);
multiDimBase[0] = add(urem(laneId, i32_val(nonKDim)), offWarp0);
} else {
multiDimBase[0] =
add(mul(i32_val(4), udiv(laneId, i32_val(nonKDim))), offWarp0);
multiDimBase[1] = add(urem(laneId, i32_val(nonKDim)), offWarp1);
}
return multiDimBase;
}
SmallVector<SmallVector<unsigned>>
emitOffsetForMfmaLayout(const MfmaEncodingAttr &mfmaLayout,
RankedTensorType type) const {
auto tensorShape = type.getShape();
SmallVector<SmallVector<unsigned>> offsets;
auto shapePerCTA = getShapePerCTA(mfmaLayout, tensorShape);
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
SmallVector<unsigned> numWarpsPerDim(2);
for (unsigned d = 0; d < 2; ++d) {
unsigned inPerCTA = std::min<unsigned>(tensorShape[d], shapePerCTA[d]);
unsigned inPerWarp = ceil<unsigned>(inPerCTA, warpsPerCTA[d]);
numWarpsPerDim[d] = ceil<unsigned>(inPerWarp, mfmaLayout.getNonKDim());
}
for (unsigned i = 0; i < numWarpsPerDim[0]; ++i) {
for (unsigned j = 0; j < numWarpsPerDim[1]; ++j) {
emitMfmaOffsetForCTA(mfmaLayout, offsets, i, j);
}
}
return offsets;
}
// Emit indices calculation within each ConversionPattern, and returns a
// [elemsPerThread X rank] index matrix.
SmallVector<SmallVector<Value>> emitIndicesForDistributedLayout(
Location loc, ConversionPatternRewriter &rewriter, Attribute layout,
RankedTensorType type, bool withCTAOffset) const {
// step 1, delinearize threadId to get the base index
auto multiDimBase =
emitBaseIndexForLayout(loc, rewriter, layout, type, withCTAOffset);
// step 2, get offset of each element
auto offset = emitOffsetForLayout(layout, type);
// step 3, add offset to base, and reorder the sequence
// of indices to guarantee that elems in the same
// sizePerThread are adjacent in order
auto shape = type.getShape();
unsigned rank = shape.size();
unsigned elemsPerThread = offset.size();
SmallVector<SmallVector<Value>> multiDimIdx(elemsPerThread,
SmallVector<Value>(rank));
for (unsigned n = 0; n < elemsPerThread; ++n)
for (unsigned k = 0; k < rank; ++k)
multiDimIdx[n][k] = add(multiDimBase[k], i32_val(offset[n][k]));
return multiDimIdx;
}
SmallVector<SmallVector<unsigned>>
emitOffsetForSliceLayout(const SliceEncodingAttr &sliceLayout,
RankedTensorType type) const {
auto parentEncoding = sliceLayout.getParent();
unsigned dim = sliceLayout.getDim();
auto parentShape = sliceLayout.paddedShape(type.getShape());
RankedTensorType parentTy = RankedTensorType::get(
parentShape, type.getElementType(), parentEncoding);
auto parentOffsets = emitOffsetForLayout(parentEncoding, parentTy);
unsigned numOffsets = parentOffsets.size();
SmallVector<SmallVector<unsigned>> resultOffsets;
std::set<SmallVector<unsigned>> uniqueOffsets;
for (unsigned i = 0; i < numOffsets; ++i) {
SmallVector<unsigned> offsets = parentOffsets[i];
offsets.erase(offsets.begin() + dim);
if (uniqueOffsets.find(offsets) == uniqueOffsets.end()) {
resultOffsets.push_back(offsets);
uniqueOffsets.insert(offsets);
}
}
return resultOffsets;
}
private:
static SmallString<16> getUniqueFormatGlobalName(mlir::ModuleOp moduleOp) {
const char formatStringPrefix[] = "printfFormat_";
// Get a unique global name.
unsigned stringNumber = 0;
SmallString<16> stringConstName;
do {
stringConstName.clear();
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
} while (moduleOp.lookupSymbol(stringConstName));
return stringConstName;
}
template <typename T>
static LLVM::LLVMFuncOp
getOrDefineFunction(T &moduleOp, const Location loc,
ConversionPatternRewriter &rewriter, StringRef name,
LLVM::LLVMFunctionType type) {
LLVM::LLVMFuncOp ret;
if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
LLVM::Linkage::External);
}
return ret;
}
protected:
// The code is borrowed from https://reviews.llvm.org/D110448
// from GPUPrintfOpToHIPLowering::matchAndRewrite().
void llPrintfHIP(mlir::Location loc, mlir::ModuleOp moduleOp, StringRef msg,
ValueRange args, ConversionPatternRewriter &rewriter,
bool stderr = false) const {
auto typeConverter = getTypeConverter();
mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
mlir::Type i8Ptr = typeConverter->getPointerType(llvmI8);
mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
auto ocklBegin = getOrDefineFunction(
moduleOp, loc, rewriter,
(stderr ? "__ockl_fprintf_stderr_begin" : "__ockl_printf_begin"),
(LLVM::LLVMFunctionType::get(llvmI64, stderr ? ArrayRef<mlir::Type>()
: llvmI64)));
LLVM::LLVMFuncOp ocklAppendArgs;
if (!args.empty()) {
ocklAppendArgs = getOrDefineFunction(
moduleOp, loc, rewriter, "__ockl_printf_append_args",
LLVM::LLVMFunctionType::get(llvmI64,
{llvmI64, /*numArgs*/ llvmI32, llvmI64,
llvmI64, llvmI64, llvmI64, llvmI64,
llvmI64, llvmI64, /*isLast*/ llvmI32}));
}
auto ocklAppendStringN = getOrDefineFunction(
moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
LLVM::LLVMFunctionType::get(
llvmI64,
{llvmI64, i8Ptr, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
/// Start the printf hostcall
Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0);
auto printfBeginCall = rewriter.create<LLVM::CallOp>(
loc, ocklBegin, stderr ? ValueRange() : zeroI64);
Value printfDesc = printfBeginCall.getResult();
// Get a unique global name for the format.
SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
SmallString<32> formatString(msg);
formatString.push_back('\n'); // Triton adds CR for each print.
formatString.push_back('\0'); // Null terminate for C
size_t formatStringSize = formatString.size_in_bytes();
Value prefixString =
LLVM::addStringToModule(loc, rewriter, "printfFormat_", formatString);
auto prefixPtrType = ocklAppendStringN.getArgumentTypes()[1];
prefixString = bitcast(prefixString, prefixPtrType);
Value stringLen =
rewriter.create<LLVM::ConstantOp>(loc, llvmI64, formatStringSize);
Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
auto appendFormatCall = rewriter.create<LLVM::CallOp>(
loc, ocklAppendStringN,
ValueRange{printfDesc, prefixString, stringLen,
args.empty() ? oneI32 : zeroI32});
printfDesc = appendFormatCall.getResult();
// __ockl_printf_append_args takes 7 values per append call
constexpr size_t argsPerAppend = 7;
size_t nArgs = args.size();
for (size_t group = 0; group < nArgs; group += argsPerAppend) {
size_t bound = std::min(group + argsPerAppend, nArgs);
size_t numArgsThisCall = bound - group;
SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments;
arguments.push_back(printfDesc);
arguments.push_back(
rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall));
for (size_t i = group; i < bound; ++i) {
Value arg = args[i];
if (auto floatType = arg.getType().dyn_cast<FloatType>()) {
if (!floatType.isF64())
arg = rewriter.create<LLVM::FPExtOp>(
loc, typeConverter->convertType(rewriter.getF64Type()), arg);
arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
}
if (arg.getType().getIntOrFloatBitWidth() != 64)
arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
arguments.push_back(arg);
}
// Pad out to 7 arguments since the hostcall always needs 7
for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
arguments.push_back(zeroI64);
}
auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
arguments.push_back(isLast);
auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
printfDesc = call.getResult();
}
}
TritonGPUToLLVMTypeConverter *converter;
ModuleAllocation *allocation;
IndexCacheInfo indexCacheInfo;
mlir::triton::gpu::TMAMetadataTy *tmaMetadata;
};
template <typename SourceOp>
class ConvertTritonGPUOpToLLVMPattern
: public ConvertOpToLLVMPattern<SourceOp>,
public ConvertTritonGPUOpToLLVMPatternBase {
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit ConvertTritonGPUOpToLLVMPattern(
TritonGPUToLLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
ConvertTritonGPUOpToLLVMPatternBase(typeConverter) {}
explicit ConvertTritonGPUOpToLLVMPattern(
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation) {}
explicit ConvertTritonGPUOpToLLVMPattern(
TritonGPUToLLVMTypeConverter &typeConverter,
IndexCacheInfo indexCacheInfo, PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, indexCacheInfo) {}
explicit ConvertTritonGPUOpToLLVMPattern(
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation,
IndexCacheInfo indexCacheInfo, PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation,
indexCacheInfo) {}
explicit ConvertTritonGPUOpToLLVMPattern(
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation,
mlir::triton::gpu::TMAMetadataTy *tmaMetadata, PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation,
tmaMetadata) {}
protected:
TritonGPUToLLVMTypeConverter *getTypeConverter() const {
LLVMTypeConverter *ret =
((ConvertTritonGPUOpToLLVMPatternBase *)this)->getTypeConverter();
return (TritonGPUToLLVMTypeConverter *)ret;
}
};
#endif