mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] avoid code duplication for fully warp-synchronous reductions (#1978)
This commit is contained in:
@@ -81,6 +81,12 @@ SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
|
||||
return smemShape;
|
||||
}
|
||||
|
||||
bool ReduceOpHelper::isWarpSynchronous() {
|
||||
auto argsLayout = getSrcLayout();
|
||||
return isFastReduction() &&
|
||||
(triton::gpu::getWarpsPerCTA(argsLayout)[axis] == 1);
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
|
||||
SmallVector<SmallVector<unsigned>> smemShapes(3);
|
||||
|
||||
@@ -88,7 +94,7 @@ SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
|
||||
auto argLayoutMma = argLayout.dyn_cast<triton::gpu::MmaEncodingAttr>();
|
||||
|
||||
// that case doesn't need inter-warp communication
|
||||
if (isFastReduction() && triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
|
||||
if (isWarpSynchronous())
|
||||
return {{0, 0}, {0, 0}};
|
||||
|
||||
/// shared memory block0
|
||||
|
||||
Reference in New Issue
Block a user