[BACKEND] avoid code duplication for fully warp-synchronous reductions (#1978)

This commit is contained in:
Philippe Tillet
2023-07-21 16:06:00 -07:00
committed by GitHub
parent 07c346b948
commit 1db3bdc52e
4 changed files with 15 additions and 5 deletions

View File

@@ -81,6 +81,12 @@ SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
return smemShape;
}
bool ReduceOpHelper::isWarpSynchronous() {
auto argsLayout = getSrcLayout();
return isFastReduction() &&
(triton::gpu::getWarpsPerCTA(argsLayout)[axis] == 1);
}
SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
SmallVector<SmallVector<unsigned>> smemShapes(3);
@@ -88,7 +94,7 @@ SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
auto argLayoutMma = argLayout.dyn_cast<triton::gpu::MmaEncodingAttr>();
// that case doesn't need inter-warp communication
if (isFastReduction() && triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
if (isWarpSynchronous())
return {{0, 0}, {0, 0}};
/// shared memory block0