mirror of
https://github.com/ROCm/ROCm.git
synced 2026-01-08 22:28:06 -05:00
Update Flash Attention guidance in "Model acceleration libraries" (#5793)
* flash attention update Signed-off-by: seungrok.jung <seungrok.jung@amd.com> flash attention update Signed-off-by: seungrok.jung <seungrok.jung@amd.com> flash attention update Signed-off-by: seungrok.jung <seungrok.jung@amd.com> sentence-case heading * Update docs/how-to/rocm-for-ai/inference-optimization/model-acceleration-libraries.rst Co-authored-by: Leo Paoletti <164940351+lpaoletti@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Leo Paoletti <164940351+lpaoletti@users.noreply.github.com> --------- Co-authored-by: seungrok.jung <seungrok.jung@amd.com> Co-authored-by: Leo Paoletti <164940351+lpaoletti@users.noreply.github.com>
This commit is contained in:
136
docs/how-to/rocm-for-ai/inference-optimization/model-acceleration-libraries.rst
Normal file → Executable file
136
docs/how-to/rocm-for-ai/inference-optimization/model-acceleration-libraries.rst
Normal file → Executable file
@@ -24,94 +24,102 @@ performance.
|
||||
:alt: Attention module of a large language module utilizing tiling
|
||||
:align: center
|
||||
|
||||
Installation prerequisites
|
||||
----------------------------
|
||||
|
||||
Before installing Flash Attention 2, ensure the following are available:
|
||||
|
||||
* ROCm-enabled PyTorch
|
||||
* Triton
|
||||
|
||||
These can be installed by following the official
|
||||
`PyTorch installation guide <https://pytorch.org/get-started/locally/>`_. Alternatively, for a simpler setup, you can use a preconfigured
|
||||
`ROCm PyTorch Docker image <https://github.com/ROCm/rocm-docker-images>`_, which already includes the required libraries.
|
||||
|
||||
Installing Flash Attention 2
|
||||
----------------------------
|
||||
|
||||
ROCm provides two different implementations of Flash Attention 2 modules. They can be deployed interchangeably:
|
||||
`Flash Attention <https://github.com/Dao-AILab/flash-attention>`_ supports two backend implementations on AMD GPUs.
|
||||
|
||||
* ROCm `Composable Kernel <https://github.com/ROCm/composable_kernel/tree/develop/example/01_gemm>`_
|
||||
(CK) Flash Attention 2
|
||||
* `Composable Kernel (CK) <https://github.com/ROCm/composable_kernel>`_ - the default backend
|
||||
* `OpenAI Triton <https://github.com/triton-lang/triton>`_ - an alternative backend
|
||||
|
||||
* `OpenAI Triton <https://triton-lang.org/main/index.html>`_ Flash Attention 2
|
||||
You can switch between these backends using the environment variable FLASH_ATTENTION_TRITON_AMD_ENABLE:
|
||||
|
||||
.. tab-set::
|
||||
FLASH_ATTENTION_TRITON_AMD_ENABLE="FALSE"
|
||||
→ Use Composable Kernel (CK) backend (FlashAttention 2)
|
||||
|
||||
.. tab-item:: CK Flash Attention 2
|
||||
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
|
||||
→ Use OpenAI Triton backend (FlashAttention 2)
|
||||
|
||||
To install CK Flash Attention 2, use the following commands.
|
||||
To install Flash Attention 2, use the following commands:
|
||||
|
||||
.. code-block:: shell
|
||||
.. code-block:: shell
|
||||
|
||||
# Install from source
|
||||
git clone https://github.com/ROCm/flash-attention.git
|
||||
cd flash-attention/
|
||||
GPU_ARCHS=gfx942 python setup.py install #MI300 Series
|
||||
git clone https://github.com/Dao-AILab/flash-attention.git
|
||||
cd flash-attention/
|
||||
pip install ninja
|
||||
|
||||
Hugging Face Transformers can easily deploy the CK Flash Attention 2 module by passing an argument
|
||||
``attn_implementation="flash_attention_2"`` in the ``from_pretrained`` class.
|
||||
# To install the CK backend flash attention
|
||||
python setup.py install
|
||||
|
||||
.. code-block:: python
|
||||
# To install the Triton backend flash attention
|
||||
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
model_name = "NousResearch/Meta-Llama-3-8B"
|
||||
# To install both CK and Triton backend flash attention
|
||||
FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE && FLASH_ATTENTION_SKIP_CK_BUILD=FALSE python setup.py install
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, torch_dtype=torch.float16, use_fast=False)
|
||||
inputs = tokenizer('Today is', return_tensors='pt').to(device)
|
||||
For detailed installation instructions, see `Flash Attention <https://github.com/Dao-AILab/flash-attention>`_.
|
||||
|
||||
model_eager = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, attn_implementation="eager").cuda(device)
|
||||
model_ckFAv2 = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, attn_implementation="flash_attention_2").cuda(device)
|
||||
Benchmarking Flash Attention 2
|
||||
------------------------------
|
||||
|
||||
print("eager GQA: ", tokenizer.decode(model_eager.generate(**inputs, max_new_tokens=10)[0], skip_special_tokens=True))
|
||||
print("ckFAv2 GQA: ", tokenizer.decode(model_ckFAv2.generate(**inputs, max_new_tokens=10)[0], skip_special_tokens=True))
|
||||
Benchmark scripts to evaluate the performance of Flash Attention 2 are stored in the `flash-attention/benchmarks/` directory.
|
||||
|
||||
# eager GQA: Today is the day of the Lord, and we are the
|
||||
# ckFAv2 GQA: Today is the day of the Lord, and we are the
|
||||
To benchmark the CK backend
|
||||
|
||||
.. tab-item:: Triton Flash Attention 2
|
||||
.. code-block:: shell
|
||||
|
||||
The Triton Flash Attention 2 module is implemented in Python and uses OpenAI’s JIT compiler. This module has been
|
||||
upstreamed into the vLLM serving toolkit, discussed in :doc:'llm-inference-frameworks'.
|
||||
cd flash-attention/benchmarks
|
||||
pip install transformers einops ninja
|
||||
|
||||
1. To install Triton Flash Attention 2 and run the benchmark, use the following commands.
|
||||
python3 benchmark_flash_attention.py
|
||||
|
||||
.. code-block:: shell
|
||||
To benchmark the Triton backend
|
||||
|
||||
# Install from the source
|
||||
pip uninstall pytorch-triton-rocm triton -y
|
||||
git clone https://github.com/ROCm/triton.git
|
||||
cd triton/python
|
||||
GPU_ARCHS=gfx942 python setup.py install #MI300 series
|
||||
pip install matplotlib pandas
|
||||
.. code-block:: shell
|
||||
|
||||
2. To test, run the Triton Flash Attention 2 performance benchmark.
|
||||
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python3 benchmark_flash_attention.py
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
# Test the triton FA v2 kernel
|
||||
python https://github.com/ROCm/triton/blob/triton-mlir/python/perf-kernels/flash-attention.py
|
||||
# Results (Okay to release TFLOPS number ???)
|
||||
fused-attention-fwd-d128:
|
||||
BATCH HQ HK N_CTX_Q N_CTX_K TFLOPS
|
||||
0 16.0 16.0 16.0 1024.0 1024.0 287.528411
|
||||
1 8.0 16.0 16.0 2048.0 2048.0 287.490806
|
||||
2 4.0 16.0 16.0 4096.0 4096.0 345.966031
|
||||
3 2.0 16.0 16.0 8192.0 8192.0 361.369510
|
||||
4 1.0 16.0 16.0 16384.0 16384.0 356.873720
|
||||
5 2.0 48.0 48.0 1024.0 1024.0 216.916235
|
||||
6 2.0 48.0 48.0 2048.0 1024.0 271.027578
|
||||
7 2.0 48.0 48.0 4096.0 8192.0 337.367372
|
||||
8 2.0 48.0 48.0 8192.0 4096.0 363.481649
|
||||
9 2.0 48.0 48.0 16384.0 8192.0 375.013622
|
||||
10 8.0 16.0 16.0 1989.0 15344.0 321.791333
|
||||
11 4.0 16.0 16.0 4097.0 163.0 122.104888
|
||||
12 2.0 16.0 16.0 8122.0 2159.0 337.060283
|
||||
13 1.0 16.0 16.0 16281.0 7.0 5.234012
|
||||
14 2.0 48.0 48.0 1021.0 1020.0 214.657425
|
||||
15 2.0 48.0 48.0 2001.0 2048.0 314.429118
|
||||
16 2.0 48.0 48.0 3996.0 9639.0 330.411368
|
||||
17 2.0 48.0 48.0 8181.0 1021.0 324.614980
|
||||
Using Flash Attention 2
|
||||
-----------------------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
model_name = "NousResearch/Llama-3.2-1B"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, dtype=torch.bfloat16, use_fast=False)
|
||||
inputs = tokenizer('Today is', return_tensors='pt').to(device)
|
||||
|
||||
model_eager = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.bfloat16, attn_implementation="eager").cuda(device)
|
||||
model_ckFAv2 = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").cuda(device)
|
||||
model_eager.generation_config.pad_token_id = model_eager.generation_config.eos_token_id
|
||||
model_ckFAv2.generation_config.pad_token_id = model_ckFAv2.generation_config.eos_token_id
|
||||
|
||||
print("eager\n GQA: ", tokenizer.decode(model_eager.generate(**inputs, max_new_tokens=22)[0], skip_special_tokens=True, do_sample=False, num_beams=1))
|
||||
print("ckFAv2\n GQA: ", tokenizer.decode(model_ckFAv2.generate(**inputs, max_new_tokens=22)[0], skip_special_tokens=True, do_sample=False, num_beams=1))
|
||||
|
||||
The outputs from eager mode and FlashAttention-2 are identical, although their performance behavior differs.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
eager
|
||||
GQA: Today is the 10th anniversary of the 9/11 attacks. I remember that day like it was yesterday.
|
||||
ckFAv2
|
||||
GQA: Today is the 10th anniversary of the 9/11 attacks. I remember that day like it was yesterday.
|
||||
|
||||
xFormers
|
||||
========
|
||||
|
||||
Reference in New Issue
Block a user