mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge commit '721897fcc4f942aa97d2e9ba3787a5e213758177' into ifu-231108
Conflicts: bin/triton-translate.cpp lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp python/triton/compiler/compiler.py python/triton/runtime/jit.py python/tutorials/06-fused-attention.py test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
@@ -38,6 +38,17 @@ bool ReduceOpHelper::isReductionOnLayoutFastAxis() {
|
||||
getParentOrder(getSrcLayout())[0];
|
||||
}
|
||||
|
||||
SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
|
||||
auto srcLayout = getSrcLayout();
|
||||
auto order = triton::gpu::getOrder(srcLayout);
|
||||
auto it = std::find(order.begin(), order.end(), axis);
|
||||
// delete the axis from order
|
||||
order.erase(it);
|
||||
// insert axis at the beginning of order
|
||||
order.insert(order.begin(), axis);
|
||||
return order;
|
||||
}
|
||||
|
||||
// Thread offset is the thread index offset of two adjacent threads on the
|
||||
// reduction axis within the warp.
|
||||
unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
|
||||
@@ -56,11 +67,11 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
|
||||
threadOffset = threadsPerWarp[sliceLayout.getDim()];
|
||||
} else {
|
||||
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout);
|
||||
if (threadsPerWarp.size() == 1) {
|
||||
threadOffset = 1;
|
||||
} else {
|
||||
assert(threadsPerWarp.size() == 2 && "Only supports 2D layouts");
|
||||
threadOffset = axis == 0 ? threadsPerWarp[1] : threadsPerWarp[0];
|
||||
auto order = triton::gpu::getOrder(srcLayout);
|
||||
for (unsigned i = 0; i < order.size(); i++) {
|
||||
if (order[i] == axis)
|
||||
break;
|
||||
threadOffset *= threadsPerWarp[order[i]];
|
||||
}
|
||||
}
|
||||
return threadOffset;
|
||||
@@ -150,8 +161,10 @@ unsigned ReduceOpHelper::getThreadsReductionAxis() {
|
||||
}
|
||||
|
||||
bool ReduceOpHelper::isWarpSynchronous() {
|
||||
auto argsLayout = getSrcLayout();
|
||||
return triton::gpu::getWarpsPerCTA(argsLayout)[axis] == 1;
|
||||
auto srcLayout = getSrcLayout();
|
||||
auto srcShape = getSrcShape();
|
||||
return triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis] ==
|
||||
1;
|
||||
}
|
||||
|
||||
SmallVector<unsigned> ReduceOpHelper::getScratchConfig() {
|
||||
@@ -502,10 +515,10 @@ static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) {
|
||||
// when #mma = MmaEncoding<version=3, warpsPerCTA=[..., 1]>
|
||||
return src && dst && src.getVersionMajor() == 3 &&
|
||||
src.getWarpsPerCTA()[1] == 1 && dst.getVersionMajor() == 3 &&
|
||||
dst.getWarpsPerCTA()[1] == 1 && srcInstrShape[2] == dstInstrShape[2];
|
||||
dst.getWarpsPerCTA()[1] == 1;
|
||||
}
|
||||
|
||||
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
|
||||
return isMmaToMmaShortcut(srcTy.getEncoding(), dstTy.getEncoding());
|
||||
}
|
||||
|
||||
@@ -521,7 +534,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
|
||||
srcTy.getElementType().isF16();
|
||||
}
|
||||
|
||||
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
|
||||
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy))
|
||||
return true;
|
||||
// dot_op<opIdx=0, parent=#mma> = #mma
|
||||
@@ -713,7 +726,10 @@ SetVector<Operation *> multiRootGetSlice(Operation *op,
|
||||
auto *currentOp = (slice)[currentIndex];
|
||||
// Compute and insert the backwardSlice starting from currentOp.
|
||||
backwardSlice.clear();
|
||||
getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
|
||||
mlir::BackwardSliceOptions opt;
|
||||
opt.omitBlockArguments = true;
|
||||
opt.filter = backwardFilter;
|
||||
getBackwardSlice(currentOp, &backwardSlice, opt);
|
||||
slice.insert(backwardSlice.begin(), backwardSlice.end());
|
||||
|
||||
// Compute and insert the forwardSlice starting from currentOp.
|
||||
|
||||
Reference in New Issue
Block a user