Zahi/slice reduce rebased (#1594)

[BACKEND] Enable slice layout support for reduce op
This commit is contained in:
Zahi Moudallal
2023-05-01 18:00:23 -07:00
committed by GitHub
parent 26d80f026d
commit 3449a9d40d
5 changed files with 187 additions and 16 deletions

View File

@@ -17,19 +17,23 @@ unsigned ReduceOpHelper::getInterWarpSize() {
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
unsigned sizeIntraWarps = getIntraWarpSize();
return std::min(srcReduceDimSize / sizeIntraWarps,
triton::gpu::getWarpsPerCTA(getSrcLayout())[axis]);
triton::gpu::getWarpsPerCTAWithUniqueData(
getSrcLayout(), getSrcShape())[axis]);
}
unsigned ReduceOpHelper::getIntraWarpSize() {
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
return std::min(srcReduceDimSize,
triton::gpu::getThreadsPerWarp(getSrcLayout())[axis]);
triton::gpu::getThreadsPerWarpWithUniqueData(
getSrcLayout(), getSrcShape())[axis]);
}
unsigned ReduceOpHelper::getThreadsReductionAxis() {
auto srcLayout = getSrcLayout();
return triton::gpu::getThreadsPerWarp(srcLayout)[axis] *
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
auto srcShape = getSrcShape();
return triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout,
srcShape)[axis] *
triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis];
}
SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
@@ -88,6 +92,9 @@ bool ReduceOpHelper::isSupportedLayout() {
return true;
}
}
if (auto sliceLayout = srcLayout.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
return true;
}
return false;
}