mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Initial commit to resolve merge conflicts
rename tl.float8e4 to tl.float8e4nv to align with upstream ROCM IFU: Fix python arch issues ROCM IFU: Fix kernel launcher ROCM IFU: Fix merge conflicts fix debug build Set correct threadsPerCTA
This commit is contained in:
@@ -356,9 +356,6 @@ bool supportMMA(triton::DotOp op, int version) {
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
|
||||
auto aElemTy = op.getA().getType().cast<RankedTensorType>().getElementType();
|
||||
auto bElemTy = op.getB().getType().cast<RankedTensorType>().getElementType();
|
||||
<<<<<<< HEAD
|
||||
|
||||
=======
|
||||
if (version == 3) {
|
||||
if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3"))
|
||||
return false;
|
||||
@@ -374,7 +371,6 @@ bool supportMMA(triton::DotOp op, int version) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
if (aElemTy.isF32() && bElemTy.isF32()) {
|
||||
return (op.getAllowTF32() && version == 2) || version == 3;
|
||||
}
|
||||
@@ -446,7 +442,6 @@ bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
!srcTy.getElementType().isF32();
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
#ifdef USE_ROCM
|
||||
bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
@@ -464,7 +459,7 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
|
||||
}
|
||||
#endif
|
||||
=======
|
||||
|
||||
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
auto src = srcTy.getEncoding().cast<triton::gpu::MmaEncodingAttr>();
|
||||
auto dst = dstTy.getEncoding().cast<triton::gpu::MmaEncodingAttr>();
|
||||
@@ -475,7 +470,6 @@ bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
dst.getVersionMajor() == 3 && dst.getWarpsPerCTA()[1] == 1 &&
|
||||
srcElemsPerThread == dstElemsPerThread;
|
||||
}
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
|
||||
bool isSingleValue(Value value) {
|
||||
// Don't consider load as expensive if it is loading a scalar.
|
||||
|
||||
Reference in New Issue
Block a user