mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-27 03:01:52 -04:00
resolve some merge conflicts
fix more conflits Resolve merge conflicts Some more build and conflict fixes Resolve conflicts for 06-fused-attension.py resolve merge conflicts for the tutorial group gemm example Fixes for some LIT tests resolve remaining conflicts in tests Fix empty kernel set capability 0
This commit is contained in:
@@ -18,11 +18,8 @@ using ::mlir::triton::gpu::getOrder;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::getShapePerCTATile;
|
||||
using ::mlir::triton::gpu::getSizePerThread;
|
||||
<<<<<<< HEAD
|
||||
using ::mlir::triton::gpu::MfmaEncodingAttr;
|
||||
=======
|
||||
using ::mlir::triton::gpu::getUniqueContigPerThread;
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||
@@ -79,7 +76,6 @@ SmallVector<unsigned> getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op) {
|
||||
}
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
#ifdef USE_ROCM
|
||||
if (srcLayout.isa<MfmaEncodingAttr>() &&
|
||||
srcLayout.dyn_cast<MfmaEncodingAttr>().getIsTransposed() &&
|
||||
@@ -88,18 +84,7 @@ SmallVector<unsigned> getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op) {
|
||||
return {};
|
||||
#endif
|
||||
|
||||
assert(srcLayout && dstLayout &&
|
||||
"Unexpected layout in getScratchConfigForCvtLayout()");
|
||||
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
|
||||
unsigned srcContigPerThread = getContigPerThread(srcLayout)[inOrd[0]];
|
||||
unsigned dstContigPerThread = getContigPerThread(dstLayout)[outOrd[0]];
|
||||
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
|
||||
// that we cannot do vectorization.
|
||||
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
|
||||
outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;
|
||||
=======
|
||||
assert(srcLayout && dstLayout && "Unexpected layout in getRepShape()");
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
|
||||
auto srcShapePerCTA = getShapePerCTA(srcTy);
|
||||
auto dstShapePerCTA = getShapePerCTA(dstTy);
|
||||
|
||||
@@ -347,15 +347,9 @@ unsigned ScanLoweringHelper::getAxisBlockStride() {
|
||||
for (unsigned dim : order) {
|
||||
if (dim == getAxis())
|
||||
return stride;
|
||||
<<<<<<< HEAD
|
||||
stride *= ceil<unsigned int>(type.getShape()[dim], sizePerThreads[dim] *
|
||||
threadsPerWarp[dim] *
|
||||
warpsPerCTA[dim]);
|
||||
=======
|
||||
stride *= ceil<unsigned int>(getShape()[dim], sizePerThreads[dim] *
|
||||
threadsPerWarp[dim] *
|
||||
warpsPerCTA[dim]);
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
}
|
||||
llvm_unreachable("Axis not found in order");
|
||||
}
|
||||
@@ -543,7 +537,6 @@ bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
!srcTy.getElementType().isF32();
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
#ifdef USE_ROCM
|
||||
bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
@@ -562,19 +555,6 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
}
|
||||
#endif
|
||||
|
||||
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
auto src = srcTy.getEncoding().cast<triton::gpu::MmaEncodingAttr>();
|
||||
auto dst = dstTy.getEncoding().cast<triton::gpu::MmaEncodingAttr>();
|
||||
auto srcElemsPerThread = triton::gpu::getTotalElemsPerThread(srcTy);
|
||||
auto dstElemsPerThread = triton::gpu::getTotalElemsPerThread(dstTy);
|
||||
// when #mma = MmaEncoding<version=3, warpsPerCTA=[..., 1]>
|
||||
return src.getVersionMajor() == 3 && src.getWarpsPerCTA()[1] == 1 &&
|
||||
dst.getVersionMajor() == 3 && dst.getWarpsPerCTA()[1] == 1 &&
|
||||
srcElemsPerThread == dstElemsPerThread;
|
||||
}
|
||||
|
||||
=======
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
bool isSingleValue(Value value) {
|
||||
// Don't consider load as expensive if it is loading a scalar.
|
||||
if (auto tensorTy = value.getType().dyn_cast<RankedTensorType>())
|
||||
|
||||
Reference in New Issue
Block a user