[FRONTEND] Improve grid calculation for persistent kernels to hoist pe… (#2283)

…rf on problems that need few blocks.

constrain the number of launched blocks to what it exactely needs for
persistent warp specialized kernel. It's useful when problems need very
few blocks.
e.g. MxNxK=800x800x60000, f16_f16_f32, block size=128x128x64,
non-split-k. Experiments show it can achieve ~16% speedup.
This commit is contained in:
jsh-20
2023-09-12 17:14:47 +08:00
committed by GitHub
parent ab9da3b2b8
commit fc5d7e6e7c

View File

@@ -141,7 +141,7 @@ def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLO
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
grid = lambda META: (num_SMs,)
grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),)
if USE_TMA:
static_persistent_tma_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS)
@@ -432,7 +432,7 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
grid = lambda META: (num_SMs,)
grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),)
if USE_TMA:
static_persistent_tma_warp_specialized_matmul_kernel[grid](
@@ -899,7 +899,7 @@ def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WAR
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
def grid(META):
return (num_SMs,)
return (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),)
full_static_persistent_matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z,
M=M, N=N, K=K,