mirror of
https://github.com/ROCm/ROCm.git
synced 2026-01-09 14:48:06 -05:00
Update documentation for JAX training MaxText 25.11 release (#5789)
This commit is contained in:
20
docs/conf.py
20
docs/conf.py
@@ -135,19 +135,29 @@ article_pages = [
|
|||||||
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/megatron-lm-v25.5", "os": ["linux"]},
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/megatron-lm-v25.5", "os": ["linux"]},
|
||||||
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/megatron-lm-v25.6", "os": ["linux"]},
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/megatron-lm-v25.6", "os": ["linux"]},
|
||||||
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/megatron-lm-v25.7", "os": ["linux"]},
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/megatron-lm-v25.7", "os": ["linux"]},
|
||||||
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/megatron-lm-v25.8", "os": ["linux"]},
|
||||||
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/megatron-lm-v25.9", "os": ["linux"]},
|
||||||
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/megatron-lm-v25.10", "os": ["linux"]},
|
||||||
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/megatron-lm-primus-migration-guide", "os": ["linux"]},
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/megatron-lm-primus-migration-guide", "os": ["linux"]},
|
||||||
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/primus-megatron-v25.7", "os": ["linux"]},
|
|
||||||
{"file": "how-to/rocm-for-ai/training/benchmark-docker/primus-megatron", "os": ["linux"]},
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/primus-megatron", "os": ["linux"]},
|
||||||
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/primus-megatron-v25.7", "os": ["linux"]},
|
||||||
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/primus-megatron-v25.8", "os": ["linux"]},
|
||||||
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/primus-megatron-v25.9", "os": ["linux"]},
|
||||||
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/primus-megatron-v25.10", "os": ["linux"]},
|
||||||
{"file": "how-to/rocm-for-ai/training/benchmark-docker/pytorch-training", "os": ["linux"]},
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/pytorch-training", "os": ["linux"]},
|
||||||
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/pytorch-training-history", "os": ["linux"]},
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/pytorch-training-history", "os": ["linux"]},
|
||||||
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/pytorch-training-v25.3", "os": ["linux"]},
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/pytorch-training-v25.3", "os": ["linux"]},
|
||||||
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/pytorch-training-v25.4", "os": ["linux"]},
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/pytorch-training-v25.4", "os": ["linux"]},
|
||||||
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/pytorch-training-v25.5", "os": ["linux"]},
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/pytorch-training-v25.5", "os": ["linux"]},
|
||||||
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/pytorch-training-v25.6", "os": ["linux"]},
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/pytorch-training-v25.6", "os": ["linux"]},
|
||||||
{"file": "how-to/rocm-for-ai/inference/xdit-diffusion-inference", "os": ["linux"]},
|
|
||||||
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/pytorch-training-v25.7", "os": ["linux"]},
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/pytorch-training-v25.7", "os": ["linux"]},
|
||||||
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/pytorch-training-v25.8", "os": ["linux"]},
|
||||||
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/pytorch-training-v25.9", "os": ["linux"]},
|
||||||
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/pytorch-training-v25.10", "os": ["linux"]},
|
||||||
{"file": "how-to/rocm-for-ai/training/benchmark-docker/primus-pytorch", "os": ["linux"]},
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/primus-pytorch", "os": ["linux"]},
|
||||||
{"file": "how-to/rocm-for-ai/training/benchmark-docker/pytorch-training", "os": ["linux"]},
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/primus-pytorch-v25.8", "os": ["linux"]},
|
||||||
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/primus-pytorch-v25.9", "os": ["linux"]},
|
||||||
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/primus-pytorch-v25.10", "os": ["linux"]},
|
||||||
{"file": "how-to/rocm-for-ai/training/benchmark-docker/jax-maxtext", "os": ["linux"]},
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/jax-maxtext", "os": ["linux"]},
|
||||||
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-history", "os": ["linux"]},
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-history", "os": ["linux"]},
|
||||||
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-v25.4", "os": ["linux"]},
|
{"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-v25.4", "os": ["linux"]},
|
||||||
@@ -177,8 +187,12 @@ article_pages = [
|
|||||||
{"file": "how-to/rocm-for-ai/inference/benchmark-docker/previous-versions/vllm-0.9.1-20250702", "os": ["linux"]},
|
{"file": "how-to/rocm-for-ai/inference/benchmark-docker/previous-versions/vllm-0.9.1-20250702", "os": ["linux"]},
|
||||||
{"file": "how-to/rocm-for-ai/inference/benchmark-docker/previous-versions/vllm-0.9.1-20250715", "os": ["linux"]},
|
{"file": "how-to/rocm-for-ai/inference/benchmark-docker/previous-versions/vllm-0.9.1-20250715", "os": ["linux"]},
|
||||||
{"file": "how-to/rocm-for-ai/inference/benchmark-docker/previous-versions/vllm-0.10.0-20250812", "os": ["linux"]},
|
{"file": "how-to/rocm-for-ai/inference/benchmark-docker/previous-versions/vllm-0.10.0-20250812", "os": ["linux"]},
|
||||||
|
{"file": "how-to/rocm-for-ai/inference/benchmark-docker/previous-versions/vllm-0.10.1-20250909", "os": ["linux"]},
|
||||||
|
{"file": "how-to/rocm-for-ai/inference/benchmark-docker/previous-versions/vllm-0.10.2-20251006", "os": ["linux"]},
|
||||||
|
{"file": "how-to/rocm-for-ai/inference/benchmark-docker/previous-versions/vllm-0.11.1-20251103", "os": ["linux"]},
|
||||||
{"file": "how-to/rocm-for-ai/inference/benchmark-docker/previous-versions/sglang-history", "os": ["linux"]},
|
{"file": "how-to/rocm-for-ai/inference/benchmark-docker/previous-versions/sglang-history", "os": ["linux"]},
|
||||||
{"file": "how-to/rocm-for-ai/inference/benchmark-docker/pytorch-inference", "os": ["linux"]},
|
{"file": "how-to/rocm-for-ai/inference/benchmark-docker/pytorch-inference", "os": ["linux"]},
|
||||||
|
{"file": "how-to/rocm-for-ai/inference/xdit-diffusion-inference", "os": ["linux"]},
|
||||||
{"file": "how-to/rocm-for-ai/inference/deploy-your-model", "os": ["linux"]},
|
{"file": "how-to/rocm-for-ai/inference/deploy-your-model", "os": ["linux"]},
|
||||||
|
|
||||||
{"file": "how-to/rocm-for-ai/inference-optimization/index", "os": ["linux"]},
|
{"file": "how-to/rocm-for-ai/inference-optimization/index", "os": ["linux"]},
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
dockers:
|
dockers:
|
||||||
- pull_tag: rocm/jax-training:maxtext-v25.9
|
- pull_tag: rocm/jax-training:maxtext-v25.11
|
||||||
docker_hub_url: https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.7/images/sha256-45f4c727d4019a63fc47313d3a5f5a5105569539294ddfd2d742218212ae9025
|
docker_hub_url: https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.11/images/sha256-18e4d8f0b8ce7a7422c58046940dd5f32249960449fca09a562b65fb8eb1562a
|
||||||
components:
|
components:
|
||||||
ROCm: 7.0.0
|
ROCm: 7.1.0
|
||||||
JAX: 0.6.2
|
JAX: 0.7.1
|
||||||
Python: 3.10.18
|
Python: 3.12
|
||||||
Transformer Engine: 2.2.0.dev0+c91bac54
|
Transformer Engine: 2.4.0.dev0+281042de
|
||||||
hipBLASLt: 1.x.x
|
hipBLASLt: 1.2.x
|
||||||
model_groups:
|
model_groups:
|
||||||
- group: Meta Llama
|
- group: Meta Llama
|
||||||
tag: llama
|
tag: llama
|
||||||
|
|||||||
@@ -0,0 +1,64 @@
|
|||||||
|
dockers:
|
||||||
|
- pull_tag: rocm/jax-training:maxtext-v25.9.1
|
||||||
|
docker_hub_url: https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.9.1/images/sha256-60946cfbd470f6ee361fc9da740233a4fb2e892727f01719145b1f7627a1cff6
|
||||||
|
components:
|
||||||
|
ROCm: 7.0.0
|
||||||
|
JAX: 0.6.2
|
||||||
|
Python: 3.10.18
|
||||||
|
Transformer Engine: 2.2.0.dev0+c91bac54
|
||||||
|
hipBLASLt: 1.x.x
|
||||||
|
model_groups:
|
||||||
|
- group: Meta Llama
|
||||||
|
tag: llama
|
||||||
|
models:
|
||||||
|
- model: Llama 2 7B
|
||||||
|
mad_tag: jax_maxtext_train_llama-2-7b
|
||||||
|
model_repo: Llama-2-7B
|
||||||
|
precision: bf16
|
||||||
|
multinode_training_script: llama2_7b_multinode.sh
|
||||||
|
doc_options: ["single-node", "multi-node"]
|
||||||
|
- model: Llama 2 70B
|
||||||
|
mad_tag: jax_maxtext_train_llama-2-70b
|
||||||
|
model_repo: Llama-2-70B
|
||||||
|
precision: bf16
|
||||||
|
multinode_training_script: llama2_70b_multinode.sh
|
||||||
|
doc_options: ["single-node", "multi-node"]
|
||||||
|
- model: Llama 3 8B (multi-node)
|
||||||
|
mad_tag: jax_maxtext_train_llama-3-8b
|
||||||
|
multinode_training_script: llama3_8b_multinode.sh
|
||||||
|
doc_options: ["multi-node"]
|
||||||
|
- model: Llama 3 70B (multi-node)
|
||||||
|
mad_tag: jax_maxtext_train_llama-3-70b
|
||||||
|
multinode_training_script: llama3_70b_multinode.sh
|
||||||
|
doc_options: ["multi-node"]
|
||||||
|
- model: Llama 3.1 8B
|
||||||
|
mad_tag: jax_maxtext_train_llama-3.1-8b
|
||||||
|
model_repo: Llama-3.1-8B
|
||||||
|
precision: bf16
|
||||||
|
doc_options: ["single-node"]
|
||||||
|
- model: Llama 3.1 70B
|
||||||
|
mad_tag: jax_maxtext_train_llama-3.1-70b
|
||||||
|
model_repo: Llama-3.1-70B
|
||||||
|
precision: bf16
|
||||||
|
doc_options: ["single-node"]
|
||||||
|
- model: Llama 3.3 70B
|
||||||
|
mad_tag: jax_maxtext_train_llama-3.3-70b
|
||||||
|
model_repo: Llama-3.3-70B
|
||||||
|
precision: bf16
|
||||||
|
doc_options: ["single-node"]
|
||||||
|
- group: DeepSeek
|
||||||
|
tag: deepseek
|
||||||
|
models:
|
||||||
|
- model: DeepSeek-V2-Lite (16B)
|
||||||
|
mad_tag: jax_maxtext_train_deepseek-v2-lite-16b
|
||||||
|
model_repo: DeepSeek-V2-lite
|
||||||
|
precision: bf16
|
||||||
|
doc_options: ["single-node"]
|
||||||
|
- group: Mistral AI
|
||||||
|
tag: mistral
|
||||||
|
models:
|
||||||
|
- model: Mixtral 8x7B
|
||||||
|
mad_tag: jax_maxtext_train_mixtral-8x7b
|
||||||
|
model_repo: Mixtral-8x7B
|
||||||
|
precision: bf16
|
||||||
|
doc_options: ["single-node"]
|
||||||
@@ -33,18 +33,15 @@ It includes the following software components:
|
|||||||
- {{ component_version }}
|
- {{ component_version }}
|
||||||
|
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
{% if jax_version == "0.6.0" %}
|
|
||||||
.. note::
|
|
||||||
|
|
||||||
Shardy is a new config in JAX 0.6.0. You might get related errors if it's
|
|
||||||
not configured correctly. For now you can turn it off by setting
|
|
||||||
``shardy=False`` during the training run. You can also follow the `migration
|
|
||||||
guide <https://docs.jax.dev/en/latest/shardy_jax_migration.html>`__ to enable
|
|
||||||
it.
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
The ``rocm/jax-training:maxtext-v25.9`` has been updated to
|
||||||
|
``rocm/jax-training:maxtext-v25.9.1``. This revision should include
|
||||||
|
a fix to address segmentation fault issues during launch. See the
|
||||||
|
:doc:`versioned documentation <previous-versions/jax-maxtext-v25.9>`.
|
||||||
|
|
||||||
MaxText with on ROCm provides the following key features to train large language models efficiently:
|
MaxText with on ROCm provides the following key features to train large language models efficiently:
|
||||||
|
|
||||||
- Transformer Engine (TE)
|
- Transformer Engine (TE)
|
||||||
@@ -57,7 +54,7 @@ MaxText with on ROCm provides the following key features to train large language
|
|||||||
|
|
||||||
- NANOO FP8 (for MI300X series GPUs) and FP8 (for MI355X and MI350X) quantization support
|
- NANOO FP8 (for MI300X series GPUs) and FP8 (for MI355X and MI350X) quantization support
|
||||||
|
|
||||||
.. _amd-maxtext-model-support-v259:
|
.. _amd-maxtext-model-support-v25.11:
|
||||||
|
|
||||||
Supported models
|
Supported models
|
||||||
================
|
================
|
||||||
@@ -139,7 +136,7 @@ Use the following command to pull the Docker image from Docker Hub.
|
|||||||
|
|
||||||
docker pull {{ docker.pull_tag }}
|
docker pull {{ docker.pull_tag }}
|
||||||
|
|
||||||
.. _amd-maxtext-multi-node-setup-v259:
|
.. _amd-maxtext-multi-node-setup-v25.11:
|
||||||
|
|
||||||
Multi-node configuration
|
Multi-node configuration
|
||||||
------------------------
|
------------------------
|
||||||
@@ -147,7 +144,7 @@ Multi-node configuration
|
|||||||
See :doc:`/how-to/rocm-for-ai/system-setup/multi-node-setup` to configure your
|
See :doc:`/how-to/rocm-for-ai/system-setup/multi-node-setup` to configure your
|
||||||
environment for multi-node training.
|
environment for multi-node training.
|
||||||
|
|
||||||
.. _amd-maxtext-get-started-v259:
|
.. _amd-maxtext-get-started-v25.11:
|
||||||
|
|
||||||
Benchmarking
|
Benchmarking
|
||||||
============
|
============
|
||||||
@@ -172,7 +169,7 @@ benchmark results:
|
|||||||
.. tab-item:: MAD-integrated benchmarking
|
.. tab-item:: MAD-integrated benchmarking
|
||||||
|
|
||||||
The following run command is tailored to {{ model.model }}.
|
The following run command is tailored to {{ model.model }}.
|
||||||
See :ref:`amd-maxtext-model-support-v259` to switch to another available model.
|
See :ref:`amd-maxtext-model-support-v25.11` to switch to another available model.
|
||||||
|
|
||||||
1. Clone the ROCm Model Automation and Dashboarding (`<https://github.com/ROCm/MAD>`__) repository to a local
|
1. Clone the ROCm Model Automation and Dashboarding (`<https://github.com/ROCm/MAD>`__) repository to a local
|
||||||
directory and install the required packages on the host machine.
|
directory and install the required packages on the host machine.
|
||||||
@@ -203,7 +200,7 @@ benchmark results:
|
|||||||
.. tab-item:: Standalone benchmarking
|
.. tab-item:: Standalone benchmarking
|
||||||
|
|
||||||
The following commands are optimized for {{ model.model }}. See
|
The following commands are optimized for {{ model.model }}. See
|
||||||
:ref:`amd-maxtext-model-support-v259` to switch to another
|
:ref:`amd-maxtext-model-support-v25.11` to switch to another
|
||||||
available model. Some instructions and resources might not be
|
available model. Some instructions and resources might not be
|
||||||
available for all models and configurations.
|
available for all models and configurations.
|
||||||
|
|
||||||
@@ -325,15 +322,69 @@ benchmark results:
|
|||||||
|
|
||||||
sbatch -N <num_nodes> {{ model.multinode_training_script }}
|
sbatch -N <num_nodes> {{ model.multinode_training_script }}
|
||||||
|
|
||||||
|
.. _maxtext-rocprofv3:
|
||||||
|
|
||||||
|
.. rubric:: Profiling with rocprofv3
|
||||||
|
|
||||||
|
If you need to collect a trace and the JAX profiler isn't working, use ``rocprofv3`` provided by the :doc:`ROCprofiler-SDK <rocprofiler-sdk:index>` as a workaround. For example:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
rocprofv3 \
|
||||||
|
--hip-trace \
|
||||||
|
--kernel-trace \
|
||||||
|
--memory-copy-trace \
|
||||||
|
--rccl-trace \
|
||||||
|
--output-format pftrace \
|
||||||
|
-d ./v3_traces \ # output directory
|
||||||
|
-- ./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} # or desired command
|
||||||
|
|
||||||
|
You can set the directory where you want the .json traces to be
|
||||||
|
saved using ``-d <TRACE_DIRECTORY>``. The resulting traces can be
|
||||||
|
opened in Perfetto: `<https://ui.perfetto.dev/>`__.
|
||||||
|
|
||||||
{% else %}
|
{% else %}
|
||||||
.. rubric:: Multi-node training
|
.. rubric:: Multi-node training
|
||||||
|
|
||||||
For multi-node training examples, choose a model from :ref:`amd-maxtext-model-support-v259`
|
For multi-node training examples, choose a model from :ref:`amd-maxtext-model-support-v25.11`
|
||||||
with an available `multi-node training script <https://github.com/ROCm/MAD/tree/develop/scripts/jax-maxtext/gpu-rocm>`__.
|
with an available `multi-node training script <https://github.com/ROCm/MAD/tree/develop/scripts/jax-maxtext/gpu-rocm>`__.
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
|
||||||
|
Known issues
|
||||||
|
============
|
||||||
|
|
||||||
|
- Minor performance regression (< 4%) for BF16 quantization in Llama models and Mixtral 8x7b.
|
||||||
|
|
||||||
|
- You might see minor loss spikes, or loss curve may have slightly higher
|
||||||
|
convergence end values compared to the previous ``jax-training`` image.
|
||||||
|
|
||||||
|
- For FP8 training on MI355, many models will display a warning message like:
|
||||||
|
``Warning: Latency not found for MI_M=16, MI_N=16, MI_K=128,
|
||||||
|
mi_input_type=BFloat8Float8_fnuz. Returning latency value of 32 (really
|
||||||
|
slow).`` The compile step may take longer than usual, but training will run.
|
||||||
|
This will be fixed in a future release.
|
||||||
|
|
||||||
|
- The built-in JAX profiler isn't working. If you need to collect a trace and
|
||||||
|
the JAX profiler isn't working, use ``rocprofv3`` provided by the
|
||||||
|
:doc:`ROCprofiler-SDK <rocprofiler-sdk:index>` as a workaround. For example:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
rocprofv3 \
|
||||||
|
--hip-trace \
|
||||||
|
--kernel-trace \
|
||||||
|
--memory-copy-trace \
|
||||||
|
--rccl-trace \
|
||||||
|
--output-format pftrace \
|
||||||
|
-d ./v3_traces \ # output directory
|
||||||
|
-- ./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} # or desired command
|
||||||
|
|
||||||
|
You can set the directory where you want the .json traces to be
|
||||||
|
saved using ``-d <TRACE_DIRECTORY>``. The resulting traces can be
|
||||||
|
opened in Perfetto: `<https://ui.perfetto.dev/>`__.
|
||||||
|
|
||||||
Further reading
|
Further reading
|
||||||
===============
|
===============
|
||||||
|
|
||||||
|
|||||||
@@ -17,13 +17,22 @@ previous releases of the ``ROCm/jax-training`` Docker image on `Docker Hub <http
|
|||||||
- Components
|
- Components
|
||||||
- Resources
|
- Resources
|
||||||
|
|
||||||
* - 25.9 (latest)
|
* - 25.11
|
||||||
|
-
|
||||||
|
* ROCm 7.1.0
|
||||||
|
* JAX 0.7.1
|
||||||
|
-
|
||||||
|
* :doc:`Documentation <../jax-maxtext>`
|
||||||
|
* `Docker Hub <https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.11/images/sha256-18e4d8f0b8ce7a7422c58046940dd5f32249960449fca09a562b65fb8eb1562a>`__
|
||||||
|
|
||||||
|
* - 25.9.1
|
||||||
-
|
-
|
||||||
* ROCm 7.0.0
|
* ROCm 7.0.0
|
||||||
* JAX 0.6.2
|
* JAX 0.6.2
|
||||||
-
|
-
|
||||||
* :doc:`Documentation <../jax-maxtext>`
|
* :doc:`Documentation <jax-maxtext-v25.9>`
|
||||||
* `Docker Hub <https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.7-jax060/images/sha256-7352212ae033a76dca2b9dceffc23c1b5f1a61a7a560082cf747a9bf1acfc9ce>`__
|
* `Docker Hub (25.9.1) <https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.9.1/images/sha256-60946cfbd470f6ee361fc9da740233a4fb2e892727f01719145b1f7627a1cff6>`__
|
||||||
|
* `Docker Hub (25.9) <https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.9/images/sha256-4bb16ab58279ef09cb7a5e362c38e3fe3f901de44d8dbac5d0cb3bac5686441e>`__
|
||||||
|
|
||||||
* - 25.7
|
* - 25.7
|
||||||
-
|
-
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ provides a prebuilt environment for training on AMD Instinct MI300X and MI325X G
|
|||||||
including essential components like JAX, XLA, ROCm libraries, and MaxText utilities.
|
including essential components like JAX, XLA, ROCm libraries, and MaxText utilities.
|
||||||
It includes the following software components:
|
It includes the following software components:
|
||||||
|
|
||||||
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml
|
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.7-benchmark-models.yaml
|
||||||
|
|
||||||
{% set dockers = data.dockers %}
|
{% set dockers = data.dockers %}
|
||||||
.. tab-set::
|
.. tab-set::
|
||||||
@@ -80,7 +80,7 @@ series GPUs. Some instructions, commands, and available training
|
|||||||
configurations in this documentation might vary by model -- select one to get
|
configurations in this documentation might vary by model -- select one to get
|
||||||
started.
|
started.
|
||||||
|
|
||||||
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml
|
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.7-benchmark-models.yaml
|
||||||
|
|
||||||
{% set model_groups = data.model_groups %}
|
{% set model_groups = data.model_groups %}
|
||||||
.. raw:: html
|
.. raw:: html
|
||||||
@@ -144,7 +144,7 @@ Pull the Docker image
|
|||||||
|
|
||||||
Use the following command to pull the Docker image from Docker Hub.
|
Use the following command to pull the Docker image from Docker Hub.
|
||||||
|
|
||||||
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml
|
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.7-benchmark-models.yaml
|
||||||
|
|
||||||
{% set dockers = data.dockers %}
|
{% set dockers = data.dockers %}
|
||||||
.. tab-set::
|
.. tab-set::
|
||||||
@@ -177,7 +177,7 @@ Benchmarking
|
|||||||
Once the setup is complete, choose between two options to reproduce the
|
Once the setup is complete, choose between two options to reproduce the
|
||||||
benchmark results:
|
benchmark results:
|
||||||
|
|
||||||
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml
|
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.7-benchmark-models.yaml
|
||||||
|
|
||||||
.. _vllm-benchmark-mad:
|
.. _vllm-benchmark-mad:
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,365 @@
|
|||||||
|
:orphan:
|
||||||
|
|
||||||
|
.. meta::
|
||||||
|
:description: How to train a model using JAX MaxText for ROCm.
|
||||||
|
:keywords: ROCm, AI, LLM, train, jax, torch, Llama, flux, tutorial, docker
|
||||||
|
|
||||||
|
******************************************
|
||||||
|
Training a model with JAX MaxText on ROCm
|
||||||
|
******************************************
|
||||||
|
|
||||||
|
.. caution::
|
||||||
|
|
||||||
|
This documentation does not reflect the latest version of ROCm JAX MaxText
|
||||||
|
training performance documentation. See :doc:`../jax-maxtext` for the latest version.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
We have refreshed the ``rocm/jax-training:maxtext-v25.9`` image as
|
||||||
|
`rocm/jax-training:maxtext-v25.9.1`. This should include a fix to address
|
||||||
|
segmentation fault issues during launch.
|
||||||
|
|
||||||
|
The MaxText for ROCm training Docker image
|
||||||
|
provides a prebuilt environment for training on AMD Instinct MI355X, MI350X, MI325X, and MI300X GPUs,
|
||||||
|
including essential components like JAX, XLA, ROCm libraries, and MaxText utilities.
|
||||||
|
It includes the following software components:
|
||||||
|
|
||||||
|
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.9-benchmark-models.yaml
|
||||||
|
|
||||||
|
{% set dockers = data.dockers %}
|
||||||
|
.. tab-set::
|
||||||
|
|
||||||
|
{% for docker in dockers %}
|
||||||
|
{% set jax_version = docker.components["JAX"] %}
|
||||||
|
|
||||||
|
.. tab-item:: ``{{ docker.pull_tag }}``
|
||||||
|
:sync: {{ docker.pull_tag }}
|
||||||
|
|
||||||
|
.. list-table::
|
||||||
|
:header-rows: 1
|
||||||
|
|
||||||
|
* - Software component
|
||||||
|
- Version
|
||||||
|
|
||||||
|
{% for component_name, component_version in docker.components.items() %}
|
||||||
|
* - {{ component_name }}
|
||||||
|
- {{ component_version }}
|
||||||
|
|
||||||
|
{% endfor %}
|
||||||
|
{% if jax_version == "0.6.0" %}
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Shardy is a new config in JAX 0.6.0. You might get related errors if it's
|
||||||
|
not configured correctly. For now you can turn it off by setting
|
||||||
|
``shardy=False`` during the training run. You can also follow the `migration
|
||||||
|
guide <https://docs.jax.dev/en/latest/shardy_jax_migration.html>`__ to enable
|
||||||
|
it.
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
MaxText with on ROCm provides the following key features to train large language models efficiently:
|
||||||
|
|
||||||
|
- Transformer Engine (TE)
|
||||||
|
|
||||||
|
- Flash Attention (FA) 3 -- with or without sequence input packing
|
||||||
|
|
||||||
|
- GEMM tuning
|
||||||
|
|
||||||
|
- Multi-node support
|
||||||
|
|
||||||
|
- NANOO FP8 (for MI300X series GPUs) and FP8 (for MI355X and MI350X) quantization support
|
||||||
|
|
||||||
|
.. _amd-maxtext-model-support-v259:
|
||||||
|
|
||||||
|
Supported models
|
||||||
|
================
|
||||||
|
|
||||||
|
The following models are pre-optimized for performance on AMD Instinct
|
||||||
|
GPUs. Some instructions, commands, and available training
|
||||||
|
configurations in this documentation might vary by model -- select one to get
|
||||||
|
started.
|
||||||
|
|
||||||
|
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.9-benchmark-models.yaml
|
||||||
|
|
||||||
|
{% set model_groups = data.model_groups %}
|
||||||
|
.. raw:: html
|
||||||
|
|
||||||
|
<div id="vllm-benchmark-ud-params-picker" class="container-fluid">
|
||||||
|
<div class="row gx-0">
|
||||||
|
<div class="col-2 me-1 px-2 model-param-head">Model</div>
|
||||||
|
<div class="row col-10 pe-0">
|
||||||
|
{% for model_group in model_groups %}
|
||||||
|
<div class="col-4 px-2 model-param" data-param-k="model-group" data-param-v="{{ model_group.tag }}" tabindex="0">{{ model_group.group }}</div>
|
||||||
|
{% endfor %}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="row gx-0 pt-1">
|
||||||
|
<div class="col-2 me-1 px-2 model-param-head">Variant</div>
|
||||||
|
<div class="row col-10 pe-0">
|
||||||
|
{% for model_group in model_groups %}
|
||||||
|
{% set models = model_group.models %}
|
||||||
|
{% for model in models %}
|
||||||
|
{% if models|length % 3 == 0 %}
|
||||||
|
<div class="col-4 px-2 model-param" data-param-k="model" data-param-v="{{ model.mad_tag }}" data-param-group="{{ model_group.tag }}" tabindex="0">{{ model.model }}</div>
|
||||||
|
{% else %}
|
||||||
|
<div class="col-6 px-2 model-param" data-param-k="model" data-param-v="{{ model.mad_tag }}" data-param-group="{{ model_group.tag }}" tabindex="0">{{ model.model }}</div>
|
||||||
|
{% endif %}
|
||||||
|
{% endfor %}
|
||||||
|
{% endfor %}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Some models, such as Llama 3, require an external license agreement through
|
||||||
|
a third party (for example, Meta).
|
||||||
|
|
||||||
|
System validation
|
||||||
|
=================
|
||||||
|
|
||||||
|
Before running AI workloads, it's important to validate that your AMD hardware is configured
|
||||||
|
correctly and performing optimally.
|
||||||
|
|
||||||
|
If you have already validated your system settings, including aspects like NUMA auto-balancing, you
|
||||||
|
can skip this step. Otherwise, complete the procedures in the :ref:`System validation and
|
||||||
|
optimization <rocm-for-ai-system-optimization>` guide to properly configure your system settings
|
||||||
|
before starting training.
|
||||||
|
|
||||||
|
To test for optimal performance, consult the recommended :ref:`System health benchmarks
|
||||||
|
<rocm-for-ai-system-health-bench>`. This suite of tests will help you verify and fine-tune your
|
||||||
|
system's configuration.
|
||||||
|
|
||||||
|
Environment setup
|
||||||
|
=================
|
||||||
|
|
||||||
|
This Docker image is optimized for specific model configurations outlined
|
||||||
|
as follows. Performance can vary for other training workloads, as AMD
|
||||||
|
doesn’t validate configurations and run conditions outside those described.
|
||||||
|
|
||||||
|
Pull the Docker image
|
||||||
|
---------------------
|
||||||
|
|
||||||
|
Use the following command to pull the Docker image from Docker Hub.
|
||||||
|
|
||||||
|
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.9-benchmark-models.yaml
|
||||||
|
|
||||||
|
{% set docker = data.dockers[0] %}
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
docker pull {{ docker.pull_tag }}
|
||||||
|
|
||||||
|
.. _amd-maxtext-multi-node-setup-v259:
|
||||||
|
|
||||||
|
Multi-node configuration
|
||||||
|
------------------------
|
||||||
|
|
||||||
|
See :doc:`/how-to/rocm-for-ai/system-setup/multi-node-setup` to configure your
|
||||||
|
environment for multi-node training.
|
||||||
|
|
||||||
|
.. _amd-maxtext-get-started-v259:
|
||||||
|
|
||||||
|
Benchmarking
|
||||||
|
============
|
||||||
|
|
||||||
|
Once the setup is complete, choose between two options to reproduce the
|
||||||
|
benchmark results:
|
||||||
|
|
||||||
|
.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.9-benchmark-models.yaml
|
||||||
|
|
||||||
|
.. _vllm-benchmark-mad:
|
||||||
|
|
||||||
|
{% set docker = data.dockers[0] %}
|
||||||
|
{% set model_groups = data.model_groups %}
|
||||||
|
{% for model_group in model_groups %}
|
||||||
|
{% for model in model_group.models %}
|
||||||
|
|
||||||
|
.. container:: model-doc {{model.mad_tag}}
|
||||||
|
|
||||||
|
.. tab-set::
|
||||||
|
|
||||||
|
{% if model.mad_tag and "single-node" in model.doc_options %}
|
||||||
|
.. tab-item:: MAD-integrated benchmarking
|
||||||
|
|
||||||
|
The following run command is tailored to {{ model.model }}.
|
||||||
|
See :ref:`amd-maxtext-model-support-v259` to switch to another available model.
|
||||||
|
|
||||||
|
1. Clone the ROCm Model Automation and Dashboarding (`<https://github.com/ROCm/MAD>`__) repository to a local
|
||||||
|
directory and install the required packages on the host machine.
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
git clone https://github.com/ROCm/MAD
|
||||||
|
cd MAD
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
2. Use this command to run the performance benchmark test on the {{ model.model }} model
|
||||||
|
using one GPU with the :literal:`{{model.precision}}` data type on the host machine.
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models"
|
||||||
|
madengine run \
|
||||||
|
--tags {{model.mad_tag}} \
|
||||||
|
--keep-model-dir \
|
||||||
|
--live-output \
|
||||||
|
--timeout 28800
|
||||||
|
|
||||||
|
MAD launches a Docker container with the name
|
||||||
|
``container_ci-{{model.mad_tag}}``. The latency and throughput reports of the
|
||||||
|
model are collected in the following path: ``~/MAD/perf.csv/``.
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
.. tab-item:: Standalone benchmarking
|
||||||
|
|
||||||
|
The following commands are optimized for {{ model.model }}. See
|
||||||
|
:ref:`amd-maxtext-model-support-v259` to switch to another
|
||||||
|
available model. Some instructions and resources might not be
|
||||||
|
available for all models and configurations.
|
||||||
|
|
||||||
|
.. rubric:: Download the Docker image and required scripts
|
||||||
|
|
||||||
|
Run the JAX MaxText benchmark tool independently by starting the
|
||||||
|
Docker container as shown in the following snippet.
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
docker pull {{ docker.pull_tag }}
|
||||||
|
|
||||||
|
{% if model.model_repo and "single-node" in model.doc_options %}
|
||||||
|
.. rubric:: Single node training
|
||||||
|
|
||||||
|
1. Set up environment variables.
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
export MAD_SECRETS_HFTOKEN=<Your Hugging Face token>
|
||||||
|
export HF_HOME=<Location of saved/cached Hugging Face models>
|
||||||
|
|
||||||
|
``MAD_SECRETS_HFTOKEN`` is your Hugging Face access token to access models, tokenizers, and data.
|
||||||
|
See `User access tokens <https://huggingface.co/docs/hub/en/security-tokens>`__.
|
||||||
|
|
||||||
|
``HF_HOME`` is where ``huggingface_hub`` will store local data. See `huggingface_hub CLI <https://huggingface.co/docs/huggingface_hub/main/en/guides/cli#huggingface-cli-download>`__.
|
||||||
|
If you already have downloaded or cached Hugging Face artifacts, set this variable to that path.
|
||||||
|
Downloaded files typically get cached to ``~/.cache/huggingface``.
|
||||||
|
|
||||||
|
2. Launch the Docker container.
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
docker run -it \
|
||||||
|
--device=/dev/dri \
|
||||||
|
--device=/dev/kfd \
|
||||||
|
--network host \
|
||||||
|
--ipc host \
|
||||||
|
--group-add video \
|
||||||
|
--cap-add=SYS_PTRACE \
|
||||||
|
--security-opt seccomp=unconfined \
|
||||||
|
--privileged \
|
||||||
|
-v $HOME:$HOME \
|
||||||
|
-v $HOME/.ssh:/root/.ssh \
|
||||||
|
-v $HF_HOME:/hf_cache \
|
||||||
|
-e HF_HOME=/hf_cache \
|
||||||
|
-e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN
|
||||||
|
--shm-size 64G \
|
||||||
|
--name training_env \
|
||||||
|
{{ docker.pull_tag }}
|
||||||
|
|
||||||
|
3. In the Docker container, clone the ROCm MAD repository and navigate to the
|
||||||
|
benchmark scripts directory at ``MAD/scripts/jax-maxtext``.
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
git clone https://github.com/ROCm/MAD
|
||||||
|
cd MAD/scripts/jax-maxtext
|
||||||
|
|
||||||
|
4. Run the setup scripts to install libraries and datasets needed
|
||||||
|
for benchmarking.
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
./jax-maxtext_benchmark_setup.sh -m {{ model.model_repo }}
|
||||||
|
|
||||||
|
5. To run the training benchmark without quantization, use the following command:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }}
|
||||||
|
|
||||||
|
For quantized training, run the script with the appropriate option for your Instinct GPU.
|
||||||
|
|
||||||
|
.. tab-set::
|
||||||
|
|
||||||
|
.. tab-item:: MI355X and MI350X
|
||||||
|
|
||||||
|
For ``fp8`` quantized training on MI355X and MI350X GPUs, use the following command:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} -q fp8
|
||||||
|
|
||||||
|
{% if model.model_repo not in ["Llama-3.1-70B", "Llama-3.3-70B"] %}
|
||||||
|
.. tab-item:: MI325X and MI300X
|
||||||
|
|
||||||
|
For ``nanoo_fp8`` quantized training on MI300X series GPUs, use the following command:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} -q nanoo_fp8
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% endif %}
|
||||||
|
{% if model.multinode_training_script and "multi-node" in model.doc_options %}
|
||||||
|
.. rubric:: Multi-node training
|
||||||
|
|
||||||
|
The following examples use SLURM to run on multiple nodes.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
The following scripts will launch the Docker container and run the
|
||||||
|
benchmark. Run them outside of any Docker container.
|
||||||
|
|
||||||
|
1. Make sure ``$HF_HOME`` is set before running the test. See
|
||||||
|
`ROCm benchmarking <https://github.com/ROCm/MAD/blob/develop/scripts/jax-maxtext/gpu-rocm/readme.md>`__
|
||||||
|
for more details on downloading the Llama models before running the
|
||||||
|
benchmark.
|
||||||
|
|
||||||
|
2. To run multi-node training for {{ model.model }},
|
||||||
|
use the
|
||||||
|
`multi-node training script <https://github.com/ROCm/MAD/blob/develop/scripts/jax-maxtext/gpu-rocm/{{ model.multinode_training_script }}>`__
|
||||||
|
under the ``scripts/jax-maxtext/gpu-rocm/`` directory.
|
||||||
|
|
||||||
|
3. Run the multi-node training benchmark script.
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
sbatch -N <num_nodes> {{ model.multinode_training_script }}
|
||||||
|
|
||||||
|
{% else %}
|
||||||
|
.. rubric:: Multi-node training
|
||||||
|
|
||||||
|
For multi-node training examples, choose a model from :ref:`amd-maxtext-model-support-v259`
|
||||||
|
with an available `multi-node training script <https://github.com/ROCm/MAD/tree/develop/scripts/jax-maxtext/gpu-rocm>`__.
|
||||||
|
{% endif %}
|
||||||
|
{% endfor %}
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
Further reading
|
||||||
|
===============
|
||||||
|
|
||||||
|
- To learn more about MAD and the ``madengine`` CLI, see the `MAD usage guide <https://github.com/ROCm/MAD?tab=readme-ov-file#usage-guide>`__.
|
||||||
|
|
||||||
|
- To learn more about system settings and management practices to configure your system for
|
||||||
|
AMD Instinct MI300X Series GPUs, see `AMD Instinct MI300X system optimization <https://instinct.docs.amd.com/projects/amdgpu-docs/en/latest/system-optimization/mi300x.html>`_.
|
||||||
|
|
||||||
|
- For a list of other ready-made Docker images for AI with ROCm, see
|
||||||
|
`AMD Infinity Hub <https://www.amd.com/en/developer/resources/infinity-hub.html#f-amd_hub_category=AI%20%26%20ML%20Models>`_.
|
||||||
|
|
||||||
|
Previous versions
|
||||||
|
=================
|
||||||
|
|
||||||
|
See :doc:`jax-maxtext-history` to find documentation for previous releases
|
||||||
|
of the ``ROCm/jax-training`` Docker image.
|
||||||
Reference in New Issue
Block a user