Fixed a bug related to split_k and prune unnecessary tuning space (#332)

* refine tuning scrit by adding prune_configs, also fixed a bug in generating tuning configs

* fixed a bug in returning the empty config
This commit is contained in:
Shucai Xiao
2023-09-21 23:47:14 -05:00
committed by GitHub
parent a8574be74d
commit 10795d8fd3
2 changed files with 41 additions and 4 deletions

View File

@@ -12,8 +12,33 @@ import yaml
import os
import subprocess
# global flag to indicate whether using the full tuing space
tuning_full_space = False
tuning_full_space = True
# pruned some unreasonable config
def prune_configs(configs, named_args):
# call only for full tuning space
if not tuning_full_space:
return configs
SIZE_M = named_args["a_ptr"].shape[0]
SIZE_N = named_args["b_ptr"].shape[1]
pruned_configs = []
for config in configs:
kw = config.kwargs
BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K =\
kw["BLOCK_SIZE_M"], kw["BLOCK_SIZE_N"], kw["BLOCK_SIZE_K"]
if SIZE_M <=32 and BLOCK_SIZE_M != 32:
continue
if SIZE_N <=32 and BLOCK_SIZE_N != 32:
continue
pruned_configs.append(config)
return pruned_configs
def get_full_tuning_space(use_split_k):
configs = []
@@ -22,7 +47,7 @@ def get_full_tuning_space(use_split_k):
block_mn_range = [32, 64, 128]
block_k_range = [32, 64]
split_k_range = [2, 4, 5, 8, 10]
split_k_range = [1, 2, 4, 5, 8, 10]
num_warps_range = [1, 2, 4, 8]
group_m_range = [1, 4, 8]
@@ -31,10 +56,12 @@ def get_full_tuning_space(use_split_k):
for block_k in block_k_range:
for num_warps in num_warps_range:
for group_m in group_m_range:
configs.append(triton.Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m}, num_stages=1, num_warps=num_warps))
if use_split_k:
for split_k in split_k_range:
configs.append(triton.Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m, 'SPLIT_K': split_k}, num_stages=1, num_warps=num_warps))
else:
configs.append(triton.Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m}, num_stages=1, num_warps=num_warps))
return configs
@@ -59,6 +86,11 @@ def get_full_tuning_space(use_split_k):
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'SPLIT_K': 10}, num_stages=1, num_warps=1),
],
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': prune_configs,
'perf_model': None,
"top_k": None
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_SIZE_K'] * args['SPLIT_K']) == 0,
@@ -170,6 +202,11 @@ def matmul_kernel_splitK(
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=1, num_warps=2),
],
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': prune_configs,
'perf_model': None,
"top_k": None
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % args['BLOCK_SIZE_K'] == 0,