mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge remote-tracking branch 'oai/main' into ifu230620
Conflicts: include/triton/Conversion/TritonToTritonGPU/Passes.td include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp python/test/unit/language/assert_helper.py python/triton/compiler/compiler.py python/triton/runtime/jit.py python/triton/tools/aot.py test/Conversion/triton_to_tritongpu.mlir test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
37
.github/workflows/Dockerfile
vendored
Normal file
37
.github/workflows/Dockerfile
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
FROM centos:7
|
||||
ARG llvm_dir=llvm-project
|
||||
|
||||
# Add the cache artifacts and the LLVM source tree to the container
|
||||
ADD sccache /sccache
|
||||
ADD "${llvm_dir}" /source/llvm-project
|
||||
ENV SCCACHE_DIR="/sccache"
|
||||
ENV SCCACHE_CACHE_SIZE="2G"
|
||||
|
||||
# Install build dependencies
|
||||
RUN yum install --assumeyes centos-release-scl
|
||||
RUN yum install --assumeyes devtoolset-9-gcc* python3-devel python3-pip
|
||||
SHELL [ "/usr/bin/scl", "enable", "devtoolset-9" ]
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
RUN python3 -m pip install --upgrade cmake ninja sccache
|
||||
|
||||
# Install MLIR's Python Dependencies
|
||||
RUN python3 -m pip install -r /source/llvm-project/mlir/python/requirements.txt
|
||||
|
||||
# Configure, Build, Test, and Install LLVM
|
||||
RUN cmake -GNinja -Bbuild \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DCMAKE_C_COMPILER=gcc \
|
||||
-DCMAKE_CXX_COMPILER=g++ \
|
||||
-DCMAKE_C_COMPILER_LAUNCHER=sccache \
|
||||
-DCMAKE_CXX_COMPILER_LAUNCHER=sccache \
|
||||
-DCMAKE_INSTALL_PREFIX="/install" \
|
||||
-DLLVM_BUILD_UTILS=ON \
|
||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
||||
-DLLVM_ENABLE_PROJECTS=mlir \
|
||||
-DLLVM_INSTALL_UTILS=ON \
|
||||
-DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU" \
|
||||
/source/llvm-project/llvm
|
||||
|
||||
RUN ninja -C build check-mlir install
|
||||
59
.github/workflows/integration-tests.yml
vendored
59
.github/workflows/integration-tests.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
||||
id: set-matrix
|
||||
run: |
|
||||
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
|
||||
echo '::set-output name=matrix::[["self-hosted", "A100"], ["self-hosted", "V100"], ["self-hosted", "gfx908"]]'
|
||||
echo '::set-output name=matrix::[["self-hosted", "A100"], ["self-hosted", "V100"], ["self-hosted", "H100"], ["self-hosted", "gfx908"], ["self-hosted", "arc770"]]'
|
||||
else
|
||||
echo '::set-output name=matrix::["ubuntu-latest"]'
|
||||
fi
|
||||
@@ -44,7 +44,7 @@ jobs:
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set CUDA ENV
|
||||
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100')}}
|
||||
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}}
|
||||
run: |
|
||||
echo "BACKEND=CUDA" >> "${GITHUB_ENV}"
|
||||
|
||||
@@ -53,6 +53,11 @@ jobs:
|
||||
run: |
|
||||
echo "BACKEND=ROCM" >> "${GITHUB_ENV}"
|
||||
|
||||
- name: Set XPU ENV
|
||||
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'arc770')}}
|
||||
run: |
|
||||
echo "BACKEND=XPU" >> "${GITHUB_ENV}"
|
||||
|
||||
- name: Clear cache
|
||||
run: |
|
||||
rm -rf ~/.triton
|
||||
@@ -62,13 +67,22 @@ jobs:
|
||||
echo "PATH=${HOME}/.local/bin:${PATH}" >> "${GITHUB_ENV}"
|
||||
|
||||
- name: Check pre-commit
|
||||
if: ${{ matrix.runner != 'macos-10.15' }}
|
||||
if: ${{ matrix.runner != 'macos-10.15' && (matrix.runner[1] != 'arc770') }}
|
||||
run: |
|
||||
python3 -m pip install --upgrade pre-commit
|
||||
python3 -m pre_commit run --all-files
|
||||
|
||||
- name: Check pre-commit arc770
|
||||
if: ${{ matrix.runner != 'macos-10.15' && (matrix.runner[1] == 'arc770') }}
|
||||
run: |
|
||||
source ${HOME}/triton_vars.sh
|
||||
source ${HOME}/miniconda3/bin/activate
|
||||
conda activate triton-xpu-ci
|
||||
python3 -m pip install --upgrade pre-commit
|
||||
python3 -m pre_commit run --all-files
|
||||
|
||||
- name: Install Triton
|
||||
if: ${{ env.BACKEND != 'ROCM'}}
|
||||
if: ${{ env.BACKEND == 'CUDA'}}
|
||||
run: |
|
||||
cd python
|
||||
python3 -m pip install --upgrade pip
|
||||
@@ -84,8 +98,23 @@ jobs:
|
||||
python3 -m pip install torch==1.13.1 --index-url https://download.pytorch.org/whl/rocm5.2
|
||||
python3 -m pip install --no-build-isolation -vvv '.[tests]'
|
||||
|
||||
- name: Install Triton on XPU
|
||||
if: ${{ env.BACKEND == 'XPU'}}
|
||||
run: |
|
||||
source ${HOME}/triton_vars.sh
|
||||
source ${HOME}/miniconda3/bin/activate
|
||||
conda activate triton-xpu-ci
|
||||
git submodule update --init --recursive
|
||||
cd python
|
||||
python3 -m pip install --upgrade pip
|
||||
python3 -m pip install cmake==3.24
|
||||
export TRITON_CODEGEN_INTEL_XPU_BACKEND=1
|
||||
python3 -m pip uninstall -y triton
|
||||
python3 setup.py build
|
||||
python3 -m pip install --no-build-isolation -vvv '.[tests]'
|
||||
|
||||
- name: Run lit tests
|
||||
if: ${{ env.BACKEND != 'ROCM'}}
|
||||
if: ${{ env.BACKEND == 'CUDA'}}
|
||||
run: |
|
||||
python3 -m pip install lit
|
||||
cd python
|
||||
@@ -102,19 +131,20 @@ jobs:
|
||||
python3 -m pytest
|
||||
|
||||
- name: Create artifacts archive
|
||||
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100')}}
|
||||
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}}
|
||||
run: |
|
||||
tar -czvf artifacts.tar.gz ~/.triton/cache
|
||||
cd ~/.triton
|
||||
tar -czvf artifacts.tar.gz cache
|
||||
|
||||
- name: Upload artifacts archive
|
||||
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100')}}
|
||||
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: artifacts
|
||||
path: artifacts.tar.gz
|
||||
path: ~/.triton/artifacts.tar.gz
|
||||
|
||||
- name: Run CXX unittests
|
||||
if: ${{ env.BACKEND != 'ROCM'}}
|
||||
if: ${{ env.BACKEND == 'CUDA'}}
|
||||
run: |
|
||||
cd python
|
||||
cd "build/$(ls build | grep -i cmake)"
|
||||
@@ -126,6 +156,15 @@ jobs:
|
||||
cd python/test/unit/language
|
||||
python3 -m pytest --capture=tee-sys -rfs --verbose "test_core.py::test_empty_kernel"
|
||||
|
||||
- name: Run python tests on XPU
|
||||
if: ${{ env.BACKEND == 'XPU'}}
|
||||
run: |
|
||||
source ${HOME}/triton_vars.sh
|
||||
source ${HOME}/miniconda3/bin/activate
|
||||
conda activate triton-xpu-ci
|
||||
cd python/test/backend/third_party_backends
|
||||
python3 -m pytest --capture=tee-sys -rfs --verbose --backend xpu
|
||||
|
||||
- name: Regression tests
|
||||
if: ${{ contains(matrix.runner, 'A100') }}
|
||||
run: |
|
||||
|
||||
195
.github/workflows/llvm-build.yml
vendored
Normal file
195
.github/workflows/llvm-build.yml
vendored
Normal file
@@ -0,0 +1,195 @@
|
||||
name: LLVM Build
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- llvm-head
|
||||
paths:
|
||||
- llvm-hash.txt
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
SCCACHE_DIR: ${{ github.workspace }}/sccache
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
|
||||
jobs:
|
||||
|
||||
build:
|
||||
|
||||
strategy:
|
||||
fail-fast: true
|
||||
|
||||
matrix:
|
||||
platform: [
|
||||
ubuntu-20.04-x64,
|
||||
ubuntu-22.04-x64,
|
||||
centos-7-x64,
|
||||
macos-x64,
|
||||
macos-arm64
|
||||
]
|
||||
|
||||
include:
|
||||
# Specify OS versions
|
||||
- platform: ubuntu-20.04-x64
|
||||
host-os: ubuntu-20.04
|
||||
target-os: ubuntu
|
||||
arch: x64
|
||||
- platform: ubuntu-22.04-x64
|
||||
host-os: ubuntu-22.04
|
||||
target-os: ubuntu
|
||||
arch: x64
|
||||
- platform: centos-7-x64
|
||||
host-os: ubuntu-22.04
|
||||
target-os: centos
|
||||
arch: x64
|
||||
- platform: macos-x64
|
||||
host-os: macos-12
|
||||
target-os: macos
|
||||
arch: x64
|
||||
- platform: macos-arm64
|
||||
host-os: macos-12
|
||||
target-os: macos
|
||||
arch: arm64
|
||||
|
||||
runs-on: ${{ matrix.host-os }}
|
||||
|
||||
steps:
|
||||
|
||||
- name: Checkout Repo
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
path: llvm-build
|
||||
|
||||
- name: Fetch LLVM Commit Hash
|
||||
run: |
|
||||
LLVM_COMMIT_HASH="$(cat llvm-build/llvm-hash.txt)"
|
||||
echo "Found LLVM commit hash: ${LLVM_COMMIT_HASH}"
|
||||
echo "llvm_commit_hash=${LLVM_COMMIT_HASH}" >> ${GITHUB_ENV}
|
||||
|
||||
SHORT_LLVM_COMMIT_HASH="${LLVM_COMMIT_HASH:0:8}"
|
||||
echo "Short LLVM commit hash: ${SHORT_LLVM_COMMIT_HASH}"
|
||||
echo "short_llvm_commit_hash=${SHORT_LLVM_COMMIT_HASH}" >> ${GITHUB_ENV}
|
||||
|
||||
INSTALL_DIR="llvm-${SHORT_LLVM_COMMIT_HASH}-${{ matrix.platform }}"
|
||||
echo "LLVM installation directory name: ${INSTALL_DIR}"
|
||||
echo "llvm_install_dir=${INSTALL_DIR}" >> ${GITHUB_ENV}
|
||||
|
||||
- name: Checkout LLVM
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
repository: llvm/llvm-project
|
||||
path: llvm-project
|
||||
ref: ${{ env.llvm_commit_hash }}
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.11
|
||||
|
||||
- name: Install Prerequisites
|
||||
run: |
|
||||
python3 -m pip install cmake ninja sccache
|
||||
mkdir -p ${{ env.SCCACHE_DIR }}
|
||||
rm -rf ${{ env.SCCACHE_DIR }}/*
|
||||
|
||||
- name: Enable Cache
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ${{ env.SCCACHE_DIR }}
|
||||
key: ${{ matrix.platform }}-${{ env.short_llvm_commit_hash }}
|
||||
restore-keys: ${{ matrix.platform }}-
|
||||
|
||||
- name: Configure, Build, Test, and Install LLVM (Ubuntu and macOS x64)
|
||||
if: matrix.arch == 'x64' && contains(fromJSON('["ubuntu", "macos"]'), matrix.target-os)
|
||||
run: >
|
||||
python3 -m pip install -r llvm-project/mlir/python/requirements.txt
|
||||
|
||||
cmake -GNinja -Bllvm-project/build
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++
|
||||
-DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache
|
||||
-DCMAKE_INSTALL_PREFIX="${{ env.llvm_install_dir }}"
|
||||
-DCMAKE_LINKER=lld
|
||||
-DLLVM_BUILD_UTILS=ON
|
||||
-DLLVM_ENABLE_ASSERTIONS=ON
|
||||
-DMLIR_ENABLE_BINDINGS_PYTHON=ON
|
||||
-DLLVM_ENABLE_PROJECTS=mlir
|
||||
-DLLVM_INSTALL_UTILS=ON
|
||||
-DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU"
|
||||
llvm-project/llvm
|
||||
|
||||
ninja -C llvm-project/build check-mlir install
|
||||
|
||||
tar czf "${{ env.llvm_install_dir }}.tar.gz" "${{ env.llvm_install_dir }}"
|
||||
|
||||
- name: Configure, Build, and Install LLVM (macOS arm64)
|
||||
if: matrix.arch == 'arm64' && matrix.target-os == 'macos'
|
||||
run: >
|
||||
python3 -m pip install -r llvm-project/mlir/python/requirements.txt
|
||||
|
||||
cmake -GNinja -Bllvm-project/build
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++
|
||||
-DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache
|
||||
-DCMAKE_INSTALL_PREFIX="${{ env.llvm_install_dir }}"
|
||||
-DCMAKE_LINKER=lld
|
||||
-DCMAKE_OSX_ARCHITECTURES=arm64
|
||||
-DLLVM_BUILD_UTILS=ON
|
||||
-DLLVM_ENABLE_ASSERTIONS=ON
|
||||
-DMLIR_ENABLE_BINDINGS_PYTHON=ON
|
||||
-DLLVM_ENABLE_PROJECTS=mlir
|
||||
-DLLVM_ENABLE_ZSTD=OFF
|
||||
-DLLVM_INSTALL_UTILS=ON
|
||||
-DLLVM_TARGETS_TO_BUILD="AArch64"
|
||||
-DLLVM_USE_HOST_TOOLS=ON
|
||||
llvm-project/llvm
|
||||
|
||||
ninja -C llvm-project/build install
|
||||
|
||||
tar czf "${{ env.llvm_install_dir }}.tar.gz" "${{ env.llvm_install_dir }}"
|
||||
|
||||
- name: Configure, Build, Test, and Install LLVM (CentOS)
|
||||
if: matrix.target-os == 'centos'
|
||||
run: |
|
||||
docker build --tag llvm-build --build-arg llvm_dir=llvm-project \
|
||||
-f llvm-build/.github/workflows/Dockerfile .
|
||||
|
||||
# Create temporary container to copy cache and installed artifacts.
|
||||
CONTAINER_ID=$(docker create llvm-build)
|
||||
docker cp "${CONTAINER_ID}:/install" "${{ env.llvm_install_dir }}"
|
||||
tar czf "${{ env.llvm_install_dir }}.tar.gz" "${{ env.llvm_install_dir }}"
|
||||
|
||||
# We remove the existing directory, otherwise docker will
|
||||
# create a subdirectory inside the existing directory.
|
||||
rm -rf "${{ env.SCCACHE_DIR }}"
|
||||
docker cp "${CONTAINER_ID}:/sccache" "${{ env.SCCACHE_DIR }}"
|
||||
sudo chown -R "$(id -u -n):$(id -g -n)" "${{ env.SCCACHE_DIR }}"
|
||||
|
||||
docker rm "${CONTAINER_ID}"
|
||||
|
||||
- name: Azure Login
|
||||
uses: azure/login@v1
|
||||
with:
|
||||
client-id: ${{ secrets.AZURE_CLIENT_ID }}
|
||||
tenant-id: ${{ secrets.AZURE_TENANT_ID }}
|
||||
subscription-id: ${{ secrets.AZURE_SUBSCRIPTION_ID }}
|
||||
|
||||
- name: Upload LLVM Artifacts to Azure
|
||||
run: |
|
||||
az storage blob upload --account-name tritonlang --auth-mode login --container-name llvm-builds --file "${{ env.llvm_install_dir }}.tar.gz" --name "${{ env.llvm_install_dir }}.tar.gz" --overwrite
|
||||
|
||||
URL=$(az storage blob url --account-name tritonlang --auth-mode login --container-name llvm-builds --name "${{ env.llvm_install_dir }}.tar.gz")
|
||||
echo "Blob URL: ${URL}"
|
||||
|
||||
- name: Azure Logout
|
||||
run: |
|
||||
az logout
|
||||
az cache purge
|
||||
az account clear
|
||||
if: always()
|
||||
|
||||
- name: Dump Sccache Statistics
|
||||
run: sccache --show-stats
|
||||
40
.github/workflows/wheels.yml
vendored
40
.github/workflows/wheels.yml
vendored
@@ -1,24 +1,45 @@
|
||||
name: Wheels
|
||||
on:
|
||||
workflow_dispatch:
|
||||
#schedule:
|
||||
# - cron: "0 0 * * *"
|
||||
schedule:
|
||||
- cron: "0 2 * * *"
|
||||
|
||||
jobs:
|
||||
|
||||
Build-Wheels:
|
||||
|
||||
runs-on: [self-hosted, V100]
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
steps:
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install Azure CLI
|
||||
run: |
|
||||
curl -sL https://aka.ms/InstallAzureCLIDeb | sudo bash
|
||||
|
||||
- name: Azure login
|
||||
uses: azure/login@v1
|
||||
with:
|
||||
client-id: ${{ secrets.AZURE_CLIENT_ID }}
|
||||
tenant-id: ${{ secrets.AZURE_TENANT_ID }}
|
||||
subscription-id: ${{ secrets.AZURE_SUBSCRIPTION_ID }}
|
||||
|
||||
- id: generate-token
|
||||
name: Generate token
|
||||
run: |
|
||||
AZ_TOKEN=$(az account get-access-token --query accessToken)
|
||||
echo "::add-mask::$AZ_TOKEN"
|
||||
echo "access_token=$AZ_TOKEN" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Patch setup.py
|
||||
run: |
|
||||
#sed -i 's/name\=\"triton\"/name="triton-nightly"/g' python/setup.py
|
||||
export LATEST_DATE=$(TZ=UTC0 git show --quiet --date='format-local:%Y%m%d' --format="%cd")
|
||||
sed -i 's/name\=\"triton\"/name="triton-nightly"/g' python/setup.py
|
||||
export LATEST_DATE=$(TZ=UTC0 git show --quiet --date='format-local:%Y%m%d%H%M%S' --format="%cd")
|
||||
sed -i -r "s/version\=\"(.*)\"/version=\"\1-dev"$LATEST_DATE"\"/g" python/setup.py
|
||||
echo "" >> python/setup.cfg
|
||||
echo "[build_ext]" >> python/setup.cfg
|
||||
@@ -29,11 +50,10 @@ jobs:
|
||||
export CIBW_MANYLINUX_X86_64_IMAGE="quay.io/pypa/manylinux2014_x86_64:latest"
|
||||
#export CIBW_MANYLINUX_PYPY_X86_64_IMAGE="quay.io/pypa/manylinux2014_x86_64:latest"
|
||||
export CIBW_BEFORE_BUILD="pip install cmake;"
|
||||
export CIBW_SKIP="{cp,pp}35-*"
|
||||
export CIBW_BUILD="{cp,pp}3*-manylinux_x86_64 cp3*-musllinux_x86_64"
|
||||
export CIBW_SKIP="{cp,pp}{35,36}-*"
|
||||
export CIBW_BUILD="{cp,pp}3*-manylinux_x86_64"
|
||||
python3 -m cibuildwheel python --output-dir wheelhouse
|
||||
|
||||
|
||||
- name: Upload wheels to PyPI
|
||||
- name: Publish wheels to Azure DevOps
|
||||
run: |
|
||||
python3 -m twine upload wheelhouse/* -u __token__ -p ${{ secrets.PYPY_API_TOKEN }}
|
||||
python3 -m twine upload -r Triton-Nightly -u TritonArtifactsSP -p ${{ steps.generate-token.outputs.access_token }} --config-file utils/nightly.pypirc --non-interactive --verbose wheelhouse/*
|
||||
|
||||
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[submodule "third_party/intel_xpu_backend"]
|
||||
path = third_party/intel_xpu_backend
|
||||
url = http://github.com/intel/intel-xpu-backend-for-triton
|
||||
@@ -22,6 +22,7 @@ endif()
|
||||
# Options
|
||||
option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
|
||||
option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
|
||||
set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends")
|
||||
|
||||
# Force TRITON_USE_ROCM for ROCm support
|
||||
set(TRITON_USE_ROCM ON)
|
||||
@@ -276,6 +277,14 @@ if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32)
|
||||
target_link_libraries(triton ${CUTLASS_LIBRARIES} ${PYTHON_LDFLAGS})
|
||||
endif()
|
||||
|
||||
list(LENGTH TRITON_CODEGEN_BACKENDS CODEGEN_BACKENDS_LEN)
|
||||
if (${CODEGEN_BACKENDS_LEN} GREATER 0)
|
||||
set(PYTHON_THIRD_PARTY_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/triton/third_party)
|
||||
foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS})
|
||||
add_subdirectory(third_party/${CODEGEN_BACKEND})
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
add_subdirectory(test)
|
||||
|
||||
add_subdirectory(unittest)
|
||||
|
||||
@@ -25,12 +25,12 @@ You can install the latest stable release of Triton from pip:
|
||||
```bash
|
||||
pip install triton
|
||||
```
|
||||
Binary wheels are available for CPython 3.6-3.11 and PyPy 3.7-3.9.
|
||||
Binary wheels are available for CPython 3.8-3.11 and PyPy 3.8-3.9.
|
||||
|
||||
And the latest nightly release:
|
||||
|
||||
```bash
|
||||
pip install -U --pre triton
|
||||
pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
|
||||
```
|
||||
# Install from source
|
||||
```
|
||||
|
||||
16
docs/backend/ldmatrixOperand0.svg
Normal file
16
docs/backend/ldmatrixOperand0.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 54 KiB |
16
docs/backend/ldmatrixOperand1.svg
Normal file
16
docs/backend/ldmatrixOperand1.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 49 KiB |
@@ -12,13 +12,13 @@ You can install the latest stable release of Triton from pip:
|
||||
|
||||
pip install triton
|
||||
|
||||
Binary wheels are available for CPython 3.6-3.9 and PyPy 3.6-3.7.
|
||||
Binary wheels are available for CPython 3.8-3.11 and PyPy 3.8-3.9.
|
||||
|
||||
And the latest nightly release:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U --pre triton
|
||||
pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
|
||||
|
||||
|
||||
-----------
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
Welcome to Triton's documentation!
|
||||
==================================
|
||||
|
||||
Triton is a language and compiler for parallel programming. It aims to provide a Python-based programming environment for productively writing custom DNN compute kernels capable of running at maximal throughput on modern GPU hardware.
|
||||
Triton_ is a language and compiler for parallel programming. It aims to provide a Python-based programming environment for productively writing custom DNN compute kernels capable of running at maximal throughput on modern GPU hardware.
|
||||
|
||||
|
||||
Getting Started
|
||||
@@ -52,3 +52,5 @@ Check out the following documents to learn more about Triton and how it compares
|
||||
|
||||
programming-guide/chapter-1/introduction
|
||||
programming-guide/chapter-2/related-work
|
||||
|
||||
.. _Triton: https://github.com/openai/triton
|
||||
|
||||
@@ -125,7 +125,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the size of the given buffer is a virtual buffer.
|
||||
/// Returns if the given buffer is a virtual buffer.
|
||||
bool isVirtualBuffer(BufferId bufferId) const {
|
||||
return bufferSet.at(bufferId).kind == BufferT::BufferKind::Virtual;
|
||||
}
|
||||
|
||||
@@ -73,9 +73,11 @@ bool supportMMA(triton::DotOp op, int version);
|
||||
|
||||
bool supportMMA(Value value, int version);
|
||||
|
||||
Type getElementType(Value value);
|
||||
bool isSingleValue(Value value);
|
||||
|
||||
std::string getValueOperandName(Value value, AsmState &state);
|
||||
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
|
||||
|
||||
Type getElementType(Value value);
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
inline SmallVector<T_OUT> convertType(ArrayRef<T_IN> in) {
|
||||
@@ -91,7 +93,7 @@ template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
|
||||
|
||||
template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; }
|
||||
|
||||
// output[i] = input[order[i]]
|
||||
/// output[i] = input[order[i]]
|
||||
template <typename T, typename RES_T = T>
|
||||
SmallVector<RES_T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
|
||||
size_t rank = order.size();
|
||||
@@ -103,6 +105,7 @@ SmallVector<RES_T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Get the highest power of 2 divisor of an integer.
|
||||
template <typename T> T highestPowOf2Divisor(T n) {
|
||||
if (n == 0) {
|
||||
return (static_cast<T>(1) << (sizeof(T) * 8 - 2));
|
||||
@@ -110,9 +113,18 @@ template <typename T> T highestPowOf2Divisor(T n) {
|
||||
return (n & (~(n - 1)));
|
||||
}
|
||||
|
||||
bool isSingleValue(Value value);
|
||||
|
||||
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
|
||||
/// Get the next power of 2 for an integer (or the integer itself if it is a
|
||||
/// power of 2).
|
||||
template <typename T> T nextPowOf2(T n) {
|
||||
if (n == 0) {
|
||||
return 1;
|
||||
}
|
||||
n--;
|
||||
for (unsigned i = 1; i < sizeof(T) * 8; i <<= 1) {
|
||||
n |= n >> i;
|
||||
}
|
||||
return n + 1;
|
||||
}
|
||||
|
||||
/// Multi-root DAG topological sort.
|
||||
/// Performs a topological sort of the Operation in the `toSort` SetVector.
|
||||
|
||||
@@ -23,7 +23,11 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
|
||||
"number of warps">,
|
||||
|
||||
Option<"threadsPerWarp", "threads-per-warp",
|
||||
<<<<<<< HEAD
|
||||
"int32_t", /*default*/"TRITONGPU_DEFAULT_WARPSIZE",
|
||||
=======
|
||||
"int32_t", /*default*/"32",
|
||||
>>>>>>> oai/main
|
||||
"number of threads per warp">,
|
||||
];
|
||||
}
|
||||
|
||||
@@ -14,6 +14,17 @@ def TT_CacheModifierAttr : I32EnumAttr<
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
|
||||
def TT_MemSemanticAttr : I32EnumAttr<
|
||||
"MemSemantic", "",
|
||||
[
|
||||
I32EnumAttrCase<"RELAXED", 1, "relaxed">,
|
||||
I32EnumAttrCase<"ACQUIRE", 2, "acquire">,
|
||||
I32EnumAttrCase<"RELEASE", 3, "release">,
|
||||
I32EnumAttrCase<"ACQUIRE_RELEASE", 4, "acq_rel">,
|
||||
]> {
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
|
||||
def TT_EvictionPolicyAttr : I32EnumAttr<
|
||||
"EvictionPolicy", "",
|
||||
[
|
||||
|
||||
@@ -234,7 +234,8 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape,
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrLike:$ptr,
|
||||
TT_Type:$val, Optional<TT_BoolLike>:$mask);
|
||||
TT_Type:$val, Optional<TT_BoolLike>:$mask,
|
||||
TT_MemSemanticAttr:$sem);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
@@ -255,7 +256,8 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>,
|
||||
return $old
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_PtrLike:$ptr, TT_Type:$cmp, TT_Type:$val);
|
||||
let arguments = (ins TT_PtrLike:$ptr, TT_Type:$cmp, TT_Type:$val,
|
||||
TT_MemSemanticAttr:$sem);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
@@ -444,7 +446,6 @@ class TT_ExternElementwiseOpBase<string mnemonic, list<Trait> traits = []> :
|
||||
let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)";
|
||||
}
|
||||
|
||||
|
||||
def TT_PureExternElementwiseOp : TT_ExternElementwiseOpBase<"pure_extern_elementwise", [Pure, Elementwise]> {
|
||||
let summary = "FFI for pure element-wise extern LLVM bitcode functions";
|
||||
}
|
||||
|
||||
@@ -23,6 +23,9 @@ namespace gpu {
|
||||
|
||||
unsigned getTotalElemsPerThread(Type type);
|
||||
|
||||
unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape,
|
||||
Type eltTy);
|
||||
|
||||
SmallVector<unsigned> getElemsPerThread(Type type);
|
||||
|
||||
SmallVector<unsigned> getThreadsPerWarp(Attribute layout);
|
||||
@@ -74,11 +77,11 @@ SmallVector<unsigned> getOrder(Attribute layout);
|
||||
|
||||
bool isaDistributedLayout(Attribute layout);
|
||||
|
||||
bool isSharedEncoding(Value value);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace triton
|
||||
|
||||
bool isSharedEncoding(Value value);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
|
||||
|
||||
@@ -34,7 +34,11 @@ def TritonGPU_Dialect : Dialect {
|
||||
static int getThreadsPerWarp(ModuleOp mod) {
|
||||
Attribute threadsPerWarp = mod->getDiscardableAttr("triton_gpu.threads-per-warp");
|
||||
if(!threadsPerWarp) {
|
||||
<<<<<<< HEAD
|
||||
return TRITONGPU_DEFAULT_WARPSIZE;
|
||||
=======
|
||||
return 32;
|
||||
>>>>>>> oai/main
|
||||
}
|
||||
return threadsPerWarp.cast<IntegerAttr>().getInt();
|
||||
}
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
#ifndef TRITON_LIB_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
|
||||
#define TRITON_LIB_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
|
||||
#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
|
||||
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
|
||||
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
LogicalResult fixupLoops(ModuleOp mod);
|
||||
@@ -12,9 +16,11 @@ LogicalResult fixupLoops(ModuleOp mod);
|
||||
LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
||||
Attribute &ret);
|
||||
|
||||
bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding);
|
||||
bool isExpensiveLoadOrStore(Operation *op, Attribute &targetEncoding);
|
||||
|
||||
bool expensiveToRemat(Operation *op, Attribute &targetEncoding);
|
||||
bool isExpensiveCat(triton::CatOp cat, Attribute &targetEncoding);
|
||||
|
||||
bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding);
|
||||
|
||||
// skipInit is True when we only consider the operands of the initOp but
|
||||
// not the initOp itself.
|
||||
@@ -36,4 +42,4 @@ LogicalResult canMoveOutOfLoop(BlockArgument arg,
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_LIB_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
|
||||
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
|
||||
@@ -1,19 +0,0 @@
|
||||
#ifndef TRITON_TARGET_AMDGCNTRANSLATION_H
|
||||
#define TRITON_TARGET_AMDGCNTRANSLATION_H
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
namespace llvm {
|
||||
class Module;
|
||||
} // namespace llvm
|
||||
|
||||
namespace triton {
|
||||
|
||||
// Translate LLVM IR to AMDGCN code.
|
||||
std::tuple<std::string, std::string>
|
||||
translateLLVMIRToAMDGCN(llvm::Module &module, std::string cc);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
#endif
|
||||
@@ -38,7 +38,7 @@ void SharedMemoryAliasAnalysis::visitOperation(
|
||||
// insert_slice %src into %dst[%offsets]
|
||||
aliasInfo = AliasInfo(operands[1]->getValue());
|
||||
pessimistic = false;
|
||||
} else if (isSharedEncoding(result)) {
|
||||
} else if (triton::gpu::isSharedEncoding(result)) {
|
||||
aliasInfo.insert(result);
|
||||
pessimistic = false;
|
||||
}
|
||||
|
||||
@@ -151,7 +151,7 @@ private:
|
||||
}
|
||||
|
||||
for (Value result : op->getResults()) {
|
||||
if (isSharedEncoding(result)) {
|
||||
if (triton::gpu::isSharedEncoding(result)) {
|
||||
// Bytes could be a different value once we support padding or other
|
||||
// allocation policies.
|
||||
auto tensorType = result.getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
@@ -60,9 +60,9 @@ SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
|
||||
|
||||
auto argLayout = getSrcLayout();
|
||||
auto argLayoutMma = argLayout.dyn_cast<triton::gpu::MmaEncodingAttr>();
|
||||
if (argLayoutMma && argLayoutMma.getVersionMajor() == 2 &&
|
||||
triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
|
||||
return {{1, 1}, {1, 1}};
|
||||
// if (argLayoutMma && argLayoutMma.getVersionMajor() == 2 &&
|
||||
// triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
|
||||
// return {{1, 1}, {1, 1}};
|
||||
|
||||
/// shared memory block0
|
||||
smemShapes[0] = convertType<unsigned>(getSrcShape());
|
||||
@@ -181,13 +181,6 @@ Type getElementType(Value value) {
|
||||
return type;
|
||||
}
|
||||
|
||||
std::string getValueOperandName(Value value, AsmState &state) {
|
||||
std::string opName;
|
||||
llvm::raw_string_ostream ss(opName);
|
||||
value.printAsOperand(ss, state);
|
||||
return opName;
|
||||
}
|
||||
|
||||
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
// dot_op<opIdx=0, parent=#mma> = #mma
|
||||
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
|
||||
|
||||
@@ -107,7 +107,7 @@ static Value loadA(Value tensor, const SharedMemoryObject &smemObj,
|
||||
auto order = sharedLayout.getOrder();
|
||||
|
||||
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
||||
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||
Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
|
||||
|
||||
bool isARow = order[0] != 0;
|
||||
auto resultEncoding = resultTy.cast<RankedTensorType>()
|
||||
@@ -232,7 +232,7 @@ static Value loadB(Value tensor, const SharedMemoryObject &smemObj,
|
||||
auto shape = tensorTy.getShape();
|
||||
auto order = sharedLayout.getOrder();
|
||||
|
||||
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||
Value smem = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
|
||||
bool isBRow = order[0] != 0; // is row-major in shared memory layout
|
||||
// isBRow_ indicates whether B is row-major in DotOperand layout
|
||||
auto resultEncoding = resultTy.cast<RankedTensorType>()
|
||||
|
||||
@@ -18,7 +18,8 @@ using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
// Data loader for mma.16816 instruction.
|
||||
class MMA16816SmemLoader {
|
||||
public:
|
||||
MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, uint32_t kOrder,
|
||||
MMA16816SmemLoader(int warpsPerTile, ArrayRef<uint32_t> order,
|
||||
ArrayRef<uint32_t> warpsPerCTA, uint32_t kOrder,
|
||||
int kWidth, ArrayRef<Value> smemStrides,
|
||||
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
|
||||
ArrayRef<int> matShape, int perPhase, int maxPhase,
|
||||
@@ -27,7 +28,7 @@ public:
|
||||
const Location &loc);
|
||||
|
||||
// lane = thread % 32
|
||||
// warpOff = (thread/32) % wpt(0)
|
||||
// warpOff = (thread/32) % warpsPerTile(0)
|
||||
llvm::SmallVector<Value> computeOffsets(Value warpOff, Value lane,
|
||||
Value cSwizzleOffset) {
|
||||
if (canUseLdmatrix)
|
||||
@@ -48,12 +49,14 @@ public:
|
||||
Value cSwizzleOffset);
|
||||
|
||||
// Load 4 matrices and returns 4 vec<2> elements.
|
||||
std::tuple<Value, Value, Value, Value>
|
||||
loadX4(int mat0, int mat1, ArrayRef<Value> offs, ArrayRef<Value> ptrs,
|
||||
Type matTy, Type shemPtrTy) const;
|
||||
std::tuple<Value, Value, Value, Value> loadX4(int mat0, int mat1,
|
||||
ArrayRef<Value> ptrs,
|
||||
Type matTy,
|
||||
Type shemPtrTy) const;
|
||||
|
||||
private:
|
||||
SmallVector<uint32_t> order;
|
||||
SmallVector<uint32_t> warpsPerCTA;
|
||||
int kOrder;
|
||||
int kWidth;
|
||||
SmallVector<int64_t> tileShape;
|
||||
@@ -66,42 +69,50 @@ private:
|
||||
const Location &loc;
|
||||
MLIRContext *ctx{};
|
||||
|
||||
int cMatShape;
|
||||
int sMatShape;
|
||||
// ldmatrix loads a matrix of size stridedMatShape x contiguousMatShape
|
||||
int contiguousMatShape;
|
||||
int stridedMatShape;
|
||||
|
||||
Value sStride;
|
||||
// Offset in shared memory to increment on the strided axis
|
||||
// This would be different than the tile shape in the case of a sliced tensor
|
||||
Value stridedSmemOffset;
|
||||
|
||||
bool needTrans;
|
||||
bool canUseLdmatrix;
|
||||
|
||||
int numPtrs;
|
||||
|
||||
int pLoadStrideInMat;
|
||||
int sMatStride;
|
||||
// Load operations offset in number of Matrices on contiguous and strided axes
|
||||
int contiguousLoadMatOffset;
|
||||
int stridedLoadMatOffset;
|
||||
|
||||
int matArrStride;
|
||||
int warpOffStride;
|
||||
// Offset in number of matrices to increment on non-k dim within a warp's 2x2
|
||||
// matrices
|
||||
int inWarpMatOffset;
|
||||
// Offset in number of matrices to increment on non-k dim across warps
|
||||
int warpMatOffset;
|
||||
};
|
||||
|
||||
SmallVector<Value>
|
||||
MMA16816SmemLoader::computeLdmatrixMatOffs(Value warpId, Value lane,
|
||||
Value cSwizzleOffset) {
|
||||
// 4x4 matrices
|
||||
Value c = urem(lane, i32_val(8));
|
||||
Value s = udiv(lane, i32_val(8)); // sub-warp-id
|
||||
Value rowInMat = urem(lane, i32_val(8)); // row in the 8x8 matrix
|
||||
Value matIndex =
|
||||
udiv(lane, i32_val(8)); // linear index of the matrix in the 2x2 matrices
|
||||
|
||||
// Decompose s => s_0, s_1, that is the coordinate in 2x2 matrices in a
|
||||
// Decompose matIndex => s_0, s_1, that is the coordinate in 2x2 matrices in a
|
||||
// warp
|
||||
Value s0 = urem(s, i32_val(2));
|
||||
Value s1 = udiv(s, i32_val(2));
|
||||
Value s0 = urem(matIndex, i32_val(2));
|
||||
Value s1 = udiv(matIndex, i32_val(2));
|
||||
|
||||
// We use different orders for a and b for better performance.
|
||||
Value kMatArr = kOrder == 1 ? s1 : s0;
|
||||
Value nkMatArr = kOrder == 1 ? s0 : s1;
|
||||
Value kMatArr = kOrder == 1 ? s1 : s0; // index of matrix on the k dim
|
||||
Value nkMatArr = kOrder == 1 ? s0 : s1; // index of matrix on the non-k dim
|
||||
|
||||
// Matrix coordinates inside a CTA,
|
||||
// the matrix layout is [2wpt[0], 2] for A and [2, 2wpt[1]] for B.
|
||||
// e.g., Setting wpt=4, the data layout for A(kOrder=1) is
|
||||
// the matrix layout is [2warpsPerTile[0], 2] for A and [2, 2warpsPerTile[1]]
|
||||
// for B. e.g., Setting warpsPerTile=4, the data layout for A(kOrder=1) is
|
||||
// |0 0| -> 0,1,2,3 are the warpids
|
||||
// |0 0|
|
||||
// |1 1|
|
||||
@@ -118,34 +129,42 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value warpId, Value lane,
|
||||
// address (s0,s1) annotates.
|
||||
|
||||
Value matOff[2];
|
||||
matOff[kOrder ^ 1] =
|
||||
add(mul(warpId, i32_val(warpOffStride)), // warp offset (kOrder=1)
|
||||
mul(nkMatArr,
|
||||
i32_val(matArrStride))); // matrix offset inside a warp (kOrder=1)
|
||||
matOff[kOrder ^ 1] = add(
|
||||
mul(warpId, i32_val(warpMatOffset)), // warp offset (kOrder=1)
|
||||
mul(nkMatArr,
|
||||
i32_val(inWarpMatOffset))); // matrix offset inside a warp (kOrder=1)
|
||||
matOff[kOrder] = kMatArr;
|
||||
|
||||
// Physical offset (before swizzling)
|
||||
Value cMatOff = matOff[order[0]];
|
||||
Value sMatOff = matOff[order[1]];
|
||||
Value cSwizzleMatOff = udiv(cSwizzleOffset, i32_val(cMatShape));
|
||||
cMatOff = add(cMatOff, cSwizzleMatOff);
|
||||
|
||||
// row offset inside a matrix, each matrix has 8 rows.
|
||||
Value sOffInMat = c;
|
||||
Value contiguousMatIndex = matOff[order[0]];
|
||||
Value stridedMatIndex = matOff[order[1]];
|
||||
// Add the offset of the slice
|
||||
Value contiguousSliceMatOffset =
|
||||
udiv(cSwizzleOffset, i32_val(contiguousMatShape));
|
||||
|
||||
SmallVector<Value> offs(numPtrs);
|
||||
Value phase = urem(udiv(sOffInMat, i32_val(perPhase)), i32_val(maxPhase));
|
||||
// To prevent out-of-bound access of B when wpt * 16 > tile_size.
|
||||
Value phase = urem(udiv(rowInMat, i32_val(perPhase)), i32_val(maxPhase));
|
||||
// To prevent out-of-bound access of B when warpsPerTile * 16 > tile_size.
|
||||
// In such a case, we need to wrap around the offset of B.
|
||||
// |0 1 2 3 0 1 2 3| -> | 0(0) 1(1) 2(2) 3(3) |
|
||||
// |0 1 2 3 0 1 2 3| | 0(0) 1(1) 2(2) 3(3) |
|
||||
// ~~~~~~~ out-of-bound access
|
||||
Value sOff = urem(add(sOffInMat, mul(sMatOff, i32_val(sMatShape))),
|
||||
i32_val(tileShape[order[1]]));
|
||||
|
||||
Value rowOffset =
|
||||
urem(add(rowInMat, mul(stridedMatIndex, i32_val(stridedMatShape))),
|
||||
i32_val(tileShape[order[1]]));
|
||||
auto contiguousTileNumMats = tileShape[order[0]] / matShape[order[0]];
|
||||
|
||||
for (int i = 0; i < numPtrs; ++i) {
|
||||
Value cMatOffI = add(cMatOff, i32_val(i * pLoadStrideInMat));
|
||||
cMatOffI = xor_(cMatOffI, phase);
|
||||
offs[i] = add(mul(cMatOffI, i32_val(cMatShape)), mul(sOff, sStride));
|
||||
Value contiguousIndex =
|
||||
add(contiguousMatIndex, i32_val(i * contiguousLoadMatOffset));
|
||||
if (warpsPerCTA[order[0]] > contiguousTileNumMats ||
|
||||
contiguousTileNumMats % warpsPerCTA[order[0]] != 0)
|
||||
contiguousIndex = urem(contiguousIndex, i32_val(contiguousTileNumMats));
|
||||
contiguousIndex = add(contiguousIndex, contiguousSliceMatOffset);
|
||||
Value contiguousIndexSwizzled = xor_(contiguousIndex, phase);
|
||||
offs[i] = add(mul(contiguousIndexSwizzled, i32_val(contiguousMatShape)),
|
||||
mul(rowOffset, stridedSmemOffset));
|
||||
}
|
||||
|
||||
return offs;
|
||||
@@ -208,14 +227,14 @@ SmallVector<Value> MMA16816SmemLoader::computeLdsMatOffs(Value warpOff,
|
||||
Value jOff = i32_val(0);
|
||||
if (!needTrans) {
|
||||
jOff = add(jOff, i32_val(quadId));
|
||||
jOff = add(jOff, i32_val(rep * pLoadStrideInMat));
|
||||
jOff = add(jOff, i32_val(rep * contiguousLoadMatOffset));
|
||||
}
|
||||
// outer index offset
|
||||
Value iOff = mul(warpOff, i32_val(warpOffStride));
|
||||
Value iOff = mul(warpOff, i32_val(warpMatOffset));
|
||||
if (needTrans) {
|
||||
int pStride = kOrder == 1 ? 1 : 2;
|
||||
iOff = add(iOff, i32_val(quadId * matArrStride));
|
||||
iOff = add(iOff, i32_val(rep * pLoadStrideInMat * pStride));
|
||||
iOff = add(iOff, i32_val(quadId * inWarpMatOffset));
|
||||
iOff = add(iOff, i32_val(rep * contiguousLoadMatOffset * pStride));
|
||||
}
|
||||
// swizzle
|
||||
if (!needTrans) {
|
||||
@@ -234,9 +253,9 @@ SmallVector<Value> MMA16816SmemLoader::computeLdsMatOffs(Value warpOff,
|
||||
// i = urem(i, i32_val(cTileShape));
|
||||
// j = urem(j, i32_val(sTileShape));
|
||||
if (needTrans) {
|
||||
offs[idx] = add(i, mul(j, sStride));
|
||||
offs[idx] = add(i, mul(j, stridedSmemOffset));
|
||||
} else {
|
||||
offs[idx] = add(mul(i, sStride), j);
|
||||
offs[idx] = add(mul(i, stridedSmemOffset), j);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -244,8 +263,7 @@ SmallVector<Value> MMA16816SmemLoader::computeLdsMatOffs(Value warpOff,
|
||||
}
|
||||
|
||||
std::tuple<Value, Value, Value, Value>
|
||||
MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> offs,
|
||||
ArrayRef<Value> ptrs, Type matTy,
|
||||
MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> ptrs, Type matTy,
|
||||
Type shemPtrTy) const {
|
||||
assert(mat0 % 2 == 0 && mat1 % 2 == 0 && "smem matrix load must be aligned");
|
||||
int matIdx[2] = {mat0, mat1};
|
||||
@@ -282,15 +300,16 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> offs,
|
||||
}
|
||||
|
||||
if (canUseLdmatrix) {
|
||||
Value sOffset =
|
||||
mul(i32_val(matIdx[order[1]] * sMatStride * sMatShape), sStride);
|
||||
Value sOffsetPtr = gep(shemPtrTy, ptr, sOffset);
|
||||
Value stridedOffset =
|
||||
mul(i32_val(matIdx[order[1]] * stridedLoadMatOffset * stridedMatShape),
|
||||
stridedSmemOffset);
|
||||
Value readPtr = gep(shemPtrTy, ptr, stridedOffset);
|
||||
|
||||
PTXBuilder builder;
|
||||
// ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a
|
||||
// thread.
|
||||
auto resArgs = builder.newListOperand(4, "=r");
|
||||
auto addrArg = builder.newAddrOperand(sOffsetPtr, "r");
|
||||
auto addrArg = builder.newAddrOperand(readPtr, "r");
|
||||
|
||||
auto ldmatrix = builder.create("ldmatrix.sync.aligned.m8n8.x4")
|
||||
->o("trans", needTrans /*predicate*/)
|
||||
@@ -313,14 +332,14 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> offs,
|
||||
for (int i = 0; i < vecWidth; i++)
|
||||
ptrs[1][i] = getPtr(ptrIdx + i + vecWidth);
|
||||
// static offsets along outer dimension
|
||||
int _i0 = matIdx[order[1]] * (sMatStride * sMatShape);
|
||||
int _i0 = matIdx[order[1]] * (stridedLoadMatOffset * stridedMatShape);
|
||||
int _i1 = _i0;
|
||||
if (needTrans)
|
||||
_i1 += sMatStride * sMatShape;
|
||||
_i1 += stridedLoadMatOffset * stridedMatShape;
|
||||
else
|
||||
_i1 += (kOrder == 1 ? 1 : sMatStride) * sMatShape;
|
||||
Value i0 = mul(i32_val(_i0), sStride);
|
||||
Value i1 = mul(i32_val(_i1), sStride);
|
||||
_i1 += (kOrder == 1 ? 1 : stridedLoadMatOffset) * stridedMatShape;
|
||||
Value i0 = mul(i32_val(_i0), stridedSmemOffset);
|
||||
Value i1 = mul(i32_val(_i1), stridedSmemOffset);
|
||||
std::array<Value, 2> ii = {i0, i1};
|
||||
// load 4 32-bit values from shared memory
|
||||
// (equivalent to ldmatrix.x4)
|
||||
@@ -363,21 +382,23 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> offs,
|
||||
}
|
||||
|
||||
MMA16816SmemLoader::MMA16816SmemLoader(
|
||||
int wpt, ArrayRef<uint32_t> order, uint32_t kOrder, int kWidth,
|
||||
ArrayRef<Value> smemStrides, ArrayRef<int64_t> tileShape,
|
||||
ArrayRef<int> instrShape, ArrayRef<int> matShape, int perPhase,
|
||||
int maxPhase, int elemBytes, ConversionPatternRewriter &rewriter,
|
||||
int warpsPerTile, ArrayRef<uint32_t> order, ArrayRef<uint32_t> warpsPerCTA,
|
||||
uint32_t kOrder, int kWidth, ArrayRef<Value> smemStrides,
|
||||
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
|
||||
ArrayRef<int> matShape, int perPhase, int maxPhase, int elemBytes,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, const Location &loc)
|
||||
: order(order.begin(), order.end()), kOrder(kOrder), kWidth(kWidth),
|
||||
tileShape(tileShape.begin(), tileShape.end()),
|
||||
: order(order.begin(), order.end()),
|
||||
warpsPerCTA(warpsPerCTA.begin(), warpsPerCTA.end()), kOrder(kOrder),
|
||||
kWidth(kWidth), tileShape(tileShape.begin(), tileShape.end()),
|
||||
instrShape(instrShape.begin(), instrShape.end()),
|
||||
matShape(matShape.begin(), matShape.end()), perPhase(perPhase),
|
||||
maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter), loc(loc),
|
||||
ctx(rewriter.getContext()) {
|
||||
cMatShape = matShape[order[0]];
|
||||
sMatShape = matShape[order[1]];
|
||||
contiguousMatShape = matShape[order[0]];
|
||||
stridedMatShape = matShape[order[1]];
|
||||
|
||||
sStride = smemStrides[order[1]];
|
||||
stridedSmemOffset = smemStrides[order[1]];
|
||||
|
||||
// rule: k must be the fast-changing axis.
|
||||
needTrans = kOrder != order[0];
|
||||
@@ -385,12 +406,13 @@ MMA16816SmemLoader::MMA16816SmemLoader(
|
||||
canUseLdmatrix = canUseLdmatrix && (kWidth == 4 / elemBytes);
|
||||
|
||||
if (canUseLdmatrix) {
|
||||
// Each CTA, the warps is arranged as [1xwpt] if not transposed,
|
||||
// otherwise [wptx1], and each warp will perform a mma.
|
||||
numPtrs =
|
||||
tileShape[order[0]] / (needTrans ? wpt : 1) / instrShape[order[0]];
|
||||
// Each CTA, the warps is arranged as [1xwarpsPerTile] if not transposed,
|
||||
// otherwise [warpsPerTilex1], and each warp will perform a mma.
|
||||
numPtrs = tileShape[order[0]] / (needTrans ? warpsPerTile : 1) /
|
||||
instrShape[order[0]];
|
||||
} else {
|
||||
numPtrs = tileShape[order[0]] / (needTrans ? wpt : 1) / matShape[order[0]];
|
||||
numPtrs = tileShape[order[0]] / (needTrans ? warpsPerTile : 1) /
|
||||
matShape[order[0]];
|
||||
numPtrs *= 4 / elemBytes;
|
||||
}
|
||||
numPtrs = std::max<int>(numPtrs, 2);
|
||||
@@ -398,23 +420,24 @@ MMA16816SmemLoader::MMA16816SmemLoader(
|
||||
// Special rule for i8/u8, 4 ptrs for each matrix
|
||||
// if (!canUseLdmatrix && elemBytes == 1)
|
||||
|
||||
int loadStrideInMat[2];
|
||||
loadStrideInMat[kOrder] =
|
||||
int loadOffsetInMat[2];
|
||||
loadOffsetInMat[kOrder] =
|
||||
2; // instrShape[kOrder] / matShape[kOrder], always 2
|
||||
loadStrideInMat[kOrder ^ 1] =
|
||||
wpt * (instrShape[kOrder ^ 1] / matShape[kOrder ^ 1]);
|
||||
loadOffsetInMat[kOrder ^ 1] =
|
||||
warpsPerTile * (instrShape[kOrder ^ 1] / matShape[kOrder ^ 1]);
|
||||
|
||||
pLoadStrideInMat = loadStrideInMat[order[0]];
|
||||
contiguousLoadMatOffset = loadOffsetInMat[order[0]];
|
||||
|
||||
sMatStride =
|
||||
loadStrideInMat[order[1]] / (instrShape[order[1]] / matShape[order[1]]);
|
||||
stridedLoadMatOffset =
|
||||
loadOffsetInMat[order[1]] / (instrShape[order[1]] / matShape[order[1]]);
|
||||
|
||||
// Each matArr contains warpOffStride matrices.
|
||||
matArrStride = kOrder == 1 ? 1 : wpt;
|
||||
warpOffStride = instrShape[kOrder ^ 1] / matShape[kOrder ^ 1];
|
||||
// The stride (in number of matrices) within warp
|
||||
inWarpMatOffset = kOrder == 1 ? 1 : warpsPerTile;
|
||||
// The stride (in number of matrices) of each warp
|
||||
warpMatOffset = instrShape[kOrder ^ 1] / matShape[kOrder ^ 1];
|
||||
}
|
||||
|
||||
Type getShemPtrTy(Type argType) {
|
||||
Type getSharedMemPtrTy(Type argType) {
|
||||
MLIRContext *ctx = argType.getContext();
|
||||
if (argType.isF16())
|
||||
return ptr_ty(type::f16Ty(ctx), 3);
|
||||
@@ -451,13 +474,12 @@ Value composeValuesToDotOperandLayoutStruct(
|
||||
return result;
|
||||
}
|
||||
|
||||
std::function<void(int, int)>
|
||||
getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj,
|
||||
MmaEncodingAttr mmaLayout, int wpt, uint32_t kOrder, int kWidth,
|
||||
SmallVector<int> instrShape, SmallVector<int> matShape,
|
||||
Value warpId, Value lane, ValueTable &vals, bool isA,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Location loc) {
|
||||
std::function<void(int, int)> getLoadMatrixFn(
|
||||
Value tensor, const SharedMemoryObject &smemObj, MmaEncodingAttr mmaLayout,
|
||||
int warpsPerTile, uint32_t kOrder, int kWidth, SmallVector<int> instrShape,
|
||||
SmallVector<int> matShape, Value warpId, Value lane, ValueTable &vals,
|
||||
bool isA, TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Location loc) {
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
Type eltTy = tensorTy.getElementType();
|
||||
// We assumes that the input operand of Dot should be from shared layout.
|
||||
@@ -471,25 +493,27 @@ getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj,
|
||||
// (a, b) is the coordinate.
|
||||
auto load = [=, &rewriter, &vals](int a, int b) {
|
||||
MMA16816SmemLoader loader(
|
||||
wpt, sharedLayout.getOrder(), kOrder, kWidth, smemObj.strides,
|
||||
tensorTy.getShape() /*tileShape*/, instrShape, matShape, perPhase,
|
||||
maxPhase, elemBytes, rewriter, typeConverter, loc);
|
||||
warpsPerTile, sharedLayout.getOrder(), mmaLayout.getWarpsPerCTA(),
|
||||
kOrder, kWidth, smemObj.strides, tensorTy.getShape() /*tileShape*/,
|
||||
instrShape, matShape, perPhase, maxPhase, elemBytes, rewriter,
|
||||
typeConverter, loc);
|
||||
// Offset of a slice within the original tensor in shared memory
|
||||
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
||||
SmallVector<Value> offs =
|
||||
loader.computeOffsets(warpId, lane, cSwizzleOffset);
|
||||
// initialize pointers
|
||||
const int numPtrs = loader.getNumPtrs();
|
||||
SmallVector<Value> ptrs(numPtrs);
|
||||
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||
Type smemPtrTy = getShemPtrTy(eltTy);
|
||||
Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
|
||||
Type smemPtrTy = getSharedMemPtrTy(eltTy);
|
||||
for (int i = 0; i < numPtrs; ++i)
|
||||
ptrs[i] = bitcast(gep(smemPtrTy, smemBase, offs[i]), smemPtrTy);
|
||||
// actually load from shared memory
|
||||
auto matTy = LLVM::LLVMStructType::getLiteral(eltTy.getContext(),
|
||||
SmallVector<Type>(4, i32_ty));
|
||||
auto [ha0, ha1, ha2, ha3] = loader.loadX4(
|
||||
(kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs,
|
||||
ptrs, matTy, getShemPtrTy(eltTy));
|
||||
(kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, ptrs,
|
||||
matTy, getSharedMemPtrTy(eltTy));
|
||||
if (!isA)
|
||||
std::swap(ha1, ha2);
|
||||
// the following is incorrect
|
||||
@@ -527,32 +551,34 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, Value tensor,
|
||||
auto numRep = encoding.getMMAv2Rep(tensorTy.getShape(), bitwidth);
|
||||
int kWidth = encoding.getMMAv2kWidth();
|
||||
|
||||
int wpt0 = mmaLayout.getWarpsPerCTA()[0];
|
||||
int wpt1 = mmaLayout.getWarpsPerCTA()[1];
|
||||
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
|
||||
Value warp = udiv(thread, i32_val(32));
|
||||
Value lane = urem(thread, i32_val(32));
|
||||
Value warpM = urem(urem(warp, i32_val(wpt0)), i32_val(shape[0] / 16));
|
||||
Value warpMN = udiv(warp, i32_val(wpt0));
|
||||
Value warpN = urem(urem(warpMN, i32_val(wpt1)), i32_val(shape[1] / 8));
|
||||
// Note: warps are currently column major in MMA layout
|
||||
Value warpRowIndex = urem(warp, i32_val(warpsPerCTA[0]));
|
||||
Value warpColIndex =
|
||||
urem(udiv(warp, i32_val(warpsPerCTA[0])), i32_val(warpsPerCTA[1]));
|
||||
Value warpM = urem(warpRowIndex, i32_val(shape[0] / 16));
|
||||
Value warpN = urem(warpColIndex, i32_val(shape[1] / 8));
|
||||
|
||||
int wpt;
|
||||
int warpsPerTile;
|
||||
if (isA)
|
||||
wpt = std::min<int>(wpt0, shape[0] / 16);
|
||||
warpsPerTile = std::min<int>(warpsPerCTA[0], shape[0] / 16);
|
||||
else
|
||||
wpt = std::min<int>(wpt1, shape[1] / 16);
|
||||
warpsPerTile = std::min<int>(warpsPerCTA[1], shape[1] / 16);
|
||||
|
||||
std::function<void(int, int)> loadFn;
|
||||
if (isA)
|
||||
loadFn = getLoadMatrixFn(
|
||||
tensor, smemObj, mmaLayout, wpt /*wpt*/, 1 /*kOrder*/, kWidth,
|
||||
{mmaInstrM, mmaInstrK} /*instrShape*/,
|
||||
tensor, smemObj, mmaLayout, warpsPerTile /*warpsPerTile*/, 1 /*kOrder*/,
|
||||
kWidth, {mmaInstrM, mmaInstrK} /*instrShape*/,
|
||||
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, lane /*laneId*/,
|
||||
vals /*vals*/, isA /*isA*/, typeConverter /* typeConverter */,
|
||||
rewriter /*rewriter*/, loc /*loc*/);
|
||||
else
|
||||
loadFn = getLoadMatrixFn(
|
||||
tensor, smemObj, mmaLayout, wpt /*wpt*/, 0 /*kOrder*/, kWidth,
|
||||
{mmaInstrK, mmaInstrN} /*instrShape*/,
|
||||
tensor, smemObj, mmaLayout, warpsPerTile /*warpsPerTile*/, 0 /*kOrder*/,
|
||||
kWidth, {mmaInstrK, mmaInstrN} /*instrShape*/,
|
||||
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, lane /*laneId*/,
|
||||
vals /*vals*/, isA /*isA*/, typeConverter /* typeConverter */,
|
||||
rewriter /*rewriter*/, loc /*loc*/);
|
||||
|
||||
@@ -402,7 +402,14 @@ struct StoreOpConversion
|
||||
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
||||
|
||||
auto &ptxStoreInstr =
|
||||
ptxBuilder.create<>("st")->global().v(nWords).b(width);
|
||||
ptxBuilder.create<>("st")
|
||||
->global()
|
||||
.o("L1::evict_first",
|
||||
op.getEvict() == triton::EvictionPolicy::EVICT_FIRST)
|
||||
.o("L1::evict_last",
|
||||
op.getEvict() == triton::EvictionPolicy::EVICT_LAST)
|
||||
.v(nWords)
|
||||
.b(width);
|
||||
ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b");
|
||||
|
||||
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
|
||||
@@ -528,11 +535,6 @@ struct AtomicCASOpConversion
|
||||
: valueTy;
|
||||
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
|
||||
Value mask = getMask(valueTy, rewriter, loc);
|
||||
PTXBuilder ptxBuilderMemfence;
|
||||
auto memfence = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
|
||||
memfence();
|
||||
auto ASMReturnTy = void_ty(ctx);
|
||||
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
|
||||
|
||||
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
|
||||
@@ -546,7 +548,10 @@ struct AtomicCASOpConversion
|
||||
auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, "r");
|
||||
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r");
|
||||
auto &atom = *ptxBuilderAtomicCAS.create<PTXInstr>("atom");
|
||||
atom.global().o("cas").o("b32");
|
||||
std::string semStr;
|
||||
llvm::raw_string_ostream os(semStr);
|
||||
os << op.getSem();
|
||||
atom.global().o(semStr).o("cas").o("b32");
|
||||
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask);
|
||||
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
|
||||
barrier();
|
||||
@@ -557,8 +562,8 @@ struct AtomicCASOpConversion
|
||||
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
|
||||
st.shared().o("b32");
|
||||
st(dstOprStore, valOprStore).predicate(mask);
|
||||
auto ASMReturnTy = void_ty(ctx);
|
||||
ptxBuilderStore.launch(rewriter, loc, ASMReturnTy);
|
||||
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
|
||||
barrier();
|
||||
Value ret = load(atomPtr);
|
||||
barrier();
|
||||
@@ -731,7 +736,7 @@ struct AtomicRMWOpConversion
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
|
||||
//
|
||||
auto atomicRmwAttr = op.getAtomicRmwOp();
|
||||
|
||||
Value val = op.getVal();
|
||||
@@ -832,7 +837,10 @@ struct AtomicRMWOpConversion
|
||||
default:
|
||||
return failure();
|
||||
}
|
||||
atom.o(rmwOp).o(sTy);
|
||||
std::string semStr;
|
||||
llvm::raw_string_ostream os(semStr);
|
||||
os << op.getSem();
|
||||
atom.o(semStr).o(rmwOp).o(sTy);
|
||||
if (tensorTy) {
|
||||
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
||||
auto retType = vec == 1 ? valueElemTy : vecTy;
|
||||
@@ -842,13 +850,13 @@ struct AtomicRMWOpConversion
|
||||
vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii));
|
||||
}
|
||||
} else {
|
||||
PTXBuilder ptxBuilderMemfence;
|
||||
auto memfenc = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
|
||||
memfenc();
|
||||
auto ASMReturnTy = void_ty(ctx);
|
||||
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
|
||||
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
||||
auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy);
|
||||
if (op->user_begin() == op->user_end()) {
|
||||
rewriter.replaceOp(op, {old});
|
||||
return success();
|
||||
}
|
||||
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
|
||||
// Only threads with rmwMask = True store the result
|
||||
|
||||
@@ -253,8 +253,8 @@ struct SharedMemoryObject {
|
||||
return offsets[order];
|
||||
}
|
||||
|
||||
Value getBaseBeforeSwizzle(int order, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value getBaseBeforeSlice(int order, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value cSwizzleOffset = getCSwizzleOffset(order);
|
||||
Value offset = sub(i32_val(0), cSwizzleOffset);
|
||||
Type type = base.getType();
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "mlir/Dialect/Index/IR/IndexDialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||
@@ -309,11 +310,43 @@ struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::CatOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// For now, this behaves like generic, but this will evolve when
|
||||
// we add support for `can_reorder=False`
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
// The cat op satisfy two conditions:
|
||||
// 1. output.numel = lhs.numel + rhs.numel
|
||||
// 2. output.total_elems_per_thread =
|
||||
// next_power_of_2(lhs.total_elems_per_thread + rhs.total_elems_per_thread)
|
||||
// For now, this behaves like generic, but this
|
||||
// will evolve when we add support for `can_reorder=False`.
|
||||
auto retType = this->getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto retEncoding =
|
||||
retType.getEncoding().cast<triton::gpu::BlockedEncodingAttr>();
|
||||
auto lhsType = adaptor.getLhs().getType().cast<RankedTensorType>();
|
||||
auto rhsType = adaptor.getRhs().getType().cast<RankedTensorType>();
|
||||
auto lhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(lhsType);
|
||||
auto rhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(rhsType);
|
||||
auto retTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(retType);
|
||||
auto retShape = retType.getShape();
|
||||
auto retOrder = retEncoding.getOrder();
|
||||
auto retSizePerThread = retEncoding.getSizePerThread();
|
||||
auto retThreadsPerWarp = retEncoding.getThreadsPerWarp();
|
||||
auto retWarpsPerCTA = retEncoding.getWarpsPerCTA();
|
||||
// Get new retSizePerThread if ret elems per thread is not enough.
|
||||
// We have to round it up to the next power of 2 due to triton's tensor size
|
||||
// constraint.
|
||||
auto newRetTotalElemsPerThread =
|
||||
nextPowOf2(lhsTotalElemsPerThread + rhsTotalElemsPerThread);
|
||||
auto newRetSizePerThread = retSizePerThread.vec();
|
||||
newRetSizePerThread[retOrder[0]] *=
|
||||
newRetTotalElemsPerThread / retTotalElemsPerThread;
|
||||
triton::gpu::BlockedEncodingAttr newRetEncoding =
|
||||
triton::gpu::BlockedEncodingAttr::get(getContext(), newRetSizePerThread,
|
||||
retThreadsPerWarp, retWarpsPerCTA,
|
||||
retOrder);
|
||||
auto newRetType = RankedTensorType::get(retShape, retType.getElementType(),
|
||||
newRetEncoding);
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::CatOp>(
|
||||
op, retType, adaptor.getOperands()),
|
||||
op, newRetType, adaptor.getOperands()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
@@ -393,7 +426,8 @@ struct TritonAtomicCASPattern
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::AtomicCASOp>(
|
||||
op, typeConverter->convertType(op.getType()),
|
||||
adaptor.getPtr(), adaptor.getCmp(), adaptor.getVal()),
|
||||
adaptor.getPtr(), adaptor.getCmp(), adaptor.getVal(),
|
||||
op.getSem()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
@@ -409,7 +443,7 @@ struct TritonAtomicRMWPattern
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::AtomicRMWOp>(
|
||||
op, typeConverter->convertType(op.getType()),
|
||||
adaptor.getAtomicRmwOp(), adaptor.getPtr(),
|
||||
adaptor.getVal(), adaptor.getMask()),
|
||||
adaptor.getVal(), adaptor.getMask(), op.getSem()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -410,9 +411,6 @@ bool isaDistributedLayout(Attribute layout) {
|
||||
layout.isa<MfmaEncodingAttr>() || layout.isa<SliceEncodingAttr>();
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace triton
|
||||
|
||||
bool isSharedEncoding(Value value) {
|
||||
auto type = value.getType();
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
@@ -422,6 +420,9 @@ bool isSharedEncoding(Value value) {
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace triton
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
static LogicalResult parseIntAttrValue(AsmParser &parser, Attribute attr,
|
||||
@@ -1282,6 +1283,10 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
|
||||
}
|
||||
// cvt(cat) -> cat
|
||||
if (auto cat = dyn_cast<triton::CatOp>(arg)) {
|
||||
auto encoding =
|
||||
op->getResult(0).getType().cast<RankedTensorType>().getEncoding();
|
||||
if (isExpensiveCat(cat, encoding))
|
||||
return mlir::failure();
|
||||
rewriter.replaceOpWithNewOp<triton::CatOp>(op, op->getResult(0).getType(),
|
||||
cat.getOperands());
|
||||
return mlir::success();
|
||||
@@ -1289,7 +1294,7 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
|
||||
// cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2)
|
||||
auto alloc_tensor = dyn_cast<triton::gpu::AllocTensorOp>(arg);
|
||||
if (alloc_tensor) {
|
||||
if (!isSharedEncoding(op->getResult(0))) {
|
||||
if (!triton::gpu::isSharedEncoding(op->getResult(0))) {
|
||||
return mlir::failure();
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::AllocTensorOp>(
|
||||
@@ -1299,7 +1304,7 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
|
||||
// cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2))
|
||||
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
|
||||
if (insert_slice) {
|
||||
if (!isSharedEncoding(op->getResult(0))) {
|
||||
if (!triton::gpu::isSharedEncoding(op->getResult(0))) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto newType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
@@ -1321,7 +1326,7 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
|
||||
// cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2))
|
||||
auto extract_slice = dyn_cast<triton::gpu::ExtractSliceOp>(arg);
|
||||
if (extract_slice) {
|
||||
if (!isSharedEncoding(op->getResult(0))) {
|
||||
if (!triton::gpu::isSharedEncoding(op->getResult(0))) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto origType =
|
||||
@@ -1351,12 +1356,13 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
|
||||
// cvt(cvt(x, type1), type2) -> cvt(x, type2)
|
||||
if (llvm::isa<triton::gpu::ConvertLayoutOp>(arg)) {
|
||||
if (arg->getOperand(0).getDefiningOp() &&
|
||||
!isSharedEncoding(arg->getOperand(0)) &&
|
||||
isSharedEncoding(op.getOperand()) &&
|
||||
!isSharedEncoding(op.getResult())) {
|
||||
!triton::gpu::isSharedEncoding(arg->getOperand(0)) &&
|
||||
triton::gpu::isSharedEncoding(op.getOperand()) &&
|
||||
!triton::gpu::isSharedEncoding(op.getResult())) {
|
||||
return mlir::failure();
|
||||
}
|
||||
if (isSharedEncoding(op.getOperand()) && isSharedEncoding(op.getResult())) {
|
||||
if (triton::gpu::isSharedEncoding(op.getOperand()) &&
|
||||
triton::gpu::isSharedEncoding(op.getResult())) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto srcType = op.getOperand().getType().cast<RankedTensorType>();
|
||||
|
||||
@@ -7,7 +7,7 @@ mlir::OpTrait::impl::verifyResultsAreSharedEncoding(Operation *op) {
|
||||
return failure();
|
||||
|
||||
for (auto result : op->getResults())
|
||||
if (!isSharedEncoding(result))
|
||||
if (!triton::gpu::isSharedEncoding(result))
|
||||
return op->emitOpError() << "requires all results to be shared encoding";
|
||||
|
||||
return success();
|
||||
|
||||
@@ -47,7 +47,7 @@ SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
|
||||
auto filter = [&dotOp](Operation *op) {
|
||||
return op->getParentRegion() == dotOp->getParentRegion();
|
||||
};
|
||||
auto slices = mlir::getSlice(dotOp, filter);
|
||||
auto slices = mlir::getSlice(dotOp, {filter});
|
||||
for (Operation *op : slices)
|
||||
if (isa<triton::DotOp>(op) && (op != dotOp))
|
||||
return {(unsigned)numWarps, 1};
|
||||
@@ -234,8 +234,8 @@ public:
|
||||
if (versionMajor == 1) {
|
||||
SetVector<Operation *> aBwdSlices, bBwdSlices;
|
||||
auto isCvt = [](Operation *op) { return isa<ConvertLayoutOp>(op); };
|
||||
getBackwardSlice(a, &aBwdSlices, isCvt);
|
||||
getBackwardSlice(b, &bBwdSlices, isCvt);
|
||||
getBackwardSlice(a, &aBwdSlices, {isCvt});
|
||||
getBackwardSlice(b, &bBwdSlices, {isCvt});
|
||||
// get the source of the first conversion found in slices
|
||||
auto getCvtArgOrder = [](Operation *op) {
|
||||
return cast<ConvertLayoutOp>(op)
|
||||
|
||||
@@ -14,8 +14,8 @@ using namespace mlir::triton;
|
||||
template <class T> SmallVector<unsigned, 4> argSort(const T &arr) {
|
||||
SmallVector<unsigned, 4> ret(arr.size());
|
||||
std::iota(ret.begin(), ret.end(), 0);
|
||||
std::sort(ret.begin(), ret.end(),
|
||||
[&](unsigned x, unsigned y) { return arr[x] > arr[y]; });
|
||||
std::stable_sort(ret.begin(), ret.end(),
|
||||
[&](unsigned x, unsigned y) { return arr[x] > arr[y]; });
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
#include "Utility.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
@@ -6,6 +5,7 @@
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
|
||||
#include <memory>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -172,7 +172,7 @@ LogicalResult Prefetcher::initialize() {
|
||||
break;
|
||||
rets.push_back(op->getOperand(0));
|
||||
if (auto cvt = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(op))
|
||||
if (isSharedEncoding(cvt.getOperand())) {
|
||||
if (triton::gpu::isSharedEncoding(cvt.getOperand())) {
|
||||
foundConvertFromShared = true;
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
#include "Utility.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
@@ -16,6 +15,7 @@
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
@@ -84,6 +84,7 @@ public:
|
||||
!dstParent.isa<triton::gpu::MfmaEncodingAttr>()))
|
||||
return mlir::failure();
|
||||
|
||||
<<<<<<< HEAD
|
||||
if (dstParent.isa<triton::gpu::MmaEncodingAttr>()) {
|
||||
auto dstParentMma = dstParent.cast<triton::gpu::MmaEncodingAttr>();
|
||||
if (dstParentMma.isVolta() || dstParentMma.getWarpsPerCTA()[1] > 1)
|
||||
@@ -97,6 +98,15 @@ public:
|
||||
return mlir::failure();
|
||||
return processEncoding(dstParentMfma, convert, dstType, rewriter);
|
||||
}
|
||||
=======
|
||||
auto tmpType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(), dstParentMma);
|
||||
auto tmp = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
convert.getLoc(), tmpType, convert.getOperand());
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(op, dstType,
|
||||
tmp);
|
||||
return mlir::success();
|
||||
>>>>>>> oai/main
|
||||
}
|
||||
return mlir::failure();
|
||||
}
|
||||
@@ -373,13 +383,13 @@ public:
|
||||
!(isa<triton::ReduceOp>(op) &&
|
||||
!op->getResult(0).getType().isa<RankedTensorType>());
|
||||
};
|
||||
mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter);
|
||||
mlir::getForwardSlice(cvt.getResult(), &cvtSlices, {filter});
|
||||
if (cvtSlices.empty())
|
||||
return failure();
|
||||
|
||||
for (Operation *op : cvtSlices) {
|
||||
// don't rematerialize anything expensive
|
||||
if (expensiveToRemat(op, dstEncoding))
|
||||
if (isExpensiveToRemat(op, srcEncoding))
|
||||
return failure();
|
||||
// don't rematerialize non-element-wise
|
||||
if (!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
|
||||
@@ -428,8 +438,8 @@ public:
|
||||
if (!op)
|
||||
return mlir::failure();
|
||||
// we don't want to rematerialize any conversion to/from shared
|
||||
if (isSharedEncoding(cvt->getResults()[0]) ||
|
||||
isSharedEncoding(cvt->getOperand(0)))
|
||||
if (triton::gpu::isSharedEncoding(cvt->getResults()[0]) ||
|
||||
triton::gpu::isSharedEncoding(cvt->getOperand(0)))
|
||||
return mlir::failure();
|
||||
// we don't handle conversions to DotOperandEncodingAttr
|
||||
// this is a heuristics to accommodate fused attention
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#include "Utility.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
@@ -88,7 +88,7 @@ LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
||||
return success();
|
||||
}
|
||||
|
||||
bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
|
||||
bool isExpensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
|
||||
// Case 1: A size 1 tensor is not expensive since all threads will load the
|
||||
// same
|
||||
if (isSingleValue(op->getOperand(0)))
|
||||
@@ -96,22 +96,34 @@ bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
|
||||
// Case 2: Tensor of pointers has more threads than elements
|
||||
// we can presume a high hit-rate that makes it cheap to load
|
||||
auto ptrType = op->getOperand(0).getType().cast<RankedTensorType>();
|
||||
IntegerAttr numWarps =
|
||||
op->getParentOfType<ModuleOp>()->getAttrOfType<IntegerAttr>(
|
||||
"triton_gpu.num-warps");
|
||||
if (numWarps) {
|
||||
int sizePerThread = triton::gpu::getTotalElemsPerThread(ptrType);
|
||||
if (ptrType.getNumElements() < numWarps.getInt() * 32)
|
||||
return false;
|
||||
}
|
||||
auto mod = op->getParentOfType<ModuleOp>();
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
|
||||
if (ptrType.getNumElements() < numWarps * threadsPerWarp)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool expensiveToRemat(Operation *op, Attribute &targetEncoding) {
|
||||
bool isExpensiveCat(triton::CatOp cat, Attribute &targetEncoding) {
|
||||
// If the new elements per thread is less than the old one, we will need to do
|
||||
// convert encoding that goes through shared memory anyway. So we consider it
|
||||
// as expensive.
|
||||
auto tensorTy = cat.getResult().getType().cast<RankedTensorType>();
|
||||
auto totalElemsPerThread = triton::gpu::getTotalElemsPerThread(tensorTy);
|
||||
auto shape = tensorTy.getShape();
|
||||
auto elemTy = tensorTy.getElementType();
|
||||
auto newTotalElemsPerThread =
|
||||
triton::gpu::getTotalElemsPerThread(targetEncoding, shape, elemTy);
|
||||
return newTotalElemsPerThread < totalElemsPerThread;
|
||||
}
|
||||
|
||||
bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) {
|
||||
if (!op)
|
||||
return true;
|
||||
if (isa<triton::LoadOp, triton::StoreOp>(op))
|
||||
return expensiveLoadOrStore(op, targetEncoding);
|
||||
return isExpensiveLoadOrStore(op, targetEncoding);
|
||||
if (isa<triton::CatOp>(op))
|
||||
return isExpensiveCat(cast<triton::CatOp>(op), targetEncoding);
|
||||
if (isa<tensor::ExtractSliceOp, triton::gpu::AllocTensorOp,
|
||||
triton::gpu::InsertSliceAsyncOp, triton::AtomicRMWOp,
|
||||
triton::AtomicCASOp, triton::DotOp>(op))
|
||||
@@ -122,10 +134,11 @@ bool expensiveToRemat(Operation *op, Attribute &targetEncoding) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool canFoldConversion(Operation *op) {
|
||||
bool canFoldConversion(Operation *op, Attribute &targetEncoding) {
|
||||
if (isa<triton::CatOp>(op))
|
||||
return !isExpensiveCat(cast<triton::CatOp>(op), targetEncoding);
|
||||
return isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
|
||||
triton::MakeRangeOp, triton::SplatOp, triton::ViewOp,
|
||||
triton::CatOp>(*op);
|
||||
triton::MakeRangeOp, triton::SplatOp, triton::ViewOp>(op);
|
||||
}
|
||||
|
||||
int simulateBackwardRematerialization(
|
||||
@@ -145,7 +158,7 @@ int simulateBackwardRematerialization(
|
||||
queue.pop_back();
|
||||
// If the current operation is expensive to rematerialize,
|
||||
// we stop everything
|
||||
if (expensiveToRemat(currOp, currLayout))
|
||||
if (isExpensiveToRemat(currOp, currLayout))
|
||||
break;
|
||||
// A conversion will be removed here (i.e. transferred to operands)
|
||||
numCvts -= 1;
|
||||
@@ -173,7 +186,7 @@ int simulateBackwardRematerialization(
|
||||
continue;
|
||||
// If the conversion can be folded into opArgI then
|
||||
// we don't count this conversion as expensive
|
||||
if (canFoldConversion(opArgI))
|
||||
if (canFoldConversion(opArgI, newEncoding))
|
||||
continue;
|
||||
|
||||
// We add one expensive conversion for the current operand
|
||||
|
||||
@@ -31,6 +31,17 @@ def get_build_type():
|
||||
# TODO: change to release when stable enough
|
||||
return "TritonRelBuildWithAsserts"
|
||||
|
||||
|
||||
def get_codegen_backends():
|
||||
backends = []
|
||||
env_prefix = "TRITON_CODEGEN_"
|
||||
for name, _ in os.environ.items():
|
||||
if name.startswith(env_prefix) and check_env_flag(name):
|
||||
assert name.count(env_prefix) <= 1
|
||||
backends.append(name.replace(env_prefix, '').lower())
|
||||
return backends
|
||||
|
||||
|
||||
# --- third party packages -----
|
||||
|
||||
|
||||
@@ -210,6 +221,11 @@ class CMakeBuild(build_ext):
|
||||
cfg = get_build_type()
|
||||
build_args = ["--config", cfg]
|
||||
|
||||
codegen_backends = get_codegen_backends()
|
||||
if len(codegen_backends) > 0:
|
||||
all_codegen_backends = ';'.join(codegen_backends)
|
||||
cmake_args += ["-DTRITON_CODEGEN_BACKENDS=" + all_codegen_backends]
|
||||
|
||||
if platform.system() == "Windows":
|
||||
cmake_args += [f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"]
|
||||
if sys.maxsize > 2**32:
|
||||
@@ -249,16 +265,14 @@ setup(
|
||||
"triton/_C",
|
||||
"triton/common",
|
||||
"triton/compiler",
|
||||
"triton/debugger",
|
||||
"triton/interpreter",
|
||||
"triton/language",
|
||||
"triton/language/extra",
|
||||
"triton/ops",
|
||||
"triton/ops/blocksparse",
|
||||
"triton/runtime",
|
||||
"triton/runtime/backends",
|
||||
"triton/third_party/cuda/bin",
|
||||
"triton/third_party/cuda/include",
|
||||
"triton/third_party/cuda/lib",
|
||||
"triton/third_party",
|
||||
"triton/tools",
|
||||
],
|
||||
install_requires=[
|
||||
|
||||
@@ -92,6 +92,13 @@ void init_triton_ir(py::module &&m) {
|
||||
.value("CG", mlir::triton::CacheModifier::CG)
|
||||
.export_values();
|
||||
|
||||
py::enum_<mlir::triton::MemSemantic>(m, "MEM_SEMANTIC")
|
||||
.value("ACQUIRE_RELEASE", mlir::triton::MemSemantic::ACQUIRE_RELEASE)
|
||||
.value("ACQUIRE", mlir::triton::MemSemantic::ACQUIRE)
|
||||
.value("RELEASE", mlir::triton::MemSemantic::RELEASE)
|
||||
.value("RELAXED", mlir::triton::MemSemantic::RELAXED)
|
||||
.export_values();
|
||||
|
||||
py::enum_<mlir::triton::EvictionPolicy>(m, "EVICTION_POLICY")
|
||||
.value("NORMAL", mlir::triton::EvictionPolicy::NORMAL)
|
||||
.value("EVICT_FIRST", mlir::triton::EvictionPolicy::EVICT_FIRST)
|
||||
@@ -1370,7 +1377,7 @@ void init_triton_ir(py::module &&m) {
|
||||
// // atomic
|
||||
.def("create_atomic_cas",
|
||||
[](mlir::OpBuilder &self, mlir::Value &ptr, mlir::Value &cmp,
|
||||
mlir::Value &val) -> mlir::Value {
|
||||
mlir::Value &val, mlir::triton::MemSemantic sem) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
mlir::Type dstType;
|
||||
if (auto srcTensorType =
|
||||
@@ -1386,12 +1393,12 @@ void init_triton_ir(py::module &&m) {
|
||||
dstType = ptrType.getPointeeType();
|
||||
}
|
||||
return self.create<mlir::triton::AtomicCASOp>(loc, dstType, ptr,
|
||||
cmp, val);
|
||||
cmp, val, sem);
|
||||
})
|
||||
.def("create_atomic_rmw",
|
||||
[](mlir::OpBuilder &self, mlir::triton::RMWOp rmwOp,
|
||||
mlir::Value &ptr, mlir::Value &val,
|
||||
mlir::Value &mask) -> mlir::Value {
|
||||
mlir::Value &ptr, mlir::Value &val, mlir::Value &mask,
|
||||
mlir::triton::MemSemantic sem) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
mlir::Type dstType;
|
||||
if (auto srcTensorType =
|
||||
@@ -1407,7 +1414,7 @@ void init_triton_ir(py::module &&m) {
|
||||
dstType = ptrType.getPointeeType();
|
||||
}
|
||||
return self.create<mlir::triton::AtomicRMWOp>(loc, dstType, rmwOp,
|
||||
ptr, val, mask);
|
||||
ptr, val, mask, sem);
|
||||
})
|
||||
// External
|
||||
.def("create_extern_elementwise",
|
||||
|
||||
42
python/test/backend/extension_backend.c
Normal file
42
python/test/backend/extension_backend.c
Normal file
@@ -0,0 +1,42 @@
|
||||
#include <Python.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
|
||||
// create a struct to hold device properties
|
||||
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", 1024,
|
||||
"multiprocessor_count", 16, "sm_clock_rate", 2100,
|
||||
"mem_clock_rate", 2300, "mem_bus_width", 2400);
|
||||
}
|
||||
|
||||
static PyObject *loadBinary(PyObject *self, PyObject *args) {
|
||||
// get allocated registers and spilled registers from the function
|
||||
int n_regs = 0;
|
||||
int n_spills = 0;
|
||||
int mod = 0;
|
||||
int fun = 0;
|
||||
return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs,
|
||||
n_spills);
|
||||
}
|
||||
|
||||
static PyMethodDef ModuleMethods[] = {
|
||||
{"load_binary", loadBinary, METH_VARARGS,
|
||||
"Load dummy binary for the extension device"},
|
||||
{"get_device_properties", getDeviceProperties, METH_VARARGS,
|
||||
"Get the properties for the extension device"},
|
||||
{NULL, NULL, 0, NULL} // sentinel
|
||||
};
|
||||
|
||||
static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "ext_utils",
|
||||
NULL, // documentation
|
||||
-1, // size
|
||||
ModuleMethods};
|
||||
|
||||
PyMODINIT_FUNC PyInit_ext_utils(void) {
|
||||
PyObject *m = PyModule_Create(&ModuleDef);
|
||||
if (m == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
PyModule_AddFunctions(m, ModuleMethods);
|
||||
return m;
|
||||
}
|
||||
262
python/test/backend/test_device_backend.py
Normal file
262
python/test/backend/test_device_backend.py
Normal file
@@ -0,0 +1,262 @@
|
||||
import functools
|
||||
import hashlib
|
||||
import importlib
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sysconfig
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import setuptools
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.common.backend import BaseBackend, register_backend
|
||||
from triton.common.build import quiet
|
||||
from triton.compiler.make_launcher import make_so_cache_key
|
||||
from triton.runtime.cache import get_cache_manager
|
||||
from triton.runtime.driver import DriverBase
|
||||
from triton.runtime.jit import version_key
|
||||
|
||||
|
||||
def build_for_backend(name, src, srcdir):
|
||||
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
||||
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
|
||||
# try to avoid setuptools if possible
|
||||
cc = os.environ.get("CC")
|
||||
if cc is None:
|
||||
# TODO: support more things here.
|
||||
clang = shutil.which("clang")
|
||||
gcc = shutil.which("gcc")
|
||||
cc = gcc if gcc is not None else clang
|
||||
if cc is None:
|
||||
raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.")
|
||||
# This function was renamed and made public in Python 3.10
|
||||
if hasattr(sysconfig, 'get_default_scheme'):
|
||||
scheme = sysconfig.get_default_scheme()
|
||||
else:
|
||||
scheme = sysconfig._get_default_scheme()
|
||||
# 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install
|
||||
# path changes to include 'local'. This change is required to use triton with system-wide python.
|
||||
if scheme == 'posix_local':
|
||||
scheme = 'posix_prefix'
|
||||
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
|
||||
|
||||
ret = subprocess.check_call([cc, src, f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-o", so])
|
||||
if ret == 0:
|
||||
return so
|
||||
# fallback on setuptools
|
||||
extra_compile_args = []
|
||||
library_dirs = []
|
||||
include_dirs = [srcdir]
|
||||
libraries = []
|
||||
# extra arguments
|
||||
extra_link_args = []
|
||||
# create extension module
|
||||
ext = setuptools.Extension(
|
||||
name=name,
|
||||
language='c',
|
||||
sources=[src],
|
||||
include_dirs=include_dirs,
|
||||
extra_compile_args=extra_compile_args + ['-O3'],
|
||||
extra_link_args=extra_link_args,
|
||||
library_dirs=library_dirs,
|
||||
libraries=libraries,
|
||||
)
|
||||
# build extension module
|
||||
args = ['build_ext']
|
||||
args.append('--build-temp=' + srcdir)
|
||||
args.append('--build-lib=' + srcdir)
|
||||
args.append('-q')
|
||||
args = dict(
|
||||
name=name,
|
||||
ext_modules=[ext],
|
||||
script_args=args,
|
||||
)
|
||||
with quiet():
|
||||
setuptools.setup(**args)
|
||||
return so
|
||||
|
||||
|
||||
class ExtensionUtils:
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
cls.instance = super(ExtensionUtils, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
def __init__(self):
|
||||
dirname = os.path.dirname(os.path.realpath(__file__))
|
||||
src = Path(os.path.join(dirname, "extension_backend.c")).read_text()
|
||||
key = hashlib.md5(src.encode("utf-8")).hexdigest()
|
||||
cache = get_cache_manager(key)
|
||||
fname = "ext_utils.so"
|
||||
cache_path = cache.get_file(fname)
|
||||
if cache_path is None:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
src_path = os.path.join(tmpdir, "main.c")
|
||||
with open(src_path, "w") as f:
|
||||
f.write(src)
|
||||
so = build_for_backend("ext_utils", src_path, tmpdir)
|
||||
with open(so, "rb") as f:
|
||||
cache_path = cache.put(f.read(), fname, binary=True)
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("ext_utils", cache_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
self.load_binary = mod.load_binary
|
||||
self.get_device_properties = mod.get_device_properties
|
||||
|
||||
|
||||
class ExtensionDriver(DriverBase):
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
cls.instance = super(ExtensionDriver, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
def __init__(self):
|
||||
self.utils = ExtensionUtils()
|
||||
|
||||
|
||||
class ExtensionBackend(BaseBackend):
|
||||
stub_so_path = ""
|
||||
|
||||
def __init__(self, device_type: str) -> None:
|
||||
super(ExtensionBackend, self).__init__(device_type)
|
||||
self.driver = ExtensionDriver()
|
||||
|
||||
def add_stages(self, arch, extern_libs, stages):
|
||||
filter_in_stages = ["ast", "ttir", "ttgir"]
|
||||
filter_out_stages = []
|
||||
for key, _ in stages.items():
|
||||
if key not in filter_in_stages:
|
||||
filter_out_stages.append(key)
|
||||
for filter_out_key in filter_out_stages:
|
||||
stages.pop(filter_out_key)
|
||||
|
||||
def add_meta_info(self, ir, cur_module, next_module, metadata, asm):
|
||||
metadata["name"] = "extension_backend_name"
|
||||
|
||||
def get_driver(self):
|
||||
return self.driver
|
||||
|
||||
def get_stream(self):
|
||||
return ""
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_device_properties(self, device):
|
||||
return self.driver.utils.get_device_properties()
|
||||
|
||||
def get_current_device(self):
|
||||
return torch.device("cpu")
|
||||
|
||||
def set_current_device(self, device):
|
||||
pass
|
||||
|
||||
def get_load_binary_fn(self):
|
||||
return self.driver.utils.load_binary
|
||||
|
||||
def get_kernel_bin(self):
|
||||
return "ttgir"
|
||||
|
||||
def get_architecture_descriptor(self, **kwargs):
|
||||
return ""
|
||||
|
||||
def make_launcher_stub(self, name, signature, constants):
|
||||
# name of files that are cached
|
||||
so_cache_key = make_so_cache_key(version_key(), signature, constants)
|
||||
so_cache_manager = get_cache_manager(so_cache_key)
|
||||
so_name = f"{name}.so"
|
||||
# retrieve stub from cache if it exists
|
||||
cache_path = so_cache_manager.get_file(so_name)
|
||||
if cache_path is None:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
src = self._generate_launcher(constants, signature)
|
||||
src_path = os.path.join(tmpdir, "main.c")
|
||||
with open(src_path, "w") as f:
|
||||
f.write(src)
|
||||
so = build_for_backend(name, src_path, tmpdir)
|
||||
with open(so, "rb") as f:
|
||||
so_path = so_cache_manager.put(f.read(), so_name, binary=True)
|
||||
type(self).stub_so_path = so_path
|
||||
return so_path
|
||||
else:
|
||||
type(self).stub_so_path = cache_path
|
||||
return cache_path
|
||||
|
||||
def _generate_launcher(self, constants, signature):
|
||||
# generate glue code
|
||||
src = """
|
||||
#define __EXTENSION_BACKEND__
|
||||
#include <Python.h>
|
||||
#include <stdio.h>
|
||||
|
||||
static PyObject* launch_counter(PyObject* self, PyObject* args) {
|
||||
static int64_t launch_counter = 0;
|
||||
launch_counter += 1;
|
||||
return PyLong_FromLong(launch_counter);
|
||||
}
|
||||
|
||||
static PyObject* launch(PyObject* self, PyObject* args) {
|
||||
if (PyErr_Occurred()) {
|
||||
return NULL;
|
||||
}
|
||||
launch_counter(self, args);
|
||||
// return None
|
||||
Py_INCREF(Py_None);
|
||||
return Py_None;
|
||||
}
|
||||
|
||||
static PyMethodDef ModuleMethods[] = {
|
||||
{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"},
|
||||
{"launch_counter", launch_counter, METH_VARARGS, "Entry point to get launch counter"},
|
||||
{NULL, NULL, 0, NULL} // sentinel
|
||||
};
|
||||
|
||||
static struct PyModuleDef ModuleDef = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
\"__triton_launcher\",
|
||||
NULL, //documentation
|
||||
-1, //size
|
||||
ModuleMethods
|
||||
};
|
||||
|
||||
PyMODINIT_FUNC PyInit___triton_launcher(void) {
|
||||
PyObject *m = PyModule_Create(&ModuleDef);
|
||||
if(m == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
PyModule_AddFunctions(m, ModuleMethods);
|
||||
return m;
|
||||
}
|
||||
"""
|
||||
|
||||
return src
|
||||
|
||||
|
||||
def test_dummy_backend():
|
||||
register_backend("cpu", ExtensionBackend)
|
||||
|
||||
@triton.jit
|
||||
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
|
||||
xnumel = 10
|
||||
xoffset = tl.program_id(0) * XBLOCK
|
||||
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
||||
xmask = xindex < xnumel
|
||||
x0 = xindex
|
||||
tmp0 = tl.load(in_ptr0 + (x0), xmask)
|
||||
tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask)
|
||||
|
||||
inp = torch.randn(10)
|
||||
out = torch.randn(10)
|
||||
kernel[(10,)](inp, out, 10, XBLOCK=16)
|
||||
spec = importlib.util.spec_from_file_location("__triton_launcher", ExtensionBackend.stub_so_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
launch_counter = getattr(mod, "launch_counter")
|
||||
|
||||
for _ in range(100):
|
||||
kernel[(10,)](inp, out, 10, XBLOCK=16)
|
||||
|
||||
assert launch_counter() > 0
|
||||
14
python/test/backend/third_party_backends/conftest.py
Normal file
14
python/test/backend/third_party_backends/conftest.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# content of conftest.py
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--backend", action="store", default="", help="Codegen backend"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cmdopt(request):
|
||||
return request.config.getoption("--backend")
|
||||
33
python/test/backend/third_party_backends/test_xpu_backend.py
Normal file
33
python/test/backend/third_party_backends/test_xpu_backend.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
def test_xpu_backend(cmdopt):
|
||||
if cmdopt == "xpu":
|
||||
has_ipex = False
|
||||
try:
|
||||
# Import IPEX to provide Intel GPU runtime
|
||||
import intel_extension_for_pytorch # type: ignore # noqa: F401
|
||||
has_ipex = True if hasattr(torch, "xpu") else False
|
||||
except Exception:
|
||||
has_ipex = False
|
||||
|
||||
@triton.jit()
|
||||
def kernel(x_ptr, y_ptr, out_ptr):
|
||||
pid = tl.program_id(axis=0)
|
||||
x = tl.load(x_ptr + pid)
|
||||
y = tl.load(y_ptr + pid)
|
||||
out = x + y
|
||||
tl.store(out_ptr + pid, out)
|
||||
|
||||
if has_ipex:
|
||||
for _ in range(1000):
|
||||
x = torch.randn((65536,), device="xpu", dtype=torch.float32)
|
||||
y = torch.randn((65536,), device="xpu", dtype=torch.float32)
|
||||
z = torch.zeros((65536,), device="xpu", dtype=torch.float32)
|
||||
kernel[(65536,)](x, y, z, num_warps=32)
|
||||
assert torch.all(x + y == z)
|
||||
else:
|
||||
return
|
||||
@@ -69,7 +69,7 @@ matmul_data = {
|
||||
(16, 1024, 1024): {'float16': 0.0077, 'float32': 0.0127, 'int8': 0.005},
|
||||
(16, 4096, 4096): {'float16': 0.044, 'float32': 0.0457, 'int8': 0.0259},
|
||||
(16, 8192, 8192): {'float16': 0.07, 'float32': 0.0648, 'int8': 0.0431},
|
||||
(64, 1024, 1024): {'float16': 0.030, 'float32': 0.0509, 'int8': 0.0169},
|
||||
(64, 1024, 1024): {'float16': 0.028, 'float32': 0.0509, 'int8': 0.0169},
|
||||
(64, 4096, 4096): {'float16': 0.163, 'float32': 0.162, 'int8': 0.097},
|
||||
(64, 8192, 8192): {'float16': 0.285, 'float32': 0.257, 'int8': 0.174},
|
||||
(1024, 64, 1024): {'float16': 0.033, 'float32': 0.0458, 'int8': 0.017},
|
||||
|
||||
@@ -4,7 +4,7 @@ import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.debugger.debugger import program_ids_from_grid
|
||||
from triton.interpreter.interpreter import program_ids_from_grid
|
||||
|
||||
|
||||
def test_addition():
|
||||
@@ -51,8 +51,13 @@ def test_assert(func: str):
|
||||
kernel_device_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
|
||||
kernel_device_assert_scalar[(1,)](x, y, num_warps=2, BLOCK=shape[0])
|
||||
elif func == "no_debug":
|
||||
<<<<<<< HEAD
|
||||
# TRITON_DEBUG=True can override the debug flag
|
||||
kernel_device_assert_no_debug[(1,)](x, y, num_warps=2, BLOCK=shape[0])
|
||||
=======
|
||||
# TRITON_DEBUG=1 can override the debug flag
|
||||
kernel_device_assert_no_debug[(1,)](x, y, BLOCK=shape[0])
|
||||
>>>>>>> oai/main
|
||||
elif func == "assert":
|
||||
kernel_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
|
||||
elif func == "static_assert":
|
||||
|
||||
@@ -937,15 +937,16 @@ def test_noinline(mode):
|
||||
# ---------------
|
||||
# test atomics
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
|
||||
@pytest.mark.parametrize("op, dtype_x_str, mode, sem", itertools.chain.from_iterable([
|
||||
[
|
||||
('add', 'float16', mode),
|
||||
('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode),
|
||||
('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode),
|
||||
('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode),
|
||||
('add', 'float16', mode, sem),
|
||||
('add', 'uint32', mode, sem), ('add', 'int32', mode, sem), ('add', 'float32', mode, sem),
|
||||
('max', 'uint32', mode, sem), ('max', 'int32', mode, sem), ('max', 'float32', mode, sem),
|
||||
('min', 'uint32', mode, sem), ('min', 'int32', mode, sem), ('min', 'float32', mode, sem),
|
||||
]
|
||||
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
||||
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']
|
||||
for sem in [None, 'acquire', 'release', 'acq_rel', 'relaxed']]))
|
||||
def test_atomic_rmw(op, dtype_x_str, mode, sem, device='cuda'):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 7:
|
||||
if dtype_x_str == 'float16':
|
||||
@@ -959,7 +960,8 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
x = tl.load(X + pid)
|
||||
old = GENERATE_TEST_HERE
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
|
||||
sem_arg = sem if sem is None else f'"{sem}"'
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x, sem={sem_arg})'})
|
||||
numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op]
|
||||
max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min
|
||||
min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max
|
||||
@@ -981,7 +983,7 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
x_tri = to_triton(x, device=device)
|
||||
|
||||
z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device)
|
||||
kernel[(n_programs, )](x_tri, z_tri)
|
||||
h = kernel[(n_programs, )](x_tri, z_tri)
|
||||
# torch result
|
||||
z_ref = numpy_op(x).astype(getattr(np, dtype_x_str))
|
||||
# compare
|
||||
@@ -990,6 +992,8 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
assert z_ref.item() == to_numpy(z_tri).item()
|
||||
else:
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
sem_str = "acq_rel" if sem is None else sem
|
||||
assert f"atom.global.gpu.{sem_str}" in h.asm["ptx"]
|
||||
|
||||
|
||||
def test_atomic_rmw_predicate(device="cuda"):
|
||||
@@ -1047,7 +1051,8 @@ def test_tensor_atomic_rmw_block(device="cuda"):
|
||||
assert torch.min(x).item() == 0.0
|
||||
|
||||
|
||||
def test_atomic_cas():
|
||||
@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed'])
|
||||
def test_atomic_cas(sem):
|
||||
# 1. make sure that atomic_cas changes the original value (Lock)
|
||||
@triton.jit
|
||||
def change_value(Lock):
|
||||
@@ -1060,9 +1065,9 @@ def test_atomic_cas():
|
||||
|
||||
# 2. only one block enters the critical section
|
||||
@triton.jit
|
||||
def serialized_add(data, Lock):
|
||||
def serialized_add(data, Lock, SEM: tl.constexpr):
|
||||
ptrs = data + tl.arange(0, 128)
|
||||
while tl.atomic_cas(Lock, 0, 1) == 1:
|
||||
while tl.atomic_cas(Lock, 0, 1, SEM) == 1:
|
||||
pass
|
||||
|
||||
tl.store(ptrs, tl.load(ptrs) + 1.0)
|
||||
@@ -1073,8 +1078,10 @@ def test_atomic_cas():
|
||||
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
|
||||
data = torch.zeros((128,), device='cuda', dtype=torch.float32)
|
||||
ref = torch.full((128,), 64.0)
|
||||
serialized_add[(64,)](data, Lock)
|
||||
h = serialized_add[(64,)](data, Lock, SEM=sem)
|
||||
sem_str = "acq_rel" if sem is None else sem
|
||||
np.testing.assert_allclose(to_numpy(data), to_numpy(ref))
|
||||
assert f"atom.global.{sem_str}" in h.asm["ptx"]
|
||||
|
||||
|
||||
# ---------------
|
||||
@@ -1140,6 +1147,28 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
assert to_numpy(z_tri) == z_ref
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, num_warps", [(dtype_str, num_warps) for dtype_str in int_dtypes + float_dtypes for num_warps in [4, 8]])
|
||||
def test_cat(dtype_str, num_warps):
|
||||
check_type_supported(dtype_str)
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, Y, Z, N: tl.constexpr):
|
||||
offs = tl.arange(0, N)
|
||||
x = tl.load(X + offs)
|
||||
y = tl.load(Y + offs)
|
||||
z = tl.cat(x, y, can_reorder=True)
|
||||
tl.store(Z + tl.arange(0, 2 * N), z)
|
||||
|
||||
x = torch.arange(0, 128, device='cuda').to(getattr(torch, dtype_str))
|
||||
y = torch.arange(-128, 0, device='cuda').to(getattr(torch, dtype_str))
|
||||
z_ref = torch.cat([x, y], dim=0).sum()
|
||||
z = torch.zeros((256,), dtype=getattr(torch, dtype_str), device='cuda')
|
||||
kernel[(1, )](x, y, z, N=128, num_warps=num_warps)
|
||||
assert z.sum() == z_ref
|
||||
# check if there's no duplicate value in z
|
||||
assert z.unique().size(0) == z.size(0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str", list(torch_dtypes))
|
||||
def test_store_constant(dtype_str):
|
||||
check_type_supported(dtype_str)
|
||||
@@ -1197,8 +1226,8 @@ def convert_float_to_float32(fp: torch.tensor, dtype=None):
|
||||
# special cases, exp is 0b11..1
|
||||
if dtype == tl.float8e4:
|
||||
# float8e4m3 does not have infinities
|
||||
output[fp == torch.tensor(0b01111111, dtype=torch.int8)] = torch.nan
|
||||
output[fp == torch.tensor(0b11111111, dtype=torch.int8)] = torch.nan
|
||||
output[fp == 0b01111111] = torch.nan
|
||||
output[fp == 0b11111111] = torch.nan
|
||||
else:
|
||||
output = torch.where(exp == (1 << exp_width) - 1,
|
||||
((sign << (tl.float32.primitive_bitwidth - 1)) | extended_exp | (frac << (tl.float32.fp_mantissa_width - dtype.fp_mantissa_width))).view(torch.float32),
|
||||
@@ -1338,7 +1367,11 @@ def get_reduced_dtype(dtype_str, op):
|
||||
|
||||
@pytest.mark.parametrize("op, dtype_str, shape",
|
||||
[(op, dtype, shape)
|
||||
for op in ['min', 'max', 'sum', 'argmin', 'argmax']
|
||||
for op in ['min', 'max',
|
||||
'min-with-indices',
|
||||
'max-with-indices',
|
||||
'argmin', 'argmax',
|
||||
'sum']
|
||||
for dtype in dtypes_with_bfloat16
|
||||
for shape in [32, 64, 128, 512]])
|
||||
def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||
@@ -1348,15 +1381,22 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(X, Z, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.store(Z, GENERATE_TEST_HERE)
|
||||
GENERATE_TEST_HERE
|
||||
tl.store(Z, z)
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=0)'})
|
||||
if 'with-indices' in op:
|
||||
patch = f'z, _ = tl.{op.split("-")[0]}(x, axis=0, return_indices=True)'
|
||||
else:
|
||||
patch = f'z = tl.{op}(x, axis=0)'
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': patch})
|
||||
# input
|
||||
rs = RandomState(17)
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = numpy_random((shape,), dtype_str=dtype_str, rs=rs)
|
||||
x_tri = to_triton(x, device=device)
|
||||
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
|
||||
'max-with-indices': np.max,
|
||||
'min-with-indices': np.min,
|
||||
'argmin': np.argmin, 'argmax': np.argmax}[op]
|
||||
# numpy result
|
||||
z_dtype_str = 'int32' if op in ('argmin', 'argmax') else dtype_str
|
||||
@@ -2444,11 +2484,11 @@ def test_call(type):
|
||||
# -------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("if_type", ["if", "if_exp", "if_and"])
|
||||
@pytest.mark.parametrize("if_type", ["if", "if_exp", "if_and_dynamic", "if_and_static"])
|
||||
def test_if(if_type):
|
||||
|
||||
@triton.jit
|
||||
def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr):
|
||||
def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr, StaticVaue: tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
cond = tl.load(Cond)
|
||||
if IfType == "if":
|
||||
@@ -2458,17 +2498,22 @@ def test_if(if_type):
|
||||
tl.store(Ret, tl.load(XFalse))
|
||||
elif IfType == "if_exp":
|
||||
tl.store(Ret, tl.load(XTrue)) if pid % 2 else tl.store(Ret, tl.load(XFalse))
|
||||
elif IfType == "if_and":
|
||||
elif IfType == "if_and_dynamic":
|
||||
if BoolVar and pid % 2 == 0:
|
||||
tl.store(Ret, tl.load(XTrue))
|
||||
else:
|
||||
tl.store(Ret, tl.load(XFalse))
|
||||
elif IfType == "if_and_static":
|
||||
if StaticVaue != 0 and StaticVaue != 0:
|
||||
tl.store(Ret, tl.load(XTrue))
|
||||
else:
|
||||
tl.store(Ret, tl.load(XFalse))
|
||||
|
||||
cond = torch.ones(1, dtype=torch.int32, device='cuda')
|
||||
x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda')
|
||||
x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda')
|
||||
ret = torch.empty(1, dtype=torch.float32, device='cuda')
|
||||
kernel[(1,)](cond, x_true, x_false, ret, if_type, True)
|
||||
kernel[(1,)](cond, x_true, x_false, ret, if_type, True, 1)
|
||||
assert torch.equal(ret, x_true)
|
||||
|
||||
|
||||
|
||||
@@ -4,75 +4,103 @@ import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import triton.ops
|
||||
|
||||
|
||||
def f8_to_f16(x):
|
||||
|
||||
@triton.jit
|
||||
def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offs < N
|
||||
x = tl.load(X + offs, mask=mask)
|
||||
y = x.to(tl.float8e5)
|
||||
tl.store(Y + offs, y, mask=mask)
|
||||
|
||||
ret = torch.empty(x.shape, dtype=torch.float16, device=x.device)
|
||||
grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']),)
|
||||
kernel[grid](ret, triton.reinterpret(x, tl.float8e5), ret.numel(), BLOCK_SIZE=1024)
|
||||
return ret
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",
|
||||
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE",
|
||||
itertools.chain(
|
||||
*[
|
||||
[
|
||||
# 1 warp
|
||||
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
# 2 warp
|
||||
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
# 4 warp
|
||||
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
# 8 warp
|
||||
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
# split-k
|
||||
(64, 64, 16, 2, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 64, 16, 2, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
# variable input
|
||||
(128, 128, 32, 1, 4, 2, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE, DTYPE),
|
||||
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]
|
||||
],
|
||||
# n-stage
|
||||
*[
|
||||
[
|
||||
(16, 16, 16, 1, 1, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(64, 32, 64, 1, 2, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(128, 64, 16, 1, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(256, 128, 32, 1, 8, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, STAGES, 384, 128, 640, AT, BT, DTYPE),
|
||||
(16, 16, 16, 1, 1, STAGES, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE),
|
||||
(64, 32, 64, 1, 2, STAGES, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE),
|
||||
(128, 64, 16, 1, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE),
|
||||
(256, 128, 32, 1, 8, STAGES, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE),
|
||||
(128, 128, 32, 1, 4, STAGES, 384, 128, 640, AT, BT, DTYPE, DTYPE),
|
||||
# split-k
|
||||
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE),
|
||||
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE),
|
||||
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE, DTYPE),
|
||||
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4]
|
||||
],
|
||||
# mixed-precision
|
||||
*[
|
||||
[
|
||||
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE),
|
||||
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, ADTYPE, BDTYPE),
|
||||
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE),
|
||||
(32, 64, 16, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE),
|
||||
(128, 128, 32, 8, 4, 2, 1024, 1024, 1024, AT, BT, ADTYPE, BDTYPE),
|
||||
] for ADTYPE, BDTYPE in [("float8", "float16")] for AT in [False, True] for BT in [False, True]
|
||||
]
|
||||
),
|
||||
)
|
||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
|
||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 7:
|
||||
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||
if capability[0] < 8 and DTYPE == "bfloat16":
|
||||
if capability[0] < 8 and (ADTYPE == "bfloat16" or BDTYPE == "bfloat16"):
|
||||
pytest.skip("Only test bfloat16 on devices with sm >= 80")
|
||||
if DTYPE == "bfloat16" and SPLIT_K != 1:
|
||||
if (ADTYPE == "bfloat16" or BDTYPE == "bfloat16") and SPLIT_K != 1:
|
||||
pytest.skip("bfloat16 matmuls don't allow split_k for now")
|
||||
torch.manual_seed(0)
|
||||
# nuke kernel decorators -- will set meta-parameters manually
|
||||
@@ -87,12 +115,18 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
|
||||
M = BLOCK_M if M is None else M
|
||||
N = BLOCK_N if N is None else N
|
||||
K = BLOCK_K * SPLIT_K if K is None else K
|
||||
|
||||
def get_input(n, m, t, dtype):
|
||||
if t:
|
||||
return get_input(m, n, False, dtype).t()
|
||||
if dtype == "float8":
|
||||
x = torch.randint(10, 50, (n, m), device="cuda", dtype=torch.int8)
|
||||
return f8_to_f16(x)
|
||||
dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype]
|
||||
return .1 * torch.randn((n, m), device="cuda", dtype=dtype)
|
||||
# allocate/transpose inputs
|
||||
DTYPE = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[DTYPE]
|
||||
a = .1 * torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
|
||||
b = .1 * torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
|
||||
a = a.t() if AT else a
|
||||
b = b.t() if BT else b
|
||||
a = get_input(M, K, AT, ADTYPE)
|
||||
b = get_input(K, N, BT, BDTYPE)
|
||||
# run test
|
||||
th_c = torch.matmul(a, b)
|
||||
try:
|
||||
|
||||
@@ -38,7 +38,7 @@ def test_memory_leak() -> None:
|
||||
kernel[(10,)](inp, out, 10, XBLOCK=16)
|
||||
gc.collect()
|
||||
end, _ = tracemalloc.get_traced_memory()
|
||||
assert end - begin < 1000
|
||||
assert end - begin < 5000
|
||||
finally:
|
||||
tracemalloc.stop()
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ from .runtime import (
|
||||
)
|
||||
from .runtime.jit import jit
|
||||
from .compiler import compile, CompilationError
|
||||
from .debugger.debugger import program_ids_from_grid
|
||||
|
||||
from . import language
|
||||
from . import testing
|
||||
@@ -43,7 +42,6 @@ __all__ = [
|
||||
"runtime",
|
||||
"TensorWrapper",
|
||||
"testing",
|
||||
"program_ids_from_grid",
|
||||
]
|
||||
|
||||
|
||||
@@ -52,11 +50,11 @@ __all__ = [
|
||||
# into any specific module
|
||||
# -------------------------------------
|
||||
|
||||
def cdiv(x, y):
|
||||
def cdiv(x: int, y: int):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def next_power_of_2(n):
|
||||
def next_power_of_2(n: int):
|
||||
"""Return the smallest power of 2 greater than or equal to n"""
|
||||
n -= 1
|
||||
n |= n >> 1
|
||||
@@ -64,5 +62,6 @@ def next_power_of_2(n):
|
||||
n |= n >> 4
|
||||
n |= n >> 8
|
||||
n |= n >> 16
|
||||
n |= n >> 32
|
||||
n += 1
|
||||
return n
|
||||
|
||||
96
python/triton/common/backend.py
Normal file
96
python/triton/common/backend.py
Normal file
@@ -0,0 +1,96 @@
|
||||
|
||||
import importlib
|
||||
import importlib.util
|
||||
from typing import Dict
|
||||
|
||||
from ..runtime.driver import DriverBase
|
||||
|
||||
|
||||
class BaseBackend:
|
||||
def __init__(self, device_type: str) -> None:
|
||||
self.device_type = device_type
|
||||
|
||||
def add_stages(self, arch, extern_libs, stages):
|
||||
"""
|
||||
Custom the arch, extern_libs and stages per backend specific requirement
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def add_meta_info(self, ir, cur_module, next_module, metadata, asm):
|
||||
"""
|
||||
Custom the ir, module, metadata and asm per backend specific requirement
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_load_binary_fn(self):
|
||||
"""
|
||||
Return a callable to load binary
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_driver(self) -> DriverBase:
|
||||
"""
|
||||
Get the backend driver. Please refer to "DriverBase" for more details
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_stream(self):
|
||||
"""
|
||||
Get stream for current device
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_device_properties(self, device):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_current_device(self):
|
||||
"""
|
||||
Get current device
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def set_current_device(self, device):
|
||||
"""
|
||||
Set current device as the given device
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_kernel_bin(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def make_launcher_stub(self, name, signature, constants):
|
||||
"""
|
||||
Generate the launcher stub to launch the kernel
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_architecture_descriptor(self, **kwargs):
|
||||
"""
|
||||
Get the architecture descriptor the backend
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def create_backend(cls, device_type: str):
|
||||
return cls(device_type)
|
||||
|
||||
|
||||
_backends: Dict[str, BaseBackend] = {}
|
||||
|
||||
|
||||
def register_backend(device_type: str, backend_cls: type):
|
||||
if device_type not in _backends:
|
||||
_backends[device_type] = backend_cls.create_backend(device_type)
|
||||
|
||||
|
||||
def get_backend(device_type: str):
|
||||
if device_type not in _backends:
|
||||
device_backend_package_name = f"triton.third_party.{device_type}"
|
||||
if importlib.util.find_spec(device_backend_package_name):
|
||||
try:
|
||||
importlib.import_module(device_backend_package_name)
|
||||
except Exception:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
return _backends[device_type] if device_type in _backends else None
|
||||
@@ -6,12 +6,12 @@ import warnings
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
from .. import language
|
||||
from .._C.libtriton.triton import ir
|
||||
from ..language import constexpr, tensor
|
||||
# ideally we wouldn't need any runtime component
|
||||
from ..runtime import JITFunction
|
||||
from .errors import (CompilationError, CompileTimeAssertionFailure,
|
||||
UnsupportedLanguageConstruct)
|
||||
from triton._C.libtriton.triton import ir
|
||||
|
||||
|
||||
def mangle_ty(ty):
|
||||
@@ -595,12 +595,14 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def visit_Compare(self, node):
|
||||
if not (len(node.comparators) == 1 and len(node.ops) == 1):
|
||||
raise UnsupportedLanguageConstruct(None, node, "simultaneous multiple comparison is not supported")
|
||||
lhs = _unwrap_if_constexpr(self.visit(node.left))
|
||||
rhs = _unwrap_if_constexpr(self.visit(node.comparators[0]))
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.comparators[0])
|
||||
lhs_value = _unwrap_if_constexpr(lhs)
|
||||
rhs_value = _unwrap_if_constexpr(rhs)
|
||||
if type(node.ops[0]) == ast.Is:
|
||||
return constexpr(lhs is rhs)
|
||||
return constexpr(lhs_value is rhs_value)
|
||||
if type(node.ops[0]) == ast.IsNot:
|
||||
return constexpr(lhs is not rhs)
|
||||
return constexpr(lhs_value is not rhs_value)
|
||||
method_name = self._method_name_for_comp_op.get(type(node.ops[0]))
|
||||
if method_name is None:
|
||||
raise UnsupportedLanguageConstruct(None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
|
||||
@@ -988,7 +990,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if not (0 < arg_count <= 2) or len(node.keywords):
|
||||
raise TypeError("`static_assert` requires one or two positional arguments only")
|
||||
|
||||
passed = self.visit(node.args[0])
|
||||
passed = _unwrap_if_constexpr(self.visit(node.args[0]))
|
||||
if not isinstance(passed, bool):
|
||||
raise NotImplementedError("Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values")
|
||||
if not passed:
|
||||
|
||||
@@ -11,12 +11,19 @@ from collections import namedtuple
|
||||
from pathlib import Path
|
||||
from typing import Any, Tuple
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from ..runtime import driver
|
||||
# import triton
|
||||
from .._C.libtriton.triton import (add_external_libs, compile_ptx_to_cubin,
|
||||
get_shared_memory_size, ir,
|
||||
translate_llvmir_to_hsaco, translate_llvmir_to_ptx,
|
||||
translate_triton_gpu_to_llvmir)
|
||||
from ..common.backend import get_backend
|
||||
# from ..runtime import driver, jit, JITFunction
|
||||
# TODO: runtime.errors
|
||||
from ..runtime.autotuner import OutOfResources
|
||||
from ..runtime.cache import get_cache_manager
|
||||
from ..runtime.driver import driver
|
||||
from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device,
|
||||
get_device_capability, version_key)
|
||||
from ..tools.disasm import extract
|
||||
from .code_generator import ast_to_ttir
|
||||
from .make_launcher import make_stub
|
||||
@@ -24,7 +31,7 @@ from .make_launcher import make_stub
|
||||
CUDA_DEFAULT_WARP_SIZE = 32
|
||||
|
||||
def inline_triton_ir(mod):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm = ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_inliner_pass()
|
||||
pm.run(mod)
|
||||
@@ -34,7 +41,7 @@ def inline_triton_ir(mod):
|
||||
def ttir_compute_capability_rewrite(mod, arch):
|
||||
# For hardware without support, we must rewrite all load/store
|
||||
# with block (tensor) pointers into tensors of pointers
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm = ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
if _is_cuda(arch):
|
||||
pm.add_rewrite_tensor_pointer_pass(arch)
|
||||
@@ -45,7 +52,7 @@ def ttir_compute_capability_rewrite(mod, arch):
|
||||
def optimize_ttir(mod, arch):
|
||||
mod = inline_triton_ir(mod)
|
||||
mod = ttir_compute_capability_rewrite(mod, arch)
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm = ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_inliner_pass()
|
||||
pm.add_triton_combine_pass()
|
||||
@@ -57,15 +64,22 @@ def optimize_ttir(mod, arch):
|
||||
return mod
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
def ttir_to_ttgir(mod, num_warps, warpsize):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize)
|
||||
=======
|
||||
def ttir_to_ttgir(mod, num_warps):
|
||||
pm = ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_convert_triton_to_tritongpu_pass(num_warps)
|
||||
>>>>>>> oai/main
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def optimize_ttgir(mod, num_stages, arch):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm = ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_tritongpu_coalesce_pass()
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
@@ -92,7 +106,7 @@ def _add_external_libs(mod, libs):
|
||||
for name, path in libs.items():
|
||||
if len(name) == 0 or len(path) == 0:
|
||||
return
|
||||
_triton.add_external_libs(mod, list(libs.keys()), list(libs.values()))
|
||||
add_external_libs(mod, list(libs.keys()), list(libs.values()))
|
||||
|
||||
|
||||
def ttgir_to_llir(mod, extern_libs, arch):
|
||||
@@ -100,9 +114,9 @@ def ttgir_to_llir(mod, extern_libs, arch):
|
||||
_add_external_libs(mod, extern_libs)
|
||||
# TODO: separate tritongpu_to_llvmir for different backends
|
||||
if _is_cuda(arch):
|
||||
return _triton.translate_triton_gpu_to_llvmir(mod, arch, False)
|
||||
return translate_triton_gpu_to_llvmir(mod, arch, False)
|
||||
else:
|
||||
return _triton.translate_triton_gpu_to_llvmir(mod, 0, True)
|
||||
return translate_triton_gpu_to_llvmir(mod, 0, True)
|
||||
|
||||
|
||||
# PTX translation
|
||||
@@ -132,8 +146,9 @@ def path_to_ptxas():
|
||||
]
|
||||
|
||||
for ptxas in paths:
|
||||
if os.path.exists(ptxas) and os.path.isfile(ptxas):
|
||||
result = subprocess.check_output([ptxas, "--version"], stderr=subprocess.STDOUT)
|
||||
ptxas_bin = ptxas.split(" ")[0]
|
||||
if os.path.exists(ptxas_bin) and os.path.isfile(ptxas_bin):
|
||||
result = subprocess.check_output([ptxas_bin, "--version"], stderr=subprocess.STDOUT)
|
||||
if result is not None:
|
||||
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
|
||||
if version is not None:
|
||||
@@ -150,7 +165,7 @@ def llir_to_ptx(mod: Any, arch: int, ptx_version: int = None) -> str:
|
||||
if ptx_version is None:
|
||||
_, cuda_version = path_to_ptxas()
|
||||
ptx_version = ptx_get_version(cuda_version)
|
||||
return _triton.translate_llvmir_to_ptx(mod, arch, ptx_version)
|
||||
return translate_llvmir_to_ptx(mod, arch, ptx_version)
|
||||
|
||||
|
||||
def ptx_to_cubin(ptx: str, arch: int):
|
||||
@@ -161,7 +176,7 @@ def ptx_to_cubin(ptx: str, arch: int):
|
||||
:return: str
|
||||
'''
|
||||
ptxas, _ = path_to_ptxas()
|
||||
return _triton.compile_ptx_to_cubin(ptx, ptxas, arch)
|
||||
return compile_ptx_to_cubin(ptx, ptxas, arch)
|
||||
|
||||
|
||||
# AMDGCN translation
|
||||
@@ -228,7 +243,7 @@ def llir_to_amdgcn_and_hsaco(mod: Any, gfx_arch: str, gfx_triple: str, gfx_featu
|
||||
- AMDGCN code
|
||||
- Path to HSACO object
|
||||
'''
|
||||
return _triton.translate_llvmir_to_hsaco(mod, gfx_arch, gfx_triple, gfx_features)
|
||||
return translate_llvmir_to_hsaco(mod, gfx_arch, gfx_triple, gfx_features)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
@@ -255,7 +270,7 @@ def convert_type_repr(x):
|
||||
|
||||
|
||||
def make_hash(fn, arch, **kwargs):
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
if isinstance(fn, JITFunction):
|
||||
configs = kwargs["configs"]
|
||||
signature = kwargs["signature"]
|
||||
constants = kwargs.get("constants", dict())
|
||||
@@ -268,7 +283,7 @@ def make_hash(fn, arch, **kwargs):
|
||||
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{debug}-{arch}"
|
||||
return hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
assert isinstance(fn, str)
|
||||
return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest()
|
||||
return hashlib.md5((Path(fn).read_text() + version_key()).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
|
||||
@@ -312,7 +327,7 @@ def _get_jsonable_constants(constants):
|
||||
|
||||
|
||||
def parse_mlir_module(path, context):
|
||||
module = _triton.ir.parse_mlir_module(path, context)
|
||||
module = ir.parse_mlir_module(path, context)
|
||||
# module takes ownership of the context
|
||||
module.context = context
|
||||
return module
|
||||
@@ -341,8 +356,8 @@ def get_architecture_descriptor(capability):
|
||||
raise ImportError("Triton requires PyTorch to be installed")
|
||||
if capability is None:
|
||||
if torch.version.hip is None:
|
||||
device = triton.runtime.jit.get_current_device()
|
||||
capability = triton.runtime.jit.get_device_capability(device)
|
||||
device = get_current_device()
|
||||
capability = get_device_capability(device)
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
else:
|
||||
capability = get_amdgpu_arch_fulldetails()
|
||||
@@ -375,6 +390,7 @@ def add_cuda_stages(arch, extern_libs, stages):
|
||||
|
||||
|
||||
def compile(fn, **kwargs):
|
||||
<<<<<<< HEAD
|
||||
if is_hip():
|
||||
capability = None
|
||||
else:
|
||||
@@ -384,6 +400,22 @@ def compile(fn, **kwargs):
|
||||
warp_size = CUDA_DEFAULT_WARP_SIZE if _is_cuda(arch) else arch[3]
|
||||
context = _triton.ir.context()
|
||||
asm = dict()
|
||||
=======
|
||||
# Get device type to decide which backend should be used
|
||||
device_type = kwargs.get("device_type", "cuda")
|
||||
_device_backend = get_backend(device_type)
|
||||
|
||||
if device_type in ["cuda", "hip"]:
|
||||
arch = get_architecture_descriptor(kwargs.get("cc", None))
|
||||
else:
|
||||
_device_backend = get_backend(device_type)
|
||||
assert _device_backend
|
||||
arch = _device_backend.get_architecture_descriptor(**kwargs)
|
||||
|
||||
is_cuda = device_type == "cuda" and _is_cuda(arch)
|
||||
is_hip = device_type in ["cuda", "hip"] and not is_cuda
|
||||
context = ir.context()
|
||||
>>>>>>> oai/main
|
||||
constants = kwargs.get("constants", dict())
|
||||
num_warps = kwargs.get("num_warps", 4)
|
||||
num_stages = kwargs.get("num_stages", 3 if is_cuda and arch >= 75 else 2)
|
||||
@@ -391,6 +423,7 @@ def compile(fn, **kwargs):
|
||||
if extern_libs is None:
|
||||
extern_libs = dict()
|
||||
debug = kwargs.get("debug", False)
|
||||
|
||||
# build compilation stages
|
||||
stages = dict()
|
||||
stages["ast"] = (lambda path: fn, None)
|
||||
@@ -402,11 +435,13 @@ def compile(fn, **kwargs):
|
||||
lambda src: ttgir_to_llir(src, extern_libs, arch))
|
||||
if is_cuda:
|
||||
add_cuda_stages(arch, extern_libs, stages)
|
||||
else:
|
||||
elif is_hip:
|
||||
add_rocm_stages(arch, extern_libs, stages)
|
||||
else:
|
||||
_device_backend.add_stages(arch, extern_libs, stages)
|
||||
|
||||
# find out the signature of the function
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
if isinstance(fn, JITFunction):
|
||||
configs = kwargs.get("configs", None)
|
||||
signature = kwargs["signature"]
|
||||
if configs is None:
|
||||
@@ -420,27 +455,31 @@ def compile(fn, **kwargs):
|
||||
kwargs["signature"] = signature
|
||||
else:
|
||||
assert isinstance(fn, str)
|
||||
_, ir = os.path.basename(fn).split(".")
|
||||
_, ir_name = os.path.basename(fn).split(".")
|
||||
src = Path(fn).read_text()
|
||||
import re
|
||||
match = re.search(prototype_pattern[ir], src, re.MULTILINE)
|
||||
match = re.search(prototype_pattern[ir_name], src, re.MULTILINE)
|
||||
name, signature = match.group(1), match.group(2)
|
||||
types = re.findall(arg_type_pattern[ir], signature)
|
||||
if ir == 'ttgir':
|
||||
types = re.findall(arg_type_pattern[ir_name], signature)
|
||||
if ir_name == 'ttgir':
|
||||
num_warps_matches = re.findall(ttgir_num_warps_pattern, src)
|
||||
assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps"
|
||||
assert "num_warps" not in kwargs or int(num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile"
|
||||
num_warps = int(num_warps_matches[0])
|
||||
param_tys = [convert_type_repr(ty) for ty in types]
|
||||
signature = {k: v for k, v in enumerate(param_tys)}
|
||||
first_stage = list(stages.keys()).index(ir)
|
||||
first_stage = list(stages.keys()).index(ir_name)
|
||||
|
||||
# cache manager
|
||||
so_path = make_stub(name, signature, constants)
|
||||
if is_cuda or is_hip:
|
||||
so_path = make_stub(name, signature, constants)
|
||||
else:
|
||||
so_path = _device_backend.make_launcher_stub(name, signature, constants)
|
||||
|
||||
# create cache manager
|
||||
fn_cache_manager = get_cache_manager(make_hash(fn, arch, **kwargs))
|
||||
# determine name and extension type of provided function
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
if isinstance(fn, JITFunction):
|
||||
name, ext = fn.__name__, "ast"
|
||||
else:
|
||||
name, ext = os.path.basename(fn).split(".")
|
||||
@@ -464,19 +503,23 @@ def compile(fn, **kwargs):
|
||||
"warp_size": warp_size,
|
||||
"num_stages": num_stages,
|
||||
"constants": _get_jsonable_constants(constants),
|
||||
"debug": debug}
|
||||
"debug": debug,
|
||||
"arch": arch, }
|
||||
if ext == "ptx":
|
||||
assert "shared" in kwargs, "ptx compilation must provide shared memory size"
|
||||
metadata["shared"] = kwargs["shared"]
|
||||
|
||||
# Add device type to meta information
|
||||
metadata["device_type"] = device_type
|
||||
|
||||
first_stage = list(stages.keys()).index(ext)
|
||||
asm = dict()
|
||||
module = fn
|
||||
# run compilation pipeline and populate metadata
|
||||
for ir, (parse, compile_kernel) in list(stages.items())[first_stage:]:
|
||||
ir_filename = f"{name}.{ir}"
|
||||
for ir_name, (parse, compile_kernel) in list(stages.items())[first_stage:]:
|
||||
ir_filename = f"{name}.{ir_name}"
|
||||
|
||||
if ir == ext:
|
||||
if ir_name == ext:
|
||||
next_module = parse(fn)
|
||||
else:
|
||||
path = metadata_group.get(ir_filename)
|
||||
@@ -490,7 +533,7 @@ def compile(fn, **kwargs):
|
||||
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
|
||||
fn_cache_manager.put(next_module, ir_filename)
|
||||
else:
|
||||
if ir == "amdgcn":
|
||||
if ir_name == "amdgcn":
|
||||
extra_file_name = f"{name}.hsaco_path"
|
||||
hasco_path = metadata_group.get(extra_file_name)
|
||||
assert hasco_path is not None, "Expected to have hsaco_path in metadata when we have the amdgcn"
|
||||
@@ -498,19 +541,21 @@ def compile(fn, **kwargs):
|
||||
else:
|
||||
next_module = parse(path)
|
||||
|
||||
if ir == "cubin":
|
||||
asm[ir] = next_module
|
||||
elif ir == "amdgcn":
|
||||
asm[ir] = str(next_module[0])
|
||||
if ir_name == "cubin":
|
||||
asm[ir_name] = next_module
|
||||
elif ir_name == "amdgcn":
|
||||
asm[ir_name] = str(next_module[0])
|
||||
else:
|
||||
asm[ir] = str(next_module)
|
||||
if ir == "llir" and "shared" not in metadata:
|
||||
metadata["shared"] = _triton.get_shared_memory_size(module)
|
||||
if ir == "ptx":
|
||||
asm[ir_name] = str(next_module)
|
||||
if ir_name == "llir" and "shared" not in metadata:
|
||||
metadata["shared"] = get_shared_memory_size(module)
|
||||
if ir_name == "ptx":
|
||||
metadata["name"] = get_kernel_name(next_module, pattern='// .globl')
|
||||
if ir == "amdgcn":
|
||||
if ir_name == "amdgcn":
|
||||
metadata["name"] = get_kernel_name(next_module[0], pattern='.globl')
|
||||
asm["hsaco_path"] = next_module[1]
|
||||
if not is_cuda and not is_hip:
|
||||
_device_backend.add_meta_info(ir_name, module, next_module, metadata, asm)
|
||||
module = next_module
|
||||
# write-back metadata, if it didn't come from the cache
|
||||
if metadata_path is None:
|
||||
@@ -536,11 +581,13 @@ class CompiledKernel:
|
||||
spec.loader.exec_module(mod)
|
||||
self.c_wrapper = getattr(mod, "launch")
|
||||
# initialize metadata
|
||||
self.shared = metadata["shared"]
|
||||
self.shared = metadata["shared"] if "shared" in metadata else 0
|
||||
self.num_warps = metadata["num_warps"]
|
||||
self.warp_size = metadata["warp_size"]
|
||||
self.num_stages = metadata["num_stages"]
|
||||
self.constants = metadata["constants"]
|
||||
self.device_type = metadata["device_type"]
|
||||
self.device_backend = get_backend(self.device_type) if self.device_type not in ["cuda", "hip"] else None
|
||||
# initialize asm dict
|
||||
self.asm = asm
|
||||
# binaries are lazily initialized
|
||||
@@ -553,15 +600,26 @@ class CompiledKernel:
|
||||
def _init_handles(self):
|
||||
if self.cu_module is not None:
|
||||
return
|
||||
device = triton.runtime.jit.get_current_device()
|
||||
bin_path = {
|
||||
driver.HIP: "hsaco_path",
|
||||
driver.CUDA: "cubin"
|
||||
}[driver.backend]
|
||||
max_shared = driver.utils.get_device_properties(device)["max_shared_mem"]
|
||||
|
||||
if self.device_type in ["cuda", "hip"]:
|
||||
device = get_current_device()
|
||||
bin_path = {
|
||||
driver.HIP: "hsaco_path",
|
||||
driver.CUDA: "cubin"
|
||||
}[driver.backend]
|
||||
max_shared = driver.utils.get_device_properties(device)["max_shared_mem"]
|
||||
fn_load_binary = driver.utils.load_binary
|
||||
else:
|
||||
assert self.device_backend
|
||||
device = self.device_backend.get_current_device()
|
||||
bin_path = self.device_backend.get_kernel_bin()
|
||||
max_shared = self.device_backend.get_device_properties(device)["max_shared_mem"]
|
||||
fn_load_binary = self.device_backend.get_load_binary_fn()
|
||||
|
||||
if self.shared > max_shared:
|
||||
raise OutOfResources(self.shared, max_shared, "shared memory")
|
||||
mod, func, n_regs, n_spills = driver.utils.load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device)
|
||||
|
||||
mod, func, n_regs, n_spills = fn_load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device)
|
||||
|
||||
self.n_spills = n_spills
|
||||
self.n_regs = n_regs
|
||||
@@ -578,8 +636,16 @@ class CompiledKernel:
|
||||
|
||||
def runner(*args, stream=None):
|
||||
if stream is None:
|
||||
<<<<<<< HEAD
|
||||
stream = triton.runtime.jit.get_cuda_stream()
|
||||
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.warp_size, self.shared, stream, self.cu_function,
|
||||
=======
|
||||
if self.device_type in ["cuda", "rocm"]:
|
||||
stream = get_cuda_stream()
|
||||
else:
|
||||
stream = get_backend(self.device_type).get_stream(None)
|
||||
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function,
|
||||
>>>>>>> oai/main
|
||||
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args)
|
||||
return runner
|
||||
|
||||
|
||||
@@ -2,13 +2,14 @@ import itertools
|
||||
import random
|
||||
from typing import Tuple
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from .. import language as tl
|
||||
# import .language.core as lcore
|
||||
from ..language import core as lcore
|
||||
from . import torch_wrapper
|
||||
from .core import ExecutionContext
|
||||
from .memory_map import MemoryMap
|
||||
from .tl_lang import (TritonLangProxy, WrappedTensor, _primitive_to_tensor,
|
||||
debugger_constexpr)
|
||||
from triton.debugger import torch_wrapper
|
||||
|
||||
torch = torch_wrapper.torch
|
||||
tl_method_backup = {}
|
||||
@@ -59,12 +60,12 @@ class DebuggerFunction:
|
||||
self.grid = grid
|
||||
|
||||
def _is_constexpr(self, name):
|
||||
return name in self.func.__annotations__ and self.func.__annotations__[name] is triton.language.core.constexpr
|
||||
return name in self.func.__annotations__ and self.func.__annotations__[name] is lcore.constexpr
|
||||
|
||||
def _get_constexpr(self):
|
||||
result = []
|
||||
for name, annotation in self.func.__annotations__.items():
|
||||
if annotation is triton.language.core.constexpr:
|
||||
if annotation is lcore.constexpr:
|
||||
result.append(name)
|
||||
return result
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
|
||||
from triton.debugger import torch_wrapper
|
||||
from triton.interpreter import torch_wrapper
|
||||
|
||||
torch = torch_wrapper.torch
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import triton
|
||||
from __future__ import annotations
|
||||
|
||||
# import triton
|
||||
from ..language import core as lcore
|
||||
from . import torch_wrapper
|
||||
from .core import ExecutionContext
|
||||
from .memory_map import MemoryMap
|
||||
from triton.debugger import torch_wrapper
|
||||
|
||||
torch = torch_wrapper.torch
|
||||
|
||||
@@ -387,7 +390,7 @@ class TritonLangProxy:
|
||||
if not isinstance(d.value, int):
|
||||
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
|
||||
shape = [x.value for x in shape]
|
||||
if isinstance(dtype, triton.language.core.dtype):
|
||||
if isinstance(dtype, lcore.dtype):
|
||||
if dtype.is_fp32():
|
||||
dtype = torch.float32
|
||||
elif dtype.is_fp16():
|
||||
@@ -405,7 +408,9 @@ class TritonLangProxy:
|
||||
return torch.zeros(size=shape, dtype=dtype, device="cuda")
|
||||
|
||||
@_tensor_operation
|
||||
def dequantize(self, input, scale, shift, nbit, dst_ty=torch.float16):
|
||||
def dequantize(self, input, scale, shift, nbit, dst_ty=None):
|
||||
if dst_ty is None:
|
||||
dst_ty = torch.float16
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
@@ -77,7 +77,7 @@ from .core import (
|
||||
static_range,
|
||||
tensor,
|
||||
trans,
|
||||
triton,
|
||||
# triton,
|
||||
uint16,
|
||||
uint32,
|
||||
uint64,
|
||||
|
||||
@@ -5,9 +5,10 @@ from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Callable, List, Sequence, TypeVar
|
||||
|
||||
import triton
|
||||
from .._C.libtriton.triton import ir
|
||||
# import triton
|
||||
from ..runtime.jit import jit
|
||||
from . import semantic
|
||||
from triton._C.libtriton.triton import ir
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
@@ -1112,59 +1113,67 @@ def _add_atomic_docstr(name: str) -> Callable[[T], T]:
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("compare-and-swap")
|
||||
def atomic_cas(pointer, cmp, val, _builder=None):
|
||||
def atomic_cas(pointer, cmp, val, sem=None, _builder=None):
|
||||
cmp = _to_tensor(cmp, _builder)
|
||||
val = _to_tensor(val, _builder)
|
||||
return semantic.atomic_cas(pointer, cmp, val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_cas(pointer, cmp, val, sem, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("exchange")
|
||||
def atomic_xchg(pointer, val, mask=None, _builder=None):
|
||||
def atomic_xchg(pointer, val, mask=None, sem=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
return semantic.atomic_xchg(pointer, val, mask, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_xchg(pointer, val, mask, sem, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("add")
|
||||
def atomic_add(pointer, val, mask=None, _builder=None):
|
||||
def atomic_add(pointer, val, mask=None, sem=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
return semantic.atomic_add(pointer, val, mask, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_add(pointer, val, mask, sem, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("max")
|
||||
def atomic_max(pointer, val, mask=None, _builder=None):
|
||||
def atomic_max(pointer, val, mask=None, sem=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
return semantic.atomic_max(pointer, val, mask, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_max(pointer, val, mask, sem, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("min")
|
||||
def atomic_min(pointer, val, mask=None, _builder=None):
|
||||
def atomic_min(pointer, val, mask=None, sem=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
return semantic.atomic_min(pointer, val, mask, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_min(pointer, val, mask, sem, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("logical and")
|
||||
def atomic_and(pointer, val, mask=None, _builder=None):
|
||||
def atomic_and(pointer, val, mask=None, sem=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
return semantic.atomic_and(pointer, val, mask, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_and(pointer, val, mask, sem, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("logical or")
|
||||
def atomic_or(pointer, val, mask=None, _builder=None):
|
||||
def atomic_or(pointer, val, mask=None, sem=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
return semantic.atomic_or(pointer, val, mask, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_or(pointer, val, mask, sem, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("logical xor")
|
||||
def atomic_xor(pointer, val, mask=None, _builder=None):
|
||||
def atomic_xor(pointer, val, mask=None, sem=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
return semantic.atomic_xor(pointer, val, mask, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_xor(pointer, val, mask, sem, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
@@ -1339,7 +1348,7 @@ def _promote_reduction_input(t, _builder=None):
|
||||
|
||||
|
||||
@builtin
|
||||
def _argreduce(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
def _reduce_with_indices(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
axis = _constexpr_to_value(axis)
|
||||
n = input.shape[axis]
|
||||
index = arange(0, n, _builder=_builder)
|
||||
@@ -1353,10 +1362,10 @@ def _argreduce(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
|
||||
rvalue, rindices = reduce((input, index), axis, combine_fn,
|
||||
_builder=_builder, _generator=_generator)
|
||||
return rindices
|
||||
return rvalue, rindices
|
||||
|
||||
|
||||
@triton.jit
|
||||
@jit
|
||||
def minimum(x, y):
|
||||
"""
|
||||
Computes the element-wise minimum of :code:`x` and :code:`y`.
|
||||
@@ -1369,7 +1378,7 @@ def minimum(x, y):
|
||||
return where(x < y, x, y)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@jit
|
||||
def maximum(x, y):
|
||||
"""
|
||||
Computes the element-wise maximum of :code:`x` and :code:`y`.
|
||||
@@ -1381,83 +1390,93 @@ def maximum(x, y):
|
||||
"""
|
||||
return where(x > y, x, y)
|
||||
|
||||
# max and argmax
|
||||
|
||||
@triton.jit
|
||||
|
||||
@jit
|
||||
def _max_combine(a, b):
|
||||
return maximum(a, b)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@_add_reduction_docstr("maximum")
|
||||
def max(input, axis=None):
|
||||
input = _promote_reduction_input(input)
|
||||
return reduce(input, axis, _max_combine)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@jit
|
||||
def _argmax_combine(value1, index1, value2, index2):
|
||||
gt = value1 > value2
|
||||
lt = value1 < value2
|
||||
index_min = minimum(index1, index2)
|
||||
index_ret = where(gt, index1, where(lt, index2, index_min))
|
||||
value_ret = maximum(value1, value2)
|
||||
value_ret = where(gt, value1, value2)
|
||||
index_ret = where(gt, index1, index2)
|
||||
return value_ret, index_ret
|
||||
|
||||
|
||||
@triton.jit
|
||||
@jit
|
||||
@_add_reduction_docstr("maximum")
|
||||
def max(input, axis=None, return_indices=False):
|
||||
input = _promote_reduction_input(input)
|
||||
if return_indices:
|
||||
return _reduce_with_indices(input, axis, _argmax_combine)
|
||||
else:
|
||||
return reduce(input, axis, _max_combine)
|
||||
|
||||
|
||||
@jit
|
||||
@_add_reduction_docstr("maximum index")
|
||||
def argmax(input, axis):
|
||||
input = _promote_reduction_input(input)
|
||||
return _argreduce(input, axis, _argmax_combine)
|
||||
(_, ret) = max(input, axis, return_indices=True)
|
||||
return ret
|
||||
|
||||
# min and argmin
|
||||
|
||||
|
||||
@triton.jit
|
||||
@jit
|
||||
def _min_combine(a, b):
|
||||
# TODO: minimum/maximum doesn't get lowered to fmin/fmax...
|
||||
return minimum(a, b)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@_add_reduction_docstr("minimum")
|
||||
def min(input, axis=None):
|
||||
input = _promote_reduction_input(input)
|
||||
return reduce(input, axis, _min_combine)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@jit
|
||||
def _argmin_combine(value1, index1, value2, index2):
|
||||
lt = value1 < value2
|
||||
gt = value1 > value2
|
||||
index_min = minimum(index1, index2)
|
||||
index_ret = where(lt, index1, where(gt, index2, index_min))
|
||||
value_ret = minimum(value1, value2)
|
||||
value_ret = where(lt, value1, value2)
|
||||
index_ret = where(lt, index1, index2)
|
||||
return value_ret, index_ret
|
||||
|
||||
|
||||
@triton.jit
|
||||
@jit
|
||||
@_add_reduction_docstr("minimum")
|
||||
def min(input, axis=None, return_indices=False):
|
||||
input = _promote_reduction_input(input)
|
||||
if return_indices:
|
||||
return _reduce_with_indices(input, axis, _argmin_combine)
|
||||
else:
|
||||
return reduce(input, axis, _min_combine)
|
||||
|
||||
|
||||
@jit
|
||||
@_add_reduction_docstr("minimum index")
|
||||
def argmin(input, axis):
|
||||
input = _promote_reduction_input(input)
|
||||
return _argreduce(input, axis, _argmin_combine)
|
||||
_, ret = min(input, axis, return_indices=True)
|
||||
return ret
|
||||
|
||||
|
||||
@triton.jit
|
||||
@jit
|
||||
def _sum_combine(a, b):
|
||||
return a + b
|
||||
|
||||
# sum
|
||||
|
||||
@triton.jit
|
||||
|
||||
@jit
|
||||
@_add_reduction_docstr("sum")
|
||||
def sum(input, axis=None):
|
||||
input = _promote_reduction_input(input)
|
||||
return reduce(input, axis, _sum_combine)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@jit
|
||||
def _xor_combine(a, b):
|
||||
return a ^ b
|
||||
|
||||
|
||||
# xor sum
|
||||
|
||||
@builtin
|
||||
@_add_reduction_docstr("xor sum")
|
||||
def xor_sum(input, axis=None, _builder=None, _generator=None):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import triton
|
||||
from ..runtime.jit import jit
|
||||
from . import core as tl
|
||||
|
||||
PHILOX_KEY_A: tl.constexpr = 0x9E3779B9
|
||||
@@ -12,7 +12,7 @@ N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
|
||||
# -------------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
@jit
|
||||
def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1).
|
||||
@@ -33,7 +33,7 @@ def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAUL
|
||||
return c0, c1, c2, c3
|
||||
|
||||
|
||||
@triton.jit
|
||||
@jit
|
||||
def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
seed = seed.to(tl.uint64)
|
||||
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
|
||||
@@ -45,7 +45,7 @@ def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@jit
|
||||
def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offset` block, returns a single
|
||||
@@ -61,7 +61,7 @@ def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
return ret
|
||||
|
||||
|
||||
@triton.jit
|
||||
@jit
|
||||
def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offset` block, returns four
|
||||
@@ -82,7 +82,7 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
# rand
|
||||
# -------------------
|
||||
|
||||
# @triton.jit
|
||||
# @jit
|
||||
# def uint32_to_uniform_float(x):
|
||||
# """
|
||||
# Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
|
||||
@@ -90,7 +90,7 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
# two_to_the_minus_32: tl.constexpr = 2.328306e-10
|
||||
# return x * two_to_the_minus_32
|
||||
|
||||
@triton.jit
|
||||
@jit
|
||||
def uint32_to_uniform_float(x):
|
||||
"""
|
||||
Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
|
||||
@@ -102,7 +102,7 @@ def uint32_to_uniform_float(x):
|
||||
return x * scale
|
||||
|
||||
|
||||
@triton.jit
|
||||
@jit
|
||||
def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offset` block,
|
||||
@@ -116,7 +116,7 @@ def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
return uint32_to_uniform_float(source)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@jit
|
||||
def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offsets` block,
|
||||
@@ -138,7 +138,7 @@ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
# -------------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
@jit
|
||||
def pair_uniform_to_normal(u1, u2):
|
||||
"""Box-Muller transform"""
|
||||
u1 = tl.maximum(1.0e-7, u1)
|
||||
@@ -147,7 +147,7 @@ def pair_uniform_to_normal(u1, u2):
|
||||
return r * tl.cos(th), r * tl.sin(th)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@jit
|
||||
def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offset` block,
|
||||
@@ -163,7 +163,7 @@ def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
return n1
|
||||
|
||||
|
||||
@triton.jit
|
||||
@jit
|
||||
def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
"""
|
||||
Given a :code:`seed` scalar and an :code:`offset` block,
|
||||
|
||||
@@ -820,6 +820,22 @@ def _str_to_padding_option(padding_option):
|
||||
return padding
|
||||
|
||||
|
||||
def _str_to_sem(sem_option):
|
||||
sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE
|
||||
if sem_option:
|
||||
if sem_option == "acquire":
|
||||
sem = ir.MEM_SEMANTIC.ACQUIRE
|
||||
elif sem_option == "release":
|
||||
sem = ir.MEM_SEMANTIC.RELEASE
|
||||
elif sem_option == "acq_rel":
|
||||
sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE
|
||||
elif sem_option == "relaxed":
|
||||
sem = ir.MEM_SEMANTIC.RELAXED
|
||||
else:
|
||||
raise ValueError(f"Memory semantic {sem_option} not supported")
|
||||
return sem
|
||||
|
||||
|
||||
def _canonicalize_boundary_check(boundary_check, block_shape):
|
||||
if boundary_check:
|
||||
if not hasattr(boundary_check, "__iter__"):
|
||||
@@ -1030,11 +1046,13 @@ def store(ptr: tl.tensor,
|
||||
def atomic_cas(ptr: tl.tensor,
|
||||
cmp: tl.tensor,
|
||||
val: tl.tensor,
|
||||
sem: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
sem = _str_to_sem(sem)
|
||||
element_ty = ptr.type.scalar.element_ty
|
||||
if element_ty.primitive_bitwidth not in [16, 32, 64]:
|
||||
raise ValueError("atomic_cas only supports elements with width {16, 32, 64}")
|
||||
return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle), val.type)
|
||||
return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem), val.type)
|
||||
|
||||
|
||||
def atom_red_typechecking_impl(ptr: tl.tensor,
|
||||
@@ -1044,7 +1062,6 @@ def atom_red_typechecking_impl(ptr: tl.tensor,
|
||||
builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]:
|
||||
if not ptr.type.scalar.is_ptr():
|
||||
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
|
||||
|
||||
element_ty = ptr.type.scalar.element_ty
|
||||
if element_ty is tl.float16 and op != 'add':
|
||||
raise ValueError("atomic_" + op + " does not support fp16")
|
||||
@@ -1069,8 +1086,10 @@ def atom_red_typechecking_impl(ptr: tl.tensor,
|
||||
def atomic_max(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder)
|
||||
sem = _str_to_sem(sem)
|
||||
sca_ty = val.type.scalar
|
||||
# direct call to atomic_max for integers
|
||||
if sca_ty.is_int():
|
||||
@@ -1078,13 +1097,15 @@ def atomic_max(ptr: tl.tensor,
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX,
|
||||
ptr.handle,
|
||||
val.handle,
|
||||
mask.handle),
|
||||
mask.handle,
|
||||
sem),
|
||||
val.type)
|
||||
else:
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX,
|
||||
ptr.handle,
|
||||
val.handle,
|
||||
mask.handle),
|
||||
mask.handle,
|
||||
sem),
|
||||
val.type)
|
||||
# ROCM TODO: implement atomic_max/min for f32 as they are supported by MI cards.
|
||||
# for float
|
||||
@@ -1094,16 +1115,18 @@ def atomic_max(ptr: tl.tensor,
|
||||
i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder)
|
||||
pos = greater_equal(val, tl.tensor(builder.get_fp32(0), sca_ty), builder)
|
||||
neg = less_than(val, tl.tensor(builder.get_fp32(0), sca_ty), builder)
|
||||
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle), i_val.type)
|
||||
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle), i_val.type)
|
||||
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle, sem), i_val.type)
|
||||
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle, sem), i_val.type)
|
||||
return where(pos, pos_ret, neg_ret, builder)
|
||||
|
||||
|
||||
def atomic_min(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder)
|
||||
sem = _str_to_sem(sem)
|
||||
sca_ty = val.type.scalar
|
||||
# direct call to atomic_min for integers
|
||||
if sca_ty.is_int():
|
||||
@@ -1111,13 +1134,15 @@ def atomic_min(ptr: tl.tensor,
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
|
||||
ptr.handle,
|
||||
val.handle,
|
||||
mask.handle),
|
||||
mask.handle,
|
||||
sem),
|
||||
val.type)
|
||||
else:
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN,
|
||||
ptr.handle,
|
||||
val.handle,
|
||||
mask.handle),
|
||||
mask.handle,
|
||||
sem),
|
||||
val.type)
|
||||
# for float
|
||||
# return atomic_smin(i_ptr, i_val) if val >= 0
|
||||
@@ -1129,12 +1154,14 @@ def atomic_min(ptr: tl.tensor,
|
||||
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
|
||||
i_ptr.handle,
|
||||
i_val.handle,
|
||||
and_(mask, pos, builder).handle),
|
||||
and_(mask, pos, builder).handle,
|
||||
sem),
|
||||
i_val.type)
|
||||
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX,
|
||||
i_ptr.handle,
|
||||
i_val.handle,
|
||||
and_(mask, neg, builder).handle),
|
||||
and_(mask, neg, builder).handle,
|
||||
sem),
|
||||
i_val.type)
|
||||
return where(pos, pos_ret, neg_ret, builder)
|
||||
|
||||
@@ -1142,43 +1169,53 @@ def atomic_min(ptr: tl.tensor,
|
||||
def atomic_add(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder)
|
||||
sem = _str_to_sem(sem)
|
||||
sca_ty = val.type.scalar
|
||||
op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD
|
||||
return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle), val.type)
|
||||
return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem), val.type)
|
||||
|
||||
|
||||
def atomic_and(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle), val.type)
|
||||
sem = _str_to_sem(sem)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem), val.type)
|
||||
|
||||
|
||||
def atomic_or(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle), val.type)
|
||||
sem = _str_to_sem(sem)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem), val.type)
|
||||
|
||||
|
||||
def atomic_xor(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle), val.type)
|
||||
sem = _str_to_sem(sem)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem), val.type)
|
||||
|
||||
|
||||
def atomic_xchg(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle), val.type)
|
||||
sem = _str_to_sem(sem)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem), val.type)
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
# Linear Algebra
|
||||
@@ -1213,12 +1250,13 @@ def dot(lhs: tl.tensor,
|
||||
out_dtype: tl.dtype,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
assert lhs.type.is_block() and rhs.type.is_block()
|
||||
assert lhs.dtype == rhs.dtype, "lhs and rhs must have the same dtype!"
|
||||
assert len(lhs.shape) == 2 and len(rhs.shape) == 2
|
||||
assert lhs.shape[1].value == rhs.shape[0].value
|
||||
assert lhs.dtype == rhs.dtype, f"First input ({lhs.dtype}) and second input ({rhs.dtype}) must have the same dtype!"
|
||||
assert len(lhs.shape) == 2, f"First input shape ({lhs.shape}) is not two dimensional!"
|
||||
assert len(rhs.shape) == 2, f"Second input shape ({rhs.shape}) is not two dimensional!"
|
||||
assert lhs.shape[1].value == rhs.shape[0].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[1].value}) must be equal to first index of second shape ({rhs.shape[0].value})"
|
||||
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
|
||||
and rhs.shape[1].value >= 16,\
|
||||
"small blocks not supported!"
|
||||
f"All values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!"
|
||||
if lhs.type.scalar.is_int():
|
||||
assert lhs.type.scalar == tl.int8, "only int8 supported!"
|
||||
# TODO: This is CUDA specific, check if ROCm has the same limitation
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from ... import cdiv, heuristics, jit
|
||||
from ... import language as tl
|
||||
|
||||
# import triton
|
||||
# import language as tl
|
||||
|
||||
# ********************************************************
|
||||
# --------------------------------------------------------
|
||||
@@ -13,10 +16,10 @@ import triton.language as tl
|
||||
# ********************************************************
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
@heuristics({
|
||||
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
|
||||
})
|
||||
@triton.jit
|
||||
@jit
|
||||
def _sdd_kernel(
|
||||
A, B, C,
|
||||
stride_za, stride_ha, stride_ma, stride_ak,
|
||||
@@ -127,7 +130,7 @@ def sdd_lut(layout, block, device):
|
||||
# -----------------------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
@jit
|
||||
def _dsd_kernel(
|
||||
A, B, C,
|
||||
stride_az, stride_ha, stride_am, stride_ak,
|
||||
@@ -227,7 +230,7 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=N
|
||||
# meta-parameter heuristics
|
||||
TILE_N = 128
|
||||
# compute output
|
||||
grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0]
|
||||
grid = lambda meta: [cdiv(BS3, meta['TILE_N']), width, BS0]
|
||||
_dsd_kernel[grid](
|
||||
a, b, c,
|
||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
# import triton
|
||||
# import language as tl
|
||||
from ... import jit
|
||||
from ... import language as tl
|
||||
from ... import next_power_of_2
|
||||
|
||||
|
||||
def num_warps(n):
|
||||
@@ -16,7 +19,7 @@ def num_warps(n):
|
||||
return 16
|
||||
|
||||
|
||||
@triton.jit
|
||||
@jit
|
||||
def _blocksparse_softmax_fwd(
|
||||
Out, A, stride_xz, LUT,
|
||||
R, extent, stride_zr, stride_hr, # relative attention
|
||||
@@ -71,7 +74,7 @@ def _blocksparse_softmax_fwd(
|
||||
tl.store(Out + off_a + lane_n, out, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@jit
|
||||
def _blocksparse_softmax_bwd(
|
||||
DA, stride_zdx,
|
||||
DOut, stride_zdout,
|
||||
@@ -169,7 +172,7 @@ class _softmax(torch.autograd.Function):
|
||||
scale,
|
||||
is_causal,
|
||||
BLOCK_SIZE=block,
|
||||
ROW_SIZE=triton.next_power_of_2(maxlut),
|
||||
ROW_SIZE=next_power_of_2(maxlut),
|
||||
IS_DENSE=is_dense,
|
||||
num_warps=num_warps(maxlut)
|
||||
)
|
||||
@@ -208,7 +211,7 @@ class _softmax(torch.autograd.Function):
|
||||
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2],
|
||||
ctx.is_causal,
|
||||
BLOCK_SIZE=ctx.block,
|
||||
ROW_SIZE=triton.next_power_of_2(ctx.maxlut),
|
||||
ROW_SIZE=next_power_of_2(ctx.maxlut),
|
||||
IS_DENSE=ctx.is_dense,
|
||||
num_warps=num_warps(ctx.maxlut)
|
||||
)
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
# import triton
|
||||
# import language as tl
|
||||
from .. import heuristics, jit
|
||||
from .. import language as tl
|
||||
from .. import next_power_of_2
|
||||
|
||||
|
||||
def num_warps(N):
|
||||
@@ -12,9 +15,9 @@ def num_warps(N):
|
||||
return 16
|
||||
|
||||
|
||||
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
|
||||
@triton.heuristics({'BLOCK': lambda nargs: triton.next_power_of_2(nargs['N'])})
|
||||
@triton.jit
|
||||
@heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
|
||||
@heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])})
|
||||
@jit
|
||||
def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
|
||||
row = tl.program_id(0)
|
||||
cols = tl.arange(0, BLOCK)
|
||||
@@ -37,9 +40,9 @@ def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
|
||||
tl.store(LOSS + row, probs)
|
||||
|
||||
|
||||
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
|
||||
@triton.heuristics({'BLOCK': lambda nargs: triton.next_power_of_2(nargs['N'])})
|
||||
@triton.jit
|
||||
@heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
|
||||
@heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])})
|
||||
@jit
|
||||
def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
|
||||
row = tl.program_id(0)
|
||||
cols = tl.arange(0, BLOCK)
|
||||
|
||||
@@ -7,11 +7,14 @@ This is a Triton implementation of the Flash Attention algorithm
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from .. import cdiv, jit
|
||||
from .. import language as tl
|
||||
|
||||
# import triton
|
||||
# import language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
@jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
L, M,
|
||||
@@ -87,7 +90,7 @@ def _fwd_kernel(
|
||||
tl.store(out_ptrs, acc)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@jit
|
||||
def _bwd_preprocess(
|
||||
Out, DO, L,
|
||||
NewDO, Delta,
|
||||
@@ -107,7 +110,7 @@ def _bwd_preprocess(
|
||||
tl.store(Delta + off_m, delta)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@jit
|
||||
def _bwd_kernel(
|
||||
Q, K, V, sm_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
@@ -205,7 +208,7 @@ class _attention(torch.autograd.Function):
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
o = torch.empty_like(q)
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
|
||||
grid = (cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from .. import Config, autotune, cdiv, heuristics, jit
|
||||
from .. import language as tl
|
||||
from .matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
# import triton
|
||||
# import language as tl
|
||||
|
||||
|
||||
def init_to_zero(name):
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
@@ -17,37 +20,37 @@ def get_configs_io_bound():
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
num_warps = 2 if block_n <= 64 else 4
|
||||
configs.append(
|
||||
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||
num_stages=num_stages, num_warps=num_warps))
|
||||
Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||
num_stages=num_stages, num_warps=num_warps))
|
||||
# split_k
|
||||
for split_k in [2, 4, 8, 16]:
|
||||
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
configs.append(Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
return configs
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
@autotune(
|
||||
configs=[
|
||||
# basic configs for compute-bound matmuls
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
# good for int8
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
] + get_configs_io_bound(),
|
||||
key=['M', 'N', 'K'],
|
||||
prune_configs_by={
|
||||
@@ -56,10 +59,10 @@ def get_configs_io_bound():
|
||||
'top_k': 10
|
||||
},
|
||||
)
|
||||
@triton.heuristics({
|
||||
@heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||
})
|
||||
@triton.jit
|
||||
@jit
|
||||
def _kernel(A, B, C, M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
@@ -146,7 +149,7 @@ class _matmul(torch.autograd.Function):
|
||||
else:
|
||||
dot_out_dtype = tl.int32
|
||||
# launch kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||
grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||
_kernel[grid](a, b, c, M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
|
||||
@@ -2,10 +2,11 @@ import heapq
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from triton.runtime import driver
|
||||
from triton.testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops
|
||||
# import triton
|
||||
from .. import cdiv
|
||||
from .._C.libtriton.triton import runtime
|
||||
from ..runtime import driver
|
||||
from ..testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops
|
||||
|
||||
|
||||
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
@@ -41,13 +42,13 @@ def estimate_matmul_time(
|
||||
):
|
||||
''' return estimated running time in ms
|
||||
= max(compute, loading) + store '''
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
backend = runtime.backend.CUDA
|
||||
device = torch.cuda.current_device()
|
||||
dtype = A.dtype
|
||||
dtsize = A.element_size()
|
||||
|
||||
num_cta_m = triton.cdiv(M, BLOCK_M)
|
||||
num_cta_n = triton.cdiv(N, BLOCK_N)
|
||||
num_cta_m = cdiv(M, BLOCK_M)
|
||||
num_cta_n = cdiv(N, BLOCK_N)
|
||||
num_cta_k = SPLIT_K
|
||||
num_ctas = num_cta_m * num_cta_n * num_cta_k
|
||||
|
||||
|
||||
@@ -8,9 +8,15 @@ import os
|
||||
import subprocess
|
||||
import textwrap
|
||||
from collections import defaultdict, namedtuple
|
||||
from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, cast, overload
|
||||
from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast,
|
||||
overload)
|
||||
|
||||
import triton
|
||||
# import triton
|
||||
# from .. import compile, CompiledKernel
|
||||
from ..common.backend import get_backend
|
||||
|
||||
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
TRITON_VERSION = "2.1.0"
|
||||
|
||||
|
||||
def get_cuda_stream(idx=None):
|
||||
@@ -65,7 +71,7 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
lhs = self.visit(node.value)
|
||||
while isinstance(lhs, ast.Attribute):
|
||||
lhs = self.visit(lhs.value)
|
||||
if lhs is None or lhs is triton:
|
||||
if lhs is None or lhs.__name__ == "triton":
|
||||
return None
|
||||
return getattr(lhs, node.attr)
|
||||
|
||||
@@ -100,15 +106,15 @@ def version_key():
|
||||
with open(__file__, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# compiler
|
||||
compiler_path = os.path.join(*triton.__path__, 'compiler')
|
||||
compiler_path = os.path.join(TRITON_PATH, 'compiler')
|
||||
for lib in pkgutil.iter_modules([compiler_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# backend
|
||||
with open(triton._C.libtriton.__file__, "rb") as f:
|
||||
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# language
|
||||
language_path = os.path.join(*triton.__path__, 'language')
|
||||
language_path = os.path.join(TRITON_PATH, 'language')
|
||||
for lib in pkgutil.iter_modules([language_path]):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
@@ -117,7 +123,7 @@ def version_key():
|
||||
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
|
||||
except Exception:
|
||||
ptxas_version = ''
|
||||
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||
return '-'.join(TRITON_VERSION) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||
|
||||
|
||||
class KernelInterface(Generic[T]):
|
||||
@@ -158,6 +164,22 @@ class JITFunction(KernelInterface[T]):
|
||||
else:
|
||||
raise TypeError(f'Unsupported type {type(arg)} for {arg}')
|
||||
|
||||
@staticmethod
|
||||
def _device_of(arg):
|
||||
if hasattr(arg, "device"):
|
||||
if hasattr(arg.device, 'type'):
|
||||
return arg.device.type
|
||||
|
||||
return ''
|
||||
|
||||
@staticmethod
|
||||
def _pinned_memory_of(arg):
|
||||
if hasattr(arg, "is_pinned"):
|
||||
if isinstance(arg.is_pinned, Callable):
|
||||
return arg.is_pinned()
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _spec_of(arg):
|
||||
if hasattr(arg, "data_ptr"):
|
||||
@@ -261,12 +283,29 @@ class JITFunction(KernelInterface[T]):
|
||||
else:
|
||||
return f'_key_of({arg})'
|
||||
|
||||
def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: List[bool]) -> str:
|
||||
device_types = [device_type for device_type in device_types if device_type != '']
|
||||
# Return cuda if one of the input tensors is cuda
|
||||
if 'cuda' in device_types:
|
||||
import torch
|
||||
return 'hip' if torch.version.hip else 'cuda'
|
||||
|
||||
is_cpu = all(device_type == 'cpu' for device_type in device_types)
|
||||
is_pinned_memory = any(pinned_memory_flag for pinned_memory_flag in pinned_memory_flags)
|
||||
# Return cuda if all the input tensors are cpu while the memory is pinned
|
||||
if is_cpu and is_pinned_memory:
|
||||
return 'cuda'
|
||||
|
||||
return device_types[0] if len(device_types) > 0 else 'cuda'
|
||||
|
||||
def _make_launcher(self):
|
||||
regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
|
||||
constexpr_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i in self.constexprs]
|
||||
args = ', '.join(regular_args)
|
||||
# cache key for regular argument type
|
||||
sig_keys = ', '.join([self._get_arg_sig_key(arg) for arg in regular_args])
|
||||
device_types = '[' + ', '.join([f'_device_of({arg})' for arg in regular_args]) + ']'
|
||||
pinned_memory_flags = '[' + ', '.join([f'_pinned_memory_of({arg})' for arg in regular_args]) + ']'
|
||||
# cache key for constexpr argument values
|
||||
constexpr_keys = ', '.join(constexpr_args)
|
||||
# cache key for argument specialization
|
||||
@@ -280,7 +319,12 @@ class JITFunction(KernelInterface[T]):
|
||||
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
|
||||
|
||||
src = f"""
|
||||
<<<<<<< HEAD
|
||||
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, warp_size=32, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None):
|
||||
=======
|
||||
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
from ..compiler import compile, CompiledKernel
|
||||
>>>>>>> oai/main
|
||||
sig_key = {sig_keys},
|
||||
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
|
||||
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()}
|
||||
@@ -294,15 +338,38 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, warp_size
|
||||
grid_0 = grid[0]
|
||||
grid_1 = grid[1] if grid_size > 1 else 1
|
||||
grid_2 = grid[2] if grid_size > 2 else 1
|
||||
|
||||
if device_type is None:
|
||||
device_types = [_device_type for _device_type in {device_types} if _device_type != '']
|
||||
device_type = self._conclude_device_type(device_types, {pinned_memory_flags})
|
||||
|
||||
device_backend = None
|
||||
if device_type not in ['cuda', 'hip']:
|
||||
device_backend = get_backend(device_type)
|
||||
if device_backend is None:
|
||||
raise ValueError('Cannot find backend for ' + device_type)
|
||||
|
||||
if device is None:
|
||||
device = get_current_device()
|
||||
set_current_device(device)
|
||||
if device_type in ['cuda', 'hip']:
|
||||
device = get_current_device()
|
||||
set_current_device(device)
|
||||
else:
|
||||
device = device_backend.get_current_device()
|
||||
device_backend.set_current_device(device)
|
||||
if stream is None and not warmup:
|
||||
stream = get_cuda_stream(device)
|
||||
if device_type in ['cuda', 'hip']:
|
||||
stream = get_cuda_stream(device)
|
||||
else:
|
||||
stream = device_backend.get_stream()
|
||||
|
||||
bin = cache[device].get(key, None)
|
||||
if bin is not None:
|
||||
if not warmup:
|
||||
<<<<<<< HEAD
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.warp_size, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args})
|
||||
=======
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, {args})
|
||||
>>>>>>> oai/main
|
||||
return bin
|
||||
# kernel not cached -- compile
|
||||
else:
|
||||
@@ -319,17 +386,31 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, warp_size
|
||||
for i, arg in constants.items():
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
|
||||
<<<<<<< HEAD
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, warp_size, num_stages, extern_libs, configs):
|
||||
bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, warp_size=warp_size, num_stages=num_stages, extern_libs=extern_libs, configs=configs, debug=self.debug)
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.warp_size, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)
|
||||
=======
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
|
||||
>>>>>>> oai/main
|
||||
self.cache[device][key] = bin
|
||||
return bin
|
||||
return None
|
||||
"""
|
||||
scope = {"version_key": version_key(), "get_cuda_stream": get_cuda_stream,
|
||||
"self": self, "_spec_of": self._spec_of, "_key_of": self._key_of,
|
||||
"cache": self.cache, "triton": triton,
|
||||
scope = {"version_key": version_key(),
|
||||
"get_cuda_stream": get_cuda_stream,
|
||||
"self": self,
|
||||
"_spec_of": self._spec_of,
|
||||
"_key_of": self._key_of,
|
||||
"_device_of": self._device_of,
|
||||
"_pinned_memory_of": self._pinned_memory_of,
|
||||
"cache": self.cache,
|
||||
"__spec__": __spec__,
|
||||
"get_backend": get_backend,
|
||||
"get_current_device": get_current_device,
|
||||
"set_current_device": set_current_device}
|
||||
exec(src, scope)
|
||||
@@ -462,7 +543,7 @@ def jit(
|
||||
def decorator(fn: T) -> JITFunction[T]:
|
||||
assert callable(fn)
|
||||
if interpret:
|
||||
from ..debugger.debugger import GridSelector
|
||||
from ..interpreter.interpreter import GridSelector
|
||||
return GridSelector(fn)
|
||||
else:
|
||||
return JITFunction(
|
||||
|
||||
@@ -4,7 +4,7 @@ import subprocess
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from ._C.libtriton.triton import runtime
|
||||
|
||||
|
||||
def nvsmi(attrs):
|
||||
@@ -86,9 +86,9 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
|
||||
end_event[i].record()
|
||||
# Record clocks
|
||||
torch.cuda.synchronize()
|
||||
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)])
|
||||
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float)
|
||||
if quantiles is not None:
|
||||
ret = torch.quantile(times, torch.tensor(quantiles)).tolist()
|
||||
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
|
||||
if len(ret) == 1:
|
||||
ret = ret[0]
|
||||
return ret
|
||||
@@ -281,7 +281,7 @@ def get_dram_gbps(backend=None, device=None):
|
||||
|
||||
from .runtime import driver
|
||||
if not backend:
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
backend = runtime.backend.CUDA
|
||||
if not device:
|
||||
device = torch.cuda.current_device()
|
||||
mem_clock_khz = driver.utils.get_device_properties(device)["mem_clock_rate"] # in kHz
|
||||
@@ -295,7 +295,7 @@ def get_max_tensorcore_tflops(dtype, backend=None, device=None, clock_rate=None)
|
||||
|
||||
from .runtime import driver
|
||||
if not backend:
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
backend = runtime.backend.CUDA
|
||||
if not device:
|
||||
device = torch.cuda.current_device()
|
||||
|
||||
@@ -398,7 +398,7 @@ def get_max_simd_tflops(dtype, backend=None, device=None):
|
||||
|
||||
from .runtime import driver
|
||||
if not backend:
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
backend = runtime.backend.CUDA
|
||||
if not device:
|
||||
device = torch.cuda.current_device()
|
||||
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import triton._C.libtriton.triton as libtriton
|
||||
import triton.compiler.compiler as tc
|
||||
from .._C.libtriton.triton import ir
|
||||
# import triton.compiler.compiler as tc
|
||||
from ..compiler.compiler import (get_amdgpu_arch_fulldetails, llir_to_amdgcn_and_hsaco,
|
||||
llir_to_ptx, optimize_ttgir, optimize_ttir,
|
||||
ttgir_to_llir, ttir_to_ttgir)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -32,12 +35,12 @@ if __name__ == '__main__':
|
||||
sys.exit(0)
|
||||
|
||||
# parse source file to MLIR module
|
||||
context = libtriton.ir.context()
|
||||
module = libtriton.ir.parse_mlir_module(args.src, context)
|
||||
context = ir.context()
|
||||
module = ir.parse_mlir_module(args.src, context)
|
||||
module.context = context
|
||||
|
||||
# optimizer triton-ir
|
||||
module = tc.optimize_ttir(module, arch=args.sm)
|
||||
module = optimize_ttir(module, arch=args.sm)
|
||||
if args.target == 'triton-ir':
|
||||
print(module.str())
|
||||
sys.exit(0)
|
||||
@@ -49,7 +52,7 @@ if __name__ == '__main__':
|
||||
if args.target == 'amdgcn':
|
||||
# auto detect available architecture and features
|
||||
# if nothing detected, set with default values
|
||||
arch_details = tc.get_amdgpu_arch_fulldetails()
|
||||
arch_details = get_amdgpu_arch_fulldetails()
|
||||
if not arch_details:
|
||||
arch_name = ""
|
||||
arch_triple = "amdgcn-amd-amdhsa"
|
||||
@@ -72,13 +75,21 @@ if __name__ == '__main__':
|
||||
|
||||
# triton-ir -> triton-gpu-ir
|
||||
# use compute_capability == 80
|
||||
<<<<<<< HEAD
|
||||
module = tc.ttir_to_ttgir(module, num_warps=args.num_warps, warpsize=arch_warpsize) # num_stages=3, compute_capability=80)
|
||||
module = tc.optimize_ttgir(module, num_stages=3, arch=args.gfx)
|
||||
# triton-gpu-ir -> llvm-ir
|
||||
# use compute_capability == 80
|
||||
module = tc.ttgir_to_llir(module, extern_libs=None, arch=args.gfx)
|
||||
=======
|
||||
module = ttir_to_ttgir(module, num_warps=args.num_warps) # num_stages=3, compute_capability=80)
|
||||
module = optimize_ttgir(module, num_stages=3, arch=80)
|
||||
# triton-gpu-ir -> llvm-ir
|
||||
# use compute_capability == 80
|
||||
module = ttgir_to_llir(module, extern_libs=None, arch=80)
|
||||
>>>>>>> oai/main
|
||||
# llvm-ir -> amdgcn asm, hsaco binary
|
||||
module, hsaco_path = tc.llir_to_amdgcn_and_hsaco(module, arch_name, arch_triple, arch_features)
|
||||
module, hsaco_path = llir_to_amdgcn_and_hsaco(module, arch_name, arch_triple, arch_features)
|
||||
|
||||
print(hsaco_path)
|
||||
print(module)
|
||||
@@ -93,14 +104,23 @@ if __name__ == '__main__':
|
||||
raise argparse.ArgumentError(None, "Must specify --sm or --gfx for ttgir compilation")
|
||||
|
||||
# triton-ir -> triton-gpu-ir
|
||||
<<<<<<< HEAD
|
||||
module = tc.ttir_to_ttgir(module, num_warps=args.num_warps, warpsize=tc.CUDA_DEFAULT_WARP_SIZE)
|
||||
module = tc.optimize_ttgir(module, num_stages=3, arch=arch)
|
||||
=======
|
||||
module = ttir_to_ttgir(module, num_warps=args.num_warps)
|
||||
module = optimize_ttgir(module, num_stages=3, arch=args.sm)
|
||||
>>>>>>> oai/main
|
||||
if args.target == 'triton-gpu-ir':
|
||||
print(module.str())
|
||||
sys.exit(0)
|
||||
|
||||
# triton-gpu-ir -> llvm-ir
|
||||
<<<<<<< HEAD
|
||||
module = tc.ttgir_to_llir(module, extern_libs=None, arch=arch)
|
||||
=======
|
||||
module = ttgir_to_llir(module, extern_libs=None, arch=args.sm)
|
||||
>>>>>>> oai/main
|
||||
if args.target == 'llvm-ir':
|
||||
print(module)
|
||||
sys.exit(0)
|
||||
@@ -111,12 +131,12 @@ if __name__ == '__main__':
|
||||
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
|
||||
if not args.ptx_version:
|
||||
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")
|
||||
module = tc.llir_to_ptx(module, arch=args.sm, ptx_version=args.ptx_version)
|
||||
module = llir_to_ptx(module, arch=args.sm, ptx_version=args.ptx_version)
|
||||
|
||||
# llvm-ir -> amdgcn
|
||||
if args.target == 'amdgcn':
|
||||
if not args.gfx:
|
||||
raise argparse.ArgumentError(None, "Must specify --gfx for AMDGCN compilation")
|
||||
module, hsaco_path = tc.llir_to_amdgcn_and_hsaco(module, args.gfx)
|
||||
module, hsaco_path = llir_to_amdgcn_and_hsaco(module, args.gfx)
|
||||
|
||||
print(module)
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s
|
||||
|
||||
tt.func @ops() {
|
||||
<<<<<<< HEAD
|
||||
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {{.*}}
|
||||
=======
|
||||
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {{.*}}
|
||||
>>>>>>> oai/main
|
||||
%a = arith.constant dense<1.00e+00> : tensor<128x32xf16>
|
||||
%b = arith.constant dense<2.00e+00> : tensor<32x128xf16>
|
||||
%c = arith.constant dense<3.00e+00> : tensor<128x128xf32>
|
||||
@@ -33,10 +37,17 @@ tt.func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
tt.func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
|
||||
// Test if the total number of threadsPerWarp is 64
|
||||
// Test if the total number of warps is 2
|
||||
<<<<<<< HEAD
|
||||
// CHECK: #[[blocked0:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 16], warpsPerCTA = [1, 2], order = [0, 1]}>
|
||||
// CHECK: #[[blocked1:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 2], order = [0, 1]}>
|
||||
// CHECK: #[[blocked2:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 2], order = [0, 1]}>
|
||||
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {{.*}}
|
||||
=======
|
||||
// CHECK: #[[blocked0:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 2], order = [0, 1]}>
|
||||
// CHECK: #[[blocked1:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 2], order = [0, 1]}>
|
||||
// CHECK: #[[blocked2:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 2], order = [0, 1]}>
|
||||
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {{.*}}
|
||||
>>>>>>> oai/main
|
||||
%c0 = arith.constant dense<1.00e+00> : tensor<4x4xf32>
|
||||
%c1 = arith.constant dense<2.00e+00> : tensor<8x2xf32>
|
||||
%c2 = arith.constant dense<3.00e+00> : tensor<16x16xf32>
|
||||
|
||||
@@ -132,6 +132,21 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: store_with_cache_attr
|
||||
tt.func @store_with_cache_attr(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: st.global.L1::evict_last.b32
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: st.global.L1::evict_last.b32
|
||||
tt.store %a_ptr_init, %cst_0, %cst {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
// CHECK-LABEL: global_load_store_no_vec
|
||||
@@ -1462,6 +1477,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: atomic_add_f32
|
||||
tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
|
||||
<<<<<<< HEAD
|
||||
// GCN-NOT: llvm.inline_asm
|
||||
// GCN: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr<f32, 1>, f32
|
||||
// PTC: llvm.inline_asm
|
||||
@@ -1469,6 +1485,13 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// PTX: llvm.inline_asm
|
||||
// PTX-SAME: @$3 atom.global.gpu.add.f32
|
||||
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
|
||||
=======
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: @$3 atom.global.gpu.relaxed.add.f32
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: @$3 atom.global.gpu.relaxed.add.f32
|
||||
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32, sem = 1 : i32} : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
|
||||
>>>>>>> oai/main
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
@@ -1478,6 +1501,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: atomic_add_f32_scalar
|
||||
tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : i1, %arg2 : f32) {
|
||||
<<<<<<< HEAD
|
||||
// GCN-NOT: llvm.inline_asm
|
||||
// GCN: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr<f32, 1>, f32
|
||||
// PTX: llvm.icmp "eq"
|
||||
@@ -1485,6 +1509,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// PTX: llvm.inline_asm
|
||||
// PTX-SAME: @$3 atom.global.gpu.add.f32
|
||||
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (!tt.ptr<f32>, f32, i1) -> f32
|
||||
=======
|
||||
// CHECK: llvm.icmp "eq"
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: @$3 atom.global.gpu.relaxed.add.f32
|
||||
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32, sem = 1: i32} : (!tt.ptr<f32>, f32, i1) -> f32
|
||||
>>>>>>> oai/main
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
// CHECK: [[$col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
// CHECK: [[$col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
// CHECK-LABEL: cst
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
tt.func @cst() -> tensor<1024xi32, #layout1> {
|
||||
%cst = arith.constant dense<0> : tensor<1024xi32, #layout0>
|
||||
%1 = triton_gpu.convert_layout %cst : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
|
||||
@@ -67,8 +69,6 @@ tt.func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
tt.return
|
||||
}
|
||||
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
tt.func @remat_fast_load(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%0 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<16x!tt.ptr<i32>, #layout1>
|
||||
%1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #layout1>
|
||||
@@ -80,7 +80,6 @@ tt.func @remat_fast_load(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
tt.store %5, %4 : tensor<16xi32, #layout0>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: if
|
||||
tt.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
@@ -164,6 +163,8 @@ tt.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility =
|
||||
tt.return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
#slice1dim1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}>
|
||||
@@ -173,6 +174,7 @@ tt.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility =
|
||||
#blocked4 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
|
||||
// CHECK-LABEL: transpose
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
tt.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
// CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, {{%cst.*}}, {{%cst.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[$row_layout]]>
|
||||
@@ -212,8 +214,10 @@ tt.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32
|
||||
tt.store %24, %25, %26 : tensor<64x64xf32, #blocked4>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: loop
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
tt.func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
// CHECK: [[loop_ret:%.*]]:2 = scf.for {{.*}} -> (tensor<64x64xf32, [[$row_layout]]>, tensor<64x64x!tt.ptr<f32>, [[$row_layout]]>)
|
||||
@@ -266,8 +270,10 @@ tt.func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32,
|
||||
tt.store %20, %21, %22 : tensor<64x64xf32, #blocked1>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: loop_if
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
tt.func @loop_if(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32) {
|
||||
%cst = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
|
||||
%cst_0 = arith.constant dense<64> : tensor<64x64xi32, #blocked1>
|
||||
@@ -318,8 +324,10 @@ tt.func @loop_if(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i3
|
||||
tt.store %20, %21, %22 : tensor<64x64xf32, #blocked1>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: vecadd
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
tt.func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
@@ -349,9 +357,11 @@ tt.func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr
|
||||
tt.store %21, %22 : tensor<256xf32, #layout1>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// Select has args with different element types
|
||||
// CHECK-LABEL: select
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
tt.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%cst = arith.constant dense<30000> : tensor<1x1xi32, #blocked2>
|
||||
@@ -400,9 +410,11 @@ tt.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr
|
||||
}
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure the following IR doesn't hang the compiler.
|
||||
// CHECK-LABEL: long_func
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
tt.func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg13: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg14: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg15: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<1.000000e+00> : tensor<1024xf32, #blocked0>
|
||||
%cst_0 = arith.constant dense<5.000000e-04> : tensor<1024xf32, #blocked0>
|
||||
@@ -796,10 +808,12 @@ tt.func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg
|
||||
tt.store %365, %366 : tensor<1024xf64, #blocked0>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// A mnist model from torch inductor.
|
||||
// Check if topological sort is working correct and there's no unnecessary convert
|
||||
// CHECK-LABEL: mnist
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
tt.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%cst = arith.constant dense<10> : tensor<16x1xi32, #blocked2>
|
||||
@@ -884,17 +898,19 @@ tt.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !
|
||||
tt.store %61, %62, %63 : tensor<16x16xf32, #blocked4>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// cmpf and cmpi have different operands and result types
|
||||
// CHECK-LABEL: cmp
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1]}>
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 4], order = [0, 1]}>
|
||||
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [0, 1]}>
|
||||
#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
|
||||
#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
|
||||
// cmpf and cmpi have different operands and result types
|
||||
// CHECK-LABEL: cmp
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
tt.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
|
||||
%c64 = arith.constant 64 : index
|
||||
%c2048 = arith.constant 2048 : index
|
||||
@@ -1034,11 +1050,13 @@ tt.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt
|
||||
}
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Just make sure it doesn't crash on non-tensor types.
|
||||
// CHECK-LABEL: if_no_tensor
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
tt.func public @if_no_tensor(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%c-1_i64 = arith.constant -1 : i64
|
||||
@@ -1062,6 +1080,7 @@ tt.func public @if_no_tensor(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %
|
||||
tt.store %9, %8 {cache = 1 : i32, evict = 1 : i32} : f32
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
@@ -11,6 +11,8 @@
|
||||
#BLR = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#BLC = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// CHECK: tt.func @push_elementwise1
|
||||
// CHECK: %[[ALOAD:.*]] = tt.load %arg0
|
||||
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]]
|
||||
@@ -122,3 +124,5 @@ tt.func @push_elementwise5(
|
||||
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av1> * tensor<16x16xf16, #Bv1> -> tensor<16x16xf32, #Cv1>
|
||||
tt.return %newc : tensor<16x16xf32, #Cv1>
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -313,3 +313,122 @@ tt.func @lut_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt
|
||||
}
|
||||
tt.return %79#0 : tensor<16x16xf32, #C>
|
||||
}
|
||||
|
||||
// CHECK: tt.func @post_load_inv
|
||||
// CHECK: scf.for
|
||||
// CHECK: arith.index_cast
|
||||
// CHECK-DAG: %[[IV:.*]] = arith.index_cast
|
||||
// CHECK: %[[NEXT_IV:.*]] = arith.addi %[[IV]], %c1_i32 : i32
|
||||
// CHECK-NOT: arith.addi %[[NEXT_IV]]
|
||||
tt.func @post_load_inv(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%arg3: i32 {tt.divisibility = 16 : i32},
|
||||
%arg4: i32 {tt.divisibility = 16 : i32},
|
||||
%arg5: i32 {tt.divisibility = 16 : i32},
|
||||
%arg6: i32 {tt.divisibility = 16 : i32},
|
||||
%arg7: i32 {tt.divisibility = 16 : i32},
|
||||
%arg8: i32 {tt.divisibility = 16 : i32}) -> tensor<32x32xf32, #C> {
|
||||
%c0_index = arith.constant 0 : index
|
||||
%c1_index = arith.constant 1 : index
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
%c32_i32 = arith.constant 32 : i32
|
||||
%84 = arith.constant 900 : index
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #C>
|
||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #AL>
|
||||
%50 = tt.splat %arg3 : (i32) -> tensor<1x32xi32, #AL>
|
||||
%59 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>, #AL>
|
||||
%81 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>, #AL>
|
||||
%66 = tt.splat %arg4 : (i32) -> tensor<32x1xi32, #AL>
|
||||
%60 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>, #AL>
|
||||
%82 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>, #AL>
|
||||
%85:3 = scf.for %arg9 = %c0_index to %84 step %c1_index iter_args(%arg10 = %cst, %arg11 = %59, %arg12 = %81) -> (tensor<32x32xf32, #C>, tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32x!tt.ptr<f32>, #AL>) {
|
||||
%130 = arith.index_cast %arg9 : index to i32
|
||||
%107 = arith.muli %130, %c32_i32 : i32
|
||||
%108 = arith.subi %arg5, %107 : i32
|
||||
%109 = tt.splat %108 : (i32) -> tensor<1x32xi32, #AL>
|
||||
%110 = "triton_gpu.cmpi"(%50, %109) <{predicate = 2 : i64}> : (tensor<1x32xi32, #AL>, tensor<1x32xi32, #AL>) -> tensor<1x32xi1, #AL>
|
||||
%111 = tt.broadcast %110 : (tensor<1x32xi1, #AL>) -> tensor<32x32xi1, #AL>
|
||||
%112 = tt.load %arg11, %111, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL>
|
||||
%113 = tt.splat %108 : (i32) -> tensor<32x1xi32, #AL>
|
||||
%114 = "triton_gpu.cmpi"(%66, %113) <{predicate = 2 : i64}> : (tensor<32x1xi32, #AL>, tensor<32x1xi32, #AL>) -> tensor<32x1xi1, #AL>
|
||||
%115 = tt.broadcast %114 : (tensor<32x1xi1, #AL>) -> tensor<32x32xi1, #AL>
|
||||
%116 = tt.load %arg12, %115, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL>
|
||||
%117 = triton_gpu.convert_layout %112 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>>
|
||||
%118 = triton_gpu.convert_layout %116 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>>
|
||||
%119 = tt.dot %117, %118, %arg10 {allowTF32 = true} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C>
|
||||
%131 = arith.index_cast %arg9 : index to i32
|
||||
%120 = arith.addi %131, %c1_i32 : i32
|
||||
%121 = arith.muli %120, %c32_i32 : i32
|
||||
%122 = tt.splat %121 : (i32) -> tensor<32x32xi32, #AL>
|
||||
%123 = tt.addptr %60, %122 : tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32xi32, #AL>
|
||||
%124 = arith.muli %121, %arg7 : i32
|
||||
%125 = tt.splat %124 : (i32) -> tensor<32x32xi32, #AL>
|
||||
%126 = tt.addptr %82, %125 : tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32xi32, #AL>
|
||||
scf.yield %119, %123, %126 : tensor<32x32xf32, #C>, tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32x!tt.ptr<f32>, #AL>
|
||||
}
|
||||
tt.return %85#0 : tensor<32x32xf32, #C>
|
||||
}
|
||||
|
||||
// CHECK: tt.func @cross_iter_dep
|
||||
// CHECK: triton_gpu.async_commit_group
|
||||
// CHECK: triton_gpu.async_commit_group
|
||||
// CHECK: triton_gpu.async_commit_group
|
||||
// CHECK: triton_gpu.async_commit_group
|
||||
// CHECK: %[[PTR0:.*]] = tt.addptr
|
||||
// CHECK: %[[PTR1:.*]] = tt.addptr
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[BUF0:.*]] = %[[PTR0]], {{.*}}, %[[BUF1:.*]] = %[[PTR1]]
|
||||
// CHECK: scf.yield
|
||||
// CHECK-SAME: %[[BUF0]]
|
||||
// CHECK-SAME: %[[BUF1]]
|
||||
tt.func @cross_iter_dep(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%arg3: i32 {tt.divisibility = 16 : i32},
|
||||
%arg4: i32 {tt.divisibility = 16 : i32},
|
||||
%arg5: i32 {tt.divisibility = 16 : i32},
|
||||
%arg6: i32 {tt.divisibility = 16 : i32},
|
||||
%arg7: i32 {tt.divisibility = 16 : i32},
|
||||
%arg8: i32 {tt.divisibility = 16 : i32}) -> tensor<32x32xf32, #C> {
|
||||
%c0_i32 = arith.constant 0 : index
|
||||
%118 = arith.constant 32 : index
|
||||
%c1_i32 = arith.constant 1 : index
|
||||
%c2_i32 = arith.constant 2 : i32
|
||||
%c32_i32 = arith.constant 32 : i32
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #C>
|
||||
%cst_1 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #AL>
|
||||
%78 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>, #AL>
|
||||
%110 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>, #AL>
|
||||
%112 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>, #AL>
|
||||
%113 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>, #AL>
|
||||
%116 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>, #AL>
|
||||
%65 = tt.splat %arg3 : (i32) -> tensor<1x32xi32, #AL>
|
||||
%88 = tt.splat %arg4 : (i32) -> tensor<32x1xi32, #AL>
|
||||
%80 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>, #AL>
|
||||
%119:5 = scf.for %arg9 = %c0_i32 to %118 step %c1_i32 iter_args(%arg10 = %cst, %arg11 = %78, %arg12 = %110, %arg13 = %113, %arg14 = %116) -> (tensor<32x32xf32, #C>, tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32x!tt.ptr<f32>, #AL>) {
|
||||
%161 = arith.index_cast %arg9 : index to i32
|
||||
%141 = arith.muli %161, %c32_i32 : i32
|
||||
%142 = arith.subi %arg5, %141 : i32
|
||||
%143 = tt.splat %142 : (i32) -> tensor<1x32xi32, #AL>
|
||||
%144 = "triton_gpu.cmpi"(%65, %143) <{predicate = 2 : i64}> : (tensor<1x32xi32, #AL>, tensor<1x32xi32, #AL>) -> tensor<1x32xi1, #AL>
|
||||
%145 = tt.broadcast %144 : (tensor<1x32xi1, #AL>) -> tensor<32x32xi1, #AL>
|
||||
%146 = tt.load %arg11, %145, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL>
|
||||
%147 = tt.splat %142 : (i32) -> tensor<32x1xi32, #AL>
|
||||
%148 = "triton_gpu.cmpi"(%88, %147) <{predicate = 2 : i64}> : (tensor<32x1xi32, #AL>, tensor<32x1xi32, #AL>) -> tensor<32x1xi1, #AL>
|
||||
%149 = tt.broadcast %148 : (tensor<32x1xi1, #AL>) -> tensor<32x32xi1, #AL>
|
||||
%150 = tt.load %arg12, %149, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL>
|
||||
%151 = triton_gpu.convert_layout %146 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>>
|
||||
%152 = triton_gpu.convert_layout %150 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>>
|
||||
%153 = tt.dot %151, %152, %arg10 {allowTF32 = true} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C>
|
||||
%162 = arith.index_cast %arg9 : index to i32
|
||||
%154 = arith.addi %162, %c2_i32 : i32
|
||||
%155 = arith.muli %154, %c32_i32 : i32
|
||||
%156 = tt.splat %155 : (i32) -> tensor<32x32xi32, #AL>
|
||||
%157 = tt.addptr %80, %156 : tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32xi32, #AL>
|
||||
%158 = arith.muli %155, %arg7 : i32
|
||||
%159 = tt.splat %158 : (i32) -> tensor<32x32xi32, #AL>
|
||||
%160 = tt.addptr %112, %159 : tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32xi32, #AL>
|
||||
scf.yield %153, %arg13, %arg14, %157, %160 : tensor<32x32xf32, #C>, tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32x!tt.ptr<f32>, #AL>
|
||||
}
|
||||
tt.return %119#0 : tensor<32x32xf32, #C>
|
||||
}
|
||||
|
||||
@@ -13,6 +13,13 @@ struct TestAliasPass
|
||||
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasPass);
|
||||
|
||||
static std::string getValueOperandName(Value value, AsmState &state) {
|
||||
std::string opName;
|
||||
llvm::raw_string_ostream ss(opName);
|
||||
value.printAsOperand(ss, state);
|
||||
return opName;
|
||||
}
|
||||
|
||||
static void print(StringRef name, SmallVector<std::string, 4> &vals,
|
||||
raw_ostream &os) {
|
||||
if (vals.empty())
|
||||
|
||||
1
third_party/intel_xpu_backend
vendored
Submodule
1
third_party/intel_xpu_backend
vendored
Submodule
Submodule third_party/intel_xpu_backend added at 0bcc485f82
6
utils/nightly.pypirc
Normal file
6
utils/nightly.pypirc
Normal file
@@ -0,0 +1,6 @@
|
||||
[distutils]
|
||||
Index-servers =
|
||||
Triton-Nightly
|
||||
|
||||
[Triton-Nightly]
|
||||
Repository = https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/upload/
|
||||
Reference in New Issue
Block a user