Support WMMA layout in TritonAMDGPUAccelerateMatmulPass

-Introduce WmmaEncodingAttr for WMMA output
-Introduce BlockedToWMMA rewrite pattern in TritonAMDGPUAccelerateMatmulPass
-Provide a flag tho check if wmma instructions are supported by target

Signed-off-by: joviliast <iveselov.nn@gmail.com>
This commit is contained in:
joviliast
2023-12-11 20:53:06 +02:00
committed by Lixun Zhang
parent b7a412d82a
commit af15da2f84
9 changed files with 258 additions and 23 deletions

View File

@@ -448,7 +448,7 @@ bool supportMFMATypes(Type a, Type b) {
auto BF16 = TypeID::get<mlir::BFloat16Type>();
auto F32 = TypeID::get<mlir::Float32Type>();
auto Int = TypeID::get<mlir::IntegerType>();
DenseSet<std::pair<mlir::TypeID, mlir::TypeID>> supportedTypes = {
const static DenseSet<std::pair<mlir::TypeID, mlir::TypeID>> supportedTypes = {
{F32, F32},
{F16, F16},
{BF16, BF16},
@@ -485,6 +485,55 @@ bool supportMFMA(triton::DotOp op) {
return true;
}
static bool supportWMMAGranularity(int m, int n, int k) {
return m % 16 == 0 && n % 16 == 0 && k % 16 == 0;
}
bool supportWMMATypes(Type a, Type b, Type c, Type d) {
if (a != b || c != d)
return false;
if (a.isIntOrIndex()) {
if (!c.isIntOrIndex())
return false;
auto aWidth = a.getIntOrFloatBitWidth();
auto cWidth = c.getIntOrFloatBitWidth();
bool aValid = a.isUnsignedInteger() && (aWidth == 4 || aWidth == 8);
bool cValid = c.isSignedInteger() && cWidth == 32;
return aValid && cValid;
} else if (a.isa<FloatType>()) {
if (a.isBF16())
return c.isBF16() || c.isF32();
if (a.isF16())
return c.isF16() || c.isF32();
}
return false;
}
// TODO: check C D operands
bool supportWMMA(triton::DotOp op) {
auto aTy = op.getA().getType().cast<RankedTensorType>();
auto bTy = op.getB().getType().cast<RankedTensorType>();
auto cTy = op.getC().getType().cast<RankedTensorType>();
auto dTy = op.getResult().getType().cast<RankedTensorType>();
auto aElemTy = aTy.getElementType();
auto bElemTy = bTy.getElementType();
auto cElemTy = cTy.getElementType();
auto dElemTy = dTy.getElementType();
if (!supportWMMATypes(aElemTy, bElemTy, cElemTy, dElemTy))
return false;
auto aShape = aTy.getShape();
auto bShape = bTy.getShape();
assert(aShape[1] == bShape[0]);
if (!supportWMMAGranularity(aShape[0], bShape[1], aShape[1]))
return false;
return true;
}
#endif
bool supportMMA(Value value, int version) {