[BACKEND] Decomposed getElemsPerThread to return a vector of the per-dim elements per thread (#1549)

This is a prerequisite for updating the semantics of SliceEncodingAttr.
This commit is contained in:
zahimoud
2023-04-19 16:52:16 -07:00
committed by GitHub
parent b42e3d06d4
commit 8f7424221f
16 changed files with 121 additions and 65 deletions

View File

@@ -21,7 +21,9 @@ namespace mlir {
namespace triton {
namespace gpu {
unsigned getElemsPerThread(Type type);
unsigned getTotalElemsPerThread(Type type);
SmallVector<unsigned> getElemsPerThread(Type type);
SmallVector<unsigned> getThreadsPerWarp(Attribute layout);

View File

@@ -35,7 +35,8 @@ Right now, Triton implements two classes of layouts: shared, and distributed.
}];
code extraBaseClassDeclaration = [{
unsigned getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const;
}];
}

View File

@@ -5,10 +5,10 @@ using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStridesFromShapeAndOrder;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getContigPerThread;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::getTotalElemsPerThread;
using ::mlir::triton::gpu::isaDistributedLayout;
using ::mlir::triton::gpu::SharedEncodingAttr;
@@ -422,14 +422,14 @@ private:
}
// Potentially we need to store for multiple CTAs in this replication
auto accumNumReplicates = product<unsigned>(numReplicates);
// unsigned elems = getElemsPerThread(srcTy);
// unsigned elems = getTotalElemsPerThread(srcTy);
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
rewriter, srcTy);
unsigned inVec = 0;
unsigned outVec = 0;
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
unsigned outElems = getElemsPerThread(dstTy);
unsigned outElems = getTotalElemsPerThread(dstTy);
auto outOrd = getOrder(dstLayout);
SmallVector<Value> outVals(outElems);
@@ -572,7 +572,7 @@ private:
// get source values
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
rewriter, srcTy);
unsigned elems = getElemsPerThread(srcTy);
unsigned elems = getTotalElemsPerThread(srcTy);
Type elemTy =
this->getTypeConverter()->convertType(srcTy.getElementType());
// for the destination type, we need to pack values together

View File

@@ -6,10 +6,10 @@ using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStridesFromShapeAndOrder;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getContigPerThread;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::getTotalElemsPerThread;
using ::mlir::triton::gpu::isaDistributedLayout;
using ::mlir::triton::gpu::SharedEncodingAttr;

View File

@@ -7,10 +7,10 @@ using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStridesFromShapeAndOrder;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getContigPerThread;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::getTotalElemsPerThread;
using ::mlir::triton::gpu::isaDistributedLayout;
using ::mlir::triton::gpu::SharedEncodingAttr;

View File

@@ -6,10 +6,10 @@ using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStridesFromShapeAndOrder;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getContigPerThread;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::getTotalElemsPerThread;
using ::mlir::triton::gpu::isaDistributedLayout;
using ::mlir::triton::gpu::SharedEncodingAttr;

View File

