mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
Conflicts: lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp lib/Target/LLVMIR/LLVMIRTranslation.cpp python/test/unit/language/assert_helper.py python/triton/third_party/cuda/bin/ptxas test/Conversion/tritongpu_to_llvm.mlir It looks like you may be committing a merge. If this is not correct, please remove the file .git/MERGE_HEAD and try again.
698 lines
30 KiB
C++
698 lines
30 KiB
C++
#include "ConvertLayoutOpToLLVM.h"
|
|
#include "Utility.h"
|
|
|
|
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
|
using ::mlir::LLVM::getStridesFromShapeAndOrder;
|
|
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
|
using ::mlir::triton::gpu::getContigPerThread;
|
|
using ::mlir::triton::gpu::getOrder;
|
|
using ::mlir::triton::gpu::getShapePerCTA;
|
|
using ::mlir::triton::gpu::getSizePerThread;
|
|
using ::mlir::triton::gpu::getTotalElemsPerThread;
|
|
using ::mlir::triton::gpu::isaDistributedLayout;
|
|
using ::mlir::triton::gpu::SharedEncodingAttr;
|
|
|
|
// Forward declarations
|
|
|
|
namespace SharedToDotOperandMMAv1 {
|
|
using CoordTy = SmallVector<Value>;
|
|
using ValueTable = std::map<std::pair<int, int>, std::pair<Value, Value>>;
|
|
|
|
SmallVector<CoordTy> getMNCoords(Value thread,
|
|
ConversionPatternRewriter &rewriter,
|
|
ArrayRef<unsigned int> wpt,
|
|
const MmaEncodingAttr &mmaLayout,
|
|
ArrayRef<int64_t> shape, bool isARow,
|
|
bool isBRow, bool isAVec4, bool isBVec4);
|
|
|
|
Value convertLayout(int opIdx, Value tensor, const SharedMemoryObject &smemObj,
|
|
Value thread, Location loc,
|
|
TritonGPUToLLVMTypeConverter *typeConverter,
|
|
ConversionPatternRewriter &rewriter, Type resultTy);
|
|
|
|
} // namespace SharedToDotOperandMMAv1
|
|
|
|
namespace SharedToDotOperandMMAv2 {
|
|
Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
|
|
Location loc, Value tensor,
|
|
DotOperandEncodingAttr bEncoding,
|
|
const SharedMemoryObject &smemObj,
|
|
TritonGPUToLLVMTypeConverter *typeConverter, Value thread);
|
|
}
|
|
|
|
namespace SharedToDotOperandFMA {
|
|
Value convertLayout(int opIdx, Value B, Value llB, BlockedEncodingAttr dLayout,
|
|
Value thread, Location loc,
|
|
TritonGPUToLLVMTypeConverter *typeConverter,
|
|
ConversionPatternRewriter &rewriter);
|
|
}
|
|
|
|
struct ConvertLayoutOpConversion
|
|
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ConvertLayoutOp> {
|
|
public:
|
|
using ConvertTritonGPUOpToLLVMPattern<
|
|
triton::gpu::ConvertLayoutOp>::ConvertTritonGPUOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Value src = op.getSrc();
|
|
Value dst = op.getResult();
|
|
auto srcTy = src.getType().cast<RankedTensorType>();
|
|
auto dstTy = dst.getType().cast<RankedTensorType>();
|
|
Attribute srcLayout = srcTy.getEncoding();
|
|
Attribute dstLayout = dstTy.getEncoding();
|
|
if (isaDistributedLayout(srcLayout) &&
|
|
dstLayout.isa<SharedEncodingAttr>()) {
|
|
return lowerDistributedToShared(op, adaptor, rewriter);
|
|
}
|
|
if (srcLayout.isa<SharedEncodingAttr>() &&
|
|
dstLayout.isa<DotOperandEncodingAttr>()) {
|
|
return lowerSharedToDotOperand(op, adaptor, rewriter);
|
|
}
|
|
if (isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout)) {
|
|
return lowerDistributedToDistributed(op, adaptor, rewriter);
|
|
}
|
|
if (srcLayout.isa<MmaEncodingAttr>() &&
|
|
dstLayout.isa<DotOperandEncodingAttr>()) {
|
|
return lowerMmaToDotOperand(op, adaptor, rewriter);
|
|
}
|
|
// TODO: to be implemented
|
|
llvm_unreachable("unsupported layout conversion");
|
|
return failure();
|
|
}
|
|
|
|
private:
|
|
SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
|
|
ConversionPatternRewriter &rewriter,
|
|
unsigned elemId, RankedTensorType type,
|
|
ArrayRef<unsigned> multiDimCTAInRepId,
|
|
ArrayRef<unsigned> shapePerCTA) const {
|
|
auto shape = type.getShape();
|
|
unsigned rank = shape.size();
|
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
|
auto multiDimOffsetFirstElem =
|
|
emitBaseIndexForLayout(loc, rewriter, blockedLayout, type);
|
|
SmallVector<Value> multiDimOffset(rank);
|
|
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
|
|
elemId, getSizePerThread(layout), getOrder(layout));
|
|
for (unsigned d = 0; d < rank; ++d) {
|
|
multiDimOffset[d] = add(multiDimOffsetFirstElem[d],
|
|
i32_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
|
|
multiDimElemId[d]));
|
|
}
|
|
return multiDimOffset;
|
|
}
|
|
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
|
unsigned dim = sliceLayout.getDim();
|
|
auto parentEncoding = sliceLayout.getParent();
|
|
auto parentSizePerThread = getSizePerThread(parentEncoding);
|
|
auto parentShape = sliceLayout.paddedShape(shape);
|
|
auto parentTy = RankedTensorType::get(parentShape, type.getElementType(),
|
|
parentEncoding);
|
|
auto offsets = emitOffsetForLayout(layout, type);
|
|
auto parentOffset = emitOffsetForLayout(parentEncoding, parentTy);
|
|
SmallVector<int> idxs;
|
|
for (SmallVector<unsigned> off : offsets) {
|
|
off.insert(off.begin() + dim, 0);
|
|
auto it = std::find(parentOffset.begin(), parentOffset.end(), off);
|
|
idxs.push_back(std::distance(parentOffset.begin(), it));
|
|
}
|
|
auto multiDimOffsetParent = getMultiDimOffset(
|
|
parentEncoding, loc, rewriter, idxs[elemId], parentTy,
|
|
sliceLayout.paddedShape(multiDimCTAInRepId),
|
|
sliceLayout.paddedShape(shapePerCTA));
|
|
SmallVector<Value> multiDimOffset(rank);
|
|
for (unsigned d = 0; d < rank + 1; ++d) {
|
|
if (d == dim)
|
|
continue;
|
|
unsigned slicedD = d < dim ? d : (d - 1);
|
|
multiDimOffset[slicedD] = multiDimOffsetParent[d];
|
|
}
|
|
return multiDimOffset;
|
|
}
|
|
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
|
SmallVector<Value> mmaColIdx(4);
|
|
SmallVector<Value> mmaRowIdx(2);
|
|
Value threadId = getThreadId(rewriter, loc);
|
|
#ifdef USE_ROCM
|
|
Value warpSize = i32_val(64);
|
|
#else
|
|
Value warpSize = i32_val(32);
|
|
#endif
|
|
Value laneId = urem(threadId, warpSize);
|
|
Value warpId = udiv(threadId, warpSize);
|
|
// TODO: fix the bug in MMAEncodingAttr document
|
|
SmallVector<Value> multiDimWarpId(2);
|
|
multiDimWarpId[0] = urem(warpId, i32_val(mmaLayout.getWarpsPerCTA()[0]));
|
|
multiDimWarpId[1] = udiv(warpId, i32_val(mmaLayout.getWarpsPerCTA()[0]));
|
|
Value _1 = i32_val(1);
|
|
Value _2 = i32_val(2);
|
|
Value _4 = i32_val(4);
|
|
Value _8 = i32_val(8);
|
|
Value _16 = i32_val(16);
|
|
if (mmaLayout.isAmpere()) {
|
|
multiDimWarpId[0] = urem(multiDimWarpId[0], i32_val(shape[0] / 16));
|
|
multiDimWarpId[1] = urem(multiDimWarpId[1], i32_val(shape[1] / 8));
|
|
Value mmaGrpId = udiv(laneId, _4);
|
|
Value mmaGrpIdP8 = add(mmaGrpId, _8);
|
|
Value mmaThreadIdInGrp = urem(laneId, _4);
|
|
Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, _2);
|
|
Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, _1);
|
|
Value rowWarpOffset = mul(multiDimWarpId[0], _16);
|
|
mmaRowIdx[0] = add(mmaGrpId, rowWarpOffset);
|
|
mmaRowIdx[1] = add(mmaGrpIdP8, rowWarpOffset);
|
|
Value colWarpOffset = mul(multiDimWarpId[1], _8);
|
|
mmaColIdx[0] = add(mmaThreadIdInGrpM2, colWarpOffset);
|
|
mmaColIdx[1] = add(mmaThreadIdInGrpM2P1, colWarpOffset);
|
|
} else if (mmaLayout.isVolta()) {
|
|
// Volta doesn't follow the pattern here."
|
|
} else {
|
|
llvm_unreachable("Unexpected MMALayout version");
|
|
}
|
|
|
|
assert(rank == 2);
|
|
SmallVector<Value> multiDimOffset(rank);
|
|
if (mmaLayout.isAmpere()) {
|
|
multiDimOffset[0] = elemId < 2 ? mmaRowIdx[0] : mmaRowIdx[1];
|
|
multiDimOffset[1] = elemId % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1];
|
|
multiDimOffset[0] = add(
|
|
multiDimOffset[0], i32_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
|
|
multiDimOffset[1] = add(
|
|
multiDimOffset[1], i32_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
|
|
} else if (mmaLayout.isVolta()) {
|
|
auto [isARow, isBRow, isAVec4, isBVec4, _] =
|
|
mmaLayout.decodeVoltaLayoutStates();
|
|
auto coords = SharedToDotOperandMMAv1::getMNCoords(
|
|
threadId, rewriter, mmaLayout.getWarpsPerCTA(), mmaLayout, shape,
|
|
isARow, isBRow, isAVec4, isBVec4);
|
|
return coords[elemId];
|
|
} else {
|
|
llvm_unreachable("Unexpected MMALayout version");
|
|
}
|
|
return multiDimOffset;
|
|
}
|
|
llvm_unreachable("unexpected layout in getMultiDimOffset");
|
|
}
|
|
|
|
// shared memory rd/st for blocked or mma layout with data padding
|
|
void processReplica(Location loc, ConversionPatternRewriter &rewriter,
|
|
bool stNotRd, RankedTensorType type,
|
|
ArrayRef<unsigned> numCTAsEachRep,
|
|
ArrayRef<unsigned> multiDimRepId, unsigned vec,
|
|
ArrayRef<unsigned> paddedRepShape,
|
|
ArrayRef<unsigned> outOrd, SmallVector<Value> &vals,
|
|
Value smemBase) const {
|
|
auto accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
|
|
auto layout = type.getEncoding();
|
|
auto rank = type.getRank();
|
|
auto sizePerThread = getSizePerThread(layout);
|
|
auto accumSizePerThread = product<unsigned>(sizePerThread);
|
|
SmallVector<unsigned> numCTAs(rank);
|
|
auto shapePerCTA = getShapePerCTA(layout, type.getShape());
|
|
auto order = getOrder(layout);
|
|
for (unsigned d = 0; d < rank; ++d) {
|
|
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
|
|
}
|
|
auto elemTy = type.getElementType();
|
|
bool isInt1 = elemTy.isInteger(1);
|
|
bool isPtr = elemTy.isa<triton::PointerType>();
|
|
auto llvmElemTyOrig = getTypeConverter()->convertType(elemTy);
|
|
if (isInt1)
|
|
elemTy = IntegerType::get(elemTy.getContext(), 8);
|
|
else if (isPtr)
|
|
elemTy = IntegerType::get(elemTy.getContext(), 64);
|
|
|
|
auto llvmElemTy = getTypeConverter()->convertType(elemTy);
|
|
|
|
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
|
|
auto multiDimCTAInRepId =
|
|
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep, order);
|
|
SmallVector<unsigned> multiDimCTAId(rank);
|
|
for (const auto &it : llvm::enumerate(multiDimCTAInRepId)) {
|
|
auto d = it.index();
|
|
multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value();
|
|
}
|
|
|
|
auto linearCTAId =
|
|
getLinearIndex<unsigned>(multiDimCTAId, numCTAs, order);
|
|
// TODO: This is actually redundant index calculation, we should
|
|
// consider of caching the index calculation result in case
|
|
// of performance issue observed.
|
|
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
|
|
SmallVector<Value> multiDimOffset =
|
|
getMultiDimOffset(layout, loc, rewriter, elemId, type,
|
|
multiDimCTAInRepId, shapePerCTA);
|
|
Value offset =
|
|
linearize(rewriter, loc, multiDimOffset, paddedRepShape, outOrd);
|
|
|
|
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
|
Value ptr = gep(elemPtrTy, smemBase, offset);
|
|
auto vecTy = vec_ty(llvmElemTy, vec);
|
|
ptr = bitcast(ptr, ptr_ty(vecTy, 3));
|
|
if (stNotRd) {
|
|
Value valVec = undef(vecTy);
|
|
for (unsigned v = 0; v < vec; ++v) {
|
|
auto currVal = vals[elemId + linearCTAId * accumSizePerThread + v];
|
|
if (isInt1)
|
|
currVal = zext(llvmElemTy, currVal);
|
|
else if (isPtr)
|
|
currVal = ptrtoint(llvmElemTy, currVal);
|
|
valVec = insert_element(vecTy, valVec, currVal, i32_val(v));
|
|
}
|
|
store(valVec, ptr);
|
|
} else {
|
|
Value valVec = load(ptr);
|
|
for (unsigned v = 0; v < vec; ++v) {
|
|
Value currVal = extract_element(llvmElemTy, valVec, i32_val(v));
|
|
if (isInt1)
|
|
currVal = icmp_ne(currVal,
|
|
rewriter.create<LLVM::ConstantOp>(
|
|
loc, i8_ty, rewriter.getI8IntegerAttr(0)));
|
|
else if (isPtr)
|
|
currVal = inttoptr(llvmElemTyOrig, currVal);
|
|
vals[elemId + linearCTAId * accumSizePerThread + v] = currVal;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// The MMAV1's result is quite different from the existing "Replica"
|
|
// structure, add a new simple but clear implementation for it to avoid
|
|
// modifying the logic of the existing one.
|
|
void processReplicaForMMAV1(Location loc, ConversionPatternRewriter &rewriter,
|
|
bool stNotRd, RankedTensorType type,
|
|
ArrayRef<unsigned> multiDimRepId, unsigned vec,
|
|
ArrayRef<unsigned> paddedRepShape,
|
|
ArrayRef<unsigned> outOrd,
|
|
SmallVector<Value> &vals, Value smemBase,
|
|
ArrayRef<int64_t> shape,
|
|
bool isDestMma = false) const {
|
|
unsigned accumNumCTAsEachRep = 1;
|
|
auto layout = type.getEncoding();
|
|
MmaEncodingAttr mma = layout.dyn_cast<MmaEncodingAttr>();
|
|
auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>();
|
|
if (sliceLayout)
|
|
mma = sliceLayout.getParent().cast<MmaEncodingAttr>();
|
|
|
|
auto order = getOrder(layout);
|
|
auto rank = type.getRank();
|
|
int accumSizePerThread = vals.size();
|
|
|
|
SmallVector<unsigned> numCTAs(rank, 1);
|
|
SmallVector<unsigned> numCTAsEachRep(rank, 1);
|
|
SmallVector<unsigned> shapePerCTA = getShapePerCTA(layout, shape);
|
|
auto elemTy = type.getElementType();
|
|
|
|
int ctaId = 0;
|
|
|
|
auto multiDimCTAInRepId =
|
|
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep, order);
|
|
SmallVector<unsigned> multiDimCTAId(rank);
|
|
for (const auto &it : llvm::enumerate(multiDimCTAInRepId)) {
|
|
auto d = it.index();
|
|
multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value();
|
|
}
|
|
|
|
std::vector<std::pair<SmallVector<Value>, Value>> coord2valT(
|
|
accumSizePerThread);
|
|
bool needTrans = outOrd[0] != 0;
|
|
if (sliceLayout || isDestMma)
|
|
needTrans = false;
|
|
|
|
vec = needTrans ? 2 : 1;
|
|
{
|
|
// We need to transpose the coordinates and values here to enable vec=2
|
|
// when store to smem.
|
|
std::vector<std::pair<SmallVector<Value>, Value>> coord2val(
|
|
accumSizePerThread);
|
|
for (unsigned elemId = 0; elemId < accumSizePerThread; ++elemId) {
|
|
// TODO[Superjomn]: Move the coordinate computation out of loop, it is
|
|
// duplicate in Volta.
|
|
SmallVector<Value> multiDimOffset =
|
|
getMultiDimOffset(layout, loc, rewriter, elemId, type,
|
|
multiDimCTAInRepId, shapePerCTA);
|
|
coord2val[elemId] = std::make_pair(multiDimOffset, vals[elemId]);
|
|
}
|
|
|
|
if (needTrans) {
|
|
// do transpose
|
|
auto aEncoding =
|
|
DotOperandEncodingAttr::get(mma.getContext(), 0, mma, 0);
|
|
int numM = aEncoding.getMMAv1NumOuter(shape);
|
|
int numN = accumSizePerThread / numM;
|
|
|
|
for (int r = 0; r < numM; r++) {
|
|
for (int c = 0; c < numN; c++) {
|
|
coord2valT[r * numN + c] = std::move(coord2val[c * numM + r]);
|
|
}
|
|
}
|
|
} else {
|
|
coord2valT = std::move(coord2val);
|
|
}
|
|
}
|
|
|
|
// Now the coord2valT has the transposed and contiguous elements(with
|
|
// vec=2), the original vals is not needed.
|
|
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
|
|
auto coord = coord2valT[elemId].first;
|
|
Value offset = linearize(rewriter, loc, coord, paddedRepShape, outOrd);
|
|
auto elemPtrTy = ptr_ty(elemTy, 3);
|
|
Value ptr = gep(elemPtrTy, smemBase, offset);
|
|
auto vecTy = vec_ty(elemTy, vec);
|
|
ptr = bitcast(ptr, ptr_ty(vecTy, 3));
|
|
if (stNotRd) {
|
|
Value valVec = undef(vecTy);
|
|
for (unsigned v = 0; v < vec; ++v) {
|
|
auto currVal = coord2valT[elemId + v].second;
|
|
valVec = insert_element(vecTy, valVec, currVal, i32_val(v));
|
|
}
|
|
store(valVec, ptr);
|
|
} else {
|
|
Value valVec = load(ptr);
|
|
for (unsigned v = 0; v < vec; ++v) {
|
|
Value currVal = extract_element(elemTy, valVec, i32_val(v));
|
|
vals[elemId + v] = currVal;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// blocked/mma -> blocked/mma.
|
|
// Data padding in shared memory to avoid bank conflict.
|
|
LogicalResult
|
|
lowerDistributedToDistributed(triton::gpu::ConvertLayoutOp op,
|
|
OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
auto loc = op.getLoc();
|
|
Value src = op.getSrc();
|
|
Value dst = op.getResult();
|
|
auto srcTy = src.getType().cast<RankedTensorType>();
|
|
auto dstTy = dst.getType().cast<RankedTensorType>();
|
|
Attribute srcLayout = srcTy.getEncoding();
|
|
Attribute dstLayout = dstTy.getEncoding();
|
|
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
|
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
|
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
|
smemBase = bitcast(smemBase, elemPtrTy);
|
|
auto shape = dstTy.getShape();
|
|
unsigned rank = dstTy.getRank();
|
|
SmallVector<unsigned> numReplicates(rank);
|
|
SmallVector<unsigned> inNumCTAsEachRep(rank);
|
|
SmallVector<unsigned> outNumCTAsEachRep(rank);
|
|
SmallVector<unsigned> inNumCTAs(rank);
|
|
SmallVector<unsigned> outNumCTAs(rank);
|
|
auto srcShapePerCTA = getShapePerCTA(srcLayout, srcTy.getShape());
|
|
auto dstShapePerCTA = getShapePerCTA(dstLayout, shape);
|
|
|
|
// For Volta, all the coords for a CTA are calculated.
|
|
bool isSrcMmaV1{}, isDstMmaV1{};
|
|
if (auto mmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>()) {
|
|
isSrcMmaV1 = mmaLayout.isVolta();
|
|
}
|
|
if (auto sliceLayout = srcLayout.dyn_cast<SliceEncodingAttr>()) {
|
|
isSrcMmaV1 = sliceLayout.getParent().isa<MmaEncodingAttr>() &&
|
|
sliceLayout.getParent().cast<MmaEncodingAttr>().isVolta();
|
|
}
|
|
if (auto mmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>()) {
|
|
isDstMmaV1 = mmaLayout.isVolta();
|
|
}
|
|
if (auto sliceLayout = dstLayout.dyn_cast<SliceEncodingAttr>()) {
|
|
isDstMmaV1 = sliceLayout.getParent().isa<MmaEncodingAttr>() &&
|
|
sliceLayout.getParent().cast<MmaEncodingAttr>().isVolta();
|
|
}
|
|
|
|
for (unsigned d = 0; d < rank; ++d) {
|
|
unsigned inPerCTA = std::min<unsigned>(shape[d], srcShapePerCTA[d]);
|
|
unsigned outPerCTA = std::min<unsigned>(shape[d], dstShapePerCTA[d]);
|
|
unsigned maxPerCTA = std::max(inPerCTA, outPerCTA);
|
|
numReplicates[d] = ceil<unsigned>(shape[d], maxPerCTA);
|
|
inNumCTAsEachRep[d] = maxPerCTA / inPerCTA;
|
|
outNumCTAsEachRep[d] = maxPerCTA / outPerCTA;
|
|
assert(maxPerCTA % inPerCTA == 0 && maxPerCTA % outPerCTA == 0);
|
|
inNumCTAs[d] = ceil<unsigned>(shape[d], inPerCTA);
|
|
outNumCTAs[d] = ceil<unsigned>(shape[d], outPerCTA);
|
|
}
|
|
// Potentially we need to store for multiple CTAs in this replication
|
|
auto accumNumReplicates = product<unsigned>(numReplicates);
|
|
// unsigned elems = getTotalElemsPerThread(srcTy);
|
|
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
|
|
rewriter, srcTy);
|
|
unsigned inVec = 0;
|
|
unsigned outVec = 0;
|
|
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
|
|
|
|
unsigned outElems = getTotalElemsPerThread(dstTy);
|
|
auto outOrd = getOrder(dstLayout);
|
|
SmallVector<Value> outVals(outElems);
|
|
|
|
for (unsigned repId = 0; repId < accumNumReplicates; ++repId) {
|
|
auto multiDimRepId =
|
|
getMultiDimIndex<unsigned>(repId, numReplicates, outOrd);
|
|
if (repId != 0)
|
|
barrier();
|
|
if (srcLayout.isa<BlockedEncodingAttr>() ||
|
|
srcLayout.isa<SliceEncodingAttr>() ||
|
|
srcLayout.isa<MmaEncodingAttr>()) {
|
|
if (isSrcMmaV1)
|
|
processReplicaForMMAV1(loc, rewriter, /*stNotRd*/ true, srcTy,
|
|
multiDimRepId, inVec, paddedRepShape, outOrd,
|
|
vals, smemBase, shape);
|
|
else
|
|
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy,
|
|
inNumCTAsEachRep, multiDimRepId, inVec, paddedRepShape,
|
|
outOrd, vals, smemBase);
|
|
} else {
|
|
assert(0 && "ConvertLayout with input layout not implemented");
|
|
return failure();
|
|
}
|
|
|
|
barrier();
|
|
if (dstLayout.isa<BlockedEncodingAttr>() ||
|
|
dstLayout.isa<SliceEncodingAttr>() ||
|
|
dstLayout.isa<MmaEncodingAttr>()) {
|
|
if (isDstMmaV1)
|
|
processReplicaForMMAV1(loc, rewriter, /*stNotRd*/ false, dstTy,
|
|
multiDimRepId, outVec, paddedRepShape, outOrd,
|
|
outVals, smemBase, shape, /*isDestMma=*/true);
|
|
else
|
|
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy,
|
|
outNumCTAsEachRep, multiDimRepId, outVec,
|
|
paddedRepShape, outOrd, outVals, smemBase);
|
|
} else {
|
|
assert(0 && "ConvertLayout with output layout not implemented");
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
SmallVector<Type> types(outElems, llvmElemTy);
|
|
auto *ctx = llvmElemTy.getContext();
|
|
Type structTy = struct_ty(types);
|
|
Value result =
|
|
getTypeConverter()->packLLElements(loc, outVals, rewriter, dstTy);
|
|
rewriter.replaceOp(op, result);
|
|
|
|
return success();
|
|
}
|
|
|
|
// blocked -> shared.
|
|
// Swizzling in shared memory to avoid bank conflict. Normally used for
|
|
// A/B operands of dots.
|
|
LogicalResult
|
|
lowerDistributedToShared(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
auto loc = op.getLoc();
|
|
Value src = op.getSrc();
|
|
Value dst = op.getResult();
|
|
auto srcTy = src.getType().cast<RankedTensorType>();
|
|
auto srcShape = srcTy.getShape();
|
|
auto dstTy = dst.getType().cast<RankedTensorType>();
|
|
auto dstShape = dstTy.getShape();
|
|
assert(srcShape.size() == 2 &&
|
|
"Unexpected rank of ConvertLayout(blocked->shared)");
|
|
auto srcLayout = srcTy.getEncoding();
|
|
auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>();
|
|
auto inOrd = getOrder(srcLayout);
|
|
auto outOrd = dstSharedLayout.getOrder();
|
|
Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
|
|
auto elemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
|
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
|
|
smemBase = bitcast(smemBase, elemPtrTy);
|
|
|
|
auto dstStrides =
|
|
getStridesFromShapeAndOrder(dstShape, outOrd, loc, rewriter);
|
|
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy);
|
|
storeDistributedToShared(src, adaptor.getSrc(), dstStrides, srcIndices, dst,
|
|
smemBase, elemTy, loc, rewriter);
|
|
auto smemObj =
|
|
SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter);
|
|
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
|
rewriter.replaceOp(op, retVal);
|
|
return success();
|
|
}
|
|
|
|
// shared -> mma_operand
|
|
LogicalResult
|
|
lowerSharedToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
auto loc = op.getLoc();
|
|
Value src = op.getSrc();
|
|
Value dst = op.getResult();
|
|
auto dstTensorTy = dst.getType().cast<RankedTensorType>();
|
|
auto srcTensorTy = src.getType().cast<RankedTensorType>();
|
|
auto dotOperandLayout =
|
|
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
|
auto sharedLayout = srcTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
|
|
|
bool isOuter{};
|
|
int K{};
|
|
if (dotOperandLayout.getOpIdx() == 0) // $a
|
|
K = dstTensorTy.getShape()[sharedLayout.getOrder()[0]];
|
|
else // $b
|
|
K = dstTensorTy.getShape()[sharedLayout.getOrder()[1]];
|
|
isOuter = K == 1;
|
|
|
|
Value res;
|
|
if (auto mmaLayout =
|
|
dotOperandLayout.getParent().dyn_cast_or_null<MmaEncodingAttr>()) {
|
|
res = lowerSharedToDotOperandMMA(op, adaptor, rewriter, mmaLayout,
|
|
dotOperandLayout, isOuter);
|
|
} else if (auto blockedLayout =
|
|
dotOperandLayout.getParent()
|
|
.dyn_cast_or_null<BlockedEncodingAttr>()) {
|
|
auto dotOpLayout =
|
|
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
|
auto thread = getThreadId(rewriter, loc);
|
|
res = SharedToDotOperandFMA::convertLayout(
|
|
dotOpLayout.getOpIdx(), src, adaptor.getSrc(), blockedLayout, thread,
|
|
loc, getTypeConverter(), rewriter);
|
|
} else {
|
|
assert(false && "Unsupported dot operand layout found");
|
|
}
|
|
|
|
rewriter.replaceOp(op, res);
|
|
return success();
|
|
}
|
|
|
|
// mma -> dot_operand
|
|
LogicalResult
|
|
lowerMmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
auto loc = op.getLoc();
|
|
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
|
|
auto dstTy = op.getResult().getType().cast<RankedTensorType>();
|
|
if (isMmaToDotShortcut(srcTy, dstTy)) {
|
|
// get source values
|
|
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
|
|
rewriter, srcTy);
|
|
unsigned elems = getTotalElemsPerThread(srcTy);
|
|
Type elemTy =
|
|
this->getTypeConverter()->convertType(srcTy.getElementType());
|
|
// for the destination type, we need to pack values together
|
|
// so they can be consumed by tensor core operations
|
|
SmallVector<Value> vecVals;
|
|
SmallVector<Type> types;
|
|
// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer
|
|
// instructions to pack & unpack sub-word integers. A workaround is to
|
|
// store the results of ldmatrix in i32
|
|
auto elemSize = elemTy.getIntOrFloatBitWidth();
|
|
if (auto intTy = elemTy.dyn_cast<IntegerType>() && elemSize <= 16) {
|
|
auto fold = 32 / elemSize;
|
|
for (unsigned i = 0; i < elems; i += fold) {
|
|
Value val = i32_val(0);
|
|
for (unsigned j = 0; j < fold; j++) {
|
|
auto ext =
|
|
shl(i32_ty, zext(i32_ty, vals[i + j]), i32_val(elemSize * j));
|
|
val = or_(i32_ty, val, ext);
|
|
}
|
|
vecVals.push_back(val);
|
|
}
|
|
elems = elems / (32 / elemSize);
|
|
types = SmallVector<Type>(elems, i32_ty);
|
|
} else {
|
|
unsigned vecSize = std::max<unsigned>(32 / elemSize, 1);
|
|
Type vecTy = vec_ty(elemTy, vecSize);
|
|
types = SmallVector<Type>(elems / vecSize, vecTy);
|
|
for (unsigned i = 0; i < elems; i += vecSize) {
|
|
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
|
for (unsigned j = 0; j < vecSize; j++)
|
|
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
|
|
vecVals.push_back(packed);
|
|
}
|
|
}
|
|
|
|
// This needs to be ordered the same way that
|
|
// ldmatrix.x4 would order it
|
|
// TODO: this needs to be refactor so we don't
|
|
// implicitly depends on how emitOffsetsForMMAV2
|
|
// is implemented
|
|
SmallVector<Value> reorderedVals;
|
|
for (unsigned i = 0; i < vecVals.size(); i += 4) {
|
|
reorderedVals.push_back(bitcast(vecVals[i], i32_ty));
|
|
reorderedVals.push_back(bitcast(vecVals[i + 2], i32_ty));
|
|
reorderedVals.push_back(bitcast(vecVals[i + 1], i32_ty));
|
|
reorderedVals.push_back(bitcast(vecVals[i + 3], i32_ty));
|
|
}
|
|
|
|
Value view = getTypeConverter()->packLLElements(loc, reorderedVals,
|
|
rewriter, dstTy);
|
|
rewriter.replaceOp(op, view);
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
// shared -> dot_operand if the result layout is mma
|
|
Value lowerSharedToDotOperandMMA(
|
|
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter, const MmaEncodingAttr &mmaLayout,
|
|
const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const {
|
|
auto loc = op.getLoc();
|
|
Value src = op.getSrc();
|
|
Value dst = op.getResult();
|
|
|
|
auto smemObj =
|
|
getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter);
|
|
Value res;
|
|
|
|
if (!isOuter && mmaLayout.isAmpere()) { // tensor core v2
|
|
|
|
res = SharedToDotOperandMMAv2::convertLayout(
|
|
dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout,
|
|
smemObj, getTypeConverter(), tid_val());
|
|
|
|
} else if (!isOuter && mmaLayout.isVolta() &&
|
|
supportMMA(dst, mmaLayout.getVersionMajor())) { // tensor core v1
|
|
bool isMMAv1Row = dotOperandLayout.getMMAv1IsRow();
|
|
auto srcSharedLayout = src.getType()
|
|
.cast<RankedTensorType>()
|
|
.getEncoding()
|
|
.cast<SharedEncodingAttr>();
|
|
|
|
// Can only convert [1, 0] to row or [0, 1] to col for now
|
|
if ((srcSharedLayout.getOrder()[0] == 1 && !isMMAv1Row) ||
|
|
(srcSharedLayout.getOrder()[0] == 0 && isMMAv1Row)) {
|
|
llvm::errs() << "Unsupported Shared -> DotOperand[MMAv1] conversion\n";
|
|
return Value();
|
|
}
|
|
|
|
res = SharedToDotOperandMMAv1::convertLayout(
|
|
dotOperandLayout.getOpIdx(), src, smemObj, getThreadId(rewriter, loc),
|
|
loc, getTypeConverter(), rewriter, dst.getType());
|
|
} else {
|
|
assert(false && "Unsupported mma layout found");
|
|
}
|
|
return res;
|
|
}
|
|
}; // namespace triton::gpu::ConvertLayoutOp
|
|
|
|
void populateConvertLayoutOpToLLVMPatterns(
|
|
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
|
ModuleAllocation &allocation,
|
|
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
|
PatternBenefit benefit) {
|
|
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation,
|
|
indexCacheInfo, benefit);
|
|
}
|