mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Fixed up ConvertLayout for slices (#1616)
This commit is contained in:
@@ -107,14 +107,19 @@ private:
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
auto parentEncoding = sliceLayout.getParent();
|
||||
auto parentSizePerThread = getSizePerThread(parentEncoding);
|
||||
unsigned stride = 1;
|
||||
if (getOrder(parentEncoding)[0] == dim)
|
||||
stride = parentSizePerThread[dim];
|
||||
auto parentShape = sliceLayout.paddedShape(shape);
|
||||
auto parentTy = RankedTensorType::get(parentShape, type.getElementType(),
|
||||
parentEncoding);
|
||||
auto offsets = emitOffsetForLayout(layout, type);
|
||||
auto parentOffset = emitOffsetForLayout(parentEncoding, parentTy);
|
||||
SmallVector<int> idxs;
|
||||
for (SmallVector<unsigned> off : offsets) {
|
||||
off.insert(off.begin() + dim, 0);
|
||||
auto it = std::find(parentOffset.begin(), parentOffset.end(), off);
|
||||
idxs.push_back(std::distance(parentOffset.begin(), it));
|
||||
}
|
||||
auto multiDimOffsetParent = getMultiDimOffset(
|
||||
parentEncoding, loc, rewriter, elemId * stride, parentTy,
|
||||
parentEncoding, loc, rewriter, idxs[elemId], parentTy,
|
||||
sliceLayout.paddedShape(multiDimCTAInRepId),
|
||||
sliceLayout.paddedShape(shapePerCTA));
|
||||
SmallVector<Value> multiDimOffset(rank);
|
||||
|
||||
Reference in New Issue
Block a user