@@ -14,7 +14,7 @@ Value loadC(Value tensor, Value llTensor,
ConversionPatternRewriter &rewriter) {
MLIRContext *ctx = tensor.getContext();
auto tensorTy = tensor.getType().cast<RankedTensorType>();
size_t fcSize = triton::gpu::getElemsPerThread(tensor.getType());
size_t fcSize = triton::gpu::getTotalElemsPerThread(tensor.getType());
assert(tensorTy.getEncoding().isa<MmaEncodingAttr>() &&
"Currently, we only support $c with a mma layout.");

View File

@@ -2,7 +2,7 @@
using namespace mlir;
using namespace mlir::triton;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::getTotalElemsPerThread;
struct FpToFpOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::FpToFpOp> {
@@ -499,7 +499,7 @@ struct FpToFpOpConversion
auto srcEltType = srcTensorType.getElementType();
auto dstEltType = dstTensorType.getElementType();
auto loc = op->getLoc();
auto elems = getElemsPerThread(dstTensorType);
auto elems = getTotalElemsPerThread(dstTensorType);
SmallVector<Value> resultVals;
bool isSrcFP8 =
srcEltType.isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>();
@@ -583,7 +583,7 @@ public:
auto resultTy = op.getType();
Location loc = op->getLoc();
unsigned elems = getElemsPerThread(resultTy);
unsigned elems = getTotalElemsPerThread(resultTy);
auto resultElementTy = getElementTypeOrSelf(resultTy);
Type elemTy = this->getTypeConverter()->convertType(resultElementTy);
SmallVector<Type> types(elems, elemTy);

View File

@@ -8,7 +8,7 @@ using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::getTotalElemsPerThread;
using ::mlir::triton::gpu::SharedEncodingAttr;
// Contains some helper functions for both Load and Store conversions.
@@ -72,7 +72,7 @@ struct LoadOpConversion
Type valueElemTy =
typeConverter->convertType(getElementTypeOrSelf(valueTy));
unsigned vec = getVectorSize(ptr);
unsigned numElems = getElemsPerThread(ptr.getType());
unsigned numElems = getTotalElemsPerThread(ptr.getType());
if (llMask)
vec = std::min<size_t>(vec, getMaskAlignment(mask));
@@ -276,7 +276,7 @@ struct StoreOpConversion
typeConverter->convertType(getElementTypeOrSelf(valueTy));
unsigned vec = getVectorSize(ptr);
unsigned elemsPerThread = getElemsPerThread(ptr.getType());
unsigned elemsPerThread = getTotalElemsPerThread(ptr.getType());
auto ptrElems = getTypeConverter()->unpackLLElements(loc, llPtr, rewriter,
ptr.getType());
@@ -493,7 +493,7 @@ struct AtomicRMWOpConversion
tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType())
: op.getResult().getType();
const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
auto elemsPerThread = getElemsPerThread(val.getType());
auto elemsPerThread = getTotalElemsPerThread(val.getType());
// vec = 1, numElements = 1 for scalar
auto vec = getVectorSize(ptr);
int numElems = 1;
@@ -776,7 +776,7 @@ struct InsertSliceAsyncOpConversion
unsigned inVec = getContiguity(src);
unsigned outVec = resSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
unsigned numElems = getElemsPerThread(srcTy);
unsigned numElems = getTotalElemsPerThread(srcTy);
unsigned perPhase = resSharedLayout.getPerPhase();
unsigned maxPhase = resSharedLayout.getMaxPhase();
auto sizePerThread = srcBlockedLayout.getSizePerThread();

View File

@@ -5,8 +5,8 @@ using namespace mlir::triton;
using ::mlir::LLVM::shflSync;
using ::mlir::LLVM::storeShared;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getTotalElemsPerThread;
struct ReduceOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::ReduceOp> {
@@ -64,7 +64,7 @@ private:
ConversionPatternRewriter &rewriter) const {
auto types = op.getInputTypes();
auto operands = adaptor.getOperands();
unsigned srcElems = getElemsPerThread(types[0]);
unsigned srcElems = getTotalElemsPerThread(types[0]);
SmallVector<SmallVector<Value>> srcValues(srcElems);
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto values = getTypeConverter()->unpackLLElements(loc, operands[i],
@@ -158,7 +158,7 @@ private:
elemPtrTys[i]);
}
unsigned srcElems = getElemsPerThread(srcTys[0]);
unsigned srcElems = getTotalElemsPerThread(srcTys[0]);
// Emits indices of the original tensor that each thread
// would own
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]);
@@ -263,7 +263,7 @@ private:
auto resultLayout = resultTy.getEncoding();
unsigned resultElems = getElemsPerThread(resultTy);
unsigned resultElems = getTotalElemsPerThread(resultTy);
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy);
assert(resultIndices.size() == resultElems);
@@ -330,7 +330,7 @@ private:
unsigned sizeIntraWarps = helper.getIntraWarpSize();
unsigned sizeInterWarps = helper.getInterWarpSize();
unsigned srcElems = getElemsPerThread(srcTys[0]);
unsigned srcElems = getTotalElemsPerThread(srcTys[0]);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]);
auto srcValues = unpackInputs(loc, op, adaptor, rewriter);
@@ -457,7 +457,7 @@ private:
op.getResult()[i].getType().dyn_cast<RankedTensorType>()) {
// nd-tensor where n >= 1
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
unsigned resultElems = getElemsPerThread(resultTy);
unsigned resultElems = getTotalElemsPerThread(resultTy);
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy);
assert(resultIndices.size() == resultElems);

View File

