mirror of
https://github.com/ROCm/ROCm.git
synced 2026-01-09 14:48:06 -05:00
Update Flash Attention guidance in "Model acceleration libraries" (#5793)
* flash attention update Signed-off-by: seungrok.jung <seungrok.jung@amd.com> flash attention update Signed-off-by: seungrok.jung <seungrok.jung@amd.com> flash attention update Signed-off-by: seungrok.jung <seungrok.jung@amd.com> sentence-case heading * Update docs/how-to/rocm-for-ai/inference-optimization/model-acceleration-libraries.rst Co-authored-by: Leo Paoletti <164940351+lpaoletti@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Leo Paoletti <164940351+lpaoletti@users.noreply.github.com> --------- Co-authored-by: seungrok.jung <seungrok.jung@amd.com> Co-authored-by: Leo Paoletti <164940351+lpaoletti@users.noreply.github.com>
This commit is contained in:
136
docs/how-to/rocm-for-ai/inference-optimization/model-acceleration-libraries.rst
Normal file → Executable file
136
docs/how-to/rocm-for-ai/inference-optimization/model-acceleration-libraries.rst
Normal file → Executable file
@@ -24,94 +24,102 @@ performance.
|
|||||||
:alt: Attention module of a large language module utilizing tiling
|
:alt: Attention module of a large language module utilizing tiling
|
||||||
:align: center
|
:align: center
|
||||||
|
|
||||||
|
Installation prerequisites
|
||||||
|
----------------------------
|
||||||
|
|
||||||
|
Before installing Flash Attention 2, ensure the following are available:
|
||||||
|
|
||||||
|
* ROCm-enabled PyTorch
|
||||||
|
* Triton
|
||||||
|
|
||||||
|
These can be installed by following the official
|
||||||
|
`PyTorch installation guide <https://pytorch.org/get-started/locally/>`_. Alternatively, for a simpler setup, you can use a preconfigured
|
||||||
|
`ROCm PyTorch Docker image <https://github.com/ROCm/rocm-docker-images>`_, which already includes the required libraries.
|
||||||
|
|
||||||
Installing Flash Attention 2
|
Installing Flash Attention 2
|
||||||
----------------------------
|
----------------------------
|
||||||
|
|
||||||
ROCm provides two different implementations of Flash Attention 2 modules. They can be deployed interchangeably:
|
`Flash Attention <https://github.com/Dao-AILab/flash-attention>`_ supports two backend implementations on AMD GPUs.
|
||||||
|
|
||||||
* ROCm `Composable Kernel <https://github.com/ROCm/composable_kernel/tree/develop/example/01_gemm>`_
|
* `Composable Kernel (CK) <https://github.com/ROCm/composable_kernel>`_ - the default backend
|
||||||
(CK) Flash Attention 2
|
* `OpenAI Triton <https://github.com/triton-lang/triton>`_ - an alternative backend
|
||||||
|
|
||||||
* `OpenAI Triton <https://triton-lang.org/main/index.html>`_ Flash Attention 2
|
You can switch between these backends using the environment variable FLASH_ATTENTION_TRITON_AMD_ENABLE:
|
||||||
|
|
||||||
.. tab-set::
|
FLASH_ATTENTION_TRITON_AMD_ENABLE="FALSE"
|
||||||
|
→ Use Composable Kernel (CK) backend (FlashAttention 2)
|
||||||
|
|
||||||
.. tab-item:: CK Flash Attention 2
|
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
|
||||||
|
→ Use OpenAI Triton backend (FlashAttention 2)
|
||||||
|
|
||||||
To install CK Flash Attention 2, use the following commands.
|
To install Flash Attention 2, use the following commands:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
# Install from source
|
git clone https://github.com/Dao-AILab/flash-attention.git
|
||||||
git clone https://github.com/ROCm/flash-attention.git
|
cd flash-attention/
|
||||||
cd flash-attention/
|
pip install ninja
|
||||||
GPU_ARCHS=gfx942 python setup.py install #MI300 Series
|
|
||||||
|
|
||||||
Hugging Face Transformers can easily deploy the CK Flash Attention 2 module by passing an argument
|
# To install the CK backend flash attention
|
||||||
``attn_implementation="flash_attention_2"`` in the ``from_pretrained`` class.
|
python setup.py install
|
||||||
|
|
||||||
.. code-block:: python
|
# To install the Triton backend flash attention
|
||||||
|
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install
|
||||||
|
|
||||||
import torch
|
# To install both CK and Triton backend flash attention
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE && FLASH_ATTENTION_SKIP_CK_BUILD=FALSE python setup.py install
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
||||||
model_name = "NousResearch/Meta-Llama-3-8B"
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name, torch_dtype=torch.float16, use_fast=False)
|
For detailed installation instructions, see `Flash Attention <https://github.com/Dao-AILab/flash-attention>`_.
|
||||||
inputs = tokenizer('Today is', return_tensors='pt').to(device)
|
|
||||||
|
|
||||||
model_eager = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, attn_implementation="eager").cuda(device)
|
Benchmarking Flash Attention 2
|
||||||
model_ckFAv2 = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, attn_implementation="flash_attention_2").cuda(device)
|
------------------------------
|
||||||
|
|
||||||
print("eager GQA: ", tokenizer.decode(model_eager.generate(**inputs, max_new_tokens=10)[0], skip_special_tokens=True))
|
Benchmark scripts to evaluate the performance of Flash Attention 2 are stored in the `flash-attention/benchmarks/` directory.
|
||||||
print("ckFAv2 GQA: ", tokenizer.decode(model_ckFAv2.generate(**inputs, max_new_tokens=10)[0], skip_special_tokens=True))
|
|
||||||
|
|
||||||
# eager GQA: Today is the day of the Lord, and we are the
|
To benchmark the CK backend
|
||||||
# ckFAv2 GQA: Today is the day of the Lord, and we are the
|
|
||||||
|
|
||||||
.. tab-item:: Triton Flash Attention 2
|
.. code-block:: shell
|
||||||
|
|
||||||
The Triton Flash Attention 2 module is implemented in Python and uses OpenAI’s JIT compiler. This module has been
|
cd flash-attention/benchmarks
|
||||||
upstreamed into the vLLM serving toolkit, discussed in :doc:'llm-inference-frameworks'.
|
pip install transformers einops ninja
|
||||||
|
|
||||||
1. To install Triton Flash Attention 2 and run the benchmark, use the following commands.
|
python3 benchmark_flash_attention.py
|
||||||
|
|
||||||
.. code-block:: shell
|
To benchmark the Triton backend
|
||||||
|
|
||||||
# Install from the source
|
.. code-block:: shell
|
||||||
pip uninstall pytorch-triton-rocm triton -y
|
|
||||||
git clone https://github.com/ROCm/triton.git
|
|
||||||
cd triton/python
|
|
||||||
GPU_ARCHS=gfx942 python setup.py install #MI300 series
|
|
||||||
pip install matplotlib pandas
|
|
||||||
|
|
||||||
2. To test, run the Triton Flash Attention 2 performance benchmark.
|
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python3 benchmark_flash_attention.py
|
||||||
|
|
||||||
.. code-block:: shell
|
Using Flash Attention 2
|
||||||
|
-----------------------
|
||||||
# Test the triton FA v2 kernel
|
|
||||||
python https://github.com/ROCm/triton/blob/triton-mlir/python/perf-kernels/flash-attention.py
|
.. code-block:: python
|
||||||
# Results (Okay to release TFLOPS number ???)
|
|
||||||
fused-attention-fwd-d128:
|
import torch
|
||||||
BATCH HQ HK N_CTX_Q N_CTX_K TFLOPS
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
0 16.0 16.0 16.0 1024.0 1024.0 287.528411
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
1 8.0 16.0 16.0 2048.0 2048.0 287.490806
|
model_name = "NousResearch/Llama-3.2-1B"
|
||||||
2 4.0 16.0 16.0 4096.0 4096.0 345.966031
|
|
||||||
3 2.0 16.0 16.0 8192.0 8192.0 361.369510
|
tokenizer = AutoTokenizer.from_pretrained(model_name, dtype=torch.bfloat16, use_fast=False)
|
||||||
4 1.0 16.0 16.0 16384.0 16384.0 356.873720
|
inputs = tokenizer('Today is', return_tensors='pt').to(device)
|
||||||
5 2.0 48.0 48.0 1024.0 1024.0 216.916235
|
|
||||||
6 2.0 48.0 48.0 2048.0 1024.0 271.027578
|
model_eager = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.bfloat16, attn_implementation="eager").cuda(device)
|
||||||
7 2.0 48.0 48.0 4096.0 8192.0 337.367372
|
model_ckFAv2 = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").cuda(device)
|
||||||
8 2.0 48.0 48.0 8192.0 4096.0 363.481649
|
model_eager.generation_config.pad_token_id = model_eager.generation_config.eos_token_id
|
||||||
9 2.0 48.0 48.0 16384.0 8192.0 375.013622
|
model_ckFAv2.generation_config.pad_token_id = model_ckFAv2.generation_config.eos_token_id
|
||||||
10 8.0 16.0 16.0 1989.0 15344.0 321.791333
|
|
||||||
11 4.0 16.0 16.0 4097.0 163.0 122.104888
|
print("eager\n GQA: ", tokenizer.decode(model_eager.generate(**inputs, max_new_tokens=22)[0], skip_special_tokens=True, do_sample=False, num_beams=1))
|
||||||
12 2.0 16.0 16.0 8122.0 2159.0 337.060283
|
print("ckFAv2\n GQA: ", tokenizer.decode(model_ckFAv2.generate(**inputs, max_new_tokens=22)[0], skip_special_tokens=True, do_sample=False, num_beams=1))
|
||||||
13 1.0 16.0 16.0 16281.0 7.0 5.234012
|
|
||||||
14 2.0 48.0 48.0 1021.0 1020.0 214.657425
|
The outputs from eager mode and FlashAttention-2 are identical, although their performance behavior differs.
|
||||||
15 2.0 48.0 48.0 2001.0 2048.0 314.429118
|
|
||||||
16 2.0 48.0 48.0 3996.0 9639.0 330.411368
|
.. code-block:: shell
|
||||||
17 2.0 48.0 48.0 8181.0 1021.0 324.614980
|
|
||||||
|
eager
|
||||||
|
GQA: Today is the 10th anniversary of the 9/11 attacks. I remember that day like it was yesterday.
|
||||||
|
ckFAv2
|
||||||
|
GQA: Today is the 10th anniversary of the 9/11 attacks. I remember that day like it was yesterday.
|
||||||
|
|
||||||
xFormers
|
xFormers
|
||||||
========
|
========
|
||||||
|
|||||||
Reference in New Issue
Block a user