Docs: Overhaul JAX compatibility page

This commit is contained in:
Adel Johar
2025-06-03 12:21:32 +02:00
parent 61c6749a10
commit c699aaf915
2 changed files with 92 additions and 294 deletions

View File

@@ -228,6 +228,7 @@ LM
LSAN
LSan
LTS
LSTMs
LanguageCrossEntropy
LoRA
MEM
@@ -679,6 +680,7 @@ installable
interop
interprocedural
intra
intrinsics
invariants
invocating
ipo
@@ -840,6 +842,7 @@ sm
smi
softmax
spack
spmm
src
stochastically
strided

View File

@@ -53,7 +53,7 @@ Use cases and recommendations
* The `nanoGPT in JAX <https://rocm.blogs.amd.com/artificial-intelligence/nanoGPT-JAX/README.html>`_
blog explores the implementation and training of a Generative Pre-trained
Transformer (GPT) model in JAX, inspired by Andrej Karpathys JAX-based
nanoGPT. Comparing how essential GPT components—such as self-attention
nanoGPT. Comparing how essential GPT components—such as self-attention
mechanisms and optimizers—are realized in JAX and JAX, also highlights
JAXs unique features.
@@ -160,12 +160,14 @@ associated inventories are tested for `ROCm 6.3.2 <https://repo.radeon.com/rocm/
- Ubuntu 22.04
- `3.10.16 <https://www.python.org/downloads/release/python-31016/>`_
.. _key_rocm_libraries:
Key ROCm libraries for JAX
================================================================================
JAX functionality on ROCm is determined by its underlying library
dependencies. These ROCm components affect the capabilities, performance, and
feature set available to developers.
The following ROCm libraries represent potential targets that could be utilized
by JAX on ROCm for various computational tasks. The actual libraries used will
depend on the specific implementation and operations performed.
.. list-table::
:header-rows: 1
@@ -173,347 +175,140 @@ feature set available to developers.
* - ROCm library
- Version
- Purpose
- Used in
* - `hipBLAS <https://github.com/ROCm/hipBLAS>`_
- :version-ref:`hipBLAS rocm_version`
- Provides GPU-accelerated Basic Linear Algebra Subprograms (BLAS) for
matrix and vector operations.
- Matrix multiplication in ``jax.numpy.matmul``, ``jax.lax.dot`` and
``jax.lax.dot_general``, operations like ``jax.numpy.dot``, which
involve vector and matrix computations and batch matrix multiplications
``jax.numpy.einsum`` with matrix-multiplication patterns algebra
operations.
* - `hipBLASLt <https://github.com/ROCm/hipBLASLt>`_
- :version-ref:`hipBLASLt rocm_version`
- hipBLASLt is an extension of hipBLAS, providing additional
features like epilogues fused into the matrix multiplication kernel or
use of integer tensor cores.
- Matrix multiplication in ``jax.numpy.matmul`` or ``jax.lax.dot``, and
the XLA (Accelerated Linear Algebra) use hipBLASLt for optimized matrix
operations, mixed-precision support, and hardware-specific
optimizations.
* - `hipCUB <https://github.com/ROCm/hipCUB>`_
- :version-ref:`hipCUB rocm_version`
- Provides a C++ template library for parallel algorithms for reduction,
scan, sort and select.
- Reduction functions (``jax.numpy.sum``, ``jax.numpy.mean``,
``jax.numpy.prod``, ``jax.numpy.max`` and ``jax.numpy.min``), prefix sum
(``jax.numpy.cumsum``, ``jax.numpy.cumprod``) and sorting
(``jax.numpy.sort``, ``jax.numpy.argsort``).
* - `hipFFT <https://github.com/ROCm/hipFFT>`_
- :version-ref:`hipFFT rocm_version`
- Provides GPU-accelerated Fast Fourier Transform (FFT) operations.
- Used in functions like ``jax.numpy.fft``.
* - `hipRAND <https://github.com/ROCm/hipRAND>`_
- :version-ref:`hipRAND rocm_version`
- Provides fast random number generation for GPUs.
- The ``jax.random.uniform``, ``jax.random.normal``,
``jax.random.randint`` and ``jax.random.split``.
* - `hipSOLVER <https://github.com/ROCm/hipSOLVER>`_
- :version-ref:`hipSOLVER rocm_version`
- Provides GPU-accelerated solvers for linear systems, eigenvalues, and
singular value decompositions (SVD).
- Solving linear systems (``jax.numpy.linalg.solve``), matrix
factorizations, SVD (``jax.numpy.linalg.svd``) and eigenvalue problems
(``jax.numpy.linalg.eig``).
* - `hipSPARSE <https://github.com/ROCm/hipSPARSE>`_
- :version-ref:`hipSPARSE rocm_version`
- Accelerates operations on sparse matrices, such as sparse matrix-vector
or matrix-matrix products.
- Sparse matrix multiplication (``jax.numpy.matmul``), sparse
matrix-vector and matrix-matrix products
(``jax.experimental.sparse.dot``), sparse linear system solvers and
sparse data handling.
* - `hipSPARSELt <https://github.com/ROCm/hipSPARSELt>`_
- :version-ref:`hipSPARSELt rocm_version`
- Accelerates operations on sparse matrices, such as sparse matrix-vector
or matrix-matrix products.
- Sparse matrix multiplication (``jax.numpy.matmul``), sparse
matrix-vector and matrix-matrix products
(``jax.experimental.sparse.dot``) and sparse linear system solvers.
* - `MIOpen <https://github.com/ROCm/MIOpen>`_
- :version-ref:`MIOpen rocm_version`
- Optimized for deep learning primitives such as convolutions, pooling,
normalization, and activation functions.
- Speeds up convolutional neural networks (CNNs), recurrent neural
networks (RNNs), and other layers. Used in operations like
``jax.nn.conv``, ``jax.nn.relu``, and ``jax.nn.batch_norm``.
* - `RCCL <https://github.com/ROCm/rccl>`_
- :version-ref:`RCCL rocm_version`
- Optimized for multi-GPU communication for operations like all-reduce,
broadcast, and scatter.
- Distribute computations across multiple GPU with ``pmap`` and
``jax.distributed``. XLA automatically uses rccl when executing
operations across multiple GPUs on AMD hardware.
* - `rocThrust <https://github.com/ROCm/rocThrust>`_
- :version-ref:`rocThrust rocm_version`
- Provides a C++ template library for parallel algorithms like sorting,
reduction, and scanning.
- Reduction operations like ``jax.numpy.sum``, ``jax.pmap`` for
distributed training, which involves parallel reductions or
operations like ``jax.numpy.cumsum`` can use rocThrust.
Supported features
.. note::
This table shows ROCm libraries that could potentially be utilized by JAX. Not
all libraries may be used in every configuration, and the actual library usage
will depend on the specific operations and implementation details.
Supported data types and modules
===============================================================================
The following table maps the public JAX API modules to their supported
ROCm and JAX versions.
The following tables lists the supported public JAX API data types and modules.
Supported data types
--------------------------------------------------------------------------------
ROCm supports all the JAX data types of `jax.dtypes <https://docs.jax.dev/en/latest/jax.dtypes.html>`_
module, `jax.numpy.dtype <https://docs.jax.dev/en/latest/_autosummary/jax.numpy.dtype.html>`_
and `default_dtype <https://docs.jax.dev/en/latest/default_dtypes.html>`_ .
The ROCm supported data types in JAX are collected in the following table.
.. list-table::
:header-rows: 1
* - Module
- Description
- As of JAX
- As of ROCm
* - ``jax.numpy``
- Implements the NumPy API, using the primitives in ``jax.lax``.
- 0.1.56
- 5.0.0
* - ``jax.scipy``
- Provides GPU-accelerated and differentiable implementations of many
functions from the SciPy library, leveraging JAX's transformations
(e.g., ``grad``, ``jit``, ``vmap``).
- 0.1.56
- 5.0.0
* - ``jax.lax``
- A library of primitives operations that underpins libraries such as
``jax.numpy.`` Transformation rules, such as Jacobian-vector product
(JVP) and batching rules, are typically defined as transformations on
``jax.lax`` primitives.
- 0.1.57
- 5.0.0
* - ``jax.random``
- Provides a number of routines for deterministic generation of sequences
of pseudorandom numbers.
- 0.1.58
- 5.0.0
* - ``jax.sharding``
- Allows to define partitioning and distributing arrays across multiple
devices.
- 0.3.20
- 5.1.0
* - ``jax.distributed``
- Enables the scaling of computations across multiple devices on a single
machine or across multiple machines.
- 0.1.74
- 5.0.0
* - ``jax.image``
- Contains image manipulation functions like resize, scale and translation.
- 0.1.57
- 5.0.0
* - ``jax.nn``
- Contains common functions for neural network libraries.
- 0.1.56
- 5.0.0
* - ``jax.ops``
- Computes the minimum, maximum, sum or product within segments of an
array.
- 0.1.57
- 5.0.0
* - ``jax.stages``
- Contains interfaces to stages of the compiled execution process.
- 0.3.4
- 5.0.0
* - ``jax.extend``
- Provides modules for access to JAX internal machinery module. The
``jax.extend`` module defines a library view of some of JAXs internal
components.
- 0.4.15
- 5.5.0
* - ``jax.example_libraries``
- Serves as a collection of example code and libraries that demonstrate
various capabilities of JAX.
- 0.1.74
- 5.0.0
* - ``jax.experimental``
- Namespace for experimental features and APIs that are in development or
are not yet fully stable for production use.
- 0.1.56
- 5.0.0
* - ``jax.lib``
- Set of internal tools and types for bridging between JAXs Python
frontend and its XLA backend.
- 0.4.6
- 5.3.0
* - ``jax_triton``
- Library that integrates the Triton deep learning compiler with JAX.
- jax_triton 0.2.0
- 6.2.4
jax.scipy module
-------------------------------------------------------------------------------
A SciPy-like API for scientific computing.
.. list-table::
:header-rows: 1
* - Module
- As of JAX
- As of ROCm
* - ``jax.scipy.cluster``
- 0.3.11
- 5.1.0
* - ``jax.scipy.fft``
- 0.1.71
- 5.0.0
* - ``jax.scipy.integrate``
- 0.4.15
- 5.5.0
* - ``jax.scipy.interpolate``
- 0.1.76
- 5.0.0
* - ``jax.scipy.linalg``
- 0.1.56
- 5.0.0
* - ``jax.scipy.ndimage``
- 0.1.56
- 5.0.0
* - ``jax.scipy.optimize``
- 0.1.57
- 5.0.0
* - ``jax.scipy.signal``
- 0.1.56
- 5.0.0
* - ``jax.scipy.spatial.transform``
- 0.4.12
- 5.4.0
* - ``jax.scipy.sparse.linalg``
- 0.1.56
- 5.0.0
* - ``jax.scipy.special``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats``
- 0.1.56
- 5.0.0
jax.scipy.stats module
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. list-table::
:header-rows: 1
* - Module
- As of JAX
- As of ROCm
* - ``jax.scipy.stats.bernouli``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.beta``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.betabinom``
- 0.1.61
- 5.0.0
* - ``jax.scipy.stats.binom``
- 0.4.14
- 5.4.0
* - ``jax.scipy.stats.cauchy``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.chi2``
- 0.1.61
- 5.0.0
* - ``jax.scipy.stats.dirichlet``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.expon``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.gamma``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.gennorm``
- 0.3.15
- 5.2.0
* - ``jax.scipy.stats.geom``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.laplace``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.logistic``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.multinomial``
- 0.3.18
- 5.1.0
* - ``jax.scipy.stats.multivariate_normal``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.nbinom``
- 0.1.72
- 5.0.0
* - ``jax.scipy.stats.norm``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.pareto``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.poisson``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.t``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.truncnorm``
- 0.4.0
- 5.3.0
* - ``jax.scipy.stats.uniform``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.vonmises``
- 0.4.2
- 5.3.0
* - ``jax.scipy.stats.wrapcauchy``
- 0.4.20
- 5.6.0
jax.extend module
-------------------------------------------------------------------------------
Modules for JAX extensions.
.. list-table::
:header-rows: 1
* - Module
- As of JAX
- As of ROCm
* - ``jax.extend.ffi``
- 0.4.30
- 6.0.0
* - ``jax.extend.linear_util``
- 0.4.17
- 5.6.0
* - ``jax.extend.mlir``
- 0.4.26
- 5.6.0
* - ``jax.extend.random``
- 0.4.15
- 5.5.0
Unsupported JAX features
===============================================================================
The following GPU-accelerated JAX features are not supported by ROCm for
the listed supported JAX versions.
.. list-table::
:header-rows: 1
* - Feature
* - Data type
- Description
* - Mixed Precision with TF32
- Mixed precision with TF32 is used for matrix multiplications,
convolutions, and other linear algebra operations, particularly in
deep learning workloads like CNNs and transformers.
* - ``bfloat16``
- 16-bit bfloat (brain floating point).
* - XLA int4 support
- 4-bit integer (int4) precision in the XLA compiler.
* - ``bool``
- Boolean.
* - MOSAIC (GPU)
- Mosaic is a library of kernel-building abstractions for JAX's Pallas system
* - ``complex128``
- 128-bit complex.
* - ``complex64``
- 64-bit complex.
* - ``float16``
- 16-bit (half precision) floating-point.
* - ``float32``
- 32-bit (single precision) floating-point.
* - ``float64``
- 64-bit (double precision) floating-point.
* - ``half``
- 16-bit (half precision) floating-point.
* - ``int16``
- Signed 16-bit integer.
* - ``int32``
- Signed 32-bit integer.
* - ``int64``
- Signed 64-bit integer.
* - ``int8``
- Signed 8-bit integer.
* - ``uint16``
- Unsigned 16-bit (word) integer.
* - ``uint32``
- Unsigned 32-bit (dword) integer.
* - ``uint64``
- Unsigned 64-bit (qword) integer.
* - ``uint8``
- Unsigned 8-bit (byte) integer.
.. note::
JAX data type support is effected by the :ref:`key_rocm_libraries` and it's
collected on :doc:`ROCm data types and precision support <rocm:reference/precision-support>`
page.
Supported modules
--------------------------------------------------------------------------------
For a complete and up-to-date list of JAX public modules (for example, ``jax.numpy``,
``jax.scipy``, ``jax.lax``), their descriptions, and usage, please refer directly to the
`official JAX API documentation <https://jax.readthedocs.io/en/latest/jax.html>`_.
.. note::
Since version 0.1.56, JAX has full support for ROCm, and the
:ref:`Known issues and important notes <jax_comp_known_issues>` section
contains details about limitations specific to the ROCm backend. The list of
JAX API modules is maintained by the JAX project and is subject to change.
Refer to the official Jax documentation for the most up-to-date information.