Files
ROCm/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Jason Furmanek 4c4e42e524 Merge remote-tracking branch 'openai/main' into IFU-230517
Conflicts:
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
	lib/Target/LLVMIR/LLVMIRTranslation.cpp
	python/test/unit/language/assert_helper.py
	python/triton/third_party/cuda/bin/ptxas
	test/Conversion/tritongpu_to_llvm.mlir

 It looks like you may be committing a merge.
 If this is not correct, please remove the file
	.git/MERGE_HEAD
 and try again.
2023-05-17 15:03:42 +00:00

698 lines
30 KiB
C++

#include "ConvertLayoutOpToLLVM.h"
#include "Utility.h"
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStridesFromShapeAndOrder;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getContigPerThread;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::getTotalElemsPerThread;
using ::mlir::triton::gpu::isaDistributedLayout;
using ::mlir::triton::gpu::SharedEncodingAttr;
// Forward declarations
namespace SharedToDotOperandMMAv1 {
using CoordTy = SmallVector<Value>;
using ValueTable = std::map<std::pair<int, int>, std::pair<Value, Value>>;
SmallVector<CoordTy> getMNCoords(Value thread,
ConversionPatternRewriter &rewriter,
ArrayRef<unsigned int> wpt,
const MmaEncodingAttr &mmaLayout,
ArrayRef<int64_t> shape, bool isARow,
bool isBRow, bool isAVec4, bool isBVec4);
Value convertLayout(int opIdx, Value tensor, const SharedMemoryObject &smemObj,
Value thread, Location loc,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Type resultTy);
} // namespace SharedToDotOperandMMAv1
namespace SharedToDotOperandMMAv2 {
Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
Location loc, Value tensor,
DotOperandEncodingAttr bEncoding,
const SharedMemoryObject &smemObj,
TritonGPUToLLVMTypeConverter *typeConverter, Value thread);
}
namespace SharedToDotOperandFMA {
Value convertLayout(int opIdx, Value B, Value llB, BlockedEncodingAttr dLayout,
Value thread, Location loc,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter);
}
struct ConvertLayoutOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ConvertLayoutOp> {
public:
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::ConvertLayoutOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value src = op.getSrc();
Value dst = op.getResult();
auto srcTy = src.getType().cast<RankedTensorType>();
auto dstTy = dst.getType().cast<RankedTensorType>();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
if (isaDistributedLayout(srcLayout) &&
dstLayout.isa<SharedEncodingAttr>()) {
return lowerDistributedToShared(op, adaptor, rewriter);
}
if (srcLayout.isa<SharedEncodingAttr>() &&
dstLayout.isa<DotOperandEncodingAttr>()) {
return lowerSharedToDotOperand(op, adaptor, rewriter);
}
if (isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout)) {
return lowerDistributedToDistributed(op, adaptor, rewriter);
}
if (srcLayout.isa<MmaEncodingAttr>() &&
dstLayout.isa<DotOperandEncodingAttr>()) {
return lowerMmaToDotOperand(op, adaptor, rewriter);
}
// TODO: to be implemented
llvm_unreachable("unsupported layout conversion");
return failure();
}
private:
SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
ConversionPatternRewriter &rewriter,
unsigned elemId, RankedTensorType type,
ArrayRef<unsigned> multiDimCTAInRepId,
ArrayRef<unsigned> shapePerCTA) const {
auto shape = type.getShape();
unsigned rank = shape.size();
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
auto multiDimOffsetFirstElem =
emitBaseIndexForLayout(loc, rewriter, blockedLayout, type);
SmallVector<Value> multiDimOffset(rank);
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
elemId, getSizePerThread(layout), getOrder(layout));
for (unsigned d = 0; d < rank; ++d) {
multiDimOffset[d] = add(multiDimOffsetFirstElem[d],
i32_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
multiDimElemId[d]));
}
return multiDimOffset;
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
unsigned dim = sliceLayout.getDim();
auto parentEncoding = sliceLayout.getParent();
auto parentSizePerThread = getSizePerThread(parentEncoding);
auto parentShape = sliceLayout.paddedShape(shape);
auto parentTy = RankedTensorType::get(parentShape, type.getElementType(),
parentEncoding);
auto offsets = emitOffsetForLayout(layout, type);
auto parentOffset = emitOffsetForLayout(parentEncoding, parentTy);
SmallVector<int> idxs;
for (SmallVector<unsigned> off : offsets) {
off.insert(off.begin() + dim, 0);
auto it = std::find(parentOffset.begin(), parentOffset.end(), off);
idxs.push_back(std::distance(parentOffset.begin(), it));
}
auto multiDimOffsetParent = getMultiDimOffset(
parentEncoding, loc, rewriter, idxs[elemId], parentTy,
sliceLayout.paddedShape(multiDimCTAInRepId),
sliceLayout.paddedShape(shapePerCTA));
SmallVector<Value> multiDimOffset(rank);
for (unsigned d = 0; d < rank + 1; ++d) {
if (d == dim)
continue;
unsigned slicedD = d < dim ? d : (d - 1);
multiDimOffset[slicedD] = multiDimOffsetParent[d];
}
return multiDimOffset;
}
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
SmallVector<Value> mmaColIdx(4);
SmallVector<Value> mmaRowIdx(2);
Value threadId = getThreadId(rewriter, loc);
#ifdef USE_ROCM
Value warpSize = i32_val(64);
#else
Value warpSize = i32_val(32);
#endif
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
// TODO: fix the bug in MMAEncodingAttr document
SmallVector<Value> multiDimWarpId(2);
multiDimWarpId[0] = urem(warpId, i32_val(mmaLayout.getWarpsPerCTA()[0]));
multiDimWarpId[1] = udiv(warpId, i32_val(mmaLayout.getWarpsPerCTA()[0]));
Value _1 = i32_val(1);
Value _2 = i32_val(2);
Value _4 = i32_val(4);
Value _8 = i32_val(8);
Value _16 = i32_val(16);
if (mmaLayout.isAmpere()) {
multiDimWarpId[0] = urem(multiDimWarpId[0], i32_val(shape[0] / 16));
multiDimWarpId[1] = urem(multiDimWarpId[1], i32_val(shape[1] / 8));
Value mmaGrpId = udiv(laneId, _4);
Value mmaGrpIdP8 = add(mmaGrpId, _8);
Value mmaThreadIdInGrp = urem(laneId, _4);
Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, _2);
Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, _1);
Value rowWarpOffset = mul(multiDimWarpId[0], _16);
mmaRowIdx[0] = add(mmaGrpId, rowWarpOffset);
mmaRowIdx[1] = add(mmaGrpIdP8, rowWarpOffset);
Value colWarpOffset = mul(multiDimWarpId[1], _8);
mmaColIdx[0] = add(mmaThreadIdInGrpM2, colWarpOffset);
mmaColIdx[1] = add(mmaThreadIdInGrpM2P1, colWarpOffset);
} else if (mmaLayout.isVolta()) {
// Volta doesn't follow the pattern here."
} else {
llvm_unreachable("Unexpected MMALayout version");
}
assert(rank == 2);
SmallVector<Value> multiDimOffset(rank);
if (mmaLayout.isAmpere()) {
multiDimOffset[0] = elemId < 2 ? mmaRowIdx[0] : mmaRowIdx[1];
multiDimOffset[1] = elemId % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1];
multiDimOffset[0] = add(
multiDimOffset[0], i32_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
multiDimOffset[1] = add(
multiDimOffset[1], i32_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
} else if (mmaLayout.isVolta()) {
auto [isARow, isBRow, isAVec4, isBVec4, _] =
mmaLayout.decodeVoltaLayoutStates();
auto coords = SharedToDotOperandMMAv1::getMNCoords(
threadId, rewriter, mmaLayout.getWarpsPerCTA(), mmaLayout, shape,
isARow, isBRow, isAVec4, isBVec4);
return coords[elemId];
} else {
llvm_unreachable("Unexpected MMALayout version");
}
return multiDimOffset;
}
llvm_unreachable("unexpected layout in getMultiDimOffset");
}
// shared memory rd/st for blocked or mma layout with data padding
void processReplica(Location loc, ConversionPatternRewriter &rewriter,
bool stNotRd, RankedTensorType type,
ArrayRef<unsigned> numCTAsEachRep,
ArrayRef<unsigned> multiDimRepId, unsigned vec,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> outOrd, SmallVector<Value> &vals,
Value smemBase) const {
auto accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
auto layout = type.getEncoding();
auto rank = type.getRank();
auto sizePerThread = getSizePerThread(layout);
auto accumSizePerThread = product<unsigned>(sizePerThread);
SmallVector<unsigned> numCTAs(rank);
auto shapePerCTA = getShapePerCTA(layout, type.getShape());
auto order = getOrder(layout);
for (unsigned d = 0; d < rank; ++d) {
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
}
auto elemTy = type.getElementType();
bool isInt1 = elemTy.isInteger(1);
bool isPtr = elemTy.isa<triton::PointerType>();
auto llvmElemTyOrig = getTypeConverter()->convertType(elemTy);
if (isInt1)
elemTy = IntegerType::get(elemTy.getContext(), 8);
else if (isPtr)
elemTy = IntegerType::get(elemTy.getContext(), 64);
auto llvmElemTy = getTypeConverter()->convertType(elemTy);
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
auto multiDimCTAInRepId =
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep, order);
SmallVector<unsigned> multiDimCTAId(rank);
for (const auto &it : llvm::enumerate(multiDimCTAInRepId)) {
auto d = it.index();
multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value();
}
auto linearCTAId =
getLinearIndex<unsigned>(multiDimCTAId, numCTAs, order);
// TODO: This is actually redundant index calculation, we should
// consider of caching the index calculation result in case
// of performance issue observed.
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
SmallVector<Value> multiDimOffset =
getMultiDimOffset(layout, loc, rewriter, elemId, type,
multiDimCTAInRepId, shapePerCTA);
Value offset =
linearize(rewriter, loc, multiDimOffset, paddedRepShape, outOrd);
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
Value ptr = gep(elemPtrTy, smemBase, offset);
auto vecTy = vec_ty(llvmElemTy, vec);
ptr = bitcast(ptr, ptr_ty(vecTy, 3));
if (stNotRd) {
Value valVec = undef(vecTy);
for (unsigned v = 0; v < vec; ++v) {
auto currVal = vals[elemId + linearCTAId * accumSizePerThread + v];
if (isInt1)
currVal = zext(llvmElemTy, currVal);
else if (isPtr)
currVal = ptrtoint(llvmElemTy, currVal);
valVec = insert_element(vecTy, valVec, currVal, i32_val(v));
}
store(valVec, ptr);
} else {
Value valVec = load(ptr);
for (unsigned v = 0; v < vec; ++v) {
Value currVal = extract_element(llvmElemTy, valVec, i32_val(v));
if (isInt1)
currVal = icmp_ne(currVal,
rewriter.create<LLVM::ConstantOp>(
loc, i8_ty, rewriter.getI8IntegerAttr(0)));
else if (isPtr)
currVal = inttoptr(llvmElemTyOrig, currVal);
vals[elemId + linearCTAId * accumSizePerThread + v] = currVal;
}
}
}
}
}
// The MMAV1's result is quite different from the existing "Replica"
// structure, add a new simple but clear implementation for it to avoid
// modifying the logic of the existing one.
void processReplicaForMMAV1(Location loc, ConversionPatternRewriter &rewriter,
bool stNotRd, RankedTensorType type,
ArrayRef<unsigned> multiDimRepId, unsigned vec,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> outOrd,
SmallVector<Value> &vals, Value smemBase,
ArrayRef<int64_t> shape,
bool isDestMma = false) const {
unsigned accumNumCTAsEachRep = 1;
auto layout = type.getEncoding();
MmaEncodingAttr mma = layout.dyn_cast<MmaEncodingAttr>();
auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>();
if (sliceLayout)
mma = sliceLayout.getParent().cast<MmaEncodingAttr>();
auto order = getOrder(layout);
auto rank = type.getRank();
int accumSizePerThread = vals.size();
SmallVector<unsigned> numCTAs(rank, 1);
SmallVector<unsigned> numCTAsEachRep(rank, 1);
SmallVector<unsigned> shapePerCTA = getShapePerCTA(layout, shape);
auto elemTy = type.getElementType();
int ctaId = 0;
auto multiDimCTAInRepId =
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep, order);
SmallVector<unsigned> multiDimCTAId(rank);
for (const auto &it : llvm::enumerate(multiDimCTAInRepId)) {
auto d = it.index();
multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value();
}
std::vector<std::pair<SmallVector<Value>, Value>> coord2valT(
accumSizePerThread);
bool needTrans = outOrd[0] != 0;
if (sliceLayout || isDestMma)
needTrans = false;
vec = needTrans ? 2 : 1;
{
// We need to transpose the coordinates and values here to enable vec=2
// when store to smem.
std::vector<std::pair<SmallVector<Value>, Value>> coord2val(
accumSizePerThread);
for (unsigned elemId = 0; elemId < accumSizePerThread; ++elemId) {
// TODO[Superjomn]: Move the coordinate computation out of loop, it is
// duplicate in Volta.
SmallVector<Value> multiDimOffset =
getMultiDimOffset(layout, loc, rewriter, elemId, type,
multiDimCTAInRepId, shapePerCTA);
coord2val[elemId] = std::make_pair(multiDimOffset, vals[elemId]);
}
if (needTrans) {
// do transpose
auto aEncoding =
DotOperandEncodingAttr::get(mma.getContext(), 0, mma, 0);
int numM = aEncoding.getMMAv1NumOuter(shape);
int numN = accumSizePerThread / numM;
for (int r = 0; r < numM; r++) {
for (int c = 0; c < numN; c++) {
coord2valT[r * numN + c] = std::move(coord2val[c * numM + r]);
}
}
} else {
coord2valT = std::move(coord2val);
}
}
// Now the coord2valT has the transposed and contiguous elements(with
// vec=2), the original vals is not needed.
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
auto coord = coord2valT[elemId].first;
Value offset = linearize(rewriter, loc, coord, paddedRepShape, outOrd);
auto elemPtrTy = ptr_ty(elemTy, 3);
Value ptr = gep(elemPtrTy, smemBase, offset);
auto vecTy = vec_ty(elemTy, vec);
ptr = bitcast(ptr, ptr_ty(vecTy, 3));
if (stNotRd) {
Value valVec = undef(vecTy);
for (unsigned v = 0; v < vec; ++v) {
auto currVal = coord2valT[elemId + v].second;
valVec = insert_element(vecTy, valVec, currVal, i32_val(v));
}
store(valVec, ptr);
} else {
Value valVec = load(ptr);
for (unsigned v = 0; v < vec; ++v) {
Value currVal = extract_element(elemTy, valVec, i32_val(v));
vals[elemId + v] = currVal;
}
}
}
}
// blocked/mma -> blocked/mma.
// Data padding in shared memory to avoid bank conflict.
LogicalResult
lowerDistributedToDistributed(triton::gpu::ConvertLayoutOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
Value src = op.getSrc();
Value dst = op.getResult();
auto srcTy = src.getType().cast<RankedTensorType>();
auto dstTy = dst.getType().cast<RankedTensorType>();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
smemBase = bitcast(smemBase, elemPtrTy);
auto shape = dstTy.getShape();
unsigned rank = dstTy.getRank();
SmallVector<unsigned> numReplicates(rank);
SmallVector<unsigned> inNumCTAsEachRep(rank);
SmallVector<unsigned> outNumCTAsEachRep(rank);
SmallVector<unsigned> inNumCTAs(rank);
SmallVector<unsigned> outNumCTAs(rank);
auto srcShapePerCTA = getShapePerCTA(srcLayout, srcTy.getShape());
auto dstShapePerCTA = getShapePerCTA(dstLayout, shape);
// For Volta, all the coords for a CTA are calculated.
bool isSrcMmaV1{}, isDstMmaV1{};
if (auto mmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>()) {
isSrcMmaV1 = mmaLayout.isVolta();
}
if (auto sliceLayout = srcLayout.dyn_cast<SliceEncodingAttr>()) {
isSrcMmaV1 = sliceLayout.getParent().isa<MmaEncodingAttr>() &&
sliceLayout.getParent().cast<MmaEncodingAttr>().isVolta();
}
if (auto mmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>()) {
isDstMmaV1 = mmaLayout.isVolta();
}
if (auto sliceLayout = dstLayout.dyn_cast<SliceEncodingAttr>()) {
isDstMmaV1 = sliceLayout.getParent().isa<MmaEncodingAttr>() &&
sliceLayout.getParent().cast<MmaEncodingAttr>().isVolta();
}
for (unsigned d = 0; d < rank; ++d) {
unsigned inPerCTA = std::min<unsigned>(shape[d], srcShapePerCTA[d]);
unsigned outPerCTA = std::min<unsigned>(shape[d], dstShapePerCTA[d]);
unsigned maxPerCTA = std::max(inPerCTA, outPerCTA);
numReplicates[d] = ceil<unsigned>(shape[d], maxPerCTA);
inNumCTAsEachRep[d] = maxPerCTA / inPerCTA;
outNumCTAsEachRep[d] = maxPerCTA / outPerCTA;
assert(maxPerCTA % inPerCTA == 0 && maxPerCTA % outPerCTA == 0);
inNumCTAs[d] = ceil<unsigned>(shape[d], inPerCTA);
outNumCTAs[d] = ceil<unsigned>(shape[d], outPerCTA);
}
// Potentially we need to store for multiple CTAs in this replication
auto accumNumReplicates = product<unsigned>(numReplicates);
// unsigned elems = getTotalElemsPerThread(srcTy);
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
rewriter, srcTy);
unsigned inVec = 0;
unsigned outVec = 0;
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
unsigned outElems = getTotalElemsPerThread(dstTy);
auto outOrd = getOrder(dstLayout);
SmallVector<Value> outVals(outElems);
for (unsigned repId = 0; repId < accumNumReplicates; ++repId) {
auto multiDimRepId =
getMultiDimIndex<unsigned>(repId, numReplicates, outOrd);
if (repId != 0)
barrier();
if (srcLayout.isa<BlockedEncodingAttr>() ||
srcLayout.isa<SliceEncodingAttr>() ||
srcLayout.isa<MmaEncodingAttr>()) {
if (isSrcMmaV1)
processReplicaForMMAV1(loc, rewriter, /*stNotRd*/ true, srcTy,
multiDimRepId, inVec, paddedRepShape, outOrd,
vals, smemBase, shape);
else
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy,
inNumCTAsEachRep, multiDimRepId, inVec, paddedRepShape,
outOrd, vals, smemBase);
} else {
assert(0 && "ConvertLayout with input layout not implemented");
return failure();
}
barrier();
if (dstLayout.isa<BlockedEncodingAttr>() ||
dstLayout.isa<SliceEncodingAttr>() ||
dstLayout.isa<MmaEncodingAttr>()) {
if (isDstMmaV1)
processReplicaForMMAV1(loc, rewriter, /*stNotRd*/ false, dstTy,
multiDimRepId, outVec, paddedRepShape, outOrd,
outVals, smemBase, shape, /*isDestMma=*/true);
else
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy,
outNumCTAsEachRep, multiDimRepId, outVec,
paddedRepShape, outOrd, outVals, smemBase);
} else {
assert(0 && "ConvertLayout with output layout not implemented");
return failure();
}
}
SmallVector<Type> types(outElems, llvmElemTy);
auto *ctx = llvmElemTy.getContext();
Type structTy = struct_ty(types);
Value result =
getTypeConverter()->packLLElements(loc, outVals, rewriter, dstTy);
rewriter.replaceOp(op, result);
return success();
}
// blocked -> shared.
// Swizzling in shared memory to avoid bank conflict. Normally used for
// A/B operands of dots.
LogicalResult
lowerDistributedToShared(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
Value src = op.getSrc();
Value dst = op.getResult();
auto srcTy = src.getType().cast<RankedTensorType>();
auto srcShape = srcTy.getShape();
auto dstTy = dst.getType().cast<RankedTensorType>();
auto dstShape = dstTy.getShape();
assert(srcShape.size() == 2 &&
"Unexpected rank of ConvertLayout(blocked->shared)");
auto srcLayout = srcTy.getEncoding();
auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>();
auto inOrd = getOrder(srcLayout);
auto outOrd = dstSharedLayout.getOrder();
Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
auto elemTy = getTypeConverter()->convertType(srcTy.getElementType());
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
smemBase = bitcast(smemBase, elemPtrTy);
auto dstStrides =
getStridesFromShapeAndOrder(dstShape, outOrd, loc, rewriter);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy);
storeDistributedToShared(src, adaptor.getSrc(), dstStrides, srcIndices, dst,
smemBase, elemTy, loc, rewriter);
auto smemObj =
SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
}
// shared -> mma_operand
LogicalResult
lowerSharedToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
Value src = op.getSrc();
Value dst = op.getResult();
auto dstTensorTy = dst.getType().cast<RankedTensorType>();
auto srcTensorTy = src.getType().cast<RankedTensorType>();
auto dotOperandLayout =
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
auto sharedLayout = srcTensorTy.getEncoding().cast<SharedEncodingAttr>();
bool isOuter{};
int K{};
if (dotOperandLayout.getOpIdx() == 0) // $a
K = dstTensorTy.getShape()[sharedLayout.getOrder()[0]];
else // $b
K = dstTensorTy.getShape()[sharedLayout.getOrder()[1]];
isOuter = K == 1;
Value res;
if (auto mmaLayout =
dotOperandLayout.getParent().dyn_cast_or_null<MmaEncodingAttr>()) {
res = lowerSharedToDotOperandMMA(op, adaptor, rewriter, mmaLayout,
dotOperandLayout, isOuter);
} else if (auto blockedLayout =
dotOperandLayout.getParent()
.dyn_cast_or_null<BlockedEncodingAttr>()) {
auto dotOpLayout =
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
auto thread = getThreadId(rewriter, loc);
res = SharedToDotOperandFMA::convertLayout(
dotOpLayout.getOpIdx(), src, adaptor.getSrc(), blockedLayout, thread,
loc, getTypeConverter(), rewriter);
} else {
assert(false && "Unsupported dot operand layout found");
}
rewriter.replaceOp(op, res);
return success();
}
// mma -> dot_operand
LogicalResult
lowerMmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
auto dstTy = op.getResult().getType().cast<RankedTensorType>();
if (isMmaToDotShortcut(srcTy, dstTy)) {
// get source values
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
rewriter, srcTy);
unsigned elems = getTotalElemsPerThread(srcTy);
Type elemTy =
this->getTypeConverter()->convertType(srcTy.getElementType());
// for the destination type, we need to pack values together
// so they can be consumed by tensor core operations
SmallVector<Value> vecVals;
SmallVector<Type> types;
// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer
// instructions to pack & unpack sub-word integers. A workaround is to
// store the results of ldmatrix in i32
auto elemSize = elemTy.getIntOrFloatBitWidth();
if (auto intTy = elemTy.dyn_cast<IntegerType>() && elemSize <= 16) {
auto fold = 32 / elemSize;
for (unsigned i = 0; i < elems; i += fold) {
Value val = i32_val(0);
for (unsigned j = 0; j < fold; j++) {
auto ext =
shl(i32_ty, zext(i32_ty, vals[i + j]), i32_val(elemSize * j));
val = or_(i32_ty, val, ext);
}
vecVals.push_back(val);
}
elems = elems / (32 / elemSize);
types = SmallVector<Type>(elems, i32_ty);
} else {
unsigned vecSize = std::max<unsigned>(32 / elemSize, 1);
Type vecTy = vec_ty(elemTy, vecSize);
types = SmallVector<Type>(elems / vecSize, vecTy);
for (unsigned i = 0; i < elems; i += vecSize) {
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
for (unsigned j = 0; j < vecSize; j++)
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
vecVals.push_back(packed);
}
}
// This needs to be ordered the same way that
// ldmatrix.x4 would order it
// TODO: this needs to be refactor so we don't
// implicitly depends on how emitOffsetsForMMAV2
// is implemented
SmallVector<Value> reorderedVals;
for (unsigned i = 0; i < vecVals.size(); i += 4) {
reorderedVals.push_back(bitcast(vecVals[i], i32_ty));
reorderedVals.push_back(bitcast(vecVals[i + 2], i32_ty));
reorderedVals.push_back(bitcast(vecVals[i + 1], i32_ty));
reorderedVals.push_back(bitcast(vecVals[i + 3], i32_ty));
}
Value view = getTypeConverter()->packLLElements(loc, reorderedVals,
rewriter, dstTy);
rewriter.replaceOp(op, view);
return success();
}
return failure();
}
// shared -> dot_operand if the result layout is mma
Value lowerSharedToDotOperandMMA(
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, const MmaEncodingAttr &mmaLayout,
const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const {
auto loc = op.getLoc();
Value src = op.getSrc();
Value dst = op.getResult();
auto smemObj =
getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter);
Value res;
if (!isOuter && mmaLayout.isAmpere()) { // tensor core v2
res = SharedToDotOperandMMAv2::convertLayout(
dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout,
smemObj, getTypeConverter(), tid_val());
} else if (!isOuter && mmaLayout.isVolta() &&
supportMMA(dst, mmaLayout.getVersionMajor())) { // tensor core v1
bool isMMAv1Row = dotOperandLayout.getMMAv1IsRow();
auto srcSharedLayout = src.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<SharedEncodingAttr>();
// Can only convert [1, 0] to row or [0, 1] to col for now
if ((srcSharedLayout.getOrder()[0] == 1 && !isMMAv1Row) ||
(srcSharedLayout.getOrder()[0] == 0 && isMMAv1Row)) {
llvm::errs() << "Unsupported Shared -> DotOperand[MMAv1] conversion\n";
return Value();
}
res = SharedToDotOperandMMAv1::convertLayout(
dotOperandLayout.getOpIdx(), src, smemObj, getThreadId(rewriter, loc),
loc, getTypeConverter(), rewriter, dst.getType());
} else {
assert(false && "Unsupported mma layout found");
}
return res;
}
}; // namespace triton::gpu::ConvertLayoutOp
void populateConvertLayoutOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) {
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation,
indexCacheInfo, benefit);
}