Compare commits

...

20 Commits

Author SHA1 Message Date
Istvan Kiss
e100cff912 Fix unsupported section structure on JAX (#4733) 2025-05-13 17:39:46 +02:00
Istvan Kiss
93bc9847cc Pytorch compatibility page update 2025-05-13 16:28:05 +02:00
Istvan Kiss
9c5b3c7f4c Fix compatibility list (#4731) 2025-05-13 16:27:48 +02:00
Istvan Kiss
028c2e4dbd JAX compatibility page upate (#4727) 2025-05-13 16:27:43 +02:00
Peter Park
d58d133762 Merge pull request #4725 from peterjunpark/docs/quark-model-quantization (#4726)
Add quark in model-quantization.rst

(cherry picked from commit 90a651d2b6)
2025-05-08 10:39:37 -04:00
Peter Park
065d1cdc95 Merge pull request #4725 from peterjunpark/docs/quark-model-quantization
Add quark in model-quantization.rst

(cherry picked from commit 90a651d2b6)
2025-05-08 10:35:33 -04:00
Peter Park
5b859352b2 Merge pull request #4724 from peterjunpark/docs/6.4.0
[docs/6.4.0] Fix incorrect throughput benchmark command in inference/vllm-benchmar…
2025-05-08 09:31:38 -04:00
Peter Park
f15a1e830e Fix incorrect throughput benchmark command in inference/vllm-benchmark.rst (#4723)
* update inference index to include pyt inference

* fix incorrect command in throughput benchmark

* wording

(cherry picked from commit bb7af3351a)
2025-05-08 09:27:44 -04:00
Pratik Basyal
a2628dce5d rocSHMEM component added to ROCm 6.4.0 documentation (#4719) (#4720)
* rocSHMEM added to ROCm 640

* Space removed

* link fixed
2025-05-07 15:42:38 -04:00
Peter Park
e0098d0668 fix links in pytorch-inference-benchmark.rst (#4713)
(cherry picked from commit 186c281aba)
2025-05-06 15:27:17 -04:00
Peter Park
71cffa9681 fix dynamic urls in toc.yml.in 2025-05-06 15:27:17 -04:00
Peter Park
ab49590526 Merge pull request #4708 from peterjunpark/docs/6.4.0
[docs/6.4.0] Add MPT-30B + LLM Foundry doc (#4704)
2025-05-02 11:17:06 -05:00
Peter Park
94337a9887 Add MPT-30B + LLM Foundry doc (#4704)
* add mpt-30b doc

* add tunableop note

* update MPT doc

* add section

* update wordlist

* fix flash attention version

* update "applies to"

* address review feedback

* Update docs/how-to/rocm-for-ai/training/benchmark-docker/mpt-llm-foundry.rst

Co-authored-by: Leo Paoletti <164940351+lpaoletti@users.noreply.github.com>

* Update docs/how-to/rocm-for-ai/training/benchmark-docker/mpt-llm-foundry.rst

Co-authored-by: Leo Paoletti <164940351+lpaoletti@users.noreply.github.com>

* Update docs/how-to/rocm-for-ai/training/benchmark-docker/mpt-llm-foundry.rst

Co-authored-by: Leo Paoletti <164940351+lpaoletti@users.noreply.github.com>

* update docker details to pytorch-training-v25.5

* update

---------

Co-authored-by: Leo Paoletti <164940351+lpaoletti@users.noreply.github.com>
(cherry picked from commit d44ea40a0d)
2025-05-02 12:13:56 -04:00
Pratik Basyal
6700febed9 Link updated (#4706) (#4707) 2025-05-01 11:47:57 -04:00
Peter Park
c8054fc6d2 Update vLLM docker pull tag 20250415 in vllm-benchmark.rst (#4702) (#4703)
(cherry picked from commit 85778177a1)
2025-04-30 16:59:18 -04:00
Peter Park
18d98ca692 Update vLLM docker pull tag 20250415 in vllm-benchmark.rst (#4702)
(cherry picked from commit 85778177a1)
2025-04-30 16:10:27 -04:00
Peter Park
e014590d35 Merge pull request #4697 from peterjunpark/docs/6.4.0
Update JAX MaxText benchmark doc to v25.5
2025-04-28 18:04:04 -04:00
Peter Park
c8144c4a60 Update JAX MaxText benchmark doc to v25.5 (#4695)
* fix shell cmd formatting

* add previous versions section

* update docker details and add llama 3.3

* update missed docker image tags to 25.5

(cherry picked from commit 7458fcb7ab)
2025-04-28 17:53:37 -04:00
Peter Park
ed45d6add9 fix link to pytorch-training v25.4 doc (#4696)
(cherry picked from commit 16d6e59003)
2025-04-28 17:53:37 -04:00
randyh62
e93e0bf925 Update RELEASE.md (#4690)
Update deprecation notice for `roc-obj` tools in HIP
2025-04-25 18:12:36 -07:00
21 changed files with 650 additions and 392 deletions

View File

@@ -226,6 +226,7 @@ LM
LSAN
LSan
LTS
LanguageCrossEntropy
LoRA
MEM
MERCHANTABILITY
@@ -243,6 +244,7 @@ MMIOH
MMU
MNIST
MPI
MPT
MSVC
MVAPICH
MVFFR
@@ -259,6 +261,7 @@ Meta's
Miniconda
MirroredStrategy
Mixtral
MosaicML
Multicore
Multithreaded
MyEnvironment
@@ -329,6 +332,7 @@ PipelineParallel
PnP
PowerEdge
PowerShell
Pretrained
Pretraining
Profiler's
PyPi

View File

@@ -6,7 +6,7 @@ different versions of the ROCm software stack and its components.
## ROCm 6.4.0
See the [ROCm 6.4.0 release notes](https://rocm-stg.amd.com/en/latest/about/release-notes.html)
See the [ROCm 6.4.0 release notes](https://rocm.docs.amd.com/en/docs-6.4.0/about/release-notes.html)
for a complete overview of this release.
### **AMD SMI** (25.3.0)
@@ -745,6 +745,10 @@ and in-depth descriptions.
#### Added
- Support for VA-API and rocDecode tracing.
- Aggregation of MPI data collected across distributed nodes and ranks. The data is concatenated into a single proto file.
#### Changed
- Backend refactored to use [ROCprofiler-SDK](https://github.com/ROCm/rocprofiler-sdk) rather than [ROCProfiler](https://github.com/ROCm/rocprofiler) and [ROCTracer](https://github.com/ROCm/ROCTracer).
#### Resolved issues
@@ -755,9 +759,9 @@ and in-depth descriptions.
- Fixed interruption in config file generation.
- Fixed segmentation fault while running rocprof-sys-instrument.
- Fixed an issue where running `rocprof-sys-causal` or using the `-I all` option with `rocprof-sys-sample` caused the system to become non-responsive.
#### Changed
- Backend refactored to use [ROCprofiler-SDK](https://github.com/ROCm/rocprofiler-sdk) rather than [ROCProfiler](https://github.com/ROCm/rocprofiler) and [ROCTracer](https://github.com/ROCm/ROCTracer).
- Fixed an issue where sampling multi-GPU Python workloads caused the system to stop responding.
### **rocPRIM** (3.4.0)

View File

@@ -253,14 +253,19 @@ Click {fab}`github` to go to the component's source code on GitHub.
</tbody>
<tbody class="rocm-components-libs rocm-components-communication tbody-reverse-zebra">
<tr>
<th rowspan="1"></th>
<th rowspan="1">Communication</th>
<th rowspan="2"></th>
<th rowspan="2">Communication</th>
<td><a href="https://rocm.docs.amd.com/projects/rccl/en/docs-6.4.0/index.html">RCCL</a></td>
<td>2.21.5&nbsp;&Rightarrow;&nbsp;<a href="#rccl-2-22-3">2.22.3</a></td>
<td><a href="https://github.com/ROCm/rccl"><i class="fab fa-github fa-lg"></i></a></td>
</tr>
<tr>
<td><a href="https://github.com/ROCm/rocSHMEM">rocSHMEM</a></td>
<td>2.0.0</td>
<td><a href="https://github.com/ROCm/rocSHMEM"><i class="fab fa-github fa-lg"></i></a></td>
</tr>
</tbody>
<tbody class="rocm-components-libs rocm-components-math">
<tbody class="rocm-components-libs rocm-components-math tbody-reverse-zebra">
<tr>
<th rowspan="16"></th>
<th rowspan="16">Math</th>
@@ -344,7 +349,7 @@ Click {fab}`github` to go to the component's source code on GitHub.
<td><a href="https://github.com/ROCm/Tensile"><i class="fab fa-github fa-lg"></i></a></td>
</tr>
</tbody>
<tbody class="rocm-components-libs rocm-components-primitives">
<tbody class="rocm-components-libs rocm-components-primitives tbody-reverse-zebra">
<tr>
<th rowspan="4"></th>
<th rowspan="4">Primitives</th>
@@ -368,7 +373,7 @@ Click {fab}`github` to go to the component's source code on GitHub.
<td><a href="https://github.com/ROCm/rocThrust"><i class="fab fa-github fa-lg"></i></a></td>
</tr>
</tbody>
<tbody class="rocm-components-tools rocm-components-system">
<tbody class="rocm-components-tools rocm-components-system tbody-reverse-zebra">
<tr>
<th rowspan="7">Tools</th>
<th rowspan="7">System management</th>
@@ -397,7 +402,7 @@ Click {fab}`github` to go to the component's source code on GitHub.
<td><a href="https://github.com/ROCm/ROCmValidationSuite"><i class="fab fa-github fa-lg"></i></a></td>
</tr>
</tbody>
<tbody class="rocm-components-tools rocm-components-perf tbody-reverse-zebra">
<tbody class="rocm-components-tools rocm-components-perf">
<tr>
<th rowspan="6"></th>
<th rowspan="6">Performance</th>
@@ -438,7 +443,7 @@ Click {fab}`github` to go to the component's source code on GitHub.
class="fab fa-github fa-lg"></i></a></td>
</tr>
</tbody>
<tbody class="rocm-components-tools rocm-components-dev tbody-reverse-zebra">
<tbody class="rocm-components-tools rocm-components-dev">
<tr>
<th rowspan="5"></th>
<th rowspan="5">Development</th>
@@ -474,7 +479,7 @@ Click {fab}`github` to go to the component's source code on GitHub.
class="fab fa-github fa-lg"></i></a></td>
</tr>
</tbody>
<tbody class="rocm-components-compilers">
<tbody class="rocm-components-compilers tbody-reverse-zebra">
<tr>
<th rowspan="2" colspan="2">Compilers</th>
<td><a href="https://rocm.docs.amd.com/projects/HIPCC/en/docs-6.4.0/index.html">HIPCC</a></td>
@@ -489,7 +494,7 @@ Click {fab}`github` to go to the component's source code on GitHub.
class="fab fa-github fa-lg"></i></a></td>
</tr>
</tbody>
<tbody class="rocm-components-runtimes">
<tbody class="rocm-components-runtimes tbody-reverse-zebra">
<tr>
<th rowspan="2" colspan="2">Runtimes</th>
<td><a href="https://rocm.docs.amd.com/projects/HIP/en/docs-6.4.0/index.html">HIP</a></td>
@@ -762,10 +767,10 @@ and in-depth descriptions.
#### Changed
* `roc-obj` tools is deprecated and will be removed in an upcoming release.
* The `roc-obj` tools have been deprecated and will be removed in a future release.
- Perl package installation is not required, and users will need to install this themselves if they want to.
- Support for ROCm Object tooling has moved into `llvm-objdump` provided by package `rocm-llvm`.
- `llvm-objdump`, `llvm-objcopy`, and `llvm-readobj` will be enhanced to provide similar functionality as that provided by the `roc-obj` tools . The LLVM tools are available in the `rocm-llvm` pkg.
- While not related to the deprecation, also note that the `roc-obj` tools package dependency on Perl has been changed to recommended. It is the users responsibility to install Perl to use these tools.
* SDMA retainer logic is removed for engine selection in operation of runtime buffer copy.
@@ -1249,6 +1254,11 @@ and in-depth descriptions.
#### Added
- Support for VA-API and rocDecode tracing.
- Aggregation of MPI data collected across distributed nodes and ranks. The data is concatenated into a single proto file.
#### Changed
- Backend refactored to use [ROCprofiler-SDK](https://github.com/ROCm/rocprofiler-sdk) rather than [ROCProfiler](https://github.com/ROCm/rocprofiler) and [ROCTracer](https://github.com/ROCm/ROCTracer).
#### Resolved issues
@@ -1259,9 +1269,9 @@ and in-depth descriptions.
- Fixed interruption in config file generation.
- Fixed segmentation fault while running rocprof-sys-instrument.
- Fixed an issue where running `rocprof-sys-causal` or using the `-I all` option with `rocprof-sys-sample` caused the system to become non-responsive.
#### Changed
- Backend refactored to use [ROCprofiler-SDK](https://github.com/ROCm/rocprofiler-sdk) rather than [ROCProfiler](https://github.com/ROCm/rocprofiler) and [ROCTracer](https://github.com/ROCm/ROCTracer).
- Fixed an issue where sampling multi-GPU Python workloads caused the system to stop responding.
### **rocPRIM** (3.4.0)

View File

@@ -81,6 +81,7 @@ additional licenses. Please review individual repositories for more information.
| [rocRAND](https://github.com/ROCm/rocRAND/) | [MIT](https://github.com/ROCm/rocRAND/blob/develop/LICENSE.txt) |
| [ROCr Debug Agent](https://github.com/ROCm/rocr_debug_agent/) | [The University of Illinois/NCSA](https://github.com/ROCm/rocr_debug_agent/blob/amd-staging/LICENSE.txt) |
| [ROCR-Runtime](https://github.com/ROCm/ROCR-Runtime/) | [The University of Illinois/NCSA](https://github.com/ROCm/ROCR-Runtime/blob/amd-staging/LICENSE.txt) |
| [rocSHMEM](https://github.com/ROCm/rocSHMEM/) | [MIT](https://github.com/ROCm/rocSHMEM/blob/develop/LICENSE.md) |
| [rocSOLVER](https://github.com/ROCm/rocSOLVER/) | [BSD-2-Clause](https://github.com/ROCm/rocSOLVER/blob/develop/LICENSE.md) |
| [rocSPARSE](https://github.com/ROCm/rocSPARSE/) | [MIT](https://github.com/ROCm/rocSPARSE/blob/develop/LICENSE.md) |
| [rocThrust](https://github.com/ROCm/rocThrust/) | [Apache 2.0](https://github.com/ROCm/rocThrust/blob/develop/LICENSE) |

View File

@@ -27,7 +27,6 @@ ROCm Version,6.4.0,6.3.3,6.3.2,6.3.1,6.3.0,6.2.4,6.2.2,6.2.1,6.2.0, 6.1.5, 6.1.2
:doc:`TensorFlow <../compatibility/ml-compatibility/tensorflow-compatibility>`,"2.18.1, 2.17.1, 2.16.2","2.17.0, 2.16.2, 2.15.1","2.17.0, 2.16.2, 2.15.1","2.17.0, 2.16.2, 2.15.1","2.17.0, 2.16.2, 2.15.1","2.16.1, 2.15.1, 2.14.1","2.16.1, 2.15.1, 2.14.1","2.16.1, 2.15.1, 2.14.1","2.16.1, 2.15.1, 2.14.1","2.15.0, 2.14.0, 2.13.1","2.15.0, 2.14.0, 2.13.1","2.15.0, 2.14.0, 2.13.1","2.15.0, 2.14.0, 2.13.1","2.14.0, 2.13.1, 2.12.1","2.14.0, 2.13.1, 2.12.1"
:doc:`JAX <../compatibility/ml-compatibility/jax-compatibility>`,0.4.35,0.4.31,0.4.31,0.4.31,0.4.31,0.4.26,0.4.26,0.4.26,0.4.26,0.4.26,0.4.26,0.4.26,0.4.26,0.4.26,0.4.26
`ONNX Runtime <https://onnxruntime.ai/docs/build/eps.html#amd-migraphx>`_,1.2,1.17.3,1.17.3,1.17.3,1.17.3,1.17.3,1.17.3,1.17.3,1.17.3,1.17.3,1.17.3,1.17.3,1.17.3,1.14.1,1.14.1
,,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,
THIRD PARTY COMMS,.. _thirdpartycomms-support-compatibility-matrix-past-60:,,,,,,,,,,,,,,
`UCC <https://github.com/ROCm/ucc>`_,>=1.3.0,>=1.3.0,>=1.3.0,>=1.3.0,>=1.3.0,>=1.3.0,>=1.3.0,>=1.3.0,>=1.3.0,>=1.3.0,>=1.3.0,>=1.3.0,>=1.3.0,>=1.2.0,>=1.2.0
@@ -53,6 +52,7 @@ ROCm Version,6.4.0,6.3.3,6.3.2,6.3.1,6.3.0,6.2.4,6.2.2,6.2.1,6.2.0, 6.1.5, 6.1.2
,,,,,,,,,,,,,,,
COMMUNICATION,.. _commlibs-support-compatibility-matrix-past-60:,,,,,,,,,,,,,,
:doc:`RCCL <rccl:index>`,2.22.3,2.21.5,2.21.5,2.21.5,2.21.5,2.20.5,2.20.5,2.20.5,2.20.5,2.18.6,2.18.6,2.18.6,2.18.6,2.18.3,2.18.3
`rocSHMEM <https://github.com/ROCm/rocSHMEM>`_,2.0.0,N/A,N/A,N/A,N/A,N/A,N/A,N/A,N/A,N/A,N/A,N/A,N/A,N/A,N/A
,,,,,,,,,,,,,,,
MATH LIBS,.. _mathlibs-support-compatibility-matrix-past-60:,,,,,,,,,,,,,,
`half <https://github.com/ROCm/half>`_ ,1.12.0,1.12.0,1.12.0,1.12.0,1.12.0,1.12.0,1.12.0,1.12.0,1.12.0,1.12.0,1.12.0,1.12.0,1.12.0,1.12.0,1.12.0
1 ROCm Version 6.4.0 6.3.3 6.3.2 6.3.1 6.3.0 6.2.4 6.2.2 6.2.1 6.2.0 6.1.5 6.1.2 6.1.1 6.1.0 6.0.2 6.0.0
27 :doc:`TensorFlow <../compatibility/ml-compatibility/tensorflow-compatibility>` 2.18.1, 2.17.1, 2.16.2 2.17.0, 2.16.2, 2.15.1 2.17.0, 2.16.2, 2.15.1 2.17.0, 2.16.2, 2.15.1 2.17.0, 2.16.2, 2.15.1 2.16.1, 2.15.1, 2.14.1 2.16.1, 2.15.1, 2.14.1 2.16.1, 2.15.1, 2.14.1 2.16.1, 2.15.1, 2.14.1 2.15.0, 2.14.0, 2.13.1 2.15.0, 2.14.0, 2.13.1 2.15.0, 2.14.0, 2.13.1 2.15.0, 2.14.0, 2.13.1 2.14.0, 2.13.1, 2.12.1 2.14.0, 2.13.1, 2.12.1
28 :doc:`JAX <../compatibility/ml-compatibility/jax-compatibility>` 0.4.35 0.4.31 0.4.31 0.4.31 0.4.31 0.4.26 0.4.26 0.4.26 0.4.26 0.4.26 0.4.26 0.4.26 0.4.26 0.4.26 0.4.26
29 `ONNX Runtime <https://onnxruntime.ai/docs/build/eps.html#amd-migraphx>`_ 1.2 1.17.3 1.17.3 1.17.3 1.17.3 1.17.3 1.17.3 1.17.3 1.17.3 1.17.3 1.17.3 1.17.3 1.17.3 1.14.1 1.14.1
30
31 THIRD PARTY COMMS .. _thirdpartycomms-support-compatibility-matrix-past-60:
32 `UCC <https://github.com/ROCm/ucc>`_ >=1.3.0 >=1.3.0 >=1.3.0 >=1.3.0 >=1.3.0 >=1.3.0 >=1.3.0 >=1.3.0 >=1.3.0 >=1.3.0 >=1.3.0 >=1.3.0 >=1.3.0 >=1.2.0 >=1.2.0
52
53 COMMUNICATION .. _commlibs-support-compatibility-matrix-past-60:
54 :doc:`RCCL <rccl:index>` 2.22.3 2.21.5 2.21.5 2.21.5 2.21.5 2.20.5 2.20.5 2.20.5 2.20.5 2.18.6 2.18.6 2.18.6 2.18.6 2.18.3 2.18.3
55 `rocSHMEM <https://github.com/ROCm/rocSHMEM>`_ 2.0.0 N/A N/A N/A N/A N/A N/A N/A N/A N/A N/A N/A N/A N/A N/A
56
57 MATH LIBS .. _mathlibs-support-compatibility-matrix-past-60:
58 `half <https://github.com/ROCm/half>`_ 1.12.0 1.12.0 1.12.0 1.12.0 1.12.0 1.12.0 1.12.0 1.12.0 1.12.0 1.12.0 1.12.0 1.12.0 1.12.0 1.12.0 1.12.0

View File

@@ -77,6 +77,7 @@ compatibility and system requirements.
,,,
COMMUNICATION,.. _commlibs-support-compatibility-matrix:,,
:doc:`RCCL <rccl:index>`,2.22.3,2.21.5,2.20.5
`rocSHMEM <https://github.com/ROCm/rocSHMEM>`_ ,2.0.0,N/A,N/A
,,,
MATH LIBS,.. _mathlibs-support-compatibility-matrix:,,
`half <https://github.com/ROCm/half>`_ ,1.12.0,1.12.0,1.12.0

View File

@@ -14,17 +14,18 @@ JAX provides a NumPy-like API, which combines automatic differentiation and the
Accelerated Linear Algebra (XLA) compiler to achieve high-performance machine
learning at scale.
JAX uses composable transformations of Python and NumPy through just-in-time (JIT) compilation,
automatic vectorization, and parallelization. To learn about JAX, including profiling and
optimizations, see the official `JAX documentation
JAX uses composable transformations of Python and NumPy through just-in-time
(JIT) compilation, automatic vectorization, and parallelization. To learn about
JAX, including profiling and optimizations, see the official `JAX documentation
<https://jax.readthedocs.io/en/latest/notebooks/quickstart.html>`_.
ROCm support for JAX is upstreamed and users can build the official source code with ROCm
support:
ROCm support for JAX is upstreamed, and users can build the official source code
with ROCm support:
- ROCm JAX release:
- Offers AMD-validated and community :ref:`Docker images <jax-docker-compat>` with ROCm and JAX pre-installed.
- Offers AMD-validated and community :ref:`Docker images <jax-docker-compat>`
with ROCm and JAX preinstalled.
- ROCm JAX repository: `ROCm/jax <https://github.com/ROCm/jax>`_
@@ -36,8 +37,8 @@ support:
- Official JAX repository: `jax-ml/jax <https://github.com/jax-ml/jax>`_
- See the `AMD GPU (Linux) installation section
<https://jax.readthedocs.io/en/latest/installation.html#amd-gpu-linux>`_ in the JAX
documentation.
<https://jax.readthedocs.io/en/latest/installation.html#amd-gpu-linux>`_ in
the JAX documentation.
.. note::
@@ -46,6 +47,44 @@ support:
`Community ROCm JAX Docker images <https://hub.docker.com/r/rocm/jax-community>`_
follow upstream JAX releases and use the latest available ROCm version.
Use cases and recommendations
================================================================================
* The `nanoGPT in JAX <https://rocm.blogs.amd.com/artificial-intelligence/nanoGPT-JAX/README.html>`_
blog explores the implementation and training of a Generative Pre-trained
Transformer (GPT) model in JAX, inspired by Andrej Karpathys JAX-based
nanoGPT. Comparing how essential GPT components—such as self-attention
mechanisms and optimizers—are realized in JAX and JAX, also highlights
JAXs unique features.
* The `Optimize GPT Training: Enabling Mixed Precision Training in JAX using
ROCm on AMD GPUs <https://rocm.blogs.amd.com/artificial-intelligence/jax-mixed-precision/README.html>`_
blog post provides a comprehensive guide on enhancing the training efficiency
of GPT models by implementing mixed precision techniques in JAX, specifically
tailored for AMD GPUs utilizing the ROCm platform.
* The `Supercharging JAX with Triton Kernels on AMD GPUs <https://rocm.blogs.amd.com/artificial-intelligence/jax-triton/README.html>`_
blog demonstrates how to develop a custom fused dropout-activation kernel for
matrices using Triton, integrate it with JAX, and benchmark its performance
using ROCm.
* The `Distributed fine-tuning with JAX on AMD GPUs <https://rocm.blogs.amd.com/artificial-intelligence/distributed-sft-jax/README.html>`_
outlines the process of fine-tuning a Bidirectional Encoder Representations
from Transformers (BERT)-based large language model (LLM) using JAX for a text
classification task. The blog post discuss techniques for parallelizing the
fine-tuning across multiple AMD GPUs and assess the model's performance on a
holdout dataset. During the fine-tuning, a BERT-base-cased transformer model
and the General Language Understanding Evaluation (GLUE) benchmark dataset was
used on a multi-GPU setup.
* The `MI300X workload optimization guide <https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html>`_
provides detailed guidance on optimizing workloads for the AMD Instinct MI300X
accelerator using ROCm. The page is aimed at helping users achieve optimal
performance for deep learning and other high-performance computing tasks on
the MI300X GPU.
For more use cases and recommendations, see `ROCm JAX blog posts <https://rocm.blogs.amd.com/blog/tag/jax.html>`_.
.. _jax-docker-compat:
Docker image compatibility
@@ -57,7 +96,7 @@ Docker image compatibility
AMD validates and publishes ready-made `ROCm JAX Docker images <https://hub.docker.com/r/rocm/jax>`_
with ROCm backends on Docker Hub. The following Docker image tags and
associated inventories are validated for
associated inventories represent the latest JAX version from the official Docker Hub and are validated for
`ROCm 6.4.0 <https://repo.radeon.com/rocm/apt/6.4/>`_. Click the |docker-icon|
icon to view the image on Docker Hub.
@@ -121,13 +160,12 @@ associated inventories are tested for `ROCm 6.3.2 <https://repo.radeon.com/rocm/
- Ubuntu 22.04
- `3.10.16 <https://www.python.org/downloads/release/python-31016/>`_
Critical ROCm libraries for JAX
Key ROCm libraries for JAX
================================================================================
The functionality of JAX with ROCm is determined by its underlying library
dependencies. These critical ROCm components affect the capabilities,
performance, and feature set available to developers. The versions described
are available in ROCm :version:`rocm_version`.
JAX functionality on ROCm is determined by its underlying library
dependencies. These ROCm components affect the capabilities, performance, and
feature set available to developers.
.. list-table::
:header-rows: 1
@@ -215,10 +253,10 @@ are available in ROCm :version:`rocm_version`.
distributed training, which involves parallel reductions or
operations like ``jax.numpy.cumsum`` can use rocThrust.
Supported and unsupported features
Supported features
===============================================================================
The following table maps GPU-accelerated JAX modules to their supported
The following table maps the public JAX API modules to their supported
ROCm and JAX versions.
.. list-table::
@@ -226,8 +264,8 @@ ROCm and JAX versions.
* - Module
- Description
- Since JAX
- Since ROCm
- As of JAX
- As of ROCm
* - ``jax.numpy``
- Implements the NumPy API, using the primitives in ``jax.lax``.
- 0.1.56
@@ -255,21 +293,11 @@ ROCm and JAX versions.
devices.
- 0.3.20
- 5.1.0
* - ``jax.dlpack``
- For exchanging tensor data between JAX and other libraries that support the
DLPack standard.
- 0.1.57
- 5.0.0
* - ``jax.distributed``
- Enables the scaling of computations across multiple devices on a single
machine or across multiple machines.
- 0.1.74
- 5.0.0
* - ``jax.dtypes``
- Provides utilities for working with and managing data types in JAX
arrays and computations.
- 0.1.66
- 5.0.0
* - ``jax.image``
- Contains image manipulation functions like resize, scale and translation.
- 0.1.57
@@ -283,27 +311,10 @@ ROCm and JAX versions.
array.
- 0.1.57
- 5.0.0
* - ``jax.profiler``
- Contains JAXs tracing and time profiling features.
- 0.1.57
- 5.0.0
* - ``jax.stages``
- Contains interfaces to stages of the compiled execution process.
- 0.3.4
- 5.0.0
* - ``jax.tree``
- Provides utilities for working with tree-like container data structures.
- 0.4.26
- 5.6.0
* - ``jax.tree_util``
- Provides utilities for working with nested data structures, or
``pytrees``.
- 0.1.65
- 5.0.0
* - ``jax.typing``
- Provides JAX-specific static type annotations.
- 0.3.18
- 5.1.0
* - ``jax.extend``
- Provides modules for access to JAX internal machinery module. The
``jax.extend`` module defines a library view of some of JAXs internal
@@ -339,8 +350,8 @@ A SciPy-like API for scientific computing.
:header-rows: 1
* - Module
- Since JAX
- Since ROCm
- As of JAX
- As of ROCm
* - ``jax.scipy.cluster``
- 0.3.11
- 5.1.0
@@ -385,8 +396,8 @@ jax.scipy.stats module
:header-rows: 1
* - Module
- Since JAX
- Since ROCm
- As of JAX
- As of ROCm
* - ``jax.scipy.stats.bernouli``
- 0.1.56
- 5.0.0
@@ -469,8 +480,8 @@ Modules for JAX extensions.
:header-rows: 1
* - Module
- Since JAX
- Since ROCm
- As of JAX
- As of ROCm
* - ``jax.extend.ffi``
- 0.4.30
- 6.0.0
@@ -484,190 +495,25 @@ Modules for JAX extensions.
- 0.4.15
- 5.5.0
jax.experimental module
-------------------------------------------------------------------------------
Experimental modules and APIs.
.. list-table::
:header-rows: 1
* - Module
- Since JAX
- Since ROCm
* - ``jax.experimental.checkify``
- 0.1.75
- 5.0.0
* - ``jax.experimental.compilation_cache.compilation_cache``
- 0.1.68
- 5.0.0
* - ``jax.experimental.custom_partitioning``
- 0.4.0
- 5.3.0
* - ``jax.experimental.jet``
- 0.1.56
- 5.0.0
* - ``jax.experimental.key_reuse``
- 0.4.26
- 5.6.0
* - ``jax.experimental.mesh_utils``
- 0.1.76
- 5.0.0
* - ``jax.experimental.multihost_utils``
- 0.3.2
- 5.0.0
* - ``jax.experimental.pallas``
- 0.4.15
- 5.5.0
* - ``jax.experimental.pjit``
- 0.1.61
- 5.0.0
* - ``jax.experimental.serialize_executable``
- 0.4.0
- 5.3.0
* - ``jax.experimental.shard_map``
- 0.4.3
- 5.3.0
* - ``jax.experimental.sparse``
- 0.1.75
- 5.0.0
.. list-table::
:header-rows: 1
* - API
- Since JAX
- Since ROCm
* - ``jax.experimental.enable_x64``
- 0.1.60
- 5.0.0
* - ``jax.experimental.disable_x64``
- 0.1.60
- 5.0.0
jax.experimental.pallas module
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Module for Pallas, a JAX extension for custom kernels.
.. list-table::
:header-rows: 1
* - Module
- Since JAX
- Since ROCm
* - ``jax.experimental.pallas.mosaic_gpu``
- 0.4.31
- 6.1.3
* - ``jax.experimental.pallas.tpu``
- 0.4.15
- 5.5.0
* - ``jax.experimental.pallas.triton``
- 0.4.32
- 6.1.3
jax.experimental.sparse module
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Experimental support for sparse matrix operations.
.. list-table::
:header-rows: 1
* - Module
- Since JAX
- Since ROCm
* - ``jax.experimental.sparse.linalg``
- 0.3.15
- 5.2.0
* - ``jax.experimental.sparse.sparsify``
- 0.3.25
- ❌
.. list-table::
:header-rows: 1
* - ``sparse`` data structure API
- Since JAX
- Since ROCm
* - ``jax.experimental.sparse.BCOO``
- 0.1.72
- 5.0.0
* - ``jax.experimental.sparse.BCSR``
- 0.3.20
- 5.1.0
* - ``jax.experimental.sparse.CSR``
- 0.1.75
- 5.0.0
* - ``jax.experimental.sparse.NM``
- 0.4.27
- 5.6.0
* - ``jax.experimental.sparse.COO``
- 0.1.75
- 5.0.0
Unsupported JAX features
------------------------
===============================================================================
The following are GPU-accelerated JAX features not currently supported by
ROCm.
The following GPU-accelerated JAX features are not supported by ROCm for
the listed supported JAX versions.
.. list-table::
:header-rows: 1
* - Feature
- Description
- Since JAX
* - Mixed Precision with TF32
- Mixed precision with TF32 is used for matrix multiplications,
convolutions, and other linear algebra operations, particularly in
deep learning workloads like CNNs and transformers.
- 0.2.25
* - RNN support
- Currently only LSTM with double bias is supported with float32 input
and weight.
- 0.3.25
* - XLA int4 support
- 4-bit integer (int4) precision in the XLA compiler.
- 0.4.0
* - ``jax.experimental.sparsify``
- Converts a dense matrix to a sparse matrix representation.
- Experimental
Use cases and recommendations
================================================================================
* The `nanoGPT in JAX <https://rocm.blogs.amd.com/artificial-intelligence/nanoGPT-JAX/README.html>`_
blog explores the implementation and training of a Generative Pre-trained
Transformer (GPT) model in JAX, inspired by Andrej Karpathys PyTorch-based
nanoGPT. By comparing how essential GPT components—such as self-attention
mechanisms and optimizers—are realized in PyTorch and JAX, also highlight
JAXs unique features.
* The `Optimize GPT Training: Enabling Mixed Precision Training in JAX using
ROCm on AMD GPUs <https://rocm.blogs.amd.com/artificial-intelligence/jax-mixed-precision/README.html>`_
blog post provides a comprehensive guide on enhancing the training efficiency
of GPT models by implementing mixed precision techniques in JAX, specifically
tailored for AMD GPUs utilizing the ROCm platform.
* The `Supercharging JAX with Triton Kernels on AMD GPUs <https://rocm.blogs.amd.com/artificial-intelligence/jax-triton/README.html>`_
blog demonstrates how to develop a custom fused dropout-activation kernel for
matrices using Triton, integrate it with JAX, and benchmark its performance
using ROCm.
* The `Distributed fine-tuning with JAX on AMD GPUs <https://rocm.blogs.amd.com/artificial-intelligence/distributed-sft-jax/README.html>`_
outlines the process of fine-tuning a Bidirectional Encoder Representations
from Transformers (BERT)-based large language model (LLM) using JAX for a text
classification task. The blog post discuss techniques for parallelizing the
fine-tuning across multiple AMD GPUs and assess the model's performance on a
holdout dataset. During the fine-tuning, a BERT-base-cased transformer model
and the General Language Understanding Evaluation (GLUE) benchmark dataset was
used on a multi-GPU setup.
* The `MI300X workload optimization guide <https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html>`_
provides detailed guidance on optimizing workloads for the AMD Instinct MI300X
accelerator using ROCm. The page is aimed at helping users achieve optimal
performance for deep learning and other high-performance computing tasks on
the MI300X GPU.
For more use cases and recommendations, see `ROCm JAX blog posts <https://rocm.blogs.amd.com/blog/tag/jax.html>`_.
* - MOSAIC (GPU)
- Mosaic is a library of kernel-building abstractions for JAX's Pallas system

View File

@@ -21,31 +21,68 @@ release cycles for PyTorch on ROCm:
- ROCm PyTorch release:
- Provides the latest version of ROCm but doesn't immediately support the latest stable PyTorch
version.
- Provides the latest version of ROCm but might not necessarily support the
latest stable PyTorch version.
- Offers :ref:`Docker images <pytorch-docker-compat>` with ROCm and PyTorch
pre-installed.
preinstalled.
- ROCm PyTorch repository: `<https://github.com/ROCm/pytorch>`_
- See the :doc:`ROCm PyTorch installation guide <rocm-install-on-linux:install/3rd-party/pytorch-install>` to get started.
- See the :doc:`ROCm PyTorch installation guide <rocm-install-on-linux:install/3rd-party/pytorch-install>`
to get started.
- Official PyTorch release:
- Provides the latest stable version of PyTorch but doesn't immediately support the latest ROCm version.
- Provides the latest stable version of PyTorch but might not necessarily
support the latest ROCm version.
- Official PyTorch repository: `<https://github.com/pytorch/pytorch>`_
- See the `Nightly and latest stable version installation guide <https://pytorch.org/get-started/locally/>`_
or `Previous versions <https://pytorch.org/get-started/previous-versions/>`_ to get started.
or `Previous versions <https://pytorch.org/get-started/previous-versions/>`_
to get started.
The upstream PyTorch includes an automatic HIPification solution that automatically generates HIP
source code from the CUDA backend. This approach allows PyTorch to support ROCm without requiring
manual code modifications.
PyTorch includes tooling that generates HIP source code from the CUDA backend.
This approach allows PyTorch to support ROCm without requiring manual code
modifications. For more information, see :doc:`HIPIFY <hipify:index>`.
Development of ROCm is aligned with the stable release of PyTorch while upstream PyTorch testing uses
the stable release of ROCm to maintain consistency.
ROCm development is aligned with the stable release of PyTorch, while upstream
PyTorch testing uses the stable release of ROCm to maintain consistency.
.. _pytorch-recommendations:
Use cases and recommendations
================================================================================
* :doc:`Using ROCm for AI: training a model </how-to/rocm-for-ai/training/benchmark-docker/pytorch-training>`
guides how to leverage the ROCm platform for training AI models. It covers the
steps, tools, and best practices for optimizing training workflows on AMD GPUs
using PyTorch features.
* :doc:`Single-GPU fine-tuning and inference </how-to/rocm-for-ai/fine-tuning/single-gpu-fine-tuning-and-inference>`
describes and demonstrates how to use the ROCm platform for the fine-tuning
and inference of machine learning models, particularly large language models
(LLMs), on systems with a single GPU. This topic provides a detailed guide for
setting up, optimizing, and executing fine-tuning and inference workflows in
such environments.
* :doc:`Multi-GPU fine-tuning and inference optimization </how-to/rocm-for-ai/fine-tuning/multi-gpu-fine-tuning-and-inference>`
describes and demonstrates the fine-tuning and inference of machine learning
models on systems with multiple GPUs.
* The :doc:`Instinct MI300X workload optimization guide </how-to/rocm-for-ai/inference-optimization/workload>`
provides detailed guidance on optimizing workloads for the AMD Instinct MI300X
accelerator using ROCm. This guide helps users achieve optimal performance for
deep learning and other high-performance computing tasks on the MI300X
accelerator.
* The :doc:`Inception with PyTorch documentation </conceptual/ai-pytorch-inception>`
describes how PyTorch integrates with ROCm for AI workloads It outlines the
use of PyTorch on the ROCm platform and focuses on efficiently leveraging AMD
GPU hardware for training and inference tasks in AI applications.
For more use cases and recommendations, see `ROCm PyTorch blog posts <https://rocm.blogs.amd.com/blog/tag/pytorch.html>`_.
.. _pytorch-docker-compat:
@@ -56,10 +93,10 @@ Docker image compatibility
<i class="fab fa-docker"></i>
AMD validates and publishes ready-made `PyTorch images <https://hub.docker.com/r/rocm/pytorch>`_
with ROCm backends on Docker Hub. The following Docker image tags and
associated inventories are validated for `ROCm 6.4.0 <https://repo.radeon.com/rocm/apt/6.4/>`_.
Click the |docker-icon| icon to view the image on Docker Hub.
AMD validates and publishes `PyTorch images <https://hub.docker.com/r/rocm/pytorch>`_
with ROCm backends on Docker Hub. The following Docker image tags and associated
inventories were tested on `ROCm 6.4.0 <https://repo.radeon.com/rocm/apt/6.4/>`_.
Click |docker-icon| to view the image on Docker Hub.
.. list-table:: PyTorch Docker image components
:header-rows: 1
@@ -212,13 +249,12 @@ Click the |docker-icon| icon to view the image on Docker Hub.
- `4.0.3 <https://github.com/open-mpi/ompi/tree/v4.0.3>`_
- `5.3-1.0.5.0 <https://content.mellanox.com/ofed/MLNX_OFED-5.3-1.0.5.0/MLNX_OFED_LINUX-5.3-1.0.5.0-ubuntu20.04-x86_64.tgz>`_
Critical ROCm libraries for PyTorch
Key ROCm libraries for PyTorch
================================================================================
The functionality of PyTorch with ROCm is determined by its underlying library
dependencies. These critical ROCm components affect the capabilities,
performance, and feature set available to developers. The versions described
are available in ROCm :version:`rocm_version`.
PyTorch functionality on ROCm is determined by its underlying library
dependencies. These ROCm components affect the capabilities, performance, and
feature set available to developers.
.. list-table::
:header-rows: 1
@@ -238,24 +274,23 @@ are available in ROCm :version:`rocm_version`.
- :version-ref:`hipBLAS rocm_version`
- Provides GPU-accelerated Basic Linear Algebra Subprograms (BLAS) for
matrix and vector operations.
- Supports operations like matrix multiplication, matrix-vector products,
and tensor contractions. Utilized in both dense and batched linear
algebra operations.
- Supports operations such as matrix multiplication, matrix-vector
products, and tensor contractions. Utilized in both dense and batched
linear algebra operations.
* - `hipBLASLt <https://github.com/ROCm/hipBLASLt>`_
- :version-ref:`hipBLASLt rocm_version`
- hipBLASLt is an extension of the hipBLAS library, providing additional
features like epilogues fused into the matrix multiplication kernel or
use of integer tensor cores.
- It accelerates operations like ``torch.matmul``, ``torch.mm``, and the
- Accelerates operations such as ``torch.matmul``, ``torch.mm``, and the
matrix multiplications used in convolutional and linear layers.
* - `hipCUB <https://github.com/ROCm/hipCUB>`_
- :version-ref:`hipCUB rocm_version`
- Provides a C++ template library for parallel algorithms for reduction,
scan, sort and select.
- Supports operations like ``torch.sum``, ``torch.cumsum``, ``torch.sort``
and ``torch.topk``. Operations on sparse tensors or tensors with
irregular shapes often involve scanning, sorting, and filtering, which
hipCUB handles efficiently.
- Supports operations such as ``torch.sum``, ``torch.cumsum``,
``torch.sort`` irregular shapes often involve scanning, sorting, and
filtering, which hipCUB handles efficiently.
* - `hipFFT <https://github.com/ROCm/hipFFT>`_
- :version-ref:`hipFFT rocm_version`
- Provides GPU-accelerated Fast Fourier Transform (FFT) operations.
@@ -263,8 +298,8 @@ are available in ROCm :version:`rocm_version`.
* - `hipRAND <https://github.com/ROCm/hipRAND>`_
- :version-ref:`hipRAND rocm_version`
- Provides fast random number generation for GPUs.
- The ``torch.rand``, ``torch.randn`` and stochastic layers like
``torch.nn.Dropout``.
- The ``torch.rand``, ``torch.randn``, and stochastic layers like
``torch.nn.Dropout`` rely on hipRAND.
* - `hipSOLVER <https://github.com/ROCm/hipSOLVER>`_
- :version-ref:`hipSOLVER rocm_version`
- Provides GPU-accelerated solvers for linear systems, eigenvalues, and
@@ -335,7 +370,7 @@ are available in ROCm :version:`rocm_version`.
- :version-ref:`RPP rocm_version`
- Speeds up data augmentation, transformation, and other preprocessing steps.
- Easy to integrate into PyTorch's ``torch.utils.data`` and
``torchvision`` data load workloads.
``torchvision`` data load workloads to speed up data processing.
* - `rocThrust <https://github.com/ROCm/rocThrust>`_
- :version-ref:`rocThrust rocm_version`
- Provides a C++ template library for parallel algorithms like sorting,
@@ -352,11 +387,11 @@ are available in ROCm :version:`rocm_version`.
involve matrix products, such as ``torch.matmul``, ``torch.bmm``, and
more.
Supported and unsupported features
Supported features
================================================================================
The following section maps GPU-accelerated PyTorch features to their supported
ROCm and PyTorch versions.
This section maps GPU-accelerated PyTorch features to their supported ROCm and
PyTorch versions.
torch
--------------------------------------------------------------------------------
@@ -364,23 +399,24 @@ torch
`torch <https://pytorch.org/docs/stable/index.html>`_ is the central module of
PyTorch, providing data structures for multi-dimensional tensors and
implementing mathematical operations on them. It also includes utilities for
efficient serialization of tensors and arbitrary data types, along with various
other tools.
efficient serialization of tensors and arbitrary data types and other tools.
Tensor data types
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The data type of a tensor is specified using the ``dtype`` attribute or argument, and PyTorch supports a wide range of data types for different use cases.
The tensor data type is specified using the ``dtype`` attribute or argument.
PyTorch supports many data types for different use cases.
The following table lists `torch.Tensor <https://pytorch.org/docs/stable/tensors.html>`_'s single data types:
The following table lists `torch.Tensor <https://pytorch.org/docs/stable/tensors.html>`_
single data types:
.. list-table::
:header-rows: 1
* - Data type
- Description
- Since PyTorch
- Since ROCm
- As of PyTorch
- As of ROCm
* - ``torch.float8_e4m3fn``
- 8-bit floating point, e4m3
- 2.3
@@ -472,11 +508,11 @@ The following table lists `torch.Tensor <https://pytorch.org/docs/stable/tensors
.. note::
Unsigned types aside from ``uint8`` are currently only have limited support in
eager mode (they primarily exist to assist usage with ``torch.compile``).
Unsigned types except ``uint8`` have limited support in eager mode. They
primarily exist to assist usage with ``torch.compile``.
The :doc:`ROCm precision support page <rocm:reference/precision-support>`
collected the native HW support of different data types.
See :doc:`ROCm precision support <rocm:reference/precision-support>` for the
native hardware support of data types.
torch.cuda
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -491,8 +527,8 @@ leveraging ROCm and CUDA as the underlying frameworks.
* - Feature
- Description
- Since PyTorch
- Since ROCm
- As of PyTorch
- As of ROCm
* - Device management
- Utilities for managing and interacting with GPUs.
- 0.4.0
@@ -566,8 +602,8 @@ PyTorch interacts with the ROCm or CUDA environment.
* - Feature
- Description
- Since PyTorch
- Since ROCm
- As of PyTorch
- As of ROCm
* - ``cufft_plan_cache``
- Manages caching of GPU FFT plans to optimize repeated FFT computations.
- 1.7.0
@@ -615,8 +651,8 @@ Supported ``torch`` options include:
* - Option
- Description
- Since PyTorch
- Since ROCm
- As of PyTorch
- As of ROCm
* - ``allow_tf32``
- TensorFloat-32 tensor cores may be used in cuDNN convolutions on NVIDIA
Ampere or newer GPUs.
@@ -631,28 +667,28 @@ Supported ``torch`` options include:
Automatic mixed precision: torch.amp
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
PyTorch that automates the process of using both 16-bit (half-precision,
float16) and 32-bit (single-precision, float32) floating-point types in model
training and inference.
PyTorch automates the process of using both 16-bit (half-precision, float16) and
32-bit (single-precision, float32) floating-point types in model training and
inference.
.. list-table::
:header-rows: 1
* - Feature
- Description
- Since PyTorch
- Since ROCm
- As of PyTorch
- As of ROCm
* - Autocasting
- Instances of autocast serve as context managers or decorators that allow
- Autocast instances serve as context managers or decorators that allow
regions of your script to run in mixed precision.
- 1.9
- 2.5
* - Gradient scaling
- To prevent underflow, “gradient scaling” multiplies the networks
loss(es) by a scale factor and invokes a backward pass on the scaled
loss(es). Gradients flowing backward through the network are then
scaled by the same factor. In other words, gradient values have a
larger magnitude, so they dont flush to zero.
loss by a scale factor and invokes a backward pass on the scaled
loss. The same factor then scales gradients flowing backward through
the network. In other words, gradient values have a larger magnitude so
that they dont flush to zero.
- 1.9
- 2.5
* - CUDA op-specific behavior
@@ -666,7 +702,7 @@ training and inference.
Distributed library features
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The PyTorch distributed library includes a collective of parallelism modules, a
PyTorch distributed library includes a collective of parallelism modules, a
communications layer, and infrastructure for launching and debugging large
training jobs. See :ref:`rocm-for-ai-pytorch-distributed` for more information.
@@ -680,13 +716,13 @@ of computational resources and scalability for large-scale tasks.
* - Feature
- Description
- Since PyTorch
- Since ROCm
- As of PyTorch
- As of ROCm
* - TensorPipe
- A point-to-point communication library integrated into
PyTorch for distributed training. It is designed to handle tensor data
transfers efficiently between different processes or devices, including
those on separate machines.
PyTorch for distributed training. It handles tensor data transfers
efficiently between different processes or devices, including those on
separate machines.
- 1.8
- 5.4
* - Gloo
@@ -705,8 +741,8 @@ torch.compiler
* - Feature
- Description
- Since PyTorch
- Since ROCm
- As of PyTorch
- As of ROCm
* - ``torch.compiler`` (AOT Autograd)
- Autograd captures not only the user-level code, but also backpropagation,
which results in capturing the backwards pass “ahead-of-time”. This
@@ -729,8 +765,8 @@ The `torchaudio <https://pytorch.org/audio/stable/index.html>`_ library provides
utilities for processing audio data in PyTorch, such as audio loading,
transformations, and feature extraction.
To ensure GPU-acceleration with ``torchaudio.transforms``, you need to move audio
data (waveform tensor) explicitly to GPU using ``.to('cuda')``.
To ensure GPU-acceleration with ``torchaudio.transforms``, you need to
explicitly move audio data (waveform tensor) to GPU using ``.to('cuda')``.
The following ``torchaudio`` features are GPU-accelerated.
@@ -739,10 +775,10 @@ The following ``torchaudio`` features are GPU-accelerated.
* - Feature
- Description
- Since torchaudio version
- Since ROCm
- As of torchaudio version
- As of ROCm
* - ``torchaudio.transforms.Spectrogram``
- Generates spectrogram of an input waveform using STFT.
- Generate a spectrogram of an input waveform using STFT.
- 0.6.0
- 4.5
* - ``torchaudio.transforms.MelSpectrogram``
@@ -762,7 +798,7 @@ torchvision
--------------------------------------------------------------------------------
The `torchvision <https://pytorch.org/vision/stable/index.html>`_ library
provide datasets, model architectures, and common image transformations for
provides datasets, model architectures, and common image transformations for
computer vision.
The following ``torchvision`` features are GPU-accelerated.
@@ -772,8 +808,8 @@ The following ``torchvision`` features are GPU-accelerated.
* - Feature
- Description
- Since torchvision version
- Since ROCm
- As of torchvision version
- As of ROCm
* - ``torchvision.transforms.functional``
- Provides GPU-compatible transformations for image preprocessing like
resize, normalize, rotate and crop.
@@ -819,7 +855,7 @@ torchtune
The `torchtune <https://pytorch.org/torchtune/stable/index.html>`_ library for
authoring, fine-tuning and experimenting with LLMs.
* Usage: It works out-of-the-box, enabling developers to fine-tune ROCm PyTorch solutions.
* Usage: Enabling developers to fine-tune ROCm PyTorch solutions.
* Only official release exists.
@@ -830,7 +866,8 @@ The `torchserve <https://pytorch.org/serve/>`_ is a PyTorch domain library
for common sparsity and parallelism primitives needed for large-scale recommender
systems.
* torchtext does not implement its own kernels. ROCm support is enabled by linking against ROCm libraries.
* torchtext does not implement its own kernels. ROCm support is enabled by
linking against ROCm libraries.
* Only official release exists.
@@ -841,14 +878,16 @@ The `torchrec <https://pytorch.org/torchrec/>`_ is a PyTorch domain library for
common sparsity and parallelism primitives needed for large-scale recommender
systems.
* torchrec does not implement its own kernels. ROCm support is enabled by linking against ROCm libraries.
* torchrec does not implement its own kernels. ROCm support is enabled by
linking against ROCm libraries.
* Only official release exists.
Unsupported PyTorch features
----------------------------
================================================================================
The following are GPU-accelerated PyTorch features not currently supported by ROCm.
The following GPU-accelerated PyTorch features are not supported by ROCm for
the listed supported PyTorch versions.
.. list-table::
:widths: 30, 60, 10
@@ -856,7 +895,7 @@ The following are GPU-accelerated PyTorch features not currently supported by RO
* - Feature
- Description
- Since PyTorch
- As of PyTorch
* - APEX batch norm
- Use APEX batch norm instead of PyTorch batch norm.
- 1.6.0
@@ -912,31 +951,3 @@ The following are GPU-accelerated PyTorch features not currently supported by RO
utilized effectively through custom CUDA extensions or advanced
workflows.
- Not a core feature
Use cases and recommendations
================================================================================
* :doc:`Using ROCm for AI: training a model </how-to/rocm-for-ai/training/train-a-model>` provides
guidance on how to leverage the ROCm platform for training AI models. It covers the steps, tools, and best practices
for optimizing training workflows on AMD GPUs using PyTorch features.
* :doc:`Single-GPU fine-tuning and inference </how-to/rocm-for-ai/fine-tuning/single-gpu-fine-tuning-and-inference>`
describes and demonstrates how to use the ROCm platform for the fine-tuning and inference of
machine learning models, particularly large language models (LLMs), on systems with a single AMD
Instinct MI300X accelerator. This page provides a detailed guide for setting up, optimizing, and
executing fine-tuning and inference workflows in such environments.
* :doc:`Multi-GPU fine-tuning and inference optimization </how-to/rocm-for-ai/fine-tuning/multi-gpu-fine-tuning-and-inference>`
describes and demonstrates the fine-tuning and inference of machine learning models on systems
with multi MI300X accelerators.
* The :doc:`Instinct MI300X workload optimization guide </how-to/rocm-for-ai/inference-optimization/workload>` provides detailed
guidance on optimizing workloads for the AMD Instinct MI300X accelerator using ROCm. This guide is aimed at helping
users achieve optimal performance for deep learning and other high-performance computing tasks on the MI300X
accelerator.
* The :doc:`Inception with PyTorch documentation </conceptual/ai-pytorch-inception>`
describes how PyTorch integrates with ROCm for AI workloads It outlines the use of PyTorch on the ROCm platform and
focuses on how to efficiently leverage AMD GPU hardware for training and inference tasks in AI applications.
For more use cases and recommendations, see `ROCm PyTorch blog posts <https://rocm.blogs.amd.com/blog/tag/pytorch.html>`_.

View File

@@ -57,6 +57,7 @@ article_pages = [
{"file": "how-to/rocm-for-ai/training/prerequisite-system-validation", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/megatron-lm", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/pytorch-training", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/benchmark-docker/mpt-llm-foundry", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/training/scale-model-training", "os": ["linux"]},
{"file": "how-to/rocm-for-ai/fine-tuning/index", "os": ["linux"]},

View File

@@ -1,8 +1,8 @@
vllm_benchmark:
unified_docker:
latest:
pull_tag: rocm/vllm:rocm6.3.1_instinct_vllm0.8.3_20250410
docker_hub_url: https://hub.docker.com/layers/rocm/vllm/rocm6.3.1_instinct_vllm0.8.3_20250410/images/sha256-a0b55c6c0f3fa5d437fb54a66e32a108306c36d4776e570dfd0ae902719bd190
pull_tag: rocm/vllm:rocm6.3.1_instinct_vllm0.8.3_20250415
docker_hub_url: https://hub.docker.com/layers/rocm/vllm/rocm6.3.1_instinct_vllm0.8.3_20250415/images/sha256-ad9062dea3483d59dedb17c67f7c49f30eebd6eb37c3fac0a171fb19696cc845
rocm_version: 6.3.1
vllm_version: 0.8.3
pytorch_version: 2.7.0 (dev nightly)

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

View File

@@ -1,15 +1,178 @@
.. meta::
:description: How to use model quantization techniques to speed up inference.
:keywords: ROCm, LLM, fine-tuning, usage, tutorial, quantization, GPTQ, transformers, bitsandbytes
:keywords: ROCm, LLM, fine-tuning, usage, tutorial, quantization, Quark, GPTQ, transformers, bitsandbytes
*****************************
Model quantization techniques
*****************************
Quantization reduces the model size compared to its native full-precision version, making it easier to fit large models
onto accelerators or GPUs with limited memory usage. This section explains how to perform LLM quantization using GPTQ
onto accelerators or GPUs with limited memory usage. This section explains how to perform LLM quantization using AMD Quark, GPTQ
and bitsandbytes on AMD Instinct hardware.
.. _quantize-llms-quark:
AMD Quark
=========
`AMD Quark <https://quark.docs.amd.com/latest/>`_ offers the leading efficient and scalable quantization solution tailored to AMD Instinct GPUs. It supports ``FP8`` and ``INT8`` quantization for activations, weights, and KV cache,
including ``FP8`` attention. For very large models, it employs a two-level ``INT4-FP8`` scheme—storing weights in ``INT4`` while computing with ``FP8``—for nearly 4× compression without sacrificing accuracy.
Quark scales efficiently across multiple GPUs, efficiently handling ultra-large models like Llama-3.1-405B. Quantized ``FP8`` models like Llama, Mixtral, and Grok-1 are available under the `AMD organization on Hugging Face <https://huggingface.co/collections/amd/quark-quantized-ocp-fp8-models-66db7936d18fcbaf95d4405c>`_, and can be deployed directly via `vLLM <https://github.com/vllm-project/vllm/tree/main/vllm>`_.
Installing Quark
-------------------
The latest release of Quark can be installed with pip
.. code-block:: shell
pip install amd-quark
For detailed installation instructions, refer to the `Quark documentation <https://quark.docs.amd.com/latest/install.html>`_.
Using Quark for quantization
-----------------------------
#. First, load the pre-trained model and its corresponding tokenizer using the Hugging Face ``transformers`` library.
.. code-block:: python
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_ID = "meta-llama/Llama-2-70b-chat-hf"
MAX_SEQ_LEN = 512
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype="auto",
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, model_max_length=MAX_SEQ_LEN)
tokenizer.pad_token = tokenizer.eos_token
#. Prepare the calibration DataLoader (static quantization requires calibration data).
.. code-block:: python
from datasets import load_dataset
from torch.utils.data import DataLoader
BATCH_SIZE = 1
NUM_CALIBRATION_DATA = 512
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
text_data = dataset["text"][:NUM_CALIBRATION_DATA]
tokenized_outputs = tokenizer(
text_data, return_tensors="pt", padding=True, truncation=True, max_length=MAX_SEQ_LEN
)
calib_dataloader = DataLoader(
tokenized_outputs['input_ids'], batch_size=BATCH_SIZE, drop_last=True
)
#. Define the quantization configuration. See the comments in the following code snippet for descriptions of each configuration option.
.. code-block:: python
from quark.torch.quantization import (Config, QuantizationConfig,
FP8E4M3PerTensorSpec)
# Define fp8/per-tensor/static spec.
FP8_PER_TENSOR_SPEC = FP8E4M3PerTensorSpec(observer_method="min_max",
is_dynamic=False).to_quantization_spec()
# Define global quantization config, input tensors and weight apply FP8_PER_TENSOR_SPEC.
global_quant_config = QuantizationConfig(input_tensors=FP8_PER_TENSOR_SPEC,
weight=FP8_PER_TENSOR_SPEC)
# Define quantization config for kv-cache layers, output tensors apply FP8_PER_TENSOR_SPEC.
KV_CACHE_SPEC = FP8_PER_TENSOR_SPEC
kv_cache_layer_names_for_llama = ["*k_proj", "*v_proj"]
kv_cache_quant_config = {name :
QuantizationConfig(input_tensors=global_quant_config.input_tensors,
weight=global_quant_config.weight,
output_tensors=KV_CACHE_SPEC)
for name in kv_cache_layer_names_for_llama}
layer_quant_config = kv_cache_quant_config.copy()
EXCLUDE_LAYERS = ["lm_head"]
quant_config = Config(
global_quant_config=global_quant_config,
layer_quant_config=layer_quant_config,
kv_cache_quant_config=kv_cache_quant_config,
exclude=EXCLUDE_LAYERS)
#. Quantize the model and export
.. code-block:: python
import torch
from quark.torch import ModelQuantizer, ModelExporter
from quark.torch.export import ExporterConfig, JsonExporterConfig
# Apply quantization.
quantizer = ModelQuantizer(quant_config)
quant_model = quantizer.quantize_model(model, calib_dataloader)
# Freeze quantized model to export.
freezed_model = quantizer.freeze(model)
# Define export config.
LLAMA_KV_CACHE_GROUP = ["*k_proj", "*v_proj"]
export_config = ExporterConfig(json_export_config=JsonExporterConfig())
export_config.json_export_config.kv_cache_group = LLAMA_KV_CACHE_GROUP
EXPORT_DIR = MODEL_ID.split("/")[1] + "-w-fp8-a-fp8-kvcache-fp8-pertensor"
exporter = ModelExporter(config=export_config, export_dir=EXPORT_DIR)
with torch.no_grad():
exporter.export_safetensors_model(freezed_model,
quant_config=quant_config, tokenizer=tokenizer)
Evaluating the quantized model with vLLM
----------------------------------------
The exported Quark-quantized model can be loaded directly by vLLM for inference. You need to specify the model path and inform vLLM about the quantization method (``quantization='quark'``) and the KV cache data type (``kv_cache_dtype='fp8'``).
Use the ``LLM`` interface to load the model:
.. code-block:: python
from vllm import LLM, SamplingParamsinterface
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
llm = LLM(model="Llama-2-70b-chat-hf-w-fp8-a-fp8-kvcache-fp8-pertensor",
kv_cache_dtype='fp8',quantization='quark')
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}")
print(f"Output: {generated_text!r}")
print("-" * 60)
You can also evaluate the quantized model's accuracy on standard benchmarks using the `lm-evaluation-harness <https://github.com/EleutherAI/lm-evaluation-harness>`_. Pass the necessary vLLM arguments to ``lm_eval`` via ``--model_args``.
.. code-block:: shell
lm_eval --model vllm \
--model_args pretrained=Llama-2-70b-chat-hf-w-fp8-a-fp8-kvcache-fp8-pertensor,kv_cache_dtype='fp8',quantization='quark' \
--tasks gsm8k
This provides a standardized way to measure the performance impact of quantization.
.. _fine-tune-llms-gptq:
GPTQ
@@ -33,7 +196,7 @@ The AutoGPTQ library implements the GPTQ algorithm.
.. code-block:: shell
# This will install pre-built wheel for a specific ROCm version.
pip install auto-gptq --no-build-isolation --extra-index-url https://huggingface.github.io/autogptq-index/whl/rocm573/
Or, install AutoGPTQ from source for the appropriate ROCm version (for example, ROCm 6.1).
@@ -43,10 +206,10 @@ The AutoGPTQ library implements the GPTQ algorithm.
# Clone the source code.
git clone https://github.com/AutoGPTQ/AutoGPTQ.git
cd AutoGPTQ
# Speed up the compilation by specifying PYTORCH_ROCM_ARCH to target device.
PYTORCH_ROCM_ARCH=gfx942 ROCM_VERSION=6.1 pip install .
# Show the package after the installation
#. Run ``pip show auto-gptq`` to print information for the installed ``auto-gptq`` package. Its output should look like
@@ -112,7 +275,7 @@ Using GPTQ with Hugging Face Transformers
.. code-block:: python
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig
base_model_name = " NousResearch/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
gptq_config = GPTQConfig(bits=4, dataset="c4", tokenizer=tokenizer)
@@ -212,10 +375,10 @@ To get started with bitsandbytes primitives, use the following code as reference
.. code-block:: python
import bitsandbytes as bnb
# Use Int8 Matrix Multiplication
bnb.matmul(..., threshold=6.0)
# Use bitsandbytes 8-bit Optimizers
adam = bnb.optim.Adam8bit(model.parameters(), lr=0.001, betas=(0.9, 0.995))
@@ -227,14 +390,14 @@ To load a Transformers model in 4-bit, set ``load_in_4bit=true`` in ``BitsAndByt
.. code-block:: python
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
base_model_name = "NousResearch/Llama-2-7b-hf"
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
bnb_model_4bit = AutoModelForCausalLM.from_pretrained(
base_model_name,
device_map="auto",
quantization_config=quantization_config)
# Check the memory footprint with get_memory_footprint method
print(bnb_model_4bit.get_memory_footprint())
@@ -243,9 +406,9 @@ To load a model in 8-bit for inference, use the ``load_in_8bit`` option.
.. code-block:: python
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
base_model_name = "NousResearch/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
@@ -253,7 +416,7 @@ To load a model in 8-bit for inference, use the ``load_in_8bit`` option.
base_model_name,
device_map="auto",
quantization_config=quantization_config)
prompt = "What is a large language model?"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
generated_ids = model.generate(**inputs)

View File

@@ -20,6 +20,8 @@ training, fine-tuning, and inference. It leverages popular machine learning fram
- :doc:`LLM inference frameworks <llm-inference-frameworks>`
- :doc:`Performance testing <vllm-benchmark>`
- :doc:`vLLM inference performance testing <vllm-benchmark>`
- :doc:`PyTorch inference performance testing <pytorch-inference-benchmark>`
- :doc:`Deploying your model <deploy-your-model>`

View File

@@ -86,7 +86,7 @@ PyTorch inference performance testing
.. container:: model-doc pyt_chai1_inference
2. Use the following command to pull the `ROCm PyTorch Docker image <https://hub.docker.com/layers/rocm/pytorch/latest/images/sha256-05b55983e5154f46e7441897d0908d79877370adca4d1fff4899d9539d6c4969>`_ from Docker Hub.
2. Use the following command to pull the `ROCm PyTorch Docker image <https://hub.docker.com/layers/rocm/pytorch/rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0_triton_llvm_reg_issue/images/sha256-b736a4239ab38a9d0e448af6d4adca83b117debed00bfbe33846f99c4540f79b>`_ from Docker Hub.
.. code-block:: shell
@@ -98,7 +98,7 @@ PyTorch inference performance testing
.. container:: model-doc pyt_clip_inference
2. Use the following command to pull the `ROCm PyTorch Docker image <https://hub.docker.com/layers/rocm/pytorch/rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0_triton_llvm_reg_issue/images/sha256-b736a4239ab38a9d0e448af6d4adca83b117debed00bfbe33846f99c4540f79b>`_ from Docker Hub.
2. Use the following command to pull the `ROCm PyTorch Docker image <https://hub.docker.com/layers/rocm/pytorch/latest/images/sha256-05b55983e5154f46e7441897d0908d79877370adca4d1fff4899d9539d6c4969>`_ from Docker Hub.
.. code-block:: shell

View File

@@ -276,7 +276,7 @@ vLLM inference performance testing
* Latency benchmark
Use this command to benchmark the latency of the {{model.model}} model on eight GPUs with the ``{{model.precision}}`` data type.
Use this command to benchmark the latency of the {{model.model}} model on eight GPUs with ``{{model.precision}}`` precision.
.. code-block::
@@ -286,11 +286,11 @@ vLLM inference performance testing
* Throughput benchmark
Use this command to throughput the latency of the {{model.model}} model on eight GPUs with the ``{{model.precision}}`` data type.
Use this command to benchmark the throughput of the {{model.model}} model on eight GPUs with ``{{model.precision}}`` precision.
.. code-block:: shell
./vllm_benchmark_report.sh -s latency -m {{model.model_repo}} -g 8 -d {{model.precision}}
./vllm_benchmark_report.sh -s throughput -m {{model.model_repo}} -g 8 -d {{model.precision}}
Find the throughput report at ``./reports_{{model.precision}}_vllm_rocm{{unified_docker.rocm_version}}/summary/{{model.model_repo.split('/', 1)[1] if '/' in model.model_repo else model.model_repo}}_throughput_report.csv``.

View File

@@ -12,7 +12,7 @@ ROCm is an optimized fork of the upstream
`<https://github.com/AI-Hypercomputer/maxtext>`__ enabling efficient AI workloads
on AMD MI300X series accelerators.
The MaxText for ROCm training Docker (``rocm/jax-training:maxtext-v25.4``) image
The MaxText for ROCm training Docker (``rocm/jax-training:maxtext-v25.5``) image
provides a prebuilt environment for training on AMD Instinct MI300X and MI325X accelerators,
including essential components like JAX, XLA, ROCm libraries, and MaxText utilities.
It includes the following software components:
@@ -20,15 +20,15 @@ It includes the following software components:
+--------------------------+--------------------------------+
| Software component | Version |
+==========================+================================+
| ROCm | 6.3.0 |
| ROCm | 6.3.4 |
+--------------------------+--------------------------------+
| JAX | 0.4.31 |
| JAX | 0.4.35 |
+--------------------------+--------------------------------+
| Python | 3.10 |
| Python | 3.10.12 |
+--------------------------+--------------------------------+
| Transformer Engine | 1.12.0.dev0+f81a3eb |
| Transformer Engine | 1.12.0.dev0+b8b92dc |
+--------------------------+--------------------------------+
| hipBLASLt | git78ec8622 |
| hipBLASLt | 0.13.0-ae9c477a |
+--------------------------+--------------------------------+
Supported features and models
@@ -48,6 +48,8 @@ MaxText provides the following key features to train large language models effic
The following models are pre-optimized for performance on AMD Instinct MI300X series accelerators.
* Llama 3.3 70B
* Llama 3.1 8B
* Llama 3.1 70B
@@ -115,7 +117,7 @@ with RDMA, skip ahead to :ref:`amd-maxtext-download-docker`.
a. Master address
Change `localhost` to the master node's resolvable hostname or IP address:
Change ``localhost`` to the master node's resolvable hostname or IP address:
.. code-block:: bash
@@ -180,13 +182,15 @@ Download the Docker image
.. code-block:: shell
docker pull rocm/jax-training:maxtext-v25.4
docker pull rocm/jax-training:maxtext-v25.5
2. Run the Docker container.
2. Use the following command to launch the Docker container. Note that the benchmarking scripts
used in the :ref:`following section <amd-maxtext-get-started>` automatically launch the Docker container
and execute the benchmark.
.. code-block:: shell
docker run -it --device /dev/dri --device /dev/kfd --network host --ipc host --group-add video --cap-add SYS_PTRACE --security-opt seccomp=unconfined --privileged -v $HOME/.ssh:/root/.ssh --shm-size 128G --name maxtext_training rocm/jax-training:maxtext-v25.4
docker run -it --device /dev/dri --device /dev/kfd --network host --ipc host --group-add video --cap-add SYS_PTRACE --security-opt seccomp=unconfined --privileged -v $HOME/.ssh:/root/.ssh --shm-size 128G --name maxtext_training rocm/jax-training:maxtext-v25.5
.. _amd-maxtext-get-started:
@@ -219,7 +223,9 @@ Single node training benchmarking examples
Run the single node training benchmark:
IMAGE="rocm/jax-training:maxtext-v25.4" bash ./llama2_7b.sh
.. code-block:: shell
IMAGE="rocm/jax-training:maxtext-v25.5" bash ./llama2_7b.sh
* Example 2: Single node training with Llama 2 70B
@@ -233,7 +239,7 @@ Single node training benchmarking examples
.. code-block:: shell
IMAGE="rocm/jax-training:maxtext-v25.4" bash ./llama2_70b.sh
IMAGE="rocm/jax-training:maxtext-v25.5" bash ./llama2_70b.sh
* Example 3: Single node training with Llama 3 8B
@@ -247,7 +253,7 @@ Single node training benchmarking examples
.. code-block:: shell
IMAGE="rocm/jax-training:maxtext-v25.4" bash ./llama3_8b.sh
IMAGE="rocm/jax-training:maxtext-v25.5" bash ./llama3_8b.sh
* Example 4: Single node training with Llama 3 70B
@@ -261,9 +267,23 @@ Single node training benchmarking examples
.. code-block:: shell
IMAGE="rocm/jax-training:maxtext-v25.4" bash ./llama3_70b.sh
IMAGE="rocm/jax-training:maxtext-v25.5" bash ./llama3_70b.sh
* Example 5: Single node training with DeepSeek V2 16B
* Example 5: Single node training with Llama 3.3 70B
Download the benchmarking script:
.. code-block:: shell
wget https://raw.githubusercontent.com/ROCm/maxtext/refs/heads/main/benchmarks/gpu-rocm/llama3.3_70b.sh
Run the single node training benchmark:
.. code-block:: shell
IMAGE="rocm/jax-training:maxtext-v25.5" bash ./llama3.3_70b.sh
* Example 6: Single node training with DeepSeek V2 16B
Download the benchmarking script:
@@ -275,7 +295,7 @@ Single node training benchmarking examples
.. code-block:: shell
IMAGE="rocm/jax-training:maxtext-v25.4" bash ./deepseek_v2_16b.sh
IMAGE="rocm/jax-training:maxtext-v25.5" bash ./deepseek_v2_16b.sh
.. note::
@@ -343,3 +363,26 @@ own cluster setup.
.. code-block:: shell
sbatch -N <num_nodes> llama3_70b_multinode.sh
Previous versions
=================
This table lists previous versions of the ROCm JAX MaxText Docker image for training
performance testing. For detailed information about available models for
benchmarking, see the version-specific documentation.
.. list-table::
:header-rows: 1
:stub-columns: 1
* - Image version
- ROCm version
- JAX version
- Resources
* - 25.4
- 6.3.0
- 0.4.31
-
* `Documentation <https://rocm.docs.amd.com/en/docs-6.3.3/how-to/rocm-for-ai/training/benchmark-docker/jax-maxtext.html>`_
* `Docker Hub <https://hub.docker.com/layers/rocm/jax-training/maxtext-v25.4/images/sha256-fb3eb71cd74298a7b3044b7130cf84113f14d518ff05a2cd625c11ea5f6a7b01>`_

View File

@@ -0,0 +1,168 @@
.. meta::
:description: How to train a model using LLM Foundry for ROCm.
:keywords: ROCm, AI, LLM, train, PyTorch, torch, Llama, flux, tutorial, docker
******************************************
Training MPT-30B with LLM Foundry and ROCm
******************************************
MPT-30B is a 30-billion parameter decoder-style transformer-based model from
the Mosaic Pretrained Transformer (MPT) family -- learn more about it in
MosaicML's research blog `MPT-30B: Raising the bar for open-source foundation
models <https://www.databricks.com/blog/mpt-30b>`_.
ROCm and `<https://github.com/ROCm/MAD>`__ provide a pre-configured training
environment for the MPT-30B model using the ``rocm/pytorch-training:v25.5``
base `Docker image <https://hub.docker.com/layers/rocm/pytorch-training/v25.5/images/sha256-d47850a9b25b4a7151f796a8d24d55ea17bba545573f0d50d54d3852f96ecde5>`_
and the `LLM Foundry <https://github.com/mosaicml/llm-foundry>`_ framework.
This environment packages the following software components to train
on AMD Instinct MI300X series accelerators:
+--------------------------+--------------------------------+
| Software component | Version |
+==========================+================================+
| ROCm | 6.3.4 |
+--------------------------+--------------------------------+
| PyTorch | 2.7.0a0+git6374332 |
+--------------------------+--------------------------------+
| Flash Attention | 3.0.0.post1 |
+--------------------------+--------------------------------+
Using this image, you can build, run, and test the training process
for MPT-30B with access to detailed logs and performance metrics.
System validation
=================
If you have already validated your system settings, including NUMA
auto-balancing, skip this step. Otherwise, complete the :ref:`system validation
and optimization steps <train-a-model-system-validation>` to set up your system
before starting training.
Getting started
===============
The following procedures help you set up the training environment in a
reproducible Docker container. This training environment is tailored for
training MPT-30B using LLM Foundry and the specific model configurations outlined.
Other configurations and run conditions outside those described in this
document are not validated.
.. tab-set::
.. tab-item:: MAD-integrated benchmarking
On your host machine, clone the ROCm Model Automation and Dashboarding
(`<https://github.com/ROCm/MAD>`__) repository to a local directory and
install the required packages.
.. code-block:: shell
git clone https://github.com/ROCm/MAD
cd MAD
pip install -r requirements.txt
Use this command to initiate the MPT-30B training benchmark.
.. code-block:: shell
python3 tools/run_models.py --tags pyt_mpt30b_training --keep-model-dir --live-output --clean-docker-cache
.. tip::
If you experience data download failures, set the
``MAD_SECRETS_HFTOKEN`` variable to your Hugging Face access token. See
`User access tokens <https://huggingface.co/docs/hub/security-tokens>`_
for details.
.. code-block:: shell
export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models"
.. note::
For improved performance (training throughput), consider enabling TunableOp.
By default, ``pyt_mpt30b_training`` runs with TunableOp disabled. To enable it,
run ``tools/run_models.py`` with the ``--tunableop on`` argument or edit the
``models.json`` configuration before running training.
Although this might increase the initial training time, it can result in a performance gain.
.. tab-item:: Standalone benchmarking
To set up the training environment, clone the
`<https://github.com/ROCm/MAD>`__ repo and build the Docker image. In
this snippet, the image is named ``mosaic_mpt30_image``.
.. code-block:: shell
git clone https://github.com/ROCm/MAD
cd MAD
docker build --build-arg MAD_SYSTEM_GPU_ARCHITECTURE=gfx942 -f docker/pyt_mpt30b_training.ubuntu.amd.Dockerfile -t mosaic_mpt30_image .
Start a ``mosaic_mpt30_image`` container using the following command.
.. code-block:: shell
docker run -it --device=/dev/kfd --device=/dev/dri --group-add=video --ipc=host --shm-size=8G mosaic_mpt30_image
In the Docker container, clone the `<https://github.com/ROCm/MAD>`__
repository and navigate to the benchmark scripts directory at
``/workspace/MAD/scripts/pyt_mpt30b_training``.
.. code-block:: shell
git clone https://github.com/ROCm/MAD
cd MAD/scripts/pyt_mpt30b_training
To initiate the training process, use the following command. This script uses the hyperparameters defined in
``mpt-30b-instruct.yaml``.
.. code-block:: shell
source run.sh
.. note::
For improved performance (training throughput), consider enabling TunableOp.
To enable it, add the ``--tunableop on`` flag.
.. code-block:: shell
source run.sh --tunableop on
Although this might increase the initial training time, it can result in a performance gain.
Interpreting the output
=======================
The training output will be displayed in the terminal and simultaneously saved
to the ``output.txt`` file in the current directory. Key performance metrics will
also be extracted and appended to the ``perf_pyt_mpt30b_training.csv`` file.
Key performance metrics include:
- Training logs: Real-time display of loss metrics, accuracy, and training progress.
- Model checkpoints: Periodically saved model snapshots for potential resume or evaluation.
- Performance metrics: Detailed summaries of training speed and training loss metrics.
- Performance (throughput/samples_per_sec)
Overall throughput, measuring the total samples processed per second. Higher values indicate better hardware utilization.
- Performance per device (throughput/samples_per_sec)
Throughput on a per-device basis, showing how each GPU or CPU is performing.
- Language Cross Entropy (metrics/train/LanguageCrossEntropy)
Measures prediction accuracy. Lower cross entropy suggests the models output is closer to the expected distribution.
- Training loss (loss/train/total)
Overall training loss. A decreasing trend indicates the model is learning effectively.

View File

@@ -443,7 +443,7 @@ benchmarking, see the version-specific documentation.
- 6.3.0
- 2.7.0a0+git637433
-
* `Documentation <https://rocm.docs.amd.com/en/docs-6.3.4/how-to/rocm-for-ai/training/benchmark-docker/pytorch-training.html>`_
* `Documentation <https://rocm.docs.amd.com/en/docs-6.3.3/how-to/rocm-for-ai/training/benchmark-docker/pytorch-training.html>`_
* `Docker Hub <https://hub.docker.com/layers/rocm/pytorch-training/v25.4/images/sha256-fa98a9aa69968e654466c06f05aaa12730db79b48b113c1ab4f7a5fe6920a20b>`_
* - v25.3

View File

@@ -45,6 +45,7 @@
(communication-libraries)=
* {doc}`RCCL <rccl:index>`
* [rocSHMEM](https://github.com/ROCm/rocSHMEM)
:::
:::{grid-item-card} Math

View File

@@ -12,14 +12,14 @@ subtrees:
- file: compatibility/compatibility-matrix.rst
title: Compatibility matrix
entries:
- url: https://rocm.docs.amd.com/projects/install-on-linux/en/latest/reference/system-requirements.html
- url: https://rocm.docs.amd.com/projects/install-on-linux/en/${branch}/reference/system-requirements.html
title: Linux system requirements
- url: https://rocm.docs.amd.com/projects/install-on-windows/en/${branch}/reference/system-requirements.html
title: Windows system requirements
- caption: Install
entries:
- url: https://rocm.docs.amd.com/projects/install-on-linux/en/latest/
- url: https://rocm.docs.amd.com/projects/install-on-linux/en/${branch}/
title: ROCm on Linux
- url: https://rocm.docs.amd.com/projects/install-on-windows/en/${branch}/
title: HIP SDK on Windows
@@ -46,6 +46,8 @@ subtrees:
title: Train a model with PyTorch
- file: how-to/rocm-for-ai/training/benchmark-docker/jax-maxtext
title: Train a model with JAX MaxText
- file: how-to/rocm-for-ai/training/benchmark-docker/mpt-llm-foundry
title: Train a model with LLM Foundry
- file: how-to/rocm-for-ai/training/scale-model-training.rst
title: Scale model training

View File

@@ -10,7 +10,7 @@ ROCm is a software stack, composed primarily of open-source software, that
provides the tools for programming AMD Graphics Processing Units (GPUs), from
low-level kernels to high-level end-user applications.
.. image:: data/rocm-software-stack-6_3_2.jpg
.. image:: data/rocm-software-stack-6_4_0.jpg
:width: 800
:alt: AMD's ROCm software stack and enabling technologies.
:align: center
@@ -52,6 +52,7 @@ Communication
:header: "Component", "Description"
":doc:`RCCL <rccl:index>`", "Standalone library that provides multi-GPU and multi-node collective communication primitives"
"`rocSHMEM <https://github.com/ROCm/rocSHMEM>`_", "Runtime that provides GPU-centric networking through an OpenSHMEM-like interface. This intra-kernel networking library simplifies application code complexity and enables more fine-grained communication/computation overlap than traditional host-driven networking"
Math
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^