[BACKEND] Support scan on dimensions other that fastest moving one (#1863)

This relax the restriction in the scan lowering to support layout where
we scan along a dimension which isn't the fastest moving one. This is
done by relaxing how we accesses elements during scanning and allow
elements to be strided.
This commit is contained in:
Thomas
2023-06-30 12:40:48 -07:00
committed by GitHub
parent 66ed53d19d
commit 2e3182bab7
4 changed files with 200 additions and 127 deletions

View File

@@ -117,13 +117,12 @@ bool ReduceOpHelper::isSupportedLayout() {
return false;
}
unsigned ScanLoweringHelper::getAxisNumElementsPerThreads() {
unsigned ScanLoweringHelper::getAxisNumElementsPerThread() {
return getEncoding().getSizePerThread()[getAxis()];
}
unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread() {
SmallVector<unsigned> sizePerThreads(getEncoding().getSizePerThread().begin(),
getEncoding().getSizePerThread().end());
SmallVector<unsigned> sizePerThreads = getContigPerThread(getEncoding());
sizePerThreads[getAxis()] = 1;
return product<unsigned>(sizePerThreads);
}
@@ -159,8 +158,9 @@ unsigned ScanLoweringHelper::getAxisNumBlocks() {
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
unsigned axis = getAxis();
return type.getShape()[axis] /
(sizePerThreads[axis] * threadsPerWarp[axis] * warpsPerCTA[axis]);
return ceil<unsigned>(
type.getShape()[axis],
(sizePerThreads[axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
}
unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
@@ -173,19 +173,18 @@ unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
for (unsigned i = 0; i < sizePerThreads.size(); i++) {
if (i == axis)
continue;
numBlocks *= type.getShape()[i] /
(sizePerThreads[i] * threadsPerWarp[i] * warpsPerCTA[i]);
numBlocks *= ceil<unsigned>(
type.getShape()[i],
(sizePerThreads[i] * threadsPerWarp[i] * warpsPerCTA[i]));
}
return numBlocks;
}
bool ScanLoweringHelper::isSupported() {
// TODO: Support the following cases:
// 1. Scan on the non-fast changing dimension
// 2. Scan on non-blocking encodings
// 3. Scan with multiple operands
if (getAxis() != triton::gpu::getOrder(srcEncoding)[0] ||
!isa<triton::gpu::BlockedEncodingAttr>(srcEncoding))
// 1. Scan on non-blocking encodings
// 2. Scan with multiple operands
if (!isa<triton::gpu::BlockedEncodingAttr>(srcEncoding))
return false;
if (scanOp.getNumOperands() != 1)
return false;
@@ -194,16 +193,58 @@ bool ScanLoweringHelper::isSupported() {
unsigned ScanLoweringHelper::getScratchSizeInBytes() {
auto type = scanOp.getOperand(0).getType().cast<RankedTensorType>();
unsigned numElement =
type.getNumElements() * type.getElementTypeBitWidth() / 8;
return numElement /
(getAxisNumElementsPerThreads() * getAxisNumThreadsPerWarp());
unsigned elementSizeInBytes = type.getElementTypeBitWidth() / 8;
auto mod = scanOp->getParentOfType<ModuleOp>();
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
unsigned numNonAxisElementsPerWapr =
getNonAxisNumThreadsPerWarp() * getNonAxisNumElementsPerThread();
unsigned numElements = numWarps * numNonAxisElementsPerWapr *
getAxisNumBlocks() * getNonAxisNumBlocks();
return elementSizeInBytes * numElements;
}
triton::gpu::BlockedEncodingAttr ScanLoweringHelper::getEncoding() {
return srcEncoding.cast<triton::gpu::BlockedEncodingAttr>();
}
unsigned ScanLoweringHelper::getAxisElementStride() {
auto order = triton::gpu::getOrder(srcEncoding);
unsigned stride = 1;
for (unsigned dim : order) {
if (dim == getAxis())
return stride;
stride *= getContigPerThread(getEncoding())[dim];
}
llvm_unreachable("Axis not found in order");
}
unsigned ScanLoweringHelper::getAxisThreadStride() {
auto order = triton::gpu::getOrder(srcEncoding);
unsigned stride = 1;
for (unsigned dim : order) {
if (dim == getAxis())
return stride;
stride *= getEncoding().getThreadsPerWarp()[dim];
}
llvm_unreachable("Axis not found in order");
}
unsigned ScanLoweringHelper::getAxisBlockStride() {
auto order = triton::gpu::getOrder(srcEncoding);
unsigned stride = 1;
auto type = scanOp.getOperand(0).getType().cast<RankedTensorType>();
auto sizePerThreads = triton::gpu::getSizePerThread(srcEncoding);
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
for (unsigned dim : order) {
if (dim == getAxis())
return stride;
stride *= type.getShape()[dim] /
(sizePerThreads[dim] * threadsPerWarp[dim] * warpsPerCTA[dim]);
}
llvm_unreachable("Axis not found in order");
}
bool maybeSharedAllocationOp(Operation *op) {
// TODO(Keren): This function can be replaced by adding
// MemoryEffectOpInterface. We can then use the MemoryEffectOpInterface to