diff --git a/docs/compatibility/ml-compatibility/jax-compatibility.rst b/docs/compatibility/ml-compatibility/jax-compatibility.rst index a84d43d16..471b2efc3 100644 --- a/docs/compatibility/ml-compatibility/jax-compatibility.rst +++ b/docs/compatibility/ml-compatibility/jax-compatibility.rst @@ -14,17 +14,18 @@ JAX provides a NumPy-like API, which combines automatic differentiation and the Accelerated Linear Algebra (XLA) compiler to achieve high-performance machine learning at scale. -JAX uses composable transformations of Python and NumPy through just-in-time (JIT) compilation, -automatic vectorization, and parallelization. To learn about JAX, including profiling and -optimizations, see the official `JAX documentation +JAX uses composable transformations of Python and NumPy through just-in-time +(JIT) compilation, automatic vectorization, and parallelization. To learn about +JAX, including profiling and optimizations, see the official `JAX documentation `_. -ROCm support for JAX is upstreamed and users can build the official source code with ROCm -support: +ROCm support for JAX is upstreamed, and users can build the official source code +with ROCm support: - ROCm JAX release: - - Offers AMD-validated and community :ref:`Docker images ` with ROCm and JAX pre-installed. + - Offers AMD-validated and community :ref:`Docker images ` + with ROCm and JAX preinstalled. - ROCm JAX repository: `ROCm/jax `_ @@ -36,8 +37,8 @@ support: - Official JAX repository: `jax-ml/jax `_ - See the `AMD GPU (Linux) installation section - `_ in the JAX - documentation. + `_ in + the JAX documentation. .. note:: @@ -46,6 +47,44 @@ support: `Community ROCm JAX Docker images `_ follow upstream JAX releases and use the latest available ROCm version. +Use cases and recommendations +================================================================================ + +* The `nanoGPT in JAX `_ + blog explores the implementation and training of a Generative Pre-trained + Transformer (GPT) model in JAX, inspired by Andrej Karpathy’s JAX-based + nanoGPT. Comparing how essential GPT components—such as self-attention + mechanisms and optimizers—are realized in JAX and JAX, also highlights + JAX’s unique features. + +* The `Optimize GPT Training: Enabling Mixed Precision Training in JAX using + ROCm on AMD GPUs `_ + blog post provides a comprehensive guide on enhancing the training efficiency + of GPT models by implementing mixed precision techniques in JAX, specifically + tailored for AMD GPUs utilizing the ROCm platform. + +* The `Supercharging JAX with Triton Kernels on AMD GPUs `_ + blog demonstrates how to develop a custom fused dropout-activation kernel for + matrices using Triton, integrate it with JAX, and benchmark its performance + using ROCm. + +* The `Distributed fine-tuning with JAX on AMD GPUs `_ + outlines the process of fine-tuning a Bidirectional Encoder Representations + from Transformers (BERT)-based large language model (LLM) using JAX for a text + classification task. The blog post discuss techniques for parallelizing the + fine-tuning across multiple AMD GPUs and assess the model's performance on a + holdout dataset. During the fine-tuning, a BERT-base-cased transformer model + and the General Language Understanding Evaluation (GLUE) benchmark dataset was + used on a multi-GPU setup. + +* The `MI300X workload optimization guide `_ + provides detailed guidance on optimizing workloads for the AMD Instinct MI300X + accelerator using ROCm. The page is aimed at helping users achieve optimal + performance for deep learning and other high-performance computing tasks on + the MI300X GPU. + +For more use cases and recommendations, see `ROCm JAX blog posts `_. + .. _jax-docker-compat: Docker image compatibility @@ -57,7 +96,7 @@ Docker image compatibility AMD validates and publishes ready-made `ROCm JAX Docker images `_ with ROCm backends on Docker Hub. The following Docker image tags and -associated inventories are validated for +associated inventories represent the latest JAX version from the official Docker Hub and are validated for `ROCm 6.4.0 `_. Click the |docker-icon| icon to view the image on Docker Hub. @@ -121,13 +160,12 @@ associated inventories are tested for `ROCm 6.3.2 `_ -Critical ROCm libraries for JAX +Key ROCm libraries for JAX ================================================================================ -The functionality of JAX with ROCm is determined by its underlying library -dependencies. These critical ROCm components affect the capabilities, -performance, and feature set available to developers. The versions described -are available in ROCm :version:`rocm_version`. +JAX functionality on ROCm is determined by its underlying library +dependencies. These ROCm components affect the capabilities, performance, and +feature set available to developers. .. list-table:: :header-rows: 1 @@ -215,10 +253,10 @@ are available in ROCm :version:`rocm_version`. distributed training, which involves parallel reductions or operations like ``jax.numpy.cumsum`` can use rocThrust. -Supported and unsupported features +Supported features =============================================================================== -The following table maps GPU-accelerated JAX modules to their supported +The following table maps the public JAX API modules to their supported ROCm and JAX versions. .. list-table:: @@ -226,8 +264,8 @@ ROCm and JAX versions. * - Module - Description - - Since JAX - - Since ROCm + - As of JAX + - As of ROCm * - ``jax.numpy`` - Implements the NumPy API, using the primitives in ``jax.lax``. - 0.1.56 @@ -255,21 +293,11 @@ ROCm and JAX versions. devices. - 0.3.20 - 5.1.0 - * - ``jax.dlpack`` - - For exchanging tensor data between JAX and other libraries that support the - DLPack standard. - - 0.1.57 - - 5.0.0 * - ``jax.distributed`` - Enables the scaling of computations across multiple devices on a single machine or across multiple machines. - 0.1.74 - 5.0.0 - * - ``jax.dtypes`` - - Provides utilities for working with and managing data types in JAX - arrays and computations. - - 0.1.66 - - 5.0.0 * - ``jax.image`` - Contains image manipulation functions like resize, scale and translation. - 0.1.57 @@ -283,27 +311,10 @@ ROCm and JAX versions. array. - 0.1.57 - 5.0.0 - * - ``jax.profiler`` - - Contains JAX’s tracing and time profiling features. - - 0.1.57 - - 5.0.0 * - ``jax.stages`` - Contains interfaces to stages of the compiled execution process. - 0.3.4 - 5.0.0 - * - ``jax.tree`` - - Provides utilities for working with tree-like container data structures. - - 0.4.26 - - 5.6.0 - * - ``jax.tree_util`` - - Provides utilities for working with nested data structures, or - ``pytrees``. - - 0.1.65 - - 5.0.0 - * - ``jax.typing`` - - Provides JAX-specific static type annotations. - - 0.3.18 - - 5.1.0 * - ``jax.extend`` - Provides modules for access to JAX internal machinery module. The ``jax.extend`` module defines a library view of some of JAX’s internal @@ -339,8 +350,8 @@ A SciPy-like API for scientific computing. :header-rows: 1 * - Module - - Since JAX - - Since ROCm + - As of JAX + - As of ROCm * - ``jax.scipy.cluster`` - 0.3.11 - 5.1.0 @@ -385,8 +396,8 @@ jax.scipy.stats module :header-rows: 1 * - Module - - Since JAX - - Since ROCm + - As of JAX + - As of ROCm * - ``jax.scipy.stats.bernouli`` - 0.1.56 - 5.0.0 @@ -469,8 +480,8 @@ Modules for JAX extensions. :header-rows: 1 * - Module - - Since JAX - - Since ROCm + - As of JAX + - As of ROCm * - ``jax.extend.ffi`` - 0.4.30 - 6.0.0 @@ -484,190 +495,23 @@ Modules for JAX extensions. - 0.4.15 - 5.5.0 -jax.experimental module -------------------------------------------------------------------------------- - -Experimental modules and APIs. - -.. list-table:: - :header-rows: 1 - - * - Module - - Since JAX - - Since ROCm - * - ``jax.experimental.checkify`` - - 0.1.75 - - 5.0.0 - * - ``jax.experimental.compilation_cache.compilation_cache`` - - 0.1.68 - - 5.0.0 - * - ``jax.experimental.custom_partitioning`` - - 0.4.0 - - 5.3.0 - * - ``jax.experimental.jet`` - - 0.1.56 - - 5.0.0 - * - ``jax.experimental.key_reuse`` - - 0.4.26 - - 5.6.0 - * - ``jax.experimental.mesh_utils`` - - 0.1.76 - - 5.0.0 - * - ``jax.experimental.multihost_utils`` - - 0.3.2 - - 5.0.0 - * - ``jax.experimental.pallas`` - - 0.4.15 - - 5.5.0 - * - ``jax.experimental.pjit`` - - 0.1.61 - - 5.0.0 - * - ``jax.experimental.serialize_executable`` - - 0.4.0 - - 5.3.0 - * - ``jax.experimental.shard_map`` - - 0.4.3 - - 5.3.0 - * - ``jax.experimental.sparse`` - - 0.1.75 - - 5.0.0 - -.. list-table:: - :header-rows: 1 - - * - API - - Since JAX - - Since ROCm - * - ``jax.experimental.enable_x64`` - - 0.1.60 - - 5.0.0 - * - ``jax.experimental.disable_x64`` - - 0.1.60 - - 5.0.0 - -jax.experimental.pallas module -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Module for Pallas, a JAX extension for custom kernels. - -.. list-table:: - :header-rows: 1 - - * - Module - - Since JAX - - Since ROCm - * - ``jax.experimental.pallas.mosaic_gpu`` - - 0.4.31 - - 6.1.3 - * - ``jax.experimental.pallas.tpu`` - - 0.4.15 - - 5.5.0 - * - ``jax.experimental.pallas.triton`` - - 0.4.32 - - 6.1.3 - -jax.experimental.sparse module -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Experimental support for sparse matrix operations. - -.. list-table:: - :header-rows: 1 - - * - Module - - Since JAX - - Since ROCm - * - ``jax.experimental.sparse.linalg`` - - 0.3.15 - - 5.2.0 - * - ``jax.experimental.sparse.sparsify`` - - 0.3.25 - - ❌ - -.. list-table:: - :header-rows: 1 - - * - ``sparse`` data structure API - - Since JAX - - Since ROCm - * - ``jax.experimental.sparse.BCOO`` - - 0.1.72 - - 5.0.0 - * - ``jax.experimental.sparse.BCSR`` - - 0.3.20 - - 5.1.0 - * - ``jax.experimental.sparse.CSR`` - - 0.1.75 - - 5.0.0 - * - ``jax.experimental.sparse.NM`` - - 0.4.27 - - 5.6.0 - * - ``jax.experimental.sparse.COO`` - - 0.1.75 - - 5.0.0 - Unsupported JAX features ------------------------- +-------------------------------------------------------------------------------- -The following are GPU-accelerated JAX features not currently supported by -ROCm. +The following GPU-accelerated JAX features are not supported by ROCm for +the listed supported JAX versions. .. list-table:: :header-rows: 1 * - Feature - Description - - Since JAX * - Mixed Precision with TF32 - Mixed precision with TF32 is used for matrix multiplications, convolutions, and other linear algebra operations, particularly in deep learning workloads like CNNs and transformers. - - 0.2.25 - * - RNN support - - Currently only LSTM with double bias is supported with float32 input - and weight. - - 0.3.25 * - XLA int4 support - 4-bit integer (int4) precision in the XLA compiler. - - 0.4.0 - * - ``jax.experimental.sparsify`` - - Converts a dense matrix to a sparse matrix representation. - - Experimental - -Use cases and recommendations -================================================================================ - -* The `nanoGPT in JAX `_ - blog explores the implementation and training of a Generative Pre-trained - Transformer (GPT) model in JAX, inspired by Andrej Karpathy’s PyTorch-based - nanoGPT. By comparing how essential GPT components—such as self-attention - mechanisms and optimizers—are realized in PyTorch and JAX, also highlight - JAX’s unique features. - -* The `Optimize GPT Training: Enabling Mixed Precision Training in JAX using - ROCm on AMD GPUs `_ - blog post provides a comprehensive guide on enhancing the training efficiency - of GPT models by implementing mixed precision techniques in JAX, specifically - tailored for AMD GPUs utilizing the ROCm platform. - -* The `Supercharging JAX with Triton Kernels on AMD GPUs `_ - blog demonstrates how to develop a custom fused dropout-activation kernel for - matrices using Triton, integrate it with JAX, and benchmark its performance - using ROCm. - -* The `Distributed fine-tuning with JAX on AMD GPUs `_ - outlines the process of fine-tuning a Bidirectional Encoder Representations - from Transformers (BERT)-based large language model (LLM) using JAX for a text - classification task. The blog post discuss techniques for parallelizing the - fine-tuning across multiple AMD GPUs and assess the model's performance on a - holdout dataset. During the fine-tuning, a BERT-base-cased transformer model - and the General Language Understanding Evaluation (GLUE) benchmark dataset was - used on a multi-GPU setup. - -* The `MI300X workload optimization guide `_ - provides detailed guidance on optimizing workloads for the AMD Instinct MI300X - accelerator using ROCm. The page is aimed at helping users achieve optimal - performance for deep learning and other high-performance computing tasks on - the MI300X GPU. - -For more use cases and recommendations, see `ROCm JAX blog posts `_. + * - MOSAIC (GPU) + - Mosaic is a library of kernel-building abstractions for JAX's Pallas system + - Not Supported