mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Decomposed getElemsPerThread to return a vector of the per-dim elements per thread (#1549)
This is a prerequisite for updating the semantics of SliceEncodingAttr.
This commit is contained in:
@@ -21,7 +21,9 @@ namespace mlir {
|
||||
namespace triton {
|
||||
namespace gpu {
|
||||
|
||||
unsigned getElemsPerThread(Type type);
|
||||
unsigned getTotalElemsPerThread(Type type);
|
||||
|
||||
SmallVector<unsigned> getElemsPerThread(Type type);
|
||||
|
||||
SmallVector<unsigned> getThreadsPerWarp(Attribute layout);
|
||||
|
||||
|
||||
@@ -35,7 +35,8 @@ Right now, Triton implements two classes of layouts: shared, and distributed.
|
||||
}];
|
||||
|
||||
code extraBaseClassDeclaration = [{
|
||||
unsigned getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
|
||||
unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
|
||||
SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
|
||||
::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const;
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -5,10 +5,10 @@ using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::LLVM::getStridesFromShapeAndOrder;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::getContigPerThread;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
using ::mlir::triton::gpu::getOrder;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::getSizePerThread;
|
||||
using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
using ::mlir::triton::gpu::isaDistributedLayout;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
@@ -422,14 +422,14 @@ private:
|
||||
}
|
||||
// Potentially we need to store for multiple CTAs in this replication
|
||||
auto accumNumReplicates = product<unsigned>(numReplicates);
|
||||
// unsigned elems = getElemsPerThread(srcTy);
|
||||
// unsigned elems = getTotalElemsPerThread(srcTy);
|
||||
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
|
||||
rewriter, srcTy);
|
||||
unsigned inVec = 0;
|
||||
unsigned outVec = 0;
|
||||
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
|
||||
|
||||
unsigned outElems = getElemsPerThread(dstTy);
|
||||
unsigned outElems = getTotalElemsPerThread(dstTy);
|
||||
auto outOrd = getOrder(dstLayout);
|
||||
SmallVector<Value> outVals(outElems);
|
||||
|
||||
@@ -572,7 +572,7 @@ private:
|
||||
// get source values
|
||||
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
|
||||
rewriter, srcTy);
|
||||
unsigned elems = getElemsPerThread(srcTy);
|
||||
unsigned elems = getTotalElemsPerThread(srcTy);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(srcTy.getElementType());
|
||||
// for the destination type, we need to pack values together
|
||||
|
||||
@@ -6,10 +6,10 @@ using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::LLVM::getStridesFromShapeAndOrder;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::getContigPerThread;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
using ::mlir::triton::gpu::getOrder;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::getSizePerThread;
|
||||
using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
using ::mlir::triton::gpu::isaDistributedLayout;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
|
||||
@@ -7,10 +7,10 @@ using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::LLVM::getStridesFromShapeAndOrder;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::getContigPerThread;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
using ::mlir::triton::gpu::getOrder;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::getSizePerThread;
|
||||
using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
using ::mlir::triton::gpu::isaDistributedLayout;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
|
||||
@@ -6,10 +6,10 @@ using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::LLVM::getStridesFromShapeAndOrder;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::getContigPerThread;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
using ::mlir::triton::gpu::getOrder;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::getSizePerThread;
|
||||
using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
using ::mlir::triton::gpu::isaDistributedLayout;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ Value loadC(Value tensor, Value llTensor,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
MLIRContext *ctx = tensor.getContext();
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
size_t fcSize = triton::gpu::getElemsPerThread(tensor.getType());
|
||||
size_t fcSize = triton::gpu::getTotalElemsPerThread(tensor.getType());
|
||||
|
||||
assert(tensorTy.getEncoding().isa<MmaEncodingAttr>() &&
|
||||
"Currently, we only support $c with a mma layout.");
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
|
||||
struct FpToFpOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::FpToFpOp> {
|
||||
@@ -499,7 +499,7 @@ struct FpToFpOpConversion
|
||||
auto srcEltType = srcTensorType.getElementType();
|
||||
auto dstEltType = dstTensorType.getElementType();
|
||||
auto loc = op->getLoc();
|
||||
auto elems = getElemsPerThread(dstTensorType);
|
||||
auto elems = getTotalElemsPerThread(dstTensorType);
|
||||
SmallVector<Value> resultVals;
|
||||
bool isSrcFP8 =
|
||||
srcEltType.isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>();
|
||||
@@ -583,7 +583,7 @@ public:
|
||||
auto resultTy = op.getType();
|
||||
Location loc = op->getLoc();
|
||||
|
||||
unsigned elems = getElemsPerThread(resultTy);
|
||||
unsigned elems = getTotalElemsPerThread(resultTy);
|
||||
auto resultElementTy = getElementTypeOrSelf(resultTy);
|
||||
Type elemTy = this->getTypeConverter()->convertType(resultElementTy);
|
||||
SmallVector<Type> types(elems, elemTy);
|
||||
|
||||
@@ -8,7 +8,7 @@ using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
// Contains some helper functions for both Load and Store conversions.
|
||||
@@ -72,7 +72,7 @@ struct LoadOpConversion
|
||||
Type valueElemTy =
|
||||
typeConverter->convertType(getElementTypeOrSelf(valueTy));
|
||||
unsigned vec = getVectorSize(ptr);
|
||||
unsigned numElems = getElemsPerThread(ptr.getType());
|
||||
unsigned numElems = getTotalElemsPerThread(ptr.getType());
|
||||
if (llMask)
|
||||
vec = std::min<size_t>(vec, getMaskAlignment(mask));
|
||||
|
||||
@@ -276,7 +276,7 @@ struct StoreOpConversion
|
||||
typeConverter->convertType(getElementTypeOrSelf(valueTy));
|
||||
|
||||
unsigned vec = getVectorSize(ptr);
|
||||
unsigned elemsPerThread = getElemsPerThread(ptr.getType());
|
||||
unsigned elemsPerThread = getTotalElemsPerThread(ptr.getType());
|
||||
|
||||
auto ptrElems = getTypeConverter()->unpackLLElements(loc, llPtr, rewriter,
|
||||
ptr.getType());
|
||||
@@ -493,7 +493,7 @@ struct AtomicRMWOpConversion
|
||||
tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType())
|
||||
: op.getResult().getType();
|
||||
const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
|
||||
auto elemsPerThread = getElemsPerThread(val.getType());
|
||||
auto elemsPerThread = getTotalElemsPerThread(val.getType());
|
||||
// vec = 1, numElements = 1 for scalar
|
||||
auto vec = getVectorSize(ptr);
|
||||
int numElems = 1;
|
||||
@@ -776,7 +776,7 @@ struct InsertSliceAsyncOpConversion
|
||||
unsigned inVec = getContiguity(src);
|
||||
unsigned outVec = resSharedLayout.getVec();
|
||||
unsigned minVec = std::min(outVec, inVec);
|
||||
unsigned numElems = getElemsPerThread(srcTy);
|
||||
unsigned numElems = getTotalElemsPerThread(srcTy);
|
||||
unsigned perPhase = resSharedLayout.getPerPhase();
|
||||
unsigned maxPhase = resSharedLayout.getMaxPhase();
|
||||
auto sizePerThread = srcBlockedLayout.getSizePerThread();
|
||||
|
||||
@@ -5,8 +5,8 @@ using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::shflSync;
|
||||
using ::mlir::LLVM::storeShared;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
using ::mlir::triton::gpu::getOrder;
|
||||
using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
|
||||
struct ReduceOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::ReduceOp> {
|
||||
@@ -64,7 +64,7 @@ private:
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto types = op.getInputTypes();
|
||||
auto operands = adaptor.getOperands();
|
||||
unsigned srcElems = getElemsPerThread(types[0]);
|
||||
unsigned srcElems = getTotalElemsPerThread(types[0]);
|
||||
SmallVector<SmallVector<Value>> srcValues(srcElems);
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
auto values = getTypeConverter()->unpackLLElements(loc, operands[i],
|
||||
@@ -158,7 +158,7 @@ private:
|
||||
elemPtrTys[i]);
|
||||
}
|
||||
|
||||
unsigned srcElems = getElemsPerThread(srcTys[0]);
|
||||
unsigned srcElems = getTotalElemsPerThread(srcTys[0]);
|
||||
// Emits indices of the original tensor that each thread
|
||||
// would own
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]);
|
||||
@@ -263,7 +263,7 @@ private:
|
||||
|
||||
auto resultLayout = resultTy.getEncoding();
|
||||
|
||||
unsigned resultElems = getElemsPerThread(resultTy);
|
||||
unsigned resultElems = getTotalElemsPerThread(resultTy);
|
||||
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy);
|
||||
assert(resultIndices.size() == resultElems);
|
||||
|
||||
@@ -330,7 +330,7 @@ private:
|
||||
unsigned sizeIntraWarps = helper.getIntraWarpSize();
|
||||
unsigned sizeInterWarps = helper.getInterWarpSize();
|
||||
|
||||
unsigned srcElems = getElemsPerThread(srcTys[0]);
|
||||
unsigned srcElems = getTotalElemsPerThread(srcTys[0]);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]);
|
||||
auto srcValues = unpackInputs(loc, op, adaptor, rewriter);
|
||||
|
||||
@@ -457,7 +457,7 @@ private:
|
||||
op.getResult()[i].getType().dyn_cast<RankedTensorType>()) {
|
||||
// nd-tensor where n >= 1
|
||||
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
|
||||
unsigned resultElems = getElemsPerThread(resultTy);
|
||||
unsigned resultElems = getTotalElemsPerThread(resultTy);
|
||||
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy);
|
||||
assert(resultIndices.size() == resultElems);
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
|
||||
@@ -435,7 +435,7 @@ struct AddPtrOpConversion
|
||||
auto ptrTy = op.getPtr().getType();
|
||||
auto resultTensorTy = resultTy.dyn_cast<RankedTensorType>();
|
||||
if (resultTensorTy) {
|
||||
unsigned elems = getElemsPerThread(resultTy);
|
||||
unsigned elems = getTotalElemsPerThread(resultTy);
|
||||
Type elemTy =
|
||||
getTypeConverter()->convertType(resultTensorTy.getElementType());
|
||||
auto ptrs = getTypeConverter()->unpackLLElements(loc, adaptor.getPtr(),
|
||||
|
||||
@@ -275,7 +275,7 @@ public:
|
||||
|
||||
auto srcEncoding = srcTy.getEncoding();
|
||||
auto srcShape = srcTy.getShape();
|
||||
unsigned numElems = triton::gpu::getElemsPerThread(srcTy);
|
||||
unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy);
|
||||
// swizzling params as described in TritonGPUAttrDefs.td
|
||||
unsigned outVec = resSharedLayout.getVec();
|
||||
unsigned perPhase = resSharedLayout.getPerPhase();
|
||||
@@ -376,7 +376,7 @@ public:
|
||||
unsigned minVec = std::min(outVec, inVec);
|
||||
unsigned perPhase = dstSharedLayout.getPerPhase();
|
||||
unsigned maxPhase = dstSharedLayout.getMaxPhase();
|
||||
unsigned numElems = triton::gpu::getElemsPerThread(srcTy);
|
||||
unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy);
|
||||
assert(numElems == srcIndices.size());
|
||||
auto inVals =
|
||||
getTypeConverter()->unpackLLElements(loc, llSrc, rewriter, srcTy);
|
||||
@@ -667,7 +667,7 @@ private:
|
||||
threadOffset * sizePerThread[k] + elemOffset);
|
||||
}
|
||||
|
||||
unsigned elemsPerThread = triton::gpu::getElemsPerThread(type);
|
||||
unsigned elemsPerThread = triton::gpu::getTotalElemsPerThread(type);
|
||||
unsigned totalSizePerThread = product<unsigned>(sizePerThread);
|
||||
SmallVector<SmallVector<unsigned>> reorderedOffset(elemsPerThread);
|
||||
for (unsigned n = 0; n < elemsPerThread; ++n) {
|
||||
|
||||
@@ -8,7 +8,7 @@ using namespace mlir::triton;
|
||||
|
||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||
@@ -144,7 +144,7 @@ Type TritonGPUToLLVMTypeConverter::convertTritonTensorType(
|
||||
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
||||
}
|
||||
|
||||
unsigned numElementsPerThread = getElemsPerThread(type);
|
||||
unsigned numElementsPerThread = getTotalElemsPerThread(type);
|
||||
SmallVector<Type, 4> types(numElementsPerThread, eltType);
|
||||
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
|
||||
struct SplatOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::SplatOp> {
|
||||
@@ -24,7 +24,7 @@ struct SplatOpConversion
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
auto srcType = typeConverter->convertType(elemType);
|
||||
auto llSrc = bitcast(constVal, srcType);
|
||||
size_t elemsPerThread = getElemsPerThread(tensorTy);
|
||||
size_t elemsPerThread = getTotalElemsPerThread(tensorTy);
|
||||
llvm::SmallVector<Value> elems(elemsPerThread, llSrc);
|
||||
return typeConverter->packLLElements(loc, elems, rewriter, resType);
|
||||
}
|
||||
@@ -93,7 +93,7 @@ struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto resultTy = op.getType().template cast<RankedTensorType>();
|
||||
unsigned elems = getElemsPerThread(resultTy);
|
||||
unsigned elems = getTotalElemsPerThread(resultTy);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||
SmallVector<Type> types(elems, elemTy);
|
||||
|
||||
@@ -22,32 +22,54 @@ namespace gpu {
|
||||
// so that all distributed layouts implement
|
||||
// these utilities
|
||||
|
||||
unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape,
|
||||
Type eltTy) {
|
||||
unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape,
|
||||
Type eltTy) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return blockedLayout.getElemsPerThread(shape, eltTy);
|
||||
return blockedLayout.getTotalElemsPerThread(shape, eltTy);
|
||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
return sliceLayout.getElemsPerThread(shape, eltTy);
|
||||
return sliceLayout.getTotalElemsPerThread(shape, eltTy);
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
return mmaLayout.getElemsPerThread(shape, eltTy);
|
||||
return mmaLayout.getTotalElemsPerThread(shape, eltTy);
|
||||
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
|
||||
return sharedLayout.getElemsPerThread(shape, eltTy);
|
||||
return sharedLayout.getTotalElemsPerThread(shape, eltTy);
|
||||
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
return dotLayout.getElemsPerThread(shape, eltTy);
|
||||
return dotLayout.getTotalElemsPerThread(shape, eltTy);
|
||||
} else {
|
||||
assert(0 && "getElemsPerThread not implemented");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
unsigned getElemsPerThread(Type type) {
|
||||
SmallVector<unsigned> getElemsPerThread(Attribute layout,
|
||||
ArrayRef<int64_t> shape, Type eltTy) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return blockedLayout.getElemsPerThread(shape, eltTy);
|
||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
return sliceLayout.getElemsPerThread(shape, eltTy);
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
return mmaLayout.getElemsPerThread(shape, eltTy);
|
||||
} else {
|
||||
assert(0 && "getElemsPerThread not implemented");
|
||||
return SmallVector<unsigned>();
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getElemsPerThread(Type type) {
|
||||
if (type.isIntOrIndexOrFloat() || type.isa<triton::PointerType>())
|
||||
return 1;
|
||||
return SmallVector<unsigned>(1, 1);
|
||||
auto tensorType = type.cast<RankedTensorType>();
|
||||
return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape(),
|
||||
tensorType.getElementType());
|
||||
}
|
||||
|
||||
unsigned getTotalElemsPerThread(Type type) {
|
||||
if (type.isIntOrIndexOrFloat() || type.isa<triton::PointerType>())
|
||||
return 1;
|
||||
auto tensorType = type.cast<RankedTensorType>();
|
||||
return getTotalElemsPerThread(tensorType.getEncoding(), tensorType.getShape(),
|
||||
tensorType.getElementType());
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return SmallVector<unsigned>(blockedLayout.getThreadsPerWarp().begin(),
|
||||
@@ -230,7 +252,7 @@ SmallVector<unsigned> getOrder(Attribute layout) {
|
||||
assert(0 && "Unimplemented usage of getOrder");
|
||||
return {};
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
bool isaDistributedLayout(Attribute layout) {
|
||||
return layout.isa<BlockedEncodingAttr>() || layout.isa<MmaEncodingAttr>() ||
|
||||
@@ -306,9 +328,9 @@ static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr,
|
||||
SliceEncodingAttr BlockedEncodingAttr::squeeze(int axis) {
|
||||
return SliceEncodingAttr::get(getContext(), axis, *this);
|
||||
}
|
||||
|
||||
unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
|
||||
Type eltTy) const {
|
||||
SmallVector<unsigned>
|
||||
BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
|
||||
Type eltTy) const {
|
||||
size_t rank = shape.size();
|
||||
auto sizePerThread = getSizePerThread();
|
||||
auto warpsPerCTA = getWarpsPerCTA();
|
||||
@@ -320,7 +342,11 @@ unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
|
||||
unsigned t = sizePerThread[i] * threadsPerWarp[i] * warpsPerCTA[i];
|
||||
elemsPerThread[i] = ceil<unsigned>(shape[i], t) * sizePerThread[i];
|
||||
}
|
||||
return product<unsigned>(elemsPerThread);
|
||||
return elemsPerThread;
|
||||
}
|
||||
unsigned BlockedEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
|
||||
Type eltTy) const {
|
||||
return product<unsigned>(getElemsPerThread(shape, eltTy));
|
||||
}
|
||||
|
||||
template <class T>
|
||||
@@ -343,19 +369,26 @@ SliceEncodingAttr::paddedShape<unsigned>(ArrayRef<unsigned> shape) const;
|
||||
template SmallVector<int64_t>
|
||||
SliceEncodingAttr::paddedShape<int64_t>(ArrayRef<int64_t> shape) const;
|
||||
|
||||
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
|
||||
Type eltTy) const {
|
||||
SmallVector<unsigned>
|
||||
SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
|
||||
Type eltTy) const {
|
||||
auto parent = getParent();
|
||||
return ::getElemsPerThread(parent, paddedShape(shape), eltTy);
|
||||
auto parentElemsPerThread =
|
||||
::getElemsPerThread(parent, paddedShape(shape), eltTy);
|
||||
return parentElemsPerThread;
|
||||
}
|
||||
unsigned SliceEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
|
||||
Type eltTy) const {
|
||||
return product<unsigned>(getElemsPerThread(shape, eltTy));
|
||||
}
|
||||
|
||||
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
|
||||
Type eltTy) const {
|
||||
SmallVector<unsigned>
|
||||
MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
|
||||
size_t rank = shape.size();
|
||||
assert(rank == 2 && "Unexpected rank of mma layout");
|
||||
assert((isVolta() || isAmpere()) && "Only version 1 and 2 is supported");
|
||||
|
||||
int res = 0;
|
||||
SmallVector<unsigned> elemsPerThread(rank);
|
||||
if (isVolta()) {
|
||||
auto [isARow, isBRow, isAVec4, isBVec4, id] = decodeVoltaLayoutStates();
|
||||
static constexpr std::array<unsigned, 2> fpw{{2, 2}};
|
||||
@@ -369,21 +402,34 @@ unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
|
||||
unsigned wptN = getWarpsPerCTA()[1];
|
||||
unsigned resM = repM * std::max<int>(1, shape[0] / (spwM * wptM));
|
||||
unsigned resN = 2 * repN * std::max<int>(1, shape[1] / (spwN * wptN));
|
||||
res = resM * resN;
|
||||
elemsPerThread[0] = resM;
|
||||
elemsPerThread[1] = resN;
|
||||
} else if (isAmpere()) {
|
||||
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
|
||||
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
|
||||
res = elemsCol * elemsRow;
|
||||
unsigned elemsRow = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
|
||||
unsigned elemsCol = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
|
||||
elemsPerThread[0] = elemsRow;
|
||||
elemsPerThread[1] = elemsCol;
|
||||
} else {
|
||||
llvm_unreachable("Unexpected mma version");
|
||||
}
|
||||
|
||||
return res;
|
||||
return elemsPerThread;
|
||||
}
|
||||
|
||||
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
|
||||
Type eltTy) const {
|
||||
llvm_unreachable("Unexpected shared layout");
|
||||
unsigned MmaEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
|
||||
Type eltTy) const {
|
||||
return product<unsigned>(getElemsPerThread(shape, eltTy));
|
||||
}
|
||||
|
||||
SmallVector<unsigned>
|
||||
SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
|
||||
Type eltTy) const {
|
||||
llvm_unreachable("getElemsPerThread is not supported for shared layout");
|
||||
return SmallVector<unsigned>();
|
||||
}
|
||||
unsigned SharedEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
|
||||
Type eltTy) const {
|
||||
llvm_unreachable("getElemsPerThread is not supported for shared layout");
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -405,8 +451,15 @@ DotOperandEncodingAttr::getMMAv2Rep(ArrayRef<int64_t> shape,
|
||||
}
|
||||
}
|
||||
|
||||
unsigned DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
|
||||
Type eltTy) const {
|
||||
SmallVector<unsigned>
|
||||
DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
|
||||
Type eltTy) const {
|
||||
llvm_unreachable("getElemsPerThread is not supported for dot operand");
|
||||
return SmallVector<unsigned>();
|
||||
}
|
||||
|
||||
unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
|
||||
Type eltTy) const {
|
||||
if (auto mmaParent = getParent().dyn_cast<MmaEncodingAttr>()) {
|
||||
int warpsPerCTAM = mmaParent.getWarpsPerCTA()[0];
|
||||
int warpsPerCTAN = mmaParent.getWarpsPerCTA()[1];
|
||||
|
||||
@@ -100,7 +100,7 @@ bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
|
||||
op->getParentOfType<ModuleOp>()->getAttrOfType<IntegerAttr>(
|
||||
"triton_gpu.num-warps");
|
||||
if (numWarps) {
|
||||
int sizePerThread = triton::gpu::getElemsPerThread(ptrType);
|
||||
int sizePerThread = triton::gpu::getTotalElemsPerThread(ptrType);
|
||||
if (ptrType.getNumElements() < numWarps.getInt() * 32)
|
||||
return false;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user