[Backend] Refactor sharedToDotOperandMFMA lowering (#439)

* Remove unnecessary xor computations for k-major swizzled tensors

* Support mfma16 and mfma4 in the fast path

* Choose warpsPerCTA according to nonKDim

* Set maxPhase=4 for mfma4

* Fix tests

For now, we do not disable swizzling for k-major tensors

* Remove fastPathComputeOffsetsTy1

* Enable k-major + disabled swizzling in the normal path
This commit is contained in:
Lixun Zhang
2024-01-12 12:50:18 -06:00
committed by GitHub
parent a7bb38ea79
commit 2e217c5a5c
5 changed files with 86 additions and 116 deletions

View File

@@ -157,6 +157,9 @@ compared to 1*64 when the hasLeadingOffset is false.
// maxPhase is set to SIMDWidth / perPhase
int vecSize = ((typeWidthInBit == 16) ? 64 : 32 ) / typeWidthInBit;
int maxPhase = SIMDWidth / perPhase;
// TODO (zhanglx): figure out better parameters for mfma4
if (mfmaEnc.getNonKDim() == 4 )
maxPhase = 4;
return get(context, vecSize, perPhase, maxPhase, order, CTALayout);
} else {