[autotuner] Add an option to print best_config for each key

This commit is contained in:
Lixun Zhang
2023-08-25 09:48:32 -05:00
committed by Lixun Zhang
parent ff7e707f87
commit b834f42ae4

View File

@@ -25,7 +25,7 @@ class OutOfResources(Exception):
class Autotuner(KernelInterface):
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None):
def __init__(self, fn, arg_names, configs, key, verbose, reset_to_zero, prune_configs_by: Dict = None):
'''
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
@@ -58,6 +58,7 @@ class Autotuner(KernelInterface):
self.perf_model, self.configs_top_k = perf_model, top_k
self.early_config_prune = early_config_prune
self.fn = fn
self.verbose = verbose
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
@@ -101,6 +102,8 @@ class Autotuner(KernelInterface):
self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args)
self.configs_timings = timings
if self.verbose:
print(str(key) + ": " + str(self.cache[key]))
config = self.cache[key]
else:
config = self.configs[0]
@@ -171,7 +174,7 @@ class Config:
return ', '.join(res)
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=False):
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.
@@ -202,9 +205,11 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
:param verbose: a boolean that controls whether the best_config for each key is printed
:type verbose: bool
"""
def decorator(fn):
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by)
return Autotuner(fn, fn.arg_names, configs, key, verbose, reset_to_zero, prune_configs_by)
return decorator