@@ -6,7 +6,7 @@ using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::getTotalElemsPerThread;
using ::mlir::triton::gpu::SharedEncodingAttr;
struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
@@ -435,7 +435,7 @@ struct AddPtrOpConversion
auto ptrTy = op.getPtr().getType();
auto resultTensorTy = resultTy.dyn_cast<RankedTensorType>();
if (resultTensorTy) {
unsigned elems = getElemsPerThread(resultTy);
unsigned elems = getTotalElemsPerThread(resultTy);
Type elemTy =
getTypeConverter()->convertType(resultTensorTy.getElementType());
auto ptrs = getTypeConverter()->unpackLLElements(loc, adaptor.getPtr(),

View File

@@ -275,7 +275,7 @@ public:
auto srcEncoding = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
unsigned numElems = triton::gpu::getElemsPerThread(srcTy);
unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy);
// swizzling params as described in TritonGPUAttrDefs.td
unsigned outVec = resSharedLayout.getVec();
unsigned perPhase = resSharedLayout.getPerPhase();
@@ -376,7 +376,7 @@ public:
unsigned minVec = std::min(outVec, inVec);
unsigned perPhase = dstSharedLayout.getPerPhase();
unsigned maxPhase = dstSharedLayout.getMaxPhase();
unsigned numElems = triton::gpu::getElemsPerThread(srcTy);
unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy);
assert(numElems == srcIndices.size());
auto inVals =
getTypeConverter()->unpackLLElements(loc, llSrc, rewriter, srcTy);
@@ -667,7 +667,7 @@ private:
threadOffset * sizePerThread[k] + elemOffset);
}
unsigned elemsPerThread = triton::gpu::getElemsPerThread(type);
unsigned elemsPerThread = triton::gpu::getTotalElemsPerThread(type);
unsigned totalSizePerThread = product<unsigned>(sizePerThread);
SmallVector<SmallVector<unsigned>> reorderedOffset(elemsPerThread);
for (unsigned n = 0; n < elemsPerThread; ++n) {

View File

@@ -8,7 +8,7 @@ using namespace mlir::triton;
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::getTotalElemsPerThread;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
@@ -144,7 +144,7 @@ Type TritonGPUToLLVMTypeConverter::convertTritonTensorType(
return LLVM::LLVMStructType::getLiteral(ctx, types);
}
unsigned numElementsPerThread = getElemsPerThread(type);
unsigned numElementsPerThread = getTotalElemsPerThread(type);
SmallVector<Type, 4> types(numElementsPerThread, eltType);
return LLVM::LLVMStructType::getLiteral(ctx, types);
}

View File

@@ -4,7 +4,7 @@ using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::getTotalElemsPerThread;
struct SplatOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::SplatOp> {
@@ -24,7 +24,7 @@ struct SplatOpConversion
auto tensorTy = resType.cast<RankedTensorType>();
auto srcType = typeConverter->convertType(elemType);
auto llSrc = bitcast(constVal, srcType);
size_t elemsPerThread = getElemsPerThread(tensorTy);
size_t elemsPerThread = getTotalElemsPerThread(tensorTy);
llvm::SmallVector<Value> elems(elemsPerThread, llSrc);
return typeConverter->packLLElements(loc, elems, rewriter, resType);
}
@@ -93,7 +93,7 @@ struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = op.getType().template cast<RankedTensorType>();
unsigned elems = getElemsPerThread(resultTy);
unsigned elems = getTotalElemsPerThread(resultTy);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);

View File

