diff --git a/docs/how-to/rocm-for-ai/inference-optimization/model-acceleration-libraries.rst b/docs/how-to/rocm-for-ai/inference-optimization/model-acceleration-libraries.rst
old mode 100644
new mode 100755
index e3b9a761a..4ea3aa5f7
--- a/docs/how-to/rocm-for-ai/inference-optimization/model-acceleration-libraries.rst
+++ b/docs/how-to/rocm-for-ai/inference-optimization/model-acceleration-libraries.rst
@@ -24,94 +24,102 @@ performance.
:alt: Attention module of a large language module utilizing tiling
:align: center
+Installation prerequisites
+----------------------------
+
+Before installing Flash Attention 2, ensure the following are available:
+
+* ROCm-enabled PyTorch
+* Triton
+
+These can be installed by following the official
+`PyTorch installation guide `_. Alternatively, for a simpler setup, you can use a preconfigured
+`ROCm PyTorch Docker image `_, which already includes the required libraries.
+
Installing Flash Attention 2
----------------------------
-ROCm provides two different implementations of Flash Attention 2 modules. They can be deployed interchangeably:
+`Flash Attention `_ supports two backend implementations on AMD GPUs.
-* ROCm `Composable Kernel `_
- (CK) Flash Attention 2
+* `Composable Kernel (CK) `_ - the default backend
+* `OpenAI Triton `_ - an alternative backend
-* `OpenAI Triton `_ Flash Attention 2
+You can switch between these backends using the environment variable FLASH_ATTENTION_TRITON_AMD_ENABLE:
-.. tab-set::
+FLASH_ATTENTION_TRITON_AMD_ENABLE="FALSE"
+→ Use Composable Kernel (CK) backend (FlashAttention 2)
- .. tab-item:: CK Flash Attention 2
+FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
+→ Use OpenAI Triton backend (FlashAttention 2)
- To install CK Flash Attention 2, use the following commands.
+To install Flash Attention 2, use the following commands:
- .. code-block:: shell
+.. code-block:: shell
- # Install from source
- git clone https://github.com/ROCm/flash-attention.git
- cd flash-attention/
- GPU_ARCHS=gfx942 python setup.py install #MI300 Series
+ git clone https://github.com/Dao-AILab/flash-attention.git
+ cd flash-attention/
+ pip install ninja
- Hugging Face Transformers can easily deploy the CK Flash Attention 2 module by passing an argument
- ``attn_implementation="flash_attention_2"`` in the ``from_pretrained`` class.
+ # To install the CK backend flash attention
+ python setup.py install
- .. code-block:: python
+ # To install the Triton backend flash attention
+ FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install
- import torch
- from transformers import AutoModelForCausalLM, AutoTokenizer
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- model_name = "NousResearch/Meta-Llama-3-8B"
+ # To install both CK and Triton backend flash attention
+ FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE && FLASH_ATTENTION_SKIP_CK_BUILD=FALSE python setup.py install
- tokenizer = AutoTokenizer.from_pretrained(model_name, torch_dtype=torch.float16, use_fast=False)
- inputs = tokenizer('Today is', return_tensors='pt').to(device)
+For detailed installation instructions, see `Flash Attention `_.
- model_eager = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, attn_implementation="eager").cuda(device)
- model_ckFAv2 = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, attn_implementation="flash_attention_2").cuda(device)
+Benchmarking Flash Attention 2
+------------------------------
- print("eager GQA: ", tokenizer.decode(model_eager.generate(**inputs, max_new_tokens=10)[0], skip_special_tokens=True))
- print("ckFAv2 GQA: ", tokenizer.decode(model_ckFAv2.generate(**inputs, max_new_tokens=10)[0], skip_special_tokens=True))
+Benchmark scripts to evaluate the performance of Flash Attention 2 are stored in the `flash-attention/benchmarks/` directory.
- # eager GQA: Today is the day of the Lord, and we are the
- # ckFAv2 GQA: Today is the day of the Lord, and we are the
+To benchmark the CK backend
- .. tab-item:: Triton Flash Attention 2
+.. code-block:: shell
- The Triton Flash Attention 2 module is implemented in Python and uses OpenAI’s JIT compiler. This module has been
- upstreamed into the vLLM serving toolkit, discussed in :doc:'llm-inference-frameworks'.
+ cd flash-attention/benchmarks
+ pip install transformers einops ninja
- 1. To install Triton Flash Attention 2 and run the benchmark, use the following commands.
+ python3 benchmark_flash_attention.py
- .. code-block:: shell
+To benchmark the Triton backend
- # Install from the source
- pip uninstall pytorch-triton-rocm triton -y
- git clone https://github.com/ROCm/triton.git
- cd triton/python
- GPU_ARCHS=gfx942 python setup.py install #MI300 series
- pip install matplotlib pandas
+.. code-block:: shell
- 2. To test, run the Triton Flash Attention 2 performance benchmark.
+ FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python3 benchmark_flash_attention.py
- .. code-block:: shell
-
- # Test the triton FA v2 kernel
- python https://github.com/ROCm/triton/blob/triton-mlir/python/perf-kernels/flash-attention.py
- # Results (Okay to release TFLOPS number ???)
- fused-attention-fwd-d128:
- BATCH HQ HK N_CTX_Q N_CTX_K TFLOPS
- 0 16.0 16.0 16.0 1024.0 1024.0 287.528411
- 1 8.0 16.0 16.0 2048.0 2048.0 287.490806
- 2 4.0 16.0 16.0 4096.0 4096.0 345.966031
- 3 2.0 16.0 16.0 8192.0 8192.0 361.369510
- 4 1.0 16.0 16.0 16384.0 16384.0 356.873720
- 5 2.0 48.0 48.0 1024.0 1024.0 216.916235
- 6 2.0 48.0 48.0 2048.0 1024.0 271.027578
- 7 2.0 48.0 48.0 4096.0 8192.0 337.367372
- 8 2.0 48.0 48.0 8192.0 4096.0 363.481649
- 9 2.0 48.0 48.0 16384.0 8192.0 375.013622
- 10 8.0 16.0 16.0 1989.0 15344.0 321.791333
- 11 4.0 16.0 16.0 4097.0 163.0 122.104888
- 12 2.0 16.0 16.0 8122.0 2159.0 337.060283
- 13 1.0 16.0 16.0 16281.0 7.0 5.234012
- 14 2.0 48.0 48.0 1021.0 1020.0 214.657425
- 15 2.0 48.0 48.0 2001.0 2048.0 314.429118
- 16 2.0 48.0 48.0 3996.0 9639.0 330.411368
- 17 2.0 48.0 48.0 8181.0 1021.0 324.614980
+Using Flash Attention 2
+-----------------------
+
+.. code-block:: python
+
+ import torch
+ from transformers import AutoModelForCausalLM, AutoTokenizer
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ model_name = "NousResearch/Llama-3.2-1B"
+
+ tokenizer = AutoTokenizer.from_pretrained(model_name, dtype=torch.bfloat16, use_fast=False)
+ inputs = tokenizer('Today is', return_tensors='pt').to(device)
+
+ model_eager = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.bfloat16, attn_implementation="eager").cuda(device)
+ model_ckFAv2 = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").cuda(device)
+ model_eager.generation_config.pad_token_id = model_eager.generation_config.eos_token_id
+ model_ckFAv2.generation_config.pad_token_id = model_ckFAv2.generation_config.eos_token_id
+
+ print("eager\n GQA: ", tokenizer.decode(model_eager.generate(**inputs, max_new_tokens=22)[0], skip_special_tokens=True, do_sample=False, num_beams=1))
+ print("ckFAv2\n GQA: ", tokenizer.decode(model_ckFAv2.generate(**inputs, max_new_tokens=22)[0], skip_special_tokens=True, do_sample=False, num_beams=1))
+
+The outputs from eager mode and FlashAttention-2 are identical, although their performance behavior differs.
+
+.. code-block:: shell
+
+ eager
+ GQA: Today is the 10th anniversary of the 9/11 attacks. I remember that day like it was yesterday.
+ ckFAv2
+ GQA: Today is the 10th anniversary of the 9/11 attacks. I remember that day like it was yesterday.
xFormers
========