mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge commit '5df904233c11a65bd131ead7268f84cca7804275' into ifu230810-2
Conflicts: include/triton/Dialect/Triton/Transforms/Passes.h include/triton/Dialect/TritonGPU/IR/Dialect.h include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td lib/Analysis/Allocation.cpp lib/Analysis/Utility.cpp lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp lib/Target/LLVMIR/LLVMIRTranslation.cpp python/src/triton.cc python/triton/compiler/compiler.py python/triton/ops/flash_attention.py python/triton/runtime/autotuner.py python/triton/runtime/jit.py python/triton/tools/aot.py python/tutorials/06-fused-attention.py test/Conversion/tritongpu_to_llvm.mlir test/Target/tritongpu_to_llvmir.mlir test/Target/tritongpu_to_llvmir_noinline.mlir
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
#include "ConvertLayoutOpToLLVM.h"
|
||||
#include "Utility.h"
|
||||
|
||||
using ::mlir::LLVM::delinearize;
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::LLVM::getStridesFromShapeAndOrder;
|
||||
using ::mlir::LLVM::linearize;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::getContigPerThread;
|
||||
using ::mlir::triton::gpu::getOrder;
|
||||
@@ -18,7 +20,7 @@ namespace SharedToDotOperandMMAv1 {
|
||||
using CoordTy = SmallVector<Value>;
|
||||
using ValueTable = std::map<std::pair<int, int>, std::pair<Value, Value>>;
|
||||
|
||||
SmallVector<CoordTy> getMNCoords(Value thread,
|
||||
SmallVector<CoordTy> getMNCoords(Value thread, Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
ArrayRef<unsigned int> wpt,
|
||||
const MmaEncodingAttr &mmaLayout,
|
||||
@@ -160,9 +162,10 @@ private:
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
// TODO: fix the bug in MMAEncodingAttr document
|
||||
SmallVector<Value> multiDimWarpId(2);
|
||||
multiDimWarpId[0] = urem(warpId, i32_val(mmaLayout.getWarpsPerCTA()[0]));
|
||||
multiDimWarpId[1] = udiv(warpId, i32_val(mmaLayout.getWarpsPerCTA()[0]));
|
||||
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
|
||||
auto order = triton::gpu::getOrder(mmaLayout);
|
||||
SmallVector<Value> multiDimWarpId =
|
||||
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
||||
Value _1 = i32_val(1);
|
||||
Value _2 = i32_val(2);
|
||||
Value _4 = i32_val(4);
|
||||
@@ -201,8 +204,8 @@ private:
|
||||
auto [isARow, isBRow, isAVec4, isBVec4, _] =
|
||||
mmaLayout.decodeVoltaLayoutStates();
|
||||
auto coords = SharedToDotOperandMMAv1::getMNCoords(
|
||||
threadId, rewriter, mmaLayout.getWarpsPerCTA(), mmaLayout, shape,
|
||||
isARow, isBRow, isAVec4, isBVec4);
|
||||
threadId, loc, rewriter, mmaLayout.getWarpsPerCTA(), mmaLayout,
|
||||
shape, isARow, isBRow, isAVec4, isBVec4);
|
||||
return coords[elemId];
|
||||
} else {
|
||||
llvm_unreachable("Unexpected MMALayout version");
|
||||
@@ -465,12 +468,15 @@ private:
|
||||
}
|
||||
// Potentially we need to store for multiple CTAs in this replication
|
||||
auto accumNumReplicates = product<unsigned>(numReplicates);
|
||||
// unsigned elems = getTotalElemsPerThread(srcTy);
|
||||
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
|
||||
rewriter, srcTy);
|
||||
unsigned inVec = 0;
|
||||
unsigned outVec = 0;
|
||||
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
|
||||
if (getElementTypeOrSelf(op.getType()).isa<mlir::Float8E4M3B11FNUZType>()) {
|
||||
assert(inVec % 4 == 0 && "conversion not supported for FP8E4M3B15");
|
||||
assert(outVec % 4 == 0 && "conversion not supported for FP8E4M3B15");
|
||||
}
|
||||
|
||||
unsigned outElems = getTotalElemsPerThread(dstTy);
|
||||
auto outOrd = getOrder(dstLayout);
|
||||
|
||||
Reference in New Issue
Block a user