[TESTS] refactor test-persistent-warp-specialized-gemm UTs (#2075)

remove unnecessary skips. decompose UTs in
persistent-warp-specialized-gemm into vintage and stylish
This commit is contained in:
Beal Wang
2023-08-10 14:57:04 +08:00
committed by GitHub
parent 776b3784c2
commit d1ce4c4950
2 changed files with 95 additions and 133 deletions

View File

@@ -297,9 +297,6 @@ attention = _attention.apply
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
# with ENABLE_TMA=0 and ENABLE_MMA_V3=0
pytest.skip('unspecified launch failure')
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()

View File

@@ -114,17 +114,21 @@ def static_persistent_tma_matmul_kernel(
pre_pid_n = pid_n
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,TRANS_A,TRANS_B', [
[4096, 4096, 64, 64, 64, 16, 4, 1, False, True],
[4096, 4096, 64, 64, 64, 32, 4, 1, False, True],
[4096, 4096, 64, 256, 64, 16, 4, 1, False, True],
[4096, 4096, 64, 128, 128, 16, 4, 1, False, True],
# TODO: fix issue for 8-warp persistent kernel
# [4096, 4096, 64, 128, 128, 16, 8, 1, False, True],
# [4096, 4096, 64, 128, 256, 16, 8, 1, False, True],
])
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA',
[(*shape, use_tma)
for shape in [
[4096, 4096, 64, 64, 64, 16, 4, 1, False, True],
[4096, 4096, 64, 64, 64, 32, 4, 1, False, True],
[4096, 4096, 64, 256, 64, 16, 4, 1, False, True],
[4096, 4096, 64, 128, 128, 16, 4, 1, False, True],
# TODO: fix issue for 8-warp persistent kernel
# [4096, 4096, 64, 128, 128, 16, 8, 1, False, True],
# [4096, 4096, 64, 128, 256, 16, 8, 1, False, True],
]
for use_tma in [False, True]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, TRANS_A, TRANS_B):
def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, TRANS_A, TRANS_B, USE_TMA):
if (TRANS_A):
a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T
else:
@@ -139,26 +143,13 @@ def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLO
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
grid = lambda META: (num_SMs,)
def call_vintage():
static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS)
return c
def call_stylish():
if USE_TMA:
static_persistent_tma_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS)
return c
else:
static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS)
th_c = torch.matmul(a, b)
# Test using old style of ptr calculation
tt_c = call_vintage()
torch.testing.assert_allclose(th_c, tt_c, atol=1e-2, rtol=0)
# Cealr c
c = torch.randn((M, N), device=a.device, dtype=torch.float32)
# Test using make_block_ptr
tt_c = call_stylish()
torch.testing.assert_allclose(th_c, tt_c, atol=1e-2, rtol=0)
torch.testing.assert_allclose(th_c, c, atol=1e-2, rtol=0)
@triton.jit
@@ -240,40 +231,37 @@ def tma_warp_specialized_matmul_kernel(
tl.store(c_ptrs, accumulator, mask=mask)
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_CTAS,TRANS_A,TRANS_B', [
[2048, 2048, 64, 64, 64, 16, 1, False, True],
[4096, 4096, 64, 64, 64, 16, 1, False, True],
[128, 4096, 64, 64, 64, 16, 1, False, True],
[4096, 128, 64, 64, 64, 16, 1, False, True],
[4096, 4096, 64, 64, 64, 32, 1, False, True],
[4096, 4096, 256, 128, 128, 16, 1, False, True],
[4096, 4096, 320, 128, 64, 64, 1, False, True],
[4096, 4096, 320, 64, 128, 64, 1, False, True],
[4096, 4096, 320, 128, 128, 64, 1, False, True],
[4096, 4096, 256, 256, 64, 16, 1, False, True],
[4096, 4096, 256, 256, 64, 64, 1, False, True],
[4096, 4096, 256, 64, 256, 16, 1, False, True],
[4096, 4096, 256, 64, 256, 64, 1, False, True],
[4096, 4096, 256, 256, 128, 16, 1, False, True],
[4096, 4096, 256, 256, 128, 64, 1, False, True],
[4096, 4096, 256, 128, 256, 16, 1, False, True],
[4096, 4096, 256, 128, 256, 64, 1, False, True],
# numCTAs > 1
[2048, 2048, 64, 128, 128, 64, 2, False, True],
[2048, 2048, 64, 128, 128, 64, 2, False, True],
[2048, 2048, 128, 256, 128, 64, 4, False, True],
[4096, 4096, 128, 256, 128, 64, 4, False, True],
[4096, 4096, 256, 128, 256, 64, 4, False, True],
[4096, 4096, 256, 256, 256, 64, 4, False, True],
])
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA',
[(*shape, use_tma)
for shape in [
[2048, 2048, 64, 64, 64, 16, 1, False, True],
[4096, 4096, 64, 64, 64, 16, 1, False, True],
[128, 4096, 64, 64, 64, 16, 1, False, True],
[4096, 128, 64, 64, 64, 16, 1, False, True],
[4096, 4096, 64, 64, 64, 32, 1, False, True],
[4096, 4096, 256, 128, 128, 16, 1, False, True],
[4096, 4096, 320, 128, 64, 64, 1, False, True],
[4096, 4096, 320, 64, 128, 64, 1, False, True],
[4096, 4096, 320, 128, 128, 64, 1, False, True],
[4096, 4096, 256, 256, 64, 16, 1, False, True],
[4096, 4096, 256, 256, 64, 64, 1, False, True],
[4096, 4096, 256, 64, 256, 16, 1, False, True],
[4096, 4096, 256, 64, 256, 64, 1, False, True],
[4096, 4096, 256, 256, 128, 16, 1, False, True],
[4096, 4096, 256, 256, 128, 64, 1, False, True],
[4096, 4096, 256, 128, 256, 16, 1, False, True],
[4096, 4096, 256, 128, 256, 64, 1, False, True],
# numCTAs > 1
[2048, 2048, 64, 128, 128, 64, 2, False, True],
[2048, 2048, 128, 256, 128, 64, 4, False, True],
[4096, 4096, 128, 256, 128, 64, 4, False, True],
[4096, 4096, 256, 128, 256, 64, 4, False, True],
[4096, 4096, 256, 256, 256, 64, 4, False, True],
]
for use_tma in [False, True]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B):
if '-'.join(map(str, [M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B])) in [
'4096-4096-256-128-256-16-1-False-True',
'4096-4096-256-128-256-64-1-False-True'
]:
pytest.skip('Insufficient register resources')
def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B, USE_TMA):
if (TRANS_A):
a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T
else:
@@ -288,20 +276,7 @@ def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K
grid = lambda META: (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),)
def call_vintage():
warp_specialized_matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K,
num_warps=4,
num_ctas=NUM_CTAS,
enable_warp_specialization=True)
return c
def call_stylish():
if USE_TMA:
tma_warp_specialized_matmul_kernel[grid](
a, b, c,
M, N, K,
@@ -312,20 +287,20 @@ def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K
num_warps=4,
num_ctas=NUM_CTAS,
enable_warp_specialization=True)
return c
else:
warp_specialized_matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K,
num_warps=4,
num_ctas=NUM_CTAS,
enable_warp_specialization=True)
th_c = torch.matmul(a, b)
# Test using old style of ptr calculation
tt_c = call_vintage()
torch.testing.assert_allclose(th_c, tt_c, atol=1e-2, rtol=0)
# Cealr c
c = torch.randn((M, N), device=a.device, dtype=torch.float32)
# # Test using make_block_ptr
tt_c = call_stylish()
torch.testing.assert_allclose(th_c, tt_c, atol=1e-2, rtol=0)
torch.testing.assert_allclose(th_c, c, atol=1e-2, rtol=0)
# # #############################################Performance Evaluation#############################################
# fn = lambda: call_vintage()
@@ -434,27 +409,31 @@ def static_persistent_tma_warp_specialized_matmul_kernel(
pre_pid_n = pid_n
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_CTAS,TRANS_A,TRANS_B', [
[2048, 2048, 64, 64, 64, 16, 1, False, True],
[4096, 4096, 64, 64, 64, 16, 1, False, True],
[128, 4096, 64, 64, 64, 16, 1, False, True],
[4096, 128, 64, 64, 64, 16, 1, False, True],
[4096, 4096, 64, 64, 64, 32, 1, False, True],
[4096, 4096, 256, 128, 128, 16, 1, False, True],
[4096, 4096, 320, 128, 64, 64, 1, False, True],
[4096, 4096, 320, 64, 128, 64, 1, False, True],
[4096, 4096, 320, 128, 128, 64, 1, False, True],
[4096, 4096, 256, 256, 64, 16, 1, False, True],
[4096, 4096, 256, 256, 64, 64, 1, False, True],
[4096, 4096, 256, 64, 256, 16, 1, False, True],
[4096, 4096, 256, 64, 256, 64, 1, False, True],
[4096, 4096, 256, 256, 128, 16, 1, False, True],
[4096, 4096, 256, 256, 128, 64, 1, False, True],
[4096, 4096, 256, 128, 256, 16, 1, False, True],
[4096, 4096, 256, 128, 256, 64, 1, False, True],
])
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA',
[(*shape, use_tma)
for shape in [
[2048, 2048, 64, 64, 64, 16, 1, False, True],
[4096, 4096, 64, 64, 64, 16, 1, False, True],
[128, 4096, 64, 64, 64, 16, 1, False, True],
[4096, 128, 64, 64, 64, 16, 1, False, True],
[4096, 4096, 64, 64, 64, 32, 1, False, True],
[4096, 4096, 256, 128, 128, 16, 1, False, True],
[4096, 4096, 320, 128, 64, 64, 1, False, True],
[4096, 4096, 320, 64, 128, 64, 1, False, True],
[4096, 4096, 320, 128, 128, 64, 1, False, True],
[4096, 4096, 256, 256, 64, 16, 1, False, True],
[4096, 4096, 256, 256, 64, 64, 1, False, True],
[4096, 4096, 256, 64, 256, 16, 1, False, True],
[4096, 4096, 256, 64, 256, 64, 1, False, True],
[4096, 4096, 256, 256, 128, 16, 1, False, True],
[4096, 4096, 256, 256, 128, 64, 1, False, True],
[4096, 4096, 256, 128, 256, 16, 1, False, True],
[4096, 4096, 256, 128, 256, 64, 1, False, True],
]
for use_tma in [False, True]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B):
def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B, USE_TMA):
if (TRANS_A):
a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T
else:
@@ -469,17 +448,7 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
grid = lambda META: (num_SMs,)
def call_vintage():
static_persistent_warp_specialized_matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs)
return c
def call_stylish():
if USE_TMA:
static_persistent_tma_warp_specialized_matmul_kernel[grid](
a, b, c,
M, N, K,
@@ -487,21 +456,17 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs)
return c
else:
static_persistent_warp_specialized_matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs)
th_c = torch.matmul(a, b)
# Test using old style of ptr calculation
tt_c = call_vintage()
torch.testing.assert_allclose(th_c, tt_c, atol=1e-2, rtol=0)
# Cealr c
c = torch.randn((M, N), device=a.device, dtype=torch.float32)
# Test using make_block_ptr
tt_c = call_stylish()
torch.testing.assert_allclose(th_c, tt_c, atol=1e-2, rtol=0)
torch.testing.assert_allclose(th_c, c, atol=1e-2, rtol=0)
# #############################################Performance Evaluation#############################################
# fn = lambda: call_stylish()
# ms = triton.testing.do_bench(fn, warmup=25, rep=100)