mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[MFMA][Dot] Swizzle mfma dot operands (#245)
* swizzling when loading dot operands * [MFMA][Dot] Swizzle mfma dot operands This PR supports swizzling in MFMA dot operands. * fix comments * Update TritonGPUAttrDefs.td * Update TritonGPUAttrDefs.td 2 --------- Co-authored-by: weihanmines <wei.han3@amd.com>
This commit is contained in:
@@ -80,11 +80,41 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
"Type":$eltTy), [{
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// ---- begin MI200 ----
|
||||
// ---- begin GFX908/GFX90A ----
|
||||
auto mfmaEnc = dotOpEnc.getParent().dyn_cast<MfmaEncodingAttr>();
|
||||
|
||||
if (mfmaEnc) {
|
||||
// Swizzling is currently disabled for MFMA
|
||||
return $_get(context, 1, 1, 1, order);
|
||||
int opIdx = dotOpEnc.getOpIdx();
|
||||
|
||||
// number of rows per phase
|
||||
int perPhase = 128 / (shape[order[0]] * (eltTy.getIntOrFloatBitWidth() / 8));
|
||||
perPhase = std::max<int>(perPhase, 1);
|
||||
|
||||
// index of the inner dimension in `order`
|
||||
unsigned inner = (opIdx == 0) ? 0 : 1;
|
||||
// for now, disable swizzle when using transposed int8 tensor cores
|
||||
if (eltTy.isInteger(8) && order[0] == inner)
|
||||
return $_get(context, 1, 1, 1, order);
|
||||
|
||||
// --- handle A operand ---
|
||||
if (opIdx == 0) { // compute swizzling for A operand
|
||||
// This is just an example of swizzle pattern, not a production ready
|
||||
unsigned vec = 8;
|
||||
unsigned maxPhase = 2;
|
||||
unsigned perPhase = 1;
|
||||
return $_get(context, vec, perPhase, maxPhase, order);
|
||||
}
|
||||
|
||||
// --- handle B operand ---
|
||||
if (opIdx == 1) {
|
||||
// This is an example of swizzle pattern, not a production ready
|
||||
unsigned vec = 2;
|
||||
unsigned maxPhase = 2;
|
||||
unsigned perPhase = 1;
|
||||
return $_get(context, vec, perPhase, maxPhase, order);
|
||||
}
|
||||
|
||||
llvm_unreachable("invalid operand index");
|
||||
}
|
||||
#endif
|
||||
auto mmaEnc = dotOpEnc.getParent().dyn_cast<MmaEncodingAttr>();
|
||||
|
||||
@@ -34,6 +34,49 @@ Value getWaveN(ConversionPatternRewriter &rewriter, Location loc, Value wave,
|
||||
|
||||
namespace SharedToDotOperandMFMA {
|
||||
|
||||
/**
|
||||
* @brief swizzling tensor element indexes according pattern encoded in
|
||||
* SharedEncodingAttr
|
||||
*
|
||||
* @param rewriter
|
||||
* @param loc
|
||||
* @param row row of target tensor element related to the start of smemObj
|
||||
* @param col col of target tensor element related to the start of smemObj
|
||||
* @param smemObj shared memory object, contains info about tensor in LDS
|
||||
* @param attr layout attribute, contains swizzling info
|
||||
* @return swizzled row, col indexes in tensor notation
|
||||
*/
|
||||
std::pair<mlir::Value, mlir::Value>
|
||||
swizzleIndexes(ConversionPatternRewriter &rewriter, Location loc, Value row,
|
||||
Value col, SharedMemoryObject smemObj, SharedEncodingAttr attr) {
|
||||
(void)smemObj; // unused in current pattern
|
||||
bool transposed = (attr.getOrder()[0] != 1);
|
||||
if (transposed) {
|
||||
// tensor is column-wise, so swapping col and row in computations
|
||||
std::swap(row, col);
|
||||
}
|
||||
auto vec = i32_val(attr.getVec());
|
||||
auto perPhase = i32_val(attr.getPerPhase());
|
||||
auto maxPhase = i32_val(attr.getMaxPhase());
|
||||
|
||||
// Original algorithm taken from getSwizzledSharedPtrs function
|
||||
// (TritonGPUToLLVMBase.h): Basic algorithm for row-major tensor is following:
|
||||
//
|
||||
// phase = (row // perPhase) % maxPhase
|
||||
// colOffSwizzled = ((col // vec) ^ phase) * vec
|
||||
// colOffOrdered = col % vec
|
||||
// colOff = colOffSwizzled + colOffOrdered
|
||||
auto phase = urem(udiv(row, perPhase), maxPhase);
|
||||
auto colOffSwizzled = mul(xor_(udiv(col, vec), phase), vec);
|
||||
auto colOffOrdered = urem(col, vec);
|
||||
auto colOff = add(colOffSwizzled, colOffOrdered);
|
||||
|
||||
if (transposed)
|
||||
return {colOff, row};
|
||||
else
|
||||
return {row, colOff};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief This function maps particular load of mfma dot operand to element
|
||||
* indexes(row, col)
|
||||
@@ -113,9 +156,11 @@ computeTensorElemMapping(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value computeOffset(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value row, Value col, SharedMemoryObject smemObj,
|
||||
SharedEncodingAttr srcLayout) {
|
||||
auto [swizzledRow, swizzledCol] =
|
||||
swizzleIndexes(rewriter, loc, row, col, smemObj, srcLayout);
|
||||
auto &strides = smemObj.strides;
|
||||
Value rowOffset = mul(row, strides[0]);
|
||||
Value colOffset = mul(col, strides[1]);
|
||||
Value rowOffset = mul(swizzledRow, strides[0]);
|
||||
Value colOffset = mul(swizzledCol, strides[1]);
|
||||
return add(rowOffset, colOffset);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user