@@ -22,32 +22,54 @@ namespace gpu {
// so that all distributed layouts implement
// these utilities
unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape,
Type eltTy) {
unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape,
Type eltTy) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getElemsPerThread(shape, eltTy);
return blockedLayout.getTotalElemsPerThread(shape, eltTy);
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
return sliceLayout.getElemsPerThread(shape, eltTy);
return sliceLayout.getTotalElemsPerThread(shape, eltTy);
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
return mmaLayout.getElemsPerThread(shape, eltTy);
return mmaLayout.getTotalElemsPerThread(shape, eltTy);
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
return sharedLayout.getElemsPerThread(shape, eltTy);
return sharedLayout.getTotalElemsPerThread(shape, eltTy);
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
return dotLayout.getElemsPerThread(shape, eltTy);
return dotLayout.getTotalElemsPerThread(shape, eltTy);
} else {
assert(0 && "getElemsPerThread not implemented");
return 0;
}
}
unsigned getElemsPerThread(Type type) {
SmallVector<unsigned> getElemsPerThread(Attribute layout,
ArrayRef<int64_t> shape, Type eltTy) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getElemsPerThread(shape, eltTy);
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
return sliceLayout.getElemsPerThread(shape, eltTy);
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
return mmaLayout.getElemsPerThread(shape, eltTy);
} else {
assert(0 && "getElemsPerThread not implemented");
return SmallVector<unsigned>();
}
}
SmallVector<unsigned> getElemsPerThread(Type type) {
if (type.isIntOrIndexOrFloat() || type.isa<triton::PointerType>())
return 1;
return SmallVector<unsigned>(1, 1);
auto tensorType = type.cast<RankedTensorType>();
return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape(),
tensorType.getElementType());
}
unsigned getTotalElemsPerThread(Type type) {
if (type.isIntOrIndexOrFloat() || type.isa<triton::PointerType>())
return 1;
auto tensorType = type.cast<RankedTensorType>();
return getTotalElemsPerThread(tensorType.getEncoding(), tensorType.getShape(),
tensorType.getElementType());
}
SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getThreadsPerWarp().begin(),
@@ -230,7 +252,7 @@ SmallVector<unsigned> getOrder(Attribute layout) {
assert(0 && "Unimplemented usage of getOrder");
return {};
}
};
}
bool isaDistributedLayout(Attribute layout) {
return layout.isa<BlockedEncodingAttr>() || layout.isa<MmaEncodingAttr>() ||
@@ -306,9 +328,9 @@ static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr,
SliceEncodingAttr BlockedEncodingAttr::squeeze(int axis) {
return SliceEncodingAttr::get(getContext(), axis, *this);
}
unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
SmallVector<unsigned>
BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
size_t rank = shape.size();
auto sizePerThread = getSizePerThread();
auto warpsPerCTA = getWarpsPerCTA();
@@ -320,7 +342,11 @@ unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
unsigned t = sizePerThread[i] * threadsPerWarp[i] * warpsPerCTA[i];
elemsPerThread[i] = ceil<unsigned>(shape[i], t) * sizePerThread[i];
}
return product<unsigned>(elemsPerThread);
return elemsPerThread;
}
unsigned BlockedEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
return product<unsigned>(getElemsPerThread(shape, eltTy));
}
template <class T>
@@ -343,19 +369,26 @@ SliceEncodingAttr::paddedShape<unsigned>(ArrayRef<unsigned> shape) const;
template SmallVector<int64_t>
SliceEncodingAttr::paddedShape<int64_t>(ArrayRef<int64_t> shape) const;
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
SmallVector<unsigned>
SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
auto parent = getParent();
return ::getElemsPerThread(parent, paddedShape(shape), eltTy);
auto parentElemsPerThread =
::getElemsPerThread(parent, paddedShape(shape), eltTy);
return parentElemsPerThread;
}
unsigned SliceEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
return product<unsigned>(getElemsPerThread(shape, eltTy));
}
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
SmallVector<unsigned>
MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
size_t rank = shape.size();
assert(rank == 2 && "Unexpected rank of mma layout");
assert((isVolta() || isAmpere()) && "Only version 1 and 2 is supported");
int res = 0;
SmallVector<unsigned> elemsPerThread(rank);
if (isVolta()) {
auto [isARow, isBRow, isAVec4, isBVec4, id] = decodeVoltaLayoutStates();
static constexpr std::array<unsigned, 2> fpw{{2, 2}};
@@ -369,21 +402,34 @@ unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
unsigned wptN = getWarpsPerCTA()[1];
unsigned resM = repM * std::max<int>(1, shape[0] / (spwM * wptM));
unsigned resN = 2 * repN * std::max<int>(1, shape[1] / (spwN * wptN));
res = resM * resN;
elemsPerThread[0] = resM;
elemsPerThread[1] = resN;
} else if (isAmpere()) {
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
res = elemsCol * elemsRow;
unsigned elemsRow = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
unsigned elemsCol = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
elemsPerThread[0] = elemsRow;
elemsPerThread[1] = elemsCol;
} else {
llvm_unreachable("Unexpected mma version");
}
return res;
return elemsPerThread;
}
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
llvm_unreachable("Unexpected shared layout");
unsigned MmaEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
return product<unsigned>(getElemsPerThread(shape, eltTy));
}
SmallVector<unsigned>
SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
llvm_unreachable("getElemsPerThread is not supported for shared layout");
return SmallVector<unsigned>();
}
unsigned SharedEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
llvm_unreachable("getElemsPerThread is not supported for shared layout");
return 0;
}
@@ -405,8 +451,15 @@ DotOperandEncodingAttr::getMMAv2Rep(ArrayRef<int64_t> shape,
}
}
unsigned DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
SmallVector<unsigned>
DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
llvm_unreachable("getElemsPerThread is not supported for dot operand");
return SmallVector<unsigned>();
}
unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
if (auto mmaParent = getParent().dyn_cast<MmaEncodingAttr>()) {
int warpsPerCTAM = mmaParent.getWarpsPerCTA()[0];
int warpsPerCTAN = mmaParent.getWarpsPerCTA()[1];

View File

@@ -100,7 +100,7 @@ bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
op->getParentOfType<ModuleOp>()->getAttrOfType<IntegerAttr>(
"triton_gpu.num-warps");
if (numWarps) {
int sizePerThread = triton::gpu::getElemsPerThread(ptrType);
int sizePerThread = triton::gpu::getTotalElemsPerThread(ptrType);
if (ptrType.getNumElements() < numWarps.getInt() * 32)
return false;
}