mirror of
https://github.com/ROCm/ROCm.git
synced 2026-01-09 06:38:00 -05:00
Updates to the vLLM optimization guide for MI300X/MI355X (#5554)
* Expand vLLM optimization guide for MI300X/MI355X with comprehensive AITER coverage. attention backend selection, environment variables (HIP/RCCL/Quick Reduce), parallelism strategies, quantization (FP8/FP4), engine tuning, CUDA graph modes, and multi-node scaling. Co-authored-by: PinSiang <pinsiang.tan@embeddedllm.com> Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Co-authored-by: pinsiangamd <pinsiang.tan@amd.com> Co-authored-by: Jeffrey Novotny <jnovotny@amd.com>
This commit is contained in:
@@ -27,6 +27,7 @@ ASICs
|
||||
ASan
|
||||
ASAN
|
||||
ASm
|
||||
Async
|
||||
ATI
|
||||
atomicRMW
|
||||
AddressSanitizer
|
||||
@@ -133,6 +134,7 @@ ELMo
|
||||
ENDPGM
|
||||
EPYC
|
||||
ESXi
|
||||
EP
|
||||
EoS
|
||||
etcd
|
||||
fas
|
||||
@@ -184,6 +186,7 @@ GPR
|
||||
GPT
|
||||
GPU
|
||||
GPU's
|
||||
GPUDirect
|
||||
GPUs
|
||||
GraphBolt
|
||||
GraphSage
|
||||
@@ -302,6 +305,7 @@ Makefiles
|
||||
Matplotlib
|
||||
Matrox
|
||||
MaxText
|
||||
MBT
|
||||
Megablocks
|
||||
Megatrends
|
||||
Megatron
|
||||
@@ -311,6 +315,7 @@ Meta's
|
||||
Miniconda
|
||||
MirroredStrategy
|
||||
Mixtral
|
||||
MLA
|
||||
MosaicML
|
||||
MoEs
|
||||
Mooncake
|
||||
@@ -353,6 +358,7 @@ OFED
|
||||
OMM
|
||||
OMP
|
||||
OMPI
|
||||
OOM
|
||||
OMPT
|
||||
OMPX
|
||||
ONNX
|
||||
@@ -398,6 +404,7 @@ Profiler's
|
||||
PyPi
|
||||
Pytest
|
||||
PyTorch
|
||||
QPS
|
||||
Qcycles
|
||||
Qwen
|
||||
RAII
|
||||
@@ -673,6 +680,7 @@ denoised
|
||||
denoises
|
||||
denormalize
|
||||
dequantization
|
||||
dequantized
|
||||
dequantizes
|
||||
deserializers
|
||||
detections
|
||||
@@ -788,6 +796,7 @@ linalg
|
||||
linearized
|
||||
linter
|
||||
linux
|
||||
llm
|
||||
llvm
|
||||
lm
|
||||
localscratch
|
||||
@@ -838,6 +847,7 @@ passthrough
|
||||
pe
|
||||
perfcounter
|
||||
performant
|
||||
piecewise
|
||||
perl
|
||||
pragma
|
||||
pre
|
||||
@@ -984,6 +994,7 @@ tokenizer
|
||||
tokenizes
|
||||
toolchain
|
||||
toolchains
|
||||
topk
|
||||
toolset
|
||||
toolsets
|
||||
torchtitan
|
||||
@@ -1011,6 +1022,7 @@ USM
|
||||
UTCL
|
||||
UTIL
|
||||
utils
|
||||
UX
|
||||
vL
|
||||
variational
|
||||
vdi
|
||||
|
||||
1139
docs/how-to/rocm-for-ai/inference-optimization/vllm-optimization.rst
Normal file
1139
docs/how-to/rocm-for-ai/inference-optimization/vllm-optimization.rst
Normal file
File diff suppressed because it is too large
Load Diff
@@ -15,10 +15,9 @@ using PyTorch. It delves into specific workloads such as
|
||||
:ref:`model inference <mi300x-vllm-optimization>`, offering strategies to
|
||||
enhance efficiency.
|
||||
|
||||
The following topics highlight :ref:`auto-tunable configurations <mi300x-auto-tune>`
|
||||
that streamline optimization as well as advanced techniques like
|
||||
:ref:`Triton kernel optimization <mi300x-triton-kernel-performance-optimization>` for
|
||||
meticulous tuning.
|
||||
The following topics highlight :ref:`auto-tunable configurations <mi300x-auto-tune>` as
|
||||
well as :ref:`Triton kernel optimization <mi300x-triton-kernel-performance-optimization>`
|
||||
for meticulous tuning.
|
||||
|
||||
Workload tuning strategy
|
||||
========================
|
||||
@@ -86,23 +85,22 @@ Optimize model inference with vLLM
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
vLLM provides tools and techniques specifically designed for efficient model
|
||||
inference on AMD Instinct MI300X GPUs. See :ref:`fine-tuning-llms-vllm`
|
||||
for installation guidance. Optimizing performance with vLLM
|
||||
involves configuring tensor parallelism, leveraging advanced features, and
|
||||
ensuring efficient execution. Here’s how to optimize vLLM performance:
|
||||
inference on AMD Instinct GPUs. See the official `vLLM installation docs
|
||||
<https://docs.vllm.ai/en/latest/getting_started/installation/gpu.html>`__ for
|
||||
installation guidance. Optimizing performance with vLLM involves configuring
|
||||
tensor parallelism, leveraging advanced features, and ensuring efficient
|
||||
execution.
|
||||
|
||||
* Tensor parallelism: Configure the
|
||||
:ref:`tensor-parallel-size parameter <mi300x-vllm-multiple-gpus>` to distribute
|
||||
tensor computations across multiple GPUs. Adjust parameters such as
|
||||
``batch-size``, ``input-len``, and ``output-len`` based on your workload.
|
||||
|
||||
* Configuration for vLLM: Set :ref:`parameters <mi300x-vllm-optimization>`
|
||||
according to workload requirements. Benchmark performance to understand
|
||||
characteristics and identify bottlenecks.
|
||||
* Configuration for vLLM: Set engine arguments according to workload
|
||||
requirements.
|
||||
|
||||
* Benchmarking and performance metrics: Measure latency and throughput to
|
||||
evaluate performance.
|
||||
|
||||
.. seealso::
|
||||
|
||||
See :doc:`vllm-optimization`.
|
||||
|
||||
.. _mi300x-auto-tune:
|
||||
|
||||
Auto-tunable configurations
|
||||
@@ -120,8 +118,7 @@ characteristics. For example:
|
||||
your specific hardware.
|
||||
|
||||
* Triton: Use :ref:`Triton’s auto-tuning features <mi300x-autotunable-kernel-config>`
|
||||
to explore various kernel configurations and automatically select the
|
||||
best-performing ones.
|
||||
to explore various kernel configurations and select the best-performing ones.
|
||||
|
||||
Manual tuning
|
||||
^^^^^^^^^^^^^
|
||||
@@ -328,381 +325,6 @@ hardware counters are also included.
|
||||
|
||||
ROCm Systems Profiler timeline trace example.
|
||||
|
||||
.. _mi300x-vllm-optimization:
|
||||
|
||||
vLLM performance optimization
|
||||
=============================
|
||||
|
||||
vLLM is a high-throughput and memory efficient inference and serving engine for large language models that has gained traction in the AI community for
|
||||
its performance and ease of use. See :ref:`fine-tuning-llms-vllm` for a primer on vLLM with ROCm.
|
||||
|
||||
Performance environment variables
|
||||
---------------------------------
|
||||
|
||||
The following performance tips are not *specific* to vLLM -- they are general
|
||||
but relevant in this context. You can tune the following vLLM parameters to
|
||||
achieve optimal request latency and throughput performance.
|
||||
|
||||
* As described in `Environment variables (MI300X)
|
||||
<https://instinct.docs.amd.com/projects/amdgpu-docs/en/latest/system-optimization/mi300x.html#environment-variables>`_,
|
||||
the environment variable ``HIP_FORCE_DEV_KERNARG`` can improve vLLM
|
||||
performance. Set it to ``export HIP_FORCE_DEV_KERNARG=1``.
|
||||
|
||||
* Set the :ref:`RCCL environment variable <mi300x-rccl>` ``NCCL_MIN_NCHANNELS``
|
||||
to ``112`` to increase the number of channels on MI300X to potentially improve
|
||||
performance.
|
||||
|
||||
* Set the environment variable ``TORCH_BLAS_PREFER_HIPBLASLT=1`` to use hipBLASLt to improve performance.
|
||||
|
||||
Auto-tuning using PyTorch TunableOp
|
||||
------------------------------------
|
||||
|
||||
Since vLLM is based on the PyTorch framework, PyTorch TunableOp can be used for auto-tuning.
|
||||
You can run auto-tuning with TunableOp in two simple steps without modifying your code:
|
||||
|
||||
* Enable TunableOp and tuning. Optionally, enable verbose mode:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_VERBOSE=1 your_vllm_script.sh
|
||||
|
||||
* Enable TunableOp and disable tuning and measure.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=0 your_vllm_script.sh
|
||||
|
||||
Learn more about TunableOp in the :ref:`PyTorch TunableOp <mi300x-tunableop>` section.
|
||||
|
||||
Performance tuning based on vLLM engine configurations
|
||||
-------------------------------------------------------
|
||||
|
||||
The following subsections describe vLLM-specific configurations for performance tuning.
|
||||
You can tune the following vLLM parameters to achieve optimal performance.
|
||||
|
||||
* ``tensor_parallel_size``
|
||||
|
||||
* ``gpu_memory_utilization``
|
||||
|
||||
* ``dtype``
|
||||
|
||||
* ``enforce_eager``
|
||||
|
||||
* ``kv_cache_dtype``
|
||||
|
||||
* ``input_len``
|
||||
|
||||
* ``output_len``
|
||||
|
||||
* ``max_num_seqs``
|
||||
|
||||
* ``num_scheduler_steps``
|
||||
|
||||
* ``max_model_len``
|
||||
|
||||
* ``enable_chunked_prefill``
|
||||
|
||||
* ``distributed_executor_backend``
|
||||
|
||||
* ``max_seq_len_to_capture``
|
||||
|
||||
Refer to `vLLM documentation <https://docs.vllm.ai/en/latest/models/performance.html>`_
|
||||
for additional performance tips. :ref:`fine-tuning-llms-vllm` describes vLLM
|
||||
usage with ROCm.
|
||||
|
||||
ROCm provides a prebuilt optimized Docker image for validating the performance
|
||||
of LLM inference with vLLM on MI300X Series GPUs. The Docker image includes
|
||||
ROCm, vLLM, and PyTorch. For more information, see
|
||||
:doc:`/how-to/rocm-for-ai/inference/benchmark-docker/vllm`.
|
||||
|
||||
.. _mi300x-vllm-throughput-measurement:
|
||||
|
||||
Evaluating performance by throughput measurement
|
||||
-------------------------------------------------
|
||||
|
||||
This tuning guide evaluates the performance of LLM inference workloads by measuring throughput in tokens per second (TPS). Throughput can be assessed using both real-world and synthetic data, depending on your evaluation goals.
|
||||
|
||||
Refer to the benchmarking script located at ``benchmarks/benchmark_throughput.py`` in the `vLLM repository <https://github.com/ROCm/vllm/blob/main/benchmarks/benchmark_throughput.py>`_.
|
||||
Use this script to measure throughput effectively. You can assess throughput using real-world and synthetic data, depending on your evaluation goals.
|
||||
|
||||
* For realistic performance evaluation, you can use datasets like Hugging Face's
|
||||
``ShareGPT_V3_unfiltered_cleaned_split.json``. This dataset includes real-world conversational
|
||||
data, making it a good representation of typical use cases for language models. Download it using
|
||||
the following command:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
|
||||
* For standardized benchmarking, you can set fixed input and output token
|
||||
lengths. Synthetic prompts provide consistent benchmarking runs, making it
|
||||
easier to compare performance across different models or configurations.
|
||||
Additionally, a controlled environment simplifies analysis.
|
||||
|
||||
By balancing real-world data and synthetic data approaches, you can get a well-rounded understanding of model performance in varied scenarios.
|
||||
|
||||
.. _mi300x-vllm-single-node:
|
||||
|
||||
Maximizing vLLM instances on a single node
|
||||
------------------------------------------
|
||||
|
||||
The general guideline is to maximize per-node throughput by running as many vLLM instances as possible.
|
||||
However, running too many instances might lead to insufficient memory for the KV-cache, which can affect performance.
|
||||
|
||||
The Instinct MI300X GPU is equipped with 192 GB of HBM3 memory capacity and bandwidth.
|
||||
For models that fit in one GPU -- to maximize the accumulated throughput -- you can run as many as eight vLLM instances
|
||||
simultaneously on one MI300X node (with eight GPUs). To do so, use the GPU isolation environment
|
||||
variable ``CUDA_VISIBLE_DEVICES``.
|
||||
|
||||
For example, this script runs eight instances of vLLM for throughput benchmarking at the same time
|
||||
with a model that can fit in one GPU:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
for i in $(seq 0 7);
|
||||
do
|
||||
CUDA_VISIBLE_DEVICES="$i" python3 /app/vllm/benchmarks/benchmark_throughput.py -tp 1 --dataset "/path/to/dataset/ShareGPT_V3_unfiltered_cleaned_split.json" --model /path/to/model &
|
||||
done
|
||||
|
||||
The total throughput achieved by running ``N`` instances of vLLM is generally much higher than running a
|
||||
single vLLM instance across ``N`` GPUs simultaneously (that is, configuring ``tensor_parallel_size`` as N or
|
||||
using the ``-tp`` N option, where ``1 < N ≤ 8``).
|
||||
|
||||
vLLM on MI300X GPUs can run a variety of model weights, including Llama 2 (7b, 13b, 70b), Llama 3 (8b, 70b), Qwen2 (7b, 72b), Mixtral-8x7b, Mixtral-8x22b, and so on.
|
||||
Notable configurations include Llama2-70b and Llama3-70b models on a single MI300X GPU, and the Llama3.1 405b model can fit on one single node with 8 MI300X GPUs.
|
||||
|
||||
.. _mi300x-vllm-gpu-memory-utilization:
|
||||
|
||||
Configure the gpu_memory_utilization parameter
|
||||
----------------------------------------------
|
||||
|
||||
There are two ways to increase throughput by configuring ``gpu-memory-utilization`` parameter.
|
||||
|
||||
1. Increase ``gpu-memory-utilization`` to improve the throughput for a single instance as long as
|
||||
it does not incur HIP or CUDA Out Of Memory. The default ``gpu-memory-utilization`` is 0.9.
|
||||
You can set it to ``>0.9`` and ``<1``.
|
||||
|
||||
For example, below benchmarking command set the ``gpu-memory-utilization`` as 0.98, or 98%.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
/vllm-workspace/benchmarks/benchmark_throughput.py --gpu-memory-utilization 0.98 --input-len 1024 --output-len 128 --model /path/to/model
|
||||
|
||||
2. Decrease ``gpu-memory-utilization`` to maximize the number of vLLM instances on the same GPU.
|
||||
|
||||
Specify GPU memory utilization to run as many instances of vLLM as possible on a single
|
||||
GPU. However, too many instances can result in no memory for KV-cache. For small models, run
|
||||
multiple instances of vLLM on the same GPU by specifying a smaller ``gpu-memory-utilization`` -- as
|
||||
long as it would not cause HIP Out Of Memory.
|
||||
|
||||
For example, run two instances of the Llama3-8b model at the same time on a single GPU by specifying
|
||||
``--gpu-memory-utilization`` to 0.4 (40%) as follows (on GPU ``0``):
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python3 /vllm-workspace/benchmarks/benchmark_throughput.py --gpu-memory-utilization 0.4
|
||||
--dataset "/path/to/dataset/ShareGPT_V3_unfiltered_cleaned_split.json" --model /path/to/model &
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python3 /vllm-workspace/benchmarks/benchmark_throughput.py --gpu-memory-utilization 0.4
|
||||
--dataset "/path/to/dataset/ShareGPT_V3_unfiltered_cleaned_split.json" --model /path/to/model &
|
||||
|
||||
See :ref:`vllm-engine-args` for other performance suggestions.
|
||||
|
||||
.. _mi300x-vllm-multiple-gpus:
|
||||
|
||||
Run vLLM on multiple GPUs
|
||||
-------------------------
|
||||
|
||||
The two main reasons to use multiple GPUs are:
|
||||
|
||||
* The model size is too big to run vLLM using one GPU as it results HIP Out of Memory.
|
||||
|
||||
* To achieve better latency when using a single GPU is not desirable.
|
||||
|
||||
To run one vLLM instance on multiple GPUs, use the ``-tp`` or ``--tensor-parallel-size`` option to
|
||||
specify multiple GPUs. Optionally, use the ``CUDA_VISIBLE_DEVICES`` environment variable to specify
|
||||
the GPUs.
|
||||
|
||||
For example, you can use two GPUs to start an API server on port 8000:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
python -m vllm.entrypoints.api_server --model /path/to/model --dtype
|
||||
float16 -tp 2 --port 8000 &
|
||||
|
||||
To achieve both latency and throughput performance for serving, you can run multiple API servers on
|
||||
different GPUs by specifying different ports for each server and use ``CUDA_VISIBLE_DEVICES`` to
|
||||
specify the GPUs for each server, for example:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1 python -m vllm.entrypoints.api_server --model
|
||||
/path/to/model --dtype float16 -tp 2 --port 8000 &
|
||||
|
||||
CUDA_VISIBLE_DEVICES=2,3 python -m vllm.entrypoints.api_server --model
|
||||
/path/to/model --dtype float16 -tp 2 --port 8001 &
|
||||
|
||||
Choose an attention backend
|
||||
---------------------------
|
||||
|
||||
vLLM on ROCm supports two attention backends, each suitable for different use cases and performance
|
||||
requirements:
|
||||
|
||||
- **Triton Flash Attention** - For benchmarking, run vLLM scripts at
|
||||
least once as a warm-up step so Triton can perform auto-tuning before
|
||||
collecting benchmarking numbers. This is the default setting.
|
||||
|
||||
- **Composable Kernel (CK) Flash Attention** - To use CK Flash Attention, specify
|
||||
the environment variable as ``export VLLM_USE_TRITON_FLASH_ATTN=0``.
|
||||
|
||||
|
||||
Refer to :ref:`Model acceleration libraries <acceleration-flash-attention>`
|
||||
to learn more about Flash Attention with Triton or CK backends.
|
||||
|
||||
.. _vllm-engine-args:
|
||||
|
||||
vLLM engine arguments
|
||||
---------------------
|
||||
|
||||
The following are configuration suggestions to potentially improve performance with vLLM. See
|
||||
`vLLM's engine arguments documentation <https://docs.vllm.ai/en/latest/serving/engine_args.html>`_
|
||||
for a full list of configurable engine arguments.
|
||||
|
||||
Configure the max-num-seqs parameter
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Increase the ``max-num-seqs`` parameter from the default ``256`` to ``512`` (``--max-num-seqs
|
||||
512``). This increases the maximum number of sequences per iteration and can improve throughput.
|
||||
|
||||
Use the float16 dtype
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
The default data type (``dtype``) is specified in the model’s configuration file. For instance, some models use ``torch.bfloat16`` as their default ``dtype``.
|
||||
Use float16 (``--dtype float16``) for better performance.
|
||||
|
||||
Multi-step scheduling
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Setting ``num-scheduler-steps`` for multi-step scheduling can increase performance. Set it between 10 to 15 (``--num-scheduler-steps 10``).
|
||||
|
||||
Distributed executor backend
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
The vLLM supports two modes of distributed executor backend: ``ray`` and ``mp``. When using the `<https://github.com/ROCm/vllm>`__ fork, using the ``mp``
|
||||
backend (``--distributed_executor_backend mp``) is recommended.
|
||||
|
||||
Graph mode max-seq-len-to-capture
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Maximum sequence length covered by CUDA graphs. In the default mode (where ``enforce_eager`` is ``False``), when a sequence has context length
|
||||
larger than this, vLLM engine falls back to eager mode. The default is 8192.
|
||||
|
||||
When working with models that support long context lengths, set the parameter ``--max-seq-len-to-capture`` to 16384.
|
||||
See this `vLLM blog <https://blog.vllm.ai/2024/10/23/vllm-serving-amd.html>`__ for details.
|
||||
|
||||
An example of long context length model is Qwen2-7b.
|
||||
|
||||
Whether to enable chunked prefill
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Another vLLM performance tip is to enable chunked prefill to improve
|
||||
throughput. Chunked prefill allows large prefills to be chunked into
|
||||
smaller chunks and batched together with decode requests.
|
||||
|
||||
You can enable the feature by specifying ``--enable-chunked-prefill`` in the
|
||||
command line or setting ``enable_chunked_prefill=True`` in the LLM
|
||||
constructor.
|
||||
|
||||
As stated in `vLLM's documentation, <https://docs.vllm.ai/en/latest/models/performance.html#chunked-prefill>`__,
|
||||
you can tune the performance by changing ``max_num_batched_tokens``. By
|
||||
default, it is set to 512 and optimized for ITL (inter-token latency).
|
||||
Smaller ``max_num_batched_tokens`` achieves better ITL because there are
|
||||
fewer prefills interrupting decodes.
|
||||
Higher ``max_num_batched_tokens`` achieves better TTFT (time to the first
|
||||
token) as you can put more prefill to the batch.
|
||||
|
||||
You might experience noticeable throughput improvements when
|
||||
benchmarking on a single GPU or 8 GPUs using the vLLM throughput
|
||||
benchmarking script along with the ShareGPT dataset as input.
|
||||
|
||||
In the case of fixed ``input-len``/``output-len``, for some configurations,
|
||||
enabling chunked prefill increases the throughput. For some other
|
||||
configurations, the throughput may be worse and elicit a need to tune
|
||||
parameter ``max_num_batched_tokens`` (for example, increasing ``max_num_batched_tokens`` value to 4096 or larger).
|
||||
|
||||
.. note::
|
||||
|
||||
Chunked prefill is no longer recommended. See the vLLM blog: `Serving LLMs on AMD MI300X: Best Practices <https://blog.vllm.ai/2024/10/23/vllm-serving-amd.html>`_ (October 2024).
|
||||
|
||||
Quantization support
|
||||
---------------------
|
||||
|
||||
Quantization reduces the precision of the model’s weights and activations, which significantly decreases the memory footprint.
|
||||
``fp8(w8a8)`` and ``AWQ`` quantization are supported for ROCm.
|
||||
|
||||
FP8 quantization
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
`<https://github.com/ROCm/vllm>`__ supports FP8 (8-bit floating point) weight and activation quantization using hardware acceleration on the Instinct MI300X.
|
||||
Quantization of models with FP8 allows for a 2x reduction in model memory requirements and up to a 1.6x improvement in throughput with minimal impact on accuracy.
|
||||
|
||||
AMD publishes Quark Quantized OCP FP8 models on Hugging Face. For example:
|
||||
|
||||
* `Llama-3.1-8B-Instruct-FP8-KV <https://huggingface.co/amd/Llama-3.1-8B-Instruct-FP8-KV>`__
|
||||
* `Llama-3.1-70B-Instruct-FP8-KV <https://huggingface.co/amd/Llama-3.1-70B-Instruct-FP8-KV>`__
|
||||
* `Llama-3.1-405B-Instruct-FP8-KV <https://huggingface.co/amd/Llama-3.1-405B-Instruct-FP8-KV>`__
|
||||
* `Mixtral-8x7B-Instruct-v0.1-FP8-KV <https://huggingface.co/amd/Mixtral-8x7B-Instruct-v0.1-FP8-KV>`__
|
||||
* `Mixtral-8x22B-Instruct-v0.1-FP8-KV <https://huggingface.co/amd/Mixtral-8x22B-Instruct-v0.1-FP8-KV>`__
|
||||
|
||||
To enable vLLM benchmarking to run on fp8 quantized models, use the ``--quantization`` parameter with value ``fp8`` (``--quantization fp8``).
|
||||
|
||||
AWQ quantization
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
You can quantize your own models by installing AutoAWQ or picking one of the 400+ models on Hugging Face. Be aware that
|
||||
that AWQ support in vLLM is currently underoptimized.
|
||||
|
||||
To enable vLLM to run on ``awq`` quantized models, using ``--quantization`` parameter with ``awq`` (``--quantization awq``).
|
||||
|
||||
You can find more specifics in the `vLLM AutoAWQ documentation <https://docs.vllm.ai/en/stable/quantization/auto_awq.html>`_.
|
||||
|
||||
fp8 kv-cached-dtype
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Using ``fp8 kv-cache dtype`` can improve performance as it reduces the size
|
||||
of ``kv-cache``. As a result, it reduces the cost required for reading and
|
||||
writing the ``kv-cache``.
|
||||
|
||||
To use this feature, specify ``--kv-cache-dtype`` as ``fp8``.
|
||||
|
||||
To specify the quantization scaling config, use the
|
||||
``--quantization-param-path`` parameter. If the parameter is not specified,
|
||||
the default scaling factor of ``1`` is used, which can lead to less accurate
|
||||
results. To generate ``kv-cache`` scaling JSON file, see `FP8 KV
|
||||
Cache <https://github.com/vllm-project/llm-compressor/blob/main/examples/quantization_kv_cache/README.md>`__
|
||||
in the vLLM GitHub repository.
|
||||
|
||||
Two sample Llama scaling configuration files are in vLLM for ``llama2-70b`` and
|
||||
``llama2-7b``.
|
||||
|
||||
If building the vLLM using
|
||||
`Dockerfile.rocm <https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile.rocm>`_
|
||||
for ``llama2-70b`` scale config, find the file at
|
||||
``/vllm-workspace/tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json`` at
|
||||
runtime.
|
||||
|
||||
Below is a sample command to run benchmarking with this feature enabled
|
||||
for the ``llama2-70b`` model:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
python3 /vllm-workspace/benchmarks/benchmark_throughput.py --model \
|
||||
/path/to/llama2-70b-model --kv-cache-dtype "fp8" \
|
||||
--quantization-param-path \
|
||||
"/vllm-workspace/tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json" \
|
||||
--input-len 512 --output-len 256 --num-prompts 500
|
||||
|
||||
|
||||
.. _mi300x-tunableop:
|
||||
|
||||
PyTorch TunableOp
|
||||
@@ -946,33 +568,33 @@ for details.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
HIP_FORCE_DEV_KERNARG=1 hipblaslt-bench --alpha 1 --beta 0 -r f16_r \
|
||||
HIP_FORCE_DEV_KERNARG=1 hipblaslt-bench --alpha 1 --beta 0 -r f16_r \
|
||||
--a_type f16_r --b_type f8_r --compute_type f32_f16_r \
|
||||
--initialization trig_float --cold_iters 100 --iters 1000 --rotating 256
|
||||
--initialization trig_float --cold_iters 100 --iters 1000 --rotating 256
|
||||
|
||||
* Example 2: Benchmark forward epilogues and backward epilogues
|
||||
|
||||
* ``HIPBLASLT_EPILOGUE_RELU: "--activation_type relu";``
|
||||
* ``HIPBLASLT_EPILOGUE_RELU: "--activation_type relu";``
|
||||
|
||||
* ``HIPBLASLT_EPILOGUE_BIAS: "--bias_vector";``
|
||||
* ``HIPBLASLT_EPILOGUE_BIAS: "--bias_vector";``
|
||||
|
||||
* ``HIPBLASLT_EPILOGUE_RELU_BIAS: "--activation_type relu --bias_vector";``
|
||||
* ``HIPBLASLT_EPILOGUE_RELU_BIAS: "--activation_type relu --bias_vector";``
|
||||
|
||||
* ``HIPBLASLT_EPILOGUE_GELU: "--activation_type gelu";``
|
||||
* ``HIPBLASLT_EPILOGUE_GELU: "--activation_type gelu";``
|
||||
|
||||
* ``HIPBLASLT_EPILOGUE_DGELU": --activation_type gelu --gradient";``
|
||||
|
||||
* ``HIPBLASLT_EPILOGUE_GELU_BIAS: "--activation_type gelu --bias_vector";``
|
||||
* ``HIPBLASLT_EPILOGUE_GELU_BIAS: "--activation_type gelu --bias_vector";``
|
||||
|
||||
* ``HIPBLASLT_EPILOGUE_GELU_AUX: "--activation_type gelu --use_e";``
|
||||
* ``HIPBLASLT_EPILOGUE_GELU_AUX: "--activation_type gelu --use_e";``
|
||||
|
||||
* ``HIPBLASLT_EPILOGUE_GELU_AUX_BIAS: "--activation_type gelu --bias_vector --use_e";``
|
||||
* ``HIPBLASLT_EPILOGUE_GELU_AUX_BIAS: "--activation_type gelu --bias_vector --use_e";``
|
||||
|
||||
* ``HIPBLASLT_EPILOGUE_DGELU_BGRAD: "--activation_type gelu --bias_vector --gradient";``
|
||||
* ``HIPBLASLT_EPILOGUE_DGELU_BGRAD: "--activation_type gelu --bias_vector --gradient";``
|
||||
|
||||
* ``HIPBLASLT_EPILOGUE_BGRADA: "--bias_vector --gradient --bias_source a";``
|
||||
* ``HIPBLASLT_EPILOGUE_BGRADA: "--bias_vector --gradient --bias_source a";``
|
||||
|
||||
* ``HIPBLASLT_EPILOGUE_BGRADB: "--bias_vector --gradient --bias_source b";``
|
||||
* ``HIPBLASLT_EPILOGUE_BGRADB: "--bias_vector --gradient --bias_source b";``
|
||||
|
||||
|
||||
hipBLASLt auto-tuning using hipblaslt-bench
|
||||
@@ -1031,26 +653,26 @@ The tuning tool is a two-step tool. It first runs the benchmark, then it creates
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
defaultBenchOptions = {"ProblemType": {
|
||||
"TransposeA": 0,
|
||||
"TransposeB": 0,
|
||||
"ComputeInputDataType": "s",
|
||||
"ComputeDataType": "s",
|
||||
"DataTypeC": "s",
|
||||
"DataTypeD": "s",
|
||||
"UseBias": False
|
||||
}, "TestConfig": {
|
||||
"ColdIter": 20,
|
||||
"Iter": 100,
|
||||
"AlgoMethod": "all",
|
||||
"RequestedSolutions": 2, # Only works in AlgoMethod heuristic
|
||||
"SolutionIndex": None, # Only works in AlgoMethod index
|
||||
"ApiMethod": "cpp",
|
||||
"RotatingBuffer": 0,
|
||||
}, "TuningParameters": {
|
||||
"SplitK": [0]
|
||||
}, "ProblemSizes": []}
|
||||
defaultCreateLogicOptions = {} # Currently unused
|
||||
defaultBenchOptions = {"ProblemType": {
|
||||
"TransposeA": 0,
|
||||
"TransposeB": 0,
|
||||
"ComputeInputDataType": "s",
|
||||
"ComputeDataType": "s",
|
||||
"DataTypeC": "s",
|
||||
"DataTypeD": "s",
|
||||
"UseBias": False
|
||||
}, "TestConfig": {
|
||||
"ColdIter": 20,
|
||||
"Iter": 100,
|
||||
"AlgoMethod": "all",
|
||||
"RequestedSolutions": 2, # Only works in AlgoMethod heuristic
|
||||
"SolutionIndex": None, # Only works in AlgoMethod index
|
||||
"ApiMethod": "cpp",
|
||||
"RotatingBuffer": 0,
|
||||
}, "TuningParameters": {
|
||||
"SplitK": [0]
|
||||
}, "ProblemSizes": []}
|
||||
defaultCreateLogicOptions = {} # Currently unused
|
||||
|
||||
* ``TestConfig``
|
||||
1. ``ColdIter``: This is number the warm-up iterations before starting the kernel benchmark.
|
||||
@@ -1230,7 +852,7 @@ command:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
merge.py original_dir new_tuned_yaml_dir output_dir
|
||||
merge.py original_dir new_tuned_yaml_dir output_dir
|
||||
|
||||
The following table describes the logic YAML files.
|
||||
|
||||
@@ -1833,7 +1455,7 @@ de-quantize the ``int4`` key-value from the ``int4`` data type to ``fp16``.
|
||||
|
||||
From the IR snippet, you can see ``i32`` data is loaded from global memory to
|
||||
registers (``%190``). With a few element-wise operations in registers, it is
|
||||
stored in shared memory (``%269``) for the transpose operation (``%270``), which
|
||||
stored in shared memory (``%269``) for the transpose operation (``%270``), which
|
||||
needs data movement across different threads. With the transpose done, it is
|
||||
loaded from LDS to register again (``%276``), and with a few more
|
||||
element-wise operations, it is stored to LDS again (``%298``). The last step
|
||||
@@ -1967,7 +1589,7 @@ something similar to the following:
|
||||
loaded at: [0x7fd4f100c000-0x7fd4f100e070]
|
||||
|
||||
The kernel name and the code object file should be listed. In the
|
||||
example above, the kernel name is vector_add_assert_trap, but this might
|
||||
example above, the kernel name is vector_add_assert_trap, but this might
|
||||
also look like:
|
||||
|
||||
.. code-block:: text
|
||||
@@ -2081,3 +1703,8 @@ Hardware efficiency is maximized with 4 or fewer HIP streams. These environment
|
||||
configuration to two compute streams and two RCCL streams, aligning with this best practice.
|
||||
Additionally, RCCL is often pre-optimized for MI300 systems in production by querying the node
|
||||
topology during startup, reducing the need for extensive manual tuning.
|
||||
|
||||
Further reading
|
||||
===============
|
||||
|
||||
* :doc:`vllm-optimization`
|
||||
|
||||
@@ -134,6 +134,8 @@ subtrees:
|
||||
title: Profile and debug
|
||||
- file: how-to/rocm-for-ai/inference-optimization/workload.rst
|
||||
title: Workload optimization
|
||||
- file: how-to/rocm-for-ai/inference-optimization/vllm-optimization.rst
|
||||
title: vLLM V1 performance optimization
|
||||
|
||||
- url: https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/
|
||||
title: AI tutorials
|
||||
|
||||
Reference in New Issue
Block a user