[BACKEND] avoid code duplication for fully warp-synchronous reductions (#1978)

This commit is contained in:
Philippe Tillet
2023-07-21 16:06:00 -07:00
committed by GitHub
parent 07c346b948
commit 1db3bdc52e
4 changed files with 15 additions and 5 deletions

View File

@@ -152,7 +152,7 @@ jobs:
if: ${{ matrix.runner != 'macos-10.15' && (matrix.runner[1] != 'arc770') }}
run: |
python3 -m pip install --upgrade pre-commit
python3 -m pre_commit run --all-files
python3 -m pre_commit run --all-files --verbose
- name: Check pre-commit arc770
if: ${{ matrix.runner != 'macos-10.15' && (matrix.runner[1] == 'arc770') }}

View File

@@ -37,6 +37,8 @@ public:
bool isFastReduction();
bool isWarpSynchronous();
unsigned getInterWarpSize();
unsigned getIntraWarpSize();

View File

@@ -81,6 +81,12 @@ SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
return smemShape;
}
bool ReduceOpHelper::isWarpSynchronous() {
auto argsLayout = getSrcLayout();
return isFastReduction() &&
(triton::gpu::getWarpsPerCTA(argsLayout)[axis] == 1);
}
SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
SmallVector<SmallVector<unsigned>> smemShapes(3);
@@ -88,7 +94,7 @@ SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
auto argLayoutMma = argLayout.dyn_cast<triton::gpu::MmaEncodingAttr>();
// that case doesn't need inter-warp communication
if (isFastReduction() && triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
if (isWarpSynchronous())
return {{0, 0}, {0, 0}};
/// shared memory block0

View File

@@ -335,7 +335,9 @@ private:
unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData();
SmallVector<Value> smemBases(op.getNumOperands());
if (sizeInterWarps > 1) {
bool isWarpSync = helper.isWarpSynchronous();
if (!isWarpSync) {
smemBases[0] = bitcast(
getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]);
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
@@ -403,7 +405,7 @@ private:
accumulate(rewriter, *combineOp, acc, shfl, false);
}
if (sizeInterWarps == 1) {
if (isWarpSync) {
finalAccs[key] = acc;
continue;
}
@@ -418,7 +420,7 @@ private:
}
}
if (sizeInterWarps == 1) {
if (isWarpSync) {
SmallVector<Value> results(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
if (auto resultTy =