Merge commit '5df904233c11a65bd131ead7268f84cca7804275' into ifu230810-2

Conflicts:
	include/triton/Dialect/Triton/Transforms/Passes.h
	include/triton/Dialect/TritonGPU/IR/Dialect.h
	include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
	lib/Analysis/Allocation.cpp
	lib/Analysis/Utility.cpp
	lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
	lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp
	lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
	lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp
	lib/Target/LLVMIR/LLVMIRTranslation.cpp
	python/src/triton.cc
	python/triton/compiler/compiler.py
	python/triton/ops/flash_attention.py
	python/triton/runtime/autotuner.py
	python/triton/runtime/jit.py
	python/triton/tools/aot.py
	python/tutorials/06-fused-attention.py
	test/Conversion/tritongpu_to_llvm.mlir
	test/Target/tritongpu_to_llvmir.mlir
	test/Target/tritongpu_to_llvmir_noinline.mlir
This commit is contained in:
Jason Furmanek
2023-09-01 03:25:33 +00:00
122 changed files with 7341 additions and 2234 deletions

View File

@@ -1,8 +1,10 @@
#include "ConvertLayoutOpToLLVM.h"
#include "Utility.h"
using ::mlir::LLVM::delinearize;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStridesFromShapeAndOrder;
using ::mlir::LLVM::linearize;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getContigPerThread;
using ::mlir::triton::gpu::getOrder;
@@ -18,7 +20,7 @@ namespace SharedToDotOperandMMAv1 {
using CoordTy = SmallVector<Value>;
using ValueTable = std::map<std::pair<int, int>, std::pair<Value, Value>>;
SmallVector<CoordTy> getMNCoords(Value thread,
SmallVector<CoordTy> getMNCoords(Value thread, Location loc,
ConversionPatternRewriter &rewriter,
ArrayRef<unsigned int> wpt,
const MmaEncodingAttr &mmaLayout,
@@ -160,9 +162,10 @@ private:
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
// TODO: fix the bug in MMAEncodingAttr document
SmallVector<Value> multiDimWarpId(2);
multiDimWarpId[0] = urem(warpId, i32_val(mmaLayout.getWarpsPerCTA()[0]));
multiDimWarpId[1] = udiv(warpId, i32_val(mmaLayout.getWarpsPerCTA()[0]));
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
auto order = triton::gpu::getOrder(mmaLayout);
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
Value _1 = i32_val(1);
Value _2 = i32_val(2);
Value _4 = i32_val(4);
@@ -201,8 +204,8 @@ private:
auto [isARow, isBRow, isAVec4, isBVec4, _] =
mmaLayout.decodeVoltaLayoutStates();
auto coords = SharedToDotOperandMMAv1::getMNCoords(
threadId, rewriter, mmaLayout.getWarpsPerCTA(), mmaLayout, shape,
isARow, isBRow, isAVec4, isBVec4);
threadId, loc, rewriter, mmaLayout.getWarpsPerCTA(), mmaLayout,
shape, isARow, isBRow, isAVec4, isBVec4);
return coords[elemId];
} else {
llvm_unreachable("Unexpected MMALayout version");
@@ -465,12 +468,15 @@ private:
}
// Potentially we need to store for multiple CTAs in this replication
auto accumNumReplicates = product<unsigned>(numReplicates);
// unsigned elems = getTotalElemsPerThread(srcTy);
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
rewriter, srcTy);
unsigned inVec = 0;
unsigned outVec = 0;
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
if (getElementTypeOrSelf(op.getType()).isa<mlir::Float8E4M3B11FNUZType>()) {
assert(inVec % 4 == 0 && "conversion not supported for FP8E4M3B15");
assert(outVec % 4 == 0 && "conversion not supported for FP8E4M3B15");
}
unsigned outElems = getTotalElemsPerThread(dstTy);
auto outOrd = getOrder(dstLayout);