Reformat Python code with yapf. (#2589)

I've add an option to yapf to do what we want for long lines, see
https://github.com/google/yapf/pull/1177.  We can now have a real Python
formatter, yay!

To make this PR, I ran my modified yapf over the repository, then looked
over the full diff.  Where yapf was mangling the param list of long
function decls/calls (mostly kernels), I manually added `#` to put
linebreaks where we want.  I fixed up other formatting too -- mostly
adding or removing a trailing comma from lists.

Overall, trailing `#` was sufficient to get formatting similar to our
current code.  I didn't have to disable yapf anywhere.

---------

Co-authored-by: Phil Tillet <phil@openai.com>
This commit is contained in:
Justin Lebar
2023-11-02 20:44:17 -07:00
committed by GitHub
parent dced22c4b7
commit df08301e76
85 changed files with 3802 additions and 3880 deletions

View File

@@ -15,16 +15,12 @@ input_dtypes = ["float32", "float64"]
out_dtypes = ["float16", "float32"]
@pytest.mark.parametrize(
"M, K, N, w_dtype, x_dtype, out_dtype",
[
(M, K, N, w, x, o)
for (M, K, N) in [(128, 128, 128), (1280, 768, 1024)]
for w in input_dtypes
for x in input_dtypes
for o in out_dtypes
]
)
@pytest.mark.parametrize("M, K, N, w_dtype, x_dtype, out_dtype",
[(M, K, N, w, x, o) #
for (M, K, N) in [(128, 128, 128), (1280, 768, 1024)] #
for w in input_dtypes
for x in input_dtypes #
for o in out_dtypes])
def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype):
if x_dtype == w_dtype:
pytest.skip("skip same dtype")
@@ -44,15 +40,14 @@ def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype):
grid = ((cdiv(M, BLOCK_M) * cdiv(N, BLOCK_N)), 1)
@jit
def matmul_kernel(A, B, C, M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
dot_out_dtype: tl.constexpr,
allow_tf32: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr
):
def matmul_kernel(A, B, C, M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
dot_out_dtype: tl.constexpr, #
allow_tf32: tl.constexpr, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, #
BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr):
# matrix multiplication
pid = tl.program_id(0)
grid_m = tl.cdiv(M, BLOCK_M)
@@ -91,16 +86,15 @@ def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype):
mask = (rm < M)[:, None] & (rn < N)[None, :]
tl.store(C, acc, mask=mask)
matmul_kernel[grid](a, b, out_triton, M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
out_triton.stride(0), out_triton.stride(1),
dot_out_dtype=triton_dtype,
allow_tf32=allow_tf32,
GROUP_M=8,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
)
matmul_kernel[grid](
a, b, out_triton, M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
out_triton.stride(0), out_triton.stride(1), dot_out_dtype=triton_dtype, #
allow_tf32=allow_tf32, #
GROUP_M=8, #
BLOCK_M=BLOCK_M, #
BLOCK_N=BLOCK_N, #
BLOCK_K=BLOCK_K)
torch.testing.assert_close(out_torch, out_triton, atol=0.3, rtol=0.01)

View File

@@ -14,18 +14,14 @@ def test_chained_matmul():
return torch.einsum('MN,NK->MK', intermediate, c)
@triton.jit
def chained_matmul_kernel(
A, # shape: (m, k)
B, # shape: (n, k)
C, # shape: (n, k)
out, # shape: (m, k)
m, n, k: tl.constexpr,
block_m: tl.constexpr,
block_n: tl.constexpr,
block_k: tl.constexpr):
def chained_matmul_kernel(A, # shape: (m, k)
B, # shape: (n, k)
C, # shape: (n, k)
out, # shape: (m, k)
m, n, k: tl.constexpr, #
block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr):
tl.static_assert(block_k == k,
f"expected block_k == k but got {block_k} != {k}")
tl.static_assert(block_k == k, f"expected block_k == k but got {block_k} != {k}")
block_ix = tl.program_id(0)
a_tile = (block_ix * block_m + tl.arange(0, block_m))[:, None] * block_k \
@@ -55,35 +51,33 @@ def test_chained_matmul():
m, n, k = 32, 64, 128
block_m, block_n, block_k = 16, 32, k
grid = (triton.cdiv(m, block_m),)
a = torch.randint(low=0, high=2, size=(m, k), dtype=torch.float16,
device='cuda')
b = torch.randint(low=0, high=2, size=(n, k), dtype=torch.float16,
device='cuda')
grid = (triton.cdiv(m, block_m), )
a = torch.randint(low=0, high=2, size=(m, k), dtype=torch.float16, device='cuda')
b = torch.randint(low=0, high=2, size=(n, k), dtype=torch.float16, device='cuda')
c = torch.randint_like(b, low=0, high=2)
triton_result = torch.zeros_like(a)
torch_result = chained_matmul_reference(a, b, c)
chained_matmul_kernel[grid](a, b, c, triton_result, m, n, k,
block_m=block_m, block_n=block_n,
block_k=block_k)
chained_matmul_kernel[grid](
a, b, c, triton_result, m, n, k, #
block_m=block_m, block_n=block_n, block_k=block_k)
assert (torch_result == triton_result).all()
def test_vecmat():
@triton.jit
def batched_vecmat(
# inputs
A, # shape: [dim_m, dim_k]
B, # shape: [dim_m, dim_n, dim_k]
# dimensions
# inputs
A, # shape: [dim_m, dim_k]
B, # shape: [dim_m, dim_n, dim_k]
# dimensions
dim_m, dim_n, dim_k,
# outputs
output,
# block information
block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr
):
# outputs
output,
# block information
block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr):
m_index = tl.program_id(0)
n_index = tl.program_id(1)
# Output tile
@@ -125,9 +119,10 @@ def test_vecmat():
grid = (M // block_m, N // block_n)
batched_vecmat[grid](A_tri, B_tri, M, N, K, C_tri,
block_m=block_m, block_n=block_n, block_k=block_k,
num_warps=4, num_stages=1)
batched_vecmat[grid](
A_tri, B_tri, M, N, K, C_tri, #
block_m=block_m, block_n=block_n, block_k=block_k, #
num_warps=4, num_stages=1)
A_expanded = A[:, np.newaxis, :]
A_broadcasted = np.broadcast_to(A_expanded, (M, N, K))
@@ -137,18 +132,18 @@ def test_vecmat():
np.testing.assert_allclose(C_ref, C_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
@pytest.mark.parametrize("type", ["pre_load", "post_load", "post_pre_mixed", "post_load_two_iters", "post_load_three_iters"])
@pytest.mark.parametrize("type",
["pre_load", "post_load", "post_pre_mixed", "post_load_two_iters", "post_load_three_iters"])
def test_iv_dependent_matmul(type):
@triton.jit
def kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
type: tl.constexpr
):
def kernel(a_ptr, b_ptr, c_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
type: tl.constexpr):
pid = tl.program_id(axis=0)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // num_pid_n
@@ -216,15 +211,16 @@ def test_iv_dependent_matmul(type):
b = torch.rand((K, N), device='cuda')
torch_output = torch.mm(a, b)
triton_output = torch.empty_like(
torch_output, device=torch_output.device)
triton_output = torch.empty_like(torch_output, device=torch_output.device)
def grid(META):
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
num_stages = 4 if type == "post_load_three_iters" else 3
kernel[grid](a, b, triton_output, M, N, K, a.stride(0), a.stride(1),
b.stride(0), b.stride(1), triton_output.stride(0), triton_output.stride(1),
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
type=type, num_stages=num_stages)
kernel[grid](
a, b, triton_output, M, N, K, #
a.stride(0), a.stride(1), b.stride(0), b.stride(1), #
triton_output.stride(0), triton_output.stride(1), #
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, type=type, #
num_stages=num_stages)
torch.testing.assert_close(torch_output, triton_output, rtol=1e-2, atol=1e-2)

View File

@@ -26,7 +26,6 @@ sm_clocks = {'v100': 1350, 'a100': 1350}
mem_clocks = {'v100': 877, 'a100': 1215}
matmul_data = {
# NOTE:
'a100': {
# square
(512, 512, 512): {'float16': 0.108, 'float32': 0.097, 'int8': 0.05},
@@ -49,10 +48,9 @@ matmul_data = {
}
@pytest.mark.parametrize('M, N, K, dtype_str',
[(M, N, K, dtype_str)
for M, N, K in matmul_data[DEVICE_NAME].keys()
for dtype_str in ['float16']])
@pytest.mark.parametrize('M, N, K, dtype_str', [(M, N, K, dtype_str)
for M, N, K in matmul_data[DEVICE_NAME].keys()
for dtype_str in ['float16']])
def test_matmul(M, N, K, dtype_str):
stream = torch.cuda.Stream()
torch.cuda.set_stream(stream)
@@ -86,8 +84,7 @@ def test_matmul(M, N, K, dtype_str):
@triton.jit
def _add(x_ptr, y_ptr, output_ptr, n_elements,
BLOCK_SIZE: tl.constexpr):
def _add(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
@@ -136,11 +133,11 @@ def test_elementwise(N, dtype_str):
print_perf(ms, cur_gpu_util, ref_gpu_util)
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01)
#######################
# Flash-Attention
#######################
flash_attention_data = {
"a100": {
(4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.542,
@@ -221,8 +218,7 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str):
@triton.jit
def _sum(x_ptr, y_ptr, output_ptr, n_elements,
BLOCK_SIZE: tl.constexpr):
def _sum(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
@@ -260,8 +256,8 @@ def test_reductions(N, dtype_str):
y = torch.randn_like(z)
else:
info = torch.iinfo(dtype)
x = torch.randint(info.min, info.max, (N,), dtype=dtype, device='cuda')
y = torch.randint(info.min, info.max, (N,), dtype=dtype, device='cuda')
x = torch.randint(info.min, info.max, (N, ), dtype=dtype, device='cuda')
y = torch.randint(info.min, info.max, (N, ), dtype=dtype, device='cuda')
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
fn = lambda: _sum[grid](x, y, z, N, BLOCK_SIZE=1024)
ms = triton.testing.do_bench_cudagraph(fn)