mirror of
https://github.com/ROCm/ROCm.git
synced 2026-01-10 07:08:08 -05:00
Docs: Overhaul JAX compatibility page for ROCm 7.0
This commit is contained in:
@@ -27,7 +27,7 @@ with ROCm support:
|
||||
- Offers AMD-validated and community :ref:`Docker images <jax-docker-compat>`
|
||||
with ROCm and JAX preinstalled.
|
||||
|
||||
- ROCm JAX repository: `ROCm/jax <https://github.com/ROCm/jax>`_
|
||||
- ROCm JAX repository: `ROCm/rocm-jax <https://github.com/ROCm/rocm-jax>`_
|
||||
|
||||
- See the :doc:`ROCm JAX installation guide <rocm-install-on-linux:install/3rd-party/jax-install>`
|
||||
to get started.
|
||||
@@ -310,5 +310,54 @@ For a complete and up-to-date list of JAX public modules (for example, ``jax.num
|
||||
Since version 0.1.56, JAX has full support for ROCm, and the
|
||||
:ref:`Known issues and important notes <jax_comp_known_issues>` section
|
||||
contains details about limitations specific to the ROCm backend. The list of
|
||||
JAX API modules is maintained by the JAX project and is subject to change.
|
||||
JAX API modules are maintained by the JAX project and is subject to change.
|
||||
Refer to the official Jax documentation for the most up-to-date information.
|
||||
|
||||
Key features and enhancements for ROCm 7.0
|
||||
===============================================================================
|
||||
|
||||
- Upgraded XLA backend: Integrates a newer XLA version, enabling better
|
||||
optimizations, broader operator support, and potential performance gains.
|
||||
|
||||
- RNN support: Native RNN support (including LSTMs via ``jax.experimental.rnn``)
|
||||
now available on ROCm, aiding sequence model development.
|
||||
|
||||
- Comprehensive linear algebra capabilities: Offers robust ``jax.linalg``
|
||||
operations, essential for scientific and machine learning tasks.
|
||||
|
||||
- Expanded AMD GPU architecture support: Provides ongoing support for gfx1101
|
||||
GPUs and introduces support for gfx950 and gfx12xx GPUs.
|
||||
|
||||
- Mixed FP8 precision support: Enables ``lax.dot_general`` operations with mixed FP8
|
||||
types, offering pathways for memory and compute efficiency.
|
||||
|
||||
- Streamlined PyPi packaging: Provides reliable PyPi wheels for JAX on ROCm,
|
||||
simplifying the installation process.
|
||||
|
||||
- Pallas experimental kernel development: Continued Pallas framework
|
||||
enhancements for custom GPU kernels, including new intrinsics (specific
|
||||
kernel behaviors under review).
|
||||
|
||||
- Improved build system and CI: Enhanced ROCm build system and CI for greater
|
||||
reliability and maintainability.
|
||||
|
||||
- Enhanced distributed computing setup: Improved JAX setup in multi-GPU
|
||||
distributed environments.
|
||||
|
||||
.. _jax_comp_known_issues:
|
||||
|
||||
Known issues and notes for ROCm 7.0
|
||||
===============================================================================
|
||||
|
||||
- ``nn.dot_product_attention``: Certain configurations of ``jax.nn.dot_product_attention``
|
||||
may cause segmentation faults, though the majority of use cases work correctly.
|
||||
|
||||
- SVD with dynamic shapes: SVD on inputs with dynamic/symbolic shapes might result in an error.
|
||||
SVD with static shapes is unaffected.
|
||||
|
||||
- QR decomposition with symbolic shapes: QR decomposition operations may fail when using
|
||||
symbolic/dynamic shapes in shape polymorphic contexts.
|
||||
|
||||
- Pallas kernels: Specific advanced Pallas kernels may exhibit variations in
|
||||
numerical output or resource usage. These are actively reviewed as part of
|
||||
Pallas's experimental development.
|
||||
|
||||
Reference in New Issue
Block a user