JAX Maxtext v25.9 doc update (#5532)

* archive previous version (25.7)

* update docker components list for 25.9

* update template

* update docker pull tag

* update

* fix intro
This commit is contained in:
peterjunpark
2025-10-17 11:31:06 -04:00
committed by GitHub
parent 14bb59fca9
commit a613bd6824
5 changed files with 541 additions and 110 deletions

View File

@@ -1,47 +1,16 @@
dockers:
- pull_tag: rocm/jax-training:maxtext-v25.7-jax060
- pull_tag: rocm/jax-training:maxtext-v25.9
docker_hub_url: https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.7/images/sha256-45f4c727d4019a63fc47313d3a5f5a5105569539294ddfd2d742218212ae9025
components:
ROCm: 6.4.1
JAX: 0.6.0
Python: 3.10.12
Transformer Engine: 2.1.0+90d703dd
hipBLASLt: 1.1.0-499ece1c21
- pull_tag: rocm/jax-training:maxtext-v25.7
docker_hub_url: https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.7/images/sha256-45f4c727d4019a63fc47313d3a5f5a5105569539294ddfd2d742218212ae9025
components:
ROCm: 6.4.1
JAX: 0.5.0
Python: 3.10.12
Transformer Engine: 2.1.0+90d703dd
ROCm: 7.0.0
JAX: 0.6.2
Python: 3.10.18
Transformer Engine: 2.2.0.dev0+c91bac54
hipBLASLt: 1.x.x
model_groups:
- group: Meta Llama
tag: llama
models:
- model: Llama 3.3 70B
mad_tag: jax_maxtext_train_llama-3.3-70b
model_repo: Llama-3.3-70B
precision: bf16
doc_options: ["single-node"]
- model: Llama 3.1 8B
mad_tag: jax_maxtext_train_llama-3.1-8b
model_repo: Llama-3.1-8B
precision: bf16
doc_options: ["single-node"]
- model: Llama 3.1 70B
mad_tag: jax_maxtext_train_llama-3.1-70b
model_repo: Llama-3.1-70B
precision: bf16
doc_options: ["single-node"]
- model: Llama 3 8B
mad_tag: jax_maxtext_train_llama-3-8b
multinode_training_script: llama3_8b_multinode.sh
doc_options: ["multi-node"]
- model: Llama 3 70B
mad_tag: jax_maxtext_train_llama-3-70b
multinode_training_script: llama3_70b_multinode.sh
doc_options: ["multi-node"]
- model: Llama 2 7B
mad_tag: jax_maxtext_train_llama-2-7b
model_repo: Llama-2-7B
@@ -54,6 +23,29 @@ model_groups:
precision: bf16
multinode_training_script: llama2_70b_multinode.sh
doc_options: ["single-node", "multi-node"]
- model: Llama 3 8B (multi-node)
mad_tag: jax_maxtext_train_llama-3-8b
multinode_training_script: llama3_8b_multinode.sh
doc_options: ["multi-node"]
- model: Llama 3 70B (multi-node)
mad_tag: jax_maxtext_train_llama-3-70b
multinode_training_script: llama3_70b_multinode.sh
doc_options: ["multi-node"]
- model: Llama 3.1 8B
mad_tag: jax_maxtext_train_llama-3.1-8b
model_repo: Llama-3.1-8B
precision: bf16
doc_options: ["single-node"]
- model: Llama 3.1 70B
mad_tag: jax_maxtext_train_llama-3.1-70b
model_repo: Llama-3.1-70B
precision: bf16
doc_options: ["single-node"]
- model: Llama 3.3 70B
mad_tag: jax_maxtext_train_llama-3.3-70b
model_repo: Llama-3.3-70B
precision: bf16
doc_options: ["single-node"]
- group: DeepSeek
tag: deepseek
models:

View File

@@ -0,0 +1,72 @@
dockers:
- pull_tag: rocm/jax-training:maxtext-v25.7-jax060
docker_hub_url: https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.7/images/sha256-45f4c727d4019a63fc47313d3a5f5a5105569539294ddfd2d742218212ae9025
components:
ROCm: 6.4.1
JAX: 0.6.0
Python: 3.10.12
Transformer Engine: 2.1.0+90d703dd
hipBLASLt: 1.1.0-499ece1c21
- pull_tag: rocm/jax-training:maxtext-v25.7
docker_hub_url: https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.7/images/sha256-45f4c727d4019a63fc47313d3a5f5a5105569539294ddfd2d742218212ae9025
components:
ROCm: 6.4.1
JAX: 0.5.0
Python: 3.10.12
Transformer Engine: 2.1.0+90d703dd
hipBLASLt: 1.x.x
model_groups:
- group: Meta Llama
tag: llama
models:
- model: Llama 3.3 70B
mad_tag: jax_maxtext_train_llama-3.3-70b
model_repo: Llama-3.3-70B
precision: bf16
doc_options: ["single-node"]
- model: Llama 3.1 8B
mad_tag: jax_maxtext_train_llama-3.1-8b
model_repo: Llama-3.1-8B
precision: bf16
doc_options: ["single-node"]
- model: Llama 3.1 70B
mad_tag: jax_maxtext_train_llama-3.1-70b
model_repo: Llama-3.1-70B
precision: bf16
doc_options: ["single-node"]
- model: Llama 3 8B
mad_tag: jax_maxtext_train_llama-3-8b
multinode_training_script: llama3_8b_multinode.sh
doc_options: ["multi-node"]
- model: Llama 3 70B
mad_tag: jax_maxtext_train_llama-3-70b
multinode_training_script: llama3_70b_multinode.sh
doc_options: ["multi-node"]
- model: Llama 2 7B
mad_tag: jax_maxtext_train_llama-2-7b
model_repo: Llama-2-7B
precision: bf16
multinode_training_script: llama2_7b_multinode.sh
doc_options: ["single-node", "multi-node"]
- model: Llama 2 70B
mad_tag: jax_maxtext_train_llama-2-70b
model_repo: Llama-2-70B
precision: bf16
multinode_training_script: llama2_70b_multinode.sh
doc_options: ["single-node", "multi-node"]
- group: DeepSeek
tag: deepseek
models:
- model: DeepSeek-V2-Lite (16B)
mad_tag: jax_maxtext_train_deepseek-v2-lite-16b
model_repo: DeepSeek-V2-lite
precision: bf16
doc_options: ["single-node"]
- group: Mistral AI
tag: mistral
models:
- model: Mixtral 8x7B
mad_tag: jax_maxtext_train_mixtral-8x7b
model_repo: Mixtral-8x7B
precision: bf16
doc_options: ["single-node"]

View File

@@ -6,14 +6,8 @@
Training a model with JAX MaxText on ROCm
******************************************
MaxText is a high-performance, open-source framework built on the Google JAX
machine learning library to train LLMs at scale. The MaxText framework for
ROCm is an optimized fork of the upstream
`<https://github.com/AI-Hypercomputer/maxtext>`__ enabling efficient AI workloads
on AMD MI300X Series GPUs.
The MaxText for ROCm training Docker image
provides a prebuilt environment for training on AMD Instinct MI300X and MI325X GPUs,
provides a prebuilt environment for training on AMD Instinct MI355X, MI350X, MI325X, and MI300X GPUs,
including essential components like JAX, XLA, ROCm libraries, and MaxText utilities.
It includes the following software components:
@@ -61,15 +55,15 @@ MaxText with on ROCm provides the following key features to train large language
- Multi-node support
- NANOO FP8 quantization support
- NANOO FP8 (for MI300X series GPUs) and FP8 (for MI355X and MI350X) quantization support
.. _amd-maxtext-model-support-v257:
.. _amd-maxtext-model-support-v259:
Supported models
================
The following models are pre-optimized for performance on AMD Instinct MI300
Series GPUs. Some instructions, commands, and available training
The following models are pre-optimized for performance on AMD Instinct
GPUs. Some instructions, commands, and available training
configurations in this documentation might vary by model -- select one to get
started.
@@ -139,22 +133,13 @@ Use the following command to pull the Docker image from Docker Hub.
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml
{% set dockers = data.dockers %}
.. tab-set::
{% set docker = data.dockers[0] %}
{% for docker in dockers %}
{% set jax_version = docker.components["JAX"] %}
.. code-block:: shell
.. tab-item:: JAX {{ jax_version }}
:sync: {{ docker.pull_tag }}
docker pull {{ docker.pull_tag }}
.. code-block:: shell
docker pull {{ docker.pull_tag }}
{% endfor %}
.. _amd-maxtext-multi-node-setup-v257:
.. _amd-maxtext-multi-node-setup-v259:
Multi-node configuration
------------------------
@@ -162,7 +147,7 @@ Multi-node configuration
See :doc:`/how-to/rocm-for-ai/system-setup/multi-node-setup` to configure your
environment for multi-node training.
.. _amd-maxtext-get-started-v257:
.. _amd-maxtext-get-started-v259:
Benchmarking
============
@@ -174,7 +159,7 @@ benchmark results:
.. _vllm-benchmark-mad:
{% set dockers = data.dockers %}
{% set docker = data.dockers[0] %}
{% set model_groups = data.model_groups %}
{% for model_group in model_groups %}
{% for model in model_group.models %}
@@ -186,6 +171,9 @@ benchmark results:
{% if model.mad_tag and "single-node" in model.doc_options %}
.. tab-item:: MAD-integrated benchmarking
The following run command is tailored to {{ model.model }}.
See :ref:`amd-maxtext-model-support-v259` to switch to another available model.
1. Clone the ROCm Model Automation and Dashboarding (`<https://github.com/ROCm/MAD>`__) repository to a local
directory and install the required packages on the host machine.
@@ -214,22 +202,19 @@ benchmark results:
.. tab-item:: Standalone benchmarking
The following commands are optimized for {{ model.model }}. See
:ref:`amd-maxtext-model-support-v259` to switch to another
available model. Some instructions and resources might not be
available for all models and configurations.
.. rubric:: Download the Docker image and required scripts
Run the JAX MaxText benchmark tool independently by starting the
Docker container as shown in the following snippet.
.. tab-set::
{% for docker in dockers %}
{% set jax_version = docker.components["JAX"] %}
.. code-block:: shell
.. tab-item:: JAX {{ jax_version }}
:sync: {{ docker.pull_tag }}
.. code-block:: shell
docker pull {{ docker.pull_tag }}
{% endfor %}
docker pull {{ docker.pull_tag }}
{% if model.model_repo and "single-node" in model.doc_options %}
.. rubric:: Single node training
@@ -250,33 +235,25 @@ benchmark results:
2. Launch the Docker container.
.. tab-set::
{% for docker in dockers %}
{% set jax_version = docker.components["JAX"] %}
.. code-block:: shell
.. tab-item:: JAX {{ jax_version }}
:sync: {{ docker.pull_tag }}
.. code-block:: shell
docker run -it \
--device=/dev/dri \
--device=/dev/kfd \
--network host \
--ipc host \
--group-add video \
--cap-add=SYS_PTRACE \
--security-opt seccomp=unconfined \
--privileged \
-v $HOME:$HOME \
-v $HOME/.ssh:/root/.ssh \
-v $HF_HOME:/hf_cache \
-e HF_HOME=/hf_cache \
-e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN
--shm-size 64G \
--name training_env \
{{ docker.pull_tag }}
{% endfor %}
docker run -it \
--device=/dev/dri \
--device=/dev/kfd \
--network host \
--ipc host \
--group-add video \
--cap-add=SYS_PTRACE \
--security-opt seccomp=unconfined \
--privileged \
-v $HOME:$HOME \
-v $HOME/.ssh:/root/.ssh \
-v $HF_HOME:/hf_cache \
-e HF_HOME=/hf_cache \
-e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN
--shm-size 64G \
--name training_env \
{{ docker.pull_tag }}
3. In the Docker container, clone the ROCm MAD repository and navigate to the
benchmark scripts directory at ``MAD/scripts/jax-maxtext``.
@@ -299,11 +276,27 @@ benchmark results:
./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }}
For quantized training, use the following command:
For quantized training, run the script with the appropriate option for your Instinct GPU.
.. code-block:: shell
.. tab-set::
./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} -q nanoo_fp8
.. tab-item:: MI355X and MI350X
For ``fp8`` quantized training on MI355X and MI350X GPUs, use the following command:
.. code-block:: shell
./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} -q fp8
{% if model.model_repo not in ["Llama-3.1-70B", "Llama-3.3-70B"] %}
.. tab-item:: MI325X and MI300X
For ``nanoo_fp8`` quantized training on MI300X series GPUs, use the following command:
.. code-block:: shell
./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} -q nanoo_fp8
{% endif %}
{% endif %}
{% if model.multinode_training_script and "multi-node" in model.doc_options %}
@@ -335,7 +328,7 @@ benchmark results:
{% else %}
.. rubric:: Multi-node training
For multi-node training examples, choose a model from :ref:`amd-maxtext-model-support-v257`
For multi-node training examples, choose a model from :ref:`amd-maxtext-model-support-v259`
with an available `multi-node training script <https://github.com/ROCm/MAD/tree/develop/scripts/jax-maxtext/gpu-rocm>`__.
{% endif %}
{% endfor %}

