[FRONTEND] Add value restoration for autotuner (#2549)

For in-place kernels, neither `reset_to_zero` nor `Config.prehook`
provided in the autotuner can restore the values changed during the
tuning process, so I propose a recovery mechanism here.

---------

Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
Chenggang Zhao
2023-11-01 09:37:44 +08:00
committed by GitHub
parent 3650213218
commit e7fdfd76fb
2 changed files with 64 additions and 12 deletions

View File

@@ -20,3 +20,20 @@ def test_kwargs():
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']),)
_kernel[grid](dst, src, N)
_kernel[grid](dst=dst, src=src, N=N)
def test_restore():
N = 1024
src = torch.zeros(N, device='cuda')
configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]
@triton.autotune(configs=configs, key=['N'], restore_value=['src'])
@triton.jit
def _kernel(src, N, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(src + offsets, mask=offsets < N) + 1
tl.store(src + offsets, x, mask=offsets < N)
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']),)
_kernel[grid](src, N)
triton.testing.assert_close(src, torch.ones_like(src))