Update documentation for JAX training MaxText 25.11 release (#5789)

This commit is contained in:
peterjunpark
2025-12-18 11:23:58 -05:00
committed by GitHub
parent 459283da3c
commit cbab9a465d
7 changed files with 536 additions and 33 deletions

View File

@@ -135,19 +135,29 @@ article_pages = [
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/megatron-lm-v25.5", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/megatron-lm-v25.6", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/megatron-lm-v25.7", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/megatron-lm-v25.8", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/megatron-lm-v25.9", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/megatron-lm-v25.10", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/megatron-lm-primus-migration-guide", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/primus-megatron-v25.7", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/primus-megatron", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/primus-megatron-v25.7", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/primus-megatron-v25.8", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/primus-megatron-v25.9", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/primus-megatron-v25.10", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/pytorch-training", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/pytorch-training-history", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/pytorch-training-v25.3", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/pytorch-training-v25.4", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/pytorch-training-v25.5", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/pytorch-training-v25.6", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/inference/xdit-diffusion-inference", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/pytorch-training-v25.7", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/pytorch-training-v25.8", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/pytorch-training-v25.9", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/pytorch-training-v25.10", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/primus-pytorch", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/pytorch-training", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/primus-pytorch-v25.8", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/primus-pytorch-v25.9", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/primus-pytorch-v25.10", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/jax-maxtext", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-history", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-v25.4", "os": ["linux"]},
@@ -177,8 +187,12 @@ article_pages = [
{"file": "how-to/rocm-for-ai/inference/benchmark-docker/previous-versions/vllm-0.9.1-20250702", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/inference/benchmark-docker/previous-versions/vllm-0.9.1-20250715", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/inference/benchmark-docker/previous-versions/vllm-0.10.0-20250812", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/inference/benchmark-docker/previous-versions/vllm-0.10.1-20250909", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/inference/benchmark-docker/previous-versions/vllm-0.10.2-20251006", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/inference/benchmark-docker/previous-versions/vllm-0.11.1-20251103", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/inference/benchmark-docker/previous-versions/sglang-history", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/inference/benchmark-docker/pytorch-inference", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/inference/xdit-diffusion-inference", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/inference/deploy-your-model", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/inference-optimization/index", "os": ["linux"]},

View File

@@ -1,12 +1,12 @@
dockers:
- pull_tag: rocm/jax-training:maxtext-v25.9
docker_hub_url: https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.7/images/sha256-45f4c727d4019a63fc47313d3a5f5a5105569539294ddfd2d742218212ae9025
- pull_tag: rocm/jax-training:maxtext-v25.11
docker_hub_url: https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.11/images/sha256-18e4d8f0b8ce7a7422c58046940dd5f32249960449fca09a562b65fb8eb1562a
components:
ROCm: 7.0.0
JAX: 0.6.2
Python: 3.10.18
Transformer Engine: 2.2.0.dev0+c91bac54
hipBLASLt: 1.x.x
ROCm: 7.1.0
JAX: 0.7.1
Python: 3.12
Transformer Engine: 2.4.0.dev0+281042de
hipBLASLt: 1.2.x
model_groups:
- group: Meta Llama
tag: llama

View File

@@ -0,0 +1,64 @@
dockers:
- pull_tag: rocm/jax-training:maxtext-v25.9.1
docker_hub_url: https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.9.1/images/sha256-60946cfbd470f6ee361fc9da740233a4fb2e892727f01719145b1f7627a1cff6
components:
ROCm: 7.0.0
JAX: 0.6.2
Python: 3.10.18
Transformer Engine: 2.2.0.dev0+c91bac54
hipBLASLt: 1.x.x
model_groups:
- group: Meta Llama
tag: llama
models:
- model: Llama 2 7B
mad_tag: jax_maxtext_train_llama-2-7b
model_repo: Llama-2-7B
precision: bf16
multinode_training_script: llama2_7b_multinode.sh
doc_options: ["single-node", "multi-node"]
- model: Llama 2 70B
mad_tag: jax_maxtext_train_llama-2-70b
model_repo: Llama-2-70B
precision: bf16
multinode_training_script: llama2_70b_multinode.sh
doc_options: ["single-node", "multi-node"]
- model: Llama 3 8B (multi-node)
mad_tag: jax_maxtext_train_llama-3-8b
multinode_training_script: llama3_8b_multinode.sh
doc_options: ["multi-node"]
- model: Llama 3 70B (multi-node)
mad_tag: jax_maxtext_train_llama-3-70b
multinode_training_script: llama3_70b_multinode.sh
doc_options: ["multi-node"]
- model: Llama 3.1 8B
mad_tag: jax_maxtext_train_llama-3.1-8b
model_repo: Llama-3.1-8B
precision: bf16
doc_options: ["single-node"]
- model: Llama 3.1 70B
mad_tag: jax_maxtext_train_llama-3.1-70b
model_repo: Llama-3.1-70B
precision: bf16
doc_options: ["single-node"]
- model: Llama 3.3 70B
mad_tag: jax_maxtext_train_llama-3.3-70b
model_repo: Llama-3.3-70B
precision: bf16
doc_options: ["single-node"]
- group: DeepSeek
tag: deepseek
models:
- model: DeepSeek-V2-Lite (16B)
mad_tag: jax_maxtext_train_deepseek-v2-lite-16b
model_repo: DeepSeek-V2-lite
precision: bf16
doc_options: ["single-node"]
- group: Mistral AI
tag: mistral
models:
- model: Mixtral 8x7B
mad_tag: jax_maxtext_train_mixtral-8x7b
model_repo: Mixtral-8x7B
precision: bf16
doc_options: ["single-node"]

View File

@@ -33,18 +33,15 @@ It includes the following software components:
- {{ component_version }}
{% endfor %}
{% if jax_version == "0.6.0" %}
.. note::
Shardy is a new config in JAX 0.6.0. You might get related errors if it's
not configured correctly. For now you can turn it off by setting
``shardy=False`` during the training run. You can also follow the `migration
guide <https://docs.jax.dev/en/latest/shardy_jax_migration.html>`__ to enable
it.
{% endif %}
{% endfor %}
.. note::
The ``rocm/jax-training:maxtext-v25.9`` has been updated to
``rocm/jax-training:maxtext-v25.9.1``. This revision should include
a fix to address segmentation fault issues during launch. See the
:doc:`versioned documentation <previous-versions/jax-maxtext-v25.9>`.
MaxText with on ROCm provides the following key features to train large language models efficiently:
- Transformer Engine (TE)
@@ -57,7 +54,7 @@ MaxText with on ROCm provides the following key features to train large language
- NANOO FP8 (for MI300X series GPUs) and FP8 (for MI355X and MI350X) quantization support
.. _amd-maxtext-model-support-v259:
.. _amd-maxtext-model-support-v25.11:
Supported models
================
@@ -139,7 +136,7 @@ Use the following command to pull the Docker image from Docker Hub.
docker pull {{ docker.pull_tag }}
.. _amd-maxtext-multi-node-setup-v259:
.. _amd-maxtext-multi-node-setup-v25.11:
Multi-node configuration
------------------------
@@ -147,7 +144,7 @@ Multi-node configuration
See :doc:`/how-to/rocm-for-ai/system-setup/multi-node-setup` to configure your
environment for multi-node training.
.. _amd-maxtext-get-started-v259:
.. _amd-maxtext-get-started-v25.11:
Benchmarking
============
@@ -172,7 +169,7 @@ benchmark results:
.. tab-item:: MAD-integrated benchmarking
The following run command is tailored to {{ model.model }}.
See :ref:`amd-maxtext-model-support-v259` to switch to another available model.
See :ref:`amd-maxtext-model-support-v25.11` to switch to another available model.
1. Clone the ROCm Model Automation and Dashboarding (`<https://github.com/ROCm/MAD>`__) repository to a local
directory and install the required packages on the host machine.
@@ -203,7 +200,7 @@ benchmark results:
.. tab-item:: Standalone benchmarking
The following commands are optimized for {{ model.model }}. See
:ref:`amd-maxtext-model-support-v259` to switch to another
:ref:`amd-maxtext-model-support-v25.11` to switch to another
available model. Some instructions and resources might not be
available for all models and configurations.
@@ -325,15 +322,69 @@ benchmark results:
sbatch -N <num_nodes> {{ model.multinode_training_script }}
.. _maxtext-rocprofv3:
.. rubric:: Profiling with rocprofv3
If you need to collect a trace and the JAX profiler isn't working, use ``rocprofv3`` provided by the :doc:`ROCprofiler-SDK <rocprofiler-sdk:index>` as a workaround. For example:
.. code-block:: bash
rocprofv3 \
--hip-trace \
--kernel-trace \
--memory-copy-trace \
--rccl-trace \
--output-format pftrace \
-d ./v3_traces \ # output directory
-- ./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} # or desired command
You can set the directory where you want the .json traces to be
saved using ``-d <TRACE_DIRECTORY>``. The resulting traces can be
opened in Perfetto: `<https://ui.perfetto.dev/>`__.
{% else %}
.. rubric:: Multi-node training
For multi-node training examples, choose a model from :ref:`amd-maxtext-model-support-v259`
For multi-node training examples, choose a model from :ref:`amd-maxtext-model-support-v25.11`
with an available `multi-node training script <https://github.com/ROCm/MAD/tree/develop/scripts/jax-maxtext/gpu-rocm>`__.
{% endif %}
{% endfor %}
{% endfor %}
Known issues
============
- Minor performance regression (< 4%) for BF16 quantization in Llama models and Mixtral 8x7b.
- You might see minor loss spikes, or loss curve may have slightly higher
convergence end values compared to the previous ``jax-training`` image.
- For FP8 training on MI355, many models will display a warning message like:
``Warning: Latency not found for MI_M=16, MI_N=16, MI_K=128,
mi_input_type=BFloat8Float8_fnuz. Returning latency value of 32 (really
slow).`` The compile step may take longer than usual, but training will run.
This will be fixed in a future release.
- The built-in JAX profiler isn't working. If you need to collect a trace and
the JAX profiler isn't working, use ``rocprofv3`` provided by the
:doc:`ROCprofiler-SDK <rocprofiler-sdk:index>` as a workaround. For example:
.. code-block:: bash
rocprofv3 \
--hip-trace \
--kernel-trace \
--memory-copy-trace \
--rccl-trace \
--output-format pftrace \
-d ./v3_traces \ # output directory
-- ./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} # or desired command
You can set the directory where you want the .json traces to be
saved using ``-d <TRACE_DIRECTORY>``. The resulting traces can be
opened in Perfetto: `<https://ui.perfetto.dev/>`__.
Further reading
===============

