mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[Tutorial] Fix post IFU issues with FA (#398)
* [Tutorial] Fix post IFU issues with FA * Remove redundant kernels in 06-fused-attention.py * Added README for scripts in perf-kernels dir * Fix bwd kernel --------- Co-authored-by: Lixun Zhang <lixun.zhang@amd.com>
This commit is contained in:
committed by
Jason Furmanek
parent
096def0c9b
commit
dfb76540b4
@@ -17,6 +17,10 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
torch_dtype:tl.constexpr = torch.float16
|
||||
TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz')
|
||||
if TORCH_HAS_FP8E5:
|
||||
torch_dtype:tl.constexpr = torch.float8_e5m2fnuz
|
||||
|
||||
@triton.jit
|
||||
def max_fn(x, y):
|
||||
@@ -145,7 +149,7 @@ def _attn_fwd(
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
# load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs
|
||||
q = tl.load(Q_block_ptr)
|
||||
q = (q * qk_scale).to(tl.float16)
|
||||
q = (q * qk_scale).to(q.dtype)
|
||||
# stage 1: off-band
|
||||
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
|
||||
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
|
||||
|
||||
30
python/perf-kernels/README.md
Normal file
30
python/perf-kernels/README.md
Normal file
@@ -0,0 +1,30 @@
|
||||
# AMD Perf Kernels
|
||||
|
||||
This directory contains customized/tuned/experimental kernels on AMD MI series GPUs.
|
||||
|
||||
## `06-fused-attention-transV.py`
|
||||
|
||||
This script is a copy of `tutorials/06-fused-attention.py` with the following
|
||||
two changes:
|
||||
|
||||
- Tensor V is transposed in the way that seqlen/N_CTX dimension becomes the
|
||||
fastest changing (a.k.a. leading or least strided) dimension.
|
||||
This script produces better performance than `tutorials/06-fused-attention.py`
|
||||
since it has better LDS access efficiency for tensor V.
|
||||
Note that in the future, we'll improve the LDS access efficiency for
|
||||
non-transposed tensor V, i.e. head dimension is the fastest changing dimension.
|
||||
- Only fwd kernel is benchmarked.
|
||||
|
||||
## `06-fused-attention-fwd-transV.py`
|
||||
|
||||
This script is used to produce the best performance for fwd kernel.
|
||||
It is a copy of `06-fused-attention-transV.py` with the following
|
||||
changes:
|
||||
|
||||
- All bwd kernels are removed.
|
||||
- Storing `m` at the end of the fwd kernel is removed.
|
||||
- Autotuner is removed. All parameters for D=64 ad D=128 are pre-tuned
|
||||
on MI250X and hard coded.
|
||||
|
||||
Note that this script is also used to benchmark FA performance with 2 GCDs.
|
||||
Check the [2GCD benchmark script](https://github.com/ROCmSoftwarePlatform/triton/blob/triton-mlir/scripts/amd/benchmark_flash_attention.py) for more details.
|
||||
Reference in New Issue
Block a user