mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge remote-tracking branch 'oai/main' into ifu230601
Conflicts: python/test/unit/language/assert_helper.py test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
@@ -87,6 +87,10 @@ public:
|
||||
dstLayout.isa<DotOperandEncodingAttr>()) {
|
||||
return lowerMmaToDotOperand(op, adaptor, rewriter);
|
||||
}
|
||||
if (srcLayout.isa<SharedEncodingAttr>() &&
|
||||
isaDistributedLayout(dstLayout)) {
|
||||
return lowerSharedToDistributed(op, adaptor, rewriter);
|
||||
}
|
||||
// TODO: to be implemented
|
||||
llvm_unreachable("unsupported layout conversion");
|
||||
return failure();
|
||||
@@ -544,9 +548,40 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Type> types(outElems, llvmElemTy);
|
||||
auto *ctx = llvmElemTy.getContext();
|
||||
Type structTy = struct_ty(types);
|
||||
Value result =
|
||||
getTypeConverter()->packLLElements(loc, outVals, rewriter, dstTy);
|
||||
rewriter.replaceOp(op, result);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
lowerSharedToDistributed(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
Value src = op.getSrc();
|
||||
Value dst = op.getResult();
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
auto dstShape = dstTy.getShape();
|
||||
assert(dstShape.size() == 2 &&
|
||||
"Unexpected rank of ConvertLayout(shared->blocked)");
|
||||
auto srcSharedLayout = srcTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto dstLayout = dstTy.getEncoding();
|
||||
auto inOrd = getOrder(srcSharedLayout);
|
||||
|
||||
auto smemObj =
|
||||
getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter);
|
||||
auto elemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
||||
|
||||
auto srcStrides =
|
||||
getStridesFromShapeAndOrder(srcShape, inOrd, loc, rewriter);
|
||||
auto dstIndices = emitIndices(loc, rewriter, dstLayout, dstTy);
|
||||
|
||||
SmallVector<Value> outVals = loadSharedToDistributed(
|
||||
dst, dstIndices, src, smemObj, elemTy, loc, rewriter);
|
||||
|
||||
Value result =
|
||||
getTypeConverter()->packLLElements(loc, outVals, rewriter, dstTy);
|
||||
rewriter.replaceOp(op, result);
|
||||
|
||||
Reference in New Issue
Block a user