[FRONTEND] Add input dtypes to autotuning key (#2534) (#374)

* [FRONTEND] Add input dtypes to autotuning key (#2534)

* Fix conflict in 06-fused-attention

* Fix get_best_config in FA-transV.py

* Fix leftover get_best_config()

---------

Co-authored-by: Adnan Akhundov <adnan.akhundov@gmail.com>
This commit is contained in:
Lixun Zhang
2023-11-07 19:36:57 -06:00
committed by GitHub
parent 3c1fe617c1
commit 1af893d8a2
5 changed files with 10 additions and 20 deletions

View File

@@ -547,7 +547,7 @@ class _attention(torch.autograd.Function):
)
## restore the grid for bwd kernel
best_config = _attn_fwd.get_best_config(Z = q.shape[0], H = q.shape[1], N_CTX = q.shape[2], STAGE = stage, BLOCK_DMODEL=Lk)
best_config = _attn_fwd.get_best_config()
block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1])
grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1)