mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Reformat Python code with yapf. (#2589)
I've add an option to yapf to do what we want for long lines, see https://github.com/google/yapf/pull/1177. We can now have a real Python formatter, yay! To make this PR, I ran my modified yapf over the repository, then looked over the full diff. Where yapf was mangling the param list of long function decls/calls (mostly kernels), I manually added `#` to put linebreaks where we want. I fixed up other formatting too -- mostly adding or removing a trailing comma from lists. Overall, trailing `#` was sufficient to get formatting similar to our current code. I didn't have to disable yapf anywhere. --------- Co-authored-by: Phil Tillet <phil@openai.com>
This commit is contained in:
@@ -26,7 +26,6 @@ sm_clocks = {'v100': 1350, 'a100': 1350}
|
||||
mem_clocks = {'v100': 877, 'a100': 1215}
|
||||
|
||||
matmul_data = {
|
||||
# NOTE:
|
||||
'a100': {
|
||||
# square
|
||||
(512, 512, 512): {'float16': 0.108, 'float32': 0.097, 'int8': 0.05},
|
||||
@@ -49,10 +48,9 @@ matmul_data = {
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M, N, K, dtype_str',
|
||||
[(M, N, K, dtype_str)
|
||||
for M, N, K in matmul_data[DEVICE_NAME].keys()
|
||||
for dtype_str in ['float16']])
|
||||
@pytest.mark.parametrize('M, N, K, dtype_str', [(M, N, K, dtype_str)
|
||||
for M, N, K in matmul_data[DEVICE_NAME].keys()
|
||||
for dtype_str in ['float16']])
|
||||
def test_matmul(M, N, K, dtype_str):
|
||||
stream = torch.cuda.Stream()
|
||||
torch.cuda.set_stream(stream)
|
||||
@@ -86,8 +84,7 @@ def test_matmul(M, N, K, dtype_str):
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _add(x_ptr, y_ptr, output_ptr, n_elements,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
def _add(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
@@ -136,11 +133,11 @@ def test_elementwise(N, dtype_str):
|
||||
print_perf(ms, cur_gpu_util, ref_gpu_util)
|
||||
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01)
|
||||
|
||||
|
||||
#######################
|
||||
# Flash-Attention
|
||||
#######################
|
||||
|
||||
|
||||
flash_attention_data = {
|
||||
"a100": {
|
||||
(4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.542,
|
||||
@@ -221,8 +218,7 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str):
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _sum(x_ptr, y_ptr, output_ptr, n_elements,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
def _sum(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
@@ -260,8 +256,8 @@ def test_reductions(N, dtype_str):
|
||||
y = torch.randn_like(z)
|
||||
else:
|
||||
info = torch.iinfo(dtype)
|
||||
x = torch.randint(info.min, info.max, (N,), dtype=dtype, device='cuda')
|
||||
y = torch.randint(info.min, info.max, (N,), dtype=dtype, device='cuda')
|
||||
x = torch.randint(info.min, info.max, (N, ), dtype=dtype, device='cuda')
|
||||
y = torch.randint(info.min, info.max, (N, ), dtype=dtype, device='cuda')
|
||||
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
|
||||
fn = lambda: _sum[grid](x, y, z, N, BLOCK_SIZE=1024)
|
||||
ms = triton.testing.do_bench_cudagraph(fn)
|
||||
|
||||
Reference in New Issue
Block a user