mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Improve grid calculation for persistent kernels to hoist pe… (#2283)
…rf on problems that need few blocks. constrain the number of launched blocks to what it exactely needs for persistent warp specialized kernel. It's useful when problems need very few blocks. e.g. MxNxK=800x800x60000, f16_f16_f32, block size=128x128x64, non-split-k. Experiments show it can achieve ~16% speedup.
This commit is contained in:
@@ -141,7 +141,7 @@ def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLO
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||
|
||||
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
|
||||
grid = lambda META: (num_SMs,)
|
||||
grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),)
|
||||
|
||||
if USE_TMA:
|
||||
static_persistent_tma_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS)
|
||||
@@ -432,7 +432,7 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||
|
||||
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
|
||||
grid = lambda META: (num_SMs,)
|
||||
grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),)
|
||||
|
||||
if USE_TMA:
|
||||
static_persistent_tma_warp_specialized_matmul_kernel[grid](
|
||||
@@ -899,7 +899,7 @@ def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WAR
|
||||
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
|
||||
|
||||
def grid(META):
|
||||
return (num_SMs,)
|
||||
return (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),)
|
||||
full_static_persistent_matmul_kernel[grid](
|
||||
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z,
|
||||
M=M, N=N, K=K,
|
||||
|
||||
Reference in New Issue
Block a user