mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
[FRONTEND] Add value restoration for autotuner (#2549)
For in-place kernels, neither `reset_to_zero` nor `Config.prehook` provided in the autotuner can restore the values changed during the tuning process, so I propose a recovery mechanism here. --------- Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com> Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
@@ -20,3 +20,20 @@ def test_kwargs():
|
||||
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']),)
|
||||
_kernel[grid](dst, src, N)
|
||||
_kernel[grid](dst=dst, src=src, N=N)
|
||||
|
||||
|
||||
def test_restore():
|
||||
N = 1024
|
||||
src = torch.zeros(N, device='cuda')
|
||||
|
||||
configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]
|
||||
|
||||
@triton.autotune(configs=configs, key=['N'], restore_value=['src'])
|
||||
@triton.jit
|
||||
def _kernel(src, N, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(src + offsets, mask=offsets < N) + 1
|
||||
tl.store(src + offsets, x, mask=offsets < N)
|
||||
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']),)
|
||||
_kernel[grid](src, N)
|
||||
triton.testing.assert_close(src, torch.ones_like(src))
|
||||
|
||||
@@ -25,7 +25,18 @@ class OutOfResources(Exception):
|
||||
|
||||
|
||||
class Autotuner(KernelInterface):
|
||||
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, warmup=25, rep=100):
|
||||
def __init__(
|
||||
self,
|
||||
fn,
|
||||
arg_names,
|
||||
configs,
|
||||
key,
|
||||
reset_to_zero,
|
||||
restore_value,
|
||||
prune_configs_by: Dict = None,
|
||||
warmup=25,
|
||||
rep=100,
|
||||
):
|
||||
"""
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
@@ -38,18 +49,38 @@ class Autotuner(KernelInterface):
|
||||
self.configs = configs
|
||||
self.key_idx = [arg_names.index(k) for k in key]
|
||||
self.cache = {}
|
||||
# hook to reset all required tensor to zeros before relaunching a kernel
|
||||
self.hook = lambda args: 0
|
||||
self.arg_names = arg_names
|
||||
|
||||
# Reset to zero or restore values
|
||||
self.reset_idx = []
|
||||
if reset_to_zero is not None:
|
||||
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
|
||||
self.restore_idx = []
|
||||
if restore_value is not None:
|
||||
self.restore_idx = [arg_names.index(k) for k in restore_value]
|
||||
|
||||
def _hook(args):
|
||||
# Hook to reset or restore for required tensors
|
||||
self.pre_hook = lambda args, reset_only=False: 0
|
||||
self.post_hook = lambda args: 0
|
||||
if len(self.reset_idx) > 0 or len(self.restore_idx) > 0:
|
||||
|
||||
def _pre_hook(args, reset_only=False):
|
||||
for i in self.reset_idx:
|
||||
args[i].zero_()
|
||||
if not reset_only:
|
||||
self.restore_copies = [args[i].clone() for i in self.restore_idx]
|
||||
|
||||
self.hook = _hook
|
||||
self.arg_names = arg_names
|
||||
# prune configs
|
||||
self.pre_hook = _pre_hook
|
||||
if len(self.restore_idx) > 0:
|
||||
|
||||
def _post_hook(args):
|
||||
for i, j in enumerate(self.restore_idx):
|
||||
args[j].copy_(self.restore_copies[i])
|
||||
self.restore_copies = []
|
||||
|
||||
self.post_hook = _post_hook
|
||||
|
||||
# Prune configs
|
||||
if prune_configs_by:
|
||||
perf_model, top_k = prune_configs_by["perf_model"], prune_configs_by["top_k"]
|
||||
if "early_config_prune" in prune_configs_by:
|
||||
@@ -58,6 +89,7 @@ class Autotuner(KernelInterface):
|
||||
perf_model, top_k, early_config_prune = None, None, None
|
||||
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||
self.early_config_prune = early_config_prune
|
||||
|
||||
self.fn = fn
|
||||
self.warmup = warmup
|
||||
self.rep = rep
|
||||
@@ -78,7 +110,7 @@ class Autotuner(KernelInterface):
|
||||
def kernel_call():
|
||||
if config.pre_hook:
|
||||
config.pre_hook(full_nargs)
|
||||
self.hook(args)
|
||||
self.pre_hook(args)
|
||||
self.fn.run(
|
||||
*args,
|
||||
num_warps=config.num_warps,
|
||||
@@ -88,6 +120,7 @@ class Autotuner(KernelInterface):
|
||||
# enable_persistent=False,
|
||||
**current,
|
||||
)
|
||||
self.post_hook(args)
|
||||
|
||||
try:
|
||||
return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
|
||||
@@ -115,7 +148,7 @@ class Autotuner(KernelInterface):
|
||||
bench_end = time.time()
|
||||
self.bench_time = bench_end - bench_start
|
||||
self.cache[key] = builtins.min(timings, key=timings.get)
|
||||
self.hook(args)
|
||||
self.pre_hook(args, reset_only=True)
|
||||
self.configs_timings = timings
|
||||
config = self.cache[key]
|
||||
else:
|
||||
@@ -201,7 +234,7 @@ class Config:
|
||||
self.num_ctas = num_ctas
|
||||
self.num_stages = num_stages
|
||||
self.enable_warp_specialization = enable_warp_specialization
|
||||
# TODO[shuhaoj]: May make enable_persistent configurable in future if necessay.
|
||||
# TODO[shuhaoj]: May make enable_persistent configurable in future if necessary.
|
||||
self.enable_persistent = False
|
||||
self.pre_hook = pre_hook
|
||||
|
||||
@@ -217,7 +250,7 @@ class Config:
|
||||
return ", ".join(res)
|
||||
|
||||
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, warmup=25, rep=100):
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, warmup=25, rep=100):
|
||||
"""
|
||||
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||
|
||||
@@ -248,6 +281,8 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, warmup=25,
|
||||
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
|
||||
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
||||
:type reset_to_zero: list[str]
|
||||
:param restore_value: a list of argument names whose value will be restored after evaluating any configs.
|
||||
:type restore_value: list[str]
|
||||
:param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25.
|
||||
:type warmup: int
|
||||
:param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100.
|
||||
@@ -255,7 +290,7 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, warmup=25,
|
||||
"""
|
||||
|
||||
def decorator(fn):
|
||||
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, warmup, rep)
|
||||
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, prune_configs_by, warmup, rep)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
Reference in New Issue
Block a user