Merge remote-tracking branch 'upstream/main' into ifu_4_26_2023

This commit is contained in:
Michael Melesse
2023-04-26 12:04:21 -05:00
119 changed files with 5234 additions and 3454 deletions

View File

@@ -5,10 +5,10 @@ using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStridesFromShapeAndOrder;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getContigPerThread;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::getTotalElemsPerThread;
using ::mlir::triton::gpu::isaDistributedLayout;
using ::mlir::triton::gpu::SharedEncodingAttr;
@@ -426,14 +426,14 @@ private:
}
// Potentially we need to store for multiple CTAs in this replication
auto accumNumReplicates = product<unsigned>(numReplicates);
// unsigned elems = getElemsPerThread(srcTy);
// unsigned elems = getTotalElemsPerThread(srcTy);
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
rewriter, srcTy);
unsigned inVec = 0;
unsigned outVec = 0;
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
unsigned outElems = getElemsPerThread(dstTy);
unsigned outElems = getTotalElemsPerThread(dstTy);
auto outOrd = getOrder(dstLayout);
SmallVector<Value> outVals(outElems);
@@ -572,15 +572,11 @@ private:
auto loc = op.getLoc();
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
auto dstTy = op.getResult().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding();
auto dstLayout = dstTy.getEncoding();
auto srcMmaLayout = srcLayout.cast<MmaEncodingAttr>();
auto dstDotLayout = dstLayout.cast<DotOperandEncodingAttr>();
if (isMmaToDotShortcut(srcMmaLayout, dstDotLayout)) {
if (isMmaToDotShortcut(srcTy, dstTy)) {
// get source values
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
rewriter, srcTy);
unsigned elems = getElemsPerThread(srcTy);
unsigned elems = getTotalElemsPerThread(srcTy);
Type elemTy =
this->getTypeConverter()->convertType(srcTy.getElementType());
// for the destination type, we need to pack values together