mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Determine fast reduce based on the parent layout (#1892)
This commit is contained in:
@@ -10,11 +10,31 @@
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
int getParentAxis(Attribute layout, int axis) {
|
||||
if (auto sliceEncoding = layout.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
|
||||
axis = axis < sliceEncoding.getDim() ? axis : axis + 1;
|
||||
return getParentAxis(sliceEncoding.getParent(), axis);
|
||||
}
|
||||
return axis;
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getParentOrder(Attribute layout) {
|
||||
if (auto sliceEncoding = layout.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
|
||||
return getParentOrder(sliceEncoding.getParent());
|
||||
}
|
||||
return triton::gpu::getOrder(layout);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool ReduceOpHelper::isFastReduction() {
|
||||
// Disable fast reduction only for debugging purpose
|
||||
if (::triton::tools::getBoolEnv("DISABLE_FAST_REDUCTION"))
|
||||
return false;
|
||||
return axis == triton::gpu::getOrder(getSrcLayout())[0];
|
||||
return getParentAxis(getSrcLayout(), axis) ==
|
||||
getParentOrder(getSrcLayout())[0];
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getInterWarpSize() {
|
||||
|
||||
Reference in New Issue
Block a user