Set vecSize and maxPhase more generically

This commit is contained in:
Lixun Zhang
2023-08-10 15:52:49 -05:00
committed by Lixun Zhang
parent 7156fcb0ef
commit 87e45cb011
2 changed files with 17 additions and 9 deletions

View File

@@ -89,21 +89,29 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
if (isKDimInner) {
const int numBanks = 32;
const int bankBitWidth = 32;
const int SIMDWidth = 16;
// number of inner dimension rows per one pattern repeat
int outerDimGranularity = mfmaEnc.getNonKDim();
int typeBitWidth = eltTy.getIntOrFloatBitWidth();
int innerDimLength = shape[order[0]];
int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeBitWidth;
int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
// Note: the following settings is customized for mfma_32x32x8f16
// to avoid **load** bank conflicts
// vecSize is set to k_base, which is 4
// maxPhase is set to BLOCK_K/4 so that every 16 workitems will access
// difference banks
int vecSize = 4;
int maxPhase = innerDimLength / 4;
// Note: the following settings is customized to avoid
// **load** bank conflicts
//
// vecSize is set to k_base, which is the number of elements each
// workitem loads for one mfma instruction.
// For now, the k_base rules are as follows
// 1. All selected mfma instructions produce a single block
// 2. For f16 data type, 2 VGPRs are used for operand A --> k_base = 4
// 3. For non-f16 data types, 1 VGPR are used for operand A
// k_base = 32 / elemTypeInBits
// 4. TODO: what about f64?
//
// maxPhase is set to SIMDWidth / perPhase
int vecSize = (eltTy.isF16() ? 64 : 32 ) / typeBitWidth;
int maxPhase = SIMDWidth / perPhase;
return $_get(context, vecSize, perPhase, maxPhase, order);
} else {

View File

@@ -77,7 +77,7 @@ def optimize_ttgir(mod, num_stages, arch):
pm = ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_tritongpu_coalesce_pass()
#pm.add_tritongpu_remove_layout_conversions_pass()
pm.add_tritongpu_remove_layout_conversions_pass()
if _is_cuda(arch):
pm.add_tritongpu_accelerate_matmul_pass(arch)
# TODO change interface of accelerate_matmul_pass