mirror of
https://github.com/ROCm/ROCm.git
synced 2026-01-08 22:28:06 -05:00
JAX Maxtext v25.9 doc update (#5532)
* archive previous version (25.7) * update docker components list for 25.9 * update template * update docker pull tag * update * fix intro
This commit is contained in:
@@ -1,47 +1,16 @@
|
||||
dockers:
|
||||
- pull_tag: rocm/jax-training:maxtext-v25.7-jax060
|
||||
- pull_tag: rocm/jax-training:maxtext-v25.9
|
||||
docker_hub_url: https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.7/images/sha256-45f4c727d4019a63fc47313d3a5f5a5105569539294ddfd2d742218212ae9025
|
||||
components:
|
||||
ROCm: 6.4.1
|
||||
JAX: 0.6.0
|
||||
Python: 3.10.12
|
||||
Transformer Engine: 2.1.0+90d703dd
|
||||
hipBLASLt: 1.1.0-499ece1c21
|
||||
- pull_tag: rocm/jax-training:maxtext-v25.7
|
||||
docker_hub_url: https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.7/images/sha256-45f4c727d4019a63fc47313d3a5f5a5105569539294ddfd2d742218212ae9025
|
||||
components:
|
||||
ROCm: 6.4.1
|
||||
JAX: 0.5.0
|
||||
Python: 3.10.12
|
||||
Transformer Engine: 2.1.0+90d703dd
|
||||
ROCm: 7.0.0
|
||||
JAX: 0.6.2
|
||||
Python: 3.10.18
|
||||
Transformer Engine: 2.2.0.dev0+c91bac54
|
||||
hipBLASLt: 1.x.x
|
||||
model_groups:
|
||||
- group: Meta Llama
|
||||
tag: llama
|
||||
models:
|
||||
- model: Llama 3.3 70B
|
||||
mad_tag: jax_maxtext_train_llama-3.3-70b
|
||||
model_repo: Llama-3.3-70B
|
||||
precision: bf16
|
||||
doc_options: ["single-node"]
|
||||
- model: Llama 3.1 8B
|
||||
mad_tag: jax_maxtext_train_llama-3.1-8b
|
||||
model_repo: Llama-3.1-8B
|
||||
precision: bf16
|
||||
doc_options: ["single-node"]
|
||||
- model: Llama 3.1 70B
|
||||
mad_tag: jax_maxtext_train_llama-3.1-70b
|
||||
model_repo: Llama-3.1-70B
|
||||
precision: bf16
|
||||
doc_options: ["single-node"]
|
||||
- model: Llama 3 8B
|
||||
mad_tag: jax_maxtext_train_llama-3-8b
|
||||
multinode_training_script: llama3_8b_multinode.sh
|
||||
doc_options: ["multi-node"]
|
||||
- model: Llama 3 70B
|
||||
mad_tag: jax_maxtext_train_llama-3-70b
|
||||
multinode_training_script: llama3_70b_multinode.sh
|
||||
doc_options: ["multi-node"]
|
||||
- model: Llama 2 7B
|
||||
mad_tag: jax_maxtext_train_llama-2-7b
|
||||
model_repo: Llama-2-7B
|
||||
@@ -54,6 +23,29 @@ model_groups:
|
||||
precision: bf16
|
||||
multinode_training_script: llama2_70b_multinode.sh
|
||||
doc_options: ["single-node", "multi-node"]
|
||||
- model: Llama 3 8B (multi-node)
|
||||
mad_tag: jax_maxtext_train_llama-3-8b
|
||||
multinode_training_script: llama3_8b_multinode.sh
|
||||
doc_options: ["multi-node"]
|
||||
- model: Llama 3 70B (multi-node)
|
||||
mad_tag: jax_maxtext_train_llama-3-70b
|
||||
multinode_training_script: llama3_70b_multinode.sh
|
||||
doc_options: ["multi-node"]
|
||||
- model: Llama 3.1 8B
|
||||
mad_tag: jax_maxtext_train_llama-3.1-8b
|
||||
model_repo: Llama-3.1-8B
|
||||
precision: bf16
|
||||
doc_options: ["single-node"]
|
||||
- model: Llama 3.1 70B
|
||||
mad_tag: jax_maxtext_train_llama-3.1-70b
|
||||
model_repo: Llama-3.1-70B
|
||||
precision: bf16
|
||||
doc_options: ["single-node"]
|
||||
- model: Llama 3.3 70B
|
||||
mad_tag: jax_maxtext_train_llama-3.3-70b
|
||||
model_repo: Llama-3.3-70B
|
||||
precision: bf16
|
||||
doc_options: ["single-node"]
|
||||
- group: DeepSeek
|
||||
tag: deepseek
|
||||
models:
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
dockers:
|
||||
- pull_tag: rocm/jax-training:maxtext-v25.7-jax060
|
||||
docker_hub_url: https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.7/images/sha256-45f4c727d4019a63fc47313d3a5f5a5105569539294ddfd2d742218212ae9025
|
||||
components:
|
||||
ROCm: 6.4.1
|
||||
JAX: 0.6.0
|
||||
Python: 3.10.12
|
||||
Transformer Engine: 2.1.0+90d703dd
|
||||
hipBLASLt: 1.1.0-499ece1c21
|
||||
- pull_tag: rocm/jax-training:maxtext-v25.7
|
||||
docker_hub_url: https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.7/images/sha256-45f4c727d4019a63fc47313d3a5f5a5105569539294ddfd2d742218212ae9025
|
||||
components:
|
||||
ROCm: 6.4.1
|
||||
JAX: 0.5.0
|
||||
Python: 3.10.12
|
||||
Transformer Engine: 2.1.0+90d703dd
|
||||
hipBLASLt: 1.x.x
|
||||
model_groups:
|
||||
- group: Meta Llama
|
||||
tag: llama
|
||||
models:
|
||||
- model: Llama 3.3 70B
|
||||
mad_tag: jax_maxtext_train_llama-3.3-70b
|
||||
model_repo: Llama-3.3-70B
|
||||
precision: bf16
|
||||
doc_options: ["single-node"]
|
||||
- model: Llama 3.1 8B
|
||||
mad_tag: jax_maxtext_train_llama-3.1-8b
|
||||
model_repo: Llama-3.1-8B
|
||||
precision: bf16
|
||||
doc_options: ["single-node"]
|
||||
- model: Llama 3.1 70B
|
||||
mad_tag: jax_maxtext_train_llama-3.1-70b
|
||||
model_repo: Llama-3.1-70B
|
||||
precision: bf16
|
||||
doc_options: ["single-node"]
|
||||
- model: Llama 3 8B
|
||||
mad_tag: jax_maxtext_train_llama-3-8b
|
||||
multinode_training_script: llama3_8b_multinode.sh
|
||||
doc_options: ["multi-node"]
|
||||
- model: Llama 3 70B
|
||||
mad_tag: jax_maxtext_train_llama-3-70b
|
||||
multinode_training_script: llama3_70b_multinode.sh
|
||||
doc_options: ["multi-node"]
|
||||
- model: Llama 2 7B
|
||||
mad_tag: jax_maxtext_train_llama-2-7b
|
||||
model_repo: Llama-2-7B
|
||||
precision: bf16
|
||||
multinode_training_script: llama2_7b_multinode.sh
|
||||
doc_options: ["single-node", "multi-node"]
|
||||
- model: Llama 2 70B
|
||||
mad_tag: jax_maxtext_train_llama-2-70b
|
||||
model_repo: Llama-2-70B
|
||||
precision: bf16
|
||||
multinode_training_script: llama2_70b_multinode.sh
|
||||
doc_options: ["single-node", "multi-node"]
|
||||
- group: DeepSeek
|
||||
tag: deepseek
|
||||
models:
|
||||
- model: DeepSeek-V2-Lite (16B)
|
||||
mad_tag: jax_maxtext_train_deepseek-v2-lite-16b
|
||||
model_repo: DeepSeek-V2-lite
|
||||
precision: bf16
|
||||
doc_options: ["single-node"]
|
||||
- group: Mistral AI
|
||||
tag: mistral
|
||||
models:
|
||||
- model: Mixtral 8x7B
|
||||
mad_tag: jax_maxtext_train_mixtral-8x7b
|
||||
model_repo: Mixtral-8x7B
|
||||
precision: bf16
|
||||
doc_options: ["single-node"]
|
||||
@@ -6,14 +6,8 @@
|
||||
Training a model with JAX MaxText on ROCm
|
||||
******************************************
|
||||
|
||||
MaxText is a high-performance, open-source framework built on the Google JAX
|
||||
machine learning library to train LLMs at scale. The MaxText framework for
|
||||
ROCm is an optimized fork of the upstream
|
||||
`<https://github.com/AI-Hypercomputer/maxtext>`__ enabling efficient AI workloads
|
||||
on AMD MI300X Series GPUs.
|
||||
|
||||
The MaxText for ROCm training Docker image
|
||||
provides a prebuilt environment for training on AMD Instinct MI300X and MI325X GPUs,
|
||||
provides a prebuilt environment for training on AMD Instinct MI355X, MI350X, MI325X, and MI300X GPUs,
|
||||
including essential components like JAX, XLA, ROCm libraries, and MaxText utilities.
|
||||
It includes the following software components:
|
||||
|
||||
@@ -61,15 +55,15 @@ MaxText with on ROCm provides the following key features to train large language
|
||||
|
||||
- Multi-node support
|
||||
|
||||
- NANOO FP8 quantization support
|
||||
- NANOO FP8 (for MI300X series GPUs) and FP8 (for MI355X and MI350X) quantization support
|
||||
|
||||
.. _amd-maxtext-model-support-v257:
|
||||
.. _amd-maxtext-model-support-v259:
|
||||
|
||||
Supported models
|
||||
================
|
||||
|
||||
The following models are pre-optimized for performance on AMD Instinct MI300
|
||||
Series GPUs. Some instructions, commands, and available training
|
||||
The following models are pre-optimized for performance on AMD Instinct
|
||||
GPUs. Some instructions, commands, and available training
|
||||
configurations in this documentation might vary by model -- select one to get
|
||||
started.
|
||||
|
||||
@@ -139,22 +133,13 @@ Use the following command to pull the Docker image from Docker Hub.
|
||||
|
||||
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml
|
||||
|
||||
{% set dockers = data.dockers %}
|
||||
.. tab-set::
|
||||
{% set docker = data.dockers[0] %}
|
||||
|
||||
{% for docker in dockers %}
|
||||
{% set jax_version = docker.components["JAX"] %}
|
||||
.. code-block:: shell
|
||||
|
||||
.. tab-item:: JAX {{ jax_version }}
|
||||
:sync: {{ docker.pull_tag }}
|
||||
docker pull {{ docker.pull_tag }}
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
docker pull {{ docker.pull_tag }}
|
||||
|
||||
{% endfor %}
|
||||
|
||||
.. _amd-maxtext-multi-node-setup-v257:
|
||||
.. _amd-maxtext-multi-node-setup-v259:
|
||||
|
||||
Multi-node configuration
|
||||
------------------------
|
||||
@@ -162,7 +147,7 @@ Multi-node configuration
|
||||
See :doc:`/how-to/rocm-for-ai/system-setup/multi-node-setup` to configure your
|
||||
environment for multi-node training.
|
||||
|
||||
.. _amd-maxtext-get-started-v257:
|
||||
.. _amd-maxtext-get-started-v259:
|
||||
|
||||
Benchmarking
|
||||
============
|
||||
@@ -174,7 +159,7 @@ benchmark results:
|
||||
|
||||
.. _vllm-benchmark-mad:
|
||||
|
||||
{% set dockers = data.dockers %}
|
||||
{% set docker = data.dockers[0] %}
|
||||
{% set model_groups = data.model_groups %}
|
||||
{% for model_group in model_groups %}
|
||||
{% for model in model_group.models %}
|
||||
@@ -186,6 +171,9 @@ benchmark results:
|
||||
{% if model.mad_tag and "single-node" in model.doc_options %}
|
||||
.. tab-item:: MAD-integrated benchmarking
|
||||
|
||||
The following run command is tailored to {{ model.model }}.
|
||||
See :ref:`amd-maxtext-model-support-v259` to switch to another available model.
|
||||
|
||||
1. Clone the ROCm Model Automation and Dashboarding (`<https://github.com/ROCm/MAD>`__) repository to a local
|
||||
directory and install the required packages on the host machine.
|
||||
|
||||
@@ -214,22 +202,19 @@ benchmark results:
|
||||
|
||||
.. tab-item:: Standalone benchmarking
|
||||
|
||||
The following commands are optimized for {{ model.model }}. See
|
||||
:ref:`amd-maxtext-model-support-v259` to switch to another
|
||||
available model. Some instructions and resources might not be
|
||||
available for all models and configurations.
|
||||
|
||||
.. rubric:: Download the Docker image and required scripts
|
||||
|
||||
Run the JAX MaxText benchmark tool independently by starting the
|
||||
Docker container as shown in the following snippet.
|
||||
|
||||
.. tab-set::
|
||||
{% for docker in dockers %}
|
||||
{% set jax_version = docker.components["JAX"] %}
|
||||
.. code-block:: shell
|
||||
|
||||
.. tab-item:: JAX {{ jax_version }}
|
||||
:sync: {{ docker.pull_tag }}
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
docker pull {{ docker.pull_tag }}
|
||||
{% endfor %}
|
||||
docker pull {{ docker.pull_tag }}
|
||||
|
||||
{% if model.model_repo and "single-node" in model.doc_options %}
|
||||
.. rubric:: Single node training
|
||||
@@ -250,33 +235,25 @@ benchmark results:
|
||||
|
||||
2. Launch the Docker container.
|
||||
|
||||
.. tab-set::
|
||||
{% for docker in dockers %}
|
||||
{% set jax_version = docker.components["JAX"] %}
|
||||
.. code-block:: shell
|
||||
|
||||
.. tab-item:: JAX {{ jax_version }}
|
||||
:sync: {{ docker.pull_tag }}
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
docker run -it \
|
||||
--device=/dev/dri \
|
||||
--device=/dev/kfd \
|
||||
--network host \
|
||||
--ipc host \
|
||||
--group-add video \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--security-opt seccomp=unconfined \
|
||||
--privileged \
|
||||
-v $HOME:$HOME \
|
||||
-v $HOME/.ssh:/root/.ssh \
|
||||
-v $HF_HOME:/hf_cache \
|
||||
-e HF_HOME=/hf_cache \
|
||||
-e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN
|
||||
--shm-size 64G \
|
||||
--name training_env \
|
||||
{{ docker.pull_tag }}
|
||||
{% endfor %}
|
||||
docker run -it \
|
||||
--device=/dev/dri \
|
||||
--device=/dev/kfd \
|
||||
--network host \
|
||||
--ipc host \
|
||||
--group-add video \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--security-opt seccomp=unconfined \
|
||||
--privileged \
|
||||
-v $HOME:$HOME \
|
||||
-v $HOME/.ssh:/root/.ssh \
|
||||
-v $HF_HOME:/hf_cache \
|
||||
-e HF_HOME=/hf_cache \
|
||||
-e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN
|
||||
--shm-size 64G \
|
||||
--name training_env \
|
||||
{{ docker.pull_tag }}
|
||||
|
||||
3. In the Docker container, clone the ROCm MAD repository and navigate to the
|
||||
benchmark scripts directory at ``MAD/scripts/jax-maxtext``.
|
||||
@@ -299,11 +276,27 @@ benchmark results:
|
||||
|
||||
./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }}
|
||||
|
||||
For quantized training, use the following command:
|
||||
For quantized training, run the script with the appropriate option for your Instinct GPU.
|
||||
|
||||
.. code-block:: shell
|
||||
.. tab-set::
|
||||
|
||||
./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} -q nanoo_fp8
|
||||
.. tab-item:: MI355X and MI350X
|
||||
|
||||
For ``fp8`` quantized training on MI355X and MI350X GPUs, use the following command:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} -q fp8
|
||||
|
||||
{% if model.model_repo not in ["Llama-3.1-70B", "Llama-3.3-70B"] %}
|
||||
.. tab-item:: MI325X and MI300X
|
||||
|
||||
For ``nanoo_fp8`` quantized training on MI300X series GPUs, use the following command:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} -q nanoo_fp8
|
||||
{% endif %}
|
||||
|
||||
{% endif %}
|
||||
{% if model.multinode_training_script and "multi-node" in model.doc_options %}
|
||||
@@ -335,7 +328,7 @@ benchmark results:
|
||||
{% else %}
|
||||
.. rubric:: Multi-node training
|
||||
|
||||
For multi-node training examples, choose a model from :ref:`amd-maxtext-model-support-v257`
|
||||
For multi-node training examples, choose a model from :ref:`amd-maxtext-model-support-v259`
|
||||
with an available `multi-node training script <https://github.com/ROCm/MAD/tree/develop/scripts/jax-maxtext/gpu-rocm>`__.
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
@@ -17,27 +17,35 @@ previous releases of the ``ROCm/jax-training`` Docker image on `Docker Hub <http
|
||||
- Components
|
||||
- Resources
|
||||
|
||||
* - 25.7 (latest)
|
||||
-
|
||||
* - 25.9 (latest)
|
||||
-
|
||||
* ROCm 7.0.0
|
||||
* JAX 0.6.2
|
||||
-
|
||||
* :doc:`Documentation <../jax-maxtext>`
|
||||
* `Docker Hub <https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.7-jax060/images/sha256-7352212ae033a76dca2b9dceffc23c1b5f1a61a7a560082cf747a9bf1acfc9ce>`__
|
||||
|
||||
* - 25.7
|
||||
-
|
||||
* ROCm 6.4.1
|
||||
* JAX 0.6.0, 0.5.0
|
||||
-
|
||||
* :doc:`Documentation <../jax-maxtext>`
|
||||
-
|
||||
* :doc:`Documentation <jax-maxtext-v25.7>`
|
||||
* `Docker Hub (JAX 0.6.0) <https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.7-jax060/images/sha256-7352212ae033a76dca2b9dceffc23c1b5f1a61a7a560082cf747a9bf1acfc9ce>`__
|
||||
* `Docker Hub (JAX 0.5.0) <https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.7/images/sha256-45f4c727d4019a63fc47313d3a5f5a5105569539294ddfd2d742218212ae9025>`__
|
||||
|
||||
* - 25.5
|
||||
-
|
||||
-
|
||||
* ROCm 6.3.4
|
||||
* JAX 0.4.35
|
||||
-
|
||||
-
|
||||
* :doc:`Documentation <jax-maxtext-v25.5>`
|
||||
* `Docker Hub <https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.5/images/sha256-4e0516358a227cae8f552fb866ec07e2edcf244756f02e7b40212abfbab5217b>`__
|
||||
|
||||
* - 25.4
|
||||
-
|
||||
-
|
||||
* ROCm 6.3.0
|
||||
* JAX 0.4.31
|
||||
-
|
||||
-
|
||||
* :doc:`Documentation <jax-maxtext-v25.4>`
|
||||
* `Docker Hub <https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.4/images/sha256-fb3eb71cd74298a7b3044b7130cf84113f14d518ff05a2cd625c11ea5f6a7b01>`__
|
||||
|
||||
@@ -0,0 +1,366 @@
|
||||
:orphan:
|
||||
|
||||
.. meta::
|
||||
:description: How to train a model using JAX MaxText for ROCm.
|
||||
:keywords: ROCm, AI, LLM, train, jax, torch, Llama, flux, tutorial, docker
|
||||
|
||||
******************************************
|
||||
Training a model with JAX MaxText on ROCm
|
||||
******************************************
|
||||
|
||||
.. caution::
|
||||
|
||||
This documentation does not reflect the latest version of ROCm JAX MaxText
|
||||
training performance documentation. See :doc:`../jax-maxtext` for the latest version.
|
||||
|
||||
MaxText is a high-performance, open-source framework built on the Google JAX
|
||||
machine learning library to train LLMs at scale. The MaxText framework for
|
||||
ROCm is an optimized fork of the upstream
|
||||
`<https://github.com/AI-Hypercomputer/maxtext>`__ enabling efficient AI workloads
|
||||
on AMD MI300X series GPUs.
|
||||
|
||||
The MaxText for ROCm training Docker image
|
||||
provides a prebuilt environment for training on AMD Instinct MI300X and MI325X GPUs,
|
||||
including essential components like JAX, XLA, ROCm libraries, and MaxText utilities.
|
||||
It includes the following software components:
|
||||
|
||||
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml
|
||||
|
||||
{% set dockers = data.dockers %}
|
||||
.. tab-set::
|
||||
|
||||
{% for docker in dockers %}
|
||||
{% set jax_version = docker.components["JAX"] %}
|
||||
|
||||
.. tab-item:: ``{{ docker.pull_tag }}``
|
||||
:sync: {{ docker.pull_tag }}
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - Software component
|
||||
- Version
|
||||
|
||||
{% for component_name, component_version in docker.components.items() %}
|
||||
* - {{ component_name }}
|
||||
- {{ component_version }}
|
||||
|
||||
{% endfor %}
|
||||
{% if jax_version == "0.6.0" %}
|
||||
.. note::
|
||||
|
||||
Shardy is a new config in JAX 0.6.0. You might get related errors if it's
|
||||
not configured correctly. For now you can turn it off by setting
|
||||
``shardy=False`` during the training run. You can also follow the `migration
|
||||
guide <https://docs.jax.dev/en/latest/shardy_jax_migration.html>`__ to enable
|
||||
it.
|
||||
{% endif %}
|
||||
|
||||
{% endfor %}
|
||||
|
||||
MaxText with on ROCm provides the following key features to train large language models efficiently:
|
||||
|
||||
- Transformer Engine (TE)
|
||||
|
||||
- Flash Attention (FA) 3 -- with or without sequence input packing
|
||||
|
||||
- GEMM tuning
|
||||
|
||||
- Multi-node support
|
||||
|
||||
- NANOO FP8 quantization support
|
||||
|
||||
.. _amd-maxtext-model-support-v257:
|
||||
|
||||
Supported models
|
||||
================
|
||||
|
||||
The following models are pre-optimized for performance on AMD Instinct MI300
|
||||
series GPUs. Some instructions, commands, and available training
|
||||
configurations in this documentation might vary by model -- select one to get
|
||||
started.
|
||||
|
||||
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml
|
||||
|
||||
{% set model_groups = data.model_groups %}
|
||||
.. raw:: html
|
||||
|
||||
<div id="vllm-benchmark-ud-params-picker" class="container-fluid">
|
||||
<div class="row gx-0">
|
||||
<div class="col-2 me-1 px-2 model-param-head">Model</div>
|
||||
<div class="row col-10 pe-0">
|
||||
{% for model_group in model_groups %}
|
||||
<div class="col-4 px-2 model-param" data-param-k="model-group" data-param-v="{{ model_group.tag }}" tabindex="0">{{ model_group.group }}</div>
|
||||
{% endfor %}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="row gx-0 pt-1">
|
||||
<div class="col-2 me-1 px-2 model-param-head">Variant</div>
|
||||
<div class="row col-10 pe-0">
|
||||
{% for model_group in model_groups %}
|
||||
{% set models = model_group.models %}
|
||||
{% for model in models %}
|
||||
{% if models|length % 3 == 0 %}
|
||||
<div class="col-4 px-2 model-param" data-param-k="model" data-param-v="{{ model.mad_tag }}" data-param-group="{{ model_group.tag }}" tabindex="0">{{ model.model }}</div>
|
||||
{% else %}
|
||||
<div class="col-6 px-2 model-param" data-param-k="model" data-param-v="{{ model.mad_tag }}" data-param-group="{{ model_group.tag }}" tabindex="0">{{ model.model }}</div>
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% endfor %}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
.. note::
|
||||
|
||||
Some models, such as Llama 3, require an external license agreement through
|
||||
a third party (for example, Meta).
|
||||
|
||||
System validation
|
||||
=================
|
||||
|
||||
Before running AI workloads, it's important to validate that your AMD hardware is configured
|
||||
correctly and performing optimally.
|
||||
|
||||
If you have already validated your system settings, including aspects like NUMA auto-balancing, you
|
||||
can skip this step. Otherwise, complete the procedures in the :ref:`System validation and
|
||||
optimization <rocm-for-ai-system-optimization>` guide to properly configure your system settings
|
||||
before starting training.
|
||||
|
||||
To test for optimal performance, consult the recommended :ref:`System health benchmarks
|
||||
<rocm-for-ai-system-health-bench>`. This suite of tests will help you verify and fine-tune your
|
||||
system's configuration.
|
||||
|
||||
Environment setup
|
||||
=================
|
||||
|
||||
This Docker image is optimized for specific model configurations outlined
|
||||
as follows. Performance can vary for other training workloads, as AMD
|
||||
doesn’t validate configurations and run conditions outside those described.
|
||||
|
||||
Pull the Docker image
|
||||
---------------------
|
||||
|
||||
Use the following command to pull the Docker image from Docker Hub.
|
||||
|
||||
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml
|
||||
|
||||
{% set dockers = data.dockers %}
|
||||
.. tab-set::
|
||||
|
||||
{% for docker in dockers %}
|
||||
{% set jax_version = docker.components["JAX"] %}
|
||||
|
||||
.. tab-item:: JAX {{ jax_version }}
|
||||
:sync: {{ docker.pull_tag }}
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
docker pull {{ docker.pull_tag }}
|
||||
|
||||
{% endfor %}
|
||||
|
||||
.. _amd-maxtext-multi-node-setup-v257:
|
||||
|
||||
Multi-node configuration
|
||||
------------------------
|
||||
|
||||
See :doc:`/how-to/rocm-for-ai/system-setup/multi-node-setup` to configure your
|
||||
environment for multi-node training.
|
||||
|
||||
.. _amd-maxtext-get-started-v257:
|
||||
|
||||
Benchmarking
|
||||
============
|
||||
|
||||
Once the setup is complete, choose between two options to reproduce the
|
||||
benchmark results:
|
||||
|
||||
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml
|
||||
|
||||
.. _vllm-benchmark-mad:
|
||||
|
||||
{% set dockers = data.dockers %}
|
||||
{% set model_groups = data.model_groups %}
|
||||
{% for model_group in model_groups %}
|
||||
{% for model in model_group.models %}
|
||||
|
||||
.. container:: model-doc {{model.mad_tag}}
|
||||
|
||||
.. tab-set::
|
||||
|
||||
{% if model.mad_tag and "single-node" in model.doc_options %}
|
||||
.. tab-item:: MAD-integrated benchmarking
|
||||
|
||||
1. Clone the ROCm Model Automation and Dashboarding (`<https://github.com/ROCm/MAD>`__) repository to a local
|
||||
directory and install the required packages on the host machine.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
git clone https://github.com/ROCm/MAD
|
||||
cd MAD
|
||||
pip install -r requirements.txt
|
||||
|
||||
2. Use this command to run the performance benchmark test on the {{ model.model }} model
|
||||
using one GPU with the :literal:`{{model.precision}}` data type on the host machine.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models"
|
||||
madengine run \
|
||||
--tags {{model.mad_tag}} \
|
||||
--keep-model-dir \
|
||||
--live-output \
|
||||
--timeout 28800
|
||||
|
||||
MAD launches a Docker container with the name
|
||||
``container_ci-{{model.mad_tag}}``. The latency and throughput reports of the
|
||||
model are collected in the following path: ``~/MAD/perf.csv/``.
|
||||
{% endif %}
|
||||
|
||||
.. tab-item:: Standalone benchmarking
|
||||
|
||||
.. rubric:: Download the Docker image and required scripts
|
||||
|
||||
Run the JAX MaxText benchmark tool independently by starting the
|
||||
Docker container as shown in the following snippet.
|
||||
|
||||
.. tab-set::
|
||||
{% for docker in dockers %}
|
||||
{% set jax_version = docker.components["JAX"] %}
|
||||
|
||||
.. tab-item:: JAX {{ jax_version }}
|
||||
:sync: {{ docker.pull_tag }}
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
docker pull {{ docker.pull_tag }}
|
||||
{% endfor %}
|
||||
|
||||
{% if model.model_repo and "single-node" in model.doc_options %}
|
||||
.. rubric:: Single node training
|
||||
|
||||
1. Set up environment variables.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
export MAD_SECRETS_HFTOKEN=<Your Hugging Face token>
|
||||
export HF_HOME=<Location of saved/cached Hugging Face models>
|
||||
|
||||
``MAD_SECRETS_HFTOKEN`` is your Hugging Face access token to access models, tokenizers, and data.
|
||||
See `User access tokens <https://huggingface.co/docs/hub/en/security-tokens>`__.
|
||||
|
||||
``HF_HOME`` is where ``huggingface_hub`` will store local data. See `huggingface_hub CLI <https://huggingface.co/docs/huggingface_hub/main/en/guides/cli#huggingface-cli-download>`__.
|
||||
If you already have downloaded or cached Hugging Face artifacts, set this variable to that path.
|
||||
Downloaded files typically get cached to ``~/.cache/huggingface``.
|
||||
|
||||
2. Launch the Docker container.
|
||||
|
||||
.. tab-set::
|
||||
{% for docker in dockers %}
|
||||
{% set jax_version = docker.components["JAX"] %}
|
||||
|
||||
.. tab-item:: JAX {{ jax_version }}
|
||||
:sync: {{ docker.pull_tag }}
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
docker run -it \
|
||||
--device=/dev/dri \
|
||||
--device=/dev/kfd \
|
||||
--network host \
|
||||
--ipc host \
|
||||
--group-add video \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--security-opt seccomp=unconfined \
|
||||
--privileged \
|
||||
-v $HOME:$HOME \
|
||||
-v $HOME/.ssh:/root/.ssh \
|
||||
-v $HF_HOME:/hf_cache \
|
||||
-e HF_HOME=/hf_cache \
|
||||
-e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN
|
||||
--shm-size 64G \
|
||||
--name training_env \
|
||||
{{ docker.pull_tag }}
|
||||
{% endfor %}
|
||||
|
||||
3. In the Docker container, clone the ROCm MAD repository and navigate to the
|
||||
benchmark scripts directory at ``MAD/scripts/jax-maxtext``.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
git clone https://github.com/ROCm/MAD
|
||||
cd MAD/scripts/jax-maxtext
|
||||
|
||||
4. Run the setup scripts to install libraries and datasets needed
|
||||
for benchmarking.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
./jax-maxtext_benchmark_setup.sh -m {{ model.model_repo }}
|
||||
|
||||
5. To run the training benchmark without quantization, use the following command:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }}
|
||||
|
||||
For quantized training, use the following command:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} -q nanoo_fp8
|
||||
|
||||
{% endif %}
|
||||
{% if model.multinode_training_script and "multi-node" in model.doc_options %}
|
||||
.. rubric:: Multi-node training
|
||||
|
||||
The following examples use SLURM to run on multiple nodes.
|
||||
|
||||
.. note::
|
||||
|
||||
The following scripts will launch the Docker container and run the
|
||||
benchmark. Run them outside of any Docker container.
|
||||
|
||||
1. Make sure ``$HF_HOME`` is set before running the test. See
|
||||
`ROCm benchmarking <https://github.com/ROCm/MAD/blob/develop/scripts/jax-maxtext/gpu-rocm/readme.md>`__
|
||||
for more details on downloading the Llama models before running the
|
||||
benchmark.
|
||||
|
||||
2. To run multi-node training for {{ model.model }},
|
||||
use the
|
||||
`multi-node training script <https://github.com/ROCm/MAD/blob/develop/scripts/jax-maxtext/gpu-rocm/{{ model.multinode_training_script }}>`__
|
||||
under the ``scripts/jax-maxtext/gpu-rocm/`` directory.
|
||||
|
||||
3. Run the multi-node training benchmark script.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
sbatch -N <num_nodes> {{ model.multinode_training_script }}
|
||||
|
||||
{% else %}
|
||||
.. rubric:: Multi-node training
|
||||
|
||||
For multi-node training examples, choose a model from :ref:`amd-maxtext-model-support-v257`
|
||||
with an available `multi-node training script <https://github.com/ROCm/MAD/tree/develop/scripts/jax-maxtext/gpu-rocm>`__.
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% endfor %}
|
||||
|
||||
Further reading
|
||||
===============
|
||||
|
||||
- To learn more about MAD and the ``madengine`` CLI, see the `MAD usage guide <https://github.com/ROCm/MAD?tab=readme-ov-file#usage-guide>`__.
|
||||
|
||||
- To learn more about system settings and management practices to configure your system for
|
||||
AMD Instinct MI300X series GPUs, see `AMD Instinct MI300X system optimization <https://instinct.docs.amd.com/projects/amdgpu-docs/en/latest/system-optimization/mi300x.html>`_.
|
||||
|
||||
- For a list of other ready-made Docker images for AI with ROCm, see
|
||||
`AMD Infinity Hub <https://www.amd.com/en/developer/resources/infinity-hub.html#f-amd_hub_category=AI%20%26%20ML%20Models>`_.
|
||||
|
||||
Previous versions
|
||||
=================
|
||||
|
||||
See :doc:`jax-maxtext-history` to find documentation for previous releases
|
||||
of the ``ROCm/jax-training`` Docker image.
|
||||
Reference in New Issue
Block a user