mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Zahi/slice reduce rebased (#1594)
[BACKEND] Enable slice layout support for reduce op
This commit is contained in:
@@ -17,19 +17,23 @@ unsigned ReduceOpHelper::getInterWarpSize() {
|
||||
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
|
||||
unsigned sizeIntraWarps = getIntraWarpSize();
|
||||
return std::min(srcReduceDimSize / sizeIntraWarps,
|
||||
triton::gpu::getWarpsPerCTA(getSrcLayout())[axis]);
|
||||
triton::gpu::getWarpsPerCTAWithUniqueData(
|
||||
getSrcLayout(), getSrcShape())[axis]);
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getIntraWarpSize() {
|
||||
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
|
||||
return std::min(srcReduceDimSize,
|
||||
triton::gpu::getThreadsPerWarp(getSrcLayout())[axis]);
|
||||
triton::gpu::getThreadsPerWarpWithUniqueData(
|
||||
getSrcLayout(), getSrcShape())[axis]);
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getThreadsReductionAxis() {
|
||||
auto srcLayout = getSrcLayout();
|
||||
return triton::gpu::getThreadsPerWarp(srcLayout)[axis] *
|
||||
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
|
||||
auto srcShape = getSrcShape();
|
||||
return triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout,
|
||||
srcShape)[axis] *
|
||||
triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis];
|
||||
}
|
||||
|
||||
SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
|
||||
@@ -88,6 +92,9 @@ bool ReduceOpHelper::isSupportedLayout() {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if (auto sliceLayout = srcLayout.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user