View File

@@ -17,27 +17,35 @@ previous releases of the ``ROCm/jax-training`` Docker image on `Docker Hub <http
- Components
- Resources
* - 25.7 (latest)
-
* - 25.9 (latest)
-
* ROCm 7.0.0
* JAX 0.6.2
-
* :doc:`Documentation <../jax-maxtext>`
* `Docker Hub <https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.7-jax060/images/sha256-7352212ae033a76dca2b9dceffc23c1b5f1a61a7a560082cf747a9bf1acfc9ce>`__
* - 25.7
-
* ROCm 6.4.1
* JAX 0.6.0, 0.5.0
-
* :doc:`Documentation <../jax-maxtext>`
-
* :doc:`Documentation <jax-maxtext-v25.7>`
* `Docker Hub (JAX 0.6.0) <https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.7-jax060/images/sha256-7352212ae033a76dca2b9dceffc23c1b5f1a61a7a560082cf747a9bf1acfc9ce>`__
* `Docker Hub (JAX 0.5.0) <https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.7/images/sha256-45f4c727d4019a63fc47313d3a5f5a5105569539294ddfd2d742218212ae9025>`__
* - 25.5
-
-
* ROCm 6.3.4
* JAX 0.4.35
-
-
* :doc:`Documentation <jax-maxtext-v25.5>`
* `Docker Hub <https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.5/images/sha256-4e0516358a227cae8f552fb866ec07e2edcf244756f02e7b40212abfbab5217b>`__
* - 25.4
-
-
* ROCm 6.3.0
* JAX 0.4.31
-
-
* :doc:`Documentation <jax-maxtext-v25.4>`
* `Docker Hub <https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.4/images/sha256-fb3eb71cd74298a7b3044b7130cf84113f14d518ff05a2cd625c11ea5f6a7b01>`__

View File

@@ -0,0 +1,366 @@
:orphan:
.. meta::
:description: How to train a model using JAX MaxText for ROCm.
:keywords: ROCm, AI, LLM, train, jax, torch, Llama, flux, tutorial, docker
******************************************
Training a model with JAX MaxText on ROCm
******************************************
.. caution::
This documentation does not reflect the latest version of ROCm JAX MaxText
training performance documentation. See :doc:`../jax-maxtext` for the latest version.
MaxText is a high-performance, open-source framework built on the Google JAX
machine learning library to train LLMs at scale. The MaxText framework for
ROCm is an optimized fork of the upstream
`<https://github.com/AI-Hypercomputer/maxtext>`__ enabling efficient AI workloads
on AMD MI300X series GPUs.
The MaxText for ROCm training Docker image
provides a prebuilt environment for training on AMD Instinct MI300X and MI325X GPUs,
including essential components like JAX, XLA, ROCm libraries, and MaxText utilities.
It includes the following software components:
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml
{% set dockers = data.dockers %}
.. tab-set::
{% for docker in dockers %}
{% set jax_version = docker.components["JAX"] %}
.. tab-item:: ``{{ docker.pull_tag }}``
:sync: {{ docker.pull_tag }}
.. list-table::
:header-rows: 1
* - Software component
- Version
{% for component_name, component_version in docker.components.items() %}
* - {{ component_name }}
- {{ component_version }}
{% endfor %}
{% if jax_version == "0.6.0" %}
.. note::
Shardy is a new config in JAX 0.6.0. You might get related errors if it's
not configured correctly. For now you can turn it off by setting
``shardy=False`` during the training run. You can also follow the `migration
guide <https://docs.jax.dev/en/latest/shardy_jax_migration.html>`__ to enable
it.
{% endif %}
{% endfor %}
MaxText with on ROCm provides the following key features to train large language models efficiently:
- Transformer Engine (TE)
- Flash Attention (FA) 3 -- with or without sequence input packing
- GEMM tuning
- Multi-node support
- NANOO FP8 quantization support
.. _amd-maxtext-model-support-v257:
Supported models
================
The following models are pre-optimized for performance on AMD Instinct MI300
series GPUs. Some instructions, commands, and available training
configurations in this documentation might vary by model -- select one to get
started.
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml
{% set model_groups = data.model_groups %}
.. raw:: html
<div id="vllm-benchmark-ud-params-picker" class="container-fluid">
<div class="row gx-0">
<div class="col-2 me-1 px-2 model-param-head">Model</div>
<div class="row col-10 pe-0">
{% for model_group in model_groups %}
<div class="col-4 px-2 model-param" data-param-k="model-group" data-param-v="{{ model_group.tag }}" tabindex="0">{{ model_group.group }}</div>
{% endfor %}
</div>
</div>
<div class="row gx-0 pt-1">
<div class="col-2 me-1 px-2 model-param-head">Variant</div>
<div class="row col-10 pe-0">
{% for model_group in model_groups %}
{% set models = model_group.models %}
{% for model in models %}
{% if models|length % 3 == 0 %}
<div class="col-4 px-2 model-param" data-param-k="model" data-param-v="{{ model.mad_tag }}" data-param-group="{{ model_group.tag }}" tabindex="0">{{ model.model }}</div>
{% else %}
<div class="col-6 px-2 model-param" data-param-k="model" data-param-v="{{ model.mad_tag }}" data-param-group="{{ model_group.tag }}" tabindex="0">{{ model.model }}</div>
{% endif %}
{% endfor %}
{% endfor %}
</div>
</div>
</div>
.. note::
Some models, such as Llama 3, require an external license agreement through
a third party (for example, Meta).
System validation
=================
Before running AI workloads, it's important to validate that your AMD hardware is configured
correctly and performing optimally.
If you have already validated your system settings, including aspects like NUMA auto-balancing, you
can skip this step. Otherwise, complete the procedures in the :ref:`System validation and
optimization <rocm-for-ai-system-optimization>` guide to properly configure your system settings
before starting training.
To test for optimal performance, consult the recommended :ref:`System health benchmarks
<rocm-for-ai-system-health-bench>`. This suite of tests will help you verify and fine-tune your
system's configuration.
Environment setup
=================
This Docker image is optimized for specific model configurations outlined
as follows. Performance can vary for other training workloads, as AMD
doesnt validate configurations and run conditions outside those described.
Pull the Docker image
---------------------
Use the following command to pull the Docker image from Docker Hub.
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml
{% set dockers = data.dockers %}
.. tab-set::
{% for docker in dockers %}
{% set jax_version = docker.components["JAX"] %}
.. tab-item:: JAX {{ jax_version }}
:sync: {{ docker.pull_tag }}
.. code-block:: shell
docker pull {{ docker.pull_tag }}
{% endfor %}
.. _amd-maxtext-multi-node-setup-v257:
Multi-node configuration
------------------------
See :doc:`/how-to/rocm-for-ai/system-setup/multi-node-setup` to configure your
environment for multi-node training.
.. _amd-maxtext-get-started-v257:
Benchmarking
============
Once the setup is complete, choose between two options to reproduce the
benchmark results:
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml
.. _vllm-benchmark-mad:
{% set dockers = data.dockers %}
{% set model_groups = data.model_groups %}
{% for model_group in model_groups %}
{% for model in model_group.models %}
.. container:: model-doc {{model.mad_tag}}
.. tab-set::
{% if model.mad_tag and "single-node" in model.doc_options %}
.. tab-item:: MAD-integrated benchmarking
1. Clone the ROCm Model Automation and Dashboarding (`<https://github.com/ROCm/MAD>`__) repository to a local
directory and install the required packages on the host machine.
.. code-block:: shell
git clone https://github.com/ROCm/MAD
cd MAD
pip install -r requirements.txt
2. Use this command to run the performance benchmark test on the {{ model.model }} model
using one GPU with the :literal:`{{model.precision}}` data type on the host machine.
.. code-block:: shell
export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models"
madengine run \
--tags {{model.mad_tag}} \
--keep-model-dir \
--live-output \
--timeout 28800
MAD launches a Docker container with the name
``container_ci-{{model.mad_tag}}``. The latency and throughput reports of the
model are collected in the following path: ``~/MAD/perf.csv/``.
{% endif %}
.. tab-item:: Standalone benchmarking
.. rubric:: Download the Docker image and required scripts
Run the JAX MaxText benchmark tool independently by starting the
Docker container as shown in the following snippet.
.. tab-set::
{% for docker in dockers %}
{% set jax_version = docker.components["JAX"] %}
.. tab-item:: JAX {{ jax_version }}
:sync: {{ docker.pull_tag }}
.. code-block:: shell
docker pull {{ docker.pull_tag }}
{% endfor %}
{% if model.model_repo and "single-node" in model.doc_options %}
.. rubric:: Single node training
1. Set up environment variables.
.. code-block:: shell
export MAD_SECRETS_HFTOKEN=<Your Hugging Face token>
export HF_HOME=<Location of saved/cached Hugging Face models>
``MAD_SECRETS_HFTOKEN`` is your Hugging Face access token to access models, tokenizers, and data.
See `User access tokens <https://huggingface.co/docs/hub/en/security-tokens>`__.
``HF_HOME`` is where ``huggingface_hub`` will store local data. See `huggingface_hub CLI <https://huggingface.co/docs/huggingface_hub/main/en/guides/cli#huggingface-cli-download>`__.
If you already have downloaded or cached Hugging Face artifacts, set this variable to that path.
Downloaded files typically get cached to ``~/.cache/huggingface``.
2. Launch the Docker container.
.. tab-set::
{% for docker in dockers %}
{% set jax_version = docker.components["JAX"] %}
.. tab-item:: JAX {{ jax_version }}
:sync: {{ docker.pull_tag }}
.. code-block:: shell
docker run -it \
--device=/dev/dri \
--device=/dev/kfd \
--network host \
--ipc host \
--group-add video \
--cap-add=SYS_PTRACE \
--security-opt seccomp=unconfined \
--privileged \
-v $HOME:$HOME \
-v $HOME/.ssh:/root/.ssh \
-v $HF_HOME:/hf_cache \
-e HF_HOME=/hf_cache \
-e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN
--shm-size 64G \
--name training_env \
{{ docker.pull_tag }}
{% endfor %}
3. In the Docker container, clone the ROCm MAD repository and navigate to the
benchmark scripts directory at ``MAD/scripts/jax-maxtext``.
.. code-block:: shell
git clone https://github.com/ROCm/MAD
cd MAD/scripts/jax-maxtext
4. Run the setup scripts to install libraries and datasets needed
for benchmarking.
.. code-block:: shell
./jax-maxtext_benchmark_setup.sh -m {{ model.model_repo }}
5. To run the training benchmark without quantization, use the following command:
.. code-block:: shell
./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }}
For quantized training, use the following command:
.. code-block:: shell
./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} -q nanoo_fp8
{% endif %}
{% if model.multinode_training_script and "multi-node" in model.doc_options %}
.. rubric:: Multi-node training
The following examples use SLURM to run on multiple nodes.
.. note::
The following scripts will launch the Docker container and run the
benchmark. Run them outside of any Docker container.
1. Make sure ``$HF_HOME`` is set before running the test. See
`ROCm benchmarking <https://github.com/ROCm/MAD/blob/develop/scripts/jax-maxtext/gpu-rocm/readme.md>`__
for more details on downloading the Llama models before running the
benchmark.
2. To run multi-node training for {{ model.model }},
use the
`multi-node training script <https://github.com/ROCm/MAD/blob/develop/scripts/jax-maxtext/gpu-rocm/{{ model.multinode_training_script }}>`__
under the ``scripts/jax-maxtext/gpu-rocm/`` directory.
3. Run the multi-node training benchmark script.
.. code-block:: shell
sbatch -N <num_nodes> {{ model.multinode_training_script }}
{% else %}
.. rubric:: Multi-node training
For multi-node training examples, choose a model from :ref:`amd-maxtext-model-support-v257`
with an available `multi-node training script <https://github.com/ROCm/MAD/tree/develop/scripts/jax-maxtext/gpu-rocm>`__.
{% endif %}
{% endfor %}
{% endfor %}
Further reading
===============
- To learn more about MAD and the ``madengine`` CLI, see the `MAD usage guide <https://github.com/ROCm/MAD?tab=readme-ov-file#usage-guide>`__.
- To learn more about system settings and management practices to configure your system for
AMD Instinct MI300X series GPUs, see `AMD Instinct MI300X system optimization <https://instinct.docs.amd.com/projects/amdgpu-docs/en/latest/system-optimization/mi300x.html>`_.
- For a list of other ready-made Docker images for AI with ROCm, see
`AMD Infinity Hub <https://www.amd.com/en/developer/resources/infinity-hub.html#f-amd_hub_category=AI%20%26%20ML%20Models>`_.
Previous versions
=================
See :doc:`jax-maxtext-history` to find documentation for previous releases
of the ``ROCm/jax-training`` Docker image.