mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Adding support for slice layout in InsertSliceAsyncOp (#2438)
This commit is contained in:
@@ -1277,10 +1277,12 @@ struct InsertSliceAsyncOpConversion
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto resTy = dst.getType().cast<RankedTensorType>();
|
||||
auto resElemTy = getTypeConverter()->convertType(resTy.getElementType());
|
||||
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
assert((srcLayout.isa<BlockedEncodingAttr, SliceEncodingAttr>() &&
|
||||
"Unexpected srcLayout in InsertSliceAsyncOpConversion"));
|
||||
auto resSharedLayout = resTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto srcShape = srcTy.getShape();
|
||||
assert(srcShape.size() == 2 &&
|
||||
assert((srcShape.size() == 1 || srcShape.size() == 2) &&
|
||||
"insert_slice_async: Unexpected rank of %src");
|
||||
|
||||
Value llDst = adaptor.getDst();
|
||||
@@ -1345,25 +1347,15 @@ struct InsertSliceAsyncOpConversion
|
||||
unsigned numElems = getTotalElemsPerThread(srcTy);
|
||||
unsigned perPhase = resSharedLayout.getPerPhase();
|
||||
unsigned maxPhase = resSharedLayout.getMaxPhase();
|
||||
auto sizePerThread = srcBlockedLayout.getSizePerThread();
|
||||
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
|
||||
auto inOrder = srcBlockedLayout.getOrder();
|
||||
DenseMap<unsigned, Value> sharedPtrs =
|
||||
getSwizzledSharedPtrs(loc, inVec, srcTy, resSharedLayout, resElemTy,
|
||||
smemObj, rewriter, offsetVals, srcStrides);
|
||||
|
||||
// If perPhase * maxPhase > threadsPerCTA, we will have elements
|
||||
// that share the same tile indices. The index calculation will
|
||||
// be cached.
|
||||
auto numSwizzleRows = std::max<unsigned>(
|
||||
(perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1);
|
||||
// A sharedLayout encoding has a "vec" parameter.
|
||||
// On the column dimension, if inVec > outVec, it means we have to divide
|
||||
// single vector read into multiple ones
|
||||
auto numVecCols = std::max<unsigned>(inVec / outVec, 1);
|
||||
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcBlockedLayout, srcTy);
|
||||
|
||||
for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) {
|
||||
// 16 * 8 = 128bits
|
||||
auto maxBitWidth =
|
||||
|
||||
Reference in New Issue
Block a user