From c699aaf915fd6d798e50df93ab67f969d797a939 Mon Sep 17 00:00:00 2001 From: Adel Johar Date: Tue, 3 Jun 2025 12:21:32 +0200 Subject: [PATCH] Docs: Overhaul JAX compatibility page --- .wordlist.txt | 3 + .../ml-compatibility/jax-compatibility.rst | 383 ++++-------------- 2 files changed, 92 insertions(+), 294 deletions(-) diff --git a/.wordlist.txt b/.wordlist.txt index 747b118eb..32177354e 100644 --- a/.wordlist.txt +++ b/.wordlist.txt @@ -228,6 +228,7 @@ LM LSAN LSan LTS +LSTMs LanguageCrossEntropy LoRA MEM @@ -679,6 +680,7 @@ installable interop interprocedural intra +intrinsics invariants invocating ipo @@ -840,6 +842,7 @@ sm smi softmax spack +spmm src stochastically strided diff --git a/docs/compatibility/ml-compatibility/jax-compatibility.rst b/docs/compatibility/ml-compatibility/jax-compatibility.rst index 626ee10bc..c93c728ff 100644 --- a/docs/compatibility/ml-compatibility/jax-compatibility.rst +++ b/docs/compatibility/ml-compatibility/jax-compatibility.rst @@ -53,7 +53,7 @@ Use cases and recommendations * The `nanoGPT in JAX `_ blog explores the implementation and training of a Generative Pre-trained Transformer (GPT) model in JAX, inspired by Andrej Karpathy’s JAX-based - nanoGPT. Comparing how essential GPT components—such as self-attention + nanoGPT. Comparing how essential GPT components—such as self-attention mechanisms and optimizers—are realized in JAX and JAX, also highlights JAX’s unique features. @@ -160,12 +160,14 @@ associated inventories are tested for `ROCm 6.3.2 `_ +.. _key_rocm_libraries: + Key ROCm libraries for JAX ================================================================================ -JAX functionality on ROCm is determined by its underlying library -dependencies. These ROCm components affect the capabilities, performance, and -feature set available to developers. +The following ROCm libraries represent potential targets that could be utilized +by JAX on ROCm for various computational tasks. The actual libraries used will +depend on the specific implementation and operations performed. .. list-table:: :header-rows: 1 @@ -173,347 +175,140 @@ feature set available to developers. * - ROCm library - Version - Purpose - - Used in * - `hipBLAS `_ - :version-ref:`hipBLAS rocm_version` - Provides GPU-accelerated Basic Linear Algebra Subprograms (BLAS) for matrix and vector operations. - - Matrix multiplication in ``jax.numpy.matmul``, ``jax.lax.dot`` and - ``jax.lax.dot_general``, operations like ``jax.numpy.dot``, which - involve vector and matrix computations and batch matrix multiplications - ``jax.numpy.einsum`` with matrix-multiplication patterns algebra - operations. * - `hipBLASLt `_ - :version-ref:`hipBLASLt rocm_version` - hipBLASLt is an extension of hipBLAS, providing additional features like epilogues fused into the matrix multiplication kernel or use of integer tensor cores. - - Matrix multiplication in ``jax.numpy.matmul`` or ``jax.lax.dot``, and - the XLA (Accelerated Linear Algebra) use hipBLASLt for optimized matrix - operations, mixed-precision support, and hardware-specific - optimizations. * - `hipCUB `_ - :version-ref:`hipCUB rocm_version` - Provides a C++ template library for parallel algorithms for reduction, scan, sort and select. - - Reduction functions (``jax.numpy.sum``, ``jax.numpy.mean``, - ``jax.numpy.prod``, ``jax.numpy.max`` and ``jax.numpy.min``), prefix sum - (``jax.numpy.cumsum``, ``jax.numpy.cumprod``) and sorting - (``jax.numpy.sort``, ``jax.numpy.argsort``). * - `hipFFT `_ - :version-ref:`hipFFT rocm_version` - Provides GPU-accelerated Fast Fourier Transform (FFT) operations. - - Used in functions like ``jax.numpy.fft``. * - `hipRAND `_ - :version-ref:`hipRAND rocm_version` - Provides fast random number generation for GPUs. - - The ``jax.random.uniform``, ``jax.random.normal``, - ``jax.random.randint`` and ``jax.random.split``. * - `hipSOLVER `_ - :version-ref:`hipSOLVER rocm_version` - Provides GPU-accelerated solvers for linear systems, eigenvalues, and singular value decompositions (SVD). - - Solving linear systems (``jax.numpy.linalg.solve``), matrix - factorizations, SVD (``jax.numpy.linalg.svd``) and eigenvalue problems - (``jax.numpy.linalg.eig``). * - `hipSPARSE `_ - :version-ref:`hipSPARSE rocm_version` - Accelerates operations on sparse matrices, such as sparse matrix-vector or matrix-matrix products. - - Sparse matrix multiplication (``jax.numpy.matmul``), sparse - matrix-vector and matrix-matrix products - (``jax.experimental.sparse.dot``), sparse linear system solvers and - sparse data handling. * - `hipSPARSELt `_ - :version-ref:`hipSPARSELt rocm_version` - Accelerates operations on sparse matrices, such as sparse matrix-vector or matrix-matrix products. - - Sparse matrix multiplication (``jax.numpy.matmul``), sparse - matrix-vector and matrix-matrix products - (``jax.experimental.sparse.dot``) and sparse linear system solvers. * - `MIOpen `_ - :version-ref:`MIOpen rocm_version` - Optimized for deep learning primitives such as convolutions, pooling, normalization, and activation functions. - - Speeds up convolutional neural networks (CNNs), recurrent neural - networks (RNNs), and other layers. Used in operations like - ``jax.nn.conv``, ``jax.nn.relu``, and ``jax.nn.batch_norm``. * - `RCCL `_ - :version-ref:`RCCL rocm_version` - Optimized for multi-GPU communication for operations like all-reduce, broadcast, and scatter. - - Distribute computations across multiple GPU with ``pmap`` and - ``jax.distributed``. XLA automatically uses rccl when executing - operations across multiple GPUs on AMD hardware. * - `rocThrust `_ - :version-ref:`rocThrust rocm_version` - Provides a C++ template library for parallel algorithms like sorting, reduction, and scanning. - - Reduction operations like ``jax.numpy.sum``, ``jax.pmap`` for - distributed training, which involves parallel reductions or - operations like ``jax.numpy.cumsum`` can use rocThrust. -Supported features +.. note:: + + This table shows ROCm libraries that could potentially be utilized by JAX. Not + all libraries may be used in every configuration, and the actual library usage + will depend on the specific operations and implementation details. + +Supported data types and modules =============================================================================== -The following table maps the public JAX API modules to their supported -ROCm and JAX versions. +The following tables lists the supported public JAX API data types and modules. + +Supported data types +-------------------------------------------------------------------------------- + +ROCm supports all the JAX data types of `jax.dtypes `_ +module, `jax.numpy.dtype `_ +and `default_dtype `_ . +The ROCm supported data types in JAX are collected in the following table. .. list-table:: :header-rows: 1 - * - Module - - Description - - As of JAX - - As of ROCm - * - ``jax.numpy`` - - Implements the NumPy API, using the primitives in ``jax.lax``. - - 0.1.56 - - 5.0.0 - * - ``jax.scipy`` - - Provides GPU-accelerated and differentiable implementations of many - functions from the SciPy library, leveraging JAX's transformations - (e.g., ``grad``, ``jit``, ``vmap``). - - 0.1.56 - - 5.0.0 - * - ``jax.lax`` - - A library of primitives operations that underpins libraries such as - ``jax.numpy.`` Transformation rules, such as Jacobian-vector product - (JVP) and batching rules, are typically defined as transformations on - ``jax.lax`` primitives. - - 0.1.57 - - 5.0.0 - * - ``jax.random`` - - Provides a number of routines for deterministic generation of sequences - of pseudorandom numbers. - - 0.1.58 - - 5.0.0 - * - ``jax.sharding`` - - Allows to define partitioning and distributing arrays across multiple - devices. - - 0.3.20 - - 5.1.0 - * - ``jax.distributed`` - - Enables the scaling of computations across multiple devices on a single - machine or across multiple machines. - - 0.1.74 - - 5.0.0 - * - ``jax.image`` - - Contains image manipulation functions like resize, scale and translation. - - 0.1.57 - - 5.0.0 - * - ``jax.nn`` - - Contains common functions for neural network libraries. - - 0.1.56 - - 5.0.0 - * - ``jax.ops`` - - Computes the minimum, maximum, sum or product within segments of an - array. - - 0.1.57 - - 5.0.0 - * - ``jax.stages`` - - Contains interfaces to stages of the compiled execution process. - - 0.3.4 - - 5.0.0 - * - ``jax.extend`` - - Provides modules for access to JAX internal machinery module. The - ``jax.extend`` module defines a library view of some of JAX’s internal - components. - - 0.4.15 - - 5.5.0 - * - ``jax.example_libraries`` - - Serves as a collection of example code and libraries that demonstrate - various capabilities of JAX. - - 0.1.74 - - 5.0.0 - * - ``jax.experimental`` - - Namespace for experimental features and APIs that are in development or - are not yet fully stable for production use. - - 0.1.56 - - 5.0.0 - * - ``jax.lib`` - - Set of internal tools and types for bridging between JAX’s Python - frontend and its XLA backend. - - 0.4.6 - - 5.3.0 - * - ``jax_triton`` - - Library that integrates the Triton deep learning compiler with JAX. - - jax_triton 0.2.0 - - 6.2.4 - -jax.scipy module -------------------------------------------------------------------------------- - -A SciPy-like API for scientific computing. - -.. list-table:: - :header-rows: 1 - - * - Module - - As of JAX - - As of ROCm - * - ``jax.scipy.cluster`` - - 0.3.11 - - 5.1.0 - * - ``jax.scipy.fft`` - - 0.1.71 - - 5.0.0 - * - ``jax.scipy.integrate`` - - 0.4.15 - - 5.5.0 - * - ``jax.scipy.interpolate`` - - 0.1.76 - - 5.0.0 - * - ``jax.scipy.linalg`` - - 0.1.56 - - 5.0.0 - * - ``jax.scipy.ndimage`` - - 0.1.56 - - 5.0.0 - * - ``jax.scipy.optimize`` - - 0.1.57 - - 5.0.0 - * - ``jax.scipy.signal`` - - 0.1.56 - - 5.0.0 - * - ``jax.scipy.spatial.transform`` - - 0.4.12 - - 5.4.0 - * - ``jax.scipy.sparse.linalg`` - - 0.1.56 - - 5.0.0 - * - ``jax.scipy.special`` - - 0.1.56 - - 5.0.0 - * - ``jax.scipy.stats`` - - 0.1.56 - - 5.0.0 - -jax.scipy.stats module -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. list-table:: - :header-rows: 1 - - * - Module - - As of JAX - - As of ROCm - * - ``jax.scipy.stats.bernouli`` - - 0.1.56 - - 5.0.0 - * - ``jax.scipy.stats.beta`` - - 0.1.56 - - 5.0.0 - * - ``jax.scipy.stats.betabinom`` - - 0.1.61 - - 5.0.0 - * - ``jax.scipy.stats.binom`` - - 0.4.14 - - 5.4.0 - * - ``jax.scipy.stats.cauchy`` - - 0.1.56 - - 5.0.0 - * - ``jax.scipy.stats.chi2`` - - 0.1.61 - - 5.0.0 - * - ``jax.scipy.stats.dirichlet`` - - 0.1.56 - - 5.0.0 - * - ``jax.scipy.stats.expon`` - - 0.1.56 - - 5.0.0 - * - ``jax.scipy.stats.gamma`` - - 0.1.56 - - 5.0.0 - * - ``jax.scipy.stats.gennorm`` - - 0.3.15 - - 5.2.0 - * - ``jax.scipy.stats.geom`` - - 0.1.56 - - 5.0.0 - * - ``jax.scipy.stats.laplace`` - - 0.1.56 - - 5.0.0 - * - ``jax.scipy.stats.logistic`` - - 0.1.56 - - 5.0.0 - * - ``jax.scipy.stats.multinomial`` - - 0.3.18 - - 5.1.0 - * - ``jax.scipy.stats.multivariate_normal`` - - 0.1.56 - - 5.0.0 - * - ``jax.scipy.stats.nbinom`` - - 0.1.72 - - 5.0.0 - * - ``jax.scipy.stats.norm`` - - 0.1.56 - - 5.0.0 - * - ``jax.scipy.stats.pareto`` - - 0.1.56 - - 5.0.0 - * - ``jax.scipy.stats.poisson`` - - 0.1.56 - - 5.0.0 - * - ``jax.scipy.stats.t`` - - 0.1.56 - - 5.0.0 - * - ``jax.scipy.stats.truncnorm`` - - 0.4.0 - - 5.3.0 - * - ``jax.scipy.stats.uniform`` - - 0.1.56 - - 5.0.0 - * - ``jax.scipy.stats.vonmises`` - - 0.4.2 - - 5.3.0 - * - ``jax.scipy.stats.wrapcauchy`` - - 0.4.20 - - 5.6.0 - -jax.extend module -------------------------------------------------------------------------------- - -Modules for JAX extensions. - -.. list-table:: - :header-rows: 1 - - * - Module - - As of JAX - - As of ROCm - * - ``jax.extend.ffi`` - - 0.4.30 - - 6.0.0 - * - ``jax.extend.linear_util`` - - 0.4.17 - - 5.6.0 - * - ``jax.extend.mlir`` - - 0.4.26 - - 5.6.0 - * - ``jax.extend.random`` - - 0.4.15 - - 5.5.0 - -Unsupported JAX features -=============================================================================== - -The following GPU-accelerated JAX features are not supported by ROCm for -the listed supported JAX versions. - -.. list-table:: - :header-rows: 1 - - * - Feature + * - Data type - Description - * - Mixed Precision with TF32 - - Mixed precision with TF32 is used for matrix multiplications, - convolutions, and other linear algebra operations, particularly in - deep learning workloads like CNNs and transformers. + * - ``bfloat16`` + - 16-bit bfloat (brain floating point). - * - XLA int4 support - - 4-bit integer (int4) precision in the XLA compiler. + * - ``bool`` + - Boolean. - * - MOSAIC (GPU) - - Mosaic is a library of kernel-building abstractions for JAX's Pallas system + * - ``complex128`` + - 128-bit complex. + + * - ``complex64`` + - 64-bit complex. + + * - ``float16`` + - 16-bit (half precision) floating-point. + + * - ``float32`` + - 32-bit (single precision) floating-point. + + * - ``float64`` + - 64-bit (double precision) floating-point. + + * - ``half`` + - 16-bit (half precision) floating-point. + + * - ``int16`` + - Signed 16-bit integer. + + * - ``int32`` + - Signed 32-bit integer. + + * - ``int64`` + - Signed 64-bit integer. + + * - ``int8`` + - Signed 8-bit integer. + + * - ``uint16`` + - Unsigned 16-bit (word) integer. + + * - ``uint32`` + - Unsigned 32-bit (dword) integer. + + * - ``uint64`` + - Unsigned 64-bit (qword) integer. + + * - ``uint8`` + - Unsigned 8-bit (byte) integer. + +.. note:: + + JAX data type support is effected by the :ref:`key_rocm_libraries` and it's + collected on :doc:`ROCm data types and precision support ` + page. + +Supported modules +-------------------------------------------------------------------------------- + +For a complete and up-to-date list of JAX public modules (for example, ``jax.numpy``, +``jax.scipy``, ``jax.lax``), their descriptions, and usage, please refer directly to the +`official JAX API documentation `_. + +.. note:: + + Since version 0.1.56, JAX has full support for ROCm, and the + :ref:`Known issues and important notes ` section + contains details about limitations specific to the ROCm backend. The list of + JAX API modules is maintained by the JAX project and is subject to change. + Refer to the official Jax documentation for the most up-to-date information.