View File

@@ -17,13 +17,22 @@ previous releases of the ``ROCm/jax-training`` Docker image on `Docker Hub <http
- Components
- Resources
* - 25.9 (latest)
* - 25.11
-
* ROCm 7.1.0
* JAX 0.7.1
-
* :doc:`Documentation <../jax-maxtext>`
* `Docker Hub <https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.11/images/sha256-18e4d8f0b8ce7a7422c58046940dd5f32249960449fca09a562b65fb8eb1562a>`__
* - 25.9.1
-
* ROCm 7.0.0
* JAX 0.6.2
-
* :doc:`Documentation <../jax-maxtext>`
* `Docker Hub <https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.7-jax060/images/sha256-7352212ae033a76dca2b9dceffc23c1b5f1a61a7a560082cf747a9bf1acfc9ce>`__
* :doc:`Documentation <jax-maxtext-v25.9>`
* `Docker Hub (25.9.1) <https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.9.1/images/sha256-60946cfbd470f6ee361fc9da740233a4fb2e892727f01719145b1f7627a1cff6>`__
* `Docker Hub (25.9) <https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.9/images/sha256-4bb16ab58279ef09cb7a5e362c38e3fe3f901de44d8dbac5d0cb3bac5686441e>`__
* - 25.7
-

View File

@@ -24,7 +24,7 @@ provides a prebuilt environment for training on AMD Instinct MI300X and MI325X G
including essential components like JAX, XLA, ROCm libraries, and MaxText utilities.
It includes the following software components:
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.7-benchmark-models.yaml
{% set dockers = data.dockers %}
.. tab-set::
@@ -80,7 +80,7 @@ series GPUs. Some instructions, commands, and available training
configurations in this documentation might vary by model -- select one to get
started.
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.7-benchmark-models.yaml
{% set model_groups = data.model_groups %}
.. raw:: html
@@ -144,7 +144,7 @@ Pull the Docker image
Use the following command to pull the Docker image from Docker Hub.
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.7-benchmark-models.yaml
{% set dockers = data.dockers %}
.. tab-set::
@@ -177,7 +177,7 @@ Benchmarking
Once the setup is complete, choose between two options to reproduce the
benchmark results:
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.7-benchmark-models.yaml
.. _vllm-benchmark-mad:

View File

@@ -0,0 +1,365 @@
:orphan:
.. meta::
:description: How to train a model using JAX MaxText for ROCm.
:keywords: ROCm, AI, LLM, train, jax, torch, Llama, flux, tutorial, docker
******************************************
Training a model with JAX MaxText on ROCm
******************************************
.. caution::
This documentation does not reflect the latest version of ROCm JAX MaxText
training performance documentation. See :doc:`../jax-maxtext` for the latest version.
.. note::
We have refreshed the ``rocm/jax-training:maxtext-v25.9`` image as
`rocm/jax-training:maxtext-v25.9.1`. This should include a fix to address
segmentation fault issues during launch.
The MaxText for ROCm training Docker image
provides a prebuilt environment for training on AMD Instinct MI355X, MI350X, MI325X, and MI300X GPUs,
including essential components like JAX, XLA, ROCm libraries, and MaxText utilities.
It includes the following software components:
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.9-benchmark-models.yaml
{% set dockers = data.dockers %}
.. tab-set::
{% for docker in dockers %}
{% set jax_version = docker.components["JAX"] %}
.. tab-item:: ``{{ docker.pull_tag }}``
:sync: {{ docker.pull_tag }}
.. list-table::
:header-rows: 1
* - Software component
- Version
{% for component_name, component_version in docker.components.items() %}
* - {{ component_name }}
- {{ component_version }}
{% endfor %}
{% if jax_version == "0.6.0" %}
.. note::
Shardy is a new config in JAX 0.6.0. You might get related errors if it's
not configured correctly. For now you can turn it off by setting
``shardy=False`` during the training run. You can also follow the `migration
guide <https://docs.jax.dev/en/latest/shardy_jax_migration.html>`__ to enable
it.
{% endif %}
{% endfor %}
MaxText with on ROCm provides the following key features to train large language models efficiently:
- Transformer Engine (TE)
- Flash Attention (FA) 3 -- with or without sequence input packing
- GEMM tuning
- Multi-node support
- NANOO FP8 (for MI300X series GPUs) and FP8 (for MI355X and MI350X) quantization support
.. _amd-maxtext-model-support-v259:
Supported models
================
The following models are pre-optimized for performance on AMD Instinct
GPUs. Some instructions, commands, and available training
configurations in this documentation might vary by model -- select one to get
started.
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.9-benchmark-models.yaml
{% set model_groups = data.model_groups %}
.. raw:: html
<div id="vllm-benchmark-ud-params-picker" class="container-fluid">
<div class="row gx-0">
<div class="col-2 me-1 px-2 model-param-head">Model</div>
<div class="row col-10 pe-0">
{% for model_group in model_groups %}
<div class="col-4 px-2 model-param" data-param-k="model-group" data-param-v="{{ model_group.tag }}" tabindex="0">{{ model_group.group }}</div>
{% endfor %}
</div>
</div>
<div class="row gx-0 pt-1">
<div class="col-2 me-1 px-2 model-param-head">Variant</div>
<div class="row col-10 pe-0">
{% for model_group in model_groups %}
{% set models = model_group.models %}
{% for model in models %}
{% if models|length % 3 == 0 %}
<div class="col-4 px-2 model-param" data-param-k="model" data-param-v="{{ model.mad_tag }}" data-param-group="{{ model_group.tag }}" tabindex="0">{{ model.model }}</div>
{% else %}
<div class="col-6 px-2 model-param" data-param-k="model" data-param-v="{{ model.mad_tag }}" data-param-group="{{ model_group.tag }}" tabindex="0">{{ model.model }}</div>
{% endif %}
{% endfor %}
{% endfor %}
</div>
</div>
</div>
.. note::
Some models, such as Llama 3, require an external license agreement through
a third party (for example, Meta).
System validation
=================
Before running AI workloads, it's important to validate that your AMD hardware is configured
correctly and performing optimally.
If you have already validated your system settings, including aspects like NUMA auto-balancing, you
can skip this step. Otherwise, complete the procedures in the :ref:`System validation and
optimization <rocm-for-ai-system-optimization>` guide to properly configure your system settings
before starting training.
To test for optimal performance, consult the recommended :ref:`System health benchmarks
<rocm-for-ai-system-health-bench>`. This suite of tests will help you verify and fine-tune your
system's configuration.
Environment setup
=================
This Docker image is optimized for specific model configurations outlined
as follows. Performance can vary for other training workloads, as AMD
doesnt validate configurations and run conditions outside those described.
Pull the Docker image
---------------------
Use the following command to pull the Docker image from Docker Hub.
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.9-benchmark-models.yaml
{% set docker = data.dockers[0] %}
.. code-block:: shell
docker pull {{ docker.pull_tag }}
.. _amd-maxtext-multi-node-setup-v259:
Multi-node configuration
------------------------
See :doc:`/how-to/rocm-for-ai/system-setup/multi-node-setup` to configure your
environment for multi-node training.
.. _amd-maxtext-get-started-v259:
Benchmarking
============
Once the setup is complete, choose between two options to reproduce the
benchmark results:
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.9-benchmark-models.yaml
.. _vllm-benchmark-mad:
{% set docker = data.dockers[0] %}
{% set model_groups = data.model_groups %}
{% for model_group in model_groups %}
{% for model in model_group.models %}
.. container:: model-doc {{model.mad_tag}}
.. tab-set::
{% if model.mad_tag and "single-node" in model.doc_options %}
.. tab-item:: MAD-integrated benchmarking
The following run command is tailored to {{ model.model }}.
See :ref:`amd-maxtext-model-support-v259` to switch to another available model.
1. Clone the ROCm Model Automation and Dashboarding (`<https://github.com/ROCm/MAD>`__) repository to a local
directory and install the required packages on the host machine.
.. code-block:: shell
git clone https://github.com/ROCm/MAD
cd MAD
pip install -r requirements.txt
2. Use this command to run the performance benchmark test on the {{ model.model }} model
using one GPU with the :literal:`{{model.precision}}` data type on the host machine.
.. code-block:: shell
export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models"
madengine run \
--tags {{model.mad_tag}} \
--keep-model-dir \
--live-output \
--timeout 28800
MAD launches a Docker container with the name
``container_ci-{{model.mad_tag}}``. The latency and throughput reports of the
model are collected in the following path: ``~/MAD/perf.csv/``.
{% endif %}
.. tab-item:: Standalone benchmarking
The following commands are optimized for {{ model.model }}. See
:ref:`amd-maxtext-model-support-v259` to switch to another
available model. Some instructions and resources might not be
available for all models and configurations.
.. rubric:: Download the Docker image and required scripts
Run the JAX MaxText benchmark tool independently by starting the
Docker container as shown in the following snippet.
.. code-block:: shell
docker pull {{ docker.pull_tag }}
{% if model.model_repo and "single-node" in model.doc_options %}
.. rubric:: Single node training
1. Set up environment variables.
.. code-block:: shell
export MAD_SECRETS_HFTOKEN=<Your Hugging Face token>
export HF_HOME=<Location of saved/cached Hugging Face models>
``MAD_SECRETS_HFTOKEN`` is your Hugging Face access token to access models, tokenizers, and data.
See `User access tokens <https://huggingface.co/docs/hub/en/security-tokens>`__.
``HF_HOME`` is where ``huggingface_hub`` will store local data. See `huggingface_hub CLI <https://huggingface.co/docs/huggingface_hub/main/en/guides/cli#huggingface-cli-download>`__.
If you already have downloaded or cached Hugging Face artifacts, set this variable to that path.
Downloaded files typically get cached to ``~/.cache/huggingface``.
2. Launch the Docker container.
.. code-block:: shell
docker run -it \
--device=/dev/dri \
--device=/dev/kfd \
--network host \
--ipc host \
--group-add video \
--cap-add=SYS_PTRACE \
--security-opt seccomp=unconfined \
--privileged \
-v $HOME:$HOME \
-v $HOME/.ssh:/root/.ssh \
-v $HF_HOME:/hf_cache \
-e HF_HOME=/hf_cache \
-e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN
--shm-size 64G \
--name training_env \
{{ docker.pull_tag }}
3. In the Docker container, clone the ROCm MAD repository and navigate to the
benchmark scripts directory at ``MAD/scripts/jax-maxtext``.
.. code-block:: shell
git clone https://github.com/ROCm/MAD
cd MAD/scripts/jax-maxtext
4. Run the setup scripts to install libraries and datasets needed
for benchmarking.
.. code-block:: shell
./jax-maxtext_benchmark_setup.sh -m {{ model.model_repo }}
5. To run the training benchmark without quantization, use the following command:
.. code-block:: shell
./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }}
For quantized training, run the script with the appropriate option for your Instinct GPU.
.. tab-set::
.. tab-item:: MI355X and MI350X
For ``fp8`` quantized training on MI355X and MI350X GPUs, use the following command:
.. code-block:: shell
./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} -q fp8
{% if model.model_repo not in ["Llama-3.1-70B", "Llama-3.3-70B"] %}
.. tab-item:: MI325X and MI300X
For ``nanoo_fp8`` quantized training on MI300X series GPUs, use the following command:
.. code-block:: shell
./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} -q nanoo_fp8
{% endif %}
{% endif %}
{% if model.multinode_training_script and "multi-node" in model.doc_options %}
.. rubric:: Multi-node training
The following examples use SLURM to run on multiple nodes.
.. note::
The following scripts will launch the Docker container and run the
benchmark. Run them outside of any Docker container.
1. Make sure ``$HF_HOME`` is set before running the test. See
`ROCm benchmarking <https://github.com/ROCm/MAD/blob/develop/scripts/jax-maxtext/gpu-rocm/readme.md>`__
for more details on downloading the Llama models before running the
benchmark.
2. To run multi-node training for {{ model.model }},
use the
`multi-node training script <https://github.com/ROCm/MAD/blob/develop/scripts/jax-maxtext/gpu-rocm/{{ model.multinode_training_script }}>`__
under the ``scripts/jax-maxtext/gpu-rocm/`` directory.
3. Run the multi-node training benchmark script.
.. code-block:: shell
sbatch -N <num_nodes> {{ model.multinode_training_script }}
{% else %}
.. rubric:: Multi-node training
For multi-node training examples, choose a model from :ref:`amd-maxtext-model-support-v259`
with an available `multi-node training script <https://github.com/ROCm/MAD/tree/develop/scripts/jax-maxtext/gpu-rocm>`__.
{% endif %}
{% endfor %}
{% endfor %}
Further reading
===============
- To learn more about MAD and the ``madengine`` CLI, see the `MAD usage guide <https://github.com/ROCm/MAD?tab=readme-ov-file#usage-guide>`__.
- To learn more about system settings and management practices to configure your system for
AMD Instinct MI300X Series GPUs, see `AMD Instinct MI300X system optimization <https://instinct.docs.amd.com/projects/amdgpu-docs/en/latest/system-optimization/mi300x.html>`_.
- For a list of other ready-made Docker images for AI with ROCm, see
`AMD Infinity Hub <https://www.amd.com/en/developer/resources/infinity-hub.html#f-amd_hub_category=AI%20%26%20ML%20Models>`_.
Previous versions
=================
See :doc:`jax-maxtext-history` to find documentation for previous releases
of the ``ROCm/jax-training`` Docker image.