mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Update test_user_defined_persistent_warp_specialized_gemm for num-CTA > 1 (#2101)
- remove auto-tune for test_user_defined_persistent_warp_specialized_gemm. - remove unnecessary perf evaluation parts. - add test cases of num-CTA > 1 for test_user_defined_persistent_warp_specialized_gemm.
This commit is contained in:
@@ -302,20 +302,7 @@ def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K
|
||||
th_c = torch.matmul(a, b)
|
||||
torch.testing.assert_allclose(th_c, c, atol=1e-2, rtol=0)
|
||||
|
||||
# # #############################################Performance Evaluation#############################################
|
||||
# fn = lambda: call_vintage()
|
||||
# ms = triton.testing.do_bench(fn, warmup=25, rep=100)
|
||||
# cur_gpu_perf = round(2. * M * N * K / ms * 1e-9, 2)
|
||||
# print(' '.join(['Performance of', str(M), str(N), str(K), ':', str(ms), 'ms, ', str(cur_gpu_perf), 'TFLOPS']))
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_stages=3, num_warps=4, enable_warp_specialization=True),
|
||||
# triton.Config({}, num_stages=3, num_warps=4, enable_warp_specialization=False),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
@triton.jit
|
||||
def static_persistent_warp_specialized_matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
@@ -355,13 +342,6 @@ def static_persistent_warp_specialized_matmul_kernel(
|
||||
tl.store(c_ptrs, accumulator)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_stages=3, num_warps=4, enable_warp_specialization=True),
|
||||
# triton.Config({}, num_stages=3, num_warps=4, enable_warp_specialization=False),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
@triton.jit
|
||||
def static_persistent_tma_warp_specialized_matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
@@ -429,6 +409,12 @@ def static_persistent_tma_warp_specialized_matmul_kernel(
|
||||
[4096, 4096, 256, 256, 128, 64, 1, False, True],
|
||||
[4096, 4096, 256, 128, 256, 16, 1, False, True],
|
||||
[4096, 4096, 256, 128, 256, 64, 1, False, True],
|
||||
# numCTAs > 1
|
||||
[2048, 2048, 64, 128, 128, 64, 2, False, True],
|
||||
[2048, 2048, 128, 256, 128, 64, 4, False, True],
|
||||
[4096, 4096, 128, 256, 128, 64, 4, False, True],
|
||||
[4096, 4096, 256, 128, 256, 64, 4, False, True],
|
||||
[4096, 4096, 256, 256, 256, 64, 4, False, True],
|
||||
]
|
||||
for use_tma in [False, True]
|
||||
])
|
||||
@@ -455,7 +441,9 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs)
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs,
|
||||
num_warps=4, num_ctas=NUM_CTAS,
|
||||
enable_warp_specialization=True)
|
||||
else:
|
||||
static_persistent_warp_specialized_matmul_kernel[grid](
|
||||
a, b, c,
|
||||
@@ -463,15 +451,12 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs)
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs,
|
||||
num_warps=4, num_ctas=NUM_CTAS,
|
||||
enable_warp_specialization=True)
|
||||
|
||||
th_c = torch.matmul(a, b)
|
||||
torch.testing.assert_allclose(th_c, c, atol=1e-2, rtol=0)
|
||||
# #############################################Performance Evaluation#############################################
|
||||
# fn = lambda: call_stylish()
|
||||
# ms = triton.testing.do_bench(fn, warmup=25, rep=100)
|
||||
# cur_gpu_perf = round(2. * M * N * K / ms * 1e-9, 2)
|
||||
# print(' '.join(['Performance of', str(M), str(N), str(K), ':', str(ms), 'ms, ', str(cur_gpu_perf), 'TFLOPS']))
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
||||
Reference in New Issue
Block a user