Initial commit to resolve merge conflicts

rename tl.float8e4 to tl.float8e4nv to align with upstream

ROCM IFU: Fix python arch issues

ROCM IFU: Fix kernel launcher

ROCM IFU: Fix merge conflicts

fix debug build

Set correct threadsPerCTA
This commit is contained in:
Jason Furmanek
2023-09-12 20:43:59 +00:00
parent 74fd8e9754
commit e5d7bb4fae
36 changed files with 414 additions and 1005 deletions

View File

@@ -356,9 +356,6 @@ bool supportMMA(triton::DotOp op, int version) {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
auto aElemTy = op.getA().getType().cast<RankedTensorType>().getElementType();
auto bElemTy = op.getB().getType().cast<RankedTensorType>().getElementType();
<<<<<<< HEAD
=======
if (version == 3) {
if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3"))
return false;
@@ -374,7 +371,6 @@ bool supportMMA(triton::DotOp op, int version) {
return false;
}
}
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
if (aElemTy.isF32() && bElemTy.isF32()) {
return (op.getAllowTF32() && version == 2) || version == 3;
}
@@ -446,7 +442,6 @@ bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
!srcTy.getElementType().isF32();
}
<<<<<<< HEAD
#ifdef USE_ROCM
bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
auto srcLayout = srcTy.getEncoding();
@@ -464,7 +459,7 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
}
#endif
=======
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
auto src = srcTy.getEncoding().cast<triton::gpu::MmaEncodingAttr>();
auto dst = dstTy.getEncoding().cast<triton::gpu::MmaEncodingAttr>();
@@ -475,7 +470,6 @@ bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
dst.getVersionMajor() == 3 && dst.getWarpsPerCTA()[1] == 1 &&
srcElemsPerThread == dstElemsPerThread;
}
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
bool isSingleValue(Value value) {
// Don't consider load as expensive if it is loading a scalar.