Compare commits

...

1 Commits

Author SHA1 Message Date
Istvan Kiss
6dec5f7ecd JAX compatibility page upate 2025-05-08 18:32:27 +02:00

View File

@@ -14,17 +14,18 @@ JAX provides a NumPy-like API, which combines automatic differentiation and the
Accelerated Linear Algebra (XLA) compiler to achieve high-performance machine Accelerated Linear Algebra (XLA) compiler to achieve high-performance machine
learning at scale. learning at scale.
JAX uses composable transformations of Python and NumPy through just-in-time (JIT) compilation, JAX uses composable transformations of Python and NumPy through just-in-time
automatic vectorization, and parallelization. To learn about JAX, including profiling and (JIT) compilation, automatic vectorization, and parallelization. To learn about
optimizations, see the official `JAX documentation JAX, including profiling and optimizations, see the official `JAX documentation
<https://jax.readthedocs.io/en/latest/notebooks/quickstart.html>`_. <https://jax.readthedocs.io/en/latest/notebooks/quickstart.html>`_.
ROCm support for JAX is upstreamed and users can build the official source code with ROCm ROCm support for JAX is upstreamed, and users can build the official source code
support: with ROCm support:
- ROCm JAX release: - ROCm JAX release:
- Offers AMD-validated and community :ref:`Docker images <jax-docker-compat>` with ROCm and JAX pre-installed. - Offers AMD-validated and community :ref:`Docker images <jax-docker-compat>`
with ROCm and JAX preinstalled.
- ROCm JAX repository: `ROCm/jax <https://github.com/ROCm/jax>`_ - ROCm JAX repository: `ROCm/jax <https://github.com/ROCm/jax>`_
@@ -36,8 +37,8 @@ support:
- Official JAX repository: `jax-ml/jax <https://github.com/jax-ml/jax>`_ - Official JAX repository: `jax-ml/jax <https://github.com/jax-ml/jax>`_
- See the `AMD GPU (Linux) installation section - See the `AMD GPU (Linux) installation section
<https://jax.readthedocs.io/en/latest/installation.html#amd-gpu-linux>`_ in the JAX <https://jax.readthedocs.io/en/latest/installation.html#amd-gpu-linux>`_ in
documentation. the JAX documentation.
.. note:: .. note::
@@ -46,6 +47,44 @@ support:
`Community ROCm JAX Docker images <https://hub.docker.com/r/rocm/jax-community>`_ `Community ROCm JAX Docker images <https://hub.docker.com/r/rocm/jax-community>`_
follow upstream JAX releases and use the latest available ROCm version. follow upstream JAX releases and use the latest available ROCm version.
Use cases and recommendations
================================================================================
* The `nanoGPT in JAX <https://rocm.blogs.amd.com/artificial-intelligence/nanoGPT-JAX/README.html>`_
blog explores the implementation and training of a Generative Pre-trained
Transformer (GPT) model in JAX, inspired by Andrej Karpathys JAX-based
nanoGPT. Comparing how essential GPT components—such as self-attention
mechanisms and optimizers—are realized in JAX and JAX, also highlights
JAXs unique features.
* The `Optimize GPT Training: Enabling Mixed Precision Training in JAX using
ROCm on AMD GPUs <https://rocm.blogs.amd.com/artificial-intelligence/jax-mixed-precision/README.html>`_
blog post provides a comprehensive guide on enhancing the training efficiency
of GPT models by implementing mixed precision techniques in JAX, specifically
tailored for AMD GPUs utilizing the ROCm platform.
* The `Supercharging JAX with Triton Kernels on AMD GPUs <https://rocm.blogs.amd.com/artificial-intelligence/jax-triton/README.html>`_
blog demonstrates how to develop a custom fused dropout-activation kernel for
matrices using Triton, integrate it with JAX, and benchmark its performance
using ROCm.
* The `Distributed fine-tuning with JAX on AMD GPUs <https://rocm.blogs.amd.com/artificial-intelligence/distributed-sft-jax/README.html>`_
outlines the process of fine-tuning a Bidirectional Encoder Representations
from Transformers (BERT)-based large language model (LLM) using JAX for a text
classification task. The blog post discuss techniques for parallelizing the
fine-tuning across multiple AMD GPUs and assess the model's performance on a
holdout dataset. During the fine-tuning, a BERT-base-cased transformer model
and the General Language Understanding Evaluation (GLUE) benchmark dataset was
used on a multi-GPU setup.
* The `MI300X workload optimization guide <https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html>`_
provides detailed guidance on optimizing workloads for the AMD Instinct MI300X
accelerator using ROCm. The page is aimed at helping users achieve optimal
performance for deep learning and other high-performance computing tasks on
the MI300X GPU.
For more use cases and recommendations, see `ROCm JAX blog posts <https://rocm.blogs.amd.com/blog/tag/jax.html>`_.
.. _jax-docker-compat: .. _jax-docker-compat:
Docker image compatibility Docker image compatibility
@@ -57,7 +96,7 @@ Docker image compatibility
AMD validates and publishes ready-made `ROCm JAX Docker images <https://hub.docker.com/r/rocm/jax>`_ AMD validates and publishes ready-made `ROCm JAX Docker images <https://hub.docker.com/r/rocm/jax>`_
with ROCm backends on Docker Hub. The following Docker image tags and with ROCm backends on Docker Hub. The following Docker image tags and
associated inventories are validated for associated inventories represent the latest JAX version from the official Docker Hub and are validated for
`ROCm 6.4.0 <https://repo.radeon.com/rocm/apt/6.4/>`_. Click the |docker-icon| `ROCm 6.4.0 <https://repo.radeon.com/rocm/apt/6.4/>`_. Click the |docker-icon|
icon to view the image on Docker Hub. icon to view the image on Docker Hub.
@@ -121,13 +160,12 @@ associated inventories are tested for `ROCm 6.3.2 <https://repo.radeon.com/rocm/
- Ubuntu 22.04 - Ubuntu 22.04
- `3.10.16 <https://www.python.org/downloads/release/python-31016/>`_ - `3.10.16 <https://www.python.org/downloads/release/python-31016/>`_
Critical ROCm libraries for JAX Key ROCm libraries for JAX
================================================================================ ================================================================================
The functionality of JAX with ROCm is determined by its underlying library JAX functionality on ROCm is determined by its underlying library
dependencies. These critical ROCm components affect the capabilities, dependencies. These ROCm components affect the capabilities, performance, and
performance, and feature set available to developers. The versions described feature set available to developers.
are available in ROCm :version:`rocm_version`.
.. list-table:: .. list-table::
:header-rows: 1 :header-rows: 1
@@ -215,10 +253,10 @@ are available in ROCm :version:`rocm_version`.
distributed training, which involves parallel reductions or distributed training, which involves parallel reductions or
operations like ``jax.numpy.cumsum`` can use rocThrust. operations like ``jax.numpy.cumsum`` can use rocThrust.
Supported and unsupported features Supported features
=============================================================================== ===============================================================================
The following table maps GPU-accelerated JAX modules to their supported The following table maps the public JAX API modules to their supported
ROCm and JAX versions. ROCm and JAX versions.
.. list-table:: .. list-table::
@@ -226,8 +264,8 @@ ROCm and JAX versions.
* - Module * - Module
- Description - Description
- Since JAX - As of JAX
- Since ROCm - As of ROCm
* - ``jax.numpy`` * - ``jax.numpy``
- Implements the NumPy API, using the primitives in ``jax.lax``. - Implements the NumPy API, using the primitives in ``jax.lax``.
- 0.1.56 - 0.1.56
@@ -255,21 +293,11 @@ ROCm and JAX versions.
devices. devices.
- 0.3.20 - 0.3.20
- 5.1.0 - 5.1.0
* - ``jax.dlpack``
- For exchanging tensor data between JAX and other libraries that support the
DLPack standard.
- 0.1.57
- 5.0.0
* - ``jax.distributed`` * - ``jax.distributed``
- Enables the scaling of computations across multiple devices on a single - Enables the scaling of computations across multiple devices on a single
machine or across multiple machines. machine or across multiple machines.
- 0.1.74 - 0.1.74
- 5.0.0 - 5.0.0
* - ``jax.dtypes``
- Provides utilities for working with and managing data types in JAX
arrays and computations.
- 0.1.66
- 5.0.0
* - ``jax.image`` * - ``jax.image``
- Contains image manipulation functions like resize, scale and translation. - Contains image manipulation functions like resize, scale and translation.
- 0.1.57 - 0.1.57
@@ -283,27 +311,10 @@ ROCm and JAX versions.
array. array.
- 0.1.57 - 0.1.57
- 5.0.0 - 5.0.0
* - ``jax.profiler``
- Contains JAXs tracing and time profiling features.
- 0.1.57
- 5.0.0
* - ``jax.stages`` * - ``jax.stages``
- Contains interfaces to stages of the compiled execution process. - Contains interfaces to stages of the compiled execution process.
- 0.3.4 - 0.3.4
- 5.0.0 - 5.0.0
* - ``jax.tree``
- Provides utilities for working with tree-like container data structures.
- 0.4.26
- 5.6.0
* - ``jax.tree_util``
- Provides utilities for working with nested data structures, or
``pytrees``.
- 0.1.65
- 5.0.0
* - ``jax.typing``
- Provides JAX-specific static type annotations.
- 0.3.18
- 5.1.0
* - ``jax.extend`` * - ``jax.extend``
- Provides modules for access to JAX internal machinery module. The - Provides modules for access to JAX internal machinery module. The
``jax.extend`` module defines a library view of some of JAXs internal ``jax.extend`` module defines a library view of some of JAXs internal
@@ -339,8 +350,8 @@ A SciPy-like API for scientific computing.
:header-rows: 1 :header-rows: 1
* - Module * - Module
- Since JAX - As of JAX
- Since ROCm - As of ROCm
* - ``jax.scipy.cluster`` * - ``jax.scipy.cluster``
- 0.3.11 - 0.3.11
- 5.1.0 - 5.1.0
@@ -385,8 +396,8 @@ jax.scipy.stats module
:header-rows: 1 :header-rows: 1
* - Module * - Module
- Since JAX - As of JAX
- Since ROCm - As of ROCm
* - ``jax.scipy.stats.bernouli`` * - ``jax.scipy.stats.bernouli``
- 0.1.56 - 0.1.56
- 5.0.0 - 5.0.0
@@ -469,8 +480,8 @@ Modules for JAX extensions.
:header-rows: 1 :header-rows: 1
* - Module * - Module
- Since JAX - As of JAX
- Since ROCm - As of ROCm
* - ``jax.extend.ffi`` * - ``jax.extend.ffi``
- 0.4.30 - 0.4.30
- 6.0.0 - 6.0.0
@@ -484,190 +495,23 @@ Modules for JAX extensions.
- 0.4.15 - 0.4.15
- 5.5.0 - 5.5.0
jax.experimental module
-------------------------------------------------------------------------------
Experimental modules and APIs.
.. list-table::
:header-rows: 1
* - Module
- Since JAX
- Since ROCm
* - ``jax.experimental.checkify``
- 0.1.75
- 5.0.0
* - ``jax.experimental.compilation_cache.compilation_cache``
- 0.1.68
- 5.0.0
* - ``jax.experimental.custom_partitioning``
- 0.4.0
- 5.3.0
* - ``jax.experimental.jet``
- 0.1.56
- 5.0.0
* - ``jax.experimental.key_reuse``
- 0.4.26
- 5.6.0
* - ``jax.experimental.mesh_utils``
- 0.1.76
- 5.0.0
* - ``jax.experimental.multihost_utils``
- 0.3.2
- 5.0.0
* - ``jax.experimental.pallas``
- 0.4.15
- 5.5.0
* - ``jax.experimental.pjit``
- 0.1.61
- 5.0.0
* - ``jax.experimental.serialize_executable``
- 0.4.0
- 5.3.0
* - ``jax.experimental.shard_map``
- 0.4.3
- 5.3.0
* - ``jax.experimental.sparse``
- 0.1.75
- 5.0.0
.. list-table::
:header-rows: 1
* - API
- Since JAX
- Since ROCm
* - ``jax.experimental.enable_x64``
- 0.1.60
- 5.0.0
* - ``jax.experimental.disable_x64``
- 0.1.60
- 5.0.0
jax.experimental.pallas module
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Module for Pallas, a JAX extension for custom kernels.
.. list-table::
:header-rows: 1
* - Module
- Since JAX
- Since ROCm
* - ``jax.experimental.pallas.mosaic_gpu``
- 0.4.31
- 6.1.3
* - ``jax.experimental.pallas.tpu``
- 0.4.15
- 5.5.0
* - ``jax.experimental.pallas.triton``
- 0.4.32
- 6.1.3
jax.experimental.sparse module
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Experimental support for sparse matrix operations.
.. list-table::
:header-rows: 1
* - Module
- Since JAX
- Since ROCm
* - ``jax.experimental.sparse.linalg``
- 0.3.15
- 5.2.0
* - ``jax.experimental.sparse.sparsify``
- 0.3.25
- ❌
.. list-table::
:header-rows: 1
* - ``sparse`` data structure API
- Since JAX
- Since ROCm
* - ``jax.experimental.sparse.BCOO``
- 0.1.72
- 5.0.0
* - ``jax.experimental.sparse.BCSR``
- 0.3.20
- 5.1.0
* - ``jax.experimental.sparse.CSR``
- 0.1.75
- 5.0.0
* - ``jax.experimental.sparse.NM``
- 0.4.27
- 5.6.0
* - ``jax.experimental.sparse.COO``
- 0.1.75
- 5.0.0
Unsupported JAX features Unsupported JAX features
------------------------ --------------------------------------------------------------------------------
The following are GPU-accelerated JAX features not currently supported by The following GPU-accelerated JAX features are not supported by ROCm for
ROCm. the listed supported JAX versions.
.. list-table:: .. list-table::
:header-rows: 1 :header-rows: 1
* - Feature * - Feature
- Description - Description
- Since JAX
* - Mixed Precision with TF32 * - Mixed Precision with TF32
- Mixed precision with TF32 is used for matrix multiplications, - Mixed precision with TF32 is used for matrix multiplications,
convolutions, and other linear algebra operations, particularly in convolutions, and other linear algebra operations, particularly in
deep learning workloads like CNNs and transformers. deep learning workloads like CNNs and transformers.
- 0.2.25
* - RNN support
- Currently only LSTM with double bias is supported with float32 input
and weight.
- 0.3.25
* - XLA int4 support * - XLA int4 support
- 4-bit integer (int4) precision in the XLA compiler. - 4-bit integer (int4) precision in the XLA compiler.
- 0.4.0 * - MOSAIC (GPU)
* - ``jax.experimental.sparsify`` - Mosaic is a library of kernel-building abstractions for JAX's Pallas system
- Converts a dense matrix to a sparse matrix representation. - Not Supported
- Experimental
Use cases and recommendations
================================================================================
* The `nanoGPT in JAX <https://rocm.blogs.amd.com/artificial-intelligence/nanoGPT-JAX/README.html>`_
blog explores the implementation and training of a Generative Pre-trained
Transformer (GPT) model in JAX, inspired by Andrej Karpathys PyTorch-based
nanoGPT. By comparing how essential GPT components—such as self-attention
mechanisms and optimizers—are realized in PyTorch and JAX, also highlight
JAXs unique features.
* The `Optimize GPT Training: Enabling Mixed Precision Training in JAX using
ROCm on AMD GPUs <https://rocm.blogs.amd.com/artificial-intelligence/jax-mixed-precision/README.html>`_
blog post provides a comprehensive guide on enhancing the training efficiency
of GPT models by implementing mixed precision techniques in JAX, specifically
tailored for AMD GPUs utilizing the ROCm platform.
* The `Supercharging JAX with Triton Kernels on AMD GPUs <https://rocm.blogs.amd.com/artificial-intelligence/jax-triton/README.html>`_
blog demonstrates how to develop a custom fused dropout-activation kernel for
matrices using Triton, integrate it with JAX, and benchmark its performance
using ROCm.
* The `Distributed fine-tuning with JAX on AMD GPUs <https://rocm.blogs.amd.com/artificial-intelligence/distributed-sft-jax/README.html>`_
outlines the process of fine-tuning a Bidirectional Encoder Representations
from Transformers (BERT)-based large language model (LLM) using JAX for a text
classification task. The blog post discuss techniques for parallelizing the
fine-tuning across multiple AMD GPUs and assess the model's performance on a
holdout dataset. During the fine-tuning, a BERT-base-cased transformer model
and the General Language Understanding Evaluation (GLUE) benchmark dataset was
used on a multi-GPU setup.
* The `MI300X workload optimization guide <https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html>`_
provides detailed guidance on optimizing workloads for the AMD Instinct MI300X
accelerator using ROCm. The page is aimed at helping users achieve optimal
performance for deep learning and other high-performance computing tasks on
the MI300X GPU.
For more use cases and recommendations, see `ROCm JAX blog posts <https://rocm.blogs.amd.com/blog/tag/jax.html>`_.