mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[MFMA] FP8 and BF8 support (#355)
* [MFMA] FP8 and BF8 support This PR adds support of fp8 and bf8 in AccelerateMatmul pass and Introduces generation of float8 mfma instructions in ttg to llvm conversion. * add tests * fix tests * review fix: fix variable naming and dot operand promotion. * review comments fixes --------- Co-authored-by: Shucai Xiao <shucai.xiao@amd.com>
This commit is contained in:
@@ -392,6 +392,34 @@ static bool supportMFMAGranularity(int m, int n, int k) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool supportMFMATypes(Type a, Type b) {
|
||||
if (a.getIntOrFloatBitWidth() != b.getIntOrFloatBitWidth())
|
||||
return false;
|
||||
|
||||
auto F8E4M3FNUZ = TypeID::get<mlir::Float8E4M3FNUZType>();
|
||||
auto F8E5M2FNUZ = TypeID::get<mlir::Float8E5M2FNUZType>();
|
||||
auto F16 = TypeID::get<mlir::Float16Type>();
|
||||
auto BF16 = TypeID::get<mlir::BFloat16Type>();
|
||||
auto F32 = TypeID::get<mlir::Float32Type>();
|
||||
auto Int = TypeID::get<mlir::IntegerType>();
|
||||
DenseSet<std::pair<mlir::TypeID, mlir::TypeID>> supportedTypes = {
|
||||
{F32, F32},
|
||||
{F16, F16},
|
||||
{BF16, BF16},
|
||||
{F8E4M3FNUZ, F8E4M3FNUZ},
|
||||
{F8E4M3FNUZ, F8E5M2FNUZ},
|
||||
{F8E5M2FNUZ, F8E4M3FNUZ},
|
||||
{F8E5M2FNUZ, F8E5M2FNUZ},
|
||||
{Int, Int}};
|
||||
|
||||
if (!supportedTypes.contains({a.getTypeID(), b.getTypeID()}))
|
||||
return false;
|
||||
|
||||
if (a.isIntOrIndex() && a.getIntOrFloatBitWidth() != 8)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool supportMFMA(triton::DotOp op) {
|
||||
auto aTy = op.getA().getType().cast<RankedTensorType>();
|
||||
auto bTy = op.getB().getType().cast<RankedTensorType>();
|
||||
@@ -399,7 +427,7 @@ bool supportMFMA(triton::DotOp op) {
|
||||
auto aElemTy = aTy.getElementType();
|
||||
auto bElemTy = bTy.getElementType();
|
||||
|
||||
if (aElemTy != bElemTy)
|
||||
if (!supportMFMATypes(aElemTy, bElemTy))
|
||||
return false;
|
||||
|
||||
auto aShape = aTy.getShape();
|
||||
@@ -409,8 +437,7 @@ bool supportMFMA(triton::DotOp op) {
|
||||
if (!supportMFMAGranularity(aShape[0], bShape[1], aShape[1]))
|
||||
return false;
|
||||
|
||||
return aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32() ||
|
||||
aElemTy.isInteger(8);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
Reference in New Issue
Block a user