[BACKEND] Fixed up ConvertLayout for slices (#1616)

This commit is contained in:
Philippe Tillet
2023-05-04 07:06:54 -07:00
committed by GitHub
parent 19e7238d50
commit f387a6c863

View File

@@ -107,14 +107,19 @@ private:
unsigned dim = sliceLayout.getDim();
auto parentEncoding = sliceLayout.getParent();
auto parentSizePerThread = getSizePerThread(parentEncoding);
unsigned stride = 1;
if (getOrder(parentEncoding)[0] == dim)
stride = parentSizePerThread[dim];
auto parentShape = sliceLayout.paddedShape(shape);
auto parentTy = RankedTensorType::get(parentShape, type.getElementType(),
parentEncoding);
auto offsets = emitOffsetForLayout(layout, type);
auto parentOffset = emitOffsetForLayout(parentEncoding, parentTy);
SmallVector<int> idxs;
for (SmallVector<unsigned> off : offsets) {
off.insert(off.begin() + dim, 0);
auto it = std::find(parentOffset.begin(), parentOffset.end(), off);
idxs.push_back(std::distance(parentOffset.begin(), it));
}
auto multiDimOffsetParent = getMultiDimOffset(
parentEncoding, loc, rewriter, elemId * stride, parentTy,
parentEncoding, loc, rewriter, idxs[elemId], parentTy,
sliceLayout.paddedShape(multiDimCTAInRepId),
sliceLayout.paddedShape(shapePerCTA));
SmallVector<Value> multiDimOffset(rank);