mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] avoid code duplication for fully warp-synchronous reductions (#1978)
This commit is contained in:
2
.github/workflows/integration-tests.yml
vendored
2
.github/workflows/integration-tests.yml
vendored
@@ -152,7 +152,7 @@ jobs:
|
||||
if: ${{ matrix.runner != 'macos-10.15' && (matrix.runner[1] != 'arc770') }}
|
||||
run: |
|
||||
python3 -m pip install --upgrade pre-commit
|
||||
python3 -m pre_commit run --all-files
|
||||
python3 -m pre_commit run --all-files --verbose
|
||||
|
||||
- name: Check pre-commit arc770
|
||||
if: ${{ matrix.runner != 'macos-10.15' && (matrix.runner[1] == 'arc770') }}
|
||||
|
||||
@@ -37,6 +37,8 @@ public:
|
||||
|
||||
bool isFastReduction();
|
||||
|
||||
bool isWarpSynchronous();
|
||||
|
||||
unsigned getInterWarpSize();
|
||||
|
||||
unsigned getIntraWarpSize();
|
||||
|
||||
@@ -81,6 +81,12 @@ SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
|
||||
return smemShape;
|
||||
}
|
||||
|
||||
bool ReduceOpHelper::isWarpSynchronous() {
|
||||
auto argsLayout = getSrcLayout();
|
||||
return isFastReduction() &&
|
||||
(triton::gpu::getWarpsPerCTA(argsLayout)[axis] == 1);
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
|
||||
SmallVector<SmallVector<unsigned>> smemShapes(3);
|
||||
|
||||
@@ -88,7 +94,7 @@ SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
|
||||
auto argLayoutMma = argLayout.dyn_cast<triton::gpu::MmaEncodingAttr>();
|
||||
|
||||
// that case doesn't need inter-warp communication
|
||||
if (isFastReduction() && triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
|
||||
if (isWarpSynchronous())
|
||||
return {{0, 0}, {0, 0}};
|
||||
|
||||
/// shared memory block0
|
||||
|
||||
@@ -335,7 +335,9 @@ private:
|
||||
unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData();
|
||||
|
||||
SmallVector<Value> smemBases(op.getNumOperands());
|
||||
if (sizeInterWarps > 1) {
|
||||
bool isWarpSync = helper.isWarpSynchronous();
|
||||
|
||||
if (!isWarpSync) {
|
||||
smemBases[0] = bitcast(
|
||||
getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]);
|
||||
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
|
||||
@@ -403,7 +405,7 @@ private:
|
||||
accumulate(rewriter, *combineOp, acc, shfl, false);
|
||||
}
|
||||
|
||||
if (sizeInterWarps == 1) {
|
||||
if (isWarpSync) {
|
||||
finalAccs[key] = acc;
|
||||
continue;
|
||||
}
|
||||
@@ -418,7 +420,7 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
if (sizeInterWarps == 1) {
|
||||
if (isWarpSync) {
|
||||
SmallVector<Value> results(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
if (auto resultTy =
|
||||
|
||||
Reference in New Issue
Block a user