mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[Backend] Refactor sharedToDotOperandMFMA lowering (#439)
* Remove unnecessary xor computations for k-major swizzled tensors * Support mfma16 and mfma4 in the fast path * Choose warpsPerCTA according to nonKDim * Set maxPhase=4 for mfma4 * Fix tests For now, we do not disable swizzling for k-major tensors * Remove fastPathComputeOffsetsTy1 * Enable k-major + disabled swizzling in the normal path
This commit is contained in:
@@ -94,8 +94,9 @@ warpsPerTile(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2>
|
||||
warpsPerTileMFMA(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
|
||||
return warpsPerTile(dotOp, shape, numWarps, {32, 32});
|
||||
warpsPerTileMFMA(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
|
||||
SmallVector<int64_t, 2> shapePerWarp) {
|
||||
return warpsPerTile(dotOp, shape, numWarps, shapePerWarp);
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2>
|
||||
@@ -263,7 +264,8 @@ public:
|
||||
|
||||
auto [nonKDim, kDim] = chooseMfmaDimensions(dotOp);
|
||||
|
||||
auto warpsPerTile = warpsPerTileMFMA(dotOp, retShape, numWarps);
|
||||
auto warpsPerTile =
|
||||
warpsPerTileMFMA(dotOp, retShape, numWarps, {nonKDim, nonKDim});
|
||||
|
||||
bool isTransposed = isChainDot(dotOp);
|
||||
mfmaEnc = ttg::MfmaEncodingAttr::get(oldRetType.getContext(), nonKDim,
|
||||
|
||||
Reference in New Issue
Block a user