mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Set vecSize and maxPhase more generically
This commit is contained in:
@@ -89,21 +89,29 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
if (isKDimInner) {
|
||||
const int numBanks = 32;
|
||||
const int bankBitWidth = 32;
|
||||
const int SIMDWidth = 16;
|
||||
|
||||
// number of inner dimension rows per one pattern repeat
|
||||
int outerDimGranularity = mfmaEnc.getNonKDim();
|
||||
int typeBitWidth = eltTy.getIntOrFloatBitWidth();
|
||||
int innerDimLength = shape[order[0]];
|
||||
int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeBitWidth;
|
||||
|
||||
int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
|
||||
// Note: the following settings is customized for mfma_32x32x8f16
|
||||
// to avoid **load** bank conflicts
|
||||
// vecSize is set to k_base, which is 4
|
||||
// maxPhase is set to BLOCK_K/4 so that every 16 workitems will access
|
||||
// difference banks
|
||||
int vecSize = 4;
|
||||
int maxPhase = innerDimLength / 4;
|
||||
// Note: the following settings is customized to avoid
|
||||
// **load** bank conflicts
|
||||
//
|
||||
// vecSize is set to k_base, which is the number of elements each
|
||||
// workitem loads for one mfma instruction.
|
||||
// For now, the k_base rules are as follows
|
||||
// 1. All selected mfma instructions produce a single block
|
||||
// 2. For f16 data type, 2 VGPRs are used for operand A --> k_base = 4
|
||||
// 3. For non-f16 data types, 1 VGPR are used for operand A
|
||||
// k_base = 32 / elemTypeInBits
|
||||
// 4. TODO: what about f64?
|
||||
//
|
||||
// maxPhase is set to SIMDWidth / perPhase
|
||||
int vecSize = (eltTy.isF16() ? 64 : 32 ) / typeBitWidth;
|
||||
int maxPhase = SIMDWidth / perPhase;
|
||||
|
||||
return $_get(context, vecSize, perPhase, maxPhase, order);
|
||||
} else {
|
||||
|
||||
@@ -77,7 +77,7 @@ def optimize_ttgir(mod, num_stages, arch):
|
||||
pm = ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_tritongpu_coalesce_pass()
|
||||
#pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
if _is_cuda(arch):
|
||||
pm.add_tritongpu_accelerate_matmul_pass(arch)
|
||||
# TODO change interface of accelerate_matmul_pass
|
||||
|
||||
Reference in New Issue
Block a user