[Backend] Refactor sharedToDotOperandMFMA lowering (#439)

* Remove unnecessary xor computations for k-major swizzled tensors

* Support mfma16 and mfma4 in the fast path

* Choose warpsPerCTA according to nonKDim

* Set maxPhase=4 for mfma4

* Fix tests

For now, we do not disable swizzling for k-major tensors

* Remove fastPathComputeOffsetsTy1

* Enable k-major + disabled swizzling in the normal path
This commit is contained in:
Lixun Zhang
2024-01-12 12:50:18 -06:00
committed by GitHub
parent a7bb38ea79
commit 2e217c5a5c
5 changed files with 86 additions and 116 deletions

View File

@@ -94,8 +94,9 @@ warpsPerTile(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
}
SmallVector<unsigned, 2>
warpsPerTileMFMA(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
return warpsPerTile(dotOp, shape, numWarps, {32, 32});
warpsPerTileMFMA(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
SmallVector<int64_t, 2> shapePerWarp) {
return warpsPerTile(dotOp, shape, numWarps, shapePerWarp);
}
SmallVector<unsigned, 2>
@@ -263,7 +264,8 @@ public:
auto [nonKDim, kDim] = chooseMfmaDimensions(dotOp);
auto warpsPerTile = warpsPerTileMFMA(dotOp, retShape, numWarps);
auto warpsPerTile =
warpsPerTileMFMA(dotOp, retShape, numWarps, {nonKDim, nonKDim});
bool isTransposed = isChainDot(dotOp);
mfmaEnc = ttg::MfmaEncodingAttr::get(oldRetType.getContext(), nonKDim,