diff --git a/python/test/unit/runtime/test_autotuner.py b/python/test/unit/runtime/test_autotuner.py index c425a3669..0c02830ec 100644 --- a/python/test/unit/runtime/test_autotuner.py +++ b/python/test/unit/runtime/test_autotuner.py @@ -20,3 +20,20 @@ def test_kwargs(): grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']),) _kernel[grid](dst, src, N) _kernel[grid](dst=dst, src=src, N=N) + + +def test_restore(): + N = 1024 + src = torch.zeros(N, device='cuda') + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] + + @triton.autotune(configs=configs, key=['N'], restore_value=['src']) + @triton.jit + def _kernel(src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + 1 + tl.store(src + offsets, x, mask=offsets < N) + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']),) + _kernel[grid](src, N) + triton.testing.assert_close(src, torch.ones_like(src)) diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 7b5e20730..7deca3711 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -25,7 +25,18 @@ class OutOfResources(Exception): class Autotuner(KernelInterface): - def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, warmup=25, rep=100): + def __init__( + self, + fn, + arg_names, + configs, + key, + reset_to_zero, + restore_value, + prune_configs_by: Dict = None, + warmup=25, + rep=100, + ): """ :param prune_configs_by: a dict of functions that are used to prune configs, fields: 'perf_model': performance model used to predicate running time with different configs, returns running time @@ -38,18 +49,38 @@ class Autotuner(KernelInterface): self.configs = configs self.key_idx = [arg_names.index(k) for k in key] self.cache = {} - # hook to reset all required tensor to zeros before relaunching a kernel - self.hook = lambda args: 0 + self.arg_names = arg_names + + # Reset to zero or restore values + self.reset_idx = [] if reset_to_zero is not None: self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + self.restore_idx = [] + if restore_value is not None: + self.restore_idx = [arg_names.index(k) for k in restore_value] - def _hook(args): + # Hook to reset or restore for required tensors + self.pre_hook = lambda args, reset_only=False: 0 + self.post_hook = lambda args: 0 + if len(self.reset_idx) > 0 or len(self.restore_idx) > 0: + + def _pre_hook(args, reset_only=False): for i in self.reset_idx: args[i].zero_() + if not reset_only: + self.restore_copies = [args[i].clone() for i in self.restore_idx] - self.hook = _hook - self.arg_names = arg_names - # prune configs + self.pre_hook = _pre_hook + if len(self.restore_idx) > 0: + + def _post_hook(args): + for i, j in enumerate(self.restore_idx): + args[j].copy_(self.restore_copies[i]) + self.restore_copies = [] + + self.post_hook = _post_hook + + # Prune configs if prune_configs_by: perf_model, top_k = prune_configs_by["perf_model"], prune_configs_by["top_k"] if "early_config_prune" in prune_configs_by: @@ -58,6 +89,7 @@ class Autotuner(KernelInterface): perf_model, top_k, early_config_prune = None, None, None self.perf_model, self.configs_top_k = perf_model, top_k self.early_config_prune = early_config_prune + self.fn = fn self.warmup = warmup self.rep = rep @@ -78,7 +110,7 @@ class Autotuner(KernelInterface): def kernel_call(): if config.pre_hook: config.pre_hook(full_nargs) - self.hook(args) + self.pre_hook(args) self.fn.run( *args, num_warps=config.num_warps, @@ -88,6 +120,7 @@ class Autotuner(KernelInterface): # enable_persistent=False, **current, ) + self.post_hook(args) try: return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8)) @@ -115,7 +148,7 @@ class Autotuner(KernelInterface): bench_end = time.time() self.bench_time = bench_end - bench_start self.cache[key] = builtins.min(timings, key=timings.get) - self.hook(args) + self.pre_hook(args, reset_only=True) self.configs_timings = timings config = self.cache[key] else: @@ -201,7 +234,7 @@ class Config: self.num_ctas = num_ctas self.num_stages = num_stages self.enable_warp_specialization = enable_warp_specialization - # TODO[shuhaoj]: May make enable_persistent configurable in future if necessay. + # TODO[shuhaoj]: May make enable_persistent configurable in future if necessary. self.enable_persistent = False self.pre_hook = pre_hook @@ -217,7 +250,7 @@ class Config: return ", ".join(res) -def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, warmup=25, rep=100): +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, warmup=25, rep=100): """ Decorator for auto-tuning a :code:`triton.jit`'d function. @@ -248,6 +281,8 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, warmup=25, 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs. :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. :type reset_to_zero: list[str] + :param restore_value: a list of argument names whose value will be restored after evaluating any configs. + :type restore_value: list[str] :param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25. :type warmup: int :param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100. @@ -255,7 +290,7 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, warmup=25, """ def decorator(fn): - return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, warmup, rep) + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, prune_configs_by, warmup, rep) return decorator