Files
ROCm/python/test/unit/runtime/test_autotuner.py
Justin Lebar df08301e76 Reformat Python code with yapf. (#2589)
I've add an option to yapf to do what we want for long lines, see
https://github.com/google/yapf/pull/1177.  We can now have a real Python
formatter, yay!

To make this PR, I ran my modified yapf over the repository, then looked
over the full diff.  Where yapf was mangling the param list of long
function decls/calls (mostly kernels), I manually added `#` to put
linebreaks where we want.  I fixed up other formatting too -- mostly
adding or removing a trailing comma from lists.

Overall, trailing `#` was sufficient to get formatting similar to our
current code.  I didn't have to disable yapf anywhere.

---------

Co-authored-by: Phil Tillet <phil@openai.com>
2023-11-02 20:44:17 -07:00

42 lines
1.3 KiB
Python

import torch
import triton
import triton.language as tl
def test_kwargs():
N = 1024
src = torch.empty(N, device='cuda')
dst = torch.empty(N, device='cuda')
configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]
@triton.autotune(configs=configs, key=['N'])
@triton.jit
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(src + offsets, mask=offsets < N)
tl.store(dst + offsets, x, mask=offsets < N)
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), )
_kernel[grid](dst, src, N)
_kernel[grid](dst=dst, src=src, N=N)
def test_restore():
N = 1024
src = torch.zeros(N, device='cuda')
configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]
@triton.autotune(configs=configs, key=['N'], restore_value=['src'])
@triton.jit
def _kernel(src, N, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(src + offsets, mask=offsets < N) + 1
tl.store(src + offsets, x, mask=offsets < N)
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), )
_kernel[grid](src, N)
triton.testing.assert_close(src, torch.ones_like(src))