Files
ROCm/docs/compatibility/ml-compatibility/jax-compatibility.rst
2025-05-13 17:39:25 +02:00

520 lines
19 KiB
ReStructuredText
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
:orphan:
.. meta::
:description: JAX compatibility
:keywords: GPU, JAX compatibility
.. version-set:: rocm_version latest
*******************************************************************************
JAX compatibility
*******************************************************************************
JAX provides a NumPy-like API, which combines automatic differentiation and the
Accelerated Linear Algebra (XLA) compiler to achieve high-performance machine
learning at scale.
JAX uses composable transformations of Python and NumPy through just-in-time
(JIT) compilation, automatic vectorization, and parallelization. To learn about
JAX, including profiling and optimizations, see the official `JAX documentation
<https://jax.readthedocs.io/en/latest/notebooks/quickstart.html>`_.
ROCm support for JAX is upstreamed, and users can build the official source code
with ROCm support:
- ROCm JAX release:
- Offers AMD-validated and community :ref:`Docker images <jax-docker-compat>`
with ROCm and JAX preinstalled.
- ROCm JAX repository: `ROCm/jax <https://github.com/ROCm/jax>`_
- See the :doc:`ROCm JAX installation guide <rocm-install-on-linux:install/3rd-party/jax-install>`
to get started.
- Official JAX release:
- Official JAX repository: `jax-ml/jax <https://github.com/jax-ml/jax>`_
- See the `AMD GPU (Linux) installation section
<https://jax.readthedocs.io/en/latest/installation.html#amd-gpu-linux>`_ in
the JAX documentation.
.. note::
AMD releases official `ROCm JAX Docker images <https://hub.docker.com/r/rocm/jax>`_
quarterly alongside new ROCm releases. These images undergo full AMD testing.
`Community ROCm JAX Docker images <https://hub.docker.com/r/rocm/jax-community>`_
follow upstream JAX releases and use the latest available ROCm version.
Use cases and recommendations
================================================================================
* The `nanoGPT in JAX <https://rocm.blogs.amd.com/artificial-intelligence/nanoGPT-JAX/README.html>`_
blog explores the implementation and training of a Generative Pre-trained
Transformer (GPT) model in JAX, inspired by Andrej Karpathys JAX-based
nanoGPT. Comparing how essential GPT components—such as self-attention
mechanisms and optimizers—are realized in JAX and JAX, also highlights
JAXs unique features.
* The `Optimize GPT Training: Enabling Mixed Precision Training in JAX using
ROCm on AMD GPUs <https://rocm.blogs.amd.com/artificial-intelligence/jax-mixed-precision/README.html>`_
blog post provides a comprehensive guide on enhancing the training efficiency
of GPT models by implementing mixed precision techniques in JAX, specifically
tailored for AMD GPUs utilizing the ROCm platform.
* The `Supercharging JAX with Triton Kernels on AMD GPUs <https://rocm.blogs.amd.com/artificial-intelligence/jax-triton/README.html>`_
blog demonstrates how to develop a custom fused dropout-activation kernel for
matrices using Triton, integrate it with JAX, and benchmark its performance
using ROCm.
* The `Distributed fine-tuning with JAX on AMD GPUs <https://rocm.blogs.amd.com/artificial-intelligence/distributed-sft-jax/README.html>`_
outlines the process of fine-tuning a Bidirectional Encoder Representations
from Transformers (BERT)-based large language model (LLM) using JAX for a text
classification task. The blog post discuss techniques for parallelizing the
fine-tuning across multiple AMD GPUs and assess the model's performance on a
holdout dataset. During the fine-tuning, a BERT-base-cased transformer model
and the General Language Understanding Evaluation (GLUE) benchmark dataset was
used on a multi-GPU setup.
* The `MI300X workload optimization guide <https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html>`_
provides detailed guidance on optimizing workloads for the AMD Instinct MI300X
accelerator using ROCm. The page is aimed at helping users achieve optimal
performance for deep learning and other high-performance computing tasks on
the MI300X GPU.
For more use cases and recommendations, see `ROCm JAX blog posts <https://rocm.blogs.amd.com/blog/tag/jax.html>`_.
.. _jax-docker-compat:
Docker image compatibility
================================================================================
.. |docker-icon| raw:: html
<i class="fab fa-docker"></i>
AMD validates and publishes ready-made `ROCm JAX Docker images <https://hub.docker.com/r/rocm/jax>`_
with ROCm backends on Docker Hub. The following Docker image tags and
associated inventories represent the latest JAX version from the official Docker Hub and are validated for
`ROCm 6.4.0 <https://repo.radeon.com/rocm/apt/6.4/>`_. Click the |docker-icon|
icon to view the image on Docker Hub.
.. list-table:: JAX Docker image components
:header-rows: 1
* - Docker image
- JAX
- Linux
- Python
* - .. raw:: html
<a href="https://hub.docker.com/layers/rocm/jax/rocm6.4-jax0.4.35-py3.12/images/sha256-4069398229078f3311128b6d276c6af377c7e97d3363d020b0bf7154fae619ca"><i class="fab fa-docker fa-lg"></i> rocm/jax</a>
- `0.4.35 <https://github.com/ROCm/jax/releases/tag/rocm-jax-v0.4.35>`_
- Ubuntu 24.04
- `3.12.7 <https://www.python.org/downloads/release/python-3127/>`_
* - .. raw:: html
<a href="https://hub.docker.com/layers/rocm/jax/rocm6.4-jax0.4.35-py3.10/images/sha256-a137f901f91ce6c13b424c40a6cf535248d4d20fd36d5daf5eee0570190a4a11"><i class="fab fa-docker fa-lg"></i> rocm/jax</a>
- `0.4.35 <https://github.com/ROCm/jax/releases/tag/rocm-jax-v0.4.35>`_
- Ubuntu 22.04
- `3.10.14 <https://www.python.org/downloads/release/python-31014/>`_
AMD publishes `Community ROCm JAX Docker images <https://hub.docker.com/r/rocm/jax-community>`_
with ROCm backends on Docker Hub. The following Docker image tags and
associated inventories are tested for `ROCm 6.3.2 <https://repo.radeon.com/rocm/apt/6.3.2/>`_.
.. list-table:: JAX community Docker image components
:header-rows: 1
* - Docker image
- JAX
- Linux
- Python
* - .. raw:: html
<a href="https://hub.docker.com/layers/rocm/jax-community/rocm6.3.2-jax0.5.0-py3.12.8/images/sha256-25dfaa0183e274bd0a3554a309af3249c6f16a1793226cb5373f418e39d3146a"><i class="fab fa-docker fa-lg"></i> rocm/jax-community</a>
- `0.5.0 <https://github.com/ROCm/jax/releases/tag/rocm-jax-v0.5.0>`_
- Ubuntu 22.04
- `3.12.8 <https://www.python.org/downloads/release/python-3128/>`_
* - .. raw:: html
<a href="https://hub.docker.com/layers/rocm/jax-community/rocm6.3.2-jax0.5.0-py3.11.11/images/sha256-ff9baeca9067d13e6c279c911e5a9e5beed0817d24fafd424367cc3d5bd381d7"><i class="fab fa-docker fa-lg"></i> rocm/jax-community</a>
- `0.5.0 <https://github.com/ROCm/jax/releases/tag/rocm-jax-v0.5.0>`_
- Ubuntu 22.04
- `3.11.11 <https://www.python.org/downloads/release/python-31111/>`_
* - .. raw:: html
<a href="https://hub.docker.com/layers/rocm/jax-community/rocm6.3.2-jax0.5.0-py3.10.16/images/sha256-8bab484be1713655f74da51a191ed824bb9d03db1104fd63530a1ac3c37cf7b1"><i class="fab fa-docker fa-lg"></i> rocm/jax-community</a>
- `0.5.0 <https://github.com/ROCm/jax/releases/tag/rocm-jax-v0.5.0>`_
- Ubuntu 22.04
- `3.10.16 <https://www.python.org/downloads/release/python-31016/>`_
Key ROCm libraries for JAX
================================================================================
JAX functionality on ROCm is determined by its underlying library
dependencies. These ROCm components affect the capabilities, performance, and
feature set available to developers.
.. list-table::
:header-rows: 1
* - ROCm library
- Version
- Purpose
- Used in
* - `hipBLAS <https://github.com/ROCm/hipBLAS>`_
- :version-ref:`hipBLAS rocm_version`
- Provides GPU-accelerated Basic Linear Algebra Subprograms (BLAS) for
matrix and vector operations.
- Matrix multiplication in ``jax.numpy.matmul``, ``jax.lax.dot`` and
``jax.lax.dot_general``, operations like ``jax.numpy.dot``, which
involve vector and matrix computations and batch matrix multiplications
``jax.numpy.einsum`` with matrix-multiplication patterns algebra
operations.
* - `hipBLASLt <https://github.com/ROCm/hipBLASLt>`_
- :version-ref:`hipBLASLt rocm_version`
- hipBLASLt is an extension of hipBLAS, providing additional
features like epilogues fused into the matrix multiplication kernel or
use of integer tensor cores.
- Matrix multiplication in ``jax.numpy.matmul`` or ``jax.lax.dot``, and
the XLA (Accelerated Linear Algebra) use hipBLASLt for optimized matrix
operations, mixed-precision support, and hardware-specific
optimizations.
* - `hipCUB <https://github.com/ROCm/hipCUB>`_
- :version-ref:`hipCUB rocm_version`
- Provides a C++ template library for parallel algorithms for reduction,
scan, sort and select.
- Reduction functions (``jax.numpy.sum``, ``jax.numpy.mean``,
``jax.numpy.prod``, ``jax.numpy.max`` and ``jax.numpy.min``), prefix sum
(``jax.numpy.cumsum``, ``jax.numpy.cumprod``) and sorting
(``jax.numpy.sort``, ``jax.numpy.argsort``).
* - `hipFFT <https://github.com/ROCm/hipFFT>`_
- :version-ref:`hipFFT rocm_version`
- Provides GPU-accelerated Fast Fourier Transform (FFT) operations.
- Used in functions like ``jax.numpy.fft``.
* - `hipRAND <https://github.com/ROCm/hipRAND>`_
- :version-ref:`hipRAND rocm_version`
- Provides fast random number generation for GPUs.
- The ``jax.random.uniform``, ``jax.random.normal``,
``jax.random.randint`` and ``jax.random.split``.
* - `hipSOLVER <https://github.com/ROCm/hipSOLVER>`_
- :version-ref:`hipSOLVER rocm_version`
- Provides GPU-accelerated solvers for linear systems, eigenvalues, and
singular value decompositions (SVD).
- Solving linear systems (``jax.numpy.linalg.solve``), matrix
factorizations, SVD (``jax.numpy.linalg.svd``) and eigenvalue problems
(``jax.numpy.linalg.eig``).
* - `hipSPARSE <https://github.com/ROCm/hipSPARSE>`_
- :version-ref:`hipSPARSE rocm_version`
- Accelerates operations on sparse matrices, such as sparse matrix-vector
or matrix-matrix products.
- Sparse matrix multiplication (``jax.numpy.matmul``), sparse
matrix-vector and matrix-matrix products
(``jax.experimental.sparse.dot``), sparse linear system solvers and
sparse data handling.
* - `hipSPARSELt <https://github.com/ROCm/hipSPARSELt>`_
- :version-ref:`hipSPARSELt rocm_version`
- Accelerates operations on sparse matrices, such as sparse matrix-vector
or matrix-matrix products.
- Sparse matrix multiplication (``jax.numpy.matmul``), sparse
matrix-vector and matrix-matrix products
(``jax.experimental.sparse.dot``) and sparse linear system solvers.
* - `MIOpen <https://github.com/ROCm/MIOpen>`_
- :version-ref:`MIOpen rocm_version`
- Optimized for deep learning primitives such as convolutions, pooling,
normalization, and activation functions.
- Speeds up convolutional neural networks (CNNs), recurrent neural
networks (RNNs), and other layers. Used in operations like
``jax.nn.conv``, ``jax.nn.relu``, and ``jax.nn.batch_norm``.
* - `RCCL <https://github.com/ROCm/rccl>`_
- :version-ref:`RCCL rocm_version`
- Optimized for multi-GPU communication for operations like all-reduce,
broadcast, and scatter.
- Distribute computations across multiple GPU with ``pmap`` and
``jax.distributed``. XLA automatically uses rccl when executing
operations across multiple GPUs on AMD hardware.
* - `rocThrust <https://github.com/ROCm/rocThrust>`_
- :version-ref:`rocThrust rocm_version`
- Provides a C++ template library for parallel algorithms like sorting,
reduction, and scanning.
- Reduction operations like ``jax.numpy.sum``, ``jax.pmap`` for
distributed training, which involves parallel reductions or
operations like ``jax.numpy.cumsum`` can use rocThrust.
Supported features
===============================================================================
The following table maps the public JAX API modules to their supported
ROCm and JAX versions.
.. list-table::
:header-rows: 1
* - Module
- Description
- As of JAX
- As of ROCm
* - ``jax.numpy``
- Implements the NumPy API, using the primitives in ``jax.lax``.
- 0.1.56
- 5.0.0
* - ``jax.scipy``
- Provides GPU-accelerated and differentiable implementations of many
functions from the SciPy library, leveraging JAX's transformations
(e.g., ``grad``, ``jit``, ``vmap``).
- 0.1.56
- 5.0.0
* - ``jax.lax``
- A library of primitives operations that underpins libraries such as
``jax.numpy.`` Transformation rules, such as Jacobian-vector product
(JVP) and batching rules, are typically defined as transformations on
``jax.lax`` primitives.
- 0.1.57
- 5.0.0
* - ``jax.random``
- Provides a number of routines for deterministic generation of sequences
of pseudorandom numbers.
- 0.1.58
- 5.0.0
* - ``jax.sharding``
- Allows to define partitioning and distributing arrays across multiple
devices.
- 0.3.20
- 5.1.0
* - ``jax.distributed``
- Enables the scaling of computations across multiple devices on a single
machine or across multiple machines.
- 0.1.74
- 5.0.0
* - ``jax.image``
- Contains image manipulation functions like resize, scale and translation.
- 0.1.57
- 5.0.0
* - ``jax.nn``
- Contains common functions for neural network libraries.
- 0.1.56
- 5.0.0
* - ``jax.ops``
- Computes the minimum, maximum, sum or product within segments of an
array.
- 0.1.57
- 5.0.0
* - ``jax.stages``
- Contains interfaces to stages of the compiled execution process.
- 0.3.4
- 5.0.0
* - ``jax.extend``
- Provides modules for access to JAX internal machinery module. The
``jax.extend`` module defines a library view of some of JAXs internal
components.
- 0.4.15
- 5.5.0
* - ``jax.example_libraries``
- Serves as a collection of example code and libraries that demonstrate
various capabilities of JAX.
- 0.1.74
- 5.0.0
* - ``jax.experimental``
- Namespace for experimental features and APIs that are in development or
are not yet fully stable for production use.
- 0.1.56
- 5.0.0
* - ``jax.lib``
- Set of internal tools and types for bridging between JAXs Python
frontend and its XLA backend.
- 0.4.6
- 5.3.0
* - ``jax_triton``
- Library that integrates the Triton deep learning compiler with JAX.
- jax_triton 0.2.0
- 6.2.4
jax.scipy module
-------------------------------------------------------------------------------
A SciPy-like API for scientific computing.
.. list-table::
:header-rows: 1
* - Module
- As of JAX
- As of ROCm
* - ``jax.scipy.cluster``
- 0.3.11
- 5.1.0
* - ``jax.scipy.fft``
- 0.1.71
- 5.0.0
* - ``jax.scipy.integrate``
- 0.4.15
- 5.5.0
* - ``jax.scipy.interpolate``
- 0.1.76
- 5.0.0
* - ``jax.scipy.linalg``
- 0.1.56
- 5.0.0
* - ``jax.scipy.ndimage``
- 0.1.56
- 5.0.0
* - ``jax.scipy.optimize``
- 0.1.57
- 5.0.0
* - ``jax.scipy.signal``
- 0.1.56
- 5.0.0
* - ``jax.scipy.spatial.transform``
- 0.4.12
- 5.4.0
* - ``jax.scipy.sparse.linalg``
- 0.1.56
- 5.0.0
* - ``jax.scipy.special``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats``
- 0.1.56
- 5.0.0
jax.scipy.stats module
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. list-table::
:header-rows: 1
* - Module
- As of JAX
- As of ROCm
* - ``jax.scipy.stats.bernouli``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.beta``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.betabinom``
- 0.1.61
- 5.0.0
* - ``jax.scipy.stats.binom``
- 0.4.14
- 5.4.0
* - ``jax.scipy.stats.cauchy``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.chi2``
- 0.1.61
- 5.0.0
* - ``jax.scipy.stats.dirichlet``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.expon``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.gamma``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.gennorm``
- 0.3.15
- 5.2.0
* - ``jax.scipy.stats.geom``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.laplace``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.logistic``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.multinomial``
- 0.3.18
- 5.1.0
* - ``jax.scipy.stats.multivariate_normal``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.nbinom``
- 0.1.72
- 5.0.0
* - ``jax.scipy.stats.norm``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.pareto``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.poisson``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.t``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.truncnorm``
- 0.4.0
- 5.3.0
* - ``jax.scipy.stats.uniform``
- 0.1.56
- 5.0.0
* - ``jax.scipy.stats.vonmises``
- 0.4.2
- 5.3.0
* - ``jax.scipy.stats.wrapcauchy``
- 0.4.20
- 5.6.0
jax.extend module
-------------------------------------------------------------------------------
Modules for JAX extensions.
.. list-table::
:header-rows: 1
* - Module
- As of JAX
- As of ROCm
* - ``jax.extend.ffi``
- 0.4.30
- 6.0.0
* - ``jax.extend.linear_util``
- 0.4.17
- 5.6.0
* - ``jax.extend.mlir``
- 0.4.26
- 5.6.0
* - ``jax.extend.random``
- 0.4.15
- 5.5.0
Unsupported JAX features
===============================================================================
The following GPU-accelerated JAX features are not supported by ROCm for
the listed supported JAX versions.
.. list-table::
:header-rows: 1
* - Feature
- Description
* - Mixed Precision with TF32
- Mixed precision with TF32 is used for matrix multiplications,
convolutions, and other linear algebra operations, particularly in
deep learning workloads like CNNs and transformers.
* - XLA int4 support
- 4-bit integer (int4) precision in the XLA compiler.
* - MOSAIC (GPU)
- Mosaic is a library of kernel-building abstractions for JAX's Pallas system