Add OptimizeEpilogue pass. (#346)

* optimize_epilogue

* Add config

* Remove licenses

* Comment out Hopper specific parameters when printing out configs

* Add benchmark parameters from flash-attention repo

* Add Z and H in the key of autotuner

---------

Co-authored-by: Lixun Zhang <lixun.zhang@amd.com>
This commit is contained in:
oplavsic
2023-11-03 22:46:24 +01:00
committed by GitHub
parent cb02a0b346
commit c65f1e6211
5 changed files with 40 additions and 46 deletions

View File

@@ -86,7 +86,7 @@ def _attn_fwd_inner(
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True
],
key=['N_CTX', 'STAGE', 'BLOCK_DMODEL'],
key=['Z', 'H', 'N_CTX', 'STAGE', 'BLOCK_DMODEL'],
)
@@ -547,7 +547,7 @@ class _attention(torch.autograd.Function):
)
## restore the grid for bwd kernel
best_config = _attn_fwd.get_best_config(N_CTX = q.shape[2], STAGE = stage, BLOCK_DMODEL=Lk)
best_config = _attn_fwd.get_best_config(Z = q.shape[0], H = q.shape[1], N_CTX = q.shape[2], STAGE = stage, BLOCK_DMODEL=Lk)
block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1])
grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1)