Merge commit '721897fcc4f942aa97d2e9ba3787a5e213758177' into ifu-231108

Conflicts:
	bin/triton-translate.cpp
	lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
	lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
	python/triton/compiler/compiler.py
	python/triton/runtime/jit.py
	python/tutorials/06-fused-attention.py
	test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
Jason Furmanek
2023-11-08 18:51:23 +00:00
72 changed files with 1623 additions and 838 deletions

View File

@@ -38,6 +38,17 @@ bool ReduceOpHelper::isReductionOnLayoutFastAxis() {
getParentOrder(getSrcLayout())[0];
}
SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
auto srcLayout = getSrcLayout();
auto order = triton::gpu::getOrder(srcLayout);
auto it = std::find(order.begin(), order.end(), axis);
// delete the axis from order
order.erase(it);
// insert axis at the beginning of order
order.insert(order.begin(), axis);
return order;
}
// Thread offset is the thread index offset of two adjacent threads on the
// reduction axis within the warp.
unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
@@ -56,11 +67,11 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
threadOffset = threadsPerWarp[sliceLayout.getDim()];
} else {
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout);
if (threadsPerWarp.size() == 1) {
threadOffset = 1;
} else {
assert(threadsPerWarp.size() == 2 && "Only supports 2D layouts");
threadOffset = axis == 0 ? threadsPerWarp[1] : threadsPerWarp[0];
auto order = triton::gpu::getOrder(srcLayout);
for (unsigned i = 0; i < order.size(); i++) {
if (order[i] == axis)
break;
threadOffset *= threadsPerWarp[order[i]];
}
}
return threadOffset;
@@ -150,8 +161,10 @@ unsigned ReduceOpHelper::getThreadsReductionAxis() {
}
bool ReduceOpHelper::isWarpSynchronous() {
auto argsLayout = getSrcLayout();
return triton::gpu::getWarpsPerCTA(argsLayout)[axis] == 1;
auto srcLayout = getSrcLayout();
auto srcShape = getSrcShape();
return triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis] ==
1;
}
SmallVector<unsigned> ReduceOpHelper::getScratchConfig() {
@@ -502,10 +515,10 @@ static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) {
// when #mma = MmaEncoding<version=3, warpsPerCTA=[..., 1]>
return src && dst && src.getVersionMajor() == 3 &&
src.getWarpsPerCTA()[1] == 1 && dst.getVersionMajor() == 3 &&
dst.getWarpsPerCTA()[1] == 1 && srcInstrShape[2] == dstInstrShape[2];
dst.getWarpsPerCTA()[1] == 1;
}
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
return isMmaToMmaShortcut(srcTy.getEncoding(), dstTy.getEncoding());
}
@@ -521,7 +534,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
srcTy.getElementType().isF16();
}
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy))
return true;
// dot_op<opIdx=0, parent=#mma> = #mma
@@ -713,7 +726,10 @@ SetVector<Operation *> multiRootGetSlice(Operation *op,
auto *currentOp = (slice)[currentIndex];
// Compute and insert the backwardSlice starting from currentOp.
backwardSlice.clear();
getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
mlir::BackwardSliceOptions opt;
opt.omitBlockArguments = true;
opt.filter = backwardFilter;
getBackwardSlice(currentOp, &backwardSlice, opt);
slice.insert(backwardSlice.begin(), backwardSlice.end());
// Compute and insert the forwardSlice starting from currentOp.