Update Flash Attention guidance in "Model acceleration libraries" (#5793)

* flash attention update

Signed-off-by: seungrok.jung <seungrok.jung@amd.com>

flash attention update

Signed-off-by: seungrok.jung <seungrok.jung@amd.com>

flash attention update

Signed-off-by: seungrok.jung <seungrok.jung@amd.com>

sentence-case heading

* Update docs/how-to/rocm-for-ai/inference-optimization/model-acceleration-libraries.rst

Co-authored-by: Leo Paoletti <164940351+lpaoletti@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Leo Paoletti <164940351+lpaoletti@users.noreply.github.com>

---------

Co-authored-by: seungrok.jung <seungrok.jung@amd.com>
Co-authored-by: Leo Paoletti <164940351+lpaoletti@users.noreply.github.com>
This commit is contained in:
peterjunpark
2025-12-19 08:48:52 -05:00
committed by GitHub
parent cbab9a465d
commit 52c0a47e84

View File

@@ -24,94 +24,102 @@ performance.
:alt: Attention module of a large language module utilizing tiling :alt: Attention module of a large language module utilizing tiling
:align: center :align: center
Installation prerequisites
----------------------------
Before installing Flash Attention 2, ensure the following are available:
* ROCm-enabled PyTorch
* Triton
These can be installed by following the official
`PyTorch installation guide <https://pytorch.org/get-started/locally/>`_. Alternatively, for a simpler setup, you can use a preconfigured
`ROCm PyTorch Docker image <https://github.com/ROCm/rocm-docker-images>`_, which already includes the required libraries.
Installing Flash Attention 2 Installing Flash Attention 2
---------------------------- ----------------------------
ROCm provides two different implementations of Flash Attention 2 modules. They can be deployed interchangeably: `Flash Attention <https://github.com/Dao-AILab/flash-attention>`_ supports two backend implementations on AMD GPUs.
* ROCm `Composable Kernel <https://github.com/ROCm/composable_kernel/tree/develop/example/01_gemm>`_ * `Composable Kernel (CK) <https://github.com/ROCm/composable_kernel>`_ - the default backend
(CK) Flash Attention 2 * `OpenAI Triton <https://github.com/triton-lang/triton>`_ - an alternative backend
* `OpenAI Triton <https://triton-lang.org/main/index.html>`_ Flash Attention 2 You can switch between these backends using the environment variable FLASH_ATTENTION_TRITON_AMD_ENABLE:
.. tab-set:: FLASH_ATTENTION_TRITON_AMD_ENABLE="FALSE"
→ Use Composable Kernel (CK) backend (FlashAttention 2)
.. tab-item:: CK Flash Attention 2 FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
→ Use OpenAI Triton backend (FlashAttention 2)
To install CK Flash Attention 2, use the following commands. To install Flash Attention 2, use the following commands:
.. code-block:: shell .. code-block:: shell
# Install from source git clone https://github.com/Dao-AILab/flash-attention.git
git clone https://github.com/ROCm/flash-attention.git cd flash-attention/
cd flash-attention/ pip install ninja
GPU_ARCHS=gfx942 python setup.py install #MI300 Series
Hugging Face Transformers can easily deploy the CK Flash Attention 2 module by passing an argument # To install the CK backend flash attention
``attn_implementation="flash_attention_2"`` in the ``from_pretrained`` class. python setup.py install
.. code-block:: python # To install the Triton backend flash attention
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install
import torch # To install both CK and Triton backend flash attention
from transformers import AutoModelForCausalLM, AutoTokenizer FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE && FLASH_ATTENTION_SKIP_CK_BUILD=FALSE python setup.py install
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_name = "NousResearch/Meta-Llama-3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name, torch_dtype=torch.float16, use_fast=False) For detailed installation instructions, see `Flash Attention <https://github.com/Dao-AILab/flash-attention>`_.
inputs = tokenizer('Today is', return_tensors='pt').to(device)
model_eager = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, attn_implementation="eager").cuda(device) Benchmarking Flash Attention 2
model_ckFAv2 = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, attn_implementation="flash_attention_2").cuda(device) ------------------------------
print("eager GQA: ", tokenizer.decode(model_eager.generate(**inputs, max_new_tokens=10)[0], skip_special_tokens=True)) Benchmark scripts to evaluate the performance of Flash Attention 2 are stored in the `flash-attention/benchmarks/` directory.
print("ckFAv2 GQA: ", tokenizer.decode(model_ckFAv2.generate(**inputs, max_new_tokens=10)[0], skip_special_tokens=True))
# eager GQA: Today is the day of the Lord, and we are the To benchmark the CK backend
# ckFAv2 GQA: Today is the day of the Lord, and we are the
.. tab-item:: Triton Flash Attention 2 .. code-block:: shell
The Triton Flash Attention 2 module is implemented in Python and uses OpenAIs JIT compiler. This module has been cd flash-attention/benchmarks
upstreamed into the vLLM serving toolkit, discussed in :doc:'llm-inference-frameworks'. pip install transformers einops ninja
1. To install Triton Flash Attention 2 and run the benchmark, use the following commands. python3 benchmark_flash_attention.py
.. code-block:: shell To benchmark the Triton backend
# Install from the source .. code-block:: shell
pip uninstall pytorch-triton-rocm triton -y
git clone https://github.com/ROCm/triton.git
cd triton/python
GPU_ARCHS=gfx942 python setup.py install #MI300 series
pip install matplotlib pandas
2. To test, run the Triton Flash Attention 2 performance benchmark. FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python3 benchmark_flash_attention.py
.. code-block:: shell Using Flash Attention 2
-----------------------
# Test the triton FA v2 kernel
python https://github.com/ROCm/triton/blob/triton-mlir/python/perf-kernels/flash-attention.py .. code-block:: python
# Results (Okay to release TFLOPS number ???)
fused-attention-fwd-d128: import torch
BATCH HQ HK N_CTX_Q N_CTX_K TFLOPS from transformers import AutoModelForCausalLM, AutoTokenizer
0 16.0 16.0 16.0 1024.0 1024.0 287.528411 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1 8.0 16.0 16.0 2048.0 2048.0 287.490806 model_name = "NousResearch/Llama-3.2-1B"
2 4.0 16.0 16.0 4096.0 4096.0 345.966031
3 2.0 16.0 16.0 8192.0 8192.0 361.369510 tokenizer = AutoTokenizer.from_pretrained(model_name, dtype=torch.bfloat16, use_fast=False)
4 1.0 16.0 16.0 16384.0 16384.0 356.873720 inputs = tokenizer('Today is', return_tensors='pt').to(device)
5 2.0 48.0 48.0 1024.0 1024.0 216.916235
6 2.0 48.0 48.0 2048.0 1024.0 271.027578 model_eager = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.bfloat16, attn_implementation="eager").cuda(device)
7 2.0 48.0 48.0 4096.0 8192.0 337.367372 model_ckFAv2 = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").cuda(device)
8 2.0 48.0 48.0 8192.0 4096.0 363.481649 model_eager.generation_config.pad_token_id = model_eager.generation_config.eos_token_id
9 2.0 48.0 48.0 16384.0 8192.0 375.013622 model_ckFAv2.generation_config.pad_token_id = model_ckFAv2.generation_config.eos_token_id
10 8.0 16.0 16.0 1989.0 15344.0 321.791333
11 4.0 16.0 16.0 4097.0 163.0 122.104888 print("eager\n GQA: ", tokenizer.decode(model_eager.generate(**inputs, max_new_tokens=22)[0], skip_special_tokens=True, do_sample=False, num_beams=1))
12 2.0 16.0 16.0 8122.0 2159.0 337.060283 print("ckFAv2\n GQA: ", tokenizer.decode(model_ckFAv2.generate(**inputs, max_new_tokens=22)[0], skip_special_tokens=True, do_sample=False, num_beams=1))
13 1.0 16.0 16.0 16281.0 7.0 5.234012
14 2.0 48.0 48.0 1021.0 1020.0 214.657425 The outputs from eager mode and FlashAttention-2 are identical, although their performance behavior differs.
15 2.0 48.0 48.0 2001.0 2048.0 314.429118
16 2.0 48.0 48.0 3996.0 9639.0 330.411368 .. code-block:: shell
17 2.0 48.0 48.0 8181.0 1021.0 324.614980
eager
GQA: Today is the 10th anniversary of the 9/11 attacks. I remember that day like it was yesterday.
ckFAv2
GQA: Today is the 10th anniversary of the 9/11 attacks. I remember that day like it was yesterday.
xFormers xFormers
======== ========