mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Support WMMA layout in TritonAMDGPUAccelerateMatmulPass
-Introduce WmmaEncodingAttr for WMMA output -Introduce BlockedToWMMA rewrite pattern in TritonAMDGPUAccelerateMatmulPass -Provide a flag tho check if wmma instructions are supported by target Signed-off-by: joviliast <iveselov.nn@gmail.com>
This commit is contained in:
@@ -448,7 +448,7 @@ bool supportMFMATypes(Type a, Type b) {
|
||||
auto BF16 = TypeID::get<mlir::BFloat16Type>();
|
||||
auto F32 = TypeID::get<mlir::Float32Type>();
|
||||
auto Int = TypeID::get<mlir::IntegerType>();
|
||||
DenseSet<std::pair<mlir::TypeID, mlir::TypeID>> supportedTypes = {
|
||||
const static DenseSet<std::pair<mlir::TypeID, mlir::TypeID>> supportedTypes = {
|
||||
{F32, F32},
|
||||
{F16, F16},
|
||||
{BF16, BF16},
|
||||
@@ -485,6 +485,55 @@ bool supportMFMA(triton::DotOp op) {
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool supportWMMAGranularity(int m, int n, int k) {
|
||||
return m % 16 == 0 && n % 16 == 0 && k % 16 == 0;
|
||||
}
|
||||
|
||||
bool supportWMMATypes(Type a, Type b, Type c, Type d) {
|
||||
if (a != b || c != d)
|
||||
return false;
|
||||
if (a.isIntOrIndex()) {
|
||||
if (!c.isIntOrIndex())
|
||||
return false;
|
||||
auto aWidth = a.getIntOrFloatBitWidth();
|
||||
auto cWidth = c.getIntOrFloatBitWidth();
|
||||
bool aValid = a.isUnsignedInteger() && (aWidth == 4 || aWidth == 8);
|
||||
bool cValid = c.isSignedInteger() && cWidth == 32;
|
||||
return aValid && cValid;
|
||||
} else if (a.isa<FloatType>()) {
|
||||
if (a.isBF16())
|
||||
return c.isBF16() || c.isF32();
|
||||
if (a.isF16())
|
||||
return c.isF16() || c.isF32();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: check C D operands
|
||||
bool supportWMMA(triton::DotOp op) {
|
||||
auto aTy = op.getA().getType().cast<RankedTensorType>();
|
||||
auto bTy = op.getB().getType().cast<RankedTensorType>();
|
||||
auto cTy = op.getC().getType().cast<RankedTensorType>();
|
||||
auto dTy = op.getResult().getType().cast<RankedTensorType>();
|
||||
|
||||
auto aElemTy = aTy.getElementType();
|
||||
auto bElemTy = bTy.getElementType();
|
||||
auto cElemTy = cTy.getElementType();
|
||||
auto dElemTy = dTy.getElementType();
|
||||
|
||||
if (!supportWMMATypes(aElemTy, bElemTy, cElemTy, dElemTy))
|
||||
return false;
|
||||
|
||||
auto aShape = aTy.getShape();
|
||||
auto bShape = bTy.getShape();
|
||||
|
||||
assert(aShape[1] == bShape[0]);
|
||||
if (!supportWMMAGranularity(aShape[0], bShape[1], aShape[1]))
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
bool supportMMA(Value value, int version) {
|
||||
|
||||
Reference in New Issue
Block a user