[Tutorial] Fix post IFU issues with FA (#398)

* [Tutorial] Fix post IFU issues with FA

* Remove redundant kernels in 06-fused-attention.py

* Added README for scripts in perf-kernels dir

* Fix bwd kernel

---------

Co-authored-by: Lixun Zhang <lixun.zhang@amd.com>
This commit is contained in:
Alexander Efimov
2023-11-14 17:46:45 +01:00
committed by Jason Furmanek
parent 096def0c9b
commit dfb76540b4
2 changed files with 35 additions and 1 deletions

View File

@@ -17,6 +17,10 @@ import torch
import triton
import triton.language as tl
torch_dtype:tl.constexpr = torch.float16
TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz')
if TORCH_HAS_FP8E5:
torch_dtype:tl.constexpr = torch.float8_e5m2fnuz
@triton.jit
def max_fn(x, y):
@@ -145,7 +149,7 @@ def _attn_fwd(
qk_scale = sm_scale * 1.44269504
# load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs
q = tl.load(Q_block_ptr)
q = (q * qk_scale).to(tl.float16)
q = (q * qk_scale).to(q.dtype)
# stage 1: off-band
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE

View File

@@ -0,0 +1,30 @@
# AMD Perf Kernels
This directory contains customized/tuned/experimental kernels on AMD MI series GPUs.
## `06-fused-attention-transV.py`
This script is a copy of `tutorials/06-fused-attention.py` with the following
two changes:
- Tensor V is transposed in the way that seqlen/N_CTX dimension becomes the
fastest changing (a.k.a. leading or least strided) dimension.
This script produces better performance than `tutorials/06-fused-attention.py`
since it has better LDS access efficiency for tensor V.
Note that in the future, we'll improve the LDS access efficiency for
non-transposed tensor V, i.e. head dimension is the fastest changing dimension.
- Only fwd kernel is benchmarked.
## `06-fused-attention-fwd-transV.py`
This script is used to produce the best performance for fwd kernel.
It is a copy of `06-fused-attention-transV.py` with the following
changes:
- All bwd kernels are removed.
- Storing `m` at the end of the fwd kernel is removed.
- Autotuner is removed. All parameters for D=64 ad D=128 are pre-tuned
on MI250X and hard coded.
Note that this script is also used to benchmark FA performance with 2 GCDs.
Check the [2GCD benchmark script](https://github.com/ROCmSoftwarePlatform/triton/blob/triton-mlir/scripts/amd/benchmark_flash_attention.py) for more details.