mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge remote-tracking branch 'upstream/main' into ifu_4_26_2023
This commit is contained in:
@@ -5,10 +5,10 @@ using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::LLVM::getStridesFromShapeAndOrder;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::getContigPerThread;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
using ::mlir::triton::gpu::getOrder;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::getSizePerThread;
|
||||
using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
using ::mlir::triton::gpu::isaDistributedLayout;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
@@ -426,14 +426,14 @@ private:
|
||||
}
|
||||
// Potentially we need to store for multiple CTAs in this replication
|
||||
auto accumNumReplicates = product<unsigned>(numReplicates);
|
||||
// unsigned elems = getElemsPerThread(srcTy);
|
||||
// unsigned elems = getTotalElemsPerThread(srcTy);
|
||||
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
|
||||
rewriter, srcTy);
|
||||
unsigned inVec = 0;
|
||||
unsigned outVec = 0;
|
||||
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
|
||||
|
||||
unsigned outElems = getElemsPerThread(dstTy);
|
||||
unsigned outElems = getTotalElemsPerThread(dstTy);
|
||||
auto outOrd = getOrder(dstLayout);
|
||||
SmallVector<Value> outVals(outElems);
|
||||
|
||||
@@ -572,15 +572,11 @@ private:
|
||||
auto loc = op.getLoc();
|
||||
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
|
||||
auto dstTy = op.getResult().getType().cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto dstLayout = dstTy.getEncoding();
|
||||
auto srcMmaLayout = srcLayout.cast<MmaEncodingAttr>();
|
||||
auto dstDotLayout = dstLayout.cast<DotOperandEncodingAttr>();
|
||||
if (isMmaToDotShortcut(srcMmaLayout, dstDotLayout)) {
|
||||
if (isMmaToDotShortcut(srcTy, dstTy)) {
|
||||
// get source values
|
||||
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
|
||||
rewriter, srcTy);
|
||||
unsigned elems = getElemsPerThread(srcTy);
|
||||
unsigned elems = getTotalElemsPerThread(srcTy);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(srcTy.getElementType());
|
||||
// for the destination type, we need to pack values together
|
||||
|
||||
Reference in New Issue
Block a user