mirror of
https://github.com/ROCm/ROCm.git
synced 2026-01-09 22:58:17 -05:00
Updates to the vLLM optimization guide for MI300X/MI355X (#5554)
* Expand vLLM optimization guide for MI300X/MI355X with comprehensive AITER coverage. attention backend selection, environment variables (HIP/RCCL/Quick Reduce), parallelism strategies, quantization (FP8/FP4), engine tuning, CUDA graph modes, and multi-node scaling. Co-authored-by: PinSiang <pinsiang.tan@embeddedllm.com> Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Co-authored-by: pinsiangamd <pinsiang.tan@amd.com> Co-authored-by: Jeffrey Novotny <jnovotny@amd.com>
This commit is contained in:
@@ -27,6 +27,7 @@ ASICs
|
|||||||
ASan
|
ASan
|
||||||
ASAN
|
ASAN
|
||||||
ASm
|
ASm
|
||||||
|
Async
|
||||||
ATI
|
ATI
|
||||||
atomicRMW
|
atomicRMW
|
||||||
AddressSanitizer
|
AddressSanitizer
|
||||||
@@ -133,6 +134,7 @@ ELMo
|
|||||||
ENDPGM
|
ENDPGM
|
||||||
EPYC
|
EPYC
|
||||||
ESXi
|
ESXi
|
||||||
|
EP
|
||||||
EoS
|
EoS
|
||||||
etcd
|
etcd
|
||||||
fas
|
fas
|
||||||
@@ -184,6 +186,7 @@ GPR
|
|||||||
GPT
|
GPT
|
||||||
GPU
|
GPU
|
||||||
GPU's
|
GPU's
|
||||||
|
GPUDirect
|
||||||
GPUs
|
GPUs
|
||||||
GraphBolt
|
GraphBolt
|
||||||
GraphSage
|
GraphSage
|
||||||
@@ -302,6 +305,7 @@ Makefiles
|
|||||||
Matplotlib
|
Matplotlib
|
||||||
Matrox
|
Matrox
|
||||||
MaxText
|
MaxText
|
||||||
|
MBT
|
||||||
Megablocks
|
Megablocks
|
||||||
Megatrends
|
Megatrends
|
||||||
Megatron
|
Megatron
|
||||||
@@ -311,6 +315,7 @@ Meta's
|
|||||||
Miniconda
|
Miniconda
|
||||||
MirroredStrategy
|
MirroredStrategy
|
||||||
Mixtral
|
Mixtral
|
||||||
|
MLA
|
||||||
MosaicML
|
MosaicML
|
||||||
MoEs
|
MoEs
|
||||||
Mooncake
|
Mooncake
|
||||||
@@ -353,6 +358,7 @@ OFED
|
|||||||
OMM
|
OMM
|
||||||
OMP
|
OMP
|
||||||
OMPI
|
OMPI
|
||||||
|
OOM
|
||||||
OMPT
|
OMPT
|
||||||
OMPX
|
OMPX
|
||||||
ONNX
|
ONNX
|
||||||
@@ -398,6 +404,7 @@ Profiler's
|
|||||||
PyPi
|
PyPi
|
||||||
Pytest
|
Pytest
|
||||||
PyTorch
|
PyTorch
|
||||||
|
QPS
|
||||||
Qcycles
|
Qcycles
|
||||||
Qwen
|
Qwen
|
||||||
RAII
|
RAII
|
||||||
@@ -673,6 +680,7 @@ denoised
|
|||||||
denoises
|
denoises
|
||||||
denormalize
|
denormalize
|
||||||
dequantization
|
dequantization
|
||||||
|
dequantized
|
||||||
dequantizes
|
dequantizes
|
||||||
deserializers
|
deserializers
|
||||||
detections
|
detections
|
||||||
@@ -788,6 +796,7 @@ linalg
|
|||||||
linearized
|
linearized
|
||||||
linter
|
linter
|
||||||
linux
|
linux
|
||||||
|
llm
|
||||||
llvm
|
llvm
|
||||||
lm
|
lm
|
||||||
localscratch
|
localscratch
|
||||||
@@ -838,6 +847,7 @@ passthrough
|
|||||||
pe
|
pe
|
||||||
perfcounter
|
perfcounter
|
||||||
performant
|
performant
|
||||||
|
piecewise
|
||||||
perl
|
perl
|
||||||
pragma
|
pragma
|
||||||
pre
|
pre
|
||||||
@@ -984,6 +994,7 @@ tokenizer
|
|||||||
tokenizes
|
tokenizes
|
||||||
toolchain
|
toolchain
|
||||||
toolchains
|
toolchains
|
||||||
|
topk
|
||||||
toolset
|
toolset
|
||||||
toolsets
|
toolsets
|
||||||
torchtitan
|
torchtitan
|
||||||
@@ -1011,6 +1022,7 @@ USM
|
|||||||
UTCL
|
UTCL
|
||||||
UTIL
|
UTIL
|
||||||
utils
|
utils
|
||||||
|
UX
|
||||||
vL
|
vL
|
||||||
variational
|
variational
|
||||||
vdi
|
vdi
|
||||||
|
|||||||
1139
docs/how-to/rocm-for-ai/inference-optimization/vllm-optimization.rst
Normal file
1139
docs/how-to/rocm-for-ai/inference-optimization/vllm-optimization.rst
Normal file
File diff suppressed because it is too large
Load Diff
@@ -15,10 +15,9 @@ using PyTorch. It delves into specific workloads such as
|
|||||||
:ref:`model inference <mi300x-vllm-optimization>`, offering strategies to
|
:ref:`model inference <mi300x-vllm-optimization>`, offering strategies to
|
||||||
enhance efficiency.
|
enhance efficiency.
|
||||||
|
|
||||||
The following topics highlight :ref:`auto-tunable configurations <mi300x-auto-tune>`
|
The following topics highlight :ref:`auto-tunable configurations <mi300x-auto-tune>` as
|
||||||
that streamline optimization as well as advanced techniques like
|
well as :ref:`Triton kernel optimization <mi300x-triton-kernel-performance-optimization>`
|
||||||
:ref:`Triton kernel optimization <mi300x-triton-kernel-performance-optimization>` for
|
for meticulous tuning.
|
||||||
meticulous tuning.
|
|
||||||
|
|
||||||
Workload tuning strategy
|
Workload tuning strategy
|
||||||
========================
|
========================
|
||||||
@@ -86,23 +85,22 @@ Optimize model inference with vLLM
|
|||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
vLLM provides tools and techniques specifically designed for efficient model
|
vLLM provides tools and techniques specifically designed for efficient model
|
||||||
inference on AMD Instinct MI300X GPUs. See :ref:`fine-tuning-llms-vllm`
|
inference on AMD Instinct GPUs. See the official `vLLM installation docs
|
||||||
for installation guidance. Optimizing performance with vLLM
|
<https://docs.vllm.ai/en/latest/getting_started/installation/gpu.html>`__ for
|
||||||
involves configuring tensor parallelism, leveraging advanced features, and
|
installation guidance. Optimizing performance with vLLM involves configuring
|
||||||
ensuring efficient execution. Here’s how to optimize vLLM performance:
|
tensor parallelism, leveraging advanced features, and ensuring efficient
|
||||||
|
execution.
|
||||||
|
|
||||||
* Tensor parallelism: Configure the
|
* Configuration for vLLM: Set engine arguments according to workload
|
||||||
:ref:`tensor-parallel-size parameter <mi300x-vllm-multiple-gpus>` to distribute
|
requirements.
|
||||||
tensor computations across multiple GPUs. Adjust parameters such as
|
|
||||||
``batch-size``, ``input-len``, and ``output-len`` based on your workload.
|
|
||||||
|
|
||||||
* Configuration for vLLM: Set :ref:`parameters <mi300x-vllm-optimization>`
|
|
||||||
according to workload requirements. Benchmark performance to understand
|
|
||||||
characteristics and identify bottlenecks.
|
|
||||||
|
|
||||||
* Benchmarking and performance metrics: Measure latency and throughput to
|
* Benchmarking and performance metrics: Measure latency and throughput to
|
||||||
evaluate performance.
|
evaluate performance.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
See :doc:`vllm-optimization`.
|
||||||
|
|
||||||
.. _mi300x-auto-tune:
|
.. _mi300x-auto-tune:
|
||||||
|
|
||||||
Auto-tunable configurations
|
Auto-tunable configurations
|
||||||
@@ -120,8 +118,7 @@ characteristics. For example:
|
|||||||
your specific hardware.
|
your specific hardware.
|
||||||
|
|
||||||
* Triton: Use :ref:`Triton’s auto-tuning features <mi300x-autotunable-kernel-config>`
|
* Triton: Use :ref:`Triton’s auto-tuning features <mi300x-autotunable-kernel-config>`
|
||||||
to explore various kernel configurations and automatically select the
|
to explore various kernel configurations and select the best-performing ones.
|
||||||
best-performing ones.
|
|
||||||
|
|
||||||
Manual tuning
|
Manual tuning
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
@@ -328,381 +325,6 @@ hardware counters are also included.
|
|||||||
|
|
||||||
ROCm Systems Profiler timeline trace example.
|
ROCm Systems Profiler timeline trace example.
|
||||||
|
|
||||||
.. _mi300x-vllm-optimization:
|
|
||||||
|
|
||||||
vLLM performance optimization
|
|
||||||
=============================
|
|
||||||
|
|
||||||
vLLM is a high-throughput and memory efficient inference and serving engine for large language models that has gained traction in the AI community for
|
|
||||||
its performance and ease of use. See :ref:`fine-tuning-llms-vllm` for a primer on vLLM with ROCm.
|
|
||||||
|
|
||||||
Performance environment variables
|
|
||||||
---------------------------------
|
|
||||||
|
|
||||||
The following performance tips are not *specific* to vLLM -- they are general
|
|
||||||
but relevant in this context. You can tune the following vLLM parameters to
|
|
||||||
achieve optimal request latency and throughput performance.
|
|
||||||
|
|
||||||
* As described in `Environment variables (MI300X)
|
|
||||||
<https://instinct.docs.amd.com/projects/amdgpu-docs/en/latest/system-optimization/mi300x.html#environment-variables>`_,
|
|
||||||
the environment variable ``HIP_FORCE_DEV_KERNARG`` can improve vLLM
|
|
||||||
performance. Set it to ``export HIP_FORCE_DEV_KERNARG=1``.
|
|
||||||
|
|
||||||
* Set the :ref:`RCCL environment variable <mi300x-rccl>` ``NCCL_MIN_NCHANNELS``
|
|
||||||
to ``112`` to increase the number of channels on MI300X to potentially improve
|
|
||||||
performance.
|
|
||||||
|
|
||||||
* Set the environment variable ``TORCH_BLAS_PREFER_HIPBLASLT=1`` to use hipBLASLt to improve performance.
|
|
||||||
|
|
||||||
Auto-tuning using PyTorch TunableOp
|
|
||||||
------------------------------------
|
|
||||||
|
|
||||||
Since vLLM is based on the PyTorch framework, PyTorch TunableOp can be used for auto-tuning.
|
|
||||||
You can run auto-tuning with TunableOp in two simple steps without modifying your code:
|
|
||||||
|
|
||||||
* Enable TunableOp and tuning. Optionally, enable verbose mode:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_VERBOSE=1 your_vllm_script.sh
|
|
||||||
|
|
||||||
* Enable TunableOp and disable tuning and measure.
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=0 your_vllm_script.sh
|
|
||||||
|
|
||||||
Learn more about TunableOp in the :ref:`PyTorch TunableOp <mi300x-tunableop>` section.
|
|
||||||
|
|
||||||
Performance tuning based on vLLM engine configurations
|
|
||||||
-------------------------------------------------------
|
|
||||||
|
|
||||||
The following subsections describe vLLM-specific configurations for performance tuning.
|
|
||||||
You can tune the following vLLM parameters to achieve optimal performance.
|
|
||||||
|
|
||||||
* ``tensor_parallel_size``
|
|
||||||
|
|
||||||
* ``gpu_memory_utilization``
|
|
||||||
|
|
||||||
* ``dtype``
|
|
||||||
|
|
||||||
* ``enforce_eager``
|
|
||||||
|
|
||||||
* ``kv_cache_dtype``
|
|
||||||
|
|
||||||
* ``input_len``
|
|
||||||
|
|
||||||
* ``output_len``
|
|
||||||
|
|
||||||
* ``max_num_seqs``
|
|
||||||
|
|
||||||
* ``num_scheduler_steps``
|
|
||||||
|
|
||||||
* ``max_model_len``
|
|
||||||
|
|
||||||
* ``enable_chunked_prefill``
|
|
||||||
|
|
||||||
* ``distributed_executor_backend``
|
|
||||||
|
|
||||||
* ``max_seq_len_to_capture``
|
|
||||||
|
|
||||||
Refer to `vLLM documentation <https://docs.vllm.ai/en/latest/models/performance.html>`_
|
|
||||||
for additional performance tips. :ref:`fine-tuning-llms-vllm` describes vLLM
|
|
||||||
usage with ROCm.
|
|
||||||
|
|
||||||
ROCm provides a prebuilt optimized Docker image for validating the performance
|
|
||||||
of LLM inference with vLLM on MI300X Series GPUs. The Docker image includes
|
|
||||||
ROCm, vLLM, and PyTorch. For more information, see
|
|
||||||
:doc:`/how-to/rocm-for-ai/inference/benchmark-docker/vllm`.
|
|
||||||
|
|
||||||
.. _mi300x-vllm-throughput-measurement:
|
|
||||||
|
|
||||||
Evaluating performance by throughput measurement
|
|
||||||
-------------------------------------------------
|
|
||||||
|
|
||||||
This tuning guide evaluates the performance of LLM inference workloads by measuring throughput in tokens per second (TPS). Throughput can be assessed using both real-world and synthetic data, depending on your evaluation goals.
|
|
||||||
|
|
||||||
Refer to the benchmarking script located at ``benchmarks/benchmark_throughput.py`` in the `vLLM repository <https://github.com/ROCm/vllm/blob/main/benchmarks/benchmark_throughput.py>`_.
|
|
||||||
Use this script to measure throughput effectively. You can assess throughput using real-world and synthetic data, depending on your evaluation goals.
|
|
||||||
|
|
||||||
* For realistic performance evaluation, you can use datasets like Hugging Face's
|
|
||||||
``ShareGPT_V3_unfiltered_cleaned_split.json``. This dataset includes real-world conversational
|
|
||||||
data, making it a good representation of typical use cases for language models. Download it using
|
|
||||||
the following command:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
|
||||||
|
|
||||||
* For standardized benchmarking, you can set fixed input and output token
|
|
||||||
lengths. Synthetic prompts provide consistent benchmarking runs, making it
|
|
||||||
easier to compare performance across different models or configurations.
|
|
||||||
Additionally, a controlled environment simplifies analysis.
|
|
||||||
|
|
||||||
By balancing real-world data and synthetic data approaches, you can get a well-rounded understanding of model performance in varied scenarios.
|
|
||||||
|
|
||||||
.. _mi300x-vllm-single-node:
|
|
||||||
|
|
||||||
Maximizing vLLM instances on a single node
|
|
||||||
------------------------------------------
|
|
||||||
|
|
||||||
The general guideline is to maximize per-node throughput by running as many vLLM instances as possible.
|
|
||||||
However, running too many instances might lead to insufficient memory for the KV-cache, which can affect performance.
|
|
||||||
|
|
||||||
The Instinct MI300X GPU is equipped with 192 GB of HBM3 memory capacity and bandwidth.
|
|
||||||
For models that fit in one GPU -- to maximize the accumulated throughput -- you can run as many as eight vLLM instances
|
|
||||||
simultaneously on one MI300X node (with eight GPUs). To do so, use the GPU isolation environment
|
|
||||||
variable ``CUDA_VISIBLE_DEVICES``.
|
|
||||||
|
|
||||||
For example, this script runs eight instances of vLLM for throughput benchmarking at the same time
|
|
||||||
with a model that can fit in one GPU:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
for i in $(seq 0 7);
|
|
||||||
do
|
|
||||||
CUDA_VISIBLE_DEVICES="$i" python3 /app/vllm/benchmarks/benchmark_throughput.py -tp 1 --dataset "/path/to/dataset/ShareGPT_V3_unfiltered_cleaned_split.json" --model /path/to/model &
|
|
||||||
done
|
|
||||||
|
|
||||||
The total throughput achieved by running ``N`` instances of vLLM is generally much higher than running a
|
|
||||||
single vLLM instance across ``N`` GPUs simultaneously (that is, configuring ``tensor_parallel_size`` as N or
|
|
||||||
using the ``-tp`` N option, where ``1 < N ≤ 8``).
|
|
||||||
|
|
||||||
vLLM on MI300X GPUs can run a variety of model weights, including Llama 2 (7b, 13b, 70b), Llama 3 (8b, 70b), Qwen2 (7b, 72b), Mixtral-8x7b, Mixtral-8x22b, and so on.
|
|
||||||
Notable configurations include Llama2-70b and Llama3-70b models on a single MI300X GPU, and the Llama3.1 405b model can fit on one single node with 8 MI300X GPUs.
|
|
||||||
|
|
||||||
.. _mi300x-vllm-gpu-memory-utilization:
|
|
||||||
|
|
||||||
Configure the gpu_memory_utilization parameter
|
|
||||||
----------------------------------------------
|
|
||||||
|
|
||||||
There are two ways to increase throughput by configuring ``gpu-memory-utilization`` parameter.
|
|
||||||
|
|
||||||
1. Increase ``gpu-memory-utilization`` to improve the throughput for a single instance as long as
|
|
||||||
it does not incur HIP or CUDA Out Of Memory. The default ``gpu-memory-utilization`` is 0.9.
|
|
||||||
You can set it to ``>0.9`` and ``<1``.
|
|
||||||
|
|
||||||
For example, below benchmarking command set the ``gpu-memory-utilization`` as 0.98, or 98%.
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
/vllm-workspace/benchmarks/benchmark_throughput.py --gpu-memory-utilization 0.98 --input-len 1024 --output-len 128 --model /path/to/model
|
|
||||||
|
|
||||||
2. Decrease ``gpu-memory-utilization`` to maximize the number of vLLM instances on the same GPU.
|
|
||||||
|
|
||||||
Specify GPU memory utilization to run as many instances of vLLM as possible on a single
|
|
||||||
GPU. However, too many instances can result in no memory for KV-cache. For small models, run
|
|
||||||
multiple instances of vLLM on the same GPU by specifying a smaller ``gpu-memory-utilization`` -- as
|
|
||||||
long as it would not cause HIP Out Of Memory.
|
|
||||||
|
|
||||||
For example, run two instances of the Llama3-8b model at the same time on a single GPU by specifying
|
|
||||||
``--gpu-memory-utilization`` to 0.4 (40%) as follows (on GPU ``0``):
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python3 /vllm-workspace/benchmarks/benchmark_throughput.py --gpu-memory-utilization 0.4
|
|
||||||
--dataset "/path/to/dataset/ShareGPT_V3_unfiltered_cleaned_split.json" --model /path/to/model &
|
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python3 /vllm-workspace/benchmarks/benchmark_throughput.py --gpu-memory-utilization 0.4
|
|
||||||
--dataset "/path/to/dataset/ShareGPT_V3_unfiltered_cleaned_split.json" --model /path/to/model &
|
|
||||||
|
|
||||||
See :ref:`vllm-engine-args` for other performance suggestions.
|
|
||||||
|
|
||||||
.. _mi300x-vllm-multiple-gpus:
|
|
||||||
|
|
||||||
Run vLLM on multiple GPUs
|
|
||||||
-------------------------
|
|
||||||
|
|
||||||
The two main reasons to use multiple GPUs are:
|
|
||||||
|
|
||||||
* The model size is too big to run vLLM using one GPU as it results HIP Out of Memory.
|
|
||||||
|
|
||||||
* To achieve better latency when using a single GPU is not desirable.
|
|
||||||
|
|
||||||
To run one vLLM instance on multiple GPUs, use the ``-tp`` or ``--tensor-parallel-size`` option to
|
|
||||||
specify multiple GPUs. Optionally, use the ``CUDA_VISIBLE_DEVICES`` environment variable to specify
|
|
||||||
the GPUs.
|
|
||||||
|
|
||||||
For example, you can use two GPUs to start an API server on port 8000:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
python -m vllm.entrypoints.api_server --model /path/to/model --dtype
|
|
||||||
float16 -tp 2 --port 8000 &
|
|
||||||
|
|
||||||
To achieve both latency and throughput performance for serving, you can run multiple API servers on
|
|
||||||
different GPUs by specifying different ports for each server and use ``CUDA_VISIBLE_DEVICES`` to
|
|
||||||
specify the GPUs for each server, for example:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=0,1 python -m vllm.entrypoints.api_server --model
|
|
||||||
/path/to/model --dtype float16 -tp 2 --port 8000 &
|
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=2,3 python -m vllm.entrypoints.api_server --model
|
|
||||||
/path/to/model --dtype float16 -tp 2 --port 8001 &
|
|
||||||
|
|
||||||
Choose an attention backend
|
|
||||||
---------------------------
|
|
||||||
|
|
||||||
vLLM on ROCm supports two attention backends, each suitable for different use cases and performance
|
|
||||||
requirements:
|
|
||||||
|
|
||||||
- **Triton Flash Attention** - For benchmarking, run vLLM scripts at
|
|
||||||
least once as a warm-up step so Triton can perform auto-tuning before
|
|
||||||
collecting benchmarking numbers. This is the default setting.
|
|
||||||
|
|
||||||
- **Composable Kernel (CK) Flash Attention** - To use CK Flash Attention, specify
|
|
||||||
the environment variable as ``export VLLM_USE_TRITON_FLASH_ATTN=0``.
|
|
||||||
|
|
||||||
|
|
||||||
Refer to :ref:`Model acceleration libraries <acceleration-flash-attention>`
|
|
||||||
to learn more about Flash Attention with Triton or CK backends.
|
|
||||||
|
|
||||||
.. _vllm-engine-args:
|
|
||||||
|
|
||||||
vLLM engine arguments
|
|
||||||
---------------------
|
|
||||||
|
|
||||||
The following are configuration suggestions to potentially improve performance with vLLM. See
|
|
||||||
`vLLM's engine arguments documentation <https://docs.vllm.ai/en/latest/serving/engine_args.html>`_
|
|
||||||
for a full list of configurable engine arguments.
|
|
||||||
|
|
||||||
Configure the max-num-seqs parameter
|
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
||||||
|
|
||||||
Increase the ``max-num-seqs`` parameter from the default ``256`` to ``512`` (``--max-num-seqs
|
|
||||||
512``). This increases the maximum number of sequences per iteration and can improve throughput.
|
|
||||||
|
|
||||||
Use the float16 dtype
|
|
||||||
^^^^^^^^^^^^^^^^^^^^^
|
|
||||||
|
|
||||||
The default data type (``dtype``) is specified in the model’s configuration file. For instance, some models use ``torch.bfloat16`` as their default ``dtype``.
|
|
||||||
Use float16 (``--dtype float16``) for better performance.
|
|
||||||
|
|
||||||
Multi-step scheduling
|
|
||||||
^^^^^^^^^^^^^^^^^^^^^
|
|
||||||
|
|
||||||
Setting ``num-scheduler-steps`` for multi-step scheduling can increase performance. Set it between 10 to 15 (``--num-scheduler-steps 10``).
|
|
||||||
|
|
||||||
Distributed executor backend
|
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
||||||
|
|
||||||
The vLLM supports two modes of distributed executor backend: ``ray`` and ``mp``. When using the `<https://github.com/ROCm/vllm>`__ fork, using the ``mp``
|
|
||||||
backend (``--distributed_executor_backend mp``) is recommended.
|
|
||||||
|
|
||||||
Graph mode max-seq-len-to-capture
|
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
||||||
|
|
||||||
Maximum sequence length covered by CUDA graphs. In the default mode (where ``enforce_eager`` is ``False``), when a sequence has context length
|
|
||||||
larger than this, vLLM engine falls back to eager mode. The default is 8192.
|
|
||||||
|
|
||||||
When working with models that support long context lengths, set the parameter ``--max-seq-len-to-capture`` to 16384.
|
|
||||||
See this `vLLM blog <https://blog.vllm.ai/2024/10/23/vllm-serving-amd.html>`__ for details.
|
|
||||||
|
|
||||||
An example of long context length model is Qwen2-7b.
|
|
||||||
|
|
||||||
Whether to enable chunked prefill
|
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
||||||
|
|
||||||
Another vLLM performance tip is to enable chunked prefill to improve
|
|
||||||
throughput. Chunked prefill allows large prefills to be chunked into
|
|
||||||
smaller chunks and batched together with decode requests.
|
|
||||||
|
|
||||||
You can enable the feature by specifying ``--enable-chunked-prefill`` in the
|
|
||||||
command line or setting ``enable_chunked_prefill=True`` in the LLM
|
|
||||||
constructor.
|
|
||||||
|
|
||||||
As stated in `vLLM's documentation, <https://docs.vllm.ai/en/latest/models/performance.html#chunked-prefill>`__,
|
|
||||||
you can tune the performance by changing ``max_num_batched_tokens``. By
|
|
||||||
default, it is set to 512 and optimized for ITL (inter-token latency).
|
|
||||||
Smaller ``max_num_batched_tokens`` achieves better ITL because there are
|
|
||||||
fewer prefills interrupting decodes.
|
|
||||||
Higher ``max_num_batched_tokens`` achieves better TTFT (time to the first
|
|
||||||
token) as you can put more prefill to the batch.
|
|
||||||
|
|
||||||
You might experience noticeable throughput improvements when
|
|
||||||
benchmarking on a single GPU or 8 GPUs using the vLLM throughput
|
|
||||||
benchmarking script along with the ShareGPT dataset as input.
|
|
||||||
|
|
||||||
In the case of fixed ``input-len``/``output-len``, for some configurations,
|
|
||||||
enabling chunked prefill increases the throughput. For some other
|
|
||||||
configurations, the throughput may be worse and elicit a need to tune
|
|
||||||
parameter ``max_num_batched_tokens`` (for example, increasing ``max_num_batched_tokens`` value to 4096 or larger).
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
|
|
||||||
Chunked prefill is no longer recommended. See the vLLM blog: `Serving LLMs on AMD MI300X: Best Practices <https://blog.vllm.ai/2024/10/23/vllm-serving-amd.html>`_ (October 2024).
|
|
||||||
|
|
||||||
Quantization support
|
|
||||||
---------------------
|
|
||||||
|
|
||||||
Quantization reduces the precision of the model’s weights and activations, which significantly decreases the memory footprint.
|
|
||||||
``fp8(w8a8)`` and ``AWQ`` quantization are supported for ROCm.
|
|
||||||
|
|
||||||
FP8 quantization
|
|
||||||
^^^^^^^^^^^^^^^^^
|
|
||||||
|
|
||||||
`<https://github.com/ROCm/vllm>`__ supports FP8 (8-bit floating point) weight and activation quantization using hardware acceleration on the Instinct MI300X.
|
|
||||||
Quantization of models with FP8 allows for a 2x reduction in model memory requirements and up to a 1.6x improvement in throughput with minimal impact on accuracy.
|
|
||||||
|
|
||||||
AMD publishes Quark Quantized OCP FP8 models on Hugging Face. For example:
|
|
||||||
|
|
||||||
* `Llama-3.1-8B-Instruct-FP8-KV <https://huggingface.co/amd/Llama-3.1-8B-Instruct-FP8-KV>`__
|
|
||||||
* `Llama-3.1-70B-Instruct-FP8-KV <https://huggingface.co/amd/Llama-3.1-70B-Instruct-FP8-KV>`__
|
|
||||||
* `Llama-3.1-405B-Instruct-FP8-KV <https://huggingface.co/amd/Llama-3.1-405B-Instruct-FP8-KV>`__
|
|
||||||
* `Mixtral-8x7B-Instruct-v0.1-FP8-KV <https://huggingface.co/amd/Mixtral-8x7B-Instruct-v0.1-FP8-KV>`__
|
|
||||||
* `Mixtral-8x22B-Instruct-v0.1-FP8-KV <https://huggingface.co/amd/Mixtral-8x22B-Instruct-v0.1-FP8-KV>`__
|
|
||||||
|
|
||||||
To enable vLLM benchmarking to run on fp8 quantized models, use the ``--quantization`` parameter with value ``fp8`` (``--quantization fp8``).
|
|
||||||
|
|
||||||
AWQ quantization
|
|
||||||
^^^^^^^^^^^^^^^^
|
|
||||||
|
|
||||||
You can quantize your own models by installing AutoAWQ or picking one of the 400+ models on Hugging Face. Be aware that
|
|
||||||
that AWQ support in vLLM is currently underoptimized.
|
|
||||||
|
|
||||||
To enable vLLM to run on ``awq`` quantized models, using ``--quantization`` parameter with ``awq`` (``--quantization awq``).
|
|
||||||
|
|
||||||
You can find more specifics in the `vLLM AutoAWQ documentation <https://docs.vllm.ai/en/stable/quantization/auto_awq.html>`_.
|
|
||||||
|
|
||||||
fp8 kv-cached-dtype
|
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^
|
|
||||||
|
|
||||||
Using ``fp8 kv-cache dtype`` can improve performance as it reduces the size
|
|
||||||
of ``kv-cache``. As a result, it reduces the cost required for reading and
|
|
||||||
writing the ``kv-cache``.
|
|
||||||
|
|
||||||
To use this feature, specify ``--kv-cache-dtype`` as ``fp8``.
|
|
||||||
|
|
||||||
To specify the quantization scaling config, use the
|
|
||||||
``--quantization-param-path`` parameter. If the parameter is not specified,
|
|
||||||
the default scaling factor of ``1`` is used, which can lead to less accurate
|
|
||||||
results. To generate ``kv-cache`` scaling JSON file, see `FP8 KV
|
|
||||||
Cache <https://github.com/vllm-project/llm-compressor/blob/main/examples/quantization_kv_cache/README.md>`__
|
|
||||||
in the vLLM GitHub repository.
|
|
||||||
|
|
||||||
Two sample Llama scaling configuration files are in vLLM for ``llama2-70b`` and
|
|
||||||
``llama2-7b``.
|
|
||||||
|
|
||||||
If building the vLLM using
|
|
||||||
`Dockerfile.rocm <https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile.rocm>`_
|
|
||||||
for ``llama2-70b`` scale config, find the file at
|
|
||||||
``/vllm-workspace/tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json`` at
|
|
||||||
runtime.
|
|
||||||
|
|
||||||
Below is a sample command to run benchmarking with this feature enabled
|
|
||||||
for the ``llama2-70b`` model:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
python3 /vllm-workspace/benchmarks/benchmark_throughput.py --model \
|
|
||||||
/path/to/llama2-70b-model --kv-cache-dtype "fp8" \
|
|
||||||
--quantization-param-path \
|
|
||||||
"/vllm-workspace/tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json" \
|
|
||||||
--input-len 512 --output-len 256 --num-prompts 500
|
|
||||||
|
|
||||||
|
|
||||||
.. _mi300x-tunableop:
|
.. _mi300x-tunableop:
|
||||||
|
|
||||||
PyTorch TunableOp
|
PyTorch TunableOp
|
||||||
@@ -946,33 +568,33 @@ for details.
|
|||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
HIP_FORCE_DEV_KERNARG=1 hipblaslt-bench --alpha 1 --beta 0 -r f16_r \
|
HIP_FORCE_DEV_KERNARG=1 hipblaslt-bench --alpha 1 --beta 0 -r f16_r \
|
||||||
--a_type f16_r --b_type f8_r --compute_type f32_f16_r \
|
--a_type f16_r --b_type f8_r --compute_type f32_f16_r \
|
||||||
--initialization trig_float --cold_iters 100 --iters 1000 --rotating 256
|
--initialization trig_float --cold_iters 100 --iters 1000 --rotating 256
|
||||||
|
|
||||||
* Example 2: Benchmark forward epilogues and backward epilogues
|
* Example 2: Benchmark forward epilogues and backward epilogues
|
||||||
|
|
||||||
* ``HIPBLASLT_EPILOGUE_RELU: "--activation_type relu";``
|
* ``HIPBLASLT_EPILOGUE_RELU: "--activation_type relu";``
|
||||||
|
|
||||||
* ``HIPBLASLT_EPILOGUE_BIAS: "--bias_vector";``
|
* ``HIPBLASLT_EPILOGUE_BIAS: "--bias_vector";``
|
||||||
|
|
||||||
* ``HIPBLASLT_EPILOGUE_RELU_BIAS: "--activation_type relu --bias_vector";``
|
* ``HIPBLASLT_EPILOGUE_RELU_BIAS: "--activation_type relu --bias_vector";``
|
||||||
|
|
||||||
* ``HIPBLASLT_EPILOGUE_GELU: "--activation_type gelu";``
|
* ``HIPBLASLT_EPILOGUE_GELU: "--activation_type gelu";``
|
||||||
|
|
||||||
* ``HIPBLASLT_EPILOGUE_DGELU": --activation_type gelu --gradient";``
|
* ``HIPBLASLT_EPILOGUE_DGELU": --activation_type gelu --gradient";``
|
||||||
|
|
||||||
* ``HIPBLASLT_EPILOGUE_GELU_BIAS: "--activation_type gelu --bias_vector";``
|
* ``HIPBLASLT_EPILOGUE_GELU_BIAS: "--activation_type gelu --bias_vector";``
|
||||||
|
|
||||||
* ``HIPBLASLT_EPILOGUE_GELU_AUX: "--activation_type gelu --use_e";``
|
* ``HIPBLASLT_EPILOGUE_GELU_AUX: "--activation_type gelu --use_e";``
|
||||||
|
|
||||||
* ``HIPBLASLT_EPILOGUE_GELU_AUX_BIAS: "--activation_type gelu --bias_vector --use_e";``
|
* ``HIPBLASLT_EPILOGUE_GELU_AUX_BIAS: "--activation_type gelu --bias_vector --use_e";``
|
||||||
|
|
||||||
* ``HIPBLASLT_EPILOGUE_DGELU_BGRAD: "--activation_type gelu --bias_vector --gradient";``
|
* ``HIPBLASLT_EPILOGUE_DGELU_BGRAD: "--activation_type gelu --bias_vector --gradient";``
|
||||||
|
|
||||||
* ``HIPBLASLT_EPILOGUE_BGRADA: "--bias_vector --gradient --bias_source a";``
|
* ``HIPBLASLT_EPILOGUE_BGRADA: "--bias_vector --gradient --bias_source a";``
|
||||||
|
|
||||||
* ``HIPBLASLT_EPILOGUE_BGRADB: "--bias_vector --gradient --bias_source b";``
|
* ``HIPBLASLT_EPILOGUE_BGRADB: "--bias_vector --gradient --bias_source b";``
|
||||||
|
|
||||||
|
|
||||||
hipBLASLt auto-tuning using hipblaslt-bench
|
hipBLASLt auto-tuning using hipblaslt-bench
|
||||||
@@ -1031,26 +653,26 @@ The tuning tool is a two-step tool. It first runs the benchmark, then it creates
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
defaultBenchOptions = {"ProblemType": {
|
defaultBenchOptions = {"ProblemType": {
|
||||||
"TransposeA": 0,
|
"TransposeA": 0,
|
||||||
"TransposeB": 0,
|
"TransposeB": 0,
|
||||||
"ComputeInputDataType": "s",
|
"ComputeInputDataType": "s",
|
||||||
"ComputeDataType": "s",
|
"ComputeDataType": "s",
|
||||||
"DataTypeC": "s",
|
"DataTypeC": "s",
|
||||||
"DataTypeD": "s",
|
"DataTypeD": "s",
|
||||||
"UseBias": False
|
"UseBias": False
|
||||||
}, "TestConfig": {
|
}, "TestConfig": {
|
||||||
"ColdIter": 20,
|
"ColdIter": 20,
|
||||||
"Iter": 100,
|
"Iter": 100,
|
||||||
"AlgoMethod": "all",
|
"AlgoMethod": "all",
|
||||||
"RequestedSolutions": 2, # Only works in AlgoMethod heuristic
|
"RequestedSolutions": 2, # Only works in AlgoMethod heuristic
|
||||||
"SolutionIndex": None, # Only works in AlgoMethod index
|
"SolutionIndex": None, # Only works in AlgoMethod index
|
||||||
"ApiMethod": "cpp",
|
"ApiMethod": "cpp",
|
||||||
"RotatingBuffer": 0,
|
"RotatingBuffer": 0,
|
||||||
}, "TuningParameters": {
|
}, "TuningParameters": {
|
||||||
"SplitK": [0]
|
"SplitK": [0]
|
||||||
}, "ProblemSizes": []}
|
}, "ProblemSizes": []}
|
||||||
defaultCreateLogicOptions = {} # Currently unused
|
defaultCreateLogicOptions = {} # Currently unused
|
||||||
|
|
||||||
* ``TestConfig``
|
* ``TestConfig``
|
||||||
1. ``ColdIter``: This is number the warm-up iterations before starting the kernel benchmark.
|
1. ``ColdIter``: This is number the warm-up iterations before starting the kernel benchmark.
|
||||||
@@ -1230,7 +852,7 @@ command:
|
|||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
merge.py original_dir new_tuned_yaml_dir output_dir
|
merge.py original_dir new_tuned_yaml_dir output_dir
|
||||||
|
|
||||||
The following table describes the logic YAML files.
|
The following table describes the logic YAML files.
|
||||||
|
|
||||||
@@ -1833,7 +1455,7 @@ de-quantize the ``int4`` key-value from the ``int4`` data type to ``fp16``.
|
|||||||
|
|
||||||
From the IR snippet, you can see ``i32`` data is loaded from global memory to
|
From the IR snippet, you can see ``i32`` data is loaded from global memory to
|
||||||
registers (``%190``). With a few element-wise operations in registers, it is
|
registers (``%190``). With a few element-wise operations in registers, it is
|
||||||
stored in shared memory (``%269``) for the transpose operation (``%270``), which
|
stored in shared memory (``%269``) for the transpose operation (``%270``), which
|
||||||
needs data movement across different threads. With the transpose done, it is
|
needs data movement across different threads. With the transpose done, it is
|
||||||
loaded from LDS to register again (``%276``), and with a few more
|
loaded from LDS to register again (``%276``), and with a few more
|
||||||
element-wise operations, it is stored to LDS again (``%298``). The last step
|
element-wise operations, it is stored to LDS again (``%298``). The last step
|
||||||
@@ -1967,7 +1589,7 @@ something similar to the following:
|
|||||||
loaded at: [0x7fd4f100c000-0x7fd4f100e070]
|
loaded at: [0x7fd4f100c000-0x7fd4f100e070]
|
||||||
|
|
||||||
The kernel name and the code object file should be listed. In the
|
The kernel name and the code object file should be listed. In the
|
||||||
example above, the kernel name is vector_add_assert_trap, but this might
|
example above, the kernel name is vector_add_assert_trap, but this might
|
||||||
also look like:
|
also look like:
|
||||||
|
|
||||||
.. code-block:: text
|
.. code-block:: text
|
||||||
@@ -2081,3 +1703,8 @@ Hardware efficiency is maximized with 4 or fewer HIP streams. These environment
|
|||||||
configuration to two compute streams and two RCCL streams, aligning with this best practice.
|
configuration to two compute streams and two RCCL streams, aligning with this best practice.
|
||||||
Additionally, RCCL is often pre-optimized for MI300 systems in production by querying the node
|
Additionally, RCCL is often pre-optimized for MI300 systems in production by querying the node
|
||||||
topology during startup, reducing the need for extensive manual tuning.
|
topology during startup, reducing the need for extensive manual tuning.
|
||||||
|
|
||||||
|
Further reading
|
||||||
|
===============
|
||||||
|
|
||||||
|
* :doc:`vllm-optimization`
|
||||||
|
|||||||
@@ -134,6 +134,8 @@ subtrees:
|
|||||||
title: Profile and debug
|
title: Profile and debug
|
||||||
- file: how-to/rocm-for-ai/inference-optimization/workload.rst
|
- file: how-to/rocm-for-ai/inference-optimization/workload.rst
|
||||||
title: Workload optimization
|
title: Workload optimization
|
||||||
|
- file: how-to/rocm-for-ai/inference-optimization/vllm-optimization.rst
|
||||||
|
title: vLLM V1 performance optimization
|
||||||
|
|
||||||
- url: https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/
|
- url: https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/
|
||||||
title: AI tutorials
|
title: AI tutorials
|
||||||
|
|||||||
Reference in New Issue
Block a user