Merge remote-tracking branch 'oai/main' into ifu230601

Conflicts:
	python/test/unit/language/assert_helper.py
	test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
Jason Furmanek
2023-06-01 20:53:33 +00:00
43 changed files with 1381 additions and 409 deletions

View File

@@ -87,6 +87,10 @@ public:
dstLayout.isa<DotOperandEncodingAttr>()) {
return lowerMmaToDotOperand(op, adaptor, rewriter);
}
if (srcLayout.isa<SharedEncodingAttr>() &&
isaDistributedLayout(dstLayout)) {
return lowerSharedToDistributed(op, adaptor, rewriter);
}
// TODO: to be implemented
llvm_unreachable("unsupported layout conversion");
return failure();
@@ -544,9 +548,40 @@ private:
}
}
SmallVector<Type> types(outElems, llvmElemTy);
auto *ctx = llvmElemTy.getContext();
Type structTy = struct_ty(types);
Value result =
getTypeConverter()->packLLElements(loc, outVals, rewriter, dstTy);
rewriter.replaceOp(op, result);
return success();
}
LogicalResult
lowerSharedToDistributed(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
Value src = op.getSrc();
Value dst = op.getResult();
auto srcTy = src.getType().cast<RankedTensorType>();
auto srcShape = srcTy.getShape();
auto dstTy = dst.getType().cast<RankedTensorType>();
auto dstShape = dstTy.getShape();
assert(dstShape.size() == 2 &&
"Unexpected rank of ConvertLayout(shared->blocked)");
auto srcSharedLayout = srcTy.getEncoding().cast<SharedEncodingAttr>();
auto dstLayout = dstTy.getEncoding();
auto inOrd = getOrder(srcSharedLayout);
auto smemObj =
getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter);
auto elemTy = getTypeConverter()->convertType(dstTy.getElementType());
auto srcStrides =
getStridesFromShapeAndOrder(srcShape, inOrd, loc, rewriter);
auto dstIndices = emitIndices(loc, rewriter, dstLayout, dstTy);
SmallVector<Value> outVals = loadSharedToDistributed(
dst, dstIndices, src, smemObj, elemTy, loc, rewriter);
Value result =
getTypeConverter()->packLLElements(loc, outVals, rewriter, dstTy);
rewriter.replaceOp(op, result);