diff --git a/docs/conf.py b/docs/conf.py index baefd9210..c66fb49b7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -166,6 +166,8 @@ article_pages = [ {"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-history", "os": ["linux"]}, {"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-v25.4", "os": ["linux"]}, {"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-v25.5", "os": ["linux"]}, + {"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-v25.9", "os": ["linux"]}, + {"file": "how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-v25.11", "os": ["linux"]}, {"file": "how-to/rocm-for-ai/training/benchmark-docker/mpt-llm-foundry", "os": ["linux"]}, {"file": "how-to/rocm-for-ai/fine-tuning/index", "os": ["linux"]}, diff --git a/docs/data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml b/docs/data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml index 144b48bea..a63eb5019 100644 --- a/docs/data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml +++ b/docs/data/how-to/rocm-for-ai/training/jax-maxtext-benchmark-models.yaml @@ -1,11 +1,11 @@ dockers: - - pull_tag: rocm/jax-training:maxtext-v25.11 - docker_hub_url: https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.11/images/sha256-18e4d8f0b8ce7a7422c58046940dd5f32249960449fca09a562b65fb8eb1562a + - pull_tag: rocm/jax-training:maxtext-v26.1 + docker_hub_url: https://hub.docker.com/layers/rocm/jax-training/maxtext-v26.1/images/sha256-901083bde353fe6362ada3036e452c792b2c96124e5900f4e9b5946c02ff9d6a components: - ROCm: 7.1.0 - JAX: 0.7.1 + ROCm: 7.1.1 + JAX: 0.8.2 Python: 3.12 - Transformer Engine: 2.4.0.dev0+281042de + Transformer Engine: 2.8.0.dev0+aec00a7f hipBLASLt: 1.2.x model_groups: - group: Meta Llama @@ -15,21 +15,29 @@ model_groups: mad_tag: jax_maxtext_train_llama-2-7b model_repo: Llama-2-7B precision: bf16 - multinode_training_script: llama2_7b_multinode.sh + multinode_config: + gfx950: env_scripts/gfx950_llama2_7b.yml + gfx942: env_scripts/llama2_7b.yml doc_options: ["single-node", "multi-node"] - model: Llama 2 70B mad_tag: jax_maxtext_train_llama-2-70b model_repo: Llama-2-70B precision: bf16 - multinode_training_script: llama2_70b_multinode.sh + multinode_config: + gfx950: env_scripts/gfx950_llama2_70b.yml + gfx942: env_scripts/llama2_70b.yml doc_options: ["single-node", "multi-node"] - model: Llama 3 8B (multi-node) mad_tag: jax_maxtext_train_llama-3-8b - multinode_training_script: llama3_8b_multinode.sh + multinode_config: + gfx950: env_scripts/gfx950_llama3_8b.yml + gfx942: env_scripts/llama3_8b.yml doc_options: ["multi-node"] - model: Llama 3 70B (multi-node) mad_tag: jax_maxtext_train_llama-3-70b - multinode_training_script: llama3_70b_multinode.sh + multinode_config: + gfx950: env_scripts/gfx950_llama3_70b.yml + gfx942: env_scripts/llama3_70b.yml doc_options: ["multi-node"] - model: Llama 3.1 8B mad_tag: jax_maxtext_train_llama-3.1-8b @@ -41,11 +49,21 @@ model_groups: model_repo: Llama-3.1-70B precision: bf16 doc_options: ["single-node"] + - model: Llama 3.1 405B + mad_tag: jax_maxtext_train_llama-3.1-405b + model_repo: Llama-3.1-405B + precision: bf16 + multinode_config: + gfx950: env_scripts/gfx950_llama3_405b.yml + doc_options: ["single-node", "multi-node"] - model: Llama 3.3 70B mad_tag: jax_maxtext_train_llama-3.3-70b model_repo: Llama-3.3-70B precision: bf16 - doc_options: ["single-node"] + multinode_config: + gfx950: env_scripts/gfx950_llama3.3_70b.yml + gfx942: env_scripts/llama3.3_70b.yml + doc_options: ["single-node", "multi-node"] - group: DeepSeek tag: deepseek models: @@ -53,7 +71,10 @@ model_groups: mad_tag: jax_maxtext_train_deepseek-v2-lite-16b model_repo: DeepSeek-V2-lite precision: bf16 - doc_options: ["single-node"] + multinode_config: + gfx950: env_scripts/gfx950_deepseek2_16b.yml + gfx942: env_scripts/deepseek2_16b.yml + doc_options: ["single-node", "multi-node"] - group: Mistral AI tag: mistral models: @@ -61,4 +82,7 @@ model_groups: mad_tag: jax_maxtext_train_mixtral-8x7b model_repo: Mixtral-8x7B precision: bf16 - doc_options: ["single-node"] + multinode_config: + gfx950: env_scripts/gfx950_mixtral_8x7b.yml + gfx942: env_scripts/llama3_8x7b.yml + doc_options: ["single-node", "multi-node"] diff --git a/docs/data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.11-benchmark-models.yaml b/docs/data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.11-benchmark-models.yaml new file mode 100644 index 000000000..144b48bea --- /dev/null +++ b/docs/data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.11-benchmark-models.yaml @@ -0,0 +1,64 @@ +dockers: + - pull_tag: rocm/jax-training:maxtext-v25.11 + docker_hub_url: https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.11/images/sha256-18e4d8f0b8ce7a7422c58046940dd5f32249960449fca09a562b65fb8eb1562a + components: + ROCm: 7.1.0 + JAX: 0.7.1 + Python: 3.12 + Transformer Engine: 2.4.0.dev0+281042de + hipBLASLt: 1.2.x +model_groups: + - group: Meta Llama + tag: llama + models: + - model: Llama 2 7B + mad_tag: jax_maxtext_train_llama-2-7b + model_repo: Llama-2-7B + precision: bf16 + multinode_training_script: llama2_7b_multinode.sh + doc_options: ["single-node", "multi-node"] + - model: Llama 2 70B + mad_tag: jax_maxtext_train_llama-2-70b + model_repo: Llama-2-70B + precision: bf16 + multinode_training_script: llama2_70b_multinode.sh + doc_options: ["single-node", "multi-node"] + - model: Llama 3 8B (multi-node) + mad_tag: jax_maxtext_train_llama-3-8b + multinode_training_script: llama3_8b_multinode.sh + doc_options: ["multi-node"] + - model: Llama 3 70B (multi-node) + mad_tag: jax_maxtext_train_llama-3-70b + multinode_training_script: llama3_70b_multinode.sh + doc_options: ["multi-node"] + - model: Llama 3.1 8B + mad_tag: jax_maxtext_train_llama-3.1-8b + model_repo: Llama-3.1-8B + precision: bf16 + doc_options: ["single-node"] + - model: Llama 3.1 70B + mad_tag: jax_maxtext_train_llama-3.1-70b + model_repo: Llama-3.1-70B + precision: bf16 + doc_options: ["single-node"] + - model: Llama 3.3 70B + mad_tag: jax_maxtext_train_llama-3.3-70b + model_repo: Llama-3.3-70B + precision: bf16 + doc_options: ["single-node"] + - group: DeepSeek + tag: deepseek + models: + - model: DeepSeek-V2-Lite (16B) + mad_tag: jax_maxtext_train_deepseek-v2-lite-16b + model_repo: DeepSeek-V2-lite + precision: bf16 + doc_options: ["single-node"] + - group: Mistral AI + tag: mistral + models: + - model: Mixtral 8x7B + mad_tag: jax_maxtext_train_mixtral-8x7b + model_repo: Mixtral-8x7B + precision: bf16 + doc_options: ["single-node"] diff --git a/docs/how-to/rocm-for-ai/training/benchmark-docker/jax-maxtext.rst b/docs/how-to/rocm-for-ai/training/benchmark-docker/jax-maxtext.rst index 1903b39d4..bf64fb68a 100644 --- a/docs/how-to/rocm-for-ai/training/benchmark-docker/jax-maxtext.rst +++ b/docs/how-to/rocm-for-ai/training/benchmark-docker/jax-maxtext.rst @@ -35,13 +35,6 @@ It includes the following software components: {% endfor %} {% endfor %} -.. note:: - - The ``rocm/jax-training:maxtext-v25.9`` has been updated to - ``rocm/jax-training:maxtext-v25.9.1``. This revision should include - a fix to address segmentation fault issues during launch. See the - :doc:`versioned documentation `. - MaxText with on ROCm provides the following key features to train large language models efficiently: - Transformer Engine (TE) @@ -54,7 +47,7 @@ MaxText with on ROCm provides the following key features to train large language - NANOO FP8 (for MI300X series GPUs) and FP8 (for MI355X and MI350X) quantization support -.. _amd-maxtext-model-support-v25.11: +.. _amd-maxtext-model-support-v26.1: Supported models ================ @@ -136,7 +129,7 @@ Use the following command to pull the Docker image from Docker Hub. docker pull {{ docker.pull_tag }} -.. _amd-maxtext-multi-node-setup-v25.11: +.. _amd-maxtext-multi-node-setup-v26.1: Multi-node configuration ------------------------ @@ -144,7 +137,7 @@ Multi-node configuration See :doc:`/how-to/rocm-for-ai/system-setup/multi-node-setup` to configure your environment for multi-node training. -.. _amd-maxtext-get-started-v25.11: +.. _amd-maxtext-get-started-v26.1: Benchmarking ============ @@ -169,7 +162,7 @@ benchmark results: .. tab-item:: MAD-integrated benchmarking The following run command is tailored to {{ model.model }}. - See :ref:`amd-maxtext-model-support-v25.11` to switch to another available model. + See :ref:`amd-maxtext-model-support-v26.1` to switch to another available model. 1. Clone the ROCm Model Automation and Dashboarding (``__) repository to a local directory and install the required packages on the host machine. @@ -200,7 +193,7 @@ benchmark results: .. tab-item:: Standalone benchmarking The following commands are optimized for {{ model.model }}. See - :ref:`amd-maxtext-model-support-v25.11` to switch to another + :ref:`amd-maxtext-model-support-v26.1` to switch to another available model. Some instructions and resources might not be available for all models and configurations. @@ -296,56 +289,57 @@ benchmark results: {% endif %} {% endif %} - {% if model.multinode_training_script and "multi-node" in model.doc_options %} + {% if model.multinode_config and "multi-node" in model.doc_options %} .. rubric:: Multi-node training - The following examples use SLURM to run on multiple nodes. + The following SLURM scripts will launch the Docker container and + run the benchmark. Run them outside of any Docker container. The + unified multi-node benchmark script accepts a configuration file + that specifies the model and training parameters. - .. note:: + .. code-block:: shell - The following scripts will launch the Docker container and run the - benchmark. Run them outside of any Docker container. + sbatch -N jax_maxtext_multinode_benchmark.sh [docker_image] - 1. Make sure ``$HF_HOME`` is set before running the test. See - `ROCm benchmarking `__ - for more details on downloading the Llama models before running the - benchmark. + + The number of nodes to use for training (for example, 2, 4, + 8). - 2. To run multi-node training for {{ model.model }}, - use the - `multi-node training script `__ - under the ``scripts/jax-maxtext/gpu-rocm/`` directory. + + Path to the YAML configuration file containing model and + training parameters. Configuration files are available in the + ``scripts/jax-maxtext/env_scripts/`` directory for different + models and GPU architectures. - 3. Run the multi-node training benchmark script. + [docker_image] (optional) + The Docker image to use. If not specified, it defaults to + ``rocm/jax-training:maxtext-v26.1``. - .. code-block:: shell + For example, to run a multi-node training benchmark on {{ model.model }}: - sbatch -N {{ model.multinode_training_script }} + .. tab-set:: - .. rubric:: Profiling with rocprofv3 + {% if model.multinode_config.gfx950 %} + .. tab-item:: MI355X and MI350X (gfx950) - If you need to collect a trace and the JAX profiler isn't working, use ``rocprofv3`` provided by the :doc:`ROCprofiler-SDK ` as a workaround. For example: + .. code-block:: bash - .. code-block:: bash + sbatch -N 4 jax_maxtext_multinode_benchmark.sh {{ model.multinode_config.gfx950 }} + {% endif %} - rocprofv3 \ - --hip-trace \ - --kernel-trace \ - --memory-copy-trace \ - --rccl-trace \ - --output-format pftrace \ - -d ./v3_traces \ # output directory - -- ./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} # or desired command + {% if model.multinode_config.gfx942 %} + .. tab-item:: MI325X and MI300X (gfx942) - You can set the directory where you want the .json traces to be - saved using ``-d ``. The resulting traces can be - opened in Perfetto: ``__. + .. code-block:: bash + + sbatch -N 4 jax_maxtext_multinode_benchmark.sh {{ model.multinode_config.gfx942 }} + {% endif %} {% else %} .. rubric:: Multi-node training - For multi-node training examples, choose a model from :ref:`amd-maxtext-model-support-v25.11` - with an available `multi-node training script `__. + For multi-node training examples, choose a model from :ref:`amd-maxtext-model-support-v26.1` + with an available `multi-node training script `__. {% endif %} {% endfor %} {% endfor %} @@ -353,35 +347,13 @@ benchmark results: Known issues ============ -- Minor performance regression (< 4%) for BF16 quantization in Llama models and Mixtral 8x7b. - -- You might see minor loss spikes, or loss curve may have slightly higher - convergence end values compared to the previous ``jax-training`` image. - -- For FP8 training on MI355, many models will display a warning message like: - ``Warning: Latency not found for MI_M=16, MI_N=16, MI_K=128, - mi_input_type=BFloat8Float8_fnuz. Returning latency value of 32 (really - slow).`` The compile step may take longer than usual, but training will run. +- You might see NaNs in the losses when setting ``packing=True``. As + a workaround, turn off input sequence packing (``packing=False``). This will be fixed in a future release. -- The built-in JAX profiler isn't working. If you need to collect a trace and - the JAX profiler isn't working, use ``rocprofv3`` provided by the - :doc:`ROCprofiler-SDK ` as a workaround. For example: - - .. code-block:: bash - - rocprofv3 \ - --hip-trace \ - --kernel-trace \ - --memory-copy-trace \ - --rccl-trace \ - --output-format pftrace \ - -d ./v3_traces \ # output directory - -- ./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} # or desired command - - You can set the directory where you want the .json traces to be - saved using ``-d ``. The resulting traces can be - opened in Perfetto: ``__. +- Docker ``rocm/jax-training:maxtext-v26.1`` does not include `Primus + `__. It is planned to be + supported in a future release. Further reading =============== diff --git a/docs/how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-history.rst b/docs/how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-history.rst index c6306f118..e004020e9 100644 --- a/docs/how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-history.rst +++ b/docs/how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-history.rst @@ -22,7 +22,7 @@ previous releases of the ``ROCm/jax-training`` Docker image on `Docker Hub ` + * :doc:`Documentation ` * `Docker Hub `__ * - 25.9.1 diff --git a/docs/how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-v25.11.rst b/docs/how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-v25.11.rst new file mode 100644 index 000000000..dfb64bdcd --- /dev/null +++ b/docs/how-to/rocm-for-ai/training/benchmark-docker/previous-versions/jax-maxtext-v25.11.rst @@ -0,0 +1,408 @@ +:orphan: + +.. meta:: + :description: How to train a model using JAX MaxText for ROCm. + :keywords: ROCm, AI, LLM, train, jax, torch, Llama, flux, tutorial, docker + +****************************************** +Training a model with JAX MaxText on ROCm +****************************************** + +.. caution:: + + This documentation does not reflect the latest version of ROCm JAX MaxText + training performance documentation. See :doc:`../jax-maxtext` for the latest version. + +The MaxText for ROCm training Docker image +provides a prebuilt environment for training on AMD Instinct MI355X, MI350X, MI325X, and MI300X GPUs, +including essential components like JAX, XLA, ROCm libraries, and MaxText utilities. +It includes the following software components: + +.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.11-benchmark-models.yaml + + {% set dockers = data.dockers %} + .. tab-set:: + + {% for docker in dockers %} + {% set jax_version = docker.components["JAX"] %} + + .. tab-item:: ``{{ docker.pull_tag }}`` + :sync: {{ docker.pull_tag }} + + .. list-table:: + :header-rows: 1 + + * - Software component + - Version + + {% for component_name, component_version in docker.components.items() %} + * - {{ component_name }} + - {{ component_version }} + + {% endfor %} + {% endfor %} + +.. note:: + + The ``rocm/jax-training:maxtext-v25.9`` has been updated to + ``rocm/jax-training:maxtext-v25.9.1``. This revision should include + a fix to address segmentation fault issues during launch. See the + :doc:`versioned documentation `. + +MaxText with on ROCm provides the following key features to train large language models efficiently: + +- Transformer Engine (TE) + +- Flash Attention (FA) 3 -- with or without sequence input packing + +- GEMM tuning + +- Multi-node support + +- NANOO FP8 (for MI300X series GPUs) and FP8 (for MI355X and MI350X) quantization support + +.. _amd-maxtext-model-support-v25.11: + +Supported models +================ + +The following models are pre-optimized for performance on AMD Instinct +GPUs. Some instructions, commands, and available training +configurations in this documentation might vary by model -- select one to get +started. + +.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.11-benchmark-models.yaml + + {% set model_groups = data.model_groups %} + .. raw:: html + +
+
+
Model
+
+ {% for model_group in model_groups %} +
{{ model_group.group }}
+ {% endfor %} +
+
+ +
+
Variant
+
+ {% for model_group in model_groups %} + {% set models = model_group.models %} + {% for model in models %} + {% if models|length % 3 == 0 %} +
{{ model.model }}
+ {% else %} +
{{ model.model }}
+ {% endif %} + {% endfor %} + {% endfor %} +
+
+
+ +.. note:: + + Some models, such as Llama 3, require an external license agreement through + a third party (for example, Meta). + +System validation +================= + +Before running AI workloads, it's important to validate that your AMD hardware is configured +correctly and performing optimally. + +If you have already validated your system settings, including aspects like NUMA auto-balancing, you +can skip this step. Otherwise, complete the procedures in the :ref:`System validation and +optimization ` guide to properly configure your system settings +before starting training. + +To test for optimal performance, consult the recommended :ref:`System health benchmarks +`. This suite of tests will help you verify and fine-tune your +system's configuration. + +Environment setup +================= + +This Docker image is optimized for specific model configurations outlined +as follows. Performance can vary for other training workloads, as AMD +doesn’t validate configurations and run conditions outside those described. + +Pull the Docker image +--------------------- + +Use the following command to pull the Docker image from Docker Hub. + +.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.11-benchmark-models.yaml + + {% set docker = data.dockers[0] %} + + .. code-block:: shell + + docker pull {{ docker.pull_tag }} + +.. _amd-maxtext-multi-node-setup-v25.11: + +Multi-node configuration +------------------------ + +See :doc:`/how-to/rocm-for-ai/system-setup/multi-node-setup` to configure your +environment for multi-node training. + +.. _amd-maxtext-get-started-v25.11: + +Benchmarking +============ + +Once the setup is complete, choose between two options to reproduce the +benchmark results: + +.. datatemplate:yaml:: /data/how-to/rocm-for-ai/training/previous-versions/jax-maxtext-v25.11-benchmark-models.yaml + + .. _vllm-benchmark-mad: + + {% set docker = data.dockers[0] %} + {% set model_groups = data.model_groups %} + {% for model_group in model_groups %} + {% for model in model_group.models %} + + .. container:: model-doc {{model.mad_tag}} + + .. tab-set:: + + {% if model.mad_tag and "single-node" in model.doc_options %} + .. tab-item:: MAD-integrated benchmarking + + The following run command is tailored to {{ model.model }}. + See :ref:`amd-maxtext-model-support-v25.11` to switch to another available model. + + 1. Clone the ROCm Model Automation and Dashboarding (``__) repository to a local + directory and install the required packages on the host machine. + + .. code-block:: shell + + git clone https://github.com/ROCm/MAD + cd MAD + pip install -r requirements.txt + + 2. Use this command to run the performance benchmark test on the {{ model.model }} model + using one GPU with the :literal:`{{model.precision}}` data type on the host machine. + + .. code-block:: shell + + export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models" + madengine run \ + --tags {{model.mad_tag}} \ + --keep-model-dir \ + --live-output \ + --timeout 28800 + + MAD launches a Docker container with the name + ``container_ci-{{model.mad_tag}}``. The latency and throughput reports of the + model are collected in the following path: ``~/MAD/perf.csv/``. + {% endif %} + + .. tab-item:: Standalone benchmarking + + The following commands are optimized for {{ model.model }}. See + :ref:`amd-maxtext-model-support-v25.11` to switch to another + available model. Some instructions and resources might not be + available for all models and configurations. + + .. rubric:: Download the Docker image and required scripts + + Run the JAX MaxText benchmark tool independently by starting the + Docker container as shown in the following snippet. + + .. code-block:: shell + + docker pull {{ docker.pull_tag }} + + {% if model.model_repo and "single-node" in model.doc_options %} + .. rubric:: Single node training + + 1. Set up environment variables. + + .. code-block:: shell + + export MAD_SECRETS_HFTOKEN= + export HF_HOME= + + ``MAD_SECRETS_HFTOKEN`` is your Hugging Face access token to access models, tokenizers, and data. + See `User access tokens `__. + + ``HF_HOME`` is where ``huggingface_hub`` will store local data. See `huggingface_hub CLI `__. + If you already have downloaded or cached Hugging Face artifacts, set this variable to that path. + Downloaded files typically get cached to ``~/.cache/huggingface``. + + 2. Launch the Docker container. + + .. code-block:: shell + + docker run -it \ + --device=/dev/dri \ + --device=/dev/kfd \ + --network host \ + --ipc host \ + --group-add video \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + --privileged \ + -v $HOME:$HOME \ + -v $HOME/.ssh:/root/.ssh \ + -v $HF_HOME:/hf_cache \ + -e HF_HOME=/hf_cache \ + -e MAD_SECRETS_HFTOKEN=$MAD_SECRETS_HFTOKEN + --shm-size 64G \ + --name training_env \ + {{ docker.pull_tag }} + + 3. In the Docker container, clone the ROCm MAD repository and navigate to the + benchmark scripts directory at ``MAD/scripts/jax-maxtext``. + + .. code-block:: shell + + git clone https://github.com/ROCm/MAD + cd MAD/scripts/jax-maxtext + + 4. Run the setup scripts to install libraries and datasets needed + for benchmarking. + + .. code-block:: shell + + ./jax-maxtext_benchmark_setup.sh -m {{ model.model_repo }} + + 5. To run the training benchmark without quantization, use the following command: + + .. code-block:: shell + + ./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} + + For quantized training, run the script with the appropriate option for your Instinct GPU. + + .. tab-set:: + + .. tab-item:: MI355X and MI350X + + For ``fp8`` quantized training on MI355X and MI350X GPUs, use the following command: + + .. code-block:: shell + + ./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} -q fp8 + + {% if model.model_repo not in ["Llama-3.1-70B", "Llama-3.3-70B"] %} + .. tab-item:: MI325X and MI300X + + For ``nanoo_fp8`` quantized training on MI300X series GPUs, use the following command: + + .. code-block:: shell + + ./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} -q nanoo_fp8 + {% endif %} + + {% endif %} + {% if model.multinode_training_script and "multi-node" in model.doc_options %} + .. rubric:: Multi-node training + + The following examples use SLURM to run on multiple nodes. + + .. note:: + + The following scripts will launch the Docker container and run the + benchmark. Run them outside of any Docker container. + + 1. Make sure ``$HF_HOME`` is set before running the test. See + `ROCm benchmarking `__ + for more details on downloading the Llama models before running the + benchmark. + + 2. To run multi-node training for {{ model.model }}, + use the + `multi-node training script `__ + under the ``scripts/jax-maxtext/gpu-rocm/`` directory. + + 3. Run the multi-node training benchmark script. + + .. code-block:: shell + + sbatch -N {{ model.multinode_training_script }} + + .. rubric:: Profiling with rocprofv3 + + If you need to collect a trace and the JAX profiler isn't working, use ``rocprofv3`` provided by the :doc:`ROCprofiler-SDK ` as a workaround. For example: + + .. code-block:: bash + + rocprofv3 \ + --hip-trace \ + --kernel-trace \ + --memory-copy-trace \ + --rccl-trace \ + --output-format pftrace \ + -d ./v3_traces \ # output directory + -- ./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} # or desired command + + You can set the directory where you want the .json traces to be + saved using ``-d ``. The resulting traces can be + opened in Perfetto: ``__. + + {% else %} + .. rubric:: Multi-node training + + For multi-node training examples, choose a model from :ref:`amd-maxtext-model-support-v25.11` + with an available `multi-node training script `__. + {% endif %} + {% endfor %} + {% endfor %} + +Known issues +============ + +- Minor performance regression (< 4%) for BF16 quantization in Llama models and Mixtral 8x7b. + +- You might see minor loss spikes, or loss curve may have slightly higher + convergence end values compared to the previous ``jax-training`` image. + +- For FP8 training on MI355, many models will display a warning message like: + ``Warning: Latency not found for MI_M=16, MI_N=16, MI_K=128, + mi_input_type=BFloat8Float8_fnuz. Returning latency value of 32 (really + slow).`` The compile step may take longer than usual, but training will run. + This will be fixed in a future release. + +- The built-in JAX profiler isn't working. If you need to collect a trace and + the JAX profiler isn't working, use ``rocprofv3`` provided by the + :doc:`ROCprofiler-SDK ` as a workaround. For example: + + .. code-block:: bash + + rocprofv3 \ + --hip-trace \ + --kernel-trace \ + --memory-copy-trace \ + --rccl-trace \ + --output-format pftrace \ + -d ./v3_traces \ # output directory + -- ./jax-maxtext_benchmark_report.sh -m {{ model.model_repo }} # or desired command + + You can set the directory where you want the .json traces to be + saved using ``-d ``. The resulting traces can be + opened in Perfetto: ``__. + +Further reading +=============== + +- To learn more about MAD and the ``madengine`` CLI, see the `MAD usage guide `__. + +- To learn more about system settings and management practices to configure your system for + AMD Instinct MI300X Series GPUs, see `AMD Instinct MI300X system optimization `_. + +- For a list of other ready-made Docker images for AI with ROCm, see + `AMD Infinity Hub `_. + +Previous versions +================= + +See :doc:`jax-maxtext-history` to find documentation for previous releases +of the ``ROCm/jax-training`` Docker image.