Update test_user_defined_persistent_warp_specialized_gemm for num-CTA > 1 (#2101)

- remove auto-tune for
test_user_defined_persistent_warp_specialized_gemm.
- remove unnecessary perf evaluation parts.
- add test cases of num-CTA > 1 for
test_user_defined_persistent_warp_specialized_gemm.
This commit is contained in:
jsh-20
2023-08-14 16:51:35 +08:00
committed by GitHub
parent facc1dcbac
commit 9055af1a5d

View File

@@ -302,20 +302,7 @@ def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K
th_c = torch.matmul(a, b)
torch.testing.assert_allclose(th_c, c, atol=1e-2, rtol=0)
# # #############################################Performance Evaluation#############################################
# fn = lambda: call_vintage()
# ms = triton.testing.do_bench(fn, warmup=25, rep=100)
# cur_gpu_perf = round(2. * M * N * K / ms * 1e-9, 2)
# print(' '.join(['Performance of', str(M), str(N), str(K), ':', str(ms), 'ms, ', str(cur_gpu_perf), 'TFLOPS']))
@triton.autotune(
configs=[
triton.Config({}, num_stages=3, num_warps=4, enable_warp_specialization=True),
# triton.Config({}, num_stages=3, num_warps=4, enable_warp_specialization=False),
],
key=['M', 'N', 'K'],
)
@triton.jit
def static_persistent_warp_specialized_matmul_kernel(
a_ptr, b_ptr, c_ptr,
@@ -355,13 +342,6 @@ def static_persistent_warp_specialized_matmul_kernel(
tl.store(c_ptrs, accumulator)
@triton.autotune(
configs=[
triton.Config({}, num_stages=3, num_warps=4, enable_warp_specialization=True),
# triton.Config({}, num_stages=3, num_warps=4, enable_warp_specialization=False),
],
key=['M', 'N', 'K'],
)
@triton.jit
def static_persistent_tma_warp_specialized_matmul_kernel(
a_ptr, b_ptr, c_ptr,
@@ -429,6 +409,12 @@ def static_persistent_tma_warp_specialized_matmul_kernel(
[4096, 4096, 256, 256, 128, 64, 1, False, True],
[4096, 4096, 256, 128, 256, 16, 1, False, True],
[4096, 4096, 256, 128, 256, 64, 1, False, True],
# numCTAs > 1
[2048, 2048, 64, 128, 128, 64, 2, False, True],
[2048, 2048, 128, 256, 128, 64, 4, False, True],
[4096, 4096, 128, 256, 128, 64, 4, False, True],
[4096, 4096, 256, 128, 256, 64, 4, False, True],
[4096, 4096, 256, 256, 256, 64, 4, False, True],
]
for use_tma in [False, True]
])
@@ -455,7 +441,9 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs)
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs,
num_warps=4, num_ctas=NUM_CTAS,
enable_warp_specialization=True)
else:
static_persistent_warp_specialized_matmul_kernel[grid](
a, b, c,
@@ -463,15 +451,12 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs)
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs,
num_warps=4, num_ctas=NUM_CTAS,
enable_warp_specialization=True)
th_c = torch.matmul(a, b)
torch.testing.assert_allclose(th_c, c, atol=1e-2, rtol=0)
# #############################################Performance Evaluation#############################################
# fn = lambda: call_stylish()
# ms = triton.testing.do_bench(fn, warmup=25, rep=100)
# cur_gpu_perf = round(2. * M * N * K / ms * 1e-9, 2)
# print(' '.join(['Performance of', str(M), str(N), str(K), ':', str(ms), 'ms, ', str(cur_gpu_perf), 'TFLOPS']))
@triton.jit