mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Fixed a bug related to split_k and prune unnecessary tuning space (#332)
* refine tuning scrit by adding prune_configs, also fixed a bug in generating tuning configs * fixed a bug in returning the empty config
This commit is contained in:
@@ -12,8 +12,33 @@ import yaml
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
|
||||
|
||||
# global flag to indicate whether using the full tuing space
|
||||
tuning_full_space = False
|
||||
tuning_full_space = True
|
||||
|
||||
# pruned some unreasonable config
|
||||
def prune_configs(configs, named_args):
|
||||
# call only for full tuning space
|
||||
if not tuning_full_space:
|
||||
return configs
|
||||
|
||||
SIZE_M = named_args["a_ptr"].shape[0]
|
||||
SIZE_N = named_args["b_ptr"].shape[1]
|
||||
|
||||
pruned_configs = []
|
||||
for config in configs:
|
||||
kw = config.kwargs
|
||||
BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K =\
|
||||
kw["BLOCK_SIZE_M"], kw["BLOCK_SIZE_N"], kw["BLOCK_SIZE_K"]
|
||||
if SIZE_M <=32 and BLOCK_SIZE_M != 32:
|
||||
continue
|
||||
if SIZE_N <=32 and BLOCK_SIZE_N != 32:
|
||||
continue
|
||||
pruned_configs.append(config)
|
||||
|
||||
return pruned_configs
|
||||
|
||||
|
||||
def get_full_tuning_space(use_split_k):
|
||||
configs = []
|
||||
@@ -22,7 +47,7 @@ def get_full_tuning_space(use_split_k):
|
||||
|
||||
block_mn_range = [32, 64, 128]
|
||||
block_k_range = [32, 64]
|
||||
split_k_range = [2, 4, 5, 8, 10]
|
||||
split_k_range = [1, 2, 4, 5, 8, 10]
|
||||
num_warps_range = [1, 2, 4, 8]
|
||||
group_m_range = [1, 4, 8]
|
||||
|
||||
@@ -31,10 +56,12 @@ def get_full_tuning_space(use_split_k):
|
||||
for block_k in block_k_range:
|
||||
for num_warps in num_warps_range:
|
||||
for group_m in group_m_range:
|
||||
configs.append(triton.Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m}, num_stages=1, num_warps=num_warps))
|
||||
if use_split_k:
|
||||
for split_k in split_k_range:
|
||||
configs.append(triton.Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m, 'SPLIT_K': split_k}, num_stages=1, num_warps=num_warps))
|
||||
else:
|
||||
configs.append(triton.Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m}, num_stages=1, num_warps=num_warps))
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
@@ -59,6 +86,11 @@ def get_full_tuning_space(use_split_k):
|
||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'SPLIT_K': 10}, num_stages=1, num_warps=1),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
prune_configs_by={
|
||||
'early_config_prune': prune_configs,
|
||||
'perf_model': None,
|
||||
"top_k": None
|
||||
},
|
||||
)
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_SIZE_K'] * args['SPLIT_K']) == 0,
|
||||
@@ -170,6 +202,11 @@ def matmul_kernel_splitK(
|
||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=1, num_warps=2),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
prune_configs_by={
|
||||
'early_config_prune': prune_configs,
|
||||
'perf_model': None,
|
||||
"top_k": None
|
||||
},
|
||||
)
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % args['BLOCK_SIZE_K'] == 0,
|
||||
|
||||
Reference in New Issue
Block a user