mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Reformat Python code with yapf. (#2589)
I've add an option to yapf to do what we want for long lines, see https://github.com/google/yapf/pull/1177. We can now have a real Python formatter, yay! To make this PR, I ran my modified yapf over the repository, then looked over the full diff. Where yapf was mangling the param list of long function decls/calls (mostly kernels), I manually added `#` to put linebreaks where we want. I fixed up other formatting too -- mostly adding or removing a trailing comma from lists. Overall, trailing `#` was sufficient to get formatting similar to our current code. I didn't have to disable yapf anywhere. --------- Co-authored-by: Phil Tillet <phil@openai.com>
This commit is contained in:
@@ -15,16 +15,12 @@ input_dtypes = ["float32", "float64"]
|
||||
out_dtypes = ["float16", "float32"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M, K, N, w_dtype, x_dtype, out_dtype",
|
||||
[
|
||||
(M, K, N, w, x, o)
|
||||
for (M, K, N) in [(128, 128, 128), (1280, 768, 1024)]
|
||||
for w in input_dtypes
|
||||
for x in input_dtypes
|
||||
for o in out_dtypes
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize("M, K, N, w_dtype, x_dtype, out_dtype",
|
||||
[(M, K, N, w, x, o) #
|
||||
for (M, K, N) in [(128, 128, 128), (1280, 768, 1024)] #
|
||||
for w in input_dtypes
|
||||
for x in input_dtypes #
|
||||
for o in out_dtypes])
|
||||
def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype):
|
||||
if x_dtype == w_dtype:
|
||||
pytest.skip("skip same dtype")
|
||||
@@ -44,15 +40,14 @@ def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype):
|
||||
grid = ((cdiv(M, BLOCK_M) * cdiv(N, BLOCK_N)), 1)
|
||||
|
||||
@jit
|
||||
def matmul_kernel(A, B, C, M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
dot_out_dtype: tl.constexpr,
|
||||
allow_tf32: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr
|
||||
):
|
||||
def matmul_kernel(A, B, C, M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
dot_out_dtype: tl.constexpr, #
|
||||
allow_tf32: tl.constexpr, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, #
|
||||
BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr):
|
||||
# matrix multiplication
|
||||
pid = tl.program_id(0)
|
||||
grid_m = tl.cdiv(M, BLOCK_M)
|
||||
@@ -91,16 +86,15 @@ def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype):
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
tl.store(C, acc, mask=mask)
|
||||
|
||||
matmul_kernel[grid](a, b, out_triton, M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
out_triton.stride(0), out_triton.stride(1),
|
||||
dot_out_dtype=triton_dtype,
|
||||
allow_tf32=allow_tf32,
|
||||
GROUP_M=8,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
BLOCK_K=BLOCK_K,
|
||||
)
|
||||
matmul_kernel[grid](
|
||||
a, b, out_triton, M, N, K, #
|
||||
a.stride(0), a.stride(1), #
|
||||
b.stride(0), b.stride(1), #
|
||||
out_triton.stride(0), out_triton.stride(1), dot_out_dtype=triton_dtype, #
|
||||
allow_tf32=allow_tf32, #
|
||||
GROUP_M=8, #
|
||||
BLOCK_M=BLOCK_M, #
|
||||
BLOCK_N=BLOCK_N, #
|
||||
BLOCK_K=BLOCK_K)
|
||||
|
||||
torch.testing.assert_close(out_torch, out_triton, atol=0.3, rtol=0.01)
|
||||
|
||||
@@ -14,18 +14,14 @@ def test_chained_matmul():
|
||||
return torch.einsum('MN,NK->MK', intermediate, c)
|
||||
|
||||
@triton.jit
|
||||
def chained_matmul_kernel(
|
||||
A, # shape: (m, k)
|
||||
B, # shape: (n, k)
|
||||
C, # shape: (n, k)
|
||||
out, # shape: (m, k)
|
||||
m, n, k: tl.constexpr,
|
||||
block_m: tl.constexpr,
|
||||
block_n: tl.constexpr,
|
||||
block_k: tl.constexpr):
|
||||
def chained_matmul_kernel(A, # shape: (m, k)
|
||||
B, # shape: (n, k)
|
||||
C, # shape: (n, k)
|
||||
out, # shape: (m, k)
|
||||
m, n, k: tl.constexpr, #
|
||||
block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr):
|
||||
|
||||
tl.static_assert(block_k == k,
|
||||
f"expected block_k == k but got {block_k} != {k}")
|
||||
tl.static_assert(block_k == k, f"expected block_k == k but got {block_k} != {k}")
|
||||
|
||||
block_ix = tl.program_id(0)
|
||||
a_tile = (block_ix * block_m + tl.arange(0, block_m))[:, None] * block_k \
|
||||
@@ -55,35 +51,33 @@ def test_chained_matmul():
|
||||
m, n, k = 32, 64, 128
|
||||
block_m, block_n, block_k = 16, 32, k
|
||||
|
||||
grid = (triton.cdiv(m, block_m),)
|
||||
a = torch.randint(low=0, high=2, size=(m, k), dtype=torch.float16,
|
||||
device='cuda')
|
||||
b = torch.randint(low=0, high=2, size=(n, k), dtype=torch.float16,
|
||||
device='cuda')
|
||||
grid = (triton.cdiv(m, block_m), )
|
||||
a = torch.randint(low=0, high=2, size=(m, k), dtype=torch.float16, device='cuda')
|
||||
b = torch.randint(low=0, high=2, size=(n, k), dtype=torch.float16, device='cuda')
|
||||
c = torch.randint_like(b, low=0, high=2)
|
||||
triton_result = torch.zeros_like(a)
|
||||
|
||||
torch_result = chained_matmul_reference(a, b, c)
|
||||
chained_matmul_kernel[grid](a, b, c, triton_result, m, n, k,
|
||||
block_m=block_m, block_n=block_n,
|
||||
block_k=block_k)
|
||||
chained_matmul_kernel[grid](
|
||||
a, b, c, triton_result, m, n, k, #
|
||||
block_m=block_m, block_n=block_n, block_k=block_k)
|
||||
|
||||
assert (torch_result == triton_result).all()
|
||||
|
||||
|
||||
def test_vecmat():
|
||||
|
||||
@triton.jit
|
||||
def batched_vecmat(
|
||||
# inputs
|
||||
A, # shape: [dim_m, dim_k]
|
||||
B, # shape: [dim_m, dim_n, dim_k]
|
||||
# dimensions
|
||||
# inputs
|
||||
A, # shape: [dim_m, dim_k]
|
||||
B, # shape: [dim_m, dim_n, dim_k]
|
||||
# dimensions
|
||||
dim_m, dim_n, dim_k,
|
||||
# outputs
|
||||
output,
|
||||
# block information
|
||||
block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr
|
||||
):
|
||||
# outputs
|
||||
output,
|
||||
# block information
|
||||
block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr):
|
||||
m_index = tl.program_id(0)
|
||||
n_index = tl.program_id(1)
|
||||
# Output tile
|
||||
@@ -125,9 +119,10 @@ def test_vecmat():
|
||||
|
||||
grid = (M // block_m, N // block_n)
|
||||
|
||||
batched_vecmat[grid](A_tri, B_tri, M, N, K, C_tri,
|
||||
block_m=block_m, block_n=block_n, block_k=block_k,
|
||||
num_warps=4, num_stages=1)
|
||||
batched_vecmat[grid](
|
||||
A_tri, B_tri, M, N, K, C_tri, #
|
||||
block_m=block_m, block_n=block_n, block_k=block_k, #
|
||||
num_warps=4, num_stages=1)
|
||||
|
||||
A_expanded = A[:, np.newaxis, :]
|
||||
A_broadcasted = np.broadcast_to(A_expanded, (M, N, K))
|
||||
@@ -137,18 +132,18 @@ def test_vecmat():
|
||||
np.testing.assert_allclose(C_ref, C_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("type", ["pre_load", "post_load", "post_pre_mixed", "post_load_two_iters", "post_load_three_iters"])
|
||||
@pytest.mark.parametrize("type",
|
||||
["pre_load", "post_load", "post_pre_mixed", "post_load_two_iters", "post_load_three_iters"])
|
||||
def test_iv_dependent_matmul(type):
|
||||
|
||||
@triton.jit
|
||||
def kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
type: tl.constexpr
|
||||
):
|
||||
def kernel(a_ptr, b_ptr, c_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
|
||||
type: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
pid_m = pid // num_pid_n
|
||||
@@ -216,15 +211,16 @@ def test_iv_dependent_matmul(type):
|
||||
b = torch.rand((K, N), device='cuda')
|
||||
|
||||
torch_output = torch.mm(a, b)
|
||||
triton_output = torch.empty_like(
|
||||
torch_output, device=torch_output.device)
|
||||
triton_output = torch.empty_like(torch_output, device=torch_output.device)
|
||||
|
||||
def grid(META):
|
||||
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
|
||||
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
|
||||
|
||||
num_stages = 4 if type == "post_load_three_iters" else 3
|
||||
kernel[grid](a, b, triton_output, M, N, K, a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1), triton_output.stride(0), triton_output.stride(1),
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
type=type, num_stages=num_stages)
|
||||
kernel[grid](
|
||||
a, b, triton_output, M, N, K, #
|
||||
a.stride(0), a.stride(1), b.stride(0), b.stride(1), #
|
||||
triton_output.stride(0), triton_output.stride(1), #
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, type=type, #
|
||||
num_stages=num_stages)
|
||||
torch.testing.assert_close(torch_output, triton_output, rtol=1e-2, atol=1e-2)
|
||||
|
||||
@@ -26,7 +26,6 @@ sm_clocks = {'v100': 1350, 'a100': 1350}
|
||||
mem_clocks = {'v100': 877, 'a100': 1215}
|
||||
|
||||
matmul_data = {
|
||||
# NOTE:
|
||||
'a100': {
|
||||
# square
|
||||
(512, 512, 512): {'float16': 0.108, 'float32': 0.097, 'int8': 0.05},
|
||||
@@ -49,10 +48,9 @@ matmul_data = {
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M, N, K, dtype_str',
|
||||
[(M, N, K, dtype_str)
|
||||
for M, N, K in matmul_data[DEVICE_NAME].keys()
|
||||
for dtype_str in ['float16']])
|
||||
@pytest.mark.parametrize('M, N, K, dtype_str', [(M, N, K, dtype_str)
|
||||
for M, N, K in matmul_data[DEVICE_NAME].keys()
|
||||
for dtype_str in ['float16']])
|
||||
def test_matmul(M, N, K, dtype_str):
|
||||
stream = torch.cuda.Stream()
|
||||
torch.cuda.set_stream(stream)
|
||||
@@ -86,8 +84,7 @@ def test_matmul(M, N, K, dtype_str):
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _add(x_ptr, y_ptr, output_ptr, n_elements,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
def _add(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
@@ -136,11 +133,11 @@ def test_elementwise(N, dtype_str):
|
||||
print_perf(ms, cur_gpu_util, ref_gpu_util)
|
||||
triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01)
|
||||
|
||||
|
||||
#######################
|
||||
# Flash-Attention
|
||||
#######################
|
||||
|
||||
|
||||
flash_attention_data = {
|
||||
"a100": {
|
||||
(4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.542,
|
||||
@@ -221,8 +218,7 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str):
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _sum(x_ptr, y_ptr, output_ptr, n_elements,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
def _sum(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
@@ -260,8 +256,8 @@ def test_reductions(N, dtype_str):
|
||||
y = torch.randn_like(z)
|
||||
else:
|
||||
info = torch.iinfo(dtype)
|
||||
x = torch.randint(info.min, info.max, (N,), dtype=dtype, device='cuda')
|
||||
y = torch.randint(info.min, info.max, (N,), dtype=dtype, device='cuda')
|
||||
x = torch.randint(info.min, info.max, (N, ), dtype=dtype, device='cuda')
|
||||
y = torch.randint(info.min, info.max, (N, ), dtype=dtype, device='cuda')
|
||||
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
|
||||
fn = lambda: _sum[grid](x, y, z, N, BLOCK_SIZE=1024)
|
||||
ms = triton.testing.do_bench_cudagraph(fn)
|
||||
|
||||
Reference in New Issue
Block a user