mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
I've add an option to yapf to do what we want for long lines, see https://github.com/google/yapf/pull/1177. We can now have a real Python formatter, yay! To make this PR, I ran my modified yapf over the repository, then looked over the full diff. Where yapf was mangling the param list of long function decls/calls (mostly kernels), I manually added `#` to put linebreaks where we want. I fixed up other formatting too -- mostly adding or removing a trailing comma from lists. Overall, trailing `#` was sufficient to get formatting similar to our current code. I didn't have to disable yapf anywhere. --------- Co-authored-by: Phil Tillet <phil@openai.com>
42 lines
1.3 KiB
Python
42 lines
1.3 KiB
Python
import torch
|
|
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
|
|
def test_kwargs():
|
|
N = 1024
|
|
src = torch.empty(N, device='cuda')
|
|
dst = torch.empty(N, device='cuda')
|
|
|
|
configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]
|
|
|
|
@triton.autotune(configs=configs, key=['N'])
|
|
@triton.jit
|
|
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
|
|
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
x = tl.load(src + offsets, mask=offsets < N)
|
|
tl.store(dst + offsets, x, mask=offsets < N)
|
|
|
|
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), )
|
|
_kernel[grid](dst, src, N)
|
|
_kernel[grid](dst=dst, src=src, N=N)
|
|
|
|
|
|
def test_restore():
|
|
N = 1024
|
|
src = torch.zeros(N, device='cuda')
|
|
|
|
configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]
|
|
|
|
@triton.autotune(configs=configs, key=['N'], restore_value=['src'])
|
|
@triton.jit
|
|
def _kernel(src, N, BLOCK_SIZE: tl.constexpr):
|
|
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
x = tl.load(src + offsets, mask=offsets < N) + 1
|
|
tl.store(src + offsets, x, mask=offsets < N)
|
|
|
|
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), )
|
|
_kernel[grid](src, N)
|
|
triton.testing.assert_close(src, torch.ones_like(src))
|