mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Decomposed getElemsPerThread to return a vector of the per-dim elements per thread (#1549)
This is a prerequisite for updating the semantics of SliceEncodingAttr.
This commit is contained in:
@@ -8,7 +8,7 @@ using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
// Contains some helper functions for both Load and Store conversions.
|
||||
@@ -72,7 +72,7 @@ struct LoadOpConversion
|
||||
Type valueElemTy =
|
||||
typeConverter->convertType(getElementTypeOrSelf(valueTy));
|
||||
unsigned vec = getVectorSize(ptr);
|
||||
unsigned numElems = getElemsPerThread(ptr.getType());
|
||||
unsigned numElems = getTotalElemsPerThread(ptr.getType());
|
||||
if (llMask)
|
||||
vec = std::min<size_t>(vec, getMaskAlignment(mask));
|
||||
|
||||
@@ -276,7 +276,7 @@ struct StoreOpConversion
|
||||
typeConverter->convertType(getElementTypeOrSelf(valueTy));
|
||||
|
||||
unsigned vec = getVectorSize(ptr);
|
||||
unsigned elemsPerThread = getElemsPerThread(ptr.getType());
|
||||
unsigned elemsPerThread = getTotalElemsPerThread(ptr.getType());
|
||||
|
||||
auto ptrElems = getTypeConverter()->unpackLLElements(loc, llPtr, rewriter,
|
||||
ptr.getType());
|
||||
@@ -493,7 +493,7 @@ struct AtomicRMWOpConversion
|
||||
tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType())
|
||||
: op.getResult().getType();
|
||||
const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
|
||||
auto elemsPerThread = getElemsPerThread(val.getType());
|
||||
auto elemsPerThread = getTotalElemsPerThread(val.getType());
|
||||
// vec = 1, numElements = 1 for scalar
|
||||
auto vec = getVectorSize(ptr);
|
||||
int numElems = 1;
|
||||
@@ -776,7 +776,7 @@ struct InsertSliceAsyncOpConversion
|
||||
unsigned inVec = getContiguity(src);
|
||||
unsigned outVec = resSharedLayout.getVec();
|
||||
unsigned minVec = std::min(outVec, inVec);
|
||||
unsigned numElems = getElemsPerThread(srcTy);
|
||||
unsigned numElems = getTotalElemsPerThread(srcTy);
|
||||
unsigned perPhase = resSharedLayout.getPerPhase();
|
||||
unsigned maxPhase = resSharedLayout.getMaxPhase();
|
||||
auto sizePerThread = srcBlockedLayout.getSizePerThread();
|
||||
|
||||
Reference in New Issue
Block a user