From 79acda67757817f16ebf8dcd49fb7e5459473dc9 Mon Sep 17 00:00:00 2001 From: peterjunpark Date: Fri, 17 Oct 2025 11:54:39 -0400 Subject: [PATCH] JAX Maxtext v25.9 doc update (#5532) (#5533) * archive previous version (25.7) * update docker components list for 25.9 * update template * update docker pull tag * update * fix intro (cherry picked from commit a613bd68244078235df737e91304c0b9712b6da5) --- .../jax-maxtext-benchmark-models.yaml | 64 ++- .../jax-maxtext-v25.7-benchmark-models.yaml | 72 ++++ .../training/benchmark-docker/jax-maxtext.rst | 125 +++--- .../previous-versions/jax-maxtext-history.rst | 24 +- .../previous-versions/jax-maxtext-v25.7.rst | 366 ++++++++++++++++++ 5 files changed, 541 insertions(+), 110 deletions(-) create mode 100644 docs/data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.7-benchmark-models.yaml create mode 100644 docs/how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-v25.7.rst diff --git a/docs/data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml b/docs/data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml index b1b971708..f670f8a8c 100644 --- a/docs/data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml +++ b/docs/data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml @@ -1,47 +1,16 @@ dockers: - - pull_tag: rocm/jax-training:maxtext-v25.7-jax060 + - pull_tag: rocm/jax-training:maxtext-v25.9 docker_hub_url: https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.7/images/sha256-45f4c727d4019a63fc47313d3a5f5a5105569539294ddfd2d742218212ae9025 components: - ROCm: 6.4.1 - JAX: 0.6.0 - Python: 3.10.12 - Transformer Engine: 2.1.0+90d703dd - hipBLASLt: 1.1.0-499ece1c21 - - pull_tag: rocm/jax-training:maxtext-v25.7 - docker_hub_url: https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.7/images/sha256-45f4c727d4019a63fc47313d3a5f5a5105569539294ddfd2d742218212ae9025 - components: - ROCm: 6.4.1 - JAX: 0.5.0 - Python: 3.10.12 - Transformer Engine: 2.1.0+90d703dd + ROCm: 7.0.0 + JAX: 0.6.2 + Python: 3.10.18 + Transformer Engine: 2.2.0.dev0+c91bac54 hipBLASLt: 1.x.x model_groups: - group: Meta Llama tag: llama models: - - model: Llama 3.3 70B - mad_tag: jax_maxtext_train_llama-3.3-70b - model_repo: Llama-3.3-70B - precision: bf16 - doc_options: ["single-node"] - - model: Llama 3.1 8B - mad_tag: jax_maxtext_train_llama-3.1-8b - model_repo: Llama-3.1-8B - precision: bf16 - doc_options: ["single-node"] - - model: Llama 3.1 70B - mad_tag: jax_maxtext_train_llama-3.1-70b - model_repo: Llama-3.1-70B - precision: bf16 - doc_options: ["single-node"] - - model: Llama 3 8B - mad_tag: jax_maxtext_train_llama-3-8b - multinode_training_script: llama3_8b_multinode.sh - doc_options: ["multi-node"] - - model: Llama 3 70B - mad_tag: jax_maxtext_train_llama-3-70b - multinode_training_script: llama3_70b_multinode.sh - doc_options: ["multi-node"] - model: Llama 2 7B mad_tag: jax_maxtext_train_llama-2-7b model_repo: Llama-2-7B @@ -54,6 +23,29 @@ model_groups: precision: bf16 multinode_training_script: llama2_70b_multinode.sh doc_options: ["single-node", "multi-node"] + - model: Llama 3 8B (multi-node) + mad_tag: jax_maxtext_train_llama-3-8b + multinode_training_script: llama3_8b_multinode.sh + doc_options: ["multi-node"] + - model: Llama 3 70B (multi-node) + mad_tag: jax_maxtext_train_llama-3-70b + multinode_training_script: llama3_70b_multinode.sh + doc_options: ["multi-node"] + - model: Llama 3.1 8B + mad_tag: jax_maxtext_train_llama-3.1-8b + model_repo: Llama-3.1-8B + precision: bf16 + doc_options: ["single-node"] + - model: Llama 3.1 70B + mad_tag: jax_maxtext_train_llama-3.1-70b + model_repo: Llama-3.1-70B + precision: bf16 + doc_options: ["single-node"] + - model: Llama 3.3 70B + mad_tag: jax_maxtext_train_llama-3.3-70b + model_repo: Llama-3.3-70B + precision: bf16 + doc_options: ["single-node"] - group: DeepSeek tag: deepseek models: diff --git a/docs/data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.7-benchmark-models.yaml b/docs/data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.7-benchmark-models.yaml new file mode 100644 index 000000000..b1b971708 --- /dev/null +++ b/docs/data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.7-benchmark-models.yaml @@ -0,0 +1,72 @@ +dockers: + - pull_tag: rocm/jax-training:maxtext-v25.7-jax060 + docker_hub_url: https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.7/images/sha256-45f4c727d4019a63fc47313d3a5f5a5105569539294ddfd2d742218212ae9025 + components: + ROCm: 6.4.1 + JAX: 0.6.0 + Python: 3.10.12 + Transformer Engine: 2.1.0+90d703dd + hipBLASLt: 1.1.0-499ece1c21 + - pull_tag: rocm/jax-training:maxtext-v25.7 + docker_hub_url: https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.7/images/sha256-45f4c727d4019a63fc47313d3a5f5a5105569539294ddfd2d742218212ae9025 + components: + ROCm: 6.4.1 + JAX: 0.5.0 + Python: 3.10.12 + Transformer Engine: 2.1.0+90d703dd + hipBLASLt: 1.x.x +model_groups: + - group: Meta Llama + tag: llama + models: + - model: Llama 3.3 70B + mad_tag: jax_maxtext_train_llama-3.3-70b + model_repo: Llama-3.3-70B + precision: bf16 + doc_options: ["single-node"] + - model: Llama 3.1 8B + mad_tag: jax_maxtext_train_llama-3.1-8b + model_repo: Llama-3.1-8B + precision: bf16 + doc_options: ["single-node"] + - model: Llama 3.1 70B + mad_tag: jax_maxtext_train_llama-3.1-70b + model_repo: Llama-3.1-70B + precision: bf16 + doc_options: ["single-node"] + - model: Llama 3 8B + mad_tag: jax_maxtext_train_llama-3-8b + multinode_training_script: llama3_8b_multinode.sh + doc_options: ["multi-node"] + - model: Llama 3 70B + mad_tag: jax_maxtext_train_llama-3-70b + multinode_training_script: llama3_70b_multinode.sh + doc_options: ["multi-node"] + - model: Llama 2 7B + mad_tag: jax_maxtext_train_llama-2-7b + model_repo: Llama-2-7B + precision: bf16 + multinode_training_script: llama2_7b_multinode.sh + doc_options: ["single-node", "multi-node"] + - model: Llama 2 70B + mad_tag: jax_maxtext_train_llama-2-70b + model_repo: Llama-2-70B + precision: bf16 + multinode_training_script: llama2_70b_multinode.sh + doc_options: ["single-node", "multi-node"] + - group: DeepSeek + tag: deepseek + models: + - model: DeepSeek-V2-Lite (16B) + mad_tag: jax_maxtext_train_deepseek-v2-lite-16b + model_repo: DeepSeek-V2-lite + precision: bf16 + doc_options: ["single-node"] + - group: Mistral AI + tag: mistral + models: + - model: Mixtral 8x7B + mad_tag: jax_maxtext_train_mixtral-8x7b + model_repo: Mixtral-8x7B + precision: bf16 + doc_options: ["single-node"] diff --git a/docs/how-to/rocm-for-ai/training/benchmark-docker/jax-maxtext.rst b/docs/how-to/rocm-for-ai/training/benchmark-docker/jax-maxtext.rst index eec785b7b..7c073ac94 100644 --- a/docs/how-to/rocm-for-ai/training/benchmark-docker/jax-maxtext.rst +++ b/docs/how-to/rocm-for-ai/training/benchmark-docker/jax-maxtext.rst @@ -6,14 +6,8 @@ Training a model with JAX MaxText on ROCm ****************************************** -MaxText is a high-performance, open-source framework built on the Google JAX -machine learning library to train LLMs at scale. The MaxText framework for -ROCm is an optimized fork of the upstream -``__ enabling efficient AI workloads -on AMD MI300X series GPUs. - The MaxText for ROCm training Docker image -provides a prebuilt environment for training on AMD Instinct MI300X and MI325X GPUs, +provides a prebuilt environment for training on AMD Instinct MI355X, MI350X, MI325X, and MI300X GPUs, including essential components like JAX, XLA, ROCm libraries, and MaxText utilities. It includes the following software components: @@ -61,15 +55,15 @@ MaxText with on ROCm provides the following key features to train large language - Multi-node support -- NANOO FP8 quantization support +- NANOO FP8 (for MI300X series GPUs) and FP8 (for MI355X and MI350X) quantization support -.. _amd-maxtext-model-support-v257: +.. _amd-maxtext-model-support-v259: Supported models ================ -The following models are pre-optimized for performance on AMD Instinct MI300 -series GPUs. Some instructions, commands, and available training +The following models are pre-optimized for performance on AMD Instinct +GPUs. Some instructions, commands, and available training configurations in this documentation might vary by model -- select one to get started. @@ -139,22 +133,13 @@ Use the following command to pull the Docker image from Docker Hub. .. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml - {% set dockers = data.dockers %} - .. tab-set:: + {% set docker = data.dockers[0] %} - {% for docker in dockers %} - {% set jax_version = docker.components["JAX"] %} + .. code-block:: shell - .. tab-item:: JAX {{ jax_version }} - :sync: {{ docker.pull_tag }} + docker pull {{ docker.pull_tag }} - .. code-block:: shell - - docker pull {{ docker.pull_tag }} - - {% endfor %} - -.. _amd-maxtext-multi-node-setup-v257: +.. _amd-maxtext-multi-node-setup-v259: Multi-node configuration ------------------------ @@ -162,7 +147,7 @@ Multi-node configuration See :doc:`/how-to/rocm-for-ai/system-setup/multi-node-setup` to configure your environment for multi-node training. -.. _amd-maxtext-get-started-v257: +.. _amd-maxtext-get-started-v259: Benchmarking ============ @@ -174,7 +159,7 @@ benchmark results: .. _vllm-benchmark-mad: - {% set dockers = data.dockers %} + {% set docker = data.dockers[0] %} {% set model_groups = data.model_groups %} {% for model_group in model_groups %} {% for model in model_group.models %} @@ -186,6 +171,9 @@ benchmark results: {% if model.mad_tag and "single-node" in model.doc_options %} .. tab-item:: MAD-integrated benchmarking + The following run command is tailored to {{ model.model }}. + See :ref:`amd-maxtext-model-support-v259` to switch to another available model. + 1. Clone the ROCm Model Automation and Dashboarding (``__) repository to a local directory and install the required packages on the host machine. @@ -214,22 +202,19 @@ benchmark results: .. tab-item:: Standalone benchmarking + The following commands are optimized for {{ model.model }}. See + :ref:`amd-maxtext-model-support-v259` to switch to another + available model. Some instructions and resources might not be + available for all models and configurations. + .. rubric:: Download the Docker image and required scripts Run the JAX MaxText benchmark tool independently by starting the Docker container as shown in the following snippet. - .. tab-set:: - {% for docker in dockers %} - {% set jax_version = docker.components["JAX"] %} + .. code-block:: shell - .. tab-item:: JAX {{ jax_version }} - :sync: {{ docker.pull_tag }} - - .. code-block:: shell - - docker pull {{ docker.pull_tag }} - {% endfor %} + docker pull {{ docker.pull_tag }} {% if model.model_repo and "single-node" in model.doc_options %} .. rubric:: Single node training @@ -250,33 +235,25 @@ benchmark results: 2. Launch the Docker container. - .. tab-set:: - {% for docker in dockers %} - {% set jax_version = docker.components["JAX"] %} + .. code-block:: shell - .. tab-item:: JAX {{ jax_version }} - :sync: {{ docker.pull_tag }} - - .. code-block:: shell - - docker run -it \ - --device=/dev/dri \ - --device=/dev/kfd \ - --network host \ - --ipc host \ - --group-add video \ - --cap-add=SYS_PTRACE \ - --security-opt seccomp=unconfined \ - --privileged \ - -v $HOME:$HOME \ - -v $HOME/.ssh:/root/.ssh \ - -v $HF_HOME:/hf_cache \ - -e HF_HOME=/hf_cache \ - -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN - --shm-size 64G \ - --name training_env \ - {{ docker.pull_tag }} - {% endfor %} + docker run -it \ + --device=/dev/dri \ + --device=/dev/kfd \ + --network host \ + --ipc host \ + --group-add video \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + --privileged \ + -v $HOME:$HOME \ + -v $HOME/.ssh:/root/.ssh \ + -v $HF_HOME:/hf_cache \ + -e HF_HOME=/hf_cache \ + -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN + --shm-size 64G \ + --name training_env \ + {{ docker.pull_tag }} 3. In the Docker container, clone the ROCm MAD repository and navigate to the benchmark scripts directory at ``MAD/scripts/jax-maxtext``. @@ -299,11 +276,27 @@ benchmark results: ./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} - For quantized training, use the following command: + For quantized training, run the script with the appropriate option for your Instinct GPU. - .. code-block:: shell + .. tab-set:: - ./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} -q nanoo_fp8 + .. tab-item:: MI355X and MI350X + + For ``fp8`` quantized training on MI355X and MI350X GPUs, use the following command: + + .. code-block:: shell + + ./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} -q fp8 + + {% if model.model_repo not in ["Llama-3.1-70B", "Llama-3.3-70B"] %} + .. tab-item:: MI325X and MI300X + + For ``nanoo_fp8`` quantized training on MI300X series GPUs, use the following command: + + .. code-block:: shell + + ./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} -q nanoo_fp8 + {% endif %} {% endif %} {% if model.multinode_training_script and "multi-node" in model.doc_options %} @@ -335,7 +328,7 @@ benchmark results: {% else %} .. rubric:: Multi-node training - For multi-node training examples, choose a model from :ref:`amd-maxtext-model-support-v257` + For multi-node training examples, choose a model from :ref:`amd-maxtext-model-support-v259` with an available `multi-node training script `__. {% endif %} {% endfor %} diff --git a/docs/how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-history.rst b/docs/how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-history.rst index e4d039356..c1444a4a1 100644 --- a/docs/how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-history.rst +++ b/docs/how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-history.rst @@ -17,27 +17,35 @@ previous releases of the ``ROCm/jax-training`` Docker image on `Docker Hub ` + * `Docker Hub `__ + + * - 25.7 + - * ROCm 6.4.1 * JAX 0.6.0, 0.5.0 - - - * :doc:`Documentation <../jax-maxtext>` + - + * :doc:`Documentation ` * `Docker Hub (JAX 0.6.0) `__ * `Docker Hub (JAX 0.5.0) `__ * - 25.5 - - + - * ROCm 6.3.4 * JAX 0.4.35 - - + - * :doc:`Documentation ` * `Docker Hub `__ * - 25.4 - - + - * ROCm 6.3.0 * JAX 0.4.31 - - + - * :doc:`Documentation ` * `Docker Hub `__ diff --git a/docs/how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-v25.7.rst b/docs/how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-v25.7.rst new file mode 100644 index 000000000..ee4e264fe --- /dev/null +++ b/docs/how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-v25.7.rst @@ -0,0 +1,366 @@ +:orphan: + +.. meta:: + :description: How to train a model using JAX MaxText for ROCm. + :keywords: ROCm, AI, LLM, train, jax, torch, Llama, flux, tutorial, docker + +****************************************** +Training a model with JAX MaxText on ROCm +****************************************** + +.. caution:: + + This documentation does not reflect the latest version of ROCm JAX MaxText + training performance documentation. See :doc:`../jax-maxtext` for the latest version. + +MaxText is a high-performance, open-source framework built on the Google JAX +machine learning library to train LLMs at scale. The MaxText framework for +ROCm is an optimized fork of the upstream +``__ enabling efficient AI workloads +on AMD MI300X series GPUs. + +The MaxText for ROCm training Docker image +provides a prebuilt environment for training on AMD Instinct MI300X and MI325X GPUs, +including essential components like JAX, XLA, ROCm libraries, and MaxText utilities. +It includes the following software components: + +.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml + + {% set dockers = data.dockers %} + .. tab-set:: + + {% for docker in dockers %} + {% set jax_version = docker.components["JAX"] %} + + .. tab-item:: ``{{ docker.pull_tag }}`` + :sync: {{ docker.pull_tag }} + + .. list-table:: + :header-rows: 1 + + * - Software component + - Version + + {% for component_name, component_version in docker.components.items() %} + * - {{ component_name }} + - {{ component_version }} + + {% endfor %} + {% if jax_version == "0.6.0" %} + .. note:: + + Shardy is a new config in JAX 0.6.0. You might get related errors if it's + not configured correctly. For now you can turn it off by setting + ``shardy=False`` during the training run. You can also follow the `migration + guide `__ to enable + it. + {% endif %} + + {% endfor %} + +MaxText with on ROCm provides the following key features to train large language models efficiently: + +- Transformer Engine (TE) + +- Flash Attention (FA) 3 -- with or without sequence input packing + +- GEMM tuning + +- Multi-node support + +- NANOO FP8 quantization support + +.. _amd-maxtext-model-support-v257: + +Supported models +================ + +The following models are pre-optimized for performance on AMD Instinct MI300 +series GPUs. Some instructions, commands, and available training +configurations in this documentation might vary by model -- select one to get +started. + +.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml + + {% set model_groups = data.model_groups %} + .. raw:: html + +
+
+
Model
+
+ {% for model_group in model_groups %} +
{{ model_group.group }}
+ {% endfor %} +
+
+ +
+
Variant
+
+ {% for model_group in model_groups %} + {% set models = model_group.models %} + {% for model in models %} + {% if models|length % 3 == 0 %} +
{{ model.model }}
+ {% else %} +
{{ model.model }}
+ {% endif %} + {% endfor %} + {% endfor %} +
+
+
+ +.. note:: + + Some models, such as Llama 3, require an external license agreement through + a third party (for example, Meta). + +System validation +================= + +Before running AI workloads, it's important to validate that your AMD hardware is configured +correctly and performing optimally. + +If you have already validated your system settings, including aspects like NUMA auto-balancing, you +can skip this step. Otherwise, complete the procedures in the :ref:`System validation and +optimization ` guide to properly configure your system settings +before starting training. + +To test for optimal performance, consult the recommended :ref:`System health benchmarks +`. This suite of tests will help you verify and fine-tune your +system's configuration. + +Environment setup +================= + +This Docker image is optimized for specific model configurations outlined +as follows. Performance can vary for other training workloads, as AMD +doesn’t validate configurations and run conditions outside those described. + +Pull the Docker image +--------------------- + +Use the following command to pull the Docker image from Docker Hub. + +.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml + + {% set dockers = data.dockers %} + .. tab-set:: + + {% for docker in dockers %} + {% set jax_version = docker.components["JAX"] %} + + .. tab-item:: JAX {{ jax_version }} + :sync: {{ docker.pull_tag }} + + .. code-block:: shell + + docker pull {{ docker.pull_tag }} + + {% endfor %} + +.. _amd-maxtext-multi-node-setup-v257: + +Multi-node configuration +------------------------ + +See :doc:`/how-to/rocm-for-ai/system-setup/multi-node-setup` to configure your +environment for multi-node training. + +.. _amd-maxtext-get-started-v257: + +Benchmarking +============ + +Once the setup is complete, choose between two options to reproduce the +benchmark results: + +.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml + + .. _vllm-benchmark-mad: + + {% set dockers = data.dockers %} + {% set model_groups = data.model_groups %} + {% for model_group in model_groups %} + {% for model in model_group.models %} + + .. container:: model-doc {{model.mad_tag}} + + .. tab-set:: + + {% if model.mad_tag and "single-node" in model.doc_options %} + .. tab-item:: MAD-integrated benchmarking + + 1. Clone the ROCm Model Automation and Dashboarding (``__) repository to a local + directory and install the required packages on the host machine. + + .. code-block:: shell + + git clone https://github.com/ROCm/MAD + cd MAD + pip install -r requirements.txt + + 2. Use this command to run the performance benchmark test on the {{ model.model }} model + using one GPU with the :literal:`{{model.precision}}` data type on the host machine. + + .. code-block:: shell + + export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models" + madengine run \ + --tags {{model.mad_tag}} \ + --keep-model-dir \ + --live-output \ + --timeout 28800 + + MAD launches a Docker container with the name + ``container_ci-{{model.mad_tag}}``. The latency and throughput reports of the + model are collected in the following path: ``~/MAD/perf.csv/``. + {% endif %} + + .. tab-item:: Standalone benchmarking + + .. rubric:: Download the Docker image and required scripts + + Run the JAX MaxText benchmark tool independently by starting the + Docker container as shown in the following snippet. + + .. tab-set:: + {% for docker in dockers %} + {% set jax_version = docker.components["JAX"] %} + + .. tab-item:: JAX {{ jax_version }} + :sync: {{ docker.pull_tag }} + + .. code-block:: shell + + docker pull {{ docker.pull_tag }} + {% endfor %} + + {% if model.model_repo and "single-node" in model.doc_options %} + .. rubric:: Single node training + + 1. Set up environment variables. + + .. code-block:: shell + + export MAD_SECRETS_HFTOKEN= + export HF_HOME= + + ``MAD_SECRETS_HFTOKEN`` is your Hugging Face access token to access models, tokenizers, and data. + See `User access tokens `__. + + ``HF_HOME`` is where ``huggingface_hub`` will store local data. See `huggingface_hub CLI `__. + If you already have downloaded or cached Hugging Face artifacts, set this variable to that path. + Downloaded files typically get cached to ``~/.cache/huggingface``. + + 2. Launch the Docker container. + + .. tab-set:: + {% for docker in dockers %} + {% set jax_version = docker.components["JAX"] %} + + .. tab-item:: JAX {{ jax_version }} + :sync: {{ docker.pull_tag }} + + .. code-block:: shell + + docker run -it \ + --device=/dev/dri \ + --device=/dev/kfd \ + --network host \ + --ipc host \ + --group-add video \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + --privileged \ + -v $HOME:$HOME \ + -v $HOME/.ssh:/root/.ssh \ + -v $HF_HOME:/hf_cache \ + -e HF_HOME=/hf_cache \ + -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN + --shm-size 64G \ + --name training_env \ + {{ docker.pull_tag }} + {% endfor %} + + 3. In the Docker container, clone the ROCm MAD repository and navigate to the + benchmark scripts directory at ``MAD/scripts/jax-maxtext``. + + .. code-block:: shell + + git clone https://github.com/ROCm/MAD + cd MAD/scripts/jax-maxtext + + 4. Run the setup scripts to install libraries and datasets needed + for benchmarking. + + .. code-block:: shell + + ./jax-maxtext_benchmark_setup.sh -m {{ model.model_repo }} + + 5. To run the training benchmark without quantization, use the following command: + + .. code-block:: shell + + ./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} + + For quantized training, use the following command: + + .. code-block:: shell + + ./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} -q nanoo_fp8 + + {% endif %} + {% if model.multinode_training_script and "multi-node" in model.doc_options %} + .. rubric:: Multi-node training + + The following examples use SLURM to run on multiple nodes. + + .. note:: + + The following scripts will launch the Docker container and run the + benchmark. Run them outside of any Docker container. + + 1. Make sure ``$HF_HOME`` is set before running the test. See + `ROCm benchmarking `__ + for more details on downloading the Llama models before running the + benchmark. + + 2. To run multi-node training for {{ model.model }}, + use the + `multi-node training script `__ + under the ``scripts/jax-maxtext/gpu-rocm/`` directory. + + 3. Run the multi-node training benchmark script. + + .. code-block:: shell + + sbatch -N {{ model.multinode_training_script }} + + {% else %} + .. rubric:: Multi-node training + + For multi-node training examples, choose a model from :ref:`amd-maxtext-model-support-v257` + with an available `multi-node training script `__. + {% endif %} + {% endfor %} + {% endfor %} + +Further reading +=============== + +- To learn more about MAD and the ``madengine`` CLI, see the `MAD usage guide `__. + +- To learn more about system settings and management practices to configure your system for + AMD Instinct MI300X series GPUs, see `AMD Instinct MI300X system optimization `_. + +- For a list of other ready-made Docker images for AI with ROCm, see + `AMD Infinity Hub `_. + +Previous versions +================= + +See :doc:`jax-maxtext-history` to find documentation for previous releases +of the ``ROCm/jax-training`` Docker image.