[FRONTEND] Add value restoration for autotuner (#2549)

For in-place kernels, neither `reset_to_zero` nor `Config.prehook`
provided in the autotuner can restore the values changed during the
tuning process, so I propose a recovery mechanism here.

---------

Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
Chenggang Zhao
2023-11-01 09:37:44 +08:00
committed by GitHub
parent 3650213218
commit e7fdfd76fb
2 changed files with 64 additions and 12 deletions

View File

@@ -20,3 +20,20 @@ def test_kwargs():
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']),)
_kernel[grid](dst, src, N)
_kernel[grid](dst=dst, src=src, N=N)
def test_restore():
N = 1024
src = torch.zeros(N, device='cuda')
configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]
@triton.autotune(configs=configs, key=['N'], restore_value=['src'])
@triton.jit
def _kernel(src, N, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(src + offsets, mask=offsets < N) + 1
tl.store(src + offsets, x, mask=offsets < N)
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']),)
_kernel[grid](src, N)
triton.testing.assert_close(src, torch.ones_like(src))

View File

@@ -25,7 +25,18 @@ class OutOfResources(Exception):
class Autotuner(KernelInterface):
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, warmup=25, rep=100):
def __init__(
self,
fn,
arg_names,
configs,
key,
reset_to_zero,
restore_value,
prune_configs_by: Dict = None,
warmup=25,
rep=100,
):
"""
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
@@ -38,18 +49,38 @@ class Autotuner(KernelInterface):
self.configs = configs
self.key_idx = [arg_names.index(k) for k in key]
self.cache = {}
# hook to reset all required tensor to zeros before relaunching a kernel
self.hook = lambda args: 0
self.arg_names = arg_names
# Reset to zero or restore values
self.reset_idx = []
if reset_to_zero is not None:
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
self.restore_idx = []
if restore_value is not None:
self.restore_idx = [arg_names.index(k) for k in restore_value]
def _hook(args):
# Hook to reset or restore for required tensors
self.pre_hook = lambda args, reset_only=False: 0
self.post_hook = lambda args: 0
if len(self.reset_idx) > 0 or len(self.restore_idx) > 0:
def _pre_hook(args, reset_only=False):
for i in self.reset_idx:
args[i].zero_()
if not reset_only:
self.restore_copies = [args[i].clone() for i in self.restore_idx]
self.hook = _hook
self.arg_names = arg_names
# prune configs
self.pre_hook = _pre_hook
if len(self.restore_idx) > 0:
def _post_hook(args):
for i, j in enumerate(self.restore_idx):
args[j].copy_(self.restore_copies[i])
self.restore_copies = []
self.post_hook = _post_hook
# Prune configs
if prune_configs_by:
perf_model, top_k = prune_configs_by["perf_model"], prune_configs_by["top_k"]
if "early_config_prune" in prune_configs_by:
@@ -58,6 +89,7 @@ class Autotuner(KernelInterface):
perf_model, top_k, early_config_prune = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k
self.early_config_prune = early_config_prune
self.fn = fn
self.warmup = warmup
self.rep = rep
@@ -78,7 +110,7 @@ class Autotuner(KernelInterface):
def kernel_call():
if config.pre_hook:
config.pre_hook(full_nargs)
self.hook(args)
self.pre_hook(args)
self.fn.run(
*args,
num_warps=config.num_warps,
@@ -88,6 +120,7 @@ class Autotuner(KernelInterface):
# enable_persistent=False,
**current,
)
self.post_hook(args)
try:
return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
@@ -115,7 +148,7 @@ class Autotuner(KernelInterface):
bench_end = time.time()
self.bench_time = bench_end - bench_start
self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args)
self.pre_hook(args, reset_only=True)
self.configs_timings = timings
config = self.cache[key]
else:
@@ -201,7 +234,7 @@ class Config:
self.num_ctas = num_ctas
self.num_stages = num_stages
self.enable_warp_specialization = enable_warp_specialization
# TODO[shuhaoj]: May make enable_persistent configurable in future if necessay.
# TODO[shuhaoj]: May make enable_persistent configurable in future if necessary.
self.enable_persistent = False
self.pre_hook = pre_hook
@@ -217,7 +250,7 @@ class Config:
return ", ".join(res)
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, warmup=25, rep=100):
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, warmup=25, rep=100):
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.
@@ -248,6 +281,8 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, warmup=25,
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
:param restore_value: a list of argument names whose value will be restored after evaluating any configs.
:type restore_value: list[str]
:param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25.
:type warmup: int
:param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100.
@@ -255,7 +290,7 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, warmup=25,
"""
def decorator(fn):
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, warmup, rep)
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, prune_configs_by, warmup, rep)
return decorator