mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Support scan on dimensions other that fastest moving one (#1863)
This relax the restriction in the scan lowering to support layout where we scan along a dimension which isn't the fastest moving one. This is done by relaxing how we accesses elements during scanning and allow elements to be strided.
This commit is contained in:
@@ -117,13 +117,12 @@ bool ReduceOpHelper::isSupportedLayout() {
|
||||
return false;
|
||||
}
|
||||
|
||||
unsigned ScanLoweringHelper::getAxisNumElementsPerThreads() {
|
||||
unsigned ScanLoweringHelper::getAxisNumElementsPerThread() {
|
||||
return getEncoding().getSizePerThread()[getAxis()];
|
||||
}
|
||||
|
||||
unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread() {
|
||||
SmallVector<unsigned> sizePerThreads(getEncoding().getSizePerThread().begin(),
|
||||
getEncoding().getSizePerThread().end());
|
||||
SmallVector<unsigned> sizePerThreads = getContigPerThread(getEncoding());
|
||||
sizePerThreads[getAxis()] = 1;
|
||||
return product<unsigned>(sizePerThreads);
|
||||
}
|
||||
@@ -159,8 +158,9 @@ unsigned ScanLoweringHelper::getAxisNumBlocks() {
|
||||
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
|
||||
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
|
||||
unsigned axis = getAxis();
|
||||
return type.getShape()[axis] /
|
||||
(sizePerThreads[axis] * threadsPerWarp[axis] * warpsPerCTA[axis]);
|
||||
return ceil<unsigned>(
|
||||
type.getShape()[axis],
|
||||
(sizePerThreads[axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
|
||||
}
|
||||
|
||||
unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
|
||||
@@ -173,19 +173,18 @@ unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
|
||||
for (unsigned i = 0; i < sizePerThreads.size(); i++) {
|
||||
if (i == axis)
|
||||
continue;
|
||||
numBlocks *= type.getShape()[i] /
|
||||
(sizePerThreads[i] * threadsPerWarp[i] * warpsPerCTA[i]);
|
||||
numBlocks *= ceil<unsigned>(
|
||||
type.getShape()[i],
|
||||
(sizePerThreads[i] * threadsPerWarp[i] * warpsPerCTA[i]));
|
||||
}
|
||||
return numBlocks;
|
||||
}
|
||||
|
||||
bool ScanLoweringHelper::isSupported() {
|
||||
// TODO: Support the following cases:
|
||||
// 1. Scan on the non-fast changing dimension
|
||||
// 2. Scan on non-blocking encodings
|
||||
// 3. Scan with multiple operands
|
||||
if (getAxis() != triton::gpu::getOrder(srcEncoding)[0] ||
|
||||
!isa<triton::gpu::BlockedEncodingAttr>(srcEncoding))
|
||||
// 1. Scan on non-blocking encodings
|
||||
// 2. Scan with multiple operands
|
||||
if (!isa<triton::gpu::BlockedEncodingAttr>(srcEncoding))
|
||||
return false;
|
||||
if (scanOp.getNumOperands() != 1)
|
||||
return false;
|
||||
@@ -194,16 +193,58 @@ bool ScanLoweringHelper::isSupported() {
|
||||
|
||||
unsigned ScanLoweringHelper::getScratchSizeInBytes() {
|
||||
auto type = scanOp.getOperand(0).getType().cast<RankedTensorType>();
|
||||
unsigned numElement =
|
||||
type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||
return numElement /
|
||||
(getAxisNumElementsPerThreads() * getAxisNumThreadsPerWarp());
|
||||
unsigned elementSizeInBytes = type.getElementTypeBitWidth() / 8;
|
||||
auto mod = scanOp->getParentOfType<ModuleOp>();
|
||||
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
unsigned numNonAxisElementsPerWapr =
|
||||
getNonAxisNumThreadsPerWarp() * getNonAxisNumElementsPerThread();
|
||||
unsigned numElements = numWarps * numNonAxisElementsPerWapr *
|
||||
getAxisNumBlocks() * getNonAxisNumBlocks();
|
||||
return elementSizeInBytes * numElements;
|
||||
}
|
||||
|
||||
triton::gpu::BlockedEncodingAttr ScanLoweringHelper::getEncoding() {
|
||||
return srcEncoding.cast<triton::gpu::BlockedEncodingAttr>();
|
||||
}
|
||||
|
||||
unsigned ScanLoweringHelper::getAxisElementStride() {
|
||||
auto order = triton::gpu::getOrder(srcEncoding);
|
||||
unsigned stride = 1;
|
||||
for (unsigned dim : order) {
|
||||
if (dim == getAxis())
|
||||
return stride;
|
||||
stride *= getContigPerThread(getEncoding())[dim];
|
||||
}
|
||||
llvm_unreachable("Axis not found in order");
|
||||
}
|
||||
|
||||
unsigned ScanLoweringHelper::getAxisThreadStride() {
|
||||
auto order = triton::gpu::getOrder(srcEncoding);
|
||||
unsigned stride = 1;
|
||||
for (unsigned dim : order) {
|
||||
if (dim == getAxis())
|
||||
return stride;
|
||||
stride *= getEncoding().getThreadsPerWarp()[dim];
|
||||
}
|
||||
llvm_unreachable("Axis not found in order");
|
||||
}
|
||||
|
||||
unsigned ScanLoweringHelper::getAxisBlockStride() {
|
||||
auto order = triton::gpu::getOrder(srcEncoding);
|
||||
unsigned stride = 1;
|
||||
auto type = scanOp.getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto sizePerThreads = triton::gpu::getSizePerThread(srcEncoding);
|
||||
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
|
||||
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
|
||||
for (unsigned dim : order) {
|
||||
if (dim == getAxis())
|
||||
return stride;
|
||||
stride *= type.getShape()[dim] /
|
||||
(sizePerThreads[dim] * threadsPerWarp[dim] * warpsPerCTA[dim]);
|
||||
}
|
||||
llvm_unreachable("Axis not found in order");
|
||||
}
|
||||
|
||||
bool maybeSharedAllocationOp(Operation *op) {
|
||||
// TODO(Keren): This function can be replaced by adding
|
||||
// MemoryEffectOpInterface. We can then use the MemoryEffectOpInterface to
|
||||
|
||||
Reference in New Issue
Block a user