[MFMA][Dot] Swizzle mfma dot operands (#245)

* swizzling when loading dot operands

* [MFMA][Dot] Swizzle mfma dot operands

This PR supports swizzling in MFMA dot operands.

* fix comments

* Update TritonGPUAttrDefs.td

* Update TritonGPUAttrDefs.td 2

---------

Co-authored-by: weihanmines <wei.han3@amd.com>
This commit is contained in:
Alexander Efimov
2023-07-12 21:19:31 +03:00
committed by GitHub
parent f3e339e5f4
commit 4d0deef45f
2 changed files with 80 additions and 5 deletions

View File

@@ -80,11 +80,41 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
"Type":$eltTy), [{
#ifdef USE_ROCM
// ---- begin MI200 ----
// ---- begin GFX908/GFX90A ----
auto mfmaEnc = dotOpEnc.getParent().dyn_cast<MfmaEncodingAttr>();
if (mfmaEnc) {
// Swizzling is currently disabled for MFMA
return $_get(context, 1, 1, 1, order);
int opIdx = dotOpEnc.getOpIdx();
// number of rows per phase
int perPhase = 128 / (shape[order[0]] * (eltTy.getIntOrFloatBitWidth() / 8));
perPhase = std::max<int>(perPhase, 1);
// index of the inner dimension in `order`
unsigned inner = (opIdx == 0) ? 0 : 1;
// for now, disable swizzle when using transposed int8 tensor cores
if (eltTy.isInteger(8) && order[0] == inner)
return $_get(context, 1, 1, 1, order);
// --- handle A operand ---
if (opIdx == 0) { // compute swizzling for A operand
// This is just an example of swizzle pattern, not a production ready
unsigned vec = 8;
unsigned maxPhase = 2;
unsigned perPhase = 1;
return $_get(context, vec, perPhase, maxPhase, order);
}
// --- handle B operand ---
if (opIdx == 1) {
// This is an example of swizzle pattern, not a production ready
unsigned vec = 2;
unsigned maxPhase = 2;
unsigned perPhase = 1;
return $_get(context, vec, perPhase, maxPhase, order);
}
llvm_unreachable("invalid operand index");
}
#endif
auto mmaEnc = dotOpEnc.getParent().dyn_cast<MmaEncodingAttr>();

View File

@@ -34,6 +34,49 @@ Value getWaveN(ConversionPatternRewriter &rewriter, Location loc, Value wave,
namespace SharedToDotOperandMFMA {
/**
* @brief swizzling tensor element indexes according pattern encoded in
* SharedEncodingAttr
*
* @param rewriter
* @param loc
* @param row row of target tensor element related to the start of smemObj
* @param col col of target tensor element related to the start of smemObj
* @param smemObj shared memory object, contains info about tensor in LDS
* @param attr layout attribute, contains swizzling info
* @return swizzled row, col indexes in tensor notation
*/
std::pair<mlir::Value, mlir::Value>
swizzleIndexes(ConversionPatternRewriter &rewriter, Location loc, Value row,
Value col, SharedMemoryObject smemObj, SharedEncodingAttr attr) {
(void)smemObj; // unused in current pattern
bool transposed = (attr.getOrder()[0] != 1);
if (transposed) {
// tensor is column-wise, so swapping col and row in computations
std::swap(row, col);
}
auto vec = i32_val(attr.getVec());
auto perPhase = i32_val(attr.getPerPhase());
auto maxPhase = i32_val(attr.getMaxPhase());
// Original algorithm taken from getSwizzledSharedPtrs function
// (TritonGPUToLLVMBase.h): Basic algorithm for row-major tensor is following:
//
// phase = (row // perPhase) % maxPhase
// colOffSwizzled = ((col // vec) ^ phase) * vec
// colOffOrdered = col % vec
// colOff = colOffSwizzled + colOffOrdered
auto phase = urem(udiv(row, perPhase), maxPhase);
auto colOffSwizzled = mul(xor_(udiv(col, vec), phase), vec);
auto colOffOrdered = urem(col, vec);
auto colOff = add(colOffSwizzled, colOffOrdered);
if (transposed)
return {colOff, row};
else
return {row, colOff};
}
/**
* @brief This function maps particular load of mfma dot operand to element
* indexes(row, col)
@@ -113,9 +156,11 @@ computeTensorElemMapping(ConversionPatternRewriter &rewriter, Location loc,
Value computeOffset(ConversionPatternRewriter &rewriter, Location loc,
Value row, Value col, SharedMemoryObject smemObj,
SharedEncodingAttr srcLayout) {
auto [swizzledRow, swizzledCol] =
swizzleIndexes(rewriter, loc, row, col, smemObj, srcLayout);
auto &strides = smemObj.strides;
Value rowOffset = mul(row, strides[0]);
Value colOffset = mul(col, strides[1]);
Value rowOffset = mul(swizzledRow, strides[0]);
Value colOffset = mul(swizzledCol, strides[1]);
return add(rowOffset, colOffset);
}