mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
[FRONTEND] Add value restoration for autotuner (#2549)
For in-place kernels, neither `reset_to_zero` nor `Config.prehook` provided in the autotuner can restore the values changed during the tuning process, so I propose a recovery mechanism here. --------- Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com> Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
@@ -20,3 +20,20 @@ def test_kwargs():
|
||||
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']),)
|
||||
_kernel[grid](dst, src, N)
|
||||
_kernel[grid](dst=dst, src=src, N=N)
|
||||
|
||||
|
||||
def test_restore():
|
||||
N = 1024
|
||||
src = torch.zeros(N, device='cuda')
|
||||
|
||||
configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]
|
||||
|
||||
@triton.autotune(configs=configs, key=['N'], restore_value=['src'])
|
||||
@triton.jit
|
||||
def _kernel(src, N, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(src + offsets, mask=offsets < N) + 1
|
||||
tl.store(src + offsets, x, mask=offsets < N)
|
||||
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']),)
|
||||
_kernel[grid](src, N)
|
||||
triton.testing.assert_close(src, torch.ones_like(src))
|
||||
|
||||
Reference in New Issue
Block a user