mirror of
https://github.com/ROCm/ROCm.git
synced 2026-01-10 15:18:11 -05:00
Compare commits
12 Commits
falcon
...
jax_comp_p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e72d89c589 | ||
|
|
f65e1412df | ||
|
|
ea1072b11d | ||
|
|
90a651d2b6 | ||
|
|
16978a382b | ||
|
|
dc23bb09c2 | ||
|
|
bb7af3351a | ||
|
|
8ef1bb0139 | ||
|
|
1610837a95 | ||
|
|
b7ce573c66 | ||
|
|
186c281aba | ||
|
|
d44ea40a0d |
@@ -15,6 +15,7 @@ parameters:
|
||||
type: object
|
||||
default:
|
||||
- bison
|
||||
- cmake
|
||||
- dejagnu
|
||||
- flex
|
||||
- libbabeltrace-dev
|
||||
@@ -39,17 +40,69 @@ parameters:
|
||||
- name: jobMatrix
|
||||
type: object
|
||||
default:
|
||||
buildTestJobs:
|
||||
testJobs:
|
||||
- gfx942:
|
||||
target: gfx942
|
||||
- gfx90a:
|
||||
target: gfx90a
|
||||
|
||||
jobs:
|
||||
- ${{ each job in parameters.jobMatrix.buildTestJobs }}:
|
||||
- job: ROCgdb_build_test_${{ job.target }}
|
||||
- job: ROCgdb
|
||||
variables:
|
||||
- group: common
|
||||
- template: /.azuredevops/variables-global.yml
|
||||
- name: PKG_CONFIG_PATH
|
||||
value: $(Agent.BuildDirectory)/rocm/share/pkgconfig
|
||||
pool:
|
||||
vmImage: ${{ variables.BASE_BUILD_POOL }}
|
||||
workspace:
|
||||
clean: all
|
||||
steps:
|
||||
- template: ${{ variables.CI_TEMPLATE_PATH }}/steps/dependencies-other.yml
|
||||
parameters:
|
||||
aptPackages: ${{ parameters.aptPackages }}
|
||||
- template: ${{ variables.CI_TEMPLATE_PATH }}/steps/preamble.yml
|
||||
- template: ${{ variables.CI_TEMPLATE_PATH }}/steps/checkout.yml
|
||||
parameters:
|
||||
checkoutRepo: ${{ parameters.checkoutRepo }}
|
||||
- template: ${{ variables.CI_TEMPLATE_PATH }}/steps/dependencies-aqlprofile.yml
|
||||
- template: ${{ variables.CI_TEMPLATE_PATH }}/steps/dependencies-rocm.yml
|
||||
parameters:
|
||||
checkoutRef: ${{ parameters.checkoutRef }}
|
||||
dependencyList: ${{ parameters.rocmDependencies }}
|
||||
aggregatePipeline: ${{ parameters.aggregatePipeline }}
|
||||
- template: ${{ variables.CI_TEMPLATE_PATH }}/steps/build-autotools.yml
|
||||
parameters:
|
||||
configureFlags: >-
|
||||
--program-prefix=roc
|
||||
--enable-64-bit-bfd
|
||||
--enable-targets="x86_64-linux-gnu,amdgcn-amd-amdhsa"
|
||||
--disable-ld
|
||||
--disable-gas
|
||||
--disable-gdbserver
|
||||
--disable-sim
|
||||
--enable-tui
|
||||
--disable-gdbtk
|
||||
--disable-shared
|
||||
--disable-gprofng
|
||||
--with-expat
|
||||
--with-system-zlib
|
||||
--without-guile
|
||||
--with-babeltrace
|
||||
--with-lzma
|
||||
--with-python=python3
|
||||
--with-rocm-dbgapi=$(Agent.BuildDirectory)/rocm
|
||||
LDFLAGS="-Wl,--enable-new-dtags,-rpath=$(Agent.BuildDirectory)/rocm/lib"
|
||||
makeCallPrefix: LD_RUN_PATH='${ORIGIN}/../lib'
|
||||
- template: ${{ variables.CI_TEMPLATE_PATH }}/steps/manifest.yml
|
||||
- template: ${{ variables.CI_TEMPLATE_PATH }}/steps/artifact-upload.yml
|
||||
- template: ${{ variables.CI_TEMPLATE_PATH }}/steps/artifact-links.yml
|
||||
|
||||
- ${{ each job in parameters.jobMatrix.testJobs }}:
|
||||
- job: ROCgdb_test_${{ job.target }}
|
||||
dependsOn: ROCgdb
|
||||
condition:
|
||||
and(
|
||||
and(succeeded(),
|
||||
eq(variables['ENABLE_${{ upper(job.target) }}_TESTS'], 'true'),
|
||||
not(containsValue(split(variables['DISABLED_${{ upper(job.target) }}_TESTS'], ','), variables['Build.DefinitionName'])),
|
||||
eq(${{ parameters.aggregatePipeline }}, False)
|
||||
@@ -99,8 +152,6 @@ jobs:
|
||||
--with-rocm-dbgapi=$(Agent.BuildDirectory)/rocm
|
||||
LDFLAGS="-Wl,--enable-new-dtags,-rpath=$(Agent.BuildDirectory)/rocm/lib"
|
||||
makeCallPrefix: LD_RUN_PATH='${ORIGIN}/../lib'
|
||||
- template: ${{ variables.CI_TEMPLATE_PATH }}/steps/manifest.yml
|
||||
- template: ${{ variables.CI_TEMPLATE_PATH }}/steps/artifact-upload.yml
|
||||
- task: Bash@3
|
||||
displayName: Setup test environment
|
||||
inputs:
|
||||
@@ -109,7 +160,6 @@ jobs:
|
||||
# Assuming that /opt is no longer persistent across runs, test environments are fully ephemeral
|
||||
sudo ln -s $(Agent.BuildDirectory)/rocm /opt/rocm
|
||||
echo "##vso[task.prependpath]/opt/rocm/bin"
|
||||
- template: ${{ variables.CI_TEMPLATE_PATH }}/steps/artifact-links.yml
|
||||
- template: ${{ variables.CI_TEMPLATE_PATH }}/steps/gpu-diagnostics.yml
|
||||
- task: Bash@3
|
||||
displayName: check-gdb
|
||||
|
||||
@@ -27,6 +27,7 @@ parameters:
|
||||
type: object
|
||||
default:
|
||||
- amdsmi
|
||||
- aomp
|
||||
- clr
|
||||
- hipBLAS-common
|
||||
- hipBLASLt
|
||||
@@ -43,6 +44,7 @@ parameters:
|
||||
type: object
|
||||
default:
|
||||
- amdsmi
|
||||
- aomp
|
||||
- clr
|
||||
- hipBLAS-common
|
||||
- hipBLASLt
|
||||
@@ -108,6 +110,7 @@ jobs:
|
||||
-DROCM_PATH=$(Agent.BuildDirectory)/rocm
|
||||
-DCMAKE_CXX_COMPILER=$(Agent.BuildDirectory)/rocm/llvm/bin/clang++
|
||||
-DCMAKE_PREFIX_PATH=$(Agent.BuildDirectory)/rocm
|
||||
-DCMAKE_CXX_FLAGS=-I$(Agent.BuildDirectory)/rocm/llvm/include
|
||||
-DCPACK_PACKAGING_INSTALL_PREFIX=$(Build.BinariesDirectory)
|
||||
-GNinja
|
||||
- template: ${{ variables.CI_TEMPLATE_PATH }}/steps/manifest.yml
|
||||
|
||||
@@ -26,9 +26,11 @@ jobs:
|
||||
parameters:
|
||||
componentName: HIP
|
||||
pipelineId: $(HIP_PIPELINE_ID)
|
||||
- template: ${{ variables.CI_TEMPLATE_PATH }}/steps/artifact-prepare-package.yml
|
||||
parameters:
|
||||
sourceDir: $(Agent.BuildDirectory)/rocm
|
||||
- task: Bash@3
|
||||
displayName: Copy HIP artifacts
|
||||
inputs:
|
||||
targetType: inline
|
||||
script: cp -a $(Agent.BuildDirectory)/rocm/* $(Build.BinariesDirectory)/
|
||||
- template: ${{ variables.CI_TEMPLATE_PATH }}/steps/manifest.yml
|
||||
- template: ${{ variables.CI_TEMPLATE_PATH }}/steps/artifact-upload.yml
|
||||
- template: ${{ variables.CI_TEMPLATE_PATH }}/steps/artifact-links.yml
|
||||
|
||||
@@ -183,6 +183,7 @@ jobs:
|
||||
parameters:
|
||||
componentName: rocm-examples
|
||||
testDir: $(Build.SourcesDirectory)/build
|
||||
testParameters: '--output-on-failure --force-new-ctest-process --output-junit test_output.xml --exclude-regex "rocfft_callback"'
|
||||
- template: ${{ variables.CI_TEMPLATE_PATH }}/steps/docker-container.yml
|
||||
parameters:
|
||||
aptPackages: ${{ parameters.aptPackages }}
|
||||
|
||||
@@ -463,7 +463,7 @@ steps:
|
||||
displayName: 'List downloaded ROCm files'
|
||||
inputs:
|
||||
targetType: inline
|
||||
script: ls -1R $(Agent.BuildDirectory)/rocm
|
||||
script: ls -la1R $(Agent.BuildDirectory)/rocm
|
||||
- ${{ if eq(parameters.skipLibraryLinking, false) }}:
|
||||
- task: Bash@3
|
||||
displayName: 'Link ROCm shared libraries'
|
||||
|
||||
@@ -226,6 +226,7 @@ LM
|
||||
LSAN
|
||||
LSan
|
||||
LTS
|
||||
LanguageCrossEntropy
|
||||
LoRA
|
||||
MEM
|
||||
MERCHANTABILITY
|
||||
@@ -243,6 +244,7 @@ MMIOH
|
||||
MMU
|
||||
MNIST
|
||||
MPI
|
||||
MPT
|
||||
MSVC
|
||||
MVAPICH
|
||||
MVFFR
|
||||
@@ -259,6 +261,7 @@ Meta's
|
||||
Miniconda
|
||||
MirroredStrategy
|
||||
Mixtral
|
||||
MosaicML
|
||||
Multicore
|
||||
Multithreaded
|
||||
MyEnvironment
|
||||
@@ -329,6 +332,7 @@ PipelineParallel
|
||||
PnP
|
||||
PowerEdge
|
||||
PowerShell
|
||||
Pretrained
|
||||
Pretraining
|
||||
Profiler's
|
||||
PyPi
|
||||
|
||||
@@ -743,6 +743,10 @@ See the full [ROCm SMI changelog](https://github.com/ROCm/rocm_smi_lib/blob/rele
|
||||
#### Added
|
||||
|
||||
- Support for VA-API and rocDecode tracing.
|
||||
- Aggregation of MPI data collected across distributed nodes and ranks. The data is concatenated into a single proto file.
|
||||
|
||||
#### Changed
|
||||
- Backend refactored to use [ROCprofiler-SDK](https://github.com/ROCm/rocprofiler-sdk) rather than [ROCProfiler](https://github.com/ROCm/rocprofiler) and [ROCTracer](https://github.com/ROCm/ROCTracer).
|
||||
|
||||
#### Resolved issues
|
||||
|
||||
@@ -753,9 +757,9 @@ See the full [ROCm SMI changelog](https://github.com/ROCm/rocm_smi_lib/blob/rele
|
||||
- Fixed interruption in config file generation.
|
||||
|
||||
- Fixed segmentation fault while running rocprof-sys-instrument.
|
||||
- Fixed an issue where running `rocprof-sys-causal` or using the `-I all` option with `rocprof-sys-sample` caused the system to become non-responsive.
|
||||
|
||||
#### Changed
|
||||
- Backend refactored to use [ROCprofiler-SDK](https://github.com/ROCm/rocprofiler-sdk) rather than [ROCProfiler](https://github.com/ROCm/rocprofiler) and [ROCTracer](https://github.com/ROCm/ROCTracer).
|
||||
- Fixed an issue where sampling multi-GPU Python workloads caused the system to stop responding.
|
||||
|
||||
### **rocPRIM** (3.4.0)
|
||||
|
||||
|
||||
32
RELEASE.md
32
RELEASE.md
@@ -253,14 +253,19 @@ Click {fab}`github` to go to the component's source code on GitHub.
|
||||
</tbody>
|
||||
<tbody class="rocm-components-libs rocm-components-communication tbody-reverse-zebra">
|
||||
<tr>
|
||||
<th rowspan="1"></th>
|
||||
<th rowspan="1">Communication</th>
|
||||
<th rowspan="2"></th>
|
||||
<th rowspan="2">Communication</th>
|
||||
<td><a href="https://rocm.docs.amd.com/projects/rccl/en/docs-6.4.0/index.html">RCCL</a></td>
|
||||
<td>2.21.5 ⇒ <a href="#rccl-2-22-3">2.22.3</a></td>
|
||||
<td><a href="https://github.com/ROCm/rccl"><i class="fab fa-github fa-lg"></i></a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><a href="https://github.com/ROCm/rocSHMEM">rocSHMEM</a></td>
|
||||
<td>2.0.0</td>
|
||||
<td><a href="https://github.com/ROCm/rocSHMEM"><i class="fab fa-github fa-lg"></i></a></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
<tbody class="rocm-components-libs rocm-components-math">
|
||||
<tbody class="rocm-components-libs rocm-components-math tbody-reverse-zebra">
|
||||
<tr>
|
||||
<th rowspan="16"></th>
|
||||
<th rowspan="16">Math</th>
|
||||
@@ -344,7 +349,7 @@ Click {fab}`github` to go to the component's source code on GitHub.
|
||||
<td><a href="https://github.com/ROCm/Tensile"><i class="fab fa-github fa-lg"></i></a></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
<tbody class="rocm-components-libs rocm-components-primitives">
|
||||
<tbody class="rocm-components-libs rocm-components-primitives tbody-reverse-zebra">
|
||||
<tr>
|
||||
<th rowspan="4"></th>
|
||||
<th rowspan="4">Primitives</th>
|
||||
@@ -368,7 +373,7 @@ Click {fab}`github` to go to the component's source code on GitHub.
|
||||
<td><a href="https://github.com/ROCm/rocThrust"><i class="fab fa-github fa-lg"></i></a></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
<tbody class="rocm-components-tools rocm-components-system">
|
||||
<tbody class="rocm-components-tools rocm-components-system tbody-reverse-zebra">
|
||||
<tr>
|
||||
<th rowspan="7">Tools</th>
|
||||
<th rowspan="7">System management</th>
|
||||
@@ -397,7 +402,7 @@ Click {fab}`github` to go to the component's source code on GitHub.
|
||||
<td><a href="https://github.com/ROCm/ROCmValidationSuite"><i class="fab fa-github fa-lg"></i></a></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
<tbody class="rocm-components-tools rocm-components-perf tbody-reverse-zebra">
|
||||
<tbody class="rocm-components-tools rocm-components-perf">
|
||||
<tr>
|
||||
<th rowspan="6"></th>
|
||||
<th rowspan="6">Performance</th>
|
||||
@@ -438,7 +443,7 @@ Click {fab}`github` to go to the component's source code on GitHub.
|
||||
class="fab fa-github fa-lg"></i></a></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
<tbody class="rocm-components-tools rocm-components-dev tbody-reverse-zebra">
|
||||
<tbody class="rocm-components-tools rocm-components-dev">
|
||||
<tr>
|
||||
<th rowspan="5"></th>
|
||||
<th rowspan="5">Development</th>
|
||||
@@ -474,7 +479,7 @@ Click {fab}`github` to go to the component's source code on GitHub.
|
||||
class="fab fa-github fa-lg"></i></a></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
<tbody class="rocm-components-compilers">
|
||||
<tbody class="rocm-components-compilers tbody-reverse-zebra">
|
||||
<tr>
|
||||
<th rowspan="2" colspan="2">Compilers</th>
|
||||
<td><a href="https://rocm.docs.amd.com/projects/HIPCC/en/docs-6.4.0/index.html">HIPCC</a></td>
|
||||
@@ -489,7 +494,7 @@ Click {fab}`github` to go to the component's source code on GitHub.
|
||||
class="fab fa-github fa-lg"></i></a></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
<tbody class="rocm-components-runtimes">
|
||||
<tbody class="rocm-components-runtimes tbody-reverse-zebra">
|
||||
<tr>
|
||||
<th rowspan="2" colspan="2">Runtimes</th>
|
||||
<td><a href="https://rocm.docs.amd.com/projects/HIP/en/docs-6.4.0/index.html">HIP</a></td>
|
||||
@@ -1247,6 +1252,11 @@ See the full [ROCm SMI changelog](https://github.com/ROCm/rocm_smi_lib/blob/rele
|
||||
#### Added
|
||||
|
||||
- Support for VA-API and rocDecode tracing.
|
||||
- Aggregation of MPI data collected across distributed nodes and ranks. The data is concatenated into a single proto file.
|
||||
|
||||
|
||||
#### Changed
|
||||
- Backend refactored to use [ROCprofiler-SDK](https://github.com/ROCm/rocprofiler-sdk) rather than [ROCProfiler](https://github.com/ROCm/rocprofiler) and [ROCTracer](https://github.com/ROCm/ROCTracer).
|
||||
|
||||
#### Resolved issues
|
||||
|
||||
@@ -1257,9 +1267,9 @@ See the full [ROCm SMI changelog](https://github.com/ROCm/rocm_smi_lib/blob/rele
|
||||
- Fixed interruption in config file generation.
|
||||
|
||||
- Fixed segmentation fault while running rocprof-sys-instrument.
|
||||
- Fixed an issue where running `rocprof-sys-causal` or using the `-I all` option with `rocprof-sys-sample` caused the system to become non-responsive.
|
||||
|
||||
#### Changed
|
||||
- Backend refactored to use [ROCprofiler-SDK](https://github.com/ROCm/rocprofiler-sdk) rather than [ROCProfiler](https://github.com/ROCm/rocprofiler) and [ROCTracer](https://github.com/ROCm/ROCTracer).
|
||||
- Fixed an issue where sampling multi-GPU Python workloads caused the system to stop responding.
|
||||
|
||||
### **rocPRIM** (3.4.0)
|
||||
|
||||
|
||||
@@ -81,6 +81,7 @@ additional licenses. Please review individual repositories for more information.
|
||||
| [rocRAND](https://github.com/ROCm/rocRAND/) | [MIT](https://github.com/ROCm/rocRAND/blob/develop/LICENSE.txt) |
|
||||
| [ROCr Debug Agent](https://github.com/ROCm/rocr_debug_agent/) | [The University of Illinois/NCSA](https://github.com/ROCm/rocr_debug_agent/blob/amd-staging/LICENSE.txt) |
|
||||
| [ROCR-Runtime](https://github.com/ROCm/ROCR-Runtime/) | [The University of Illinois/NCSA](https://github.com/ROCm/ROCR-Runtime/blob/amd-staging/LICENSE.txt) |
|
||||
| [rocSHMEM](https://github.com/ROCm/rocSHMEM/) | [MIT](https://github.com/ROCm/rocSHMEM/blob/develop/LICENSE.md) |
|
||||
| [rocSOLVER](https://github.com/ROCm/rocSOLVER/) | [BSD-2-Clause](https://github.com/ROCm/rocSOLVER/blob/develop/LICENSE.md) |
|
||||
| [rocSPARSE](https://github.com/ROCm/rocSPARSE/) | [MIT](https://github.com/ROCm/rocSPARSE/blob/develop/LICENSE.md) |
|
||||
| [rocThrust](https://github.com/ROCm/rocThrust/) | [Apache 2.0](https://github.com/ROCm/rocThrust/blob/develop/LICENSE) |
|
||||
|
||||
@@ -27,7 +27,6 @@ ROCm Version,6.4.0,6.3.3,6.3.2,6.3.1,6.3.0,6.2.4,6.2.2,6.2.1,6.2.0, 6.1.5, 6.1.2
|
||||
:doc:`TensorFlow <../compatibility/ml-compatibility/tensorflow-compatibility>`,"2.18.1, 2.17.1, 2.16.2","2.17.0, 2.16.2, 2.15.1","2.17.0, 2.16.2, 2.15.1","2.17.0, 2.16.2, 2.15.1","2.17.0, 2.16.2, 2.15.1","2.16.1, 2.15.1, 2.14.1","2.16.1, 2.15.1, 2.14.1","2.16.1, 2.15.1, 2.14.1","2.16.1, 2.15.1, 2.14.1","2.15.0, 2.14.0, 2.13.1","2.15.0, 2.14.0, 2.13.1","2.15.0, 2.14.0, 2.13.1","2.15.0, 2.14.0, 2.13.1","2.14.0, 2.13.1, 2.12.1","2.14.0, 2.13.1, 2.12.1"
|
||||
:doc:`JAX <../compatibility/ml-compatibility/jax-compatibility>`,0.4.35,0.4.31,0.4.31,0.4.31,0.4.31,0.4.26,0.4.26,0.4.26,0.4.26,0.4.26,0.4.26,0.4.26,0.4.26,0.4.26,0.4.26
|
||||
`ONNX Runtime <https://onnxruntime.ai/docs/build/eps.html#amd-migraphx>`_,1.2,1.17.3,1.17.3,1.17.3,1.17.3,1.17.3,1.17.3,1.17.3,1.17.3,1.17.3,1.17.3,1.17.3,1.17.3,1.14.1,1.14.1
|
||||
,,,,,,,,,,,,,,,
|
||||
,,,,,,,,,,,,,,,
|
||||
THIRD PARTY COMMS,.. _thirdpartycomms-support-compatibility-matrix-past-60:,,,,,,,,,,,,,,
|
||||
`UCC <https://github.com/ROCm/ucc>`_,>=1.3.0,>=1.3.0,>=1.3.0,>=1.3.0,>=1.3.0,>=1.3.0,>=1.3.0,>=1.3.0,>=1.3.0,>=1.3.0,>=1.3.0,>=1.3.0,>=1.3.0,>=1.2.0,>=1.2.0
|
||||
@@ -53,6 +52,7 @@ ROCm Version,6.4.0,6.3.3,6.3.2,6.3.1,6.3.0,6.2.4,6.2.2,6.2.1,6.2.0, 6.1.5, 6.1.2
|
||||
,,,,,,,,,,,,,,,
|
||||
COMMUNICATION,.. _commlibs-support-compatibility-matrix-past-60:,,,,,,,,,,,,,,
|
||||
:doc:`RCCL <rccl:index>`,2.22.3,2.21.5,2.21.5,2.21.5,2.21.5,2.20.5,2.20.5,2.20.5,2.20.5,2.18.6,2.18.6,2.18.6,2.18.6,2.18.3,2.18.3
|
||||
`rocSHMEM <https://github.com/ROCm/rocSHMEM>`_,2.0.0,N/A,N/A,N/A,N/A,N/A,N/A,N/A,N/A,N/A,N/A,N/A,N/A,N/A,N/A
|
||||
,,,,,,,,,,,,,,,
|
||||
MATH LIBS,.. _mathlibs-support-compatibility-matrix-past-60:,,,,,,,,,,,,,,
|
||||
`half <https://github.com/ROCm/half>`_ ,1.12.0,1.12.0,1.12.0,1.12.0,1.12.0,1.12.0,1.12.0,1.12.0,1.12.0,1.12.0,1.12.0,1.12.0,1.12.0,1.12.0,1.12.0
|
||||
|
||||
|
@@ -77,6 +77,7 @@ compatibility and system requirements.
|
||||
,,,
|
||||
COMMUNICATION,.. _commlibs-support-compatibility-matrix:,,
|
||||
:doc:`RCCL <rccl:index>`,2.22.3,2.21.5,2.20.5
|
||||
`rocSHMEM <https://github.com/ROCm/rocSHMEM>`_ ,2.0.0,N/A,N/A
|
||||
,,,
|
||||
MATH LIBS,.. _mathlibs-support-compatibility-matrix:,,
|
||||
`half <https://github.com/ROCm/half>`_ ,1.12.0,1.12.0,1.12.0
|
||||
|
||||
@@ -14,17 +14,18 @@ JAX provides a NumPy-like API, which combines automatic differentiation and the
|
||||
Accelerated Linear Algebra (XLA) compiler to achieve high-performance machine
|
||||
learning at scale.
|
||||
|
||||
JAX uses composable transformations of Python and NumPy through just-in-time (JIT) compilation,
|
||||
automatic vectorization, and parallelization. To learn about JAX, including profiling and
|
||||
optimizations, see the official `JAX documentation
|
||||
JAX uses composable transformations of Python and NumPy through just-in-time
|
||||
(JIT) compilation, automatic vectorization, and parallelization. To learn about
|
||||
JAX, including profiling and optimizations, see the official `JAX documentation
|
||||
<https://jax.readthedocs.io/en/latest/notebooks/quickstart.html>`_.
|
||||
|
||||
ROCm support for JAX is upstreamed and users can build the official source code with ROCm
|
||||
support:
|
||||
ROCm support for JAX is upstreamed, and users can build the official source code
|
||||
with ROCm support:
|
||||
|
||||
- ROCm JAX release:
|
||||
|
||||
- Offers AMD-validated and community :ref:`Docker images <jax-docker-compat>` with ROCm and JAX pre-installed.
|
||||
- Offers AMD-validated and community :ref:`Docker images <jax-docker-compat>`
|
||||
with ROCm and JAX preinstalled.
|
||||
|
||||
- ROCm JAX repository: `ROCm/jax <https://github.com/ROCm/jax>`_
|
||||
|
||||
@@ -36,8 +37,8 @@ support:
|
||||
- Official JAX repository: `jax-ml/jax <https://github.com/jax-ml/jax>`_
|
||||
|
||||
- See the `AMD GPU (Linux) installation section
|
||||
<https://jax.readthedocs.io/en/latest/installation.html#amd-gpu-linux>`_ in the JAX
|
||||
documentation.
|
||||
<https://jax.readthedocs.io/en/latest/installation.html#amd-gpu-linux>`_ in
|
||||
the JAX documentation.
|
||||
|
||||
.. note::
|
||||
|
||||
@@ -46,6 +47,44 @@ support:
|
||||
`Community ROCm JAX Docker images <https://hub.docker.com/r/rocm/jax-community>`_
|
||||
follow upstream JAX releases and use the latest available ROCm version.
|
||||
|
||||
Use cases and recommendations
|
||||
================================================================================
|
||||
|
||||
* The `nanoGPT in JAX <https://rocm.blogs.amd.com/artificial-intelligence/nanoGPT-JAX/README.html>`_
|
||||
blog explores the implementation and training of a Generative Pre-trained
|
||||
Transformer (GPT) model in JAX, inspired by Andrej Karpathy’s JAX-based
|
||||
nanoGPT. Comparing how essential GPT components—such as self-attention
|
||||
mechanisms and optimizers—are realized in JAX and JAX, also highlights
|
||||
JAX’s unique features.
|
||||
|
||||
* The `Optimize GPT Training: Enabling Mixed Precision Training in JAX using
|
||||
ROCm on AMD GPUs <https://rocm.blogs.amd.com/artificial-intelligence/jax-mixed-precision/README.html>`_
|
||||
blog post provides a comprehensive guide on enhancing the training efficiency
|
||||
of GPT models by implementing mixed precision techniques in JAX, specifically
|
||||
tailored for AMD GPUs utilizing the ROCm platform.
|
||||
|
||||
* The `Supercharging JAX with Triton Kernels on AMD GPUs <https://rocm.blogs.amd.com/artificial-intelligence/jax-triton/README.html>`_
|
||||
blog demonstrates how to develop a custom fused dropout-activation kernel for
|
||||
matrices using Triton, integrate it with JAX, and benchmark its performance
|
||||
using ROCm.
|
||||
|
||||
* The `Distributed fine-tuning with JAX on AMD GPUs <https://rocm.blogs.amd.com/artificial-intelligence/distributed-sft-jax/README.html>`_
|
||||
outlines the process of fine-tuning a Bidirectional Encoder Representations
|
||||
from Transformers (BERT)-based large language model (LLM) using JAX for a text
|
||||
classification task. The blog post discuss techniques for parallelizing the
|
||||
fine-tuning across multiple AMD GPUs and assess the model's performance on a
|
||||
holdout dataset. During the fine-tuning, a BERT-base-cased transformer model
|
||||
and the General Language Understanding Evaluation (GLUE) benchmark dataset was
|
||||
used on a multi-GPU setup.
|
||||
|
||||
* The `MI300X workload optimization guide <https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html>`_
|
||||
provides detailed guidance on optimizing workloads for the AMD Instinct MI300X
|
||||
accelerator using ROCm. The page is aimed at helping users achieve optimal
|
||||
performance for deep learning and other high-performance computing tasks on
|
||||
the MI300X GPU.
|
||||
|
||||
For more use cases and recommendations, see `ROCm JAX blog posts <https://rocm.blogs.amd.com/blog/tag/jax.html>`_.
|
||||
|
||||
.. _jax-docker-compat:
|
||||
|
||||
Docker image compatibility
|
||||
@@ -57,7 +96,7 @@ Docker image compatibility
|
||||
|
||||
AMD validates and publishes ready-made `ROCm JAX Docker images <https://hub.docker.com/r/rocm/jax>`_
|
||||
with ROCm backends on Docker Hub. The following Docker image tags and
|
||||
associated inventories are validated for
|
||||
associated inventories represent the latest JAX version from the official Docker Hub and are validated for
|
||||
`ROCm 6.4.0 <https://repo.radeon.com/rocm/apt/6.4/>`_. Click the |docker-icon|
|
||||
icon to view the image on Docker Hub.
|
||||
|
||||
@@ -121,13 +160,12 @@ associated inventories are tested for `ROCm 6.3.2 <https://repo.radeon.com/rocm/
|
||||
- Ubuntu 22.04
|
||||
- `3.10.16 <https://www.python.org/downloads/release/python-31016/>`_
|
||||
|
||||
Critical ROCm libraries for JAX
|
||||
Key ROCm libraries for JAX
|
||||
================================================================================
|
||||
|
||||
The functionality of JAX with ROCm is determined by its underlying library
|
||||
dependencies. These critical ROCm components affect the capabilities,
|
||||
performance, and feature set available to developers. The versions described
|
||||
are available in ROCm :version:`rocm_version`.
|
||||
JAX functionality on ROCm is determined by its underlying library
|
||||
dependencies. These ROCm components affect the capabilities, performance, and
|
||||
feature set available to developers.
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
@@ -215,10 +253,10 @@ are available in ROCm :version:`rocm_version`.
|
||||
distributed training, which involves parallel reductions or
|
||||
operations like ``jax.numpy.cumsum`` can use rocThrust.
|
||||
|
||||
Supported and unsupported features
|
||||
Supported features
|
||||
===============================================================================
|
||||
|
||||
The following table maps GPU-accelerated JAX modules to their supported
|
||||
The following table maps the public JAX API modules to their supported
|
||||
ROCm and JAX versions.
|
||||
|
||||
.. list-table::
|
||||
@@ -226,8 +264,8 @@ ROCm and JAX versions.
|
||||
|
||||
* - Module
|
||||
- Description
|
||||
- Since JAX
|
||||
- Since ROCm
|
||||
- As of JAX
|
||||
- As of ROCm
|
||||
* - ``jax.numpy``
|
||||
- Implements the NumPy API, using the primitives in ``jax.lax``.
|
||||
- 0.1.56
|
||||
@@ -255,21 +293,11 @@ ROCm and JAX versions.
|
||||
devices.
|
||||
- 0.3.20
|
||||
- 5.1.0
|
||||
* - ``jax.dlpack``
|
||||
- For exchanging tensor data between JAX and other libraries that support the
|
||||
DLPack standard.
|
||||
- 0.1.57
|
||||
- 5.0.0
|
||||
* - ``jax.distributed``
|
||||
- Enables the scaling of computations across multiple devices on a single
|
||||
machine or across multiple machines.
|
||||
- 0.1.74
|
||||
- 5.0.0
|
||||
* - ``jax.dtypes``
|
||||
- Provides utilities for working with and managing data types in JAX
|
||||
arrays and computations.
|
||||
- 0.1.66
|
||||
- 5.0.0
|
||||
* - ``jax.image``
|
||||
- Contains image manipulation functions like resize, scale and translation.
|
||||
- 0.1.57
|
||||
@@ -283,27 +311,10 @@ ROCm and JAX versions.
|
||||
array.
|
||||
- 0.1.57
|
||||
- 5.0.0
|
||||
* - ``jax.profiler``
|
||||
- Contains JAX’s tracing and time profiling features.
|
||||
- 0.1.57
|
||||
- 5.0.0
|
||||
* - ``jax.stages``
|
||||
- Contains interfaces to stages of the compiled execution process.
|
||||
- 0.3.4
|
||||
- 5.0.0
|
||||
* - ``jax.tree``
|
||||
- Provides utilities for working with tree-like container data structures.
|
||||
- 0.4.26
|
||||
- 5.6.0
|
||||
* - ``jax.tree_util``
|
||||
- Provides utilities for working with nested data structures, or
|
||||
``pytrees``.
|
||||
- 0.1.65
|
||||
- 5.0.0
|
||||
* - ``jax.typing``
|
||||
- Provides JAX-specific static type annotations.
|
||||
- 0.3.18
|
||||
- 5.1.0
|
||||
* - ``jax.extend``
|
||||
- Provides modules for access to JAX internal machinery module. The
|
||||
``jax.extend`` module defines a library view of some of JAX’s internal
|
||||
@@ -339,8 +350,8 @@ A SciPy-like API for scientific computing.
|
||||
:header-rows: 1
|
||||
|
||||
* - Module
|
||||
- Since JAX
|
||||
- Since ROCm
|
||||
- As of JAX
|
||||
- As of ROCm
|
||||
* - ``jax.scipy.cluster``
|
||||
- 0.3.11
|
||||
- 5.1.0
|
||||
@@ -385,8 +396,8 @@ jax.scipy.stats module
|
||||
:header-rows: 1
|
||||
|
||||
* - Module
|
||||
- Since JAX
|
||||
- Since ROCm
|
||||
- As of JAX
|
||||
- As of ROCm
|
||||
* - ``jax.scipy.stats.bernouli``
|
||||
- 0.1.56
|
||||
- 5.0.0
|
||||
@@ -469,8 +480,8 @@ Modules for JAX extensions.
|
||||
:header-rows: 1
|
||||
|
||||
* - Module
|
||||
- Since JAX
|
||||
- Since ROCm
|
||||
- As of JAX
|
||||
- As of ROCm
|
||||
* - ``jax.extend.ffi``
|
||||
- 0.4.30
|
||||
- 6.0.0
|
||||
@@ -484,190 +495,25 @@ Modules for JAX extensions.
|
||||
- 0.4.15
|
||||
- 5.5.0
|
||||
|
||||
jax.experimental module
|
||||
-------------------------------------------------------------------------------
|
||||
|
||||
Experimental modules and APIs.
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - Module
|
||||
- Since JAX
|
||||
- Since ROCm
|
||||
* - ``jax.experimental.checkify``
|
||||
- 0.1.75
|
||||
- 5.0.0
|
||||
* - ``jax.experimental.compilation_cache.compilation_cache``
|
||||
- 0.1.68
|
||||
- 5.0.0
|
||||
* - ``jax.experimental.custom_partitioning``
|
||||
- 0.4.0
|
||||
- 5.3.0
|
||||
* - ``jax.experimental.jet``
|
||||
- 0.1.56
|
||||
- 5.0.0
|
||||
* - ``jax.experimental.key_reuse``
|
||||
- 0.4.26
|
||||
- 5.6.0
|
||||
* - ``jax.experimental.mesh_utils``
|
||||
- 0.1.76
|
||||
- 5.0.0
|
||||
* - ``jax.experimental.multihost_utils``
|
||||
- 0.3.2
|
||||
- 5.0.0
|
||||
* - ``jax.experimental.pallas``
|
||||
- 0.4.15
|
||||
- 5.5.0
|
||||
* - ``jax.experimental.pjit``
|
||||
- 0.1.61
|
||||
- 5.0.0
|
||||
* - ``jax.experimental.serialize_executable``
|
||||
- 0.4.0
|
||||
- 5.3.0
|
||||
* - ``jax.experimental.shard_map``
|
||||
- 0.4.3
|
||||
- 5.3.0
|
||||
* - ``jax.experimental.sparse``
|
||||
- 0.1.75
|
||||
- 5.0.0
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - API
|
||||
- Since JAX
|
||||
- Since ROCm
|
||||
* - ``jax.experimental.enable_x64``
|
||||
- 0.1.60
|
||||
- 5.0.0
|
||||
* - ``jax.experimental.disable_x64``
|
||||
- 0.1.60
|
||||
- 5.0.0
|
||||
|
||||
jax.experimental.pallas module
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Module for Pallas, a JAX extension for custom kernels.
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - Module
|
||||
- Since JAX
|
||||
- Since ROCm
|
||||
* - ``jax.experimental.pallas.mosaic_gpu``
|
||||
- 0.4.31
|
||||
- 6.1.3
|
||||
* - ``jax.experimental.pallas.tpu``
|
||||
- 0.4.15
|
||||
- 5.5.0
|
||||
* - ``jax.experimental.pallas.triton``
|
||||
- 0.4.32
|
||||
- 6.1.3
|
||||
|
||||
jax.experimental.sparse module
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Experimental support for sparse matrix operations.
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - Module
|
||||
- Since JAX
|
||||
- Since ROCm
|
||||
* - ``jax.experimental.sparse.linalg``
|
||||
- 0.3.15
|
||||
- 5.2.0
|
||||
* - ``jax.experimental.sparse.sparsify``
|
||||
- 0.3.25
|
||||
- ❌
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - ``sparse`` data structure API
|
||||
- Since JAX
|
||||
- Since ROCm
|
||||
* - ``jax.experimental.sparse.BCOO``
|
||||
- 0.1.72
|
||||
- 5.0.0
|
||||
* - ``jax.experimental.sparse.BCSR``
|
||||
- 0.3.20
|
||||
- 5.1.0
|
||||
* - ``jax.experimental.sparse.CSR``
|
||||
- 0.1.75
|
||||
- 5.0.0
|
||||
* - ``jax.experimental.sparse.NM``
|
||||
- 0.4.27
|
||||
- 5.6.0
|
||||
* - ``jax.experimental.sparse.COO``
|
||||
- 0.1.75
|
||||
- 5.0.0
|
||||
|
||||
Unsupported JAX features
|
||||
------------------------
|
||||
===============================================================================
|
||||
|
||||
The following are GPU-accelerated JAX features not currently supported by
|
||||
ROCm.
|
||||
The following GPU-accelerated JAX features are not supported by ROCm for
|
||||
the listed supported JAX versions.
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - Feature
|
||||
- Description
|
||||
- Since JAX
|
||||
|
||||
* - Mixed Precision with TF32
|
||||
- Mixed precision with TF32 is used for matrix multiplications,
|
||||
convolutions, and other linear algebra operations, particularly in
|
||||
deep learning workloads like CNNs and transformers.
|
||||
- 0.2.25
|
||||
* - RNN support
|
||||
- Currently only LSTM with double bias is supported with float32 input
|
||||
and weight.
|
||||
- 0.3.25
|
||||
|
||||
* - XLA int4 support
|
||||
- 4-bit integer (int4) precision in the XLA compiler.
|
||||
- 0.4.0
|
||||
* - ``jax.experimental.sparsify``
|
||||
- Converts a dense matrix to a sparse matrix representation.
|
||||
- Experimental
|
||||
|
||||
Use cases and recommendations
|
||||
================================================================================
|
||||
|
||||
* The `nanoGPT in JAX <https://rocm.blogs.amd.com/artificial-intelligence/nanoGPT-JAX/README.html>`_
|
||||
blog explores the implementation and training of a Generative Pre-trained
|
||||
Transformer (GPT) model in JAX, inspired by Andrej Karpathy’s PyTorch-based
|
||||
nanoGPT. By comparing how essential GPT components—such as self-attention
|
||||
mechanisms and optimizers—are realized in PyTorch and JAX, also highlight
|
||||
JAX’s unique features.
|
||||
|
||||
* The `Optimize GPT Training: Enabling Mixed Precision Training in JAX using
|
||||
ROCm on AMD GPUs <https://rocm.blogs.amd.com/artificial-intelligence/jax-mixed-precision/README.html>`_
|
||||
blog post provides a comprehensive guide on enhancing the training efficiency
|
||||
of GPT models by implementing mixed precision techniques in JAX, specifically
|
||||
tailored for AMD GPUs utilizing the ROCm platform.
|
||||
|
||||
* The `Supercharging JAX with Triton Kernels on AMD GPUs <https://rocm.blogs.amd.com/artificial-intelligence/jax-triton/README.html>`_
|
||||
blog demonstrates how to develop a custom fused dropout-activation kernel for
|
||||
matrices using Triton, integrate it with JAX, and benchmark its performance
|
||||
using ROCm.
|
||||
|
||||
* The `Distributed fine-tuning with JAX on AMD GPUs <https://rocm.blogs.amd.com/artificial-intelligence/distributed-sft-jax/README.html>`_
|
||||
outlines the process of fine-tuning a Bidirectional Encoder Representations
|
||||
from Transformers (BERT)-based large language model (LLM) using JAX for a text
|
||||
classification task. The blog post discuss techniques for parallelizing the
|
||||
fine-tuning across multiple AMD GPUs and assess the model's performance on a
|
||||
holdout dataset. During the fine-tuning, a BERT-base-cased transformer model
|
||||
and the General Language Understanding Evaluation (GLUE) benchmark dataset was
|
||||
used on a multi-GPU setup.
|
||||
|
||||
* The `MI300X workload optimization guide <https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html>`_
|
||||
provides detailed guidance on optimizing workloads for the AMD Instinct MI300X
|
||||
accelerator using ROCm. The page is aimed at helping users achieve optimal
|
||||
performance for deep learning and other high-performance computing tasks on
|
||||
the MI300X GPU.
|
||||
|
||||
For more use cases and recommendations, see `ROCm JAX blog posts <https://rocm.blogs.amd.com/blog/tag/jax.html>`_.
|
||||
* - MOSAIC (GPU)
|
||||
- Mosaic is a library of kernel-building abstractions for JAX's Pallas system
|
||||
|
||||
@@ -57,6 +57,7 @@ article_pages = [
|
||||
{"file": "how-to/rocm-for-ai/training/prerequisite-system-validation", "os": ["linux"]},
|
||||
{"file": "how-to/rocm-for-ai/training/benchmark-docker/megatron-lm", "os": ["linux"]},
|
||||
{"file": "how-to/rocm-for-ai/training/benchmark-docker/pytorch-training", "os": ["linux"]},
|
||||
{"file": "how-to/rocm-for-ai/training/benchmark-docker/mpt-llm-foundry", "os": ["linux"]},
|
||||
{"file": "how-to/rocm-for-ai/training/scale-model-training", "os": ["linux"]},
|
||||
|
||||
{"file": "how-to/rocm-for-ai/fine-tuning/index", "os": ["linux"]},
|
||||
|
||||
BIN
docs/data/rocm-software-stack-6_4_0.jpg
Normal file
BIN
docs/data/rocm-software-stack-6_4_0.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.2 MiB |
@@ -1,15 +1,178 @@
|
||||
.. meta::
|
||||
:description: How to use model quantization techniques to speed up inference.
|
||||
:keywords: ROCm, LLM, fine-tuning, usage, tutorial, quantization, GPTQ, transformers, bitsandbytes
|
||||
:keywords: ROCm, LLM, fine-tuning, usage, tutorial, quantization, Quark, GPTQ, transformers, bitsandbytes
|
||||
|
||||
*****************************
|
||||
Model quantization techniques
|
||||
*****************************
|
||||
|
||||
Quantization reduces the model size compared to its native full-precision version, making it easier to fit large models
|
||||
onto accelerators or GPUs with limited memory usage. This section explains how to perform LLM quantization using GPTQ
|
||||
onto accelerators or GPUs with limited memory usage. This section explains how to perform LLM quantization using AMD Quark, GPTQ
|
||||
and bitsandbytes on AMD Instinct hardware.
|
||||
|
||||
.. _quantize-llms-quark:
|
||||
|
||||
AMD Quark
|
||||
=========
|
||||
|
||||
`AMD Quark <https://quark.docs.amd.com/latest/>`_ offers the leading efficient and scalable quantization solution tailored to AMD Instinct GPUs. It supports ``FP8`` and ``INT8`` quantization for activations, weights, and KV cache,
|
||||
including ``FP8`` attention. For very large models, it employs a two-level ``INT4-FP8`` scheme—storing weights in ``INT4`` while computing with ``FP8``—for nearly 4× compression without sacrificing accuracy.
|
||||
Quark scales efficiently across multiple GPUs, efficiently handling ultra-large models like Llama-3.1-405B. Quantized ``FP8`` models like Llama, Mixtral, and Grok-1 are available under the `AMD organization on Hugging Face <https://huggingface.co/collections/amd/quark-quantized-ocp-fp8-models-66db7936d18fcbaf95d4405c>`_, and can be deployed directly via `vLLM <https://github.com/vllm-project/vllm/tree/main/vllm>`_.
|
||||
|
||||
Installing Quark
|
||||
-------------------
|
||||
|
||||
The latest release of Quark can be installed with pip
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install amd-quark
|
||||
|
||||
For detailed installation instructions, refer to the `Quark documentation <https://quark.docs.amd.com/latest/install.html>`_.
|
||||
|
||||
|
||||
Using Quark for quantization
|
||||
-----------------------------
|
||||
|
||||
#. First, load the pre-trained model and its corresponding tokenizer using the Hugging Face ``transformers`` library.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
MODEL_ID = "meta-llama/Llama-2-70b-chat-hf"
|
||||
MAX_SEQ_LEN = 512
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_ID, device_map="auto", torch_dtype="auto",
|
||||
)
|
||||
model.eval()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, model_max_length=MAX_SEQ_LEN)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
#. Prepare the calibration DataLoader (static quantization requires calibration data).
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
BATCH_SIZE = 1
|
||||
NUM_CALIBRATION_DATA = 512
|
||||
|
||||
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
|
||||
text_data = dataset["text"][:NUM_CALIBRATION_DATA]
|
||||
|
||||
tokenized_outputs = tokenizer(
|
||||
text_data, return_tensors="pt", padding=True, truncation=True, max_length=MAX_SEQ_LEN
|
||||
)
|
||||
calib_dataloader = DataLoader(
|
||||
tokenized_outputs['input_ids'], batch_size=BATCH_SIZE, drop_last=True
|
||||
)
|
||||
|
||||
#. Define the quantization configuration. See the comments in the following code snippet for descriptions of each configuration option.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from quark.torch.quantization import (Config, QuantizationConfig,
|
||||
FP8E4M3PerTensorSpec)
|
||||
|
||||
# Define fp8/per-tensor/static spec.
|
||||
FP8_PER_TENSOR_SPEC = FP8E4M3PerTensorSpec(observer_method="min_max",
|
||||
is_dynamic=False).to_quantization_spec()
|
||||
|
||||
# Define global quantization config, input tensors and weight apply FP8_PER_TENSOR_SPEC.
|
||||
global_quant_config = QuantizationConfig(input_tensors=FP8_PER_TENSOR_SPEC,
|
||||
weight=FP8_PER_TENSOR_SPEC)
|
||||
|
||||
# Define quantization config for kv-cache layers, output tensors apply FP8_PER_TENSOR_SPEC.
|
||||
KV_CACHE_SPEC = FP8_PER_TENSOR_SPEC
|
||||
kv_cache_layer_names_for_llama = ["*k_proj", "*v_proj"]
|
||||
kv_cache_quant_config = {name :
|
||||
QuantizationConfig(input_tensors=global_quant_config.input_tensors,
|
||||
weight=global_quant_config.weight,
|
||||
output_tensors=KV_CACHE_SPEC)
|
||||
for name in kv_cache_layer_names_for_llama}
|
||||
layer_quant_config = kv_cache_quant_config.copy()
|
||||
|
||||
EXCLUDE_LAYERS = ["lm_head"]
|
||||
quant_config = Config(
|
||||
global_quant_config=global_quant_config,
|
||||
layer_quant_config=layer_quant_config,
|
||||
kv_cache_quant_config=kv_cache_quant_config,
|
||||
exclude=EXCLUDE_LAYERS)
|
||||
|
||||
#. Quantize the model and export
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
from quark.torch import ModelQuantizer, ModelExporter
|
||||
from quark.torch.export import ExporterConfig, JsonExporterConfig
|
||||
|
||||
# Apply quantization.
|
||||
quantizer = ModelQuantizer(quant_config)
|
||||
quant_model = quantizer.quantize_model(model, calib_dataloader)
|
||||
|
||||
# Freeze quantized model to export.
|
||||
freezed_model = quantizer.freeze(model)
|
||||
|
||||
# Define export config.
|
||||
LLAMA_KV_CACHE_GROUP = ["*k_proj", "*v_proj"]
|
||||
export_config = ExporterConfig(json_export_config=JsonExporterConfig())
|
||||
export_config.json_export_config.kv_cache_group = LLAMA_KV_CACHE_GROUP
|
||||
|
||||
EXPORT_DIR = MODEL_ID.split("/")[1] + "-w-fp8-a-fp8-kvcache-fp8-pertensor"
|
||||
exporter = ModelExporter(config=export_config, export_dir=EXPORT_DIR)
|
||||
with torch.no_grad():
|
||||
exporter.export_safetensors_model(freezed_model,
|
||||
quant_config=quant_config, tokenizer=tokenizer)
|
||||
|
||||
Evaluating the quantized model with vLLM
|
||||
----------------------------------------
|
||||
|
||||
The exported Quark-quantized model can be loaded directly by vLLM for inference. You need to specify the model path and inform vLLM about the quantization method (``quantization='quark'``) and the KV cache data type (``kv_cache_dtype='fp8'``).
|
||||
Use the ``LLM`` interface to load the model:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from vllm import LLM, SamplingParamsinterface
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(model="Llama-2-70b-chat-hf-w-fp8-a-fp8-kvcache-fp8-pertensor",
|
||||
kv_cache_dtype='fp8',quantization='quark')
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}")
|
||||
print(f"Output: {generated_text!r}")
|
||||
print("-" * 60)
|
||||
|
||||
You can also evaluate the quantized model's accuracy on standard benchmarks using the `lm-evaluation-harness <https://github.com/EleutherAI/lm-evaluation-harness>`_. Pass the necessary vLLM arguments to ``lm_eval`` via ``--model_args``.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
lm_eval --model vllm \
|
||||
--model_args pretrained=Llama-2-70b-chat-hf-w-fp8-a-fp8-kvcache-fp8-pertensor,kv_cache_dtype='fp8',quantization='quark' \
|
||||
--tasks gsm8k
|
||||
|
||||
This provides a standardized way to measure the performance impact of quantization.
|
||||
.. _fine-tune-llms-gptq:
|
||||
|
||||
GPTQ
|
||||
@@ -33,7 +196,7 @@ The AutoGPTQ library implements the GPTQ algorithm.
|
||||
.. code-block:: shell
|
||||
|
||||
# This will install pre-built wheel for a specific ROCm version.
|
||||
|
||||
|
||||
pip install auto-gptq --no-build-isolation --extra-index-url https://huggingface.github.io/autogptq-index/whl/rocm573/
|
||||
|
||||
Or, install AutoGPTQ from source for the appropriate ROCm version (for example, ROCm 6.1).
|
||||
@@ -43,10 +206,10 @@ The AutoGPTQ library implements the GPTQ algorithm.
|
||||
# Clone the source code.
|
||||
git clone https://github.com/AutoGPTQ/AutoGPTQ.git
|
||||
cd AutoGPTQ
|
||||
|
||||
|
||||
# Speed up the compilation by specifying PYTORCH_ROCM_ARCH to target device.
|
||||
PYTORCH_ROCM_ARCH=gfx942 ROCM_VERSION=6.1 pip install .
|
||||
|
||||
|
||||
# Show the package after the installation
|
||||
|
||||
#. Run ``pip show auto-gptq`` to print information for the installed ``auto-gptq`` package. Its output should look like
|
||||
@@ -112,7 +275,7 @@ Using GPTQ with Hugging Face Transformers
|
||||
.. code-block:: python
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig
|
||||
|
||||
|
||||
base_model_name = " NousResearch/Llama-2-7b-hf"
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
||||
gptq_config = GPTQConfig(bits=4, dataset="c4", tokenizer=tokenizer)
|
||||
@@ -212,10 +375,10 @@ To get started with bitsandbytes primitives, use the following code as reference
|
||||
.. code-block:: python
|
||||
|
||||
import bitsandbytes as bnb
|
||||
|
||||
|
||||
# Use Int8 Matrix Multiplication
|
||||
bnb.matmul(..., threshold=6.0)
|
||||
|
||||
|
||||
# Use bitsandbytes 8-bit Optimizers
|
||||
adam = bnb.optim.Adam8bit(model.parameters(), lr=0.001, betas=(0.9, 0.995))
|
||||
|
||||
@@ -227,14 +390,14 @@ To load a Transformers model in 4-bit, set ``load_in_4bit=true`` in ``BitsAndByt
|
||||
.. code-block:: python
|
||||
|
||||
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
||||
|
||||
|
||||
base_model_name = "NousResearch/Llama-2-7b-hf"
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
bnb_model_4bit = AutoModelForCausalLM.from_pretrained(
|
||||
base_model_name,
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config)
|
||||
|
||||
|
||||
# Check the memory footprint with get_memory_footprint method
|
||||
print(bnb_model_4bit.get_memory_footprint())
|
||||
|
||||
@@ -243,9 +406,9 @@ To load a model in 8-bit for inference, use the ``load_in_8bit`` option.
|
||||
.. code-block:: python
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
||||
|
||||
|
||||
base_model_name = "NousResearch/Llama-2-7b-hf"
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
||||
@@ -253,7 +416,7 @@ To load a model in 8-bit for inference, use the ``load_in_8bit`` option.
|
||||
base_model_name,
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config)
|
||||
|
||||
|
||||
prompt = "What is a large language model?"
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
||||
generated_ids = model.generate(**inputs)
|
||||
|
||||
@@ -20,6 +20,8 @@ training, fine-tuning, and inference. It leverages popular machine learning fram
|
||||
|
||||
- :doc:`LLM inference frameworks <llm-inference-frameworks>`
|
||||
|
||||
- :doc:`Performance testing <vllm-benchmark>`
|
||||
- :doc:`vLLM inference performance testing <vllm-benchmark>`
|
||||
|
||||
- :doc:`PyTorch inference performance testing <pytorch-inference-benchmark>`
|
||||
|
||||
- :doc:`Deploying your model <deploy-your-model>`
|
||||
|
||||
@@ -86,7 +86,7 @@ PyTorch inference performance testing
|
||||
|
||||
.. container:: model-doc pyt_chai1_inference
|
||||
|
||||
2. Use the following command to pull the `ROCm PyTorch Docker image <https://hub.docker.com/layers/rocm/pytorch/latest/images/sha256-05b55983e5154f46e7441897d0908d79877370adca4d1fff4899d9539d6c4969>`_ from Docker Hub.
|
||||
2. Use the following command to pull the `ROCm PyTorch Docker image <https://hub.docker.com/layers/rocm/pytorch/rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0_triton_llvm_reg_issue/images/sha256-b736a4239ab38a9d0e448af6d4adca83b117debed00bfbe33846f99c4540f79b>`_ from Docker Hub.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
@@ -98,7 +98,7 @@ PyTorch inference performance testing
|
||||
|
||||
.. container:: model-doc pyt_clip_inference
|
||||
|
||||
2. Use the following command to pull the `ROCm PyTorch Docker image <https://hub.docker.com/layers/rocm/pytorch/rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0_triton_llvm_reg_issue/images/sha256-b736a4239ab38a9d0e448af6d4adca83b117debed00bfbe33846f99c4540f79b>`_ from Docker Hub.
|
||||
2. Use the following command to pull the `ROCm PyTorch Docker image <https://hub.docker.com/layers/rocm/pytorch/latest/images/sha256-05b55983e5154f46e7441897d0908d79877370adca4d1fff4899d9539d6c4969>`_ from Docker Hub.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
|
||||
@@ -276,7 +276,7 @@ vLLM inference performance testing
|
||||
|
||||
* Latency benchmark
|
||||
|
||||
Use this command to benchmark the latency of the {{model.model}} model on eight GPUs with the ``{{model.precision}}`` data type.
|
||||
Use this command to benchmark the latency of the {{model.model}} model on eight GPUs with ``{{model.precision}}`` precision.
|
||||
|
||||
.. code-block::
|
||||
|
||||
@@ -286,11 +286,11 @@ vLLM inference performance testing
|
||||
|
||||
* Throughput benchmark
|
||||
|
||||
Use this command to throughput the latency of the {{model.model}} model on eight GPUs with the ``{{model.precision}}`` data type.
|
||||
Use this command to benchmark the throughput of the {{model.model}} model on eight GPUs with ``{{model.precision}}`` precision.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
./vllm_benchmark_report.sh -s latency -m {{model.model_repo}} -g 8 -d {{model.precision}}
|
||||
./vllm_benchmark_report.sh -s throughput -m {{model.model_repo}} -g 8 -d {{model.precision}}
|
||||
|
||||
Find the throughput report at ``./reports_{{model.precision}}_vllm_rocm{{unified_docker.rocm_version}}/summary/{{model.model_repo.split('/', 1)[1] if '/' in model.model_repo else model.model_repo}}_throughput_report.csv``.
|
||||
|
||||
|
||||
@@ -0,0 +1,168 @@
|
||||
.. meta::
|
||||
:description: How to train a model using LLM Foundry for ROCm.
|
||||
:keywords: ROCm, AI, LLM, train, PyTorch, torch, Llama, flux, tutorial, docker
|
||||
|
||||
******************************************
|
||||
Training MPT-30B with LLM Foundry and ROCm
|
||||
******************************************
|
||||
|
||||
MPT-30B is a 30-billion parameter decoder-style transformer-based model from
|
||||
the Mosaic Pretrained Transformer (MPT) family -- learn more about it in
|
||||
MosaicML's research blog `MPT-30B: Raising the bar for open-source foundation
|
||||
models <https://www.databricks.com/blog/mpt-30b>`_.
|
||||
|
||||
ROCm and `<https://github.com/ROCm/MAD>`__ provide a pre-configured training
|
||||
environment for the MPT-30B model using the ``rocm/pytorch-training:v25.5``
|
||||
base `Docker image <https://hub.docker.com/layers/rocm/pytorch-training/v25.5/images/sha256-d47850a9b25b4a7151f796a8d24d55ea17bba545573f0d50d54d3852f96ecde5>`_
|
||||
and the `LLM Foundry <https://github.com/mosaicml/llm-foundry>`_ framework.
|
||||
This environment packages the following software components to train
|
||||
on AMD Instinct MI300X series accelerators:
|
||||
|
||||
+--------------------------+--------------------------------+
|
||||
| Software component | Version |
|
||||
+==========================+================================+
|
||||
| ROCm | 6.3.4 |
|
||||
+--------------------------+--------------------------------+
|
||||
| PyTorch | 2.7.0a0+git6374332 |
|
||||
+--------------------------+--------------------------------+
|
||||
| Flash Attention | 3.0.0.post1 |
|
||||
+--------------------------+--------------------------------+
|
||||
|
||||
Using this image, you can build, run, and test the training process
|
||||
for MPT-30B with access to detailed logs and performance metrics.
|
||||
|
||||
System validation
|
||||
=================
|
||||
|
||||
If you have already validated your system settings, including NUMA
|
||||
auto-balancing, skip this step. Otherwise, complete the :ref:`system validation
|
||||
and optimization steps <train-a-model-system-validation>` to set up your system
|
||||
before starting training.
|
||||
|
||||
Getting started
|
||||
===============
|
||||
|
||||
The following procedures help you set up the training environment in a
|
||||
reproducible Docker container. This training environment is tailored for
|
||||
training MPT-30B using LLM Foundry and the specific model configurations outlined.
|
||||
Other configurations and run conditions outside those described in this
|
||||
document are not validated.
|
||||
|
||||
.. tab-set::
|
||||
|
||||
.. tab-item:: MAD-integrated benchmarking
|
||||
|
||||
On your host machine, clone the ROCm Model Automation and Dashboarding
|
||||
(`<https://github.com/ROCm/MAD>`__) repository to a local directory and
|
||||
install the required packages.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
git clone https://github.com/ROCm/MAD
|
||||
cd MAD
|
||||
pip install -r requirements.txt
|
||||
|
||||
Use this command to initiate the MPT-30B training benchmark.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
python3 tools/run_models.py --tags pyt_mpt30b_training --keep-model-dir --live-output --clean-docker-cache
|
||||
|
||||
.. tip::
|
||||
|
||||
If you experience data download failures, set the
|
||||
``MAD_SECRETS_HFTOKEN`` variable to your Hugging Face access token. See
|
||||
`User access tokens <https://huggingface.co/docs/hub/security-tokens>`_
|
||||
for details.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
export MAD_SECRETS_HFTOKEN="your personal Hugging Face token to access gated models"
|
||||
|
||||
.. note::
|
||||
|
||||
For improved performance (training throughput), consider enabling TunableOp.
|
||||
By default, ``pyt_mpt30b_training`` runs with TunableOp disabled. To enable it,
|
||||
run ``tools/run_models.py`` with the ``--tunableop on`` argument or edit the
|
||||
``models.json`` configuration before running training.
|
||||
|
||||
Although this might increase the initial training time, it can result in a performance gain.
|
||||
|
||||
.. tab-item:: Standalone benchmarking
|
||||
|
||||
To set up the training environment, clone the
|
||||
`<https://github.com/ROCm/MAD>`__ repo and build the Docker image. In
|
||||
this snippet, the image is named ``mosaic_mpt30_image``.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
git clone https://github.com/ROCm/MAD
|
||||
cd MAD
|
||||
|
||||
docker build --build-arg MAD_SYSTEM_GPU_ARCHITECTURE=gfx942 -f docker/pyt_mpt30b_training.ubuntu.amd.Dockerfile -t mosaic_mpt30_image .
|
||||
|
||||
Start a ``mosaic_mpt30_image`` container using the following command.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
docker run -it --device=/dev/kfd --device=/dev/dri --group-add=video --ipc=host --shm-size=8G mosaic_mpt30_image
|
||||
|
||||
In the Docker container, clone the `<https://github.com/ROCm/MAD>`__
|
||||
repository and navigate to the benchmark scripts directory at
|
||||
``/workspace/MAD/scripts/pyt_mpt30b_training``.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
git clone https://github.com/ROCm/MAD
|
||||
cd MAD/scripts/pyt_mpt30b_training
|
||||
|
||||
To initiate the training process, use the following command. This script uses the hyperparameters defined in
|
||||
``mpt-30b-instruct.yaml``.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
source run.sh
|
||||
|
||||
.. note::
|
||||
|
||||
For improved performance (training throughput), consider enabling TunableOp.
|
||||
To enable it, add the ``--tunableop on`` flag.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
source run.sh --tunableop on
|
||||
|
||||
Although this might increase the initial training time, it can result in a performance gain.
|
||||
|
||||
Interpreting the output
|
||||
=======================
|
||||
|
||||
The training output will be displayed in the terminal and simultaneously saved
|
||||
to the ``output.txt`` file in the current directory. Key performance metrics will
|
||||
also be extracted and appended to the ``perf_pyt_mpt30b_training.csv`` file.
|
||||
|
||||
Key performance metrics include:
|
||||
|
||||
- Training logs: Real-time display of loss metrics, accuracy, and training progress.
|
||||
|
||||
- Model checkpoints: Periodically saved model snapshots for potential resume or evaluation.
|
||||
|
||||
- Performance metrics: Detailed summaries of training speed and training loss metrics.
|
||||
|
||||
- Performance (throughput/samples_per_sec)
|
||||
|
||||
Overall throughput, measuring the total samples processed per second. Higher values indicate better hardware utilization.
|
||||
|
||||
- Performance per device (throughput/samples_per_sec)
|
||||
|
||||
Throughput on a per-device basis, showing how each GPU or CPU is performing.
|
||||
|
||||
- Language Cross Entropy (metrics/train/LanguageCrossEntropy)
|
||||
|
||||
Measures prediction accuracy. Lower cross entropy suggests the model’s output is closer to the expected distribution.
|
||||
|
||||
- Training loss (loss/train/total)
|
||||
|
||||
Overall training loss. A decreasing trend indicates the model is learning effectively.
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@
|
||||
(communication-libraries)=
|
||||
|
||||
* {doc}`RCCL <rccl:index>`
|
||||
* [rocSHMEM](https://github.com/ROCm/rocSHMEM)
|
||||
:::
|
||||
|
||||
:::{grid-item-card} Math
|
||||
|
||||
@@ -46,6 +46,8 @@ subtrees:
|
||||
title: Train a model with PyTorch
|
||||
- file: how-to/rocm-for-ai/training/benchmark-docker/jax-maxtext
|
||||
title: Train a model with JAX MaxText
|
||||
- file: how-to/rocm-for-ai/training/benchmark-docker/mpt-llm-foundry
|
||||
title: Train a model with LLM Foundry
|
||||
- file: how-to/rocm-for-ai/training/scale-model-training.rst
|
||||
title: Scale model training
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ ROCm is a software stack, composed primarily of open-source software, that
|
||||
provides the tools for programming AMD Graphics Processing Units (GPUs), from
|
||||
low-level kernels to high-level end-user applications.
|
||||
|
||||
.. image:: data/rocm-software-stack-6_3_2.jpg
|
||||
.. image:: data/rocm-software-stack-6_4_0.jpg
|
||||
:width: 800
|
||||
:alt: AMD's ROCm software stack and enabling technologies.
|
||||
:align: center
|
||||
@@ -52,6 +52,7 @@ Communication
|
||||
:header: "Component", "Description"
|
||||
|
||||
":doc:`RCCL <rccl:index>`", "Standalone library that provides multi-GPU and multi-node collective communication primitives"
|
||||
"`rocSHMEM <https://github.com/ROCm/rocSHMEM>`_", "Runtime that provides GPU-centric networking through an OpenSHMEM-like interface. This intra-kernel networking library simplifies application code complexity and enables more fine-grained communication/computation overlap than traditional host-driven networking"
|
||||
|
||||
Math
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Reference in New Issue
Block a user