[MFMA] FP8 and BF8 support (#355)

* [MFMA] FP8 and BF8 support

This PR adds support of fp8 and bf8 in AccelerateMatmul pass and
Introduces generation of float8 mfma instructions in ttg to llvm conversion.

* add tests

* fix tests

* review fix: fix variable naming and dot operand promotion.

* review comments fixes

---------

Co-authored-by: Shucai Xiao <shucai.xiao@amd.com>
This commit is contained in:
Alexander Efimov
2023-10-25 20:27:10 +02:00
committed by GitHub
parent 8547694665
commit 5a86b46bb1
8 changed files with 199 additions and 30 deletions

View File

@@ -392,6 +392,34 @@ static bool supportMFMAGranularity(int m, int n, int k) {
return false;
}
bool supportMFMATypes(Type a, Type b) {
if (a.getIntOrFloatBitWidth() != b.getIntOrFloatBitWidth())
return false;
auto F8E4M3FNUZ = TypeID::get<mlir::Float8E4M3FNUZType>();
auto F8E5M2FNUZ = TypeID::get<mlir::Float8E5M2FNUZType>();
auto F16 = TypeID::get<mlir::Float16Type>();
auto BF16 = TypeID::get<mlir::BFloat16Type>();
auto F32 = TypeID::get<mlir::Float32Type>();
auto Int = TypeID::get<mlir::IntegerType>();
DenseSet<std::pair<mlir::TypeID, mlir::TypeID>> supportedTypes = {
{F32, F32},
{F16, F16},
{BF16, BF16},
{F8E4M3FNUZ, F8E4M3FNUZ},
{F8E4M3FNUZ, F8E5M2FNUZ},
{F8E5M2FNUZ, F8E4M3FNUZ},
{F8E5M2FNUZ, F8E5M2FNUZ},
{Int, Int}};
if (!supportedTypes.contains({a.getTypeID(), b.getTypeID()}))
return false;
if (a.isIntOrIndex() && a.getIntOrFloatBitWidth() != 8)
return false;
return true;
}
bool supportMFMA(triton::DotOp op) {
auto aTy = op.getA().getType().cast<RankedTensorType>();
auto bTy = op.getB().getType().cast<RankedTensorType>();
@@ -399,7 +427,7 @@ bool supportMFMA(triton::DotOp op) {
auto aElemTy = aTy.getElementType();
auto bElemTy = bTy.getElementType();
if (aElemTy != bElemTy)
if (!supportMFMATypes(aElemTy, bElemTy))
return false;
auto aShape = aTy.getShape();
@@ -409,8 +437,7 @@ bool supportMFMA(triton::DotOp op) {
if (!supportMFMAGranularity(aShape[0], bShape[1], aShape[1]))
return false;
return aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32() ||
aElemTy.isInteger(8);
return true;
}
#endif