Merge pull request #268 from ROCmSoftwarePlatform/improve_reduce_for_fa

[CHERRY-PICKED FROM UPSTREAM][BACKEND] no longer uses shared mem or barriers for single-warp reductions (openai#1915)
This commit is contained in:
jayfurmanek
2023-08-21 13:29:11 -05:00
committed by GitHub
6 changed files with 83 additions and 33 deletions

View File

@@ -60,9 +60,10 @@ SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
auto argLayout = getSrcLayout();
auto argLayoutMma = argLayout.dyn_cast<triton::gpu::MmaEncodingAttr>();
// if (argLayoutMma && argLayoutMma.getVersionMajor() == 2 &&
// triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
// return {{1, 1}, {1, 1}};
// that case doesn't need inter-warp communication
if (isFastReduction() && triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
return {{0, 0}, {0, 0}};
/// shared memory block0
smemShapes[0] = convertType<unsigned>(getSrcShape());