From d1debc7e45c4668e3b2539e966eb91d5dcc2d48d Mon Sep 17 00:00:00 2001 From: Wei Luo Date: Thu, 8 May 2025 14:28:51 +0800 Subject: [PATCH] [doc]: Add quark in model-quantization.rst (#374) * Add quark in model-quantization.rst --------- Co-authored-by: Peter Park Co-authored-by: Peter Park --- .../model-quantization.rst | 189 ++++++++++++++++-- 1 file changed, 176 insertions(+), 13 deletions(-) diff --git a/docs/how-to/rocm-for-ai/inference-optimization/model-quantization.rst b/docs/how-to/rocm-for-ai/inference-optimization/model-quantization.rst index 9de27c20e..8a14f42cf 100644 --- a/docs/how-to/rocm-for-ai/inference-optimization/model-quantization.rst +++ b/docs/how-to/rocm-for-ai/inference-optimization/model-quantization.rst @@ -1,15 +1,178 @@ .. meta:: :description: How to use model quantization techniques to speed up inference. - :keywords: ROCm, LLM, fine-tuning, usage, tutorial, quantization, GPTQ, transformers, bitsandbytes + :keywords: ROCm, LLM, fine-tuning, usage, tutorial, quantization, Quark, GPTQ, transformers, bitsandbytes ***************************** Model quantization techniques ***************************** Quantization reduces the model size compared to its native full-precision version, making it easier to fit large models -onto accelerators or GPUs with limited memory usage. This section explains how to perform LLM quantization using GPTQ +onto accelerators or GPUs with limited memory usage. This section explains how to perform LLM quantization using AMD Quark, GPTQ and bitsandbytes on AMD Instinct hardware. +.. _quantize-llms-quark: + +AMD Quark +========= + +`AMD Quark `_ offers the leading efficient and scalable quantization solution tailored to AMD Instinct GPUs. It supports ``FP8`` and ``INT8`` quantization for activations, weights, and KV cache, +including ``FP8`` attention. For very large models, it employs a two-level ``INT4-FP8`` scheme—storing weights in ``INT4`` while computing with ``FP8``—for nearly 4× compression without sacrificing accuracy. +Quark scales efficiently across multiple GPUs, efficiently handling ultra-large models like Llama-3.1-405B. Quantized ``FP8`` models like Llama, Mixtral, and Grok-1 are available under the `AMD organization on Hugging Face `_, and can be deployed directly via `vLLM `_. + +Installing Quark +------------------- + +The latest release of Quark can be installed with pip + +.. code-block:: shell + + pip install amd-quark + +For detailed installation instructions, refer to the `Quark documentation `_. + + +Using Quark for quantization +----------------------------- + +#. First, load the pre-trained model and its corresponding tokenizer using the Hugging Face ``transformers`` library. + + .. code-block:: python + + from transformers import AutoTokenizer, AutoModelForCausalLM + + MODEL_ID = "meta-llama/Llama-2-70b-chat-hf" + MAX_SEQ_LEN = 512 + + model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, device_map="auto", torch_dtype="auto", + ) + model.eval() + + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, model_max_length=MAX_SEQ_LEN) + tokenizer.pad_token = tokenizer.eos_token + +#. Prepare the calibration DataLoader (static quantization requires calibration data). + + .. code-block:: python + + from datasets import load_dataset + from torch.utils.data import DataLoader + + BATCH_SIZE = 1 + NUM_CALIBRATION_DATA = 512 + + dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") + text_data = dataset["text"][:NUM_CALIBRATION_DATA] + + tokenized_outputs = tokenizer( + text_data, return_tensors="pt", padding=True, truncation=True, max_length=MAX_SEQ_LEN + ) + calib_dataloader = DataLoader( + tokenized_outputs['input_ids'], batch_size=BATCH_SIZE, drop_last=True + ) + +#. Define the quantization configuration. See the comments in the following code snippet for descriptions of each configuration option. + + .. code-block:: python + + from quark.torch.quantization import (Config, QuantizationConfig, + FP8E4M3PerTensorSpec) + + # Define fp8/per-tensor/static spec. + FP8_PER_TENSOR_SPEC = FP8E4M3PerTensorSpec(observer_method="min_max", + is_dynamic=False).to_quantization_spec() + + # Define global quantization config, input tensors and weight apply FP8_PER_TENSOR_SPEC. + global_quant_config = QuantizationConfig(input_tensors=FP8_PER_TENSOR_SPEC, + weight=FP8_PER_TENSOR_SPEC) + + # Define quantization config for kv-cache layers, output tensors apply FP8_PER_TENSOR_SPEC. + KV_CACHE_SPEC = FP8_PER_TENSOR_SPEC + kv_cache_layer_names_for_llama = ["*k_proj", "*v_proj"] + kv_cache_quant_config = {name : + QuantizationConfig(input_tensors=global_quant_config.input_tensors, + weight=global_quant_config.weight, + output_tensors=KV_CACHE_SPEC) + for name in kv_cache_layer_names_for_llama} + layer_quant_config = kv_cache_quant_config.copy() + + EXCLUDE_LAYERS = ["lm_head"] + quant_config = Config( + global_quant_config=global_quant_config, + layer_quant_config=layer_quant_config, + kv_cache_quant_config=kv_cache_quant_config, + exclude=EXCLUDE_LAYERS) + +#. Quantize the model and export + + .. code-block:: python + + import torch + from quark.torch import ModelQuantizer, ModelExporter + from quark.torch.export import ExporterConfig, JsonExporterConfig + + # Apply quantization. + quantizer = ModelQuantizer(quant_config) + quant_model = quantizer.quantize_model(model, calib_dataloader) + + # Freeze quantized model to export. + freezed_model = quantizer.freeze(model) + + # Define export config. + LLAMA_KV_CACHE_GROUP = ["*k_proj", "*v_proj"] + export_config = ExporterConfig(json_export_config=JsonExporterConfig()) + export_config.json_export_config.kv_cache_group = LLAMA_KV_CACHE_GROUP + + EXPORT_DIR = MODEL_ID.split("/")[1] + "-w-fp8-a-fp8-kvcache-fp8-pertensor" + exporter = ModelExporter(config=export_config, export_dir=EXPORT_DIR) + with torch.no_grad(): + exporter.export_safetensors_model(freezed_model, + quant_config=quant_config, tokenizer=tokenizer) + +Evaluating the quantized model with vLLM +---------------------------------------- + +The exported Quark-quantized model can be loaded directly by vLLM for inference. You need to specify the model path and inform vLLM about the quantization method (``quantization='quark'``) and the KV cache data type (``kv_cache_dtype='fp8'``). +Use the ``LLM`` interface to load the model: + +.. code-block:: python + + from vllm import LLM, SamplingParamsinterface + + # Sample prompts. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + # Create an LLM. + llm = LLM(model="Llama-2-70b-chat-hf-w-fp8-a-fp8-kvcache-fp8-pertensor", + kv_cache_dtype='fp8',quantization='quark') + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Output: {generated_text!r}") + print("-" * 60) + +You can also evaluate the quantized model's accuracy on standard benchmarks using the `lm-evaluation-harness `_. Pass the necessary vLLM arguments to ``lm_eval`` via ``--model_args``. + +.. code-block:: shell + + lm_eval --model vllm \ + --model_args pretrained=Llama-2-70b-chat-hf-w-fp8-a-fp8-kvcache-fp8-pertensor,kv_cache_dtype='fp8',quantization='quark' \ + --tasks gsm8k + +This provides a standardized way to measure the performance impact of quantization. .. _fine-tune-llms-gptq: GPTQ @@ -33,7 +196,7 @@ The AutoGPTQ library implements the GPTQ algorithm. .. code-block:: shell # This will install pre-built wheel for a specific ROCm version. - + pip install auto-gptq --no-build-isolation --extra-index-url https://huggingface.github.io/autogptq-index/whl/rocm573/ Or, install AutoGPTQ from source for the appropriate ROCm version (for example, ROCm 6.1). @@ -43,10 +206,10 @@ The AutoGPTQ library implements the GPTQ algorithm. # Clone the source code. git clone https://github.com/AutoGPTQ/AutoGPTQ.git cd AutoGPTQ - + # Speed up the compilation by specifying PYTORCH_ROCM_ARCH to target device. PYTORCH_ROCM_ARCH=gfx942 ROCM_VERSION=6.1 pip install . - + # Show the package after the installation #. Run ``pip show auto-gptq`` to print information for the installed ``auto-gptq`` package. Its output should look like @@ -112,7 +275,7 @@ Using GPTQ with Hugging Face Transformers .. code-block:: python from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig - + base_model_name = " NousResearch/Llama-2-7b-hf" tokenizer = AutoTokenizer.from_pretrained(base_model_name) gptq_config = GPTQConfig(bits=4, dataset="c4", tokenizer=tokenizer) @@ -212,10 +375,10 @@ To get started with bitsandbytes primitives, use the following code as reference .. code-block:: python import bitsandbytes as bnb - + # Use Int8 Matrix Multiplication bnb.matmul(..., threshold=6.0) - + # Use bitsandbytes 8-bit Optimizers adam = bnb.optim.Adam8bit(model.parameters(), lr=0.001, betas=(0.9, 0.995)) @@ -227,14 +390,14 @@ To load a Transformers model in 4-bit, set ``load_in_4bit=true`` in ``BitsAndByt .. code-block:: python from transformers import AutoModelForCausalLM, BitsAndBytesConfig - + base_model_name = "NousResearch/Llama-2-7b-hf" quantization_config = BitsAndBytesConfig(load_in_4bit=True) bnb_model_4bit = AutoModelForCausalLM.from_pretrained( base_model_name, device_map="auto", quantization_config=quantization_config) - + # Check the memory footprint with get_memory_footprint method print(bnb_model_4bit.get_memory_footprint()) @@ -243,9 +406,9 @@ To load a model in 8-bit for inference, use the ``load_in_8bit`` option. .. code-block:: python from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig - + base_model_name = "NousResearch/Llama-2-7b-hf" - + tokenizer = AutoTokenizer.from_pretrained(base_model_name) quantization_config = BitsAndBytesConfig(load_in_8bit=True) tokenizer = AutoTokenizer.from_pretrained(base_model_name) @@ -253,7 +416,7 @@ To load a model in 8-bit for inference, use the ``load_in_8bit`` option. base_model_name, device_map="auto", quantization_config=quantization_config) - + prompt = "What is a large language model?" inputs = tokenizer(prompt, return_tensors="pt").to("cuda") generated_ids = model.generate(**inputs)