mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Unify slow/fast reduce codegen (#2220)
This commit is contained in:
@@ -33,14 +33,39 @@ SmallVector<unsigned> getParentOrder(Attribute layout) {
|
||||
|
||||
} // namespace
|
||||
|
||||
bool ReduceOpHelper::isFastReduction() {
|
||||
// Disable fast reduction only for debugging purpose
|
||||
if (::triton::tools::getBoolEnv("DISABLE_FAST_REDUCTION"))
|
||||
return false;
|
||||
bool ReduceOpHelper::isReductionOnLayoutFastAxis() {
|
||||
return getParentAxis(getSrcLayout(), axis) ==
|
||||
getParentOrder(getSrcLayout())[0];
|
||||
}
|
||||
|
||||
// Thread offset is the thread index offset of two adjacent threads on the
|
||||
// reduction axis within the warp.
|
||||
unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
|
||||
auto srcLayout = getSrcLayout();
|
||||
|
||||
// If the reduction axis is the fast axis of the parent layout
|
||||
if (isReductionOnLayoutFastAxis()) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
unsigned threadOffset = 1;
|
||||
if (auto sliceLayout =
|
||||
srcLayout.dyn_cast<mlir::triton::gpu::SliceEncodingAttr>()) {
|
||||
auto parentLayout = sliceLayout.getParent();
|
||||
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(parentLayout);
|
||||
threadOffset = threadsPerWarp[sliceLayout.getDim()];
|
||||
} else {
|
||||
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout);
|
||||
if (threadsPerWarp.size() == 1) {
|
||||
threadOffset = 1;
|
||||
} else {
|
||||
assert(threadsPerWarp.size() == 2 && "Only supports 2D layouts");
|
||||
threadOffset = axis == 0 ? threadsPerWarp[1] : threadsPerWarp[0];
|
||||
}
|
||||
}
|
||||
return threadOffset;
|
||||
}
|
||||
|
||||
// Cases where distributed shared memory is not required in ConvertLayout:
|
||||
// (1) numCTAs == 1
|
||||
// (2) numCTAs > 1 but srcCTALayout == dstCTALayout
|
||||
@@ -124,53 +149,26 @@ unsigned ReduceOpHelper::getThreadsReductionAxis() {
|
||||
triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis];
|
||||
}
|
||||
|
||||
SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
|
||||
auto smemShape = convertType<unsigned>(getSrcShape());
|
||||
smemShape[axis] = std::min(smemShape[axis], getThreadsReductionAxis());
|
||||
bool ReduceOpHelper::isWarpSynchronous() {
|
||||
auto argsLayout = getSrcLayout();
|
||||
return triton::gpu::getWarpsPerCTA(argsLayout)[axis] == 1;
|
||||
}
|
||||
|
||||
SmallVector<unsigned> ReduceOpHelper::getScratchConfig() {
|
||||
SmallVector<unsigned> smemShape;
|
||||
// that case doesn't need inter-warp communication
|
||||
if (isWarpSynchronous())
|
||||
return {0, 0};
|
||||
|
||||
smemShape = convertType<unsigned>(getSrcShape());
|
||||
smemShape[axis] = getInterWarpSizeWithUniqueData();
|
||||
|
||||
return smemShape;
|
||||
}
|
||||
|
||||
bool ReduceOpHelper::isWarpSynchronous() {
|
||||
auto argsLayout = getSrcLayout();
|
||||
return isFastReduction() &&
|
||||
(triton::gpu::getWarpsPerCTA(argsLayout)[axis] == 1);
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
|
||||
SmallVector<SmallVector<unsigned>> smemShapes(3);
|
||||
|
||||
auto argLayout = getSrcLayout();
|
||||
auto argLayoutMma = argLayout.dyn_cast<triton::gpu::MmaEncodingAttr>();
|
||||
|
||||
// that case doesn't need inter-warp communication
|
||||
if (isWarpSynchronous())
|
||||
return {{0, 0}, {0, 0}};
|
||||
|
||||
/// shared memory block0
|
||||
smemShapes[0] = convertType<unsigned>(getSrcShape());
|
||||
smemShapes[0][axis] = getInterWarpSize();
|
||||
|
||||
/// FIXME(Qingyi): This size is actually larger than required.
|
||||
/// shared memory block1:
|
||||
auto mod = op->getParentOfType<ModuleOp>();
|
||||
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
unsigned threadsPerWarp =
|
||||
triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
|
||||
smemShapes[1].push_back(numWarps * threadsPerWarp);
|
||||
|
||||
return smemShapes;
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getScratchSizeInBytes() {
|
||||
unsigned elems = 0;
|
||||
if (isFastReduction()) {
|
||||
auto smemShapes = getScratchConfigsFast();
|
||||
for (const auto &smemShape : smemShapes)
|
||||
elems = std::max(elems, product<unsigned>(smemShape));
|
||||
} else {
|
||||
auto smemShape = getScratchConfigBasic();
|
||||
elems = product<unsigned>(smemShape);
|
||||
}
|
||||
auto smemShape = getScratchConfig();
|
||||
auto elems = product<unsigned>(smemShape);
|
||||
|
||||
unsigned bytesPerElem = 0;
|
||||
for (const auto &ty : srcElementTypes) {
|
||||
@@ -179,7 +177,21 @@ unsigned ReduceOpHelper::getScratchSizeInBytes() {
|
||||
return bytesPerElem * elems;
|
||||
}
|
||||
|
||||
bool ReduceOpHelper::isReduceWithinCTA() {
|
||||
auto axis = getAxis();
|
||||
auto srcLayout = getSrcLayout();
|
||||
auto CTASplitNum = mlir::triton::gpu::getCTASplitNum(srcLayout);
|
||||
assert(axis < CTASplitNum.size());
|
||||
return CTASplitNum[axis] == 1;
|
||||
}
|
||||
|
||||
bool ReduceOpHelper::isSupportedLayout() {
|
||||
// Layout optimization passes such as PlanCTAPass and
|
||||
// RemoveLayoutConversionPass should avoid cross-CTA reduction
|
||||
if (!isReduceWithinCTA()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto srcLayout = getSrcLayout();
|
||||
if (srcLayout.isa<triton::gpu::BlockedEncodingAttr>()) {
|
||||
return true;
|
||||
|
||||
Reference in New Issue
Block a user