[BACKEND] Determine fast reduce based on the parent layout (#1892)

This commit is contained in:
Keren Zhou
2023-07-05 14:52:22 -04:00
committed by GitHub
parent ae0ee5248f
commit 4255ef0e9e

View File

@@ -10,11 +10,31 @@
namespace mlir {
namespace {
int getParentAxis(Attribute layout, int axis) {
if (auto sliceEncoding = layout.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
axis = axis < sliceEncoding.getDim() ? axis : axis + 1;
return getParentAxis(sliceEncoding.getParent(), axis);
}
return axis;
}
SmallVector<unsigned> getParentOrder(Attribute layout) {
if (auto sliceEncoding = layout.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
return getParentOrder(sliceEncoding.getParent());
}
return triton::gpu::getOrder(layout);
}
} // namespace
bool ReduceOpHelper::isFastReduction() {
// Disable fast reduction only for debugging purpose
if (::triton::tools::getBoolEnv("DISABLE_FAST_REDUCTION"))
return false;
return axis == triton::gpu::getOrder(getSrcLayout())[0];
return getParentAxis(getSrcLayout(), axis) ==
getParentOrder(getSrcLayout())[0];
}
unsigned ReduceOpHelper::getInterWarpSize() {