[BACKEND] Unify slow/fast reduce codegen (#2220)

This commit is contained in:
Zahi Moudallal
2023-09-12 08:46:19 -07:00
committed by GitHub
parent fc5d7e6e7c
commit a47f1f5c28
7 changed files with 277 additions and 381 deletions

View File

@@ -33,14 +33,39 @@ SmallVector<unsigned> getParentOrder(Attribute layout) {
} // namespace
bool ReduceOpHelper::isFastReduction() {
// Disable fast reduction only for debugging purpose
if (::triton::tools::getBoolEnv("DISABLE_FAST_REDUCTION"))
return false;
bool ReduceOpHelper::isReductionOnLayoutFastAxis() {
return getParentAxis(getSrcLayout(), axis) ==
getParentOrder(getSrcLayout())[0];
}
// Thread offset is the thread index offset of two adjacent threads on the
// reduction axis within the warp.
unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
auto srcLayout = getSrcLayout();
// If the reduction axis is the fast axis of the parent layout
if (isReductionOnLayoutFastAxis()) {
return 1;
}
unsigned threadOffset = 1;
if (auto sliceLayout =
srcLayout.dyn_cast<mlir::triton::gpu::SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(parentLayout);
threadOffset = threadsPerWarp[sliceLayout.getDim()];
} else {
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout);
if (threadsPerWarp.size() == 1) {
threadOffset = 1;
} else {
assert(threadsPerWarp.size() == 2 && "Only supports 2D layouts");
threadOffset = axis == 0 ? threadsPerWarp[1] : threadsPerWarp[0];
}
}
return threadOffset;
}
// Cases where distributed shared memory is not required in ConvertLayout:
// (1) numCTAs == 1
// (2) numCTAs > 1 but srcCTALayout == dstCTALayout
@@ -124,53 +149,26 @@ unsigned ReduceOpHelper::getThreadsReductionAxis() {
triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis];
}
SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
auto smemShape = convertType<unsigned>(getSrcShape());
smemShape[axis] = std::min(smemShape[axis], getThreadsReductionAxis());
bool ReduceOpHelper::isWarpSynchronous() {
auto argsLayout = getSrcLayout();
return triton::gpu::getWarpsPerCTA(argsLayout)[axis] == 1;
}
SmallVector<unsigned> ReduceOpHelper::getScratchConfig() {
SmallVector<unsigned> smemShape;
// that case doesn't need inter-warp communication
if (isWarpSynchronous())
return {0, 0};
smemShape = convertType<unsigned>(getSrcShape());
smemShape[axis] = getInterWarpSizeWithUniqueData();
return smemShape;
}
bool ReduceOpHelper::isWarpSynchronous() {
auto argsLayout = getSrcLayout();
return isFastReduction() &&
(triton::gpu::getWarpsPerCTA(argsLayout)[axis] == 1);
}
SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
SmallVector<SmallVector<unsigned>> smemShapes(3);
auto argLayout = getSrcLayout();
auto argLayoutMma = argLayout.dyn_cast<triton::gpu::MmaEncodingAttr>();
// that case doesn't need inter-warp communication
if (isWarpSynchronous())
return {{0, 0}, {0, 0}};
/// shared memory block0
smemShapes[0] = convertType<unsigned>(getSrcShape());
smemShapes[0][axis] = getInterWarpSize();
/// FIXME(Qingyi): This size is actually larger than required.
/// shared memory block1:
auto mod = op->getParentOfType<ModuleOp>();
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
unsigned threadsPerWarp =
triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
smemShapes[1].push_back(numWarps * threadsPerWarp);
return smemShapes;
}
unsigned ReduceOpHelper::getScratchSizeInBytes() {
unsigned elems = 0;
if (isFastReduction()) {
auto smemShapes = getScratchConfigsFast();
for (const auto &smemShape : smemShapes)
elems = std::max(elems, product<unsigned>(smemShape));
} else {
auto smemShape = getScratchConfigBasic();
elems = product<unsigned>(smemShape);
}
auto smemShape = getScratchConfig();
auto elems = product<unsigned>(smemShape);
unsigned bytesPerElem = 0;
for (const auto &ty : srcElementTypes) {
@@ -179,7 +177,21 @@ unsigned ReduceOpHelper::getScratchSizeInBytes() {
return bytesPerElem * elems;
}
bool ReduceOpHelper::isReduceWithinCTA() {
auto axis = getAxis();
auto srcLayout = getSrcLayout();
auto CTASplitNum = mlir::triton::gpu::getCTASplitNum(srcLayout);
assert(axis < CTASplitNum.size());
return CTASplitNum[axis] == 1;
}
bool ReduceOpHelper::isSupportedLayout() {
// Layout optimization passes such as PlanCTAPass and
// RemoveLayoutConversionPass should avoid cross-CTA reduction
if (!isReduceWithinCTA()) {
return false;
}
auto srcLayout = getSrcLayout();
if (srcLayout.isa<triton::gpu::BlockedEncodingAttr>()) {
return true;