[BACKEND] Decomposed getElemsPerThread to return a vector of the per-dim elements per thread (#1549)

This is a prerequisite for updating the semantics of SliceEncodingAttr.
This commit is contained in:
zahimoud
2023-04-19 16:52:16 -07:00
committed by GitHub
parent b42e3d06d4
commit 8f7424221f
16 changed files with 121 additions and 65 deletions

View File

@@ -8,7 +8,7 @@ using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::getTotalElemsPerThread;
using ::mlir::triton::gpu::SharedEncodingAttr;
// Contains some helper functions for both Load and Store conversions.
@@ -72,7 +72,7 @@ struct LoadOpConversion
Type valueElemTy =
typeConverter->convertType(getElementTypeOrSelf(valueTy));
unsigned vec = getVectorSize(ptr);
unsigned numElems = getElemsPerThread(ptr.getType());
unsigned numElems = getTotalElemsPerThread(ptr.getType());
if (llMask)
vec = std::min<size_t>(vec, getMaskAlignment(mask));
@@ -276,7 +276,7 @@ struct StoreOpConversion
typeConverter->convertType(getElementTypeOrSelf(valueTy));
unsigned vec = getVectorSize(ptr);
unsigned elemsPerThread = getElemsPerThread(ptr.getType());
unsigned elemsPerThread = getTotalElemsPerThread(ptr.getType());
auto ptrElems = getTypeConverter()->unpackLLElements(loc, llPtr, rewriter,
ptr.getType());
@@ -493,7 +493,7 @@ struct AtomicRMWOpConversion
tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType())
: op.getResult().getType();
const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
auto elemsPerThread = getElemsPerThread(val.getType());
auto elemsPerThread = getTotalElemsPerThread(val.getType());
// vec = 1, numElements = 1 for scalar
auto vec = getVectorSize(ptr);
int numElems = 1;
@@ -776,7 +776,7 @@ struct InsertSliceAsyncOpConversion
unsigned inVec = getContiguity(src);
unsigned outVec = resSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
unsigned numElems = getElemsPerThread(srcTy);
unsigned numElems = getTotalElemsPerThread(srcTy);
unsigned perPhase = resSharedLayout.getPerPhase();
unsigned maxPhase = resSharedLayout.getMaxPhase();
auto sizePerThread = srcBlockedLayout.getSizePerThread();