mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Add OptimizeEpilogue pass. (#346)
* optimize_epilogue * Add config * Remove licenses * Comment out Hopper specific parameters when printing out configs * Add benchmark parameters from flash-attention repo * Add Z and H in the key of autotuner --------- Co-authored-by: Lixun Zhang <lixun.zhang@amd.com>
This commit is contained in:
@@ -86,7 +86,7 @@ def _attn_fwd_inner(
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True
|
||||
],
|
||||
key=['N_CTX', 'STAGE', 'BLOCK_DMODEL'],
|
||||
key=['Z', 'H', 'N_CTX', 'STAGE', 'BLOCK_DMODEL'],
|
||||
)
|
||||
|
||||
|
||||
@@ -547,7 +547,7 @@ class _attention(torch.autograd.Function):
|
||||
)
|
||||
|
||||
## restore the grid for bwd kernel
|
||||
best_config = _attn_fwd.get_best_config(N_CTX = q.shape[2], STAGE = stage, BLOCK_DMODEL=Lk)
|
||||
best_config = _attn_fwd.get_best_config(Z = q.shape[0], H = q.shape[1], N_CTX = q.shape[2], STAGE = stage, BLOCK_DMODEL=Lk)
|
||||
block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1])
|
||||
grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user