Docs: Overhaul JAX compatibility page for ROCm 7.0

This commit is contained in:
Adel Johar
2025-08-28 17:08:27 +02:00
committed by GitHub
parent 53bd9b5da4
commit 04beef8773

View File

@@ -27,7 +27,7 @@ with ROCm support:
- Offers AMD-validated and community :ref:`Docker images <jax-docker-compat>`
with ROCm and JAX preinstalled.
- ROCm JAX repository: `ROCm/jax <https://github.com/ROCm/jax>`_
- ROCm JAX repository: `ROCm/rocm-jax <https://github.com/ROCm/rocm-jax>`_
- See the :doc:`ROCm JAX installation guide <rocm-install-on-linux:install/3rd-party/jax-install>`
to get started.
@@ -310,5 +310,54 @@ For a complete and up-to-date list of JAX public modules (for example, ``jax.num
Since version 0.1.56, JAX has full support for ROCm, and the
:ref:`Known issues and important notes <jax_comp_known_issues>` section
contains details about limitations specific to the ROCm backend. The list of
JAX API modules is maintained by the JAX project and is subject to change.
JAX API modules are maintained by the JAX project and is subject to change.
Refer to the official Jax documentation for the most up-to-date information.
Key features and enhancements for ROCm 7.0
===============================================================================
- Upgraded XLA backend: Integrates a newer XLA version, enabling better
optimizations, broader operator support, and potential performance gains.
- RNN support: Native RNN support (including LSTMs via ``jax.experimental.rnn``)
now available on ROCm, aiding sequence model development.
- Comprehensive linear algebra capabilities: Offers robust ``jax.linalg``
operations, essential for scientific and machine learning tasks.
- Expanded AMD GPU architecture support: Provides ongoing support for gfx1101
GPUs and introduces support for gfx950 and gfx12xx GPUs.
- Mixed FP8 precision support: Enables ``lax.dot_general`` operations with mixed FP8
types, offering pathways for memory and compute efficiency.
- Streamlined PyPi packaging: Provides reliable PyPi wheels for JAX on ROCm,
simplifying the installation process.
- Pallas experimental kernel development: Continued Pallas framework
enhancements for custom GPU kernels, including new intrinsics (specific
kernel behaviors under review).
- Improved build system and CI: Enhanced ROCm build system and CI for greater
reliability and maintainability.
- Enhanced distributed computing setup: Improved JAX setup in multi-GPU
distributed environments.
.. _jax_comp_known_issues:
Known issues and notes for ROCm 7.0
===============================================================================
- ``nn.dot_product_attention``: Certain configurations of ``jax.nn.dot_product_attention``
may cause segmentation faults, though the majority of use cases work correctly.
- SVD with dynamic shapes: SVD on inputs with dynamic/symbolic shapes might result in an error.
SVD with static shapes is unaffected.
- QR decomposition with symbolic shapes: QR decomposition operations may fail when using
symbolic/dynamic shapes in shape polymorphic contexts.
- Pallas kernels: Specific advanced Pallas kernels may exhibit variations in
numerical output or resource usage. These are actively reviewed as part of
Pallas's experimental development.