diff --git a/docs/how-to/rocm-for-ai/inference-optimization/model-acceleration-libraries.rst b/docs/how-to/rocm-for-ai/inference-optimization/model-acceleration-libraries.rst old mode 100644 new mode 100755 index e3b9a761a..4ea3aa5f7 --- a/docs/how-to/rocm-for-ai/inference-optimization/model-acceleration-libraries.rst +++ b/docs/how-to/rocm-for-ai/inference-optimization/model-acceleration-libraries.rst @@ -24,94 +24,102 @@ performance. :alt: Attention module of a large language module utilizing tiling :align: center +Installation prerequisites +---------------------------- + +Before installing Flash Attention 2, ensure the following are available: + +* ROCm-enabled PyTorch +* Triton + +These can be installed by following the official +`PyTorch installation guide `_. Alternatively, for a simpler setup, you can use a preconfigured +`ROCm PyTorch Docker image `_, which already includes the required libraries. + Installing Flash Attention 2 ---------------------------- -ROCm provides two different implementations of Flash Attention 2 modules. They can be deployed interchangeably: +`Flash Attention `_ supports two backend implementations on AMD GPUs. -* ROCm `Composable Kernel `_ - (CK) Flash Attention 2 +* `Composable Kernel (CK) `_ - the default backend +* `OpenAI Triton `_ - an alternative backend -* `OpenAI Triton `_ Flash Attention 2 +You can switch between these backends using the environment variable FLASH_ATTENTION_TRITON_AMD_ENABLE: -.. tab-set:: +FLASH_ATTENTION_TRITON_AMD_ENABLE="FALSE" +→ Use Composable Kernel (CK) backend (FlashAttention 2) - .. tab-item:: CK Flash Attention 2 +FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" +→ Use OpenAI Triton backend (FlashAttention 2) - To install CK Flash Attention 2, use the following commands. +To install Flash Attention 2, use the following commands: - .. code-block:: shell +.. code-block:: shell - # Install from source - git clone https://github.com/ROCm/flash-attention.git - cd flash-attention/ - GPU_ARCHS=gfx942 python setup.py install #MI300 Series + git clone https://github.com/Dao-AILab/flash-attention.git + cd flash-attention/ + pip install ninja - Hugging Face Transformers can easily deploy the CK Flash Attention 2 module by passing an argument - ``attn_implementation="flash_attention_2"`` in the ``from_pretrained`` class. + # To install the CK backend flash attention + python setup.py install - .. code-block:: python + # To install the Triton backend flash attention + FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install - import torch - from transformers import AutoModelForCausalLM, AutoTokenizer - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - model_name = "NousResearch/Meta-Llama-3-8B" + # To install both CK and Triton backend flash attention + FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE && FLASH_ATTENTION_SKIP_CK_BUILD=FALSE python setup.py install - tokenizer = AutoTokenizer.from_pretrained(model_name, torch_dtype=torch.float16, use_fast=False) - inputs = tokenizer('Today is', return_tensors='pt').to(device) +For detailed installation instructions, see `Flash Attention `_. - model_eager = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, attn_implementation="eager").cuda(device) - model_ckFAv2 = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, attn_implementation="flash_attention_2").cuda(device) +Benchmarking Flash Attention 2 +------------------------------ - print("eager GQA: ", tokenizer.decode(model_eager.generate(**inputs, max_new_tokens=10)[0], skip_special_tokens=True)) - print("ckFAv2 GQA: ", tokenizer.decode(model_ckFAv2.generate(**inputs, max_new_tokens=10)[0], skip_special_tokens=True)) +Benchmark scripts to evaluate the performance of Flash Attention 2 are stored in the `flash-attention/benchmarks/` directory. - # eager GQA: Today is the day of the Lord, and we are the - # ckFAv2 GQA: Today is the day of the Lord, and we are the +To benchmark the CK backend - .. tab-item:: Triton Flash Attention 2 +.. code-block:: shell - The Triton Flash Attention 2 module is implemented in Python and uses OpenAI’s JIT compiler. This module has been - upstreamed into the vLLM serving toolkit, discussed in :doc:'llm-inference-frameworks'. + cd flash-attention/benchmarks + pip install transformers einops ninja - 1. To install Triton Flash Attention 2 and run the benchmark, use the following commands. + python3 benchmark_flash_attention.py - .. code-block:: shell +To benchmark the Triton backend - # Install from the source - pip uninstall pytorch-triton-rocm triton -y - git clone https://github.com/ROCm/triton.git - cd triton/python - GPU_ARCHS=gfx942 python setup.py install #MI300 series - pip install matplotlib pandas +.. code-block:: shell - 2. To test, run the Triton Flash Attention 2 performance benchmark. + FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python3 benchmark_flash_attention.py - .. code-block:: shell - - # Test the triton FA v2 kernel - python https://github.com/ROCm/triton/blob/triton-mlir/python/perf-kernels/flash-attention.py - # Results (Okay to release TFLOPS number ???) - fused-attention-fwd-d128: - BATCH HQ HK N_CTX_Q N_CTX_K TFLOPS - 0 16.0 16.0 16.0 1024.0 1024.0 287.528411 - 1 8.0 16.0 16.0 2048.0 2048.0 287.490806 - 2 4.0 16.0 16.0 4096.0 4096.0 345.966031 - 3 2.0 16.0 16.0 8192.0 8192.0 361.369510 - 4 1.0 16.0 16.0 16384.0 16384.0 356.873720 - 5 2.0 48.0 48.0 1024.0 1024.0 216.916235 - 6 2.0 48.0 48.0 2048.0 1024.0 271.027578 - 7 2.0 48.0 48.0 4096.0 8192.0 337.367372 - 8 2.0 48.0 48.0 8192.0 4096.0 363.481649 - 9 2.0 48.0 48.0 16384.0 8192.0 375.013622 - 10 8.0 16.0 16.0 1989.0 15344.0 321.791333 - 11 4.0 16.0 16.0 4097.0 163.0 122.104888 - 12 2.0 16.0 16.0 8122.0 2159.0 337.060283 - 13 1.0 16.0 16.0 16281.0 7.0 5.234012 - 14 2.0 48.0 48.0 1021.0 1020.0 214.657425 - 15 2.0 48.0 48.0 2001.0 2048.0 314.429118 - 16 2.0 48.0 48.0 3996.0 9639.0 330.411368 - 17 2.0 48.0 48.0 8181.0 1021.0 324.614980 +Using Flash Attention 2 +----------------------- + +.. code-block:: python + + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + model_name = "NousResearch/Llama-3.2-1B" + + tokenizer = AutoTokenizer.from_pretrained(model_name, dtype=torch.bfloat16, use_fast=False) + inputs = tokenizer('Today is', return_tensors='pt').to(device) + + model_eager = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.bfloat16, attn_implementation="eager").cuda(device) + model_ckFAv2 = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.bfloat16, attn_implementation="flash_attention_2").cuda(device) + model_eager.generation_config.pad_token_id = model_eager.generation_config.eos_token_id + model_ckFAv2.generation_config.pad_token_id = model_ckFAv2.generation_config.eos_token_id + + print("eager\n GQA: ", tokenizer.decode(model_eager.generate(**inputs, max_new_tokens=22)[0], skip_special_tokens=True, do_sample=False, num_beams=1)) + print("ckFAv2\n GQA: ", tokenizer.decode(model_ckFAv2.generate(**inputs, max_new_tokens=22)[0], skip_special_tokens=True, do_sample=False, num_beams=1)) + +The outputs from eager mode and FlashAttention-2 are identical, although their performance behavior differs. + +.. code-block:: shell + + eager + GQA: Today is the 10th anniversary of the 9/11 attacks. I remember that day like it was yesterday. + ckFAv2 + GQA: Today is the 10th anniversary of the 9/11 attacks. I remember that day like it was yesterday. xFormers ========