[BACKEND] Adding support for slice layout in InsertSliceAsyncOp (#2438)

This commit is contained in:
Zahi Moudallal
2023-10-03 20:59:53 -07:00
committed by GitHub
parent 6b860e7a74
commit 0d84a7d70c
4 changed files with 66 additions and 23 deletions

View File

@@ -1277,10 +1277,12 @@ struct InsertSliceAsyncOpConversion
auto srcTy = src.getType().cast<RankedTensorType>();
auto resTy = dst.getType().cast<RankedTensorType>();
auto resElemTy = getTypeConverter()->convertType(resTy.getElementType());
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto srcLayout = srcTy.getEncoding();
assert((srcLayout.isa<BlockedEncodingAttr, SliceEncodingAttr>() &&
"Unexpected srcLayout in InsertSliceAsyncOpConversion"));
auto resSharedLayout = resTy.getEncoding().cast<SharedEncodingAttr>();
auto srcShape = srcTy.getShape();
assert(srcShape.size() == 2 &&
assert((srcShape.size() == 1 || srcShape.size() == 2) &&
"insert_slice_async: Unexpected rank of %src");
Value llDst = adaptor.getDst();
@@ -1345,25 +1347,15 @@ struct InsertSliceAsyncOpConversion
unsigned numElems = getTotalElemsPerThread(srcTy);
unsigned perPhase = resSharedLayout.getPerPhase();
unsigned maxPhase = resSharedLayout.getMaxPhase();
auto sizePerThread = srcBlockedLayout.getSizePerThread();
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
auto inOrder = srcBlockedLayout.getOrder();
DenseMap<unsigned, Value> sharedPtrs =
getSwizzledSharedPtrs(loc, inVec, srcTy, resSharedLayout, resElemTy,
smemObj, rewriter, offsetVals, srcStrides);
// If perPhase * maxPhase > threadsPerCTA, we will have elements
// that share the same tile indices. The index calculation will
// be cached.
auto numSwizzleRows = std::max<unsigned>(
(perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1);
// A sharedLayout encoding has a "vec" parameter.
// On the column dimension, if inVec > outVec, it means we have to divide
// single vector read into multiple ones
auto numVecCols = std::max<unsigned>(inVec / outVec, 1);
auto srcIndices = emitIndices(loc, rewriter, srcBlockedLayout, srcTy);
for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) {
// 16 * 8 = 128bits
auto maxBitWidth =