diff --git a/.flake8 b/.flake8 index edeaeec83..b70a02e26 100644 --- a/.flake8 +++ b/.flake8 @@ -1,2 +1,5 @@ [flake8] -ignore = E501,E701,E731 +# W503 (linebreak occurred before binary operator) seems to be enabled by +# default, even though it goes against pep8 and is incompatible with W504 +# (linebreak occurred *after* binary operator). Disable it. +ignore = E501,E701,E731,W503 diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index b4f4495cf..90b5750e7 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -34,7 +34,7 @@ jobs: fi - Integration-Tests-Nvidia: + Integration-Tests: needs: Runner-Preparation runs-on: ${{ matrix.runner }} @@ -73,10 +73,10 @@ jobs: run: | cd python python3 -m pip install --upgrade pip - python3 -m pip install cmake==3.24 - python3 -m pip install ninja - python3 -m pip install --no-build-isolation -vvv '.[tests]' - python3 -m pip install pytest-xdist + python3 -m pip install cmake==3.24 ninja pytest-xdist + sudo apt-get update -y + sudo apt-get install -y ccache clang lld + TRITON_BUILD_WITH_CLANG_LLD=true TRITON_BUILD_WITH_CCACHE=true python3 -m pip install --no-build-isolation -vvv '.[tests]' - name: Run lit tests if: ${{ env.BACKEND == 'CUDA'}} @@ -171,140 +171,8 @@ jobs: python3 -m pytest -vs . --reruns 10 sudo nvidia-smi -i 0 -rgc - Integration-Tests-Shared-Middle-Layer: - - runs-on: ubuntu-latest - - steps: - - name: Checkout - uses: actions/checkout@v2 - - - name: Clear cache - run: | - rm -rf ~/.triton - - - name: Update PATH - run: | - echo "PATH=${HOME}/.local/bin:${PATH}" >> "${GITHUB_ENV}" - - - name: Check pre-commit - run: | - python3 -m pip install --upgrade pre-commit - python3 -m pre_commit run --all-files --verbose - - - name: Install Triton - run: | - export TRITON_CODEGEN_TRITON_SHARED=1 - git submodule update --init --recursive - cd python - python3 -m pip install --upgrade pip - python3 -m pip install cmake==3.24 - python3 -m pip install ninja - python3 -m pip uninstall -y triton - python3 setup.py build - python3 -m pip install --no-build-isolation -vvv '.[tests]' - - - name: Run shared middle-layer lit tests - run: | - python3 -m pip install lit - cd python - LIT_TEST_DIR="build/$(ls build | grep -i cmake)/third_party/triton_shared/test" - if [ ! -d "${LIT_TEST_DIR}" ]; then - echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1 - fi - lit -v "${LIT_TEST_DIR}" - - - Integration-Tests-Third-Party: - needs: Runner-Preparation - if: false - - runs-on: ${{ matrix.runner }} - - strategy: - matrix: - runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-optional)}} - - steps: - - name: Checkout - uses: actions/checkout@v2 - - - name: Set ROCM ENV - if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'gfx908')}} - run: | - echo "BACKEND=ROCM" >> "${GITHUB_ENV}" - - - name: Set XPU ENV - if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'arc770')}} - run: | - echo "BACKEND=XPU" >> "${GITHUB_ENV}" - - - name: Clear cache - run: | - rm -rf ~/.triton - - - name: Update PATH - run: | - echo "PATH=${HOME}/.local/bin:${PATH}" >> "${GITHUB_ENV}" - - - name: Check pre-commit - if: ${{ matrix.runner != 'macos-10.15' && (matrix.runner[1] != 'arc770') }} - run: | - python3 -m pip install --upgrade pre-commit - python3 -m pre_commit run --all-files --verbose - - - name: Check pre-commit arc770 - if: ${{ matrix.runner != 'macos-10.15' && (matrix.runner[1] == 'arc770') }} - run: | - source ${HOME}/triton_vars.sh - source ${HOME}/miniconda3/bin/activate - conda activate triton-xpu-ci - python3 -m pip install --upgrade pre-commit - python3 -m pre_commit run --all-files - - - name: Install Triton on ROCM - if: ${{ env.BACKEND == 'ROCM'}} - run: | - git submodule update --init --recursive - cd python - python3 -m pip install --upgrade pip - python3 -m pip install cmake==3.24 - python3 -m pip install torch==1.13.1 --index-url https://download.pytorch.org/whl/rocm5.2 - export TRITON_CODEGEN_AMD_HIP_BACKEND=1 - python3 -m pip install --no-build-isolation -vvv '.[tests]' - - - name: Install Triton on XPU - if: ${{ env.BACKEND == 'XPU'}} - run: | - source ${HOME}/triton_vars.sh - source ${HOME}/miniconda3/bin/activate - conda activate triton-xpu-ci - git submodule update --init --recursive - cd python - python3 -m pip install --upgrade pip - python3 -m pip install cmake==3.24 - export TRITON_CODEGEN_INTEL_XPU_BACKEND=1 - python3 -m pip uninstall -y triton - python3 setup.py build - python3 -m pip install --no-build-isolation -vvv '.[tests]' - - - name: Run python tests on ROCM - if: ${{ env.BACKEND == 'ROCM'}} - run: | - cd python/test/unit/language - python3 -m pytest --capture=tee-sys -rfs --verbose "test_core.py" - - - name: Run python tests on XPU - if: ${{ env.BACKEND == 'XPU'}} - run: | - source ${HOME}/triton_vars.sh - source ${HOME}/miniconda3/bin/activate - conda activate triton-xpu-ci - cd python/test/backend/third_party_backends - python3 -m pytest --capture=tee-sys -rfs --verbose --backend xpu - Compare-artifacts: - needs: Integration-Tests-Nvidia + needs: Integration-Tests runs-on: ubuntu-latest @@ -333,7 +201,7 @@ jobs: - name: Download latest main artifacts env: ARTIFACT_NAME: artifacts A100 - ARTIFACT_JOB_NAME: Integration-Tests-Nvidia + ARTIFACT_JOB_NAME: Integration-Tests MAX_NUM_ACTIONS_PAGES: 30 GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | diff --git a/.github/workflows/llvm-build.yml b/.github/workflows/llvm-build.yml index 92394fec3..05b825501 100644 --- a/.github/workflows/llvm-build.yml +++ b/.github/workflows/llvm-build.yml @@ -18,43 +18,17 @@ permissions: jobs: build: + name: Build on ${{ matrix.config.runner }} + runs-on: ${{ matrix.config.runs_on }} strategy: fail-fast: true - matrix: - platform: [ - ubuntu-20.04-x64, - ubuntu-22.04-x64, - centos-7-x64, - macos-x64, - macos-arm64 - ] - - include: - # Specify OS versions - - platform: ubuntu-20.04-x64 - host-os: ubuntu-20.04 - target-os: ubuntu - arch: x64 - - platform: ubuntu-22.04-x64 - host-os: ubuntu-22.04 - target-os: ubuntu - arch: x64 - - platform: centos-7-x64 - host-os: ubuntu-22.04 - target-os: centos - arch: x64 - - platform: macos-x64 - host-os: macos-12 - target-os: macos - arch: x64 - - platform: macos-arm64 - host-os: macos-12 - target-os: macos - arch: arm64 - - runs-on: ${{ matrix.host-os }} + config: + - {runner: 'Ubuntu 20.04', runs_on: 'ubuntu-20.04', target-os: 'ubuntu', arch: 'x64'} + - {runner: 'CentOS 7', runs_on: ['self-hosted', 'CPU'], target-os: 'centos', arch: 'x64'} + - {runner: 'MacOS X64', runs_on: 'macos-12', target-os: 'macos', arch: 'x64'} + - {runner: 'MacOS ARM64', runs_on: 'macos-12', target-os: 'macos', arch: 'arm64'} steps: @@ -73,7 +47,7 @@ jobs: echo "Short LLVM commit hash: ${SHORT_LLVM_COMMIT_HASH}" echo "short_llvm_commit_hash=${SHORT_LLVM_COMMIT_HASH}" >> ${GITHUB_ENV} - INSTALL_DIR="llvm-${SHORT_LLVM_COMMIT_HASH}-${{ matrix.platform }}" + INSTALL_DIR="llvm-${SHORT_LLVM_COMMIT_HASH}-${{ matrix.config.target-os }}-${{ matrix.config.arch }}" echo "LLVM installation directory name: ${INSTALL_DIR}" echo "llvm_install_dir=${INSTALL_DIR}" >> ${GITHUB_ENV} @@ -99,11 +73,11 @@ jobs: uses: actions/cache@v3 with: path: ${{ env.SCCACHE_DIR }} - key: ${{ matrix.platform }}-${{ env.short_llvm_commit_hash }} - restore-keys: ${{ matrix.platform }}- + key: ${{ matrix.config.target-os }}-${{ matrix.config.arch }}-${{ env.short_llvm_commit_hash }} + restore-keys: ${{ matrix.config.target-os }}-${{ matrix.config.arch }}- - name: Configure, Build, Test, and Install LLVM (Ubuntu and macOS x64) - if: matrix.arch == 'x64' && contains(fromJSON('["ubuntu", "macos"]'), matrix.target-os) + if: matrix.config.arch == 'x64' && (matrix.config.target-os == 'ubuntu' || matrix.config.target-os == 'macos') run: > python3 -m pip install -r llvm-project/mlir/python/requirements.txt @@ -114,11 +88,13 @@ jobs: -DCMAKE_INSTALL_PREFIX="${{ env.llvm_install_dir }}" -DCMAKE_LINKER=lld -DLLVM_BUILD_UTILS=ON + -DLLVM_BUILD_TOOLS=ON -DLLVM_ENABLE_ASSERTIONS=ON -DMLIR_ENABLE_BINDINGS_PYTHON=ON -DLLVM_ENABLE_PROJECTS=mlir -DLLVM_INSTALL_UTILS=ON -DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU" + -DLLVM_ENABLE_TERMINFO=OFF llvm-project/llvm ninja -C llvm-project/build check-mlir install @@ -126,7 +102,7 @@ jobs: tar czf "${{ env.llvm_install_dir }}.tar.gz" "${{ env.llvm_install_dir }}" - name: Configure, Build, and Install LLVM (macOS arm64) - if: matrix.arch == 'arm64' && matrix.target-os == 'macos' + if: matrix.config.arch == 'arm64' && matrix.config.target-os == 'macos' run: > python3 -m pip install -r llvm-project/mlir/python/requirements.txt @@ -138,6 +114,7 @@ jobs: -DCMAKE_LINKER=lld -DCMAKE_OSX_ARCHITECTURES=arm64 -DLLVM_BUILD_UTILS=ON + -DLLVM_BUILD_TOOLS=ON -DLLVM_ENABLE_ASSERTIONS=ON -DMLIR_ENABLE_BINDINGS_PYTHON=ON -DLLVM_ENABLE_PROJECTS=mlir @@ -145,6 +122,7 @@ jobs: -DLLVM_INSTALL_UTILS=ON -DLLVM_TARGETS_TO_BUILD="AArch64" -DLLVM_USE_HOST_TOOLS=ON + -DLLVM_ENABLE_TERMINFO=OFF llvm-project/llvm ninja -C llvm-project/build install @@ -152,10 +130,10 @@ jobs: tar czf "${{ env.llvm_install_dir }}.tar.gz" "${{ env.llvm_install_dir }}" - name: Configure, Build, Test, and Install LLVM (CentOS) - if: matrix.target-os == 'centos' + if: matrix.config.target-os == 'centos' run: | docker build --tag llvm-build --build-arg llvm_dir=llvm-project \ - -f llvm-build/.github/workflows/Dockerfile . + -f llvm-build/.github/workflows/llvm-build/Dockerfile . # Create temporary container to copy cache and installed artifacts. CONTAINER_ID=$(docker create llvm-build) diff --git a/.github/workflows/Dockerfile b/.github/workflows/llvm-build/Dockerfile similarity index 60% rename from .github/workflows/Dockerfile rename to .github/workflows/llvm-build/Dockerfile index 3ef0f90e3..cd7e852a8 100644 --- a/.github/workflows/Dockerfile +++ b/.github/workflows/llvm-build/Dockerfile @@ -1,16 +1,17 @@ FROM centos:7 ARG llvm_dir=llvm-project - # Add the cache artifacts and the LLVM source tree to the container ADD sccache /sccache ADD "${llvm_dir}" /source/llvm-project ENV SCCACHE_DIR="/sccache" ENV SCCACHE_CACHE_SIZE="2G" +RUN echo -e "[llvmtoolset-build]\nname=LLVM Toolset 13.0 - Build\nbaseurl=https://buildlogs.centos.org/c7-llvm-toolset-13.0.x86_64/\ngpgcheck=0\nenabled=1" > /etc/yum.repos.d/llvmtoolset-build.repo # Install build dependencies RUN yum install --assumeyes centos-release-scl -RUN yum install --assumeyes devtoolset-9-gcc* python3-devel python3-pip -SHELL [ "/usr/bin/scl", "enable", "devtoolset-9" ] +RUN yum install --assumeyes --nogpgcheck llvm-toolset-13.0 +RUN yum install --assumeyes rh-python38-python-devel rh-python38-python-pip +SHELL [ "/usr/bin/scl", "enable", "llvm-toolset-13.0", "rh-python38" ] RUN python3 -m pip install --upgrade pip RUN python3 -m pip install --upgrade cmake ninja sccache @@ -21,17 +22,22 @@ RUN python3 -m pip install -r /source/llvm-project/mlir/python/requirements.txt # Configure, Build, Test, and Install LLVM RUN cmake -GNinja -Bbuild \ -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_C_COMPILER=gcc \ - -DCMAKE_CXX_COMPILER=g++ \ + -DCMAKE_C_COMPILER=clang \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_ASM_COMPILER=clang \ -DCMAKE_C_COMPILER_LAUNCHER=sccache \ -DCMAKE_CXX_COMPILER_LAUNCHER=sccache \ + -DCMAKE_CXX_FLAGS="-Wno-everything" \ + -DCMAKE_LINKER=lld \ -DCMAKE_INSTALL_PREFIX="/install" \ -DLLVM_BUILD_UTILS=ON \ + -DLLVM_BUILD_TOOLS=ON \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DLLVM_ENABLE_PROJECTS=mlir \ + -DLLVM_ENABLE_TERMINFO=OFF \ -DLLVM_INSTALL_UTILS=ON \ -DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU" \ /source/llvm-project/llvm -RUN ninja -C build check-mlir install +RUN ninja -C build install diff --git a/.github/workflows/third-party/integration-tests.yml b/.github/workflows/third-party/integration-tests.yml new file mode 100644 index 000000000..324eec3bb --- /dev/null +++ b/.github/workflows/third-party/integration-tests.yml @@ -0,0 +1,167 @@ +name: Integration Tests + +on: + workflow_dispatch: + pull_request: + branches: [main] + merge_group: + branches: [main] + types: [checks_requested] + +concurrency: + group: ${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/master' }} + +env: + TRITON_USE_ASSERT_ENABLED_LLVM: "TRUE" + +jobs: + Runner-Preparation: + runs-on: ubuntu-latest + outputs: + matrix-required: ${{ steps.set-matrix.outputs.matrix-required }} + matrix-optional: ${{ steps.set-matrix.outputs.matrix-optional }} + steps: + - name: Prepare runner matrix + id: set-matrix + run: | + if [ x"${{ github.repository }}" == x"openai/triton" ]; then + echo '::set-output name=matrix-required::[["self-hosted", "A100"], ["self-hosted", "H100"]]' + echo '::set-output name=matrix-optional::[["self-hosted", "gfx908"], ["self-hosted", "arc770"]]' + else + echo '::set-output name=matrix-required::["ubuntu-latest"]' + echo '::set-output name=matrix-optional::["ubuntu-latest"]' + fi + + + Integration-Tests-Shared-Middle-Layer: + + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Clear cache + run: | + rm -rf ~/.triton + + - name: Update PATH + run: | + echo "PATH=${HOME}/.local/bin:${PATH}" >> "${GITHUB_ENV}" + + - name: Check pre-commit + run: | + python3 -m pip install --upgrade pre-commit + python3 -m pre_commit run --all-files --verbose + + - name: Install Triton + run: | + export TRITON_CODEGEN_TRITON_SHARED=1 + git submodule update --init --recursive + cd python + python3 -m pip install --upgrade pip + python3 -m pip install cmake==3.24 + python3 -m pip install ninja + python3 -m pip uninstall -y triton + python3 setup.py build + python3 -m pip install --no-build-isolation -vvv '.[tests]' + + - name: Run shared middle-layer lit tests + run: | + python3 -m pip install lit + cd python + LIT_TEST_DIR="build/$(ls build | grep -i cmake)/third_party/triton_shared/test" + if [ ! -d "${LIT_TEST_DIR}" ]; then + echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1 + fi + lit -v "${LIT_TEST_DIR}" + + + Integration-Tests-Third-Party: + needs: Runner-Preparation + if: false + + runs-on: ${{ matrix.runner }} + + strategy: + matrix: + runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-optional)}} + + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Set ROCM ENV + if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'gfx908')}} + run: | + echo "BACKEND=ROCM" >> "${GITHUB_ENV}" + + - name: Set XPU ENV + if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'arc770')}} + run: | + echo "BACKEND=XPU" >> "${GITHUB_ENV}" + + - name: Clear cache + run: | + rm -rf ~/.triton + + - name: Update PATH + run: | + echo "PATH=${HOME}/.local/bin:${PATH}" >> "${GITHUB_ENV}" + + - name: Check pre-commit + if: ${{ matrix.runner != 'macos-10.15' && (matrix.runner[1] != 'arc770') }} + run: | + python3 -m pip install --upgrade pre-commit + python3 -m pre_commit run --all-files --verbose + + - name: Check pre-commit arc770 + if: ${{ matrix.runner != 'macos-10.15' && (matrix.runner[1] == 'arc770') }} + run: | + source ${HOME}/triton_vars.sh + source ${HOME}/miniconda3/bin/activate + conda activate triton-xpu-ci + python3 -m pip install --upgrade pre-commit + python3 -m pre_commit run --all-files + + - name: Install Triton on ROCM + if: ${{ env.BACKEND == 'ROCM'}} + run: | + git submodule update --init --recursive + cd python + python3 -m pip install --upgrade pip + python3 -m pip install cmake==3.24 + python3 -m pip install torch==1.13.1 --index-url https://download.pytorch.org/whl/rocm5.2 + export TRITON_CODEGEN_AMD_HIP_BACKEND=1 + python3 -m pip install --no-build-isolation -vvv '.[tests]' + + - name: Install Triton on XPU + if: ${{ env.BACKEND == 'XPU'}} + run: | + source ${HOME}/triton_vars.sh + source ${HOME}/miniconda3/bin/activate + conda activate triton-xpu-ci + git submodule update --init --recursive + cd python + python3 -m pip install --upgrade pip + python3 -m pip install cmake==3.24 + export TRITON_CODEGEN_INTEL_XPU_BACKEND=1 + python3 -m pip uninstall -y triton + python3 setup.py build + python3 -m pip install --no-build-isolation -vvv '.[tests]' + + - name: Run python tests on ROCM + if: ${{ env.BACKEND == 'ROCM'}} + run: | + cd python/test/unit/language + python3 -m pytest --capture=tee-sys -rfs --verbose "test_core.py" + + - name: Run python tests on XPU + if: ${{ env.BACKEND == 'XPU'}} + run: | + source ${HOME}/triton_vars.sh + source ${HOME}/miniconda3/bin/activate + conda activate triton-xpu-ci + cd python/test/backend/third_party_backends + python3 -m pytest --capture=tee-sys -rfs --verbose --backend xpu diff --git a/.github/workflows/torch-inductor-tests.yml b/.github/workflows/torch-inductor-tests.yml index 73a45c797..75ceae545 100644 --- a/.github/workflows/torch-inductor-tests.yml +++ b/.github/workflows/torch-inductor-tests.yml @@ -1,7 +1,9 @@ name: Torchinductor on: - workflow_dispatch: + workflow_run: + workflows: ["Wheel"] + types: [completed] jobs: Runner-Preparation: @@ -23,17 +25,17 @@ jobs: steps: - name: Checkout uses: actions/checkout@v2 - #- name: Packages - # run: | - # ./.github/workflows/torchinductor/scripts/install_torchinductor.sh + - name: Packages + run: | + ./.github/workflows/torch-inductor/scripts/install_torchinductor.sh torchbench - name: Environment run: | source /opt/torchinductor_venv/bin/activate - ./.github/workflows/torchinductor/scripts/install_triton.sh + ./.github/workflows/torch-inductor/scripts/install_triton.sh - name: Performance run: | - ./.github/workflows/torchinductor/scripts/run_torchinductor_perf.sh + ./.github/workflows/torch-inductor/scripts/run_torchinductor_perf.sh torchbench # Runs too long time #- name: Accuracy # run: | - # ./.github/workflows/torchinductor/scripts/run_torchinductor_acc.sh + # ./.github/workflows/torch-inductor/scripts/run_torchinductor_acc.sh diff --git a/.github/workflows/torchinductor/scripts/check_acc.py b/.github/workflows/torch-inductor/scripts/check_acc.py similarity index 100% rename from .github/workflows/torchinductor/scripts/check_acc.py rename to .github/workflows/torch-inductor/scripts/check_acc.py diff --git a/.github/workflows/torchinductor/scripts/check_perf.py b/.github/workflows/torch-inductor/scripts/check_perf.py similarity index 62% rename from .github/workflows/torchinductor/scripts/check_perf.py rename to .github/workflows/torch-inductor/scripts/check_perf.py index 6c34ac28f..a966b7be2 100644 --- a/.github/workflows/torchinductor/scripts/check_perf.py +++ b/.github/workflows/torch-inductor/scripts/check_perf.py @@ -3,8 +3,7 @@ import csv from collections import namedtuple # Create a named tuple for the output of the benchmark -BenchmarkOutput = namedtuple( - 'BenchmarkOutput', ['dev', 'name', 'batch_size', 'speedup', 'latency']) +BenchmarkOutput = namedtuple('BenchmarkOutput', ['dev', 'name', 'batch_size', 'speedup', 'latency']) def parse_output(file_path: str) -> dict: @@ -19,13 +18,11 @@ def parse_output(file_path: str) -> dict: batch_size = row[2] speedup = float(row[3]) latency = float(row[4]) - entries[name] = BenchmarkOutput( - dev, name, batch_size, speedup, latency) + entries[name] = BenchmarkOutput(dev, name, batch_size, speedup, latency) return entries -def compare(baseline: dict, new: dict, threshold: float, - geomean_threshold: float) -> bool: +def compare(baseline: dict, new: dict, threshold: float, geomean_threshold: float) -> bool: baseline_geomean = 1.0 new_geomean = 1.0 for key in new: @@ -33,20 +30,27 @@ def compare(baseline: dict, new: dict, threshold: float, print(f"New benchmark {key} not found in baseline") baseline_latency = baseline[key].latency new_latency = new[key].latency + if baseline_latency == 0: + print(f"Baseline latency for {key} is 0") + continue + elif new_latency == 0: + print(f"New latency for {key} is 0") + continue + if new_latency < baseline_latency * (1 - threshold): - print( - f"New benchmark {key} is faster than baseline: {new_latency} vs {baseline_latency}") + print(f"New benchmark {key} is faster than baseline: {new_latency} vs {baseline_latency}") elif new_latency > baseline_latency * (1 + threshold): - print( - f"New benchmark {key} is slower than baseline: {new_latency} vs {baseline_latency}") + print(f"New benchmark {key} is slower than baseline: {new_latency} vs {baseline_latency}") + else: + print(f"New benchmark {key} is within threshold: {new_latency} vs {baseline_latency}") baseline_geomean *= baseline[key].speedup new_geomean *= new[key].speedup - baseline_geomean = baseline_geomean ** (1 / len(baseline)) - new_geomean = new_geomean ** (1 / len(new)) + baseline_geomean = baseline_geomean**(1 / len(baseline)) + new_geomean = new_geomean**(1 / len(new)) print(f"Baseline geomean: {baseline_geomean}") print(f"New geomean: {new_geomean}") - assert new_geomean > baseline_geomean * (1 - geomean_threshold), \ + assert new_geomean >= baseline_geomean * (1 - geomean_threshold), \ f"New geomean is slower than baseline: {new_geomean} vs {baseline_geomean}" diff --git a/.github/workflows/torchinductor/scripts/common.sh b/.github/workflows/torch-inductor/scripts/common.sh similarity index 100% rename from .github/workflows/torchinductor/scripts/common.sh rename to .github/workflows/torch-inductor/scripts/common.sh diff --git a/.github/workflows/torch-inductor/scripts/install_torchinductor.sh b/.github/workflows/torch-inductor/scripts/install_torchinductor.sh new file mode 100755 index 000000000..4dc0d5209 --- /dev/null +++ b/.github/workflows/torch-inductor/scripts/install_torchinductor.sh @@ -0,0 +1,70 @@ +#!/bin/bash + +# remember where we started +ROOT="$(pwd)" +MODEL_SPEC=$1 + +# torchinductor venv +whoami +# clean up old venv +rm -rf /opt/torchinductor_venv +python3 -m venv /opt/torchinductor_venv +# shellcheck source=/dev/null +source /opt/torchinductor_venv/bin/activate +# shellcheck source=/dev/null +source ./.github/workflows/torch-inductor/scripts/common.sh + +# pytorch nightly +pip3 install --force-reinstall --pre torch torchtext torchvision torchaudio torchrec --extra-index-url https://download.pytorch.org/whl/nightly/cu121 +# pytorch source to get torchbench for dynamo +cd /opt || exit +# cleanup old pytorch +rm -rf pytorch +git clone --recursive https://github.com/pytorch/pytorch +cd pytorch || exit +# if you are updating an existing checkout +git submodule sync +git submodule update --init --recursive +cd .. + +# required packages +# https://github.com/pytorch/benchmark/blob/main/docker/gcp-a100-runner-dind.dockerfile#L17 +sudo apt-get install --yes libpango-1.0-0 libpangoft2-1.0-0 +pip3 install --upgrade pip +pip3 install expecttest psutil lightning-utilities pyre_extensions + +# torchbench +if [ "$MODEL_SPEC" == "torchbench" ] || [ "$MODEL_SPEC" != "all" ]; then + # clean up old torchbench + rm -rf benchmark + pip3 install pyyaml + git clone https://github.com/pytorch/benchmark.git + cd benchmark || exit + python3 install.py + cd .. +fi + +# timm +if [ "$MODEL_SPEC" == "timm_models" ] || [ "$MODEL_SPEC" != "all" ]; then + # clean up old timm + rm -rf pytorch-image-models + git clone https://github.com/huggingface/pytorch-image-models.git + cd pytorch-image-models || exit + pip3 install -e . + cd .. +fi + +# build our own triton +cd "$ROOT" || exit +cd python || exit +rm -rf build +pip3 install -e . +pip3 uninstall pytorch-triton -y + +# clean up cache +rm -rf /tmp/torchinductor_root/ +rm -rf ~/.triton/cache +rm -rf "$TEST_REPORTS_DIR" + +# go back to where we started +cd "$ROOT" || exit diff --git a/.github/workflows/torchinductor/scripts/install_triton.sh b/.github/workflows/torch-inductor/scripts/install_triton.sh similarity index 89% rename from .github/workflows/torchinductor/scripts/install_triton.sh rename to .github/workflows/torch-inductor/scripts/install_triton.sh index 13ae189fd..857c2fc15 100755 --- a/.github/workflows/torchinductor/scripts/install_triton.sh +++ b/.github/workflows/torch-inductor/scripts/install_triton.sh @@ -6,7 +6,7 @@ ROOT="$(pwd)" # shellcheck source=/dev/null source /opt/torchinductor_venv/bin/activate # shellcheck source=/dev/null -source ./.github/workflows/torchinductor/scripts/common.sh +source ./.github/workflows/torch-inductor/scripts/common.sh # build our own triton cd python || exit diff --git a/.github/workflows/torchinductor/scripts/run_torchinductor_acc.sh b/.github/workflows/torch-inductor/scripts/run_torchinductor_acc.sh similarity index 84% rename from .github/workflows/torchinductor/scripts/run_torchinductor_acc.sh rename to .github/workflows/torch-inductor/scripts/run_torchinductor_acc.sh index 1d03cbd46..e1e519d0f 100755 --- a/.github/workflows/torchinductor/scripts/run_torchinductor_acc.sh +++ b/.github/workflows/torch-inductor/scripts/run_torchinductor_acc.sh @@ -2,7 +2,8 @@ # remember where we started ROOT="$(pwd)" -INDUCTOR="$ROOT"/.github/workflows/torchinductor +INDUCTOR="$ROOT"/.github/workflows/torch-inductor +MODEL_SPEC=$1 # shellcheck source=/dev/null source /opt/torchinductor_venv/bin/activate @@ -14,6 +15,9 @@ TEST_REPORTS_DIR=$TEST_REPORTS_DIR/acc mkdir -p "$TEST_REPORTS_DIR" for model in "${MODELS[@]}"; do + if [ "$model" != "$MODEL_SPEC" ] && [ "$MODEL_SPEC" != "all" ]; then + continue + fi echo "Running accuracy test for $model" python3 benchmarks/dynamo/"$model".py --ci --accuracy --timing --explain --inductor --device cuda \ --output "$TEST_REPORTS_DIR"/inference_"$model".csv @@ -25,6 +29,9 @@ done cd "$ROOT" || exit for model in "${MODELS[@]}"; do + if [ "$model" != "$MODEL_SPEC" ] && [ "$MODEL_SPEC" != "all" ]; then + continue + fi echo "Checking accuracy test for $model" python3 "$INDUCTOR"/scripts/check_acc.py "$TEST_REPORTS_DIR"/inference_"$model".csv python3 "$INDUCTOR"/scripts/check_acc.py "$TEST_REPORTS_DIR"/training_"$model".csv diff --git a/.github/workflows/torch-inductor/scripts/run_torchinductor_perf.sh b/.github/workflows/torch-inductor/scripts/run_torchinductor_perf.sh new file mode 100755 index 000000000..0a80264c9 --- /dev/null +++ b/.github/workflows/torch-inductor/scripts/run_torchinductor_perf.sh @@ -0,0 +1,69 @@ +#!/bin/bash + +# remember where we started +ROOT="$(pwd)" +INDUCTOR="$ROOT"/.github/workflows/torch-inductor +MODEL_SPEC=$1 + +# shellcheck source=/dev/null +source /opt/torchinductor_venv/bin/activate +# shellcheck source=/dev/null +source "$INDUCTOR"/scripts/common.sh + +# lock GPU clocks to 1350 MHz +sudo nvidia-smi -i 0 -pm 1 +sudo nvidia-smi -i 0 --lock-gpu-clocks=1350,1350 + +cd "$PYTORCH_DIR" || exit +TRITON_TEST_REPORTS_DIR=$TEST_REPORTS_DIR/perf +BASE_TEST_REPORTS_DIR=$TEST_REPORTS_DIR/acc +mkdir -p "$TRITON_TEST_REPORTS_DIR" +mkdir -p "$BASE_TEST_REPORTS_DIR" + + +echo "Running with Triton Nightly" +for model in "${MODELS[@]}"; do + if [ "$model" != "$MODEL_SPEC" ] && [ "$MODEL_SPEC" != "all" ]; then + continue + fi + echo "Running performance test for $model" + python3 benchmarks/dynamo/"$model".py --float32 -dcuda --training --inductor --performance \ + --output "$TRITON_TEST_REPORTS_DIR"/"$model".csv +done + +# install pytorch-triton +pip3 uninstall triton -y +pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu121 + +echo "Running with pytorch-triton" +for model in "${MODELS[@]}"; do + if [ "$model" != "$MODEL_SPEC" ] && [ "$MODEL_SPEC" != "all" ]; then + continue + fi + echo "Running performance test for $model" + python3 benchmarks/dynamo/"$model".py --float32 -dcuda --training --inductor --performance \ + --output "$BASE_TEST_REPORTS_DIR"/"$model".csv +done + +# uninstall pytorch-triton +pip3 uninstall pytorch-triton -y + +cd "$ROOT" || exit +for model in "${MODELS[@]}"; do + if [ "$model" != "$MODEL_SPEC" ] && [ "$MODEL_SPEC" != "all" ]; then + continue + fi + echo "Checking performance test for $model" + python3 "$INDUCTOR"/scripts/check_perf.py --new "$TRITON_TEST_REPORTS_DIR"/"$model".csv --baseline "$BASE_TEST_REPORTS_DIR"/"$model".csv + EXIT_STATUS=$? + if [ "$EXIT_STATUS" -ne 0 ]; then + echo "Performance test for $model failed" + exit "$EXIT_STATUS" + fi +done + +# unlock GPU clocks +sudo nvidia-smi -i 0 -rgc + +# go back to where we started +cd "$ROOT" || exit diff --git a/.github/workflows/torchinductor/data/huggingface.csv b/.github/workflows/torchinductor/data/huggingface.csv deleted file mode 100644 index bda10ffcf..000000000 --- a/.github/workflows/torchinductor/data/huggingface.csv +++ /dev/null @@ -1,37 +0,0 @@ -dev,name,batch_size,speedup,abs_latency,compilation_latency,compression_ratio -cuda,AlbertForMaskedLM,4,1.5511,164.3373,26.8523,1.2647 -cuda,AlbertForQuestionAnswering,4,1.5501,163.5580,25.7983,1.3145 -cuda,BartForCausalLM,4,1.5080,71.7230,32.8907,0.9749 -cuda,BertForMaskedLM,16,1.5350,67.9451,35.3286,1.0494 -cuda,BertForQuestionAnswering,16,1.6735,53.2963,34.3754,1.1710 -cuda,BlenderbotSmallForCausalLM,64,1.2106,46.6466,23.8058,0.9120 -cuda,BlenderbotSmallForConditionalGeneration,64,1.3616,77.3013,55.3546,0.9803 -cuda,CamemBert,16,1.4779,76.1809,35.3883,1.0469 -cuda,DebertaForMaskedLM,4,0.8415,62.3395,35.9657,1.0418 -cuda,DebertaForQuestionAnswering,8,1.0609,67.5151,35.7728,1.1528 -cuda,DebertaV2ForMaskedLM,1,0.6026,134.6517,66.1783,0.9773 -cuda,DistilBertForMaskedLM,128,1.2460,66.9382,18.3089,0.9624 -cuda,DistilBertForQuestionAnswering,256,1.3997,72.4126,18.1956,1.1486 -cuda,DistillGPT2,16,1.6656,60.5455,17.2280,1.0641 -cuda,ElectraForCausalLM,32,1.8299,45.4841,37.0944,0.9717 -cuda,ElectraForQuestionAnswering,64,2.0289,52.6890,35.9632,1.1928 -cuda,GPT2ForSequenceClassification,4,2.2567,38.2969,30.0527,1.2323 -cuda,LayoutLMForMaskedLM,16,1.5423,68.8018,36.5562,1.0495 -cuda,LayoutLMForSequenceClassification,16,1.7058,53.9355,35.2225,1.1659 -cuda,MBartForCausalLM,4,1.4945,71.4649,32.8653,0.9830 -cuda,MegatronBertForCausalLM,4,1.4328,58.4404,70.6226,1.0951 -cuda,MegatronBertForQuestionAnswering,8,1.5886,85.2533,69.1219,1.1152 -cuda,MobileBertForMaskedLM,64,0.9007,131.7379,107.5275,1.0136 -cuda,MobileBertForQuestionAnswering,128,0.8435,167.9066,106.7049,0.8579 -cuda,PLBartForCausalLM,8,1.5261,68.9224,19.5826,0.9887 -cuda,PLBartForConditionalGeneration,4,1.5298,71.2811,45.6902,1.0495 -cuda,PegasusForCausalLM,32,1.2212,57.5436,33.3863,0.9736 -cuda,PegasusForConditionalGeneration,32,1.2822,106.4678,69.8825,1.0689 -cuda,RobertaForCausalLM,16,1.6128,67.5706,34.7355,1.0496 -cuda,RobertaForQuestionAnswering,16,1.6800,53.6267,33.8527,1.1704 -cuda,Speech2Text2ForCausalLM,256,1.8230,32.9145,18.7201,0.8760 -cuda,T5ForConditionalGeneration,4,1.6592,59.5324,39.4406,1.1814 -cuda,T5Small,4,1.6581,59.5930,37.0471,1.1814 -cuda,TrOCRForCausalLM,32,1.2586,106.2633,32.5330,0.9583 -cuda,XLNetLMHeadModel,8,1.8108,142.8795,84.8197,1.1240 -cuda,YituTechConvBert,16,1.5207,81.4595,53.1565,1.0362 diff --git a/.github/workflows/torchinductor/data/timm_models.csv b/.github/workflows/torchinductor/data/timm_models.csv deleted file mode 100644 index 3591a866d..000000000 --- a/.github/workflows/torchinductor/data/timm_models.csv +++ /dev/null @@ -1,54 +0,0 @@ -dev,name,batch_size,speedup,abs_latency,compilation_latency,compression_ratio -cuda,adv_inception_v3,128,1.5923,102.5292,51.6032,1.0472 -cuda,beit_base_patch16_224,64,1.3390,75.3027,29.7471,1.0156 -cuda,coat_lite_mini,128,2.0579,53.3689,37.1856,1.0437 -cuda,convmixer_768_32,32,1.0470,275.5328,23.8037,0.9999 -cuda,convnext_base,64,1.5084,80.1811,42.5659,1.0373 -cuda,crossvit_9_240,128,1.5392,37.1806,44.9986,0.9193 -cuda,cspdarknet53,64,1.4721,75.0403,35.2882,1.0547 -cuda,deit_base_distilled_patch16_224,64,1.1432,55.9737,23.4038,0.9816 -cuda,dla102,128,1.5282,123.7284,49.3612,1.0430 -cuda,dm_nfnet_f0,128,1.4354,79.7518,34.8994,1.1038 -cuda,dpn107,32,1.2412,83.8921,58.9111,0.9952 -cuda,eca_botnext26ts_256,128,1.5425,71.2406,28.8920,1.0270 -cuda,ese_vovnet19b_dw,128,1.4647,42.4837,18.0285,1.0135 -cuda,fbnetc_100,128,1.5795,53.8033,33.0222,1.0082 -cuda,gernet_l,128,1.1684,63.4230,26.8687,1.0053 -cuda,ghostnet_100,128,1.7812,54.4211,47.6168,1.0484 -cuda,gluon_inception_v3,128,1.5952,102.5018,50.0857,1.0469 -cuda,gmixer_24_224,128,1.6749,69.2430,42.0841,1.1921 -cuda,gmlp_s16_224,128,1.5886,79.2132,43.0142,1.2343 -cuda,hrnet_w18,128,1.3743,221.5304,134.2573,1.0100 -cuda,inception_v3,128,1.5847,102.8333,49.7648,1.0472 -cuda,jx_nest_base,32,1.3747,71.4190,61.4053,0.9905 -cuda,lcnet_050,128,1.8159,18.0047,18.8249,1.0005 -cuda,mixer_b16_224,128,1.2795,90.9229,21.0438,1.0133 -cuda,mixnet_l,128,1.2273,149.9722,47.7482,1.0129 -cuda,mnasnet_100,128,1.6594,40.0512,26.5165,1.0047 -cuda,mobilenetv2_100,128,1.6085,41.1217,27.4450,1.1731 -cuda,mobilenetv3_large_100,128,1.6610,37.9995,29.8185,1.0052 -cuda,mobilevit_s,64,1.5212,55.4152,53.6475,1.0258 -cuda,nfnet_l0,128,1.4927,65.7078,32.4067,0.9980 -cuda,pit_b_224,64,1.2286,57.9484,26.5321,0.9606 -cuda,pnasnet5large,16,1.0000,198.2494,93.4641,1.3184 -cuda,poolformer_m36,64,1.3486,103.9235,62.3196,1.1942 -cuda,regnety_002,128,1.3030,32.4968,27.2439,1.0014 -cuda,repvgg_a2,128,1.2485,59.7729,26.9209,1.0185 -cuda,res2net101_26w_4s,64,1.0813,94.1773,86.6520,0.9655 -cuda,res2net50_14w_8s,128,1.3251,109.5258,79.9578,0.9830 -cuda,res2next50,128,1.2518,125.5008,43.9754,0.9756 -cuda,resmlp_12_224,128,1.3060,45.2373,19.3709,1.1048 -cuda,resnest101e,64,1.4346,108.1945,78.1993,1.1037 -cuda,rexnet_100,128,1.4637,55.0121,41.2075,1.0862 -cuda,selecsls42b,128,1.4284,44.6645,23.3892,1.0139 -cuda,spnasnet_100,128,1.5908,45.3189,32.0148,1.0048 -cuda,swin_base_patch4_window7_224,64,1.6164,89.5854,75.5848,0.9299 -cuda,swsl_resnext101_32x16d,32,1.0175,110.0041,45.7853,1.0003 -cuda,tf_efficientnet_b0,128,1.5271,55.7361,34.5551,1.1079 -cuda,tf_mixnet_l,128,1.2369,155.9027,48.6695,1.0921 -cuda,tinynet_a,128,1.3792,53.0640,40.6346,1.1108 -cuda,tnt_s_patch16_224,128,3.1078,104.8486,59.6028,1.0660 -cuda,twins_pcpvt_base,64,1.5921,67.4600,84.4977,1.0909 -cuda,visformer_small,128,1.1952,72.8705,23.7303,1.0410 -cuda,vit_base_patch16_224,64,1.1309,56.4866,22.0208,0.9804 -cuda,volo_d1_224,64,1.6868,72.0957,65.3011,0.9729 diff --git a/.github/workflows/torchinductor/data/torchbench.csv b/.github/workflows/torchinductor/data/torchbench.csv deleted file mode 100644 index 9b8e6ff2b..000000000 --- a/.github/workflows/torchinductor/data/torchbench.csv +++ /dev/null @@ -1,53 +0,0 @@ -dev,name,batch_size,speedup,abs_latency,compilation_latency,compression_ratio -cuda,BERT_pytorch,16,1.7111,24.2741,35.7065,1.3212 -cuda,LearningToPaint,96,1.0513,10.7557,11.1879,0.9896 -cuda,Super_SloMo,6,1.3267,60.4328,28.2097,1.2392 -cuda,alexnet,128,1.1754,8.3246,5.3319,1.0003 -cuda,attention_is_all_you_need_pytorch,256,1.3416,36.4401,39.5927,1.1774 -cuda,dcgan,32,0.9151,2.6249,3.2964,1.0082 -cuda,densenet121,4,0.9225,51.3747,68.5841,0.9930 -cuda,doctr_det_predictor,0,0.0000 -cuda,doctr_reco_predictor,0,0.0000 -cuda,drq,1,0.9500,3.4884,4.8028,0.9687 -cuda,fastNLP_Bert,6,1.4328,34.7753,35.4863,1.2368 -cuda,functorch_dp_cifar10,64,1.2015,8.1625,12.9040,1.0609 -cuda,functorch_maml_omniglot,1,0.9322,2.5844,3.8640,1.0000 -cuda,hf_Albert,8,2.1228,30.3377,26.8282,1.2676 -cuda,hf_Bart,4,1.2899,39.1935,47.2373,1.0080 -cuda,hf_Bert,4,1.3262,26.1063,35.0281,1.0656 -cuda,hf_Bert_large,4,1.4163,55.1021,67.2825,1.0915 -cuda,hf_DistilBert,8,1.4051,21.7191,18.0399,1.0242 -cuda,hf_GPT2,4,1.6661,26.9039,29.9473,1.1555 -cuda,hf_Longformer,0,0.0000 -cuda,hf_Reformer,4,1.1709,64.6979,15.7035,0.9267 -cuda,hf_T5_large,2,1.7215,107.0798,148.8805,1.1684 -cuda,lennard_jones,1000,0.8428,1.8488,3.0609,1.0001 -cuda,maml_omniglot,32,0.9648,2.6869,3.9775,0.9999 -cuda,mnasnet1_0,32,1.0469,21.6251,25.8232,0.9996 -cuda,mobilenet_v2,96,1.5604,31.9572,27.0225,1.1734 -cuda,nvidia_deeprecommender,256,1.0605,9.2080,4.1318,0.9711 -cuda,phlippe_densenet,128,1.0237,27.5988,28.0400,1.0023 -cuda,phlippe_resnet,128,1.0493,10.9751,10.2485,1.0092 -cuda,pytorch_CycleGAN_and_pix2pix,1,1.3724,8.2225,11.9561,1.0219 -cuda,pytorch_stargan,16,1.1835,11.9178,10.0507,1.0868 -cuda,pytorch_unet,1,1.3787,29.7543,13.7711,1.0100 -cuda,resnet152,32,0.9834,63.2446,67.7935,0.9991 -cuda,resnet18,16,0.9451,9.4977,11.7663,0.9948 -cuda,resnet50,32,1.0513,24.5141,24.6629,1.0021 -cuda,resnext50_32x4d,8,0.9216,22.2460,24.3420,0.9984 -cuda,shufflenet_v2_x1_0,128,1.1943,25.4520,28.8611,1.0951 -cuda,soft_actor_critic,256,0.8691,1.9637,3.3716,0.9996 -cuda,speech_transformer,32,1.2718,35.2922,46.9957,1.0897 -cuda,squeezenet1_1,32,1.1302,8.4540,7.9625,1.0771 -cuda,timm_efficientdet,1,1.3370,80.0377,120.1814,1.2713 -cuda,timm_efficientnet,32,1.1874,27.6302,33.9059,1.0971 -cuda,timm_nfnet,128,1.4525,77.3461,34.3270,1.1056 -cuda,timm_regnet,32,1.0644,50.6953,35.7562,1.0000 -cuda,timm_resnest,32,1.6200,14.7763,17.2245,1.0906 -cuda,timm_vision_transformer,32,1.0800,19.4188,22.0255,0.9966 -cuda,timm_vision_transformer_large,32,1.0081,393.1742,127.8083,0.9735 -cuda,timm_vovnet,32,1.1472,22.4727,22.7328,1.0120 -cuda,torchrec_dlrm,0,0.0000 -cuda,tts_angular,64,0.8974,6.5057,2.5555,0.9973 -cuda,vgg16,64,1.2909,50.7405,6.1510,0.9828 -cuda,yolov3,16,1.2930,54.8069,41.9269,1.0563 diff --git a/.github/workflows/torchinductor/scripts/install_torchinductor.sh b/.github/workflows/torchinductor/scripts/install_torchinductor.sh deleted file mode 100755 index 2893ca203..000000000 --- a/.github/workflows/torchinductor/scripts/install_torchinductor.sh +++ /dev/null @@ -1,54 +0,0 @@ -#!/bin/bash - -# remember where we started -ROOT="$(pwd)" - -# torchinductor venv -whoami -python3 -m venv /opt/torchinductor_venv -# shellcheck source=/dev/null -source /opt/torchinductor_venv/bin/activate -# shellcheck source=/dev/null -source ./.github/workflows/torchinductor/scripts/common.sh - -# pytorch nightly -pip3 install --force-reinstall --pre torch torchtext torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu118 -# pytorch source to get torchbench for dynamo -cd /opt || exit -git clone --recursive https://github.com/pytorch/pytorch -cd pytorch || exit -# if you are updating an existing checkout -git submodule sync -git submodule update --init --recursive -cd .. - -# required packages -pip3 install expecttest psutil - -# torchbench -pip3 install pyyaml -git clone https://github.com/pytorch/benchmark.git -cd benchmark || exit -python3 install.py -cd .. - -# timm -git clone https://github.com/huggingface/pytorch-image-models.git -cd pytorch-image-models || exit -pip3 install -e . -cd .. - -# build our own triton -cd "$ROOT" || exit -cd python || exit -rm -rf build -pip3 install -e . -pip3 uninstall pytorch-triton -y - -# clean up cache -rm -rf /tmp/torchinductor_root/ -rm -rf ~/.triton/cache -rm -rf "$TEST_REPORTS_DIR" - -# go back to where we started -cd "$ROOT" || exit diff --git a/.github/workflows/torchinductor/scripts/run_torchinductor_perf.sh b/.github/workflows/torchinductor/scripts/run_torchinductor_perf.sh deleted file mode 100755 index 6710cfed4..000000000 --- a/.github/workflows/torchinductor/scripts/run_torchinductor_perf.sh +++ /dev/null @@ -1,41 +0,0 @@ -#!/bin/bash - -# remember where we started -ROOT="$(pwd)" -INDUCTOR="$ROOT"/.github/workflows/torchinductor - -# shellcheck source=/dev/null -source /opt/torchinductor_venv/bin/activate -# shellcheck source=/dev/null -source "$INDUCTOR"/scripts/common.sh - -# lock GPU clocks to 1350 MHz -sudo nvidia-smi -i 0 -pm 1 -sudo nvidia-smi -i 0 --lock-gpu-clocks=1350,1350 - -cd "$PYTORCH_DIR" || exit -TEST_REPORTS_DIR=$TEST_REPORTS_DIR/perf -mkdir -p "$TEST_REPORTS_DIR" - -for model in "${MODELS[@]}"; do - echo "Running performance test for $model" - python3 benchmarks/dynamo/"$model".py --ci --training --performance --disable-cudagraphs\ - --device cuda --inductor --amp --output "$TEST_REPORTS_DIR"/"$model".csv -done - -cd "$ROOT" || exit -for model in "${MODELS[@]}"; do - echo "Checking performance test for $model" - python3 "$INDUCTOR"/scripts/check_perf.py --new "$TEST_REPORTS_DIR"/"$model".csv --baseline "$INDUCTOR"/data/"$model".csv - EXIT_STATUS=$? - if [ "$EXIT_STATUS" -ne 0 ]; then - echo "Performance test for $model failed" - exit "$EXIT_STATUS" - fi -done - -# unlock GPU clocks -sudo nvidia-smi -i 0 -rgc - -# go back to where we started -cd "$ROOT" || exit diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 98f05e720..abd6bbb7c 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -8,7 +8,7 @@ jobs: Build-Wheels: - runs-on: [self-hosted, V100] + runs-on: [self-hosted, CPU] permissions: id-token: write contents: read diff --git a/.isort.cfg b/.isort.cfg deleted file mode 100644 index 833801cca..000000000 --- a/.isort.cfg +++ /dev/null @@ -1,4 +0,0 @@ -[settings] -known_local_folder=triton -line_length=88 -py_version=36 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1729fc92e..efd1c59a7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,34 +15,29 @@ repos: - id: check-shebang-scripts-are-executable - id: detect-private-key - id: debug-statements - - repo: https://github.com/PyCQA/isort - rev: 5.12.0 + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.3 hooks: - - id: isort - stages: [commit, push, manual] - - repo: https://github.com/pre-commit/mirrors-autopep8 - rev: v1.6.0 - hooks: - - id: autopep8 - args: ["-i"] - stages: [commit, push, manual] - - repo: https://github.com/pycqa/flake8 - rev: 6.0.0 - hooks: - - id: flake8 - # TODO: uncomment this to enable more flake8 plugins - # additional_dependencies: - # - flake8-bugbear - # - flake8-comprehensions - # - flake8-docstrings - # - flake8-pyi - # - flake8-simplify + - id: ruff + files: '^python/.*' + args: ["--fix", "--line-length", "120"] stages: [commit, push, manual] exclude: | (?x)( + ^python/triton/runtime/.*| ^test/| ^docs/conf.py$ ) + + - repo: https://github.com/jlebar/yapf + rev: bf301f5ef7777e137b97219842629ca78eb5ef2a + hooks: + - id: yapf + args: ["-p", "-i"] + stages: [commit, push, manual] + exclude: "python/test/unit/language/test_line_info.py" + - repo: https://github.com/pre-commit/mirrors-clang-format rev: v16.0.6 hooks: diff --git a/README.md b/README.md index cafc069b6..aa1f23f6c 100644 --- a/README.md +++ b/README.md @@ -90,10 +90,10 @@ arbitrary LLVM version. 1. Find the version of LLVM that Triton builds against. Check `python/setup.py` for a line like - version = "llvmorg-18-init-7000-g76ce4736721a" + rev = "b1115f8c" This means that the version of Triton you have builds against - [LLVM](https://github.com/llvm/llvm-project) 76ce4736721a. + [LLVM](https://github.com/llvm/llvm-project) b1115f8c. 2. `git checkout` LLVM at this revision. Optionally, make additional modifications to LLVM. @@ -130,6 +130,21 @@ arbitrary LLVM version. Without this, every invocation of `pip install` uses a different symlink to cmake, and this forces ninja to rebuild most of the `.a` files. +- vscode intellisense has some difficulty figuring out how to build Triton's C++ + (probably because, in our build, users don't invoke cmake directly, but + instead use setup.py). Teach vscode how to compile Triton as follows. + + - Do a local build. + - Get the full path to the `compile_commands.json` file produced by the build: + `find python/build -name 'compile_commands.json | xargs readlink -f'` + - In vscode, install the + [C/C++ + extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode.cpptools), + then open the command palette (`Shift + Command + P` on Mac, or `Shift + + Ctrl + P` on Windows/Linux) and open `C/C++: Edit Configurations (UI)`. + - Open "Advanced Settings" and paste the full path to + `compile_commands.json` into the "Compile Commands" textbox. + # Running tests There currently isn't a turnkey way to run all the Triton tests, but you can diff --git a/docs/conf.py b/docs/conf.py index 54ca52468..b40acaa4b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -22,7 +22,6 @@ # -- General configuration ------------------------------------------------ - import os import shutil import sys @@ -121,12 +120,9 @@ def setup(app): return old_documenter(app, obj, parent) sphinx.ext.autosummary.get_documenter = documenter - sphinx.util.inspect.unwrap_all = forward_jit_fn( - sphinx.util.inspect.unwrap_all) - sphinx.util.inspect.signature = forward_jit_fn( - sphinx.util.inspect.signature) - sphinx.util.inspect.object_description = forward_jit_fn( - sphinx.util.inspect.object_description) + sphinx.util.inspect.unwrap_all = forward_jit_fn(sphinx.util.inspect.unwrap_all) + sphinx.util.inspect.signature = forward_jit_fn(sphinx.util.inspect.signature) + sphinx.util.inspect.object_description = forward_jit_fn(sphinx.util.inspect.object_description) # Auto Doc @@ -139,7 +135,8 @@ extensions = [ 'sphinx.ext.coverage', 'sphinx.ext.napoleon', 'sphinx_multiversion', - 'myst_parser'] + 'myst_parser', +] autosummary_generate = True # versioning config @@ -294,6 +291,6 @@ man_pages = [(master_doc, 'triton', 'Triton Documentation', [author], 1)] # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'Triton', 'Triton Documentation', author, - 'Triton', 'One line description of project.', 'Miscellaneous'), + (master_doc, 'Triton', 'Triton Documentation', author, 'Triton', 'One line description of project.', + 'Miscellaneous'), ] diff --git a/docs/meetups/07-18-2023.md b/docs/meetups/07-18-2023/notes.md similarity index 100% rename from docs/meetups/07-18-2023.md rename to docs/meetups/07-18-2023/notes.md diff --git a/docs/meetups/Triton_AMD_update_0823.pdf b/docs/meetups/08-22-2023/amd-update.pdf similarity index 100% rename from docs/meetups/Triton_AMD_update_0823.pdf rename to docs/meetups/08-22-2023/amd-update.pdf diff --git a/docs/meetups/Intel XPU Backend for Triton - Update - 0823.pptx b/docs/meetups/08-22-2023/intel-xpu-update.pptx similarity index 100% rename from docs/meetups/Intel XPU Backend for Triton - Update - 0823.pptx rename to docs/meetups/08-22-2023/intel-xpu-update.pptx diff --git a/docs/meetups/08-22-2023.md b/docs/meetups/08-22-2023/notes.md similarity index 100% rename from docs/meetups/08-22-2023.md rename to docs/meetups/08-22-2023/notes.md diff --git a/docs/meetups/10-25-2023/intel-xpu-update.pdf b/docs/meetups/10-25-2023/intel-xpu-update.pdf new file mode 100644 index 000000000..defc4b719 Binary files /dev/null and b/docs/meetups/10-25-2023/intel-xpu-update.pdf differ diff --git a/docs/meetups/10-25-2023/notes.md b/docs/meetups/10-25-2023/notes.md new file mode 100644 index 000000000..04777eb3f --- /dev/null +++ b/docs/meetups/10-25-2023/notes.md @@ -0,0 +1,24 @@ +#### Agenda: + +##### Items: +1. H100 updates +2. Triton-Shared layer updates +3. Intel update +4. Open discussion + +##### Minutes: +Recording link [here](https://youtu.be/KZAzpKx1ebI) + +1. H100 updates + - Enabled WGMMA by default, now any matmul can reuse it. + - fp8 formats enabled – 1.3 Petaflops on dense matmul on H100 (gemm performance) + - Enabled Flash Attention using wgmma, resulting in 450 teraflop on fwd pass and 250 on backward pass – still working on perf for flash attention + - fp8 numbers with flash attention running in fp8 with matmul is tricky, because the fp8 layout is significantly different than what is returned by wgmma, still wip + +2. Triton-Shared layer + - Please refer to slides for more details + - Created a repo where you can find the middle layer + - Available as a plugin into triton + +3. Intel Update + - Please refer to slides for more details diff --git a/docs/meetups/10-25-2023/triton-shared.pptx b/docs/meetups/10-25-2023/triton-shared.pptx new file mode 100644 index 000000000..ea2f0ea41 Binary files /dev/null and b/docs/meetups/10-25-2023/triton-shared.pptx differ diff --git a/docs/meetups/tritonconf2023.md b/docs/meetups/dev-meetup-2023.md similarity index 100% rename from docs/meetups/tritonconf2023.md rename to docs/meetups/dev-meetup-2023.md diff --git a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td index 0225a210a..b6a779dbe 100644 --- a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td +++ b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -66,6 +66,16 @@ def TT_AtomicRMWAttr : I32EnumAttr< let cppNamespace = "::mlir::triton"; } +def TT_MemSyncScopeAttr : I32EnumAttr< + "MemSyncScope", "", + [ + I32EnumAttrCase<"GPU", 1, "gpu">, + I32EnumAttrCase<"CTA", 2, "cta">, + I32EnumAttrCase<"SYSTEM", 3, "sys">, + ]> { + let cppNamespace = "::mlir::triton"; +} + // Program ID dimensions. def TT_ProgramDim : I32EnumAttr< "ProgramIDDim", "", diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 9445b5c40..7d7c29609 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -242,7 +242,7 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape, let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrLike:$ptr, TT_Type:$val, Optional:$mask, - TT_MemSemanticAttr:$sem); + TT_MemSemanticAttr:$sem, TT_MemSyncScopeAttr:$scope); let results = (outs TT_Type:$result); } @@ -264,7 +264,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>, }]; let arguments = (ins TT_PtrLike:$ptr, TT_Type:$cmp, TT_Type:$val, - TT_MemSemanticAttr:$sem); + TT_MemSemanticAttr:$sem, TT_MemSyncScopeAttr:$scope); let results = (outs TT_Type:$result); } diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index fad35fe00..e0288941c 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -109,6 +109,9 @@ bool isSharedEncoding(Value value); bool isExpensiveCat(CatOp cat, Attribute targetEncoding); +// Return true if a view between the two types cannot be implemented as a no-op. +bool isExpensiveView(Type srcType, Type dstType); + } // namespace gpu } // namespace triton } // namespace mlir diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index c69b62eb5..9f8e78241 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -85,54 +85,6 @@ def TTG_AsyncBulkCommitGroupOp : TTG_Op<"async_bulk_commit_group"> { }]; } - -// Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU. -// This is needed because these ops don't -// handle encodings -// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td#L111 -def TTG_CmpIOp : TTG_Op<"cmpi", [Pure, Elementwise, - SameOperandsAndResultShape, - SameOperandsAndResultEncoding]> { - let summary = "integer comparison operation"; - - let description = [{}]; - - let arguments = (ins Arith_CmpIPredicateAttr:$predicate, - TT_IntLike:$lhs, - TT_IntLike:$rhs); - - let results = (outs TT_BoolLike:$result); -} - -def TTG_CmpFOp : TTG_Op<"cmpf", [Pure, Elementwise, - SameOperandsAndResultShape, - SameOperandsAndResultEncoding]> { - let summary = "floating-point comparison operation"; - - let description = [{}]; - - let arguments = (ins Arith_CmpFPredicateAttr:$predicate, - TT_FloatLike:$lhs, - TT_FloatLike:$rhs); - - let results = (outs TT_BoolLike:$result); -} - -// TODO: migrate to arith::SelectOp on LLVM16 -def TTG_SelectOp : TTG_Op<"select", [Pure, Elementwise, - SameOperandsAndResultShape, - SameOperandsAndResultEncoding]> { - let summary = "select operation"; - - let description = [{}]; - - let arguments = (ins TT_BoolLike:$condition, - TT_Tensor:$true_value, - TT_Tensor:$false_value); - - let results = (outs TT_Type:$result); -} - // TODO[goostavz]: extract a base class for InsertSlice & InsertSliceAsync once the op definition is verified def TTG_InsertSliceOp : TTG_Op<"insert_slice", [AttrSizedOperandSegments, diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.h b/include/triton/Dialect/TritonGPU/Transforms/Passes.h index abab74741..d3383fac2 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.h @@ -38,6 +38,8 @@ std::unique_ptr createTritonGPUOptimizeDotOperandsPass(); std::unique_ptr createTritonGPUOptimizeEpiloguePass(); +std::unique_ptr createTritonGPUOptimizeThreadLocalityPass(); + /// Generate the code for registering passes. #define GEN_PASS_REGISTRATION #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index ad6720616..7d6824889 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -162,6 +162,19 @@ def TritonGPUOptimizeEpilogue : Pass<"tritongpu-optimize-epilogue", "mlir::Modul } +def TritonGPUOptimizeThreadLocality : Pass<"tritongpu-optimize-thread-locality", "mlir::ModuleOp"> { + let summary = "Reduce the cost of synchronization between threads in an SM"; + + let description = [{ + Today, this optimizes reduction yielded by loop to be thread-local until after the loop completes. + }]; + + let constructor = "mlir::createTritonGPUOptimizeThreadLocalityPass()"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + def TritonGPUReorderInstructions: Pass<"tritongpu-reorder-instructions", "mlir::ModuleOp"> { let summary = "Reorder instructions"; diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index fe9f9f8c5..8f7ce75f2 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -111,6 +111,11 @@ bool isExpensiveLoadOrStore(Operation *op); bool canFoldIntoConversion(Operation *op, Attribute targetEncoding); +// Replace ForOp with a new ForOp with extra operands. The YieldOp is not +// updated and needs to be updated separatly for the loop to be correct. +scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, scf::ForOp loop, + ValueRange newIterOperands); + Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, IRMapping &mapping); @@ -140,7 +145,6 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, ArrayRef shape); - } // namespace mlir #endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index ed0cb4cd8..fb281c18f 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -270,16 +270,16 @@ def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [Pure, } def TTNG_DotWaitOp : TTNG_Op<"dot_wait", [DeclareOpInterfaceMethods, - AllTypesMatch<["input", "output"]>]> { + AllTypesMatch<["inputs", "outputs"]>]> { let summary = "dot wait"; - let arguments = (ins TT_FpIntTensor:$input, I32Attr:$pendings); - let results = (outs TT_FpIntTensor:$output); + let arguments = (ins Variadic:$inputs, I32Attr:$pendings); + let results = (outs Variadic:$outputs); let description = [{ This operation defining the waiting action for a async dot, MMAv3 .e.g. The subsequent operations should not execute until this operation completes waiting. }]; - let assemblyFormat = "$input attr-dict `:` type($input)"; + let assemblyFormat = "$inputs attr-dict `:` type($inputs)"; } def TTNG_StoreAsyncOp : TTNG_Op<"store_async", diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 1ed0c7658..27f3c22e2 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -364,17 +364,23 @@ private: scratchAlignment); } } else if (auto atomicCASOp = dyn_cast(op)) { + // only scalar requires scratch memory + // make it explicit for readability auto value = op->getOperand(0); - auto smemShape = getScratchConfigForAtomicCAS(atomicCASOp); - unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, - std::multiplies{}); - auto elemTy = - value.getType().cast().getPointeeType(); - auto bytes = elemTy.isa() - ? elems * kPtrBitWidth / 8 - : elems * elemTy.getIntOrFloatBitWidth() / 8; - maybeAddScratchBuffer(op, bytes, - scratchAlignment); + if (value.getType().dyn_cast()) { + // nothing to do + } else { + auto smemShape = getScratchConfigForAtomicCAS(atomicCASOp); + unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, + std::multiplies{}); + auto elemTy = + value.getType().cast().getPointeeType(); + auto bytes = elemTy.isa() + ? elems * kPtrBitWidth / 8 + : elems * elemTy.getIntOrFloatBitWidth() / 8; + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } } else if (auto callOp = dyn_cast(op)) { auto callable = callOp.resolveCallable(); auto funcOp = dyn_cast(callable); diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index cf09310d4..e9e7fcb30 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -635,10 +635,6 @@ public: } private: - static arith::CmpIPredicate getPredicate(triton::gpu::CmpIOp op) { - return op.getPredicate(); - } - static arith::CmpIPredicate getPredicate(arith::CmpIOp op) { return op.getPredicate(); } @@ -917,13 +913,11 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver) visitors.append(); visitors.append(); visitors.append(); - visitors.append, - CmpOpAxisInfoVisitor>(); + visitors.append>(); visitors.append, LogicalOpAxisInfoVisitor, LogicalOpAxisInfoVisitor>(); - visitors.append, - SelectOpAxisInfoVisitor>(); + visitors.append>(); visitors.append, ShROpAxisInfoVisitor>(); visitors.append, diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index adb08ef68..bb6a56943 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -98,7 +98,7 @@ bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout) { auto dim = sliceLayout.getDim(); auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(sliceLayout.getParent()); if (CTAsPerCGA[dim] != 1) - assert(0 && "Layout conversion to be implemented"); + llvm::report_fatal_error("Layout conversion to be implemented"); } // Case where CTAsPerCGA of dstLayout in the sliced dim is not 1 is supported diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 14212fce7..aa25a5205 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -676,7 +676,8 @@ private: inNumCTAsEachRep, multiDimRepId, inVec, paddedRepShape, origRepShape, outOrd, vals, smemBase); } else { - assert(0 && "ConvertLayout with input layout not implemented"); + llvm::report_fatal_error( + "ConvertLayout with input layout not implemented"); return failure(); } @@ -713,7 +714,8 @@ private: paddedRepShape, origRepShape, outOrd, outVals, smemBase); } else { - assert(0 && "ConvertLayout with output layout not implemented"); + llvm::report_fatal_error( + "ConvertLayout with output layout not implemented"); return failure(); } } @@ -1159,7 +1161,7 @@ private: } return res; } -}; // namespace triton::gpu::ConvertLayoutOp> +}; void populateConvertLayoutOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp index e939e2fe9..4d637ab5c 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp @@ -422,7 +422,6 @@ MMA16816SmemLoader::MMA16816SmemLoader( needTrans = kOrder != order[0]; canUseLdmatrix = elemBytes == 2 || (!needTrans); canUseLdmatrix = canUseLdmatrix && (kWidth == vecWidth); - // canUseLdmatrix = false; if (canUseLdmatrix) { // Each CTA, the warps is arranged as [1xwarpsPerTile] if not transposed, diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp index fcf6b2d76..65d595c8d 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp @@ -146,8 +146,55 @@ struct DotWaitOpConversion matchAndRewrite(triton::nvidia_gpu::DotWaitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto pendings = op.getPendings(); - rewriter.replaceOpWithNewOp( - op, adaptor.getInput(), pendings); + Location loc = op.getLoc(); + if (adaptor.getInputs().size() <= 1) { + Value intput = + adaptor.getInputs().size() == 1 ? adaptor.getInputs()[0] : Value(); + rewriter.replaceOpWithNewOp(op, intput, + pendings); + return success(); + } + std::vector types; + // Pack the inputs into a single struct. + for (Value input : adaptor.getInputs()) { + auto structType = input.getType().dyn_cast(); + if (!structType) + return failure(); + for (Type type : structType.getBody()) + types.push_back(type); + } + auto packedType = + LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types); + Value packed = rewriter.create(loc, packedType); + unsigned outputStructIndex = 0; + for (Value input : adaptor.getInputs()) { + auto structType = input.getType().dyn_cast(); + for (unsigned i = 0; i < structType.getBody().size(); ++i) { + Value value = rewriter.create( + loc, structType.getBody()[i], input, i); + packed = rewriter.create( + loc, packedType, packed, value, outputStructIndex++); + } + } + Value packedOutput = + rewriter.create(loc, packed, pendings); + // Unpack the output into the original struct types. + SmallVector outputs; + outputStructIndex = 0; + for (Value input : adaptor.getInputs()) { + auto structType = input.getType().cast(); + Value unpacked = rewriter.create(loc, structType); + for (unsigned i = 0; i < structType.getBody().size(); ++i) { + Value value = rewriter.create( + loc, packedType.getBody()[outputStructIndex], packedOutput, + outputStructIndex); + outputStructIndex++; + unpacked = rewriter.create(loc, structType, + unpacked, value, i); + } + outputs.push_back(unpacked); + } + rewriter.replaceOp(op, outputs); return success(); } }; diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index b8c4cc00b..a1889f9bf 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -240,7 +240,11 @@ Fp8E5M2_to_Bf16(Location loc, ConversionPatternRewriter &rewriter, static const std::string Fp8E5M2_to_Bf16(bool hasNativeFP) { std::string ret; if (!hasNativeFP) { +<<<<<<< HEAD ret = "{ \n" +======= + ret = "{ \n" +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 ".reg .b32 a<2>, b<2>, c<4>, d<4>, e112; \n" // if input = 0xf1f2f3f4 "mov.u32 e112, 0x77800000; \n" "prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400 @@ -265,6 +269,7 @@ static const std::string Fp8E5M2_to_Bf16(bool hasNativeFP) { "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign) "}"; } else { +<<<<<<< HEAD ret = "{ \n" ".reg .b32 a; \n" ".reg .f16 a<2>; \n" @@ -275,6 +280,25 @@ static const std::string Fp8E5M2_to_Bf16(bool hasNativeFP) { "cvt.bf16.f16 b1, a1; \n" "mov.b32 $0, {b0, b1}; \n" "}"; +======= + ret = + "{ \n" + ".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4 + ".reg .b32 e112; \n" + "mov.u32 e112, 0x77807780; \n" // 2**112 represented as + // bf16x2 + "prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400 + "prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200 + "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff + "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign) + "shr.b32 b0, b0, 3; \n" // b0 >>= 3 + "shr.b32 b1, b1, 3; \n" // shift into bf16 position + "lop3.b32 b0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0) + "lop3.b32 b1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign) + "mul.rn.bf16x2 $0, b0, e112; \n" // b0.exp += 2**7-2**4 + "mul.rn.bf16x2 $1, b1, e112; \n" // exponent compensate = 112 + "}"; +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 } return ret; } @@ -400,7 +424,7 @@ static const std::string Bf16_to_Fp8E5M2(bool hasNativeFP) { "mov.b32 {a0, a1}, $1; \n" "cvt.f32.bf16 b0, a0; \n" "cvt.f32.bf16 b1, a1; \n" - "cvt.rn.satfinite.e5m2x2.f32 $0, b0, b1; \n" + "cvt.rn.satfinite.e5m2x2.f32 $0, b1, b0; \n" "}"; } return ret; @@ -981,7 +1005,7 @@ static const std::string Bf16_to_Fp8E4M3Nv = "mov.b32 {a0, a1}, $1; \n" "cvt.f32.bf16 b0, a0; \n" "cvt.f32.bf16 b1, a1; \n" - "cvt.rn.satfinite.e4m3x2.f32 $0, b0, b1; \n" + "cvt.rn.satfinite.e4m3x2.f32 $0, b1, b0; \n" "}"; /* ----- Packed integer to BF16 ------ */ @@ -1239,8 +1263,118 @@ public: using OpAdaptor = typename SourceOp::Adaptor; explicit ElementwiseOpConversionBase( - TritonGPUToLLVMTypeConverter &typeConverter, PatternBenefit benefit = 1) - : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit) {} + TritonGPUToLLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit = 1) + : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit), + axisAnalysisPass(axisAnalysisPass) {} + + // Try to deduplicate the resultVals based on the + // constancy properties of the result discovered by + // the axis analysis pass. If possible, redundant + // computation is eliminated. + SmallVector maybeDeduplicate(SourceOp op, + SmallVector resultVals) const { + if (!isMemoryEffectFree(op)) + // the op has side effects: can't dedup + return resultVals; + SmallVector results = op->getResults(); + if (results.size() == 0 || results.size() > 1) + // there must be exactly 1 result + return resultVals; + Value result = results[0]; + Type type = result.getType(); + if (!type) + return resultVals; + RankedTensorType rtType = type.dyn_cast(); + if (!rtType) + // the result must be a tensor + return resultVals; + Attribute encoding = rtType.getEncoding(); + if (!encoding) + // encoding not available + return resultVals; + if (!encoding.dyn_cast() && + !encoding.dyn_cast()) { + // TODO: constraining the ecndoing type here is necessary + // for avoiding crashes in the triton::gpu::getElemsPerThread + // call below happening in the test_core::test_fp8_dot_acc + return resultVals; + } + + SmallVector elemsPerThread = + triton::gpu::getElemsPerThread(rtType); + int rank = elemsPerThread.size(); + if (product(elemsPerThread) != resultVals.size()) + return resultVals; + AxisInfo *axisInfo = axisAnalysisPass.getAxisInfo(result); + if (!axisInfo) + // axis info (e.g., constancy) not available + return resultVals; + SmallVector sizePerThread = + triton::gpu::getSizePerThread(encoding); + if (rank != sizePerThread.size()) + return resultVals; + + SmallVector constancy = axisInfo->getConstancy(); + if (rank != constancy.size()) + return resultVals; + bool hasConstancy = false; + for (int i = 0; i < rank; ++i) { + if (constancy[i] > sizePerThread[i]) { + if (constancy[i] % sizePerThread[i] != 0) + // constancy is not evenly covered by sizePerThread + return resultVals; + // can't move the values across different + // "sizePerThread"-sized blocks + constancy[i] = sizePerThread[i]; + } + if (elemsPerThread[i] < 1 || constancy[i] < 1) + return resultVals; + if (!(elemsPerThread[i] % constancy[i] == 0 || + constancy[i] % elemsPerThread[i] == 0)) + // either the constancy along each dimension must fit + // into the elemsPerThread or the other way around + return resultVals; + if (constancy[i] > 1) + hasConstancy = true; + } + if (!hasConstancy) + // nothing to deduplicate + return resultVals; + + if (rank > 1) { + // reorder the shape and constancy vectors by the axis order: + // from the fastest-changing to the smallest-changing axis + SmallVector order = triton::gpu::getOrder(encoding); + if (rank != order.size()) + return resultVals; + ArrayRef orderRef(order); + elemsPerThread = reorder(ArrayRef(elemsPerThread), orderRef); + constancy = reorder(ArrayRef(constancy), orderRef); + } + + SmallVector strides(rank, 1); + for (int i = 1; i < rank; ++i) { + strides[i] = strides[i - 1] * elemsPerThread[i - 1]; + } + SmallVector dedupResultVals; + dedupResultVals.reserve(resultVals.size()); + for (int i = 0; i < resultVals.size(); ++i) { + // each coordinate of the orig_idx is "coarsened" using the + // constancy along this dimension: the resulting dedup_idx + // points to the reused value in the original resultsVal + int orig_idx = i; + int dedup_idx = 0; + for (int j = 0; j < rank; ++j) { + int coord_j = orig_idx % elemsPerThread[j]; + dedup_idx += (coord_j / constancy[j] * constancy[j]) * strides[j]; + orig_idx /= elemsPerThread[j]; + } + dedupResultVals.push_back(resultVals[dedup_idx]); + } + + return dedupResultVals; + } LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, @@ -1283,6 +1417,7 @@ public: auto argTy = op->getOperand(0).getType(); resultVals = reorderValues(resultVals, argTy, resultTy); } + resultVals = maybeDeduplicate(op, resultVals); resultVals = packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter()); resultVals = this->getTypeConverter()->packMfmaOperand(resultVals, resultTy, rewriter, loc); @@ -1294,6 +1429,9 @@ public: return success(); } +protected: + ModuleAxisInfoAnalysis &axisAnalysisPass; + private: int computeCapability; }; @@ -1325,8 +1463,9 @@ struct FpToFpOpConversion triton::FpToFpOp, FpToFpOpConversion>::ElementwiseOpConversionBase; explicit FpToFpOpConversion(TritonGPUToLLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, int computeCapability, PatternBenefit benefit = 1) - : ElementwiseOpConversionBase(typeConverter, benefit), + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), computeCapability(computeCapability) {} static Value convertBf16ToFp32(Location loc, @@ -1398,6 +1537,7 @@ struct FpToFpOpConversion #endif } +<<<<<<< HEAD static Value convertFp32ToFp16(Location loc, ConversionPatternRewriter &rewriter, const Value &v) { @@ -1409,8 +1549,13 @@ struct FpToFpOpConversion cvt(res, operand); return builder.launch(rewriter, loc, f16_ty, false); #else +======= + static Value convertFp32ToFp16NZ(Location loc, + ConversionPatternRewriter &rewriter, + const Value &v) { +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 PTXBuilder builder; - auto &cvt = *builder.create("cvt.rn.f16.f32"); + auto &cvt = *builder.create("cvt.rz.f16.f32"); auto res = builder.newOperand("=h"); auto operand = builder.newOperand(v, "r"); cvt(res, operand); @@ -1480,7 +1625,7 @@ struct FpToFpOpConversion int inVecWidthBits = 32; int outVecWidthBits = 32; if (srcTy.isFloat8E4M3FNUZ() || - (computeCapability >= 90 && srcTy.isFloat8E5M2())) { + (computeCapability >= 90 && srcTy.isFloat8E5M2() && dstTy.isF16())) { inVecWidthBits = 16; outVecWidthBits = 32; } @@ -1524,7 +1669,9 @@ struct FpToFpOpConversion if (srcElementType.isFloat8E4M3FNUZ() || dstElementType.isFloat8E4M3FNUZ() || (computeCapability >= 90 && - (srcElementType.isFloat8E5M2() || dstElementType.isFloat8E5M2()))) { + ((srcElementType.isFloat8E5M2() && + (dstElementType.isF16() || dstElementType.isF32())) || + dstElementType.isFloat8E5M2()))) { numElements = 2; } bool useFP16IntermediateSrc = @@ -1532,18 +1679,17 @@ struct FpToFpOpConversion !(computeCapability >= 90 && (dstElementType.isFloat8E4M3FNUZ() || dstElementType.isFloat8E5M2())); bool isDstFP32 = dstElementType.isF32(); - auto cvtFunc = - getConversionFunc(useFP16IntermediateSrc ? f16_ty : srcElementType, - isDstFP32 ? f16_ty : dstElementType); + Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType; + Type dstType = isDstFP32 ? f16_ty : dstElementType; + auto cvtFunc = getConversionFunc(srcType, dstType); SmallVector inVals; for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) { inVals.push_back(operands[i][0]); } if (useFP16IntermediateSrc) for (Value &v : inVals) - v = convertFp32ToFp16(loc, rewriter, v); - inVals.resize(numElements, - undef(typeConverter->convertType(srcElementType))); + v = convertFp32ToFp16NZ(loc, rewriter, v); + inVals.resize(numElements, undef(typeConverter->convertType(srcType))); SmallVector outVals = cvtFunc(loc, rewriter, inVals); assert(outVals.size() == inVals.size()); outVals.resize(std::min(numElements, operands.size())); @@ -1569,18 +1715,17 @@ Value EmitDualBF16ElementwiseOp(Location loc, } struct CmpIOpConversion - : public ElementwiseOpConversionBase { - using Base = - ElementwiseOpConversionBase; + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; using Base::Base; using Adaptor = typename Base::OpAdaptor; // An interface to support variant DestOp builder. - SmallVector - createDestOps(triton::gpu::CmpIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter, Type elemTy, - MultipleOperandsRange operands, Location loc) const { + SmallVector createDestOps(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, + MultipleOperandsRange operands, + Location loc) const { return {rewriter.create( loc, elemTy, ArithCmpIPredicateToLLVM(op.getPredicate()), operands[0][0], operands[0][1])}; @@ -1611,16 +1756,14 @@ struct CmpIOpConversion }; struct CmpFOpConversion - : public ElementwiseOpConversionBase { - using Base = - ElementwiseOpConversionBase; + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; using Base::Base; using Adaptor = typename Base::OpAdaptor; // An interface to support variant DestOp builder. static SmallVector - createDestOps(triton::gpu::CmpFOp op, OpAdaptor adaptor, + createDestOps(arith::CmpFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Type elemTy, MultipleOperandsRange operands, Location loc) { return {rewriter.create( @@ -1812,7 +1955,7 @@ struct FDivOpConversion } else if (64 == bitwidth) { fdiv.o("rn").o("f64"); } else { - assert(0 && bitwidth && "not supported"); + llvm::report_fatal_error("Unsupported bitwidth"); } auto res = ptxBuilder.newOperand(bitwidth == 32 ? "=r" : "=l"); @@ -2171,20 +2314,40 @@ struct IndexCastOpLowering } }; +struct SelectOpConversion + : ElementwiseOpConversionBase { + using Base = + ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(mlir::arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + std::array llvmOperands; + if (operands[0].size() == 2) { + // Case of scalar condition with tensor operands. + assert(op.getCondition().getType().isInteger(1)); + llvmOperands = {adaptor.getCondition(), operands[0][0], operands[0][1]}; + } else { + llvmOperands = {operands[0][0], operands[0][1], operands[0][2]}; + } + return {rewriter.create( + loc, llvmOperands[1].getType(), llvmOperands, + adaptor.getAttributes().getValue())}; + } +}; + void populateElementwiseOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, int computeCapability, PatternBenefit benefit) { -#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \ - patterns.add>(typeConverter, benefit); - POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp) - POPULATE_TERNARY_OP(arith::SelectOp, LLVM::SelectOp) -#undef POPULATE_TERNARY_OP - #define POPULATE_BINARY_OP(SRC_OP, DST_OP) \ - patterns.add>(typeConverter, benefit); + patterns.add>( \ + typeConverter, axisInfoAnalysis, benefit); POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // - POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // + POPULATE_BINARY_OP(arith::MulIOp, LLVM::MulOp) // * @@ -2208,7 +2371,8 @@ void populateElementwiseOpToLLVMPatterns( #undef POPULATE_BINARY_OP #define POPULATE_UNARY_OP(SRC_OP, DST_OP) \ - patterns.add>(typeConverter, benefit); + patterns.add>( \ + typeConverter, axisInfoAnalysis, benefit); POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp) POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp) POPULATE_UNARY_OP(arith::ExtUIOp, LLVM::ZExtOp) @@ -2224,29 +2388,33 @@ void populateElementwiseOpToLLVMPatterns( POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp) #undef POPULATE_UNARY_OP - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); - patterns.add(typeConverter, computeCapability, benefit); + patterns.add(typeConverter, axisInfoAnalysis, + computeCapability, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); + patterns.add(typeConverter, axisInfoAnalysis, + benefit); + patterns.add(typeConverter, + axisInfoAnalysis, benefit); // ExpOpConversionApprox will try using ex2.approx if the input type is // FP32. For other input types, ExpOpConversionApprox will return failure and // ElementwiseOpConversion defined below will call // __nv_expf for higher-precision calculation - patterns.add(typeConverter, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp index 38cbb8065..dec41ea3d 100644 --- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -28,6 +28,8 @@ static CUtensorMapDataType getCUtensorMapDataType(Type ty) { return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else if (ty.isF32()) { return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + } else if (ty.getIntOrFloatBitWidth() == 8) { + return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8; } else { llvm::report_fatal_error("Unsupported elemTy for InsertSliceAsyncV2Op"); return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; @@ -930,6 +932,11 @@ private: return -1 - op.getValue().dyn_cast().getValue().getZExtValue(); } + if (!isa(v) && + !isa( + v.getDefiningOp())) + llvm::report_fatal_error( + "Operand of `MakeTensorPtrOp` is not the function's argument"); if (v.getDefiningOp() && isa(v.getDefiningOp())) { return getArgIdx(v.getDefiningOp()->getOperand(0)); @@ -1095,40 +1102,81 @@ struct AtomicCASOpConversion TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType()) : valueTy; auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth(); + auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType()); + // vec = 1 for scalar + auto vec = getVectorSize(op.getPtr()); + // tensor + if (TensorTy) { + auto valTy = op.getVal().getType().cast(); + vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); + } + Value mask = getMask(valueTy, rewriter, loc); + auto vecTy = vec_ty(valueElemTy, vec); + SmallVector resultVals(elemsPerThread); - Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation()); - atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3)); - Value casPtr = ptrElements[0]; - Value casCmp = cmpElements[0]; - Value casVal = valElements[0]; + for (size_t i = 0; i < elemsPerThread; i += vec) { + Value casVal = undef(vecTy); + for (int ii = 0; ii < vec; ++ii) { + Value iiVal = createIndexAttrConstant( + rewriter, loc, getTypeConverter()->getIndexType(), ii); + casVal = insert_element(vecTy, casVal, valElements[i + ii], iiVal); + } - PTXBuilder ptxBuilderAtomicCAS; - auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=r", /*init=*/true); - auto *ptrOpr = ptxBuilderAtomicCAS.newAddrOperand(casPtr, "l"); - auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, "r"); - auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r"); - auto &atom = *ptxBuilderAtomicCAS.create("atom"); - std::string semStr; - llvm::raw_string_ostream os(semStr); - os << op.getSem(); - atom.global().o(semStr).o("cas").o("b32"); - atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask); - auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy); - createBarrier(rewriter, loc, numCTAs); + Value casPtr = ptrElements[i]; + Value casCmp = cmpElements[i]; + casVal = valElements[i]; + PTXBuilder ptxBuilderAtomicCAS; + std::string tyId = valueElemNBits * vec == 64 + ? "l" + : (valueElemNBits * vec == 32 ? "r" : "h"); + auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=" + tyId, /*init=*/true); + auto *ptrOpr = ptxBuilderAtomicCAS.newAddrOperand(casPtr, "l"); + auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, tyId); + auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, tyId); + auto &atom = *ptxBuilderAtomicCAS.create("atom"); + auto sTy = "b" + std::to_string(valueElemNBits); + std::string semStr; + llvm::raw_string_ostream os(semStr); + os << op.getSem(); + auto scope = stringifyMemSyncScope(op.getScope()).str(); + atom.global().o(semStr).o(scope).o("cas").o(sTy); + atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask); - PTXBuilder ptxBuilderStore; - auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "r"); - auto *valOprStore = ptxBuilderStore.newOperand(old, "r"); - auto &st = *ptxBuilderStore.create("st"); - st.shared().o("b32"); - st(dstOprStore, valOprStore).predicate(mask); - auto ASMReturnTy = void_ty(ctx); - ptxBuilderStore.launch(rewriter, loc, ASMReturnTy); - createBarrier(rewriter, loc, numCTAs); - Value ret = load(atomPtr); - createBarrier(rewriter, loc, numCTAs); - rewriter.replaceOp(op, {ret}); + if (TensorTy) { + auto retType = vec == 1 ? valueElemTy : vecTy; + auto ret = ptxBuilderAtomicCAS.launch(rewriter, loc, retType); + for (int ii = 0; ii < vec; ++ii) { + resultVals[i + ii] = + vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii)); + } + } else { + auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy); + createBarrier(rewriter, loc, numCTAs); + Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation()); + atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3)); + // Only threads with mask = True store the result + PTXBuilder ptxBuilderStore; + auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "r"); + auto *valOprStore = ptxBuilderStore.newOperand(old, "r"); + auto &st = *ptxBuilderStore.create("st"); + st.shared().o(sTy); + st(dstOprStore, valOprStore).predicate(mask); + auto ASMReturnTy = void_ty(ctx); + ptxBuilderStore.launch(rewriter, loc, ASMReturnTy); + createBarrier(rewriter, loc, numCTAs); + Value ret = load(atomPtr); + createBarrier(rewriter, loc, numCTAs); + rewriter.replaceOp(op, {ret}); + } + } + + if (TensorTy) { + Type structTy = getTypeConverter()->convertType(TensorTy); + Value resultStruct = getTypeConverter()->packLLElements( + loc, resultVals, rewriter, structTy); + rewriter.replaceOp(op, {resultStruct}); + } return success(); } #endif // USE_ROCM @@ -1360,7 +1408,8 @@ struct AtomicRMWOpConversion auto *ptrOpr = ptxBuilderAtomicRMW.newAddrOperand(rmwPtr, "l"); auto *valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId); - auto &atom = ptxBuilderAtomicRMW.create<>("atom")->global().o("gpu"); + auto scope = stringifyMemSyncScope(op.getScope()).str(); + auto &atom = ptxBuilderAtomicRMW.create<>("atom")->global().o(scope); auto rmwOp = stringifyRMWOp(atomicRmwAttr).str(); auto sBits = std::to_string(valueElemNBits); switch (atomicRmwAttr) { @@ -2001,6 +2050,11 @@ private: return -1 - op.getValue().dyn_cast().getValue().getZExtValue(); } + if (!isa(v) && + !isa( + v.getDefiningOp())) + llvm::report_fatal_error( + "Operand of `MakeTensorPtrOp` is not the function's argument"); if (v.getDefiningOp() && isa(v.getDefiningOp())) { return getArgIdx(v.getDefiningOp()->getOperand(0)); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index f59f2b703..34fca8a35 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -2,6 +2,8 @@ #include "Utility.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +namespace { + using namespace mlir; using namespace mlir::triton; @@ -10,6 +12,23 @@ using ::mlir::LLVM::getSRegValue; using ::mlir::triton::gpu::getTotalElemsPerThread; using ::mlir::triton::gpu::SharedEncodingAttr; +Value llGetPid(int axis, Location loc, ModuleOp moduleOp, + ConversionPatternRewriter &rewriter) { + assert(axis >= 0); + assert(axis < 3); + assert(moduleOp); + + // It is not easy to get the compute capability here, so we use numCTAs to + // decide the semantic of GetProgramIdOp. If numCTAs = 1, then + // GetProgramIdOp is converted to "%ctaid", otherwise it is converted to + // "%clusterid". + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp); + + std::string sreg = numCTAs == 1 ? "%ctaid." : "%clusterid."; + sreg.append(1, 'x' + axis); // 0 -> 'x', 1 -> 'y', 2 -> 'z' + return getSRegValue(rewriter, loc, sreg); +} + struct ReturnOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -91,6 +110,12 @@ struct BroadcastOpConversion } }; +// The input print op contains: +// - a "prefix" (string) specified by the user, and +// - one or more "operands" (tensors). +// +// For each operand, we print all of the values contained in this GPU thread, +// one per line, along with the index of the value in its tensor. struct PrintOpConversion : public ConvertTritonGPUOpToLLVMPattern { using ConvertTritonGPUOpToLLVMPattern< @@ -100,14 +125,61 @@ struct PrintOpConversion matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); - SmallVector operands; - for (size_t i = 0; i < op.getNumOperands(); i++) { - auto sub_operands = getTypeConverter()->unpackLLElements( - loc, adaptor.getOperands()[i], rewriter, op.getOperand(i).getType()); - for (auto elem : sub_operands) { - operands.push_back(elem); + Value prefixStr = + LLVM::addStringToModule(loc, rewriter, "printfPrefix_", op.getPrefix()); + + auto getPid = [&](int axis) { + return llGetPid(axis, loc, op->getParentOfType(), rewriter); + }; + std::array pid = {getPid(0), getPid(1), getPid(2)}; + + // Simple printf of a string without any tensors. + if (op.getNumOperands() == 0) { + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + os << "pid (" << getFormatSubstr(pid[0]) << ", " + << getFormatSubstr(pid[1]) << ", " << getFormatSubstr(pid[2]) << ")%s"; + llPrintf(formatStr, {pid[0], pid[1], pid[2], prefixStr}, rewriter); + } else { + for (size_t i = 0; i < op.getNumOperands(); i++) { + // Elements of the tensor that are resident in this GPU thread. + auto elems = getTypeConverter()->unpackLLElements( + loc, adaptor.getOperands()[i], rewriter, + op.getOperand(i).getType()); + + // Get the indices of `elems` within the tensor. Note that if `elems` + // has an "interesting" layout, then these will not be in any + // particularly nice order. + + // Extract the shape of the tensor being printed and use it to figure + // out how many digits we need for each of the dimensions. + SmallVector dimWidths; + SmallVector> indices; + if (auto rankedTy = + op.getOperand(i).getType().dyn_cast()) { + indices = + emitIndices(loc, rewriter, rankedTy.getEncoding(), rankedTy); + for (int64_t dim : rankedTy.getShape()) { + if (dim > 0) { + dimWidths.push_back(static_cast(std::ceil(std::log10(dim)))); + } else { + dimWidths.push_back(0); + } + } + } else { + // We're printing a scalar. + assert(elems.size() == 1); + indices.push_back({}); + } + + if (!elems.empty()) { + printTensor(prefixStr, /*operand=*/i, + /*numOperands=*/op.getNumOperands(), elems, pid, indices, + dimWidths, rewriter); + } } } +<<<<<<< HEAD std::string formatStr; llvm::raw_string_ostream os(formatStr); os << op.getPrefix(); @@ -124,26 +196,118 @@ struct PrintOpConversion #else llPrintf(formatStr, operands, rewriter); #endif +======= +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 rewriter.eraseOp(op); return success(); } - std::string getFormatSubstr(Value value) const { + void printTensor(Value prefixStr, size_t operand, size_t numOperands, + ArrayRef elems, std::array pid, + ArrayRef> indices, + ArrayRef dimWidths, + ConversionPatternRewriter &rewriter) const { + assert(!elems.empty()); + assert(elems.size() == indices.size()); + assert(dimWidths.size() == indices.front().size()); + + size_t rank = dimWidths.size(); + + // Format is: + // pid (, , ) idx (, , ...) (operand ) + // where we leave off "(operand )" if there's only one operand. + // + // The Python wrapper munges `prefix` so that it prints nicely (e.g. starts + // with " " and ends with ": "). + + Value formatStrValue; + for (int i = 0; i < elems.size(); i++) { + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + + // nvptx printf can only accept 32 args; if we pass more than that, it + // will print garbage for the trailing args. + constexpr int kMaxPrintfOperands = 32; + SmallVector printfOperands; + + // TODO(jlebar): We really should pad the pid, but because the max pid is + // not known at compile-time, this would require nontrivial device-side + // work. + os << "pid ("; + for (int j = 0; j < pid.size(); j++) { + if (j != 0) { + os << ", "; + } + os << getFormatSubstr(pid[j]); + printfOperands.push_back(pid[j]); + } + os << ") "; + + // If `rank` is large enough, we could end up exceeding + // kMaxPrintfOperands. In that case, just truncate the index. + // (Subtract 2 because we're going to add two operands after the index.) + int maxAllowedRank = kMaxPrintfOperands - printfOperands.size() - 2; + + os << "idx ("; + const auto &index = indices[i]; + for (size_t dim = 0; dim < index.size(); dim++) { + if (dim != 0) { + os << ", "; + } + if (dim == maxAllowedRank) { + os << "... (truncated)"; + break; + } + os << getFormatSubstr(index[dim], /*width=*/dimWidths[dim]); + printfOperands.push_back(index[dim]); + } + os << ")"; + + os << "%s"; + printfOperands.push_back(prefixStr); + + if (numOperands > 1) { + os << "(operand " << operand << ") "; + } + + auto elem = elems[i]; + os << getFormatSubstr(elem); + printfOperands.push_back(elem); + + // It's the same format string each iteration, but it's a lot easier if we + // construct the format string at the same time as we populate + // printfOperands. But we don't want to create BLOCK_SIZE duplicate + // strings, so we cache the Value. + if (i == 0) { + formatStrValue = llPrintf(formatStr, printfOperands, rewriter); + } else { + llPrintf(formatStrValue, printfOperands, rewriter); + } + } + } + + std::string getFormatSubstr(Value value, + std::optional width = std::nullopt) const { + std::string prefix = "%"; + if (width.has_value()) { + prefix += std::to_string(*width); + } + Type type = value.getType(); if (type.isa()) { - return "%p"; + return prefix + "p"; } else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { - return "%f"; + return prefix + "f"; } else if (type.isSignedInteger()) { if (type.getIntOrFloatBitWidth() == 64) - return "%lli"; + return prefix + "lli"; else - return "%i"; + return prefix + "i"; } else if (type.isUnsignedInteger() || type.isSignlessInteger()) { if (type.getIntOrFloatBitWidth() == 64) - return "%llu"; + return prefix + "llu"; else - return "%u"; + return prefix + "u"; } assert(false && "not supported type"); return ""; @@ -199,9 +363,22 @@ struct PrintOpConversion return {newType, newOp}; } - static void llPrintf(StringRef msg, ValueRange args, + // Returns a Value for the format string, which you can reuse. + static Value llPrintf(StringRef msg, ValueRange args, + ConversionPatternRewriter &rewriter) { + assert(!msg.empty() && "printf with empty string not supported"); + llvm::SmallString<64> msgNewline(msg); + msgNewline.push_back('\n'); + msgNewline.push_back('\0'); + Value msgValue = + LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), + rewriter, "printfFormat_", msgNewline); + llPrintf(msgValue, args, rewriter); + return msgValue; + } + + static void llPrintf(Value msg, ValueRange args, ConversionPatternRewriter &rewriter) { - assert(!msg.empty() && "printf with empty string not support"); Type int8Ptr = ptr_ty(i8_ty); auto *ctx = rewriter.getContext(); @@ -213,11 +390,6 @@ struct PrintOpConversion Value one = i32_val(1); Value zero = i32_val(0); - llvm::SmallString<64> msgNewline(msg); - msgNewline.push_back('\n'); - msgNewline.push_back('\0'); - Value prefixString = - LLVM::addStringToModule(loc, rewriter, "printfFormat_", msgNewline); Value bufferPtr = null(int8Ptr); SmallVector newArgs; @@ -245,7 +417,7 @@ struct PrintOpConversion bufferPtr = bitcast(allocated, int8Ptr); } - SmallVector operands{prefixString, bufferPtr}; + SmallVector operands{msg, bufferPtr}; call(funcOp, operands); } }; @@ -446,6 +618,7 @@ struct GetProgramIdOpConversion LogicalResult matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { +<<<<<<< HEAD #ifdef USE_ROCM Location loc = op->getLoc(); @@ -470,6 +643,10 @@ struct GetProgramIdOpConversion sreg.append(1, 'x' + op.getAxisAsInt()); // 0 -> 'x', 1 -> 'y', 2 -> 'z' Value programId = getSRegValue(rewriter, loc, sreg); +======= + Value programId = llGetPid(op.getAxisAsInt(), op->getLoc(), + op->getParentOfType(), rewriter); +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 rewriter.replaceOp(op, programId); return success(); #endif @@ -770,28 +947,9 @@ struct AsyncBulkCommitGroupOpConversion } }; -namespace mlir { -namespace LLVM { +} // namespace -void vprintf(StringRef msg, ValueRange args, - ConversionPatternRewriter &rewriter) { - PrintOpConversion::llPrintf(msg, args, rewriter); -} - -void vprintf_array(Value thread, ArrayRef arr, std::string info, - std::string elem_repr, ConversionPatternRewriter &builder) { - std::string fmt = info + " t-%d "; - std::vector new_arr({thread}); - for (int i = 0; i < arr.size(); ++i) { - fmt += elem_repr + ((i == arr.size() - 1) ? "" : ", "); - new_arr.push_back(arr[i]); - } - - vprintf(fmt, new_arr, builder); -} - -} // namespace LLVM -} // namespace mlir +namespace mlir::triton { void populateTritonGPUToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, @@ -818,3 +976,5 @@ void populateTritonGPUToLLVMPatterns( patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); } + +} // namespace mlir::triton diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h index 901907358..b49791543 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h @@ -6,6 +6,8 @@ using namespace mlir; using namespace mlir::triton; +namespace mlir::triton { + void populateTritonGPUToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, @@ -13,4 +15,6 @@ void populateTritonGPUToLLVMPatterns( ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, PatternBenefit benefit); +} // namespace mlir::triton + #endif diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h index 789aa7be0..3e2701af0 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h @@ -38,19 +38,6 @@ namespace ttng = ::mlir::triton::nvidia_gpu; typedef DenseMap TensorPtrMapT; -namespace mlir { -namespace LLVM { - -// Helper function for using printf in LLVM conversion. -void vprintf(StringRef msg, ValueRange args, - ConversionPatternRewriter &rewriter); - -void vprintf_array(Value thread, ArrayRef arr, std::string info, - std::string elem_repr, ConversionPatternRewriter &builder); - -} // namespace LLVM -} // namespace mlir - // FuncOpConversion/FuncOpConversionBase is borrowed from // https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276 // since it is not exposed on header files in mlir v14 @@ -193,10 +180,10 @@ public: // Key: {layout, shape, withCTAOffset} struct IndexCacheInfo { DenseMap, CacheKeyDenseMapInfo> - *baseIndexCache; + *baseIndexCache = nullptr; DenseMap>, - CacheKeyDenseMapInfo> *indexCache; - OpBuilder::InsertPoint *indexInsertPoint; + CacheKeyDenseMapInfo> *indexCache = nullptr; + OpBuilder::InsertPoint *indexInsertPoint = nullptr; }; explicit ConvertTritonGPUOpToLLVMPatternBase( @@ -835,7 +822,7 @@ public: emitIndicesForDistributedLayout(loc, b, slice, type, withCTAOffset); } else { llvm_unreachable( - "emitIndices for layouts other than blocked & slice not " + "emitIndices for layouts other than blocked, mma, and slice not " "implemented yet"); } if (cache) { diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index ca53a82dd..0e46f7d1e 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -64,8 +64,8 @@ Type TritonGPUToLLVMTypeConverter::convertTritonPointerType( for (size_t i = 0; i < 2 * shape.size(); ++i) types.push_back(IntegerType::get(ctx, 64)); - types.push_back( - LLVM::LLVMPointerType::get(eleType, type.getAddressSpace())); + types.push_back(LLVM::LLVMPointerType::get(convertType(eleType), + type.getAddressSpace())); return LLVM::LLVMStructType::getLiteral(ctx, types); } diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index d690acf6d..694c0be78 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -445,8 +445,10 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, } Value zero = i32_val(0); - Value globalPtr = - rewriter.create(UnknownLoc::get(ctx), global); + Type globalPtrType = + LLVM::LLVMPointerType::get(globalType, global.getAddrSpace()); + Value globalPtr = rewriter.create( + UnknownLoc::get(ctx), globalPtrType, global.getSymName()); Value stringStart = rewriter.create(UnknownLoc::get(ctx), ptr_ty(i8_ty), globalPtr, SmallVector({zero, zero})); diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index fdd47f2de..001c83783 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -147,6 +147,8 @@ struct ViewOpConversion : public ConvertTritonGPUOpToLLVMPattern { matchAndRewrite(ViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); + assert(!triton::gpu::isExpensiveView(op.getSrc().getType(), op.getType()) && + "expensive view not supported"); auto resultTy = op.getType().template cast(); auto vals = this->getTypeConverter()->unpackLLElements( loc, adaptor.getSrc(), rewriter, op.getOperand().getType()); diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 327a9138e..a13b714c1 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -46,23 +46,6 @@ template struct GenericOpPattern : public OpConversionPattern { } }; -template -class ArithCmpPattern : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type retType = this->getTypeConverter()->convertType(op.getType()); - addNamedAttrs( - rewriter.replaceOpWithNewOp(op, retType, adaptor.getPredicate(), - adaptor.getLhs(), adaptor.getRhs()), - adaptor.getAttributes()); - return success(); - } -}; - class ArithConstantPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -122,8 +105,9 @@ void populateArithPatternsAndLegality(TritonGPUTypeConverter &typeConverter, GenericOpPattern, GenericOpPattern, GenericOpPattern, // Cmp - ArithCmpPattern, - ArithCmpPattern, + GenericOpPattern, GenericOpPattern, + // Select + GenericOpPattern, // Cast Ops GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, @@ -132,45 +116,6 @@ void populateArithPatternsAndLegality(TritonGPUTypeConverter &typeConverter, GenericOpPattern>(typeConverter, context); } -// this shouldn't exist if mlir's SelectOp checked encodings properly -class StdSelectPattern : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type retType = this->getTypeConverter()->convertType(op.getType()); - - Value cond = adaptor.getCondition(); - if (llvm::isa(retType) && - !llvm::isa(cond.getType())) { - // triton_gpu.select doesn't support scalar condition values, so add a - // splat - auto retTypeTensor = llvm::cast(retType); - auto retShape = retTypeTensor.getShape(); - auto retEncoding = retTypeTensor.getEncoding(); - Type condTy = - RankedTensorType::get(retShape, cond.getType(), retEncoding); - cond = rewriter.create(op.getLoc(), condTy, cond); - } - - addNamedAttrs( - rewriter.replaceOpWithNewOp( - op, retType, cond, adaptor.getTrueValue(), adaptor.getFalseValue()), - adaptor.getAttributes()); - return success(); - } -}; - -void populateStdPatternsAndLegality(TritonGPUTypeConverter &typeConverter, - RewritePatternSet &patterns, - TritonGPUConversionTarget &target) { - MLIRContext *context = patterns.getContext(); - // Rewrite rule - patterns.add(typeConverter, context); -} - void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns, TritonGPUConversionTarget &target) { @@ -529,6 +474,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, TritonFuncOpPattern>(typeConverter, context); @@ -745,7 +691,6 @@ public: // rewrite patterns RewritePatternSet patterns(context); // add rules - populateStdPatternsAndLegality(typeConverter, patterns, target); populateArithPatternsAndLegality(typeConverter, patterns, target); populateMathPatternsAndLegality(typeConverter, patterns, target); populateTritonPatterns(typeConverter, patterns, numCTAs); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 1db61b8e4..d76ead693 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -52,7 +52,7 @@ unsigned getTotalElemsPerThread(Attribute layout, ArrayRef shape, } else if (auto dotLayout = layout.dyn_cast()) { return dotLayout.getTotalElemsPerThread(shape, eltTy); } else { - assert(0 && "getElemsPerThread not implemented"); + llvm::report_fatal_error("getElemsPerThread not implemented"); return 0; } } @@ -68,7 +68,7 @@ SmallVector getElemsPerThread(Attribute layout, } else if (auto mfmaLayout = layout.dyn_cast()) { return mfmaLayout.getElemsPerThread(shape, eltTy); } else { - assert(0 && "getElemsPerThread not implemented"); + llvm::report_fatal_error("getElemsPerThread not implemented"); return SmallVector(); } } @@ -129,7 +129,7 @@ SmallVector getThreadsPerWarp(Attribute layout) { threadsPerWarp[i] *= parentThreadsPerWarp[sliceLayout.getDim()]; return threadsPerWarp; } - assert(0 && "getThreadsPerWarp not implemented"); + llvm::report_fatal_error("getThreadsPerWarp not implemented"); return {}; } @@ -180,15 +180,17 @@ SmallVector getWarpsPerCTA(Attribute layout) { if (auto sliceLayout = layout.dyn_cast()) { auto parent = sliceLayout.getParent(); auto parentWarpsPerCTA = getWarpsPerCTA(parent); - assert(parentWarpsPerCTA.size() == 2 && - "getWarpsPerCTA only implemented for 2D slice layout"); + assert(parentWarpsPerCTA.size() == 2 || + parentWarpsPerCTA[sliceLayout.getDim()] == 1 && + "getWarpsPerCTA only implemented for 2D slice layout or the " + "slice dim must have 1 warp in the parent layout"); SmallVector warpsPerCTA = parentWarpsPerCTA; warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim()); for (unsigned i = 0; i < warpsPerCTA.size(); i++) warpsPerCTA[i] *= parentWarpsPerCTA[sliceLayout.getDim()]; return warpsPerCTA; } - assert(0 && "getWarpsPerCTA not implemented"); + llvm::report_fatal_error("getWarpsPerCTA not implemented"); return {}; } @@ -264,7 +266,7 @@ SmallVector getSizePerThread(Attribute layout) { } else if (opIdx == 1) { return {4, 1}; } else { - assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1"); + llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); return {}; } } else if (parentLayout.isa()) { @@ -278,12 +280,13 @@ SmallVector getSizePerThread(Attribute layout) { return {}; } } else { - assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not " - "supported yet"); + llvm::report_fatal_error( + "DotOperandEncodingAttr non-MmaEncodingAttr parent not " + "supported yet"); return {}; } } else { - assert(0 && "getSizePerThread not implemented"); + llvm::report_fatal_error("getSizePerThread not implemented"); return {}; } } @@ -337,6 +340,7 @@ SmallVector getThreadsPerCTA(Attribute layout) { threads = {8 * mmaLayout.getWarpsPerCTA()[0], 4 * mmaLayout.getWarpsPerCTA()[1]}; } else +<<<<<<< HEAD assert(0 && "Unimplemented usage of MmaEncodingAttr"); } else if (auto mfmaLayout = layout.dyn_cast()) { if (mfmaLayout.getNonKDim() == 32) { @@ -346,8 +350,11 @@ SmallVector getThreadsPerCTA(Attribute layout) { threads = {16 * mfmaLayout.getWarpsPerCTA()[0], 4 * mfmaLayout.getWarpsPerCTA()[1]}; } +======= + llvm::report_fatal_error("Unimplemented usage of MmaEncodingAttr"); +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 } else { - assert(0 && "Unimplemented usage of getThreadsPerCTA"); + llvm::report_fatal_error("Unimplemented usage of getThreadsPerCTA"); } return threads; @@ -381,11 +388,15 @@ SmallVector getShapePerCTATile(Attribute layout, return {16 * mmaLayout.getWarpsPerCTA()[0], instrShape[1] * mmaLayout.getWarpsPerCTA()[1]}; } +<<<<<<< HEAD assert(0 && "Unexpected MMA layout version found"); } else if (auto mfmaLayout = layout.dyn_cast()) { auto nonKDim = mfmaLayout.getNonKDim(); return {nonKDim * mfmaLayout.getWarpsPerCTA()[0], nonKDim * mfmaLayout.getWarpsPerCTA()[1]}; +======= + llvm::report_fatal_error("Unexpected MMA layout version found"); +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 } else if (auto dotLayout = layout.dyn_cast()) { auto parentLayout = dotLayout.getParent(); assert(parentLayout && "DotOperandEncodingAttr must have a parent"); @@ -401,7 +412,7 @@ SmallVector getShapePerCTATile(Attribute layout, } else if (opIdx == 1) { return {16, parentShapePerCTATile[1]}; } else { - assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1"); + llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); } } else if (auto parentMfmaLayout = parentLayout.dyn_cast()) { @@ -416,15 +427,20 @@ SmallVector getShapePerCTATile(Attribute layout, assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1"); } } else { - assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not " - "supported yet"); + llvm::report_fatal_error( + "DotOperandEncodingAttr non-MmaEncodingAttr parent not " + "supported yet"); } } else { - assert(0 && "Unimplemented usage of getShapePerCTATile"); + llvm::report_fatal_error("Unimplemented usage of getShapePerCTATile"); } return shape; } +bool isExpensiveView(Type srcType, Type dstType) { + return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType); +} + namespace { /* Utility function used by getOrder and getCTAOrder of SliceEncodingAttr. @@ -473,7 +489,7 @@ SmallVector getOrder(Attribute layout) { return SmallVector(sharedLayout.getOrder().begin(), sharedLayout.getOrder().end()); } else { - assert(0 && "Unimplemented usage of getOrder"); + llvm::report_fatal_error("Unimplemented usage of getOrder"); } return {}; }; @@ -494,7 +510,7 @@ CTALayoutAttr getCTALayout(Attribute layout) { else if (auto sharedLayout = layout.dyn_cast()) return sharedLayout.getCTALayout(); else - assert(0 && "Unimplemented usage of getCTALayout"); + llvm::report_fatal_error("Unimplemented usage of getCTALayout"); return {}; } @@ -522,7 +538,8 @@ SmallVector getCTAsPerCGA(Attribute layout) { * in the branch where layout is an instance of SliceEncodingAttr. This is * inconvenient but safe. */ - assert(0 && "getCTAsPerCGA for SliceEncodingAttr is not well-defined"); + llvm::report_fatal_error( + "getCTAsPerCGA for SliceEncodingAttr is not well-defined"); } else if (auto mmaLayout = layout.dyn_cast()) ref = mmaLayout.getCTALayout().getCTAsPerCGA(); #ifdef USE_ROCM @@ -534,7 +551,7 @@ SmallVector getCTAsPerCGA(Attribute layout) { else if (auto sharedLayout = layout.dyn_cast()) ref = sharedLayout.getCTALayout().getCTAsPerCGA(); else - assert(0 && "Unimplemented usage of getCTAsPerCGA"); + llvm::report_fatal_error("Unimplemented usage of getCTAsPerCGA"); return SmallVector(ref.begin(), ref.end()); } @@ -589,7 +606,7 @@ SmallVector getCTAOrder(Attribute layout) { } else if (auto sharedLayout = layout.dyn_cast()) { ref = sharedLayout.getCTALayout().getCTAOrder(); } else { - assert(0 && "Unimplemented usage of getCTAOrder"); + llvm::report_fatal_error("Unimplemented usage of getCTAOrder"); } return SmallVector(ref.begin(), ref.end()); } @@ -642,9 +659,9 @@ unsigned getNumWarpsPerCTA(Attribute layout) { else if (auto dotLayout = layout.dyn_cast()) return getNumWarpsPerCTA(dotLayout.getParent()); else if (auto sharedLayout = layout.dyn_cast()) - assert(0 && "Cannot get numWarps from SharedEncodingAttr"); + llvm::report_fatal_error("Cannot get numWarps from SharedEncodingAttr"); else - assert(0 && "Unimplemented usage of getNumWarpsPerCTA"); + llvm::report_fatal_error("Unimplemented usage of getNumWarpsPerCTA"); return product(warpsPerCTA); } @@ -665,7 +682,7 @@ unsigned getNumCTAs(Attribute layout) { else if (auto sharedLayout = layout.dyn_cast()) CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA(); else - assert(0 && "Unimplemented usage of getNumCTAs"); + llvm::report_fatal_error("Unimplemented usage of getNumCTAs"); return product(CTAsPerCGA); } @@ -1779,13 +1796,15 @@ struct CanonicalizeConvertFromView Operation *arg = op->getOperand(0).getDefiningOp(); if (!arg) return mlir::failure(); + auto convert = dyn_cast(arg); + if (!convert) + return failure(); + if (isExpensiveView(convert.getOperand().getType(), op.getType())) + return failure(); // view(convert) -> view - if (auto convert = dyn_cast(arg)) { - rewriter.replaceOpWithNewOp( - op, op->getResult(0).getType(), convert.getOperand()); - return mlir::success(); - } - return mlir::failure(); + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + convert.getOperand()); + return mlir::success(); } }; @@ -1831,6 +1850,8 @@ struct CanonicalizeConvertFromConvert return mlir::failure(); // cvt(view) -> view if (auto view = dyn_cast(arg)) { + if (isExpensiveView(view.getOperand().getType(), op.getType())) + return failure(); rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), view.getResult()); return mlir::success(); diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index a0887afdc..c9e13866d 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -70,10 +70,15 @@ warpsPerTileV2(tt::DotOp dotOp, const ArrayRef shape, int numWarps) { auto filter = [&dotOp](Operation *op) { return op->getParentRegion() == dotOp->getParentRegion(); }; - auto slices = mlir::getSlice(dotOp, {filter}); + auto slices = multiRootGetSlice(dotOp, {filter}); for (Operation *op : slices) - if (isa(op) && (op != dotOp)) - return {(unsigned)numWarps, 1}; + if (isa(op) && (op != dotOp)) { + if (shape[0] >= shape[1]) { + return {(unsigned)numWarps, 1}; + } else { + return {1, (unsigned)numWarps}; + } + } SmallVector ret = {1, 1}; SmallVector shapePerWarp = {16, 8}; @@ -133,8 +138,18 @@ class BlockedToMMA : public mlir::RewritePattern { mlir::TypeID::get()); } - // finds the first different value bitwidth in the chain of - // shape-preserving unary ops that x depends on + // Finds the first different bitwidth in the chain of shape-preserving + // unary ops that x depends on. + // There are two primary scenarios: + // (1) Upcasting: A sequence such as loading an fp16, followed by arithmetic + // operations, then bitcasting to fp32, and finally computing in fp32. + // (2) Downcasting: This might involve loading an fp32, performing arithmetic + // operations, bitcasting to fp16, and finally computing in fp16. + // In the upcasting scenario, element reordering converts the original + // elements distribution to the order of higher precision primitives. As a + // result, kwidth can be the bitwidth of the lower precision primitive. + // Conversely, in the downcasting scenario, no reordering is performed, + // making it directory use the lower precision primitive. static int computeOrigBitWidth(Value x) { int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth(); int origBitWidth = finalBitWidth; @@ -143,11 +158,17 @@ class BlockedToMMA : public mlir::RewritePattern { opt.omitBlockArguments = true; opt.filter = bwdFilter; getBackwardSlice(x, &slice, opt); - Operation *firstOp = slice.empty() ? nullptr : *slice.begin(); - if (firstOp) - if (Value arg = firstOp->getOperand(0)) - if (RankedTensorType argTy = arg.getType().dyn_cast()) - origBitWidth = argTy.getElementType().getIntOrFloatBitWidth(); + for (auto op : slice) { + if (Value arg = op->getOperand(0)) + if (RankedTensorType argTy = + arg.getType().dyn_cast()) { + auto argBitWidth = argTy.getElementType().getIntOrFloatBitWidth(); + if (argBitWidth != origBitWidth) { + origBitWidth = std::min(origBitWidth, argBitWidth); + break; + } + } + } return origBitWidth; } diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index fab606807..91c11384c 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -5,7 +5,10 @@ add_mlir_dialect_library(TritonGPUTransforms DecomposeConversions.cpp OptimizeDotOperands.cpp OptimizeEpilogue.cpp - Pipeline.cpp + OptimizeThreadLocality.cpp + Pipeliner/MatmulLoopPipeline.cpp + Pipeliner/PipelineExpander.cpp + Pipeliner/SoftwarePipeliner.cpp Prefetch.cpp RemoveLayoutConversions.cpp ReorderInstructions.cpp diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp new file mode 100644 index 000000000..0675cf1f0 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp @@ -0,0 +1,312 @@ +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +using namespace mlir; + +class TritonGPUOptimizeThreadLocalityPass + : public TritonGPUOptimizeThreadLocalityBase< + TritonGPUOptimizeThreadLocalityPass> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + DenseSet reduceOps; + mod.walk([&](triton::ReduceOp reduce) -> void { + auto srcType = reduce.getOperands()[0].getType().cast(); + auto rank = srcType.getShape().size(); + auto srcEncoding = srcType.getEncoding(); + auto reductionOp = getReductionOp(reduce); + if (!reductionOp || + !isa(reductionOp.value())) + return; + // TODO: relax this restriction + if (!(srcEncoding.isa() && rank > 1)) + return; + for (auto operand : reduce->getOperands()) { + auto def = operand.getDefiningOp(); + if (!isa(def)) + return; + } + auto elemsPerThread = + triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()]; + // Not worth applying this optimization if there is only one element per + // thread on the reduction axis + if (elemsPerThread == 1) + return; + if (!reduce->hasOneUse()) + return; + Operation *user = *(reduce->getUsers().begin()); + if (!user->hasOneUse()) + return; + OpOperand &yieldOpOperand = *(user->getUses().begin()); + auto yieldOp = dyn_cast(yieldOpOperand.getOwner()); + if (!yieldOp) + return; + auto operandNumber = yieldOpOperand.getOperandNumber(); + Block *block = reduce->getBlock(); + Operation *parentOp = block->getParentOp(); + auto forOp = dyn_cast(parentOp); + if (!forOp) + return; + auto argNum = yieldOpOperand.getOperandNumber(); + auto oldAccum = forOp.getInitArgs()[argNum]; + auto cstOp = dyn_cast(oldAccum.getDefiningOp()); + if (!cstOp) + return; + reduceOps.insert(reduce); + }); + + for (auto reduce : reduceOps) { + OpBuilder builder(reduce); + auto srcType = reduce.getOperands()[0].getType().cast(); + auto srcShape = srcType.getShape(); + auto srcEncoding = srcType.getEncoding(); + assert(srcEncoding.isa() && + "Thread locality optimization only supports blocked encoding"); + auto blocked = srcEncoding.dyn_cast(); + auto elemsPerThread = + triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()]; + auto rank = srcShape.size(); + // create new layouts + auto blocked3d = getThreadLocalityOptimizedEncoding(reduce); + auto viewOpTensorShape = getThreadLocalityOptimizedShape(reduce); + auto viewOpTensorType = RankedTensorType::get( + viewOpTensorShape, srcType.getElementType(), blocked3d); + auto slice2d = triton::gpu::SliceEncodingAttr::get(mod.getContext(), rank, + blocked3d); + // Get forOp + assert(reduce->hasOneUse()); + OpOperand &use = *(reduce->getUses().begin()); + auto operandNumber = use.getOperandNumber(); + auto oldUpdate = use.getOwner(); + assert(oldUpdate->getNumOperands() == 2); + auto accumOperandNumber = (operandNumber == 0) ? 1 : 0; + auto accumOperand = oldUpdate->getOperand(accumOperandNumber); + assert(accumOperand.isa()); + auto blockArg = accumOperand.dyn_cast(); + auto blockArgNum = blockArg.getArgNumber(); + auto forOp = dyn_cast(blockArg.getOwner()->getParentOp()); + // get oldAccum + auto oldAccum = + forOp.getInitArgs()[blockArgNum - forOp.getNumInductionVars()]; + // get old loop user + Value loopResult = + forOp.getResult(blockArgNum - forOp.getNumInductionVars()); + assert(loopResult.hasOneUse()); + OpOperand &loopUse = *(loopResult.getUses().begin()); + Operation *loopUser = loopUse.getOwner(); + // get old loop yield + auto oldYield = cast(forOp.getBody()->getTerminator()); + // create newAccum initialization + auto newAccum = + createAccum(builder, reduce, oldAccum, viewOpTensorShape, slice2d); + // create new loop by copying the old for op signature and appending + // newAccum to the block arguments + auto newLoop = replaceForOpWithNewSignature( + builder, forOp, ValueRange{newAccum->getResult(0)}); + // create thread local reduction (also adds viewOps) + auto newReduce = createReduce(builder, reduce, viewOpTensorType); + + // create new accum update + auto newUpdate = createUpdate(builder, newLoop, newReduce, oldUpdate); + // create new yield + auto newYield = createYield(builder, newLoop, oldYield, + newUpdate->getResult(0), blockArgNum); + // create post loop reduction on the original reduce axis + auto newReduce2 = createPostLoopReduce(builder, newLoop, reduce); + // add convert_layout to get back to original layout, the result layout + // should now match the layout of the old accumulator (%cst) + Type destType = loopResult.getType(); + auto cvtLayout = createConvertLayout(builder, destType, newReduce2); + // incorporate the original accumulator value into the final result + auto finalOp = incorporateOriginalAccumulatorValue(builder, oldUpdate, + cvtLayout, oldAccum); + // Replace the old loop user with the final result + loopUser->setOperand(loopUse.getOperandNumber(), finalOp->getResult(0)); + + // cleanup + oldYield.erase(); + forOp.erase(); + } + }; + +private: + std::optional getReductionOp(triton::ReduceOp reduce) const { + auto numRegions = reduce->getNumRegions(); + if (numRegions != 1) + return std::nullopt; + Region ®ion = reduce->getRegion(0); + auto numBlocks = region.getBlocks().size(); + if (numBlocks != 1) + return std::nullopt; + Block &block = region.front(); + auto blockWithoutTerminator = block.without_terminator(); + auto blockSizeWithoutTerminator = std::distance( + blockWithoutTerminator.begin(), blockWithoutTerminator.end()); + if (blockSizeWithoutTerminator != 1) + return std::nullopt; + Operation *op = &block.front(); + return std::optional(op); + } + Operation *incorporateOriginalAccumulatorValue(OpBuilder &builder, + Operation *oldUpdate, + Operation *cvtLayout, + Value oldAccum) const { + builder.setInsertionPointAfter(cvtLayout); + IRMapping mapping; + mapping.map(oldUpdate->getOperand(0), oldAccum); + mapping.map(oldUpdate->getOperand(1), cvtLayout->getResult(0)); + auto finalOp = cloneWithInferType(builder, &(*oldUpdate), mapping); + return finalOp; + } + Operation *createConvertLayout(OpBuilder &builder, Type destType, + Operation *newReduce) const { + builder.setInsertionPointAfter(newReduce); + auto newCvt = builder.create( + newReduce->getLoc(), destType, newReduce->getResult(0)); + return newCvt; + } + + Operation *createPostLoopReduce(OpBuilder &builder, scf::ForOp &loop, + triton::ReduceOp &reduce) const { + auto resultIndex = + loop.getBody()->getNumArguments() - 1 - loop.getNumInductionVars(); + auto newLoopResult = loop.getResult(resultIndex); + builder.setInsertionPointAfter(loop); + IRMapping mapping; + mapping.map(*(reduce.getOperands().begin()), newLoopResult); + auto newReduce2 = cloneWithInferType(builder, &(*reduce), mapping); + return newReduce2; + } + + Operation *createYield(OpBuilder &builder, scf::ForOp &loop, + scf::YieldOp &oldYield, Value newUpdate, + int oldAccumBlockArgNum) const { + builder.setInsertionPoint(oldYield); + SmallVector yieldValues = llvm::to_vector(oldYield.getOperands()); + yieldValues[oldAccumBlockArgNum - 1] = + loop.getBody()->getArgument(oldAccumBlockArgNum); + yieldValues.push_back(newUpdate); + auto newYield = + builder.create(oldYield.getLoc(), yieldValues); + return newYield; + } + + Operation *createUpdate(OpBuilder &builder, scf::ForOp &loop, + Operation *newReduce, Operation *oldUpdate) const { + auto blockArgNum = loop.getBody()->getNumArguments() - 1; + auto newArg = loop.getBody()->getArgument(blockArgNum); + builder.setInsertionPointAfter(newReduce); + IRMapping mapping; + mapping.map(oldUpdate->getOperand(0), newArg); + mapping.map(oldUpdate->getOperand(1), newReduce->getResult(0)); + auto newUpdate = cloneWithInferType(builder, oldUpdate, mapping); + return newUpdate; + } + + Operation *createReduce(OpBuilder &builder, triton::ReduceOp reduce, + Type viewOpTensorType) const { + auto srcType = reduce.getOperands()[0].getType().cast(); + auto rank = srcType.getShape().size(); + builder.setInsertionPointAfter(reduce); + IRMapping mapping; + for (auto operand : reduce.getOperands()) { + auto viewOp = builder.create(reduce.getLoc(), + viewOpTensorType, operand); + mapping.map(operand, viewOp); + } + + auto newReduce = cloneWithInferType(builder, &(*reduce), mapping); + newReduce->setAttr("axis", builder.getI32IntegerAttr(rank)); + auto typeInfer = dyn_cast(newReduce); + if (typeInfer) { + SmallVector newTypes; + auto success = typeInfer.inferReturnTypes( + newReduce->getContext(), newReduce->getLoc(), + newReduce->getOperands(), newReduce->getAttrDictionary(), + newReduce->getPropertiesStorage(), newReduce->getRegions(), newTypes); + if (succeeded(success)) { + for (size_t i = 0; i < newTypes.size(); i++) + newReduce->getResult(i).setType(newTypes[i]); + } + } + return newReduce; + } + + Operation *createAccum(OpBuilder &builder, triton::ReduceOp reduce, + Value &oldAccum, SmallVector &shape, + Attribute &slice2d) const { + // Drop the last dimension (thread locality dimension) + SmallVector accumShape(shape.begin(), shape.end() - 1); + auto elemType = + oldAccum.getType().cast().getElementType(); + // Create tensor type for the new accumulator + auto accumType = RankedTensorType::get(accumShape, elemType, slice2d); + // Create new accumulator + builder.setInsertionPointAfter(oldAccum.getDefiningOp()); + auto reductionOp = getReductionOp(reduce); + assert(reductionOp && "Processing a reduce that is not supported!"); + auto neutralVal = mlir::arith::getNeutralElement(reductionOp.value()); + assert(neutralVal && "Could not find neutral value for reduction op!"); + auto denseAttr = DenseElementsAttr::get(accumType, neutralVal.value()); + auto newAccum = builder.create(oldAccum.getLoc(), + accumType, denseAttr); + return newAccum; + } + + SmallVector + getThreadLocalityOptimizedShape(triton::ReduceOp reduce) const { + auto srcType = reduce.getOperands()[0].getType().cast(); + auto srcShape = srcType.getShape(); + auto rank = srcShape.size(); + auto elemsPerThread = + triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()]; + auto viewOpTensorShape = insertValue(srcShape, rank, 1); + viewOpTensorShape[reduce.getAxis()] /= elemsPerThread; + viewOpTensorShape[rank] = elemsPerThread; + return viewOpTensorShape; + } + + Attribute getThreadLocalityOptimizedEncoding(triton::ReduceOp reduce) const { + auto srcType = reduce.getOperands()[0].getType().cast(); + auto rank = srcType.getShape().size(); + auto srcEncoding = srcType.getEncoding(); + auto blocked = srcEncoding.dyn_cast(); + auto sizePerThread3d = + insertValue(blocked.getSizePerThread(), rank, + blocked.getSizePerThread()[reduce.getAxis()]); + sizePerThread3d[reduce.getAxis()] = 1; + auto threadsPerWarp3d = insertValue(blocked.getThreadsPerWarp(), rank, 1); + auto warsPerCTA3d = insertValue(blocked.getWarpsPerCTA(), rank, 1); + auto order3d = insertValue(blocked.getOrder(), 0, rank); + auto ctasPerCGA3d = + insertValue(blocked.getCTALayout().getCTAsPerCGA(), rank, 1); + auto ctasSplitNum3d = + insertValue(blocked.getCTALayout().getCTASplitNum(), rank, 1); + auto ctaOrder3d = + insertValue(blocked.getCTALayout().getCTAOrder(), rank, rank); + auto ctaLayout3d = triton::gpu::CTALayoutAttr::get( + reduce.getContext(), ctasPerCGA3d, ctasSplitNum3d, ctaOrder3d); + auto blocked3d = triton::gpu::BlockedEncodingAttr::get( + reduce.getContext(), sizePerThread3d, threadsPerWarp3d, warsPerCTA3d, + order3d, ctaLayout3d); + return blocked3d; + } + + template + SmallVector insertValue(ArrayRef vec, unsigned index, int value) const { + SmallVector res(vec.begin(), vec.end()); + res.insert(res.begin() + index, static_cast(value)); + return res; + } +}; + +std::unique_ptr mlir::createTritonGPUOptimizeThreadLocalityPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp deleted file mode 100644 index 86d228988..000000000 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ /dev/null @@ -1,1826 +0,0 @@ -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Analysis/Utility.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/Transforms/Passes.h" -#include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" -#include "triton/Tools/Sys/GetEnv.hpp" -#include "llvm/ADT/MapVector.h" -#include "llvm/Support/Debug.h" - -//===----------------------------------------------------------------------===// -// This file implements software pipelining for loops. The implementation here -// is inspired by the pipeline pass in Triton (version 2.0) and SCF's -// LoopPipelining. -// -// We divide the loop body into the following phases: -// a. Pre-load operations: for instance, index computation. -// b. Load operations: loading from global memory to shared memory. -// c. Compute operations: for instance, Triton dot. -// d. Post-load operations: for instance, index computation. -// -// To pipeline the loop, we need to: -// - Hoist the pipelinable load operations for the first numStages-1 iterations -// to the loop pre-header -// - Find all the dependencies of the load operations. -// - Rematerialize the dependencies for their values at the first numStage-1 -// iterations -// - Assemble the loop body (numStage) and prefetch (numStage + 1). -// -// In the prologue, the sequence of operations is the same as the original loop -// body, following the (a) -> (b) -> (c) -> (d) order. In the loop body, -// however, we first execute the compute operations, then pre-load operations, -// post-load operations, and eventually the asynchronous load operations - in -// the (c) -> (a) -> (d) -> (b) order. This is used to better hide the latency -// of the load operations. Because of this, if post-load operations have direct -// dependencies on the load operations, we could repeat the post-load -// operations. More specifically, this occurs when: -// 1. Any load operand has an immediate dependency argument used at numStage-1. -// 2. The argument is first defined at numStage-2. -// To avoid the repeat, we peeled off post-load operations in the prologue that -// satisfy the above two conditions. See the example below for the definition of -// immediate and non-immediate dependencies. -// If we have a load that immediately depends on a block argument in the -// current iteration, it is an immediate dependency. Otherwise, it is a -// non-immediate dependency, which means the load depends on a block argument -// in the previous iterations. -// For example: -// scf.for (%arg0, %arg1, %arg2) { -// %0 = load %arg0 <--- immediate dep, this address is initialized before -// numStages-1. -// %1 = load %arg1 -// %2 = add %1, %arg2 -// %3 = load %2 <--- non-immediate dep, %arg1 must be an -// update-to-date value. -// } -// -// Our pipelining pass share some common characteristics with SCF's -// LoopPipelining. However, it is also noteworthy that our pipelining pass has -// the following characteristics different from SCF's LoopPipelining: -// 1. It can handle loop-carried dependencies of distance greater than 1. -// 2. It does not have a complicated epilogue but instead uses masking to handle -// boundary conditions. -// 3. Each operation/loop-carried argument cannot provide values to both -// immediate and non-immediate dependencies. Otherwise, we have to rematerialize -// the operation and arguments, which would likely increase register pressure. -// For example: -// scf.for (%arg0, %arg1, %arg2) { -// %0 = load %arg0 -// %1 = load %arg1, %0 <--- %0 is both a post-load op at numStages-2 and a -// pre-load op at numStages-1, so that we need two versions of %0. -// %2 = add %0, %arg2 -// scf.yield %arg0, %2, %arg2 -// } -// -//===----------------------------------------------------------------------===// - -using llvm::MapVector; -using namespace mlir; -namespace tt = mlir::triton; -namespace ttg = mlir::triton::gpu; -/// FIXME(Keren): The pipeline pass shouldn't be aware of nvidia_gpu dialect -namespace ttng = mlir::triton::nvidia_gpu; - -#define GEN_PASS_CLASSES -#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" - -#define int_attr(num) builder.getI64IntegerAttr(num) - -namespace { - -// Pass named attrs (e.g., tt.contiguity) from Triton to Triton -void addNamedAttrs(Operation *op, DictionaryAttr dictAttrs) { - NamedAttrList attrs = op->getDiscardableAttrs(); - // Collect the attributes to propagate: the ones in dictAttrs and not yet on - // the operation. - SmallVector toPropagate; - for (const NamedAttribute attr : dictAttrs.getValue()) { - if (!attrs.get(attr.getName())) - toPropagate.push_back(attr); - } - // If we found any, let's set them here as a single step. - if (toPropagate.size()) { - attrs.append(toPropagate); - op->setDiscardableAttrs(attrs); - } -} - -struct ConsumerReleaseInfo { - Value iterVar; - Value stageVar; - Value phaseVar; - Value nextIVVar; - Value stepVar; - Value upperBoundVar; - ttg::CTALayoutAttr CTALayout; - DenseMap consumerStageMap; -}; -typedef DenseMap - ConsumerReleaseMap; - -class LoopPipeliner { - /// Cache of ForOp and YieldOp related to this pipeliner. - scf::ForOp forOp; - scf::YieldOp yieldOp; - - /// Loads to be pipelined - SetVector validLoads; - /// The value that each load will be mapped to (after layout conversion) - DenseMap loadsMapping; - /// load => buffer - DenseMap loadsBuffer; - /// load => buffer type (with shared layout after swizzling) - DenseMap loadsBufferType; - /// load => buffer at stage N - DenseMap> loadStageBuffer; - /// load => after extract - DenseMap loadsExtract; - - /// XXX(Keren): The following are h100 only and disabled - /// load => full barrier arrive - DenseMap loadsBarrierArvOp; - /// load => mbarriers - DenseMap loadsFullBarriers; - DenseMap loadsEmptyBarriers; - /// load => null value or previous load which can share barrier with - DenseMap loadsCanShareBarriers; - /// Maintains the information to emit consumer_release mbarrier_arrive - ConsumerReleaseMap &consumerReleaseMap; - bool hasHopperDot = false; - // XXX(Keren): why the variable name is hopper dot and why do we need this - // check? - void checkHopperDots(SetVector &ops); - // XXX(Keren): it looks more like an optimization to be, not sure if it should - // exist in the base pipeliner - void checkOpShareBarriers(SetVector &ops); - int numLoadsRequireAsyncWait = 0; - int numLoadsRequireMBarrier = 0; - // Number of buffers to allocate for each input. - int numSharedMemorySlices = 0; - - /// Iterator values - Value nextIV; - Value pipelineIterIdx; - Value curWaitIdx; - - // Only needed when numLoadsRequireMBarrier > 0 - Value loopIterIdx; - Value curPhase; - Value curEmptyPhase; - - /// Yield values - SmallVector nextBuffers; - SmallVector extractSlices; - SmallVector yieldValues; - - /// The number of stages in the pipeline. - /// Stages in the range of [0, numStages-1) are in the prologue. - /// numStages-1 is appended after the loop body. - int numStages; - - /// Arg indicies - size_t bufferIdx, loadIdx, depArgsBeginIdx, ivIdx; - DenseMap depArgsIdx; - - /// XXX(Keren): The mode parameter is hacky, should be refactored - // false: legacy mode as a temporary solution for backward compatibility - // true: new mode for hopper - bool mode; - int numWarps; - int numCTAs; - - /// value (in loop) => value at stage N - DenseMap> valueMapping; - /// loop iter arg => value - DenseMap depArgsMapping; - /// forOp value => newForOp value - IRMapping mapping; - /// forOp value => prefetch value - IRMapping nextMapping; - - /// Dependency ops by program order - SmallVector orderedDeps; - - /// arg => source operand defined stages - DenseMap> immediateArgStages; - - /// block arguments that loads depend on - SetVector depArgs; - - /// operation => source operand defined stages - DenseMap> immediateOpStages; - - /// operations that loads depend on - SetVector depOps; - - /// Collect all pipelinable ops - LogicalResult collectOps(SetVector &ops); - - /// Collect values that `v` depends on and are defined inside the loop - void collectValueDep(Value v, int stage, SetVector &opDeps); - - /// Collect all op dependencies - void collectDeps(SetVector &ops, - MapVector> &opDeps); - - /// Check if none of the ops has valid uses - LogicalResult checkOpUses(SetVector &ops); - - /// Check if ops have dependencies that are not pipelinable - void checkOpDeps(SetVector &ops); - - void createBufferTypes(); - - void createOrderedDeps(); - - /// Return the stage at which `v` is defined prior to `stage` - int getValueDefStage(Value v, int stage); - - /// Map `origin` to `newValue` at `stage` - void setValueMapping(Value origin, Value newValue, int stage); - - /// Map `origin` to `newValue` at `stage` according to the association between - /// yieldOp and forOp - void setValueMappingYield(Value origin, Value newValue, int stage); - - /// Map `origin` to `newValue` at the next stage according to the association - /// between yieldOp and forOp - void setValueMappingYield(scf::ForOp newForOp, Value origin, Value newValue); - - /// Return the value mapped to `origin` at `stage`, if it exists. - Value lookupOrDefault(Value origin, int stage); - - /// Get the load mask for `loadOp`, given the mapped mask `mappedMask` (if - /// exists) and the current iteration's `loopCond`. - Value getLoadMask(tt::LoadOp loadOp, Value mappedMask, Value loopCond, - OpBuilder &builder); - - /// Return an empty buffer of size - ttg::AllocTensorOp allocateEmptyBuffer(tt::LoadOp loadOp, OpBuilder &builder); - - /// Collect all args of the new loop - SmallVector collectNewLoopArgs(); - - /// Clone the forOp and return the new forOp - scf::ForOp cloneForOp(ArrayRef newLoopArgs, OpBuilder &builder); - - /// Prefetch the next iteration for `newForOp` - void prefetchNextIteration(scf::ForOp newForOp, OpBuilder &builder); - - /// Check if curIdx is out of bound and wrap value around if necessary - Value getBoundedIterationValue(OpBuilder &builder, Value curIdx, - Value upperBoundIdx, Value curValue, - Value initValue); - - /// Assemble `newForOp`'s yield op - void finalizeYield(scf::ForOp newForOp, OpBuilder &builder); - -public: - LoopPipeliner(scf::ForOp forOp, int numStages, int numWarps, int numCTAs, - bool mode, int numSharedMemorySlices, - ConsumerReleaseMap &consumerReleaseMap) - : forOp(forOp), numStages(numStages), numWarps(numWarps), - numCTAs(numCTAs), mode(mode), - numSharedMemorySlices(numSharedMemorySlices), - consumerReleaseMap(consumerReleaseMap) { - // cache yieldOp - yieldOp = cast(forOp.getBody()->getTerminator()); - } - - LoopPipeliner() = delete; - - /// Collect loads to pipeline. Return success if we can pipeline this loop - LogicalResult initialize(); - - /// Emit pipelined loads (before loop body) - void emitPrologue(); - - /// emit pipelined loads (after loop body) - void emitEpilogue(); - - /// create the new ForOp (add new args & insert prefetched ops) - scf::ForOp createNewForOp(); - - friend struct PipelinePass; -}; - -/// Collect loads to pipeline. Return success if we can pipeline this loop -LogicalResult LoopPipeliner::collectOps(SetVector &ops) { - ModuleOp moduleOp = forOp->getParentOfType(); - ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); - - // We cannot use forOp.walk(...) here because we only want to visit the - // operations in the loop body block. Nested blocks are handled separately. - for (Operation &op : forOp) - if (auto loadOp = dyn_cast(&op)) { - if (isLoadFromTensorPtr(loadOp)) { - ops.insert(loadOp); - } else { - auto ptr = loadOp.getPtr(); - unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr); - if (auto mask = loadOp.getMask()) - vec = - std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); - - auto tensorTy = ptr.getType().dyn_cast(); - if (!tensorTy || tensorTy.getRank() < 2) - continue; - auto ty = - tensorTy.getElementType().cast().getPointeeType(); - unsigned width = vec * ty.getIntOrFloatBitWidth(); - // We do not pipeline all loads for the following reasons: - // 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8 and 16. - // 2. It's likely that pipling small loads won't offer much performance - // improvement and may even hurt performance by increasing register - // pressure. - if (width >= 32) - ops.insert(loadOp); - } - } - - if (ops.empty()) - return failure(); - else - return success(); -} - -void LoopPipeliner::collectValueDep(Value v, int stage, - SetVector &deps) { - // Loop-invariant value, skip - if (v.getParentRegion() != &forOp.getRegion()) - return; - - // Since we only need to peel the loop numStages-1 times, don't worry - // about depends that are too far away - if (stage < 0) - return; - - if (auto arg = v.dyn_cast()) { - if (arg.getArgNumber() > 0) { - deps.insert(v); - collectValueDep(yieldOp->getOperand(arg.getArgNumber() - 1), stage - 1, - deps); - } - } else { // value - deps.insert(v); - for (Value op : v.getDefiningOp()->getOperands()) - collectValueDep(op, stage, deps); - } -} - -void LoopPipeliner::collectDeps( - SetVector &ops, - MapVector> &valueDeps) { - for (auto op : ops) { - for (Value v : op->getOperands()) { - SetVector deps; - collectValueDep(v, numStages - 1, deps); - valueDeps[op] = deps; - } - } -} - -LogicalResult LoopPipeliner::checkOpUses(SetVector &ops) { - DenseSet invalidOps; - // Collect all ops' dependencies - MapVector> opDeps; - collectDeps(ops, opDeps); - - for (Operation *op : ops) { - if (auto loadOp = dyn_cast(op)) { - // Don't pipeline valid loads that depend on other valid loads - // (Because if a valid load depends on another valid load, this load needs - // to wait on the other load in the prologue, which is against the point - // of the pipeline pass) - bool isCandidate = true; - for (Operation *other : ops) - if (isa(other)) - if (opDeps[op].contains(other->getResult(0))) { - isCandidate = false; - break; - } - // We only pipeline loads that have one covert_layout (to dot_op) use - // TODO: lift this constraint in the future - if (isCandidate && loadOp.getResult().hasOneUse() && - !isLoadFromTensorPtr(loadOp)) { - isCandidate = false; - Operation *use = *loadOp.getResult().getUsers().begin(); - Operation *preUse = nullptr; - - // Advance to the first conversion as long as the use resides in shared - // memory and it has a single use itself - while (use) { - if (use->getNumResults() != 1 || !use->getResult(0).hasOneUse()) - break; - auto tensorType = - use->getResult(0).getType().dyn_cast(); - if (!tensorType.getEncoding().isa()) - break; - preUse = use; - use = *use->getResult(0).getUsers().begin(); - } - - if (auto convertLayout = llvm::dyn_cast(use)) { - if (auto tensorType = convertLayout.getResult() - .getType() - .dyn_cast()) - if (auto dotOpEnc = tensorType.getEncoding() - .dyn_cast()) { - isCandidate = true; - loadsMapping[loadOp] = convertLayout; - } - } else if (preUse && isa(use)) { - isCandidate = false; - // for MMAv3 whose dot take SharedEncoding as operands directly - Operation *post = *loadOp.getResult().getUsers().begin(); - auto newOrder = post->getResult(0) - .getType() - .cast() - .getEncoding() - .cast() - .getOrder(); - auto ty = loadOp.getType().cast(); - auto oldOrder = ttg::getOrder(ty.getEncoding()); - // The operand of MMAv3 is in SharedEncoding and it's order should not - // be changed after FuseTranspositions Pass. So we only pipeline the - // load if the order of the loaded BlockedEncoding is the same as the - // order of the SharedEncoding it is converted to. - // TODO: remove this constraint once the LoadOp supports transpose - // fusion - if (newOrder[0] == oldOrder[0] || newOrder[1] == oldOrder[1]) { - isCandidate = true; - loadsMapping[loadOp] = preUse->getResult(0); - } - } - } else if (isCandidate && mode && isLoadFromTensorPtr(loadOp)) { - loadsMapping[loadOp] = loadOp.getResult(); - } else - isCandidate = false; - - if (!isCandidate) - invalidOps.insert(loadOp); - else { - validLoads.insert(loadOp); - if (!isLoadFromTensorPtr(loadOp)) - numLoadsRequireAsyncWait++; - else - numLoadsRequireMBarrier++; - } - } - } - - for (Operation *op : invalidOps) - ops.remove(op); - - if (ops.empty()) - return failure(); - else - return success(); -} - -void LoopPipeliner::checkHopperDots(SetVector &ops) { - // dots to be pipelined - SetVector dots; - for (Operation &op : forOp) { - if (auto dotOp = dyn_cast(&op)) { - auto resTy = dotOp.getResult().getType().dyn_cast(); - if (auto resEnc = resTy.getEncoding().dyn_cast()) { - if (resEnc && resEnc.isHopper()) { - // Don't pipeline valid dots that depend on ops other than scf.yield - // and scf.for - auto dot = dotOp.getResult(); - bool valid = true; - - // all users of dot should be scf.yield - if (!dot.hasOneUse()) - valid = false; - if (!isa(*dot.getUsers().begin())) - valid = false; - - // C should be a block argument - auto CArg = dotOp.getOperand(2).dyn_cast(); - if (!CArg || !CArg.hasOneUse()) - valid = false; - - if (valid) - dots.insert(dotOp); - } - } - } - } - - hasHopperDot = true; -} - -void LoopPipeliner::checkOpShareBarriers(SetVector &ops) { - // Check if loads can share barriers - auto canShare = [&](Value load0, Value load1) -> bool { - if (!load0.hasOneUse() || !load1.hasOneUse()) - return false; - auto use0 = *load0.getUsers().begin(); - auto use1 = *load1.getUsers().begin(); - if (!use0->hasOneUse() || !use1->hasOneUse()) - return false; - if (*use0->getUsers().begin() != *use1->getUsers().begin()) - return false; - return true; - }; - // XXX(Keren): the logic here is pretty weird and might be incomplete - for (Value loadOp : validLoads) { - Value depLoad; - for (auto oldPair : loadsCanShareBarriers) { - Value oldLoad = oldPair.first; - if (canShare(loadOp, oldLoad)) { - depLoad = oldLoad; - break; - } - } - loadsCanShareBarriers[loadOp] = depLoad; - } -} - -void LoopPipeliner::checkOpDeps(SetVector &ops) { - SetVector nonImmediateDepArgs; - SetVector nonImmediateOps; - for (Operation *op : ops) { - for (Value v : op->getOperands()) { - SetVector deps; - collectValueDep(v, numStages - 1, deps); - int defStage = getValueDefStage(v, numStages - 1); - assert(defStage >= 0 && - "newLoopArgs has null args without a define op. Consider either " - "rewrite the loop to reduce cross iteration dependencies or " - "increase the num_stages value."); - for (auto dep : deps) { - auto immediate = deps.front().isa(); - if (auto arg = dyn_cast(dep)) { - depArgs.insert(arg); - if (immediate) - immediateArgStages[arg].insert(defStage); - else - nonImmediateDepArgs.insert(arg); - } else { - depOps.insert(dep.getDefiningOp()); - if (immediate) - immediateOpStages[dep.getDefiningOp()].insert(defStage); - else - nonImmediateOps.insert(dep.getDefiningOp()); - } - } - } - } - - // We could remove the following constraints if we can rematerialize in the - // loop. Check if immediateDepArgs and nonImmediateDepArgs are disjoint. - for (auto &[arg, stages] : immediateArgStages) { - assert(stages.size() == 1 && - "Triton doesn't support an argument provides values for " - "immediate operands of loads from multiple stages. Consider " - "removing post load instructions dependency on this argument."); - assert(!(nonImmediateDepArgs.contains(arg) && - stages.contains(numStages - 2)) && - "Loop-carried arguments provide values for both immediate and " - "non-immediate operands of loads. Please consider removing " - "pre/post load instructions dependency on this argument."); - } - - // Check if immediateOps and nonImmediateOps are disjoint. - for (auto &[op, stages] : immediateOpStages) { - assert(stages.size() == 1 && - "Triton doesn't support an operation provides values for " - "immediate operands of loads from multiple stages. Consider " - "removing post load instructions dependency on this argument."); - assert(!(nonImmediateOps.contains(op) && stages.contains(numStages - 2)) && - "Operations provide values for both immediate and " - "non-immediate operands of loads. Please consider " - "removing pre/post load instructions dependency on this " - "operation."); - } -} - -// helpers -void LoopPipeliner::setValueMapping(Value origin, Value newValue, int stage) { - if (valueMapping.find(origin) == valueMapping.end()) - valueMapping[origin] = SmallVector(numStages); - valueMapping[origin][stage] = newValue; -} - -void LoopPipeliner::setValueMappingYield(Value origin, Value newValue, - int stage) { - for (OpOperand &operand : origin.getUses()) { - if (operand.getOwner() == yieldOp) { - auto yieldIdx = operand.getOperandNumber(); - auto value = forOp.getRegionIterArgs()[yieldIdx]; - setValueMapping(value, newValue, stage); - } - } -} - -void LoopPipeliner::setValueMappingYield(scf::ForOp newForOp, Value origin, - Value newValue) { - for (OpOperand &operand : origin.getUses()) { - if (operand.getOwner() == yieldOp) { - auto yieldIdx = operand.getOperandNumber(); - auto depYieldIdx = depArgsIdx[forOp.getRegionIterArgs()[yieldIdx]]; - auto originArg = forOp.getRegionIterArgs()[yieldIdx]; - nextMapping.map(originArg, newValue); - auto newArg = newForOp.getRegionIterArgs()[depYieldIdx]; - if (!depArgsMapping.contains(newArg)) - depArgsMapping[newArg] = newValue; - } - } -} - -Value LoopPipeliner::lookupOrDefault(Value origin, int stage) { - if (valueMapping.find(origin) == valueMapping.end()) - return origin; - return valueMapping[origin][stage]; -} - -void LoopPipeliner::createBufferTypes() { - for (auto loadCvt : loadsMapping) { - auto loadOp = loadCvt.first; - Value cvt = loadCvt.second; - auto ty = loadOp.getType().cast(); - SmallVector bufferShape(ty.getShape().begin(), - ty.getShape().end()); - bufferShape.insert(bufferShape.begin(), numSharedMemorySlices); - auto CTALayout = ttg::getCTALayout(ty.getEncoding()); - Attribute sharedEnc; - if (auto dotOpEnc = cvt.getType() - .cast() - .getEncoding() - .dyn_cast()) { - // MMAv1 and MMAv2 - bool needTrans = dyn_cast_or_null( - cvt.getDefiningOp()->getOperand(0).getDefiningOp()); - unsigned bitWidth = ty.getElementType().getIntOrFloatBitWidth(); - sharedEnc = ttg::SharedEncodingAttr::get( - ty.getContext(), dotOpEnc, ty.getShape(), - ttg::getOrder(ty.getEncoding()), CTALayout, bitWidth, needTrans); - } else { - // MMAv3 - sharedEnc = ttg::SharedEncodingAttr::get(ty.getContext(), ty.getShape(), - ttg::getOrder(ty.getEncoding()), - CTALayout, ty.getElementType()); - } - // FIXME(Keren): block ptr not handled - loadsBufferType[loadOp] = - RankedTensorType::get(bufferShape, ty.getElementType(), sharedEnc); - } -} - -void LoopPipeliner::createOrderedDeps() { - for (Operation &op : *forOp.getBody()) { - if (depOps.contains(&op)) - orderedDeps.push_back(&op); - else if (op.getNumResults() > 0 && validLoads.contains(op.getResult(0))) - orderedDeps.push_back(&op); - } - assert(depOps.size() + validLoads.size() == orderedDeps.size() && - "depOps contains invalid values"); -} - -int LoopPipeliner::getValueDefStage(Value v, int stage) { - if (stage < 0) - return -1; - if (auto arg = v.dyn_cast()) { - if (arg.getArgNumber() > 0) - return getValueDefStage(yieldOp->getOperand(arg.getArgNumber() - 1), - stage - 1); - llvm_unreachable("Loop induction variable should not be a dependency"); - } else - return stage; -} - -ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(tt::LoadOp loadOp, - OpBuilder &builder) { - // Allocate a buffer for each pipelined tensor - // shape: e.g. (numStages==4), <32x64xbf16> -> <4x32x64xbf16> - Value convertLayout = loadsMapping[loadOp]; - if (auto tensorType = convertLayout.getType().dyn_cast()) - return builder.create(convertLayout.getLoc(), - loadsBufferType[loadOp]); - llvm_unreachable("Async copy's return should be of RankedTensorType"); -} - -LogicalResult LoopPipeliner::initialize() { - // All ops that maybe pipelined - SetVector ops; - - if (collectOps(ops).failed()) - return failure(); - - if (checkOpUses(ops).failed()) - return failure(); - - // XXX(Keren): hopper specific, should be cleaned up - checkHopperDots(ops); - - checkOpShareBarriers(ops); - - checkOpDeps(ops); - - createBufferTypes(); - - createOrderedDeps(); - - return success(); -} - -Value LoopPipeliner::getLoadMask(tt::LoadOp loadOp, Value mappedMask, - Value loopCond, OpBuilder &builder) { - Type maskType = tt::getI1SameShape(loadOp.getType()); - Value mask = loadOp.getMask(); - Value newMask; - if (mask) { - Value cond = loopCond; - if (isa(maskType)) { - cond = builder.create(mask.getLoc(), maskType, loopCond); - } - newMask = builder.create(mask.getLoc(), mappedMask, cond); - } else { - if (isa(maskType)) { - newMask = - builder.create(loopCond.getLoc(), maskType, loopCond); - } else { - newMask = loopCond; - } - } - return newMask; -} - -void LoopPipeliner::emitPrologue() { - OpBuilder builder(forOp); - // Get init operands for loop carried values - for (BlockArgument &arg : forOp.getRegionIterArgs()) { - OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg); - setValueMapping(arg, operand.get(), 0); - } - - // Alloc a vector of MBarriers in size numStages for each load to be pipelined - bool isMcast = false; - for (Value loadOp : validLoads) { - auto load = cast(loadOp.getDefiningOp()); - if (isLoadFromTensorPtr(load)) { - auto loadTy = loadOp.getType().cast(); - auto CTALayout = ttg::CTALayoutAttr::get( - load.getContext(), - /*CTAsPerCGA*/ {static_cast(numCTAs)}, - /*CTASplitNum*/ {1}, - /*CTAOrder*/ {0}); - auto sharedEncoding = ttg::SharedEncodingAttr::get( - load.getContext(), 1, 1, 1, {0}, CTALayout, false); - auto mBarriersTy = RankedTensorType::get( - {numStages}, builder.getIntegerType(64), sharedEncoding); - - if (!loadsCanShareBarriers[loadOp]) { - Value fullBarriers = builder.create( - load.getLoc(), mBarriersTy, 1); - loadsFullBarriers[loadOp] = fullBarriers; - } - auto layout = loadTy.getEncoding(); - auto CTASplitNum = ttg::getCTASplitNum(layout); - auto CTAsPerCGA = ttg::getCTAsPerCGA(layout); - if (CTASplitNum != CTAsPerCGA) { - isMcast = true; - // FIXME: numConsumerThreads could be 32 as well instead of 128 - // incase the consumer is not GMMA - unsigned arriveCnt = ttg::getNumWarpsPerCTA(layout); - if (hasHopperDot) - arriveCnt /= 4; - arriveCnt *= - product(CTAsPerCGA) / product(CTASplitNum); - - Value emptyBarriers = builder.create( - load.getLoc(), mBarriersTy, arriveCnt); - loadsEmptyBarriers[loadOp] = emptyBarriers; - } - } - } - - if (isMcast) { - builder.create(forOp.getLoc(), /*relaxed*/ 1); - builder.create(forOp.getLoc()); - } - - // prologue from [0, numStage-1) - Value iv = forOp.getLowerBound(); - pipelineIterIdx = builder.create(iv.getLoc(), 0, 32); - for (int stage = 0; stage < numStages - 1; ++stage) { - // Special handling for induction variable as the increment is implicit - if (stage != 0) - iv = builder.create(iv.getLoc(), iv, forOp.getStep()); - setValueMapping(forOp.getInductionVar(), iv, stage); - - // Special handling for loop condition as there is no condition in ForOp - Value loopCond = builder.create( - iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound()); - for (Operation *op : orderedDeps) { - Operation *newOp = nullptr; - if (validLoads.contains(op->getResult(0))) { - auto load = cast(op); - // Allocate empty buffer - if (stage == 0) { - loadsBuffer[load] = allocateEmptyBuffer(load, builder); - loadStageBuffer[load] = {loadsBuffer[load]}; - } - // load => copy async - if (auto loadOp = llvm::dyn_cast(op)) { - Value newMask = - getLoadMask(loadOp, lookupOrDefault(loadOp.getMask(), stage), - loopCond, builder); - - if (mode && isLoadFromTensorPtr(loadOp)) { - auto loc = op->getLoc(); - auto mBarTy = tt::PointerType::get(builder.getIntegerType(64), 3); - Value stageVal = - builder.create(loc, stage, 32); - // producer_acquire - if (loadsEmptyBarriers.count(loadOp)) { - Value emptyBarrier = builder.create( - loc, mBarTy, loadsEmptyBarriers[loadOp], stageVal); - auto trueVal = - builder.create(loc, 1, /*bitWidth*/ 1); - builder.create(loc, emptyBarrier, trueVal); - } - - // producer_commit - Value fullBarrier; - if (!loadsCanShareBarriers[loadOp]) { - fullBarrier = builder.create( - loc, mBarTy, loadsFullBarriers[loadOp], stageVal); - loadsExtract[loadOp] = fullBarrier; - } else { - // Reuse the barrier from previouse load. - fullBarrier = loadsExtract[loadsCanShareBarriers[loadOp]]; - } - - auto loadTy = loadOp.getType().dyn_cast(); - assert(loadTy); - auto CTASplitNum = ttg::getCTASplitNum(loadTy.getEncoding()); - auto shapePerSlice = - ttg::getShapePerCTA(CTASplitNum, loadTy.getShape()); - unsigned elems = - std::accumulate(shapePerSlice.begin(), shapePerSlice.end(), 1, - std::multiplies{}); - elems *= (loadTy.getElementType().getIntOrFloatBitWidth() / 8); - - if (!loadsCanShareBarriers[loadOp]) { - Value _0 = builder.create(loc, 0, 32); - Value threadId = builder.create(loc); - Value pred = builder.create( - loc, arith::CmpIPredicate::eq, threadId, _0); - pred = builder.create(loc, pred, loopCond); - Operation *barrierArvOp = builder.create( - loc, fullBarrier, pred, - /*remoteCtaId*/ nullptr, /*trackAsyncOp*/ false, elems); - loadsBarrierArvOp[loadOp] = barrierArvOp; - } else { - // Increase the transcnt for barrier of previouse load by the - // bytes of current load. - Operation *barrierArvOp = - loadsBarrierArvOp[loadsCanShareBarriers[loadOp]]; - unsigned base_elems = - barrierArvOp->getAttr("txCount").cast().getInt(); - barrierArvOp->setAttr("txCount", - IntegerAttr::get(builder.getIntegerType(32), - base_elems + elems)); - } - newOp = builder.create( - loc, loadsBuffer[loadOp].getType(), - lookupOrDefault(loadOp.getPtr(), stage), - loadStageBuffer[loadOp][stage], pipelineIterIdx, fullBarrier, - newMask, lookupOrDefault(loadOp.getOther(), stage), - loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile(), - /*axis*/ 0); - } else { - newOp = builder.create( - op->getLoc(), loadsBuffer[loadOp].getType(), - lookupOrDefault(loadOp.getPtr(), stage), - loadStageBuffer[loadOp][stage], pipelineIterIdx, newMask, - lookupOrDefault(loadOp.getOther(), stage), loadOp.getCache(), - loadOp.getEvict(), loadOp.getIsVolatile(), /*axis*/ 0); - builder.create(op->getLoc()); - } - loadStageBuffer[loadOp].push_back(newOp->getResult(0)); - } else - llvm_unreachable("This should be LoadOp"); - } else { - if (auto loadOp = dyn_cast(op)) { - Value newMask = - getLoadMask(loadOp, lookupOrDefault(loadOp.getMask(), stage), - loopCond, builder); - newOp = builder.create( - loadOp.getLoc(), loadOp.getResult().getType(), - lookupOrDefault(loadOp.getPtr(), stage), newMask, - lookupOrDefault(loadOp.getOther(), stage), - loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), - loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); - addNamedAttrs(newOp, op->getDiscardableAttrDictionary()); - } else - newOp = builder.clone(*op); - // Update loop-carried uses - for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) { - auto it = valueMapping.find(op->getOperand(opIdx)); - if (it != valueMapping.end()) { - Value v = it->second[stage]; - assert(v && "Value not found in valueMapping"); - newOp->setOperand(opIdx, v); - } // else, op at opIdx is a loop-invariant value - } - } - - for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) { - Value originResult = op->getResult(dstIdx); - if (validLoads.contains(originResult)) - break; - setValueMapping(originResult, newOp->getResult(dstIdx), stage); - // Update mapping for loop-carried values (args) - setValueMappingYield(op->getResult(dstIdx), newOp->getResult(dstIdx), - stage + 1); - } - } // for (Operation *op : orderedDeps) - - // Update pipeline index - pipelineIterIdx = builder.create( - iv.getLoc(), pipelineIterIdx, - builder.create(iv.getLoc(), 1, 32)); - Value numSlices = builder.create( - iv.getLoc(), numSharedMemorySlices, 32); - Value _0 = builder.create(iv.getLoc(), 0, 32); - pipelineIterIdx = getBoundedIterationValue(builder, pipelineIterIdx, - numSlices, pipelineIterIdx, _0); - // Some values have not been used by any ops in the loop body - for (BlockArgument arg : forOp.getRegionIterArgs()) - setValueMappingYield(arg, valueMapping[arg][stage], stage + 1); - } // for (int stage = 0; stage < numStages - 1; ++stage) - - // async.wait & extract_slice - if (numLoadsRequireAsyncWait > 0) - builder.create(validLoads.front().getLoc(), - validLoads.size() * (numStages - 2)); - for (Value loadOp : validLoads) { - auto bufferType = loadStageBuffer[loadOp][numStages - 1] - .getType() - .cast(); - auto bufferShape = bufferType.getShape(); - auto sliceType = loadsMapping[loadOp].getType().cast(); - sliceType = RankedTensorType::get({bufferShape[1], bufferShape[2]}, - sliceType.getElementType(), - loadsBufferType[loadOp].getEncoding()); - Value extractSlice = builder.create( - loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1], - SmallVector{int_attr(0), int_attr(0), int_attr(0)}, - SmallVector{int_attr(1), - int_attr(sliceType.getShape()[0]), - int_attr(sliceType.getShape()[1])}, - SmallVector{int_attr(1), int_attr(1), int_attr(1)}); - loadsExtract[loadOp] = extractSlice; - } - curWaitIdx = builder.create(iv.getLoc(), 0, 32); - loopIterIdx = builder.create(iv.getLoc(), 0, 32); - curPhase = builder.create(iv.getLoc(), 0, 1); - curEmptyPhase = builder.create(iv.getLoc(), 1, 1); -} - -void LoopPipeliner::emitEpilogue() { - // If there's any outstanding async copies, we need to wait for them. - if (numLoadsRequireAsyncWait > 0) { - OpBuilder builder(forOp); - OpBuilder::InsertionGuard g(builder); - builder.setInsertionPointAfter(forOp); - builder.create(forOp.getLoc(), 0); - } -} - -SmallVector LoopPipeliner::collectNewLoopArgs() { - // Order of new args: - // (original args) - // (insertSliceAsync buffer at stage numStages - 1) for each load - // (extracted tensor) for each load - // (depArgs at stage numStages - 1) - // (depArgs at stage numStages - 2) - // ... - // (iv at stage numStages - 2) - // (pipeline iteration index) - // (loop iteration index) - // (wait index) - // (phase index) - // (empty phase index) - - // We need this to update operands for yield - // original block arg => new arg's idx - SmallVector newLoopArgs; - for (auto v : forOp.getInitArgs()) - newLoopArgs.push_back(v); - - bufferIdx = newLoopArgs.size(); - for (auto loadOp : validLoads) - newLoopArgs.push_back(loadStageBuffer[loadOp].back()); - - loadIdx = newLoopArgs.size(); - for (auto loadOp : validLoads) - newLoopArgs.push_back(loadsExtract[loadOp]); - - depArgsBeginIdx = newLoopArgs.size(); - for (auto depArg : depArgs) { - depArgsIdx[depArg] = newLoopArgs.size(); - if (immediateArgStages[depArg].contains(numStages - 2)) - // Peel off post load ops in numStage-1 - newLoopArgs.push_back(valueMapping[depArg][numStages - 2]); - else - newLoopArgs.push_back(valueMapping[depArg][numStages - 1]); - } - - ivIdx = newLoopArgs.size(); - newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]); - newLoopArgs.push_back(pipelineIterIdx); - newLoopArgs.push_back(curWaitIdx); - if (numLoadsRequireMBarrier > 0) { - newLoopArgs.push_back(loopIterIdx); - newLoopArgs.push_back(curPhase); - newLoopArgs.push_back(curEmptyPhase); - } - - return newLoopArgs; -} - -scf::ForOp LoopPipeliner::cloneForOp(ArrayRef newLoopArgs, - OpBuilder &builder) { - // Clone the original ForOp - auto newForOp = builder.create( - forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), newLoopArgs); - - // Set mapping on body of the new ForOp - builder.setInsertionPointToStart(newForOp.getBody()); - for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) - mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); - mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); - - // Loop iteration args - Value upperBound = newForOp.getUpperBound(); - Value step = newForOp.getStep(); - Value curIV = newForOp.getRegionIterArgs()[ivIdx]; - pipelineIterIdx = newForOp.getRegionIterArgs()[ivIdx + 1]; - curWaitIdx = newForOp.getRegionIterArgs()[ivIdx + 2]; - if (numLoadsRequireMBarrier > 0) { - loopIterIdx = newForOp.getRegionIterArgs()[ivIdx + 3]; - curPhase = newForOp.getRegionIterArgs()[ivIdx + 4]; - curEmptyPhase = newForOp.getRegionIterArgs()[ivIdx + 5]; - } - - // Clone the loop body, replace original args with args of the new ForOp. - SmallVector loadsFromTensorPtr; - for (Operation &op : forOp.getBody()->without_terminator()) { - if (auto cvtOp = dyn_cast(op)) { - auto result = op.getResult(0); - auto cvtDstTy = result.getType().cast(); - auto it = - std::find(validLoads.begin(), validLoads.end(), op.getOperand(0)); - if (it != validLoads.end()) { - auto loadArgIdx = std::distance(validLoads.begin(), it); - if (cvtDstTy.getEncoding().isa()) { - // We want to find cvt ops that match the following pattern: - // %0 = load %ptr - // %1 (dotOperand) = cvt %0 - // We replace the use new load use with a convert layout - auto cvt = builder.create( - result.getLoc(), cvtDstTy, - newForOp.getRegionIterArgs()[loadIdx + loadArgIdx]); - mapping.map(result, cvt.getResult()); - continue; - } else if (cvtDstTy.getEncoding().isa()) { - // We want to find cvt ops that match the following pattern: - // %0 = load %ptr - // %1 (sharedEncoding) = cvt %0 - // We replace the use new load use with insert_slice_async's result - mapping.map(result, - newForOp.getRegionIterArgs()[loadIdx + loadArgIdx]); - continue; - } - } - } else if (auto loadOp = dyn_cast(op)) { - if (isLoadFromTensorPtr(loadOp)) { - // XXX(Keren): The comparison operator using std::find on tensor ptr - // doesn't work as expected - auto operand = loadOp.getPtr(); - auto tensorTy = - operand.getType().cast().getPointeeType(); - auto loadArgIdx = 0; - for (auto validLoad : validLoads) { - auto defOp = cast(validLoad.getDefiningOp()); - if (isLoadFromTensorPtr(defOp)) { - auto validOperand = defOp.getOperand(0); - auto validTensorTy = - validOperand.getType().cast().getPointeeType(); - if (tensorTy == validTensorTy) - break; - } - loadArgIdx++; - } - // consumer_wait, emitted before the first consumer - auto firstConsumer = getFirstUser(loadOp); - mapping.map(loadOp, newForOp.getRegionIterArgs()[loadIdx + loadArgIdx]); - - // If current load can reuse barriers shared by previous load, then we - // do nothing. - if (!loadsCanShareBarriers[loadOp]) { - // emit mbarrier wait before the first consumer of the loaD - OpBuilder mBarBuilder(firstConsumer); - auto mBarTy = tt::PointerType::get(builder.getIntegerType(64), 3); - Value fullBarrier = mBarBuilder.create( - loadOp.getLoc(), mBarTy, loadsFullBarriers[loadOp], curWaitIdx); - mBarBuilder.create(loadOp.getLoc(), fullBarrier, - curPhase); - } - - loadsFromTensorPtr.push_back(loadOp); - continue; - } - } - cloneWithInferType(builder, &op, mapping); - } - - for (Value load : loadsFromTensorPtr) { - // consumer_relase, emitted after the last consumer - // 'the last consumer' might be updated in the following Phase_1 since - // some of the consumers might be pipelined. Thus we maintain this - // information in 'consumerReleaseMap' and move the position of - // consumer_release barrier in a seperate Phase_2 in case necessary. - if (loadsEmptyBarriers.count(load)) { - auto users = mapping.lookup(load).getUsers(); - DenseMap consumerStageMap; - for (Operation *user : users) { - // All the stage is initialized to zero before Phase_1, - // since no consumers has been pipelined yet. - consumerStageMap[user] = 0; - } - auto CTALayout = ttg::getCTALayout( - load.getType().cast().getEncoding()); - ConsumerReleaseInfo info{ - loopIterIdx, pipelineIterIdx, curEmptyPhase, curIV, - step, upperBound, CTALayout, consumerStageMap}; - consumerReleaseMap[loadsEmptyBarriers[load]] = info; - } - } - - // Remove redundant conversions - // e.g., %145 = triton_gpu.convert_layout %arg15 : (tensor<128x64xf16, - // #shared1>) -> tensor<128x64xf16, #shared1> - for (Operation &op : newForOp.getBody()->without_terminator()) { - if (auto convert_layout = dyn_cast(op)) { - auto result = op.getResult(0); - auto cvtDstTy = result.getType(); - auto operand = convert_layout.getOperand(); - auto tensorTy = operand.getType(); - if (cvtDstTy == tensorTy) - result.replaceAllUsesWith(operand); - } - } - - return newForOp; -} - -Value LoopPipeliner::getBoundedIterationValue(OpBuilder &builder, Value curIdx, - Value upperBoundIdx, - Value curValue, Value initValue) { - Value cond = builder.create( - curIdx.getLoc(), arith::CmpIPredicate::uge, curIdx, upperBoundIdx); - Value selectValue = builder.create( - curIdx.getLoc(), cond, initValue, curValue); - return selectValue; -} - -void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp, - OpBuilder &builder) { - // Map the dep args of the next iteration to the dep args of the current - size_t argIdx = 0; - for (auto depArg : depArgs) { - BlockArgument nextArg = - newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]; - nextMapping.map(depArg, nextArg); - ++argIdx; - } - - // Update loop iteration args - Value curIV = newForOp.getRegionIterArgs()[ivIdx]; - pipelineIterIdx = newForOp.getRegionIterArgs()[ivIdx + 1]; - curWaitIdx = newForOp.getRegionIterArgs()[ivIdx + 2]; - if (numLoadsRequireMBarrier > 0) { - loopIterIdx = newForOp.getRegionIterArgs()[ivIdx + 3]; - curPhase = newForOp.getRegionIterArgs()[ivIdx + 4]; - curEmptyPhase = newForOp.getRegionIterArgs()[ivIdx + 5]; - } - - // Special handling for iv & loop condition - auto idxLoc = curIV.getLoc(); - nextIV = builder.create(idxLoc, curIV, newForOp.getStep()); - Value nextLoopCond = builder.create( - idxLoc, arith::CmpIPredicate::slt, nextIV, newForOp.getUpperBound()); - - // Constants - Value _0 = builder.create(idxLoc, 0, 32); - Value _1 = builder.create(idxLoc, 1, 32); - Value numStagesVal = - builder.create(idxLoc, numStages, 32); - Value numSlices = - builder.create(idxLoc, numSharedMemorySlices, 32); - - // nextWaitIdx - Value waitIdxPlusOne = builder.create(idxLoc, curWaitIdx, _1); - Value nextWaitIdx = getBoundedIterationValue(builder, waitIdxPlusOne, - numSlices, waitIdxPlusOne, _0); - - // Indices of InsertSliceAsyncOp and ExtractSliceOp - Value insertSliceIndex = pipelineIterIdx; - Value extractSliceIndex = nextWaitIdx; - - // Prefetch load deps - // If a load-dependent instruction that uses a block argument, we - // shouldn't update the new mapping of the block argument in the current - // iteration. - // For example. - // %a = add %arg0, %c - // %b = add %arg0, %d - // - // Update %arg0 will cause the value of %b to be incorrect. - // We do need to use the next iteration value of %arg0 because it could be a - // immediate arg of a load op. - // load %arg0 - // %a = add %arg0, %c - // yield %a - // - // We reroder instructions so %a and yield are actually before load. load - // %arg0 should use the updated %arg0. - IRMapping curMapping = nextMapping; - for (Operation *op : orderedDeps) - if (!validLoads.contains(op->getResult(0))) { - if (immediateOpStages[op].contains(numStages - 2)) - // A post load op that provides values for numStage - 2 - curMapping.map(forOp.getInductionVar(), curIV); - else - curMapping.map(forOp.getInductionVar(), nextIV); - Operation *nextOp; - if (auto loadOp = dyn_cast(op)) { - auto newMask = - getLoadMask(loadOp, curMapping.lookupOrDefault(loadOp.getMask()), - nextLoopCond, builder); - nextOp = builder.create( - loadOp.getLoc(), loadOp.getResult().getType(), - curMapping.lookupOrDefault(loadOp.getPtr()), newMask, - curMapping.lookupOrDefault(loadOp.getOther()), - loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), - loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); - addNamedAttrs(nextOp, op->getDiscardableAttrDictionary()); - curMapping.map(loadOp.getResult(), nextOp->getResult(0)); - nextMapping.map(loadOp.getResult(), nextOp->getResult(0)); - } else { - nextOp = builder.clone(*op, curMapping); - for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) - nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx)); - } - - for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) - setValueMappingYield(newForOp, op->getResult(dstIdx), - nextOp->getResult(dstIdx)); - } - - // loads -> async loads - for (Operation *op : orderedDeps) { - Operation *nextOp = nullptr; - // Update loading mask - if (validLoads.contains(op->getResult(0))) { - auto loadOp = llvm::cast(op); - auto mask = loadOp.getMask(); - auto newMask = - getLoadMask(loadOp, nextMapping.lookupOrDefault(loadOp.getMask()), - nextLoopCond, builder); - if (mask) { - // If mask is defined outside the loop, don't update the map more than - // once - if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask))) - nextMapping.map(mask, newMask); - newMask = nextMapping.lookupOrDefault(loadOp.getMask()); - } - Value insertedVal; - if (mode && isLoadFromTensorPtr(loadOp)) { - auto loc = op->getLoc(); - auto mBarTy = tt::PointerType::get(builder.getIntegerType(64), 3); - - // producer_acquire - if (loadsEmptyBarriers.count(loadOp)) { - auto ifOp = builder.create(loc, ArrayRef{}, - nextLoopCond, false); - builder.setInsertionPointToStart(ifOp.thenBlock()); - Value emptyBarrier = builder.create( - loc, mBarTy, loadsEmptyBarriers[loadOp], insertSliceIndex); - builder.create(loc, emptyBarrier, - curEmptyPhase); - builder.setInsertionPointAfter(ifOp); - } - - // producer_commit - Value fullBarrier; - if (!loadsCanShareBarriers[loadOp]) { - fullBarrier = builder.create( - loc, mBarTy, loadsFullBarriers[loadOp], insertSliceIndex); - loadsExtract[loadOp] = fullBarrier; - } else { - // Reuse the barrier from previouse load. - fullBarrier = loadsExtract[loadsCanShareBarriers[loadOp]]; - } - - auto loadTy = loadOp.getType().dyn_cast(); - assert(loadTy); - auto CTASplitNum = ttg::getCTASplitNum(loadTy.getEncoding()); - auto shapePerSlice = - ttg::getShapePerCTA(CTASplitNum, loadTy.getShape()); - unsigned elems = std::accumulate( - shapePerSlice.begin(), shapePerSlice.end(), 1, std::multiplies{}); - elems *= (loadTy.getElementType().getIntOrFloatBitWidth() / 8); - if (!loadsCanShareBarriers[loadOp]) { - Value _0 = builder.create(loc, 0, 32); - Value threadId = builder.create(loc); - Value pred = builder.create( - loc, arith::CmpIPredicate::eq, threadId, _0); - pred = builder.create(loc, pred, nextLoopCond); - Operation *barrierArvOp = builder.create( - loc, fullBarrier, pred, - /*remoteCtaId*/ nullptr, - /*trackAsyncOp*/ false, elems); - loadsBarrierArvOp[loadOp] = barrierArvOp; - } else { - // Increase the transcnt for barrier of previouse load by the bytes of - // current load. - Operation *barrierArvOp = - loadsBarrierArvOp[loadsCanShareBarriers[loadOp]]; - unsigned base_elems = - barrierArvOp->getAttr("txCount").cast().getInt(); - barrierArvOp->setAttr( - "txCount", - IntegerAttr::get(builder.getIntegerType(32), base_elems + elems)); - } - insertedVal = builder.create( - loc, loadsBuffer[loadOp].getType(), - nextMapping.lookupOrDefault(loadOp.getPtr()), - newForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()], - insertSliceIndex, fullBarrier, newMask, - nextMapping.lookupOrDefault(loadOp.getOther()), loadOp.getCache(), - loadOp.getEvict(), loadOp.getIsVolatile(), /*axis*/ 0); - } else { - insertedVal = builder.create( - op->getLoc(), loadsBuffer[loadOp].getType(), - nextMapping.lookupOrDefault(loadOp.getPtr()), - newForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()], - insertSliceIndex, newMask, - nextMapping.lookupOrDefault(loadOp.getOther()), loadOp.getCache(), - loadOp.getEvict(), loadOp.getIsVolatile(), /*axis*/ 0); - builder.create(op->getLoc()); - } - nextBuffers.push_back(insertedVal); - // Extract slice - auto bufferType = insertedVal.getType().cast(); - auto bufferShape = bufferType.getShape(); - auto sliceType = loadsMapping[loadOp].getType().cast(); - sliceType = RankedTensorType::get({bufferShape[1], bufferShape[2]}, - sliceType.getElementType(), - loadsBufferType[loadOp].getEncoding()); - - nextOp = builder.create( - op->getLoc(), sliceType, insertedVal, - SmallVector{extractSliceIndex, int_attr(0), - int_attr(0)}, - SmallVector{int_attr(1), - int_attr(sliceType.getShape()[0]), - int_attr(sliceType.getShape()[1])}, - SmallVector{int_attr(1), int_attr(1), int_attr(1)}); - extractSlices.push_back(nextOp->getResult(0)); - - // Update mapping of results - for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) - // If this is a loop-carried value, update the mapping for yield - setValueMappingYield(newForOp, op->getResult(dstIdx), - nextOp->getResult(dstIdx)); - } - } - - // Some values have not been used by any ops in the loop body - for (BlockArgument arg : forOp.getRegionIterArgs()) - setValueMappingYield(newForOp, arg, - newForOp.getRegionIterArgs()[depArgsIdx[arg]]); - - // async.wait & extract_slice - if (numLoadsRequireAsyncWait > 0) { - Operation *asyncWait = builder.create( - validLoads[0].getLoc(), validLoads.size() * (numStages - 2)); - for (auto it = extractSlices.rbegin(); it != extractSlices.rend(); ++it) { - // move extract_slice after asyncWait - it->getDefiningOp()->moveAfter(asyncWait); - } - } - - // Bump pipelineIterIdx - Value pipelineIterIdxPlusOne = - builder.create(idxLoc, pipelineIterIdx, _1); - pipelineIterIdx = getBoundedIterationValue( - builder, pipelineIterIdxPlusOne, numSlices, pipelineIterIdxPlusOne, _0); - - // Bump curWaitIdx - curWaitIdx = nextWaitIdx; - - if (numLoadsRequireMBarrier > 0) { - // Bump loopIterIdx - loopIterIdx = builder.create(idxLoc, loopIterIdx, _1); - - Value _1_1b = builder.create(idxLoc, 1, 1); - - // Flip curPhase - Value nextPhase = builder.create(idxLoc, curPhase, _1_1b); - curPhase = getBoundedIterationValue(builder, waitIdxPlusOne, numStagesVal, - curPhase, nextPhase); - - // Flip curEmptyPhase - Value nextEmptyPhase = - builder.create(idxLoc, curEmptyPhase, _1_1b); - curEmptyPhase = - getBoundedIterationValue(builder, pipelineIterIdxPlusOne, numStagesVal, - curEmptyPhase, nextEmptyPhase); - } -} - -void LoopPipeliner::finalizeYield(scf::ForOp newForOp, OpBuilder &builder) { - SmallVector yieldValues; - for (Value v : yieldOp->getOperands()) - yieldValues.push_back(mapping.lookup(v)); - for (Value nextBuffer : nextBuffers) - yieldValues.push_back(nextBuffer); - for (Value nextSlice : extractSlices) - yieldValues.push_back(nextSlice); - - for (size_t i = depArgsBeginIdx; i < ivIdx; ++i) { - auto arg = newForOp.getRegionIterArgs()[i]; - assert(depArgsMapping.count(arg) && "Missing loop-carried value"); - yieldValues.push_back(depArgsMapping[arg]); - } - - // Loop iteration args - yieldValues.push_back(nextIV); - yieldValues.push_back(pipelineIterIdx); - yieldValues.push_back(curWaitIdx); - if (numLoadsRequireMBarrier > 0) { - yieldValues.push_back(loopIterIdx); - yieldValues.push_back(curPhase); - yieldValues.push_back(curEmptyPhase); - } - - builder.setInsertionPointToEnd(newForOp.getBody()); - builder.create(yieldOp->getLoc(), yieldValues); -} - -scf::ForOp LoopPipeliner::createNewForOp() { - OpBuilder builder(forOp); - auto newLoopArgs = collectNewLoopArgs(); - auto newForOp = cloneForOp(newLoopArgs, builder); - prefetchNextIteration(newForOp, builder); - finalizeYield(newForOp, builder); - return newForOp; -} - -// ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp -struct PipelinePass : public TritonGPUPipelineBase { - PipelinePass() = default; - PipelinePass(int numStages, int numWarps, int numCTAs, - int computeCapability) { - this->numStages = numStages; - this->numWarps = numWarps; - this->numCTAs = numCTAs; - this->computeCapability = computeCapability; - } - - void runOnOperation() override { - // TODO[goostavz]: mode = 0 is temporary for backward compatible, will be - // deprecated after the refactor of pipeline fully gets done - // TODO[goostavz]: When mode = 1, the mask of prefetch insert_slice in the - // prologue is currently not properly provided. Need some second thought on - // the mask definition of InsertSliceOp when the src is ptr - bool mode = - computeCapability >= 90 && ::triton::tools::getBoolEnv("ENABLE_TMA"); - if (this->numStages <= 1) - return; - - // phase 0: pipeline loads in loops - // Pre-processing - // we make sure element-wise ops are done *after* the conversion - // to dot operands - // we can achieve this with simple recursive pattern matching - // MLIRContext *context = &getContext(); - // mlir::RewritePatternSet patterns(context); - // patterns.add(context); - // auto didPreprocess = - // applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - - llvm::SmallVector newForOps; - - // Currently we schedule stage 0 after stage `numStages - 1` during - // pipelining therefore we only need `numStages - 1` slice of memory. - // On Hopper we have a separate post-processing that pipelines wgmma so we - // need an extra buffer for each input. - // Note that an alternative would be to keep allocating `numStages` buffers - // and remove the barrier between the loads from shared memory and the - // copies from global to shared. This would require improving existing - // membar analysis. - int numSharedMemorySlices = - computeCapability < 90 ? numStages - 1 : numStages; - - // Do the pipelining - getOperation()->walk([&](scf::ForOp forOp) -> void { - LoopPipeliner pipeliner(forOp, this->numStages, this->numWarps, - this->numCTAs, mode, numSharedMemorySlices, - consumerReleaseMap); - if (pipeliner.initialize().failed()) - return; - - pipeliner.emitPrologue(); - scf::ForOp newForOp = pipeliner.createNewForOp(); - pipeliner.emitEpilogue(); - newForOps.push_back(newForOp); - - // Replace the original loop - for (unsigned i = 0; i < forOp->getNumResults(); ++i) - forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i)); - forOp->erase(); - }); - - // phase 1: pipeline dots in loops - // A tt.dot suitable for GMMA will be converted to ttg.dot_async. And a - // ttg.DotWaitOp will synchronize it lagging just one iteration, which is - // a hueristic rule. - for (auto forOp : newForOps) - asyncLaunchDots(forOp); - - // phase 2: emit consumer_release (empty barrier arrive) logics in case of - // TMA multicast. - // For each load ops, it is emitted after its last consumer, if the consumer - // is another async op, find its associated sync op. Each async load will be - // emitted with a consumer_release action. The merge of redundant mbarriers - // will be processed in the consequent OptimizeBarriers pass. - for (const auto &item : consumerReleaseMap) - emitConsumerRelease(item.first, item.second, numStages); - } - -private: - Value getRemoteCTAId(OpBuilder &b, Location loc, ttg::CTALayoutAttr CTALayout, - Value remoteCTAIdIdx) const; - void updateConsumerReleaseInfo(Operation *oldOp, Operation *newOp, int stage); - void asyncLaunchDots(scf::ForOp forOp); - void emitConsumerRelease(Value mbarTensor, const ConsumerReleaseInfo &info, - int numStages); - - ConsumerReleaseMap consumerReleaseMap; -}; - -void PipelinePass::updateConsumerReleaseInfo(Operation *oldOp, Operation *newOp, - int stage) { - for (auto &item : consumerReleaseMap) { - auto &m = item.second.consumerStageMap; - if (m.count(oldOp)) { - m.erase(oldOp); - m[newOp] = stage; - } - - for (Value operand : oldOp->getOperands()) { - Operation *op = operand.getDefiningOp(); - if (op && isa(op)) { - auto cvt = cast(op); - auto src = cvt.getSrc(); - auto srcEncoding = src.getType().cast().getEncoding(); - auto dstEncoding = - cvt.getResult().getType().cast().getEncoding(); - if (srcEncoding == dstEncoding && m.count(op)) { - m.erase(op); - m[newOp] = stage; - } - } - } - } -} - -void PipelinePass::asyncLaunchDots(scf::ForOp forOp) { - Block *loop = forOp.getBody(); - - /// XXX(Keren): Clean up the following duplicate code with checkDotOp - /// dots to be pipelined - SmallVector dots; - SmallVector resultNeedSync; - for (Operation &op : *loop) { - if (auto dotOp = dyn_cast(&op)) { - auto resTy = dotOp.getResult().getType().dyn_cast(); - if (auto resEnc = resTy.getEncoding().dyn_cast()) { - if (resEnc && resEnc.isHopper()) { - // Don't pipeline valid dots that depend on ops other than scf.yield - // and scf.for - auto dot = dotOp.getResult(); - bool valid = true; - - // all users of dot should be scf.yield - if (!dot.hasOneUse()) - valid = false; - if (!isa(*dot.getUsers().begin())) - valid = false; - - // C should be a block argument - auto CArg = dotOp.getOperand(2).dyn_cast(); - if (!CArg || !CArg.hasOneUse()) - valid = false; - - if (valid) { - dots.push_back(dotOp); - resultNeedSync.push_back( - dotOp->getUses().begin()->getOperandNumber()); - } - } - } - } - } - - // Early stop: no need to continue if there is no valid dot in the loop. - if (dots.empty()) - return; - - OpBuilder builder(forOp); - // 0. insert dot_wait after the last dot in the loop as we implicitly pipeline - // wgmma ops by one stage. - // This is needed to prevent shared memory inputs to be overriden before the - // operation is completed. - // TODO: merge this with the rest of the pipelining transformation and look at - // a better representation for async dots. - tt::DotOp lastDot = dots.back(); - builder.setInsertionPointAfter(lastDot); - auto dotWait = builder.create( - lastDot.getLoc(), lastDot.getResult(), dots.size()); - - // 1. replace Dot with DotAsync - for (size_t idx = 0; idx < dots.size(); ++idx) { - tt::DotOp dotOp = dots[idx]; - builder.setInsertionPoint(dotOp); - auto dotAsync = builder.create( - dotOp.getLoc(), dotOp.getA(), dotOp.getB(), dotOp.getC(), - dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc()); - dotOp.replaceAllUsesWith(dotAsync.getResult()); - updateConsumerReleaseInfo(dotOp, dotWait, /*stage=*/1); - dotOp->erase(); - } - - // 2. If there's any outstanding DotAsyncOps, we need to wait for them. - builder.setInsertionPointAfter(forOp); - for (unsigned resultIndex : resultNeedSync) { - Value result = forOp->getResult(resultIndex); - if (result.use_empty()) - continue; - auto dotWait = - builder.create(forOp.getLoc(), result, 0); - result.replaceAllUsesExcept(dotWait.getResult(), dotWait); - } -} - -Value PipelinePass::getRemoteCTAId(OpBuilder &b, Location loc, - ttg::CTALayoutAttr CTALayout, - Value remoteCTAIdIdx) const { - auto CTAsPerCGA = CTALayout.getCTAsPerCGA(); - auto CTAOrder = CTALayout.getCTAOrder(); - auto CTASplitNum = CTALayout.getCTASplitNum(); - - // Short path when bcastMask is a constant - bool isConstMcastMask = true; - for (unsigned s : CTASplitNum) { - if (s > 1) { - isConstMcastMask = false; - break; - } - } - if (isConstMcastMask) - return remoteCTAIdIdx; - - Value linearCTAId = b.create(loc); - SmallVector multiDimCTAId = - delinearize(b, loc, linearCTAId, CTAsPerCGA, CTAOrder); - auto rank = CTAOrder.size(); - int bcastDim = -1; - for (size_t i = 0; i < rank; ++i) { - if (CTAsPerCGA[i] != CTASplitNum[i]) { - assert(bcastDim < 0 && "bcast in multiple dims is not expected"); - bcastDim = i; - } - } - multiDimCTAId[bcastDim] = remoteCTAIdIdx; - return linearize(b, loc, multiDimCTAId, CTAsPerCGA, CTAOrder); -} - -void PipelinePass::emitConsumerRelease(Value mbarTensor, - const ConsumerReleaseInfo &info, - int numStages) { - Value iterVar = info.iterVar; - Value stage = info.stageVar; - Value phase = info.phaseVar; - Value nextIV = info.nextIVVar; - Value step = info.stepVar; - Value upperBound = info.upperBoundVar; - - const auto &consumerStageMap = info.consumerStageMap; - // find the the last consumer among all the consumers with the largest stage. - SmallVector consumersWithLargestStage; - int maxStage = 0; - for (const auto &it : consumerStageMap) { - if (it.second > maxStage) { - consumersWithLargestStage.clear(); - consumersWithLargestStage.push_back(it.first); - maxStage = it.second; - } else if (it.second == maxStage) { - consumersWithLargestStage.push_back(it.first); - } - } - assert(consumersWithLargestStage.size() > 0); - DenseMap operationId; - consumersWithLargestStage[0]->getBlock()->walk( - [&](Operation *op) { operationId[op] = operationId.size(); }); - size_t maxId = 0; - Operation *lastUserWithLargestStage; - for (Operation *op : consumersWithLargestStage) { - assert(operationId.find(op) != operationId.end()); - size_t userId = operationId[op]; - if (userId > maxId) { - maxId = userId; - lastUserWithLargestStage = op; - } - } - - OpBuilder b(&getContext()); - b.setInsertionPointAfter(lastUserWithLargestStage); - auto loc = lastUserWithLargestStage->getLoc(); - auto maxStageVal = b.create(loc, maxStage, 32); - - // pred = (iterVar >= maxStage) && - // (threadId % (numConsumerThreads / numRemoteCTAs) == 0); - - // [benzh] maybe we can simplify the logics here - auto cmpOp = arith::CmpIPredicate::sge; - if (maxStage == 0) - cmpOp = arith::CmpIPredicate::sgt; - Value pred = b.create(loc, cmpOp, iterVar, maxStageVal); - - Value threadId = b.create(loc); - auto CTAsPerCGA = info.CTALayout.getCTAsPerCGA(); - auto CTASplitNum = info.CTALayout.getCTASplitNum(); - auto numRemoteCTAs = std::accumulate(CTAsPerCGA.begin(), CTAsPerCGA.end(), 1, - std::multiplies{}) / - std::accumulate(CTASplitNum.begin(), CTASplitNum.end(), - 1, std::multiplies{}); - auto numConsumerThreads = - isa(lastUserWithLargestStage) ? 128 : 32; - Value _0 = b.create(loc, 0, 32); - Value numArrives = b.create( - loc, numConsumerThreads / numRemoteCTAs, 32); - pred = b.create( - loc, pred, - b.create( - loc, arith::CmpIPredicate::eq, - b.create(loc, threadId, numArrives), _0)); - // remoteCtaIdIdx = (threadId % numConsumerThreads) / (numConsumerThreads / - // numRemoteCTAs); - Value remoteCTAIdIdx = b.create( - loc, - b.create( - loc, threadId, - b.create(loc, numConsumerThreads, 32)), - numArrives); - Value remoteCTAId = getRemoteCTAId(b, loc, info.CTALayout, remoteCTAIdIdx); - Value emptyBarrier = b.create( - loc, tt::PointerType::get(b.getIntegerType(64), 3), mbarTensor, stage); - - Value newNextIV = b.create(loc, nextIV, step); - Value nextLoopCond = b.create(loc, arith::CmpIPredicate::slt, - newNextIV, upperBound); - auto ifOp = b.create(loc, ArrayRef{}, nextLoopCond, - /*hasElse*/ false); - b.setInsertionPointToStart(ifOp.thenBlock()); - - b.create(loc, emptyBarrier, pred, remoteCTAId, - /*trackAsyncOp*/ false); -} - -} // anonymous namespace - -std::unique_ptr mlir::createTritonGPUPipelinePass(int numStages, - int numWarps, - int numCTAs, - int computeCapability) { - return std::make_unique(numStages, numWarps, numCTAs, - computeCapability); -} diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp new file mode 100644 index 000000000..e1493140a --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -0,0 +1,814 @@ +#include "PipelineExpander.h" +#include "Schedule.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/Support/Debug.h" + +#define int_attr(num) builder.getI64IntegerAttr(num) + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +// TODO: We can extra some helpers into common utilities once we add more +// schedules. + +/// Replace the yield with a new one with the given operands appended. +static void appendToYield(scf::ForOp forOp, ArrayRef newOperands) { + // Fix up the yield op. + Operation *yieldOp = forOp.getBody()->getTerminator(); + SmallVector operands(yieldOp->getOperands().begin(), + yieldOp->getOperands().end()); + operands.append(newOperands.begin(), newOperands.end()); + OpBuilder builder(yieldOp); + builder.create(yieldOp->getLoc(), operands); + yieldOp->erase(); +} + +static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, + Value insertIdx, Value extractIdx) { + OpBuilder builder(forOp); + // Replace the load with insert/extract slice. + builder.setInsertionPoint(loadOp); + Location loc = loadOp.getLoc(); + auto insertOp = builder.create( + loc, alloc.getType(), loadOp.getPtr(), alloc, insertIdx, loadOp.getMask(), + loadOp.getOther(), loadOp.getCache(), loadOp.getEvict(), + loadOp.getIsVolatile(), /*axis*/ 0); + auto commmit = builder.create(loc); + + // Extract part. + auto allocType = alloc.getType().cast(); + RankedTensorType sliceType = RankedTensorType::get( + {allocType.getShape()[1], allocType.getShape()[2]}, + allocType.getElementType(), allocType.getEncoding()); + auto extract = builder.create( + loc, sliceType, insertOp.getResult(), + SmallVector{extractIdx, int_attr(0), int_attr(0)}, + SmallVector{int_attr(1), int_attr(sliceType.getShape()[0]), + int_attr(sliceType.getShape()[1])}, + SmallVector{int_attr(1), int_attr(1), int_attr(1)}); + Operation *user = *loadOp.getResult().getUsers().begin(); + auto convertLayout = llvm::cast(user); + auto newCvt = builder.create( + convertLayout->getLoc(), convertLayout.getType(), extract.getResult()); + convertLayout->replaceAllUsesWith(newCvt->getResults()); + convertLayout->erase(); + loadOp.erase(); + + // Fix up the yield op. + appendToYield(forOp, {insertOp}); +} + +static void createTMALoad(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, + Value insertIdx, Value extractIdx, Value phase) { + OpBuilder builder(forOp); + Location loc = loadOp.getLoc(); + auto CTALayout = ttg::CTALayoutAttr::get(loadOp.getContext(), + /*CTAsPerCGA*/ {1}, + /*CTASplitNum*/ {1}, + /*CTAOrder*/ {0}); + auto sharedEncoding = ttg::SharedEncodingAttr::get(loadOp.getContext(), 1, 1, + 1, {0}, CTALayout, false); + int64_t numBuffers = alloc.getType().cast().getShape()[0]; + auto mBarriersTy = RankedTensorType::get( + {numBuffers}, builder.getIntegerType(64), sharedEncoding); + // Allocate an array of mbarrier objects outside the loop. + Value barrierArray = + builder.create(loc, mBarriersTy, 1); + // extract the barrier and emit arriver/copy/wait/extract code sequence. + builder.setInsertionPoint(loadOp); + auto mBarTy = tt::PointerType::get(builder.getIntegerType(64), 3); + Value barrier = builder.create( + loc, mBarTy, barrierArray, insertIdx); + Value zero = builder.create(loc, 0, 32); + Value threadId = builder.create(loc); + Value pred = builder.create(loc, arith::CmpIPredicate::eq, + threadId, zero); + + auto loadTy = loadOp.getType().dyn_cast(); + auto loadShape = loadTy.getShape(); + auto CTASplitNum = ttg::getCTASplitNum(loadTy.getEncoding()); + auto shapePerSlice = ttg::getShapePerCTA(CTASplitNum, loadShape); + auto elemTy = loadTy.getElementType(); + unsigned elems = std::accumulate(shapePerSlice.begin(), shapePerSlice.end(), + 1, std::multiplies{}); + elems *= (elemTy.getIntOrFloatBitWidth() / 8); + builder.create(loc, barrier, pred, + /*remoteCtaId*/ nullptr, + /*trackAsyncOp*/ false, elems); + auto allocType = alloc.getType().cast(); + auto insertOp = builder.create( + loc, allocType, loadOp.getPtr(), alloc, + /*index*/ insertIdx, barrier, loadOp.getMask(), loadOp.getOther(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile(), + /*axis*/ 0); + + RankedTensorType sliceType = RankedTensorType::get( + {allocType.getShape()[1], allocType.getShape()[2]}, + allocType.getElementType(), allocType.getEncoding()); + auto extract = builder.create( + loc, sliceType, insertOp.getResult(), + SmallVector{extractIdx, int_attr(0), int_attr(0)}, + SmallVector{int_attr(1), int_attr(sliceType.getShape()[0]), + int_attr(sliceType.getShape()[1])}, + SmallVector{int_attr(1), int_attr(1), int_attr(1)}); + + Value barrierWait = builder.create( + loc, mBarTy, barrierArray, extractIdx); + builder.create(loc, barrierWait, phase); + + Operation *user = *loadOp.getResult().getUsers().begin(); + auto convertLayout = llvm::cast(user); + auto newCvt = builder.create( + convertLayout->getLoc(), convertLayout.getType(), extract.getResult()); + convertLayout->replaceAllUsesWith(newCvt->getResults()); + convertLayout->erase(); + loadOp.erase(); + + // Fix up the yield op. + appendToYield(forOp, {insertOp}); +} + +/// Create an async load equivalent to the given load. +static void createAsyncLoad(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, + Value insertIdx, Value extractIdx, Value phase) { + if (isLoadFromTensorPtr(loadOp)) { + createTMALoad(forOp, loadOp, alloc, insertIdx, extractIdx, phase); + } else { + createAsyncCopy(forOp, loadOp, alloc, insertIdx, extractIdx); + } +} + +// Return the transitive use of the load which is a dot operand. +static Value loadDotOperand(tt::LoadOp loadOp, bool &hasMMAV3) { + // We only pipeline loads that have one covert_layout (to dot_op) use + // TODO: lift this constraint in the future + bool isCandidate = false; + if (!loadOp.getResult().hasOneUse()) + return Value(); + + Operation *use = *loadOp.getResult().getUsers().begin(); + if (auto convertLayout = llvm::dyn_cast(use)) { + auto tensorType = + convertLayout.getResult().getType().cast(); + if (auto sharedEnc = + tensorType.getEncoding().dyn_cast()) { + if (sharedEnc.getHasLeadingOffset()) { + // MMA V3 case. + auto newOrder = sharedEnc.getOrder(); + auto ty = loadOp.getType().cast(); + auto oldOrder = ttg::getOrder(ty.getEncoding()); + if (newOrder[0] == oldOrder[0] || newOrder[1] == oldOrder[1]) { + // The operand of MMAv3 is in SharedEncoding and it's order should + // not be changed after FuseTranspositions Pass. So we only pipeline + // the load if the order of the loaded BlockedEncoding is the same + // as the order of the SharedEncoding it is converted to. + // TODO: remove this constraint once the LoadOp supports transpose + // fusion + hasMMAV3 = true; + return convertLayout.getResult(); + } + } + } + } + // Advance to the first conversion as long as the use resides in shared + // memory and it has a single use itself + while (use) { + if (use->getNumResults() != 1 || !use->getResult(0).hasOneUse()) + break; + auto tensorType = use->getResult(0).getType().dyn_cast(); + if (!tensorType.getEncoding().isa()) + break; + use = *use->getResult(0).getUsers().begin(); + } + + if (auto convertLayout = llvm::dyn_cast(use)) { + if (auto tensorType = + convertLayout.getResult().getType().dyn_cast()) { + if (auto dotOpEnc = tensorType.getEncoding() + .dyn_cast()) { + return convertLayout.getResult(); + } + } + } + return Value(); +} + +namespace { +struct LoadDotOperand { + LoadDotOperand(tt::LoadOp load, Value dotOperand) + : load(load), dotOperand(dotOperand) {} + tt::LoadOp load; + Value dotOperand; +}; +} // namespace + +/// Collect loads to pipeline. Return success if we can pipeline this loop +static void collectOpsToPipeline(scf::ForOp forOp, + SmallVectorImpl &ops, + bool &hasMMAV3) { + ModuleOp moduleOp = forOp->getParentOfType(); + ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + + // We cannot use forOp.walk(...) here because we only want to visit the + // operations in the loop body block. Nested blocks are handled separately. + for (Operation &op : forOp) { + if (auto loadOp = dyn_cast(&op)) { + bool candidate = false; + if (isLoadFromTensorPtr(loadOp)) { + // Map to TMA load. + candidate = true; + } else { + auto ptr = loadOp.getPtr(); + unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr); + if (auto mask = loadOp.getMask()) + vec = + std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); + + auto tensorTy = ptr.getType().dyn_cast(); + if (!tensorTy || tensorTy.getRank() < 2) + continue; + auto ty = + tensorTy.getElementType().cast().getPointeeType(); + unsigned width = vec * ty.getIntOrFloatBitWidth(); + // We do not pipeline all loads for the following reasons: + // 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8 and 16. + // 2. It's likely that pipling small loads won't offer much performance + // improvement and may even hurt performance by increasing register + // pressure. + if (width >= 32) + candidate = true; + } + if (!candidate) + continue; + Value dotOperand = loadDotOperand(loadOp, hasMMAV3); + if (!dotOperand) + continue; + ops.emplace_back(loadOp, dotOperand); + } + } +} + +// Create an allocation that can old distance number of loadOp shapes. +static Value createAlloc(scf::ForOp &forOp, tt::LoadOp loadOp, Value dotOperand, + unsigned distance) { + OpBuilder builder(forOp); + auto ty = loadOp.getType().cast(); + if (!loadOp.getResult().hasOneUse()) + return Value(); + Attribute sharedEnc; + auto CTALayout = ttg::getCTALayout(ty.getEncoding()); + auto tensorType = dotOperand.getType().cast(); + if (auto dotOpEnc = + tensorType.getEncoding().dyn_cast()) { + auto convertLayout = dotOperand.getDefiningOp(); + bool needTrans = dyn_cast_or_null( + convertLayout->getOperand(0).getDefiningOp()); + unsigned bitWidth = ty.getElementType().getIntOrFloatBitWidth(); + sharedEnc = ttg::SharedEncodingAttr::get( + ty.getContext(), dotOpEnc, ty.getShape(), + ttg::getOrder(ty.getEncoding()), CTALayout, bitWidth, needTrans); + } else { + // MMAv3 + sharedEnc = ttg::SharedEncodingAttr::get(ty.getContext(), ty.getShape(), + ttg::getOrder(ty.getEncoding()), + CTALayout, ty.getElementType()); + } + SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); + bufferShape.insert(bufferShape.begin(), distance); + Type allocType = + RankedTensorType::get(bufferShape, ty.getElementType(), sharedEnc); + Value alloc = builder.create( + loadOp.getLoc(), allocType); + return alloc; +} + +// Convert load ops into their asyn version and apply multi-buffering based on +// the number of stages. +static void createAsynOps(scf::ForOp &forOp, ArrayRef loads, + int numStages, bool hasMMAV3) { + struct AsyncLoad { + AsyncLoad(tt::LoadOp loadOp, Value alloc) : loadOp(loadOp), alloc(alloc) {} + tt::LoadOp loadOp; + Value alloc; + }; + int numBuffers = numStages - 1; + // For MMAv3 we need an extra buffer as this is assumed in the wgmma + // pipelining post-processing. + // TODO: Improve modeling of wgmma pipelining. + if (hasMMAV3) + numBuffers++; + SmallVector asyncLoads; + SmallVector newOperands; + bool needsMbarrierPhase = false; + bool needsAsyncWait = false; + for (const LoadDotOperand &loadOperand : loads) { + tt::LoadOp loadOp = loadOperand.load; + Value dotOperand = loadOperand.dotOperand; + Value alloc = createAlloc(forOp, loadOp, dotOperand, numBuffers); + assert(alloc && "Failed to create alloc for the async load."); + newOperands.push_back(alloc); + asyncLoads.emplace_back(loadOp, alloc); + if (isLoadFromTensorPtr(loadOp)) + needsMbarrierPhase = true; + else + needsAsyncWait = true; + } + + OpBuilder builder(forOp); + Location loc = forOp.getLoc(); + // Create two new counters to index into the allocs. + Value minusOne = builder.create(loc, -1, 32); + Value zero = builder.create(loc, 0, 32); + Value one = builder.create(loc, 1, 32); + Value insertIdx = minusOne; + Value extractIdx = minusOne; + Value numBuffersVal = + builder.create(loc, numBuffers, 32); + newOperands.push_back(insertIdx); + newOperands.push_back(extractIdx); + Value phase; + if (needsMbarrierPhase) { + phase = builder.create(loc, 0, 1); + newOperands.push_back(phase); + } + unsigned newOperandIndex = forOp.getBody()->getNumArguments(); + // Patch the loop to add the new loop carried dependencies. + scf::ForOp newForOp = + replaceForOpWithNewSignature(builder, forOp, newOperands); + forOp.erase(); + forOp = newForOp; + for (int i = 0; i < asyncLoads.size(); i++) { + asyncLoads[i].alloc = newForOp.getBody()->getArgument(newOperandIndex + i); + } + insertIdx = + newForOp.getBody()->getArgument(newOperandIndex + asyncLoads.size()); + extractIdx = + newForOp.getBody()->getArgument(newOperandIndex + asyncLoads.size() + 1); + + // Create two counters for the insert and extract indices to avoid creating + // long liverange. + builder.setInsertionPoint(asyncLoads.front().loadOp); + insertIdx = builder.create(loc, insertIdx, one); + Value cndIns = builder.create(loc, arith::CmpIPredicate::slt, + insertIdx, numBuffersVal); + insertIdx = builder.create(loc, cndIns, insertIdx, zero); + + extractIdx = builder.create(loc, extractIdx, one); + Value cndExt = builder.create(loc, arith::CmpIPredicate::slt, + extractIdx, numBuffersVal); + extractIdx = builder.create(loc, cndExt, extractIdx, zero); + + if (needsMbarrierPhase) { + phase = newForOp.getBody()->getArgument(newOperandIndex + + asyncLoads.size() + 2); + Value oneI1 = builder.create(loc, 1, 1); + Value nextPhase = builder.create(loc, phase, oneI1); + phase = builder.create(loc, cndExt, phase, nextPhase); + } + + bool firstLoad = true; + for (AsyncLoad &asyncLoad : asyncLoads) { + createAsyncLoad(forOp, asyncLoad.loadOp, asyncLoad.alloc, insertIdx, + extractIdx, phase); + firstLoad = false; + } + // Insert a waitOp after the first async copy. This does make the assumption + // that the wait will be scheduled in a different stage that all the async + // copy but we cannot guarantee that one wait is enough otherwise. + for (auto &op : forOp.getBody()->without_terminator()) { + if (isa(op)) { + OpBuilder builder(op.getContext()); + builder.setInsertionPointAfter(&op); + builder.create(op.getLoc(), 0); + break; + } + } + SmallVector newYieldOperands = {insertIdx, extractIdx}; + if (needsMbarrierPhase) + newYieldOperands.push_back(phase); + // Patch the yield with the updated counters. + appendToYield(forOp, newYieldOperands); +} + +// Combine the current mask with the given predicate. +static Value getPredMask(RewriterBase &rewriter, Type typeLike, + Value currentMask, Value pred) { + Type maskType = tt::getI1SameShape(typeLike); + Location loc = pred.getLoc(); + Value mask = pred; + if (maskType.isa()) { + mask = rewriter.create(loc, maskType, pred); + } + if (currentMask) { + mask = rewriter.create(loc, mask, currentMask); + } + return mask; +} + +// Function to mask operations during scheduling. +static Operation *predicateOp(RewriterBase &rewriter, Operation *op, + Value pred) { + OpBuilder::InsertionGuard guard(rewriter); + if (mlir::isMemoryEffectFree(op)) + return op; + if (isa(op)) + return op; + if (isa(op)) + return op; + if (auto insertOp = dyn_cast(op)) { + rewriter.setInsertionPoint(insertOp); + Value mask = getPredMask(rewriter, insertOp.getSrc().getType(), + insertOp.getMask(), pred); + insertOp.getMaskMutable().assign(mask); + return op; + } + if (auto insertOp = dyn_cast(op)) { + rewriter.setInsertionPoint(insertOp); + Value mask = getPredMask( + rewriter, + insertOp.getSrc().getType().cast().getPointeeType(), + insertOp.getMask(), pred); + insertOp.getMaskMutable().assign(mask); + return op; + } + if (auto arriveOp = dyn_cast(op)) { + rewriter.setInsertionPoint(arriveOp); + Value mask = getPredMask(rewriter, rewriter.getIntegerType(1), + arriveOp.getPred(), pred); + arriveOp.getPredMutable().assign(mask); + return op; + } + if (isa(op)) { + return op; + } + if (auto loadOp = dyn_cast(op)) { + rewriter.setInsertionPoint(loadOp); + Value mask = getPredMask(rewriter, loadOp.getPtr().getType(), + loadOp.getMask(), pred); + loadOp.getMaskMutable().assign(mask); + return op; + } + + assert("don't know how to predicate this op" && false); + return op; +} + +static void setWaitNum(Operation *op, + mlir::triton::PipeliningOption::PipelinerPart part, + unsigned iteration, unsigned numLoadsInStage) { + if (auto waitOp = dyn_cast(op)) { + waitOp.setNum(numLoadsInStage); + } +} + +/// Helper to recursively add dependencies to the same stage. +static void addDep(Operation *op, DenseSet &deps, + bool includeArg = true, + DenseSet *filter = nullptr) { + if (filter && filter->count(op)) + return; + if (!deps.insert(op).second) + return; + for (Value operand : op->getOperands()) { + Value v = operand; + llvm::SmallDenseSet seen; + while (auto arg = v.dyn_cast()) { + if (!includeArg) + break; + if (!seen.insert(v).second) + break; + if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) { + auto yieldOp = op->getBlock()->getTerminator(); + v = yieldOp->getOperand(arg.getArgNumber() - 1); + continue; + } + break; + } + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + addDep(defOp, deps, includeArg, filter); + } + } +} + +// Add operations to the shedule with the given stage based on the filter +// function. +static void addOps(scf::ForOp forOp, int stage, + std::vector> &schedule, + std::function filter) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!filter(&op)) + continue; + schedule.emplace_back(&op, stage); + } +} + +// create the schedule for a matmul loop. This is ad hoc based on how we know +// matmul loops should be pipelined and is not a generic scheduler. +static std::vector> +createSchedule(scf::ForOp forOp, int numStages, bool prefetchExtract) { + SmallVector insertOps; + SmallVector extractOps; + // Find the insert/extract ops that will go respectively in stage 0 and stage + // `numStages - 2`. All the other operations will go in stage `numStages - 1`. + for (Operation &op : forOp.getBody()->without_terminator()) { + if (isa(op)) + insertOps.emplace_back(&op); + if (prefetchExtract) { + if (isa(op)) + extractOps.emplace_back(&op); + } + } + DenseSet insertAndDeps; + for (Operation *op : insertOps) { + addDep(op, insertAndDeps, false); + } + + // Find depenencies with distance of 1. + SmallVector distanceOneUsers; + for (Operation *op : insertAndDeps) { + for (Value operand : op->getOperands()) { + if (auto arg = operand.dyn_cast()) { + if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) { + auto yieldOp = op->getBlock()->getTerminator(); + Value v = yieldOp->getOperand(arg.getArgNumber() - 1); + Operation *defOp = v.getDefiningOp(); + if (defOp && insertAndDeps.count(defOp) == 0) { + distanceOneUsers.push_back(defOp); + } + } + } + } + } + // Schedule loads with a distance of 1 in stage 0 + for (Operation *op : distanceOneUsers) { + if (isa(op)) { + addDep(op, insertAndDeps, true); + } + } + // For the rest of the ops we can move then into stage 1 so that they can be + // closer to their uses. + DenseSet stage1deps; + for (Operation *op : distanceOneUsers) { + if (!isa(op)) { + addDep(op, stage1deps, true, &insertAndDeps); + } + } + + DenseSet extractAndDeps; + for (Operation *op : extractOps) { + addDep(op, extractAndDeps, true, &insertAndDeps); + } + std::vector> schedule; + // Schedule stage `numStage - 1` first. + addOps(forOp, numStages - 1, schedule, [&](Operation *op) { + return insertAndDeps.count(op) == 0 && stage1deps.count(op) == 0 && + extractAndDeps.count(op) == 0; + }); + + // Schedule some dependencies with distance of 1 into stage 1 to reduce + // pressure. + addOps(forOp, 1, schedule, + [&](Operation *op) { return stage1deps.count(op); }); + + // Then Schedule stage 0. + addOps(forOp, 0, schedule, + [&](Operation *op) { return insertAndDeps.count(op); }); + + // Finally schedule the extract ops in stage `numStage - 2` so that they get + // pre-fetched and play well with pretech pass. + addOps(forOp, numStages - 2, schedule, + [&](Operation *op) { return extractAndDeps.count(op); }); + return schedule; +} + +bool mlir::triton::preProcessLoopAndGetSchedule( + scf::ForOp &forOp, int numStages, mlir::triton::PipeliningOption &options) { + // 1. First collect "interesting" operations with a stage where to schedule + // them. This gives a coarse scheduling for the loop. + SmallVector loads; + bool hasMMAV3 = false; + collectOpsToPipeline(forOp, loads, hasMMAV3); + if (loads.empty()) + return false; + bool hasAsynCp = llvm::any_of(loads, [](LoadDotOperand &load) { + return !isLoadFromTensorPtr(load.load); + }); + // 2. Convert the loads into async loads and create the allocs. + createAsynOps(forOp, loads, numStages, hasMMAV3); + + // 3. Create the final schedule for the kernel loop. This will dictate the + // stages and order of operations to the pipeline expander. + std::vector> schedule = + createSchedule(forOp, numStages, /*prefetchExtract=*/!hasMMAV3); + + // 4. Fill out the pipeline options. + options.getScheduleFn = + [schedule](scf::ForOp forOp, + std::vector> &s) { + s = std::move(schedule); + }; + options.peelEpilogue = false; + options.predicateFn = predicateOp; + options.supportDynamicLoops = true; + unsigned numLoadsInStage = (numStages - 2) * loads.size(); + options.annotateFn = + [numLoadsInStage](Operation *op, + mlir::triton::PipeliningOption::PipelinerPart part, + unsigned iteration) { + return setWaitNum(op, part, iteration, numLoadsInStage); + }; + + if (hasAsynCp) { + // Insert a wait 0 after the loop + OpBuilder builder(forOp); + builder.setInsertionPointAfter(forOp); + builder.create(forOp.getLoc(), 0); + } + return true; +} + +/// MMA V3 post-processing. +static bool selfDepend(tt::DotOp dotOp, scf::ForOp forOp, + Operation **firstUse) { + std::function dependOn = + [&dependOn](Value v, int argId, scf::ForOp forOp) { + auto op = v.getDefiningOp(); + if (isa(v)) { + auto iterArgs = forOp.getRegionIterArgs(); + auto iter = std::find(iterArgs.begin(), iterArgs.end(), v); + if (iter != iterArgs.end()) + return std::distance(iterArgs.begin(), iter) == argId; + } else { + if (!op) + return false; + for (auto operand : op->getOperands()) { + if (dependOn(operand, argId, forOp)) + return true; + } + } + return false; + }; + auto result = dotOp.getResult(); + auto yieldOp = forOp.getBody()->getTerminator(); + int argIdx = -1; + auto iter = std::find(yieldOp->getOperands().begin(), + yieldOp->getOperands().end(), result); + if (iter != yieldOp->getOperands().end()) + argIdx = std::distance(yieldOp->getOperands().begin(), iter); + if (argIdx == -1) + return false; + for (auto operand : dotOp.getOperands()) { + if (dependOn(operand, argIdx, forOp)) { + auto iterArgs = forOp.getRegionIterArgs(); + *firstUse = iterArgs[argIdx].use_begin().getUser(); + return true; + } + } + return false; +} + +static void removeExtraWait(tt::nvidia_gpu::DotWaitOp dotWaitOp, + bool hasDotWait0) { + if (hasDotWait0) { + dotWaitOp->erase(); + } +} + +void mlir::triton::asyncLaunchDots(scf::ForOp forOp) { + Block *loop = forOp.getBody(); + auto getBlockNumInFor = [](Operation *op, scf::ForOp forOp) { + if (!op) + return -1l; + auto lastOp = op; + while (op->getBlock()->getParentOp() != forOp) { + lastOp = op; + op = op->getBlock()->getParentOp(); + } + return std::distance(lastOp->getBlock()->getParent()->begin(), + lastOp->getBlock()->getIterator()); + }; + /// XXX(Keren): Clean up the following duplicate code with checkDotOp + /// dots to be pipelined + bool hasSyncDot = false; + bool hasDotWait0 = false; + SmallVector allDots; + SmallVector dots; + SmallVector resultNeedSync; + for (Operation &op : *loop) { + if (auto dotWaitOp = dyn_cast(&op)) { + auto attr = dotWaitOp->getAttrOfType("pendings"); + auto pendingCount = attr.getInt(); + if (pendingCount == 0) + hasDotWait0 = true; + } + if (auto dotOp = dyn_cast(&op)) { + allDots.push_back(dotOp); + } + } + for (Operation &op : *loop) { + if (auto dotOp = dyn_cast(&op)) { + auto resTy = dotOp.getResult().getType().dyn_cast(); + if (auto resEnc = resTy.getEncoding().dyn_cast()) { + if (resEnc && resEnc.isHopper()) { + auto dot = dotOp.getResult(); + bool valid = true; + + // all users of dot should be scf.yield + if (!dot.hasOneUse()) + valid = false; + if (!isa(*dot.getUsers().begin())) + valid = false; + + Operation *firstUse = nullptr; + auto depend = selfDepend(dotOp, forOp, &firstUse); + bool selfDirectDepend = (dotOp == firstUse); + for (auto tempInAll : allDots) { + auto iter = std::find(dots.begin(), dots.end(), tempInAll); + if (iter != dots.end()) + continue; + auto db = getBlockNumInFor(tempInAll, forOp); + auto fb = getBlockNumInFor(firstUse, forOp); + if (db < fb || + (db == fb && db >= 0 && tempInAll->isBeforeInBlock(firstUse))) + hasSyncDot = true; + } + auto CArg = dotOp.getOperand(2); + if (!(selfDirectDepend || + (depend && !selfDirectDepend && hasSyncDot)) || + !CArg.hasOneUse()) + valid = false; + + if (valid) { + dots.push_back(dotOp); + resultNeedSync.push_back( + dotOp->getUses().begin()->getOperandNumber()); + } + } + } + } + } + + // Early stop: no need to continue if there is no valid dot in the loop. + if (dots.empty()) + return; + + OpBuilder builder(forOp); + // 0. insert dot_wait after the last dot in the loop as we implicitly pipeline + // wgmma ops by one stage. + // This is needed to prevent shared memory inputs to be overriden before the + // operation is completed. + // TODO: merge this with the rest of the pipelining transformation and look at + // a better representation for async dots. + tt::DotOp lastDot = dots.back(); + auto loc = lastDot.getLoc(); + builder.setInsertionPointAfter(lastDot); + auto dotWait = builder.create( + lastDot.getLoc(), lastDot.getResult(), dots.size()); + + // 1. replace Dot with DotAsync + for (size_t idx = 0; idx < dots.size(); ++idx) { + tt::DotOp dotOp = dots[idx]; + builder.setInsertionPoint(dotOp); + auto dotAsync = builder.create( + dotOp.getLoc(), dotOp.getA(), dotOp.getB(), dotOp.getC(), + dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc()); + dotOp.replaceAllUsesWith(dotAsync.getResult()); + dotOp->erase(); + } + + hasDotWait0 = hasDotWait0 || hasSyncDot; + + // 2. If there's any outstanding DotAsyncOps, we need to wait for them. + builder.setInsertionPointAfter(forOp); + SmallVector waitOperands; + for (int i = 0; i < resultNeedSync.size(); ++i) { + Value result = forOp->getResult(resultNeedSync[i]); + if (result.use_empty()) + continue; + waitOperands.push_back(result); + } + if (!waitOperands.empty()) { + auto dotWait = builder.create(forOp.getLoc(), + waitOperands, 0); + for (int i = 0; i < resultNeedSync.size(); ++i) { + Value result = forOp->getResult(resultNeedSync[i]); + result.replaceAllUsesExcept(dotWait.getResult(i), dotWait); + } + } + + // 3. potentially remove redundant dot_wait after dot_async if having mutiple + // DotOp in the loop + removeExtraWait(dotWait, hasDotWait0); +} diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp new file mode 100644 index 000000000..18df341ac --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp @@ -0,0 +1,704 @@ +//===- LoopPipelining.cpp - Code to perform loop software pipelining-------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements loop software pipelining +// +//===----------------------------------------------------------------------===// + +// Fork of upstream pipeliner. This will be merged upstream once things are +// stable. Modifications so far are: +// -Bug fix for def with a distance of 1 scheduled in stage 0. +// -Support dynamic loops and predicate operations in the prologue. +// -Support for non-index type for induction variable. +// -Support source with distance of 1 used multiple stages later. +// -Fix bug when a value yield is used outside the loop and the value def is not +// in the last stage. If we are not peeling the epilgue we need to remap the +// output correctly. + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/MathExtras.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/Support/Debug.h" + +#include "PipelineExpander.h" + +#define DEBUG_TYPE "triton-loop-pipelining" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +using namespace mlir::scf; +using namespace mlir::triton; + +namespace { + +/// Helper to keep internal information during pipelining transformation. +struct LoopPipelinerInternal { + /// Coarse liverange information for ops used across stages. + struct LiverangeInfo { + unsigned lastUseStage = 0; + unsigned defStage = 0; + }; + +protected: + ForOp forOp; + unsigned maxStage = 0; + DenseMap stages; + std::vector opOrder; + Value ub; + Value lb; + Value step; + bool dynamicLoop; + triton::PipeliningOption::AnnotationlFnType annotateFn = nullptr; + bool peelEpilogue; + triton::PipeliningOption::PredicateOpFnType predicateFn = nullptr; + + // When peeling the kernel we generate several version of each value for + // different stage of the prologue. This map tracks the mapping between + // original Values in the loop and the different versions + // peeled from the loop. + DenseMap> valueMapping; + + /// Assign a value to `valueMapping`, this means `val` represents the version + /// `idx` of `key` in the epilogue. + void setValueMapping(Value key, Value el, int64_t idx); + + std::pair getDefiningOpAndDistance(Value value); + +public: + /// Initalize the information for the given `op`, return true if it + /// satisfies the pre-condition to apply pipelining. + bool initializeLoopInfo(ForOp op, const triton::PipeliningOption &options); + /// Emits the prologue, this creates `maxStage - 1` part which will contain + /// operations from stages [0; i], where i is the part index. + void emitPrologue(RewriterBase &rewriter); + /// Gather liverange information for Values that are used in a different stage + /// than its definition. + llvm::MapVector analyzeCrossStageValues(); + scf::ForOp createKernelLoop( + const llvm::MapVector &crossStageValues, + RewriterBase &rewriter, + llvm::DenseMap, unsigned> &loopArgMap); + /// Emits the pipelined kernel. This clones loop operations following user + /// order and remaps operands defined in a different stage as their use. + LogicalResult createKernel( + scf::ForOp newForOp, + const llvm::MapVector &crossStageValues, + const llvm::DenseMap, unsigned> &loopArgMap, + RewriterBase &rewriter); + /// Emits the epilogue, this creates `maxStage - 1` part which will contain + /// operations from stages [i; maxStage], where i is the part index. + llvm::SmallVector emitEpilogue(RewriterBase &rewriter); +}; + +bool LoopPipelinerInternal::initializeLoopInfo( + ForOp op, const triton::PipeliningOption &options) { + LDBG("Start initializeLoopInfo"); + forOp = op; + ub = forOp.getUpperBound(); + lb = forOp.getLowerBound(); + step = forOp.getStep(); + + dynamicLoop = true; + auto upperBoundCst = ub.getDefiningOp(); + auto lowerBoundCst = lb.getDefiningOp(); + auto stepCst = step.getDefiningOp(); + if (!upperBoundCst || !lowerBoundCst || !stepCst) { + if (!options.supportDynamicLoops) { + LDBG("--dynamic loop not supported -> BAIL"); + return false; + } + } else { + int64_t ubImm = upperBoundCst.value(); + int64_t lbImm = lowerBoundCst.value(); + int64_t stepImm = stepCst.value(); + int64_t numIteration = ceilDiv(ubImm - lbImm, stepImm); + if (numIteration > maxStage) { + dynamicLoop = false; + } else if (!options.supportDynamicLoops) { + LDBG("--fewer loop iterations than pipeline stages -> BAIL"); + return false; + } + } + peelEpilogue = options.peelEpilogue; + predicateFn = options.predicateFn; + if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) { + LDBG("--no epilogue or predicate set -> BAIL"); + return false; + } + std::vector> schedule; + options.getScheduleFn(forOp, schedule); + if (schedule.empty()) { + LDBG("--empty schedule -> BAIL"); + return false; + } + + opOrder.reserve(schedule.size()); + for (auto &opSchedule : schedule) { + maxStage = std::max(maxStage, opSchedule.second); + stages[opSchedule.first] = opSchedule.second; + opOrder.push_back(opSchedule.first); + } + + // All operations need to have a stage. + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!stages.contains(&op)) { + op.emitOpError("not assigned a pipeline stage"); + LDBG("--op not assigned a pipeline stage: " << op << " -> BAIL"); + return false; + } + } + + // Currently, we do not support assigning stages to ops in nested regions. The + // block of all operations assigned a stage should be the single `scf.for` + // body block. + for (const auto &[op, stageNum] : stages) { + (void)stageNum; + if (op == forOp.getBody()->getTerminator()) { + op->emitError("terminator should not be assigned a stage"); + LDBG("--terminator should not be assigned stage: " << *op << " -> BAIL"); + return false; + } + if (op->getBlock() != forOp.getBody()) { + op->emitOpError("the owning Block of all operations assigned a stage " + "should be the loop body block"); + LDBG("--the owning Block of all operations assigned a stage " + "should be the loop body block: " + << *op << " -> BAIL"); + return false; + } + } + + // Only support loop carried dependency with a distance of 1. This means the + // source of all the scf.yield operands needs to be defined by operations in + // the loop. + if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), + [this](Value operand) { + Operation *def = operand.getDefiningOp(); + return !def || !stages.contains(def); + })) { + LDBG("--only support loop carried dependency with a distance of 1 -> BAIL"); + return false; + } + annotateFn = options.annotateFn; + return true; +} + +/// Clone `op` and call `callback` on the cloned op's oeprands as well as any +/// operands of nested ops that: +/// 1) aren't defined within the new op or +/// 2) are block arguments. +static Operation * +cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op, + function_ref callback) { + Operation *clone = rewriter.clone(*op); + for (OpOperand &operand : clone->getOpOperands()) + callback(&operand); + clone->walk([&](Operation *nested) { + for (OpOperand &operand : nested->getOpOperands()) { + Operation *def = operand.get().getDefiningOp(); + if ((def && !clone->isAncestor(def)) || isa(operand.get())) + callback(&operand); + } + }); + return clone; +} + +void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { + // Initialize the iteration argument to the loop initiale values. + for (BlockArgument &arg : forOp.getRegionIterArgs()) { + OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg); + setValueMapping(arg, operand.get(), 0); + } + auto yield = cast(forOp.getBody()->getTerminator()); + Location loc = forOp.getLoc(); + for (int64_t i = 0; i < maxStage; i++) { + Value predicate; + if (dynamicLoop) { + Type t = ub.getType(); + // pred = ub > lb + (i * step) + Value iv = rewriter.create( + loc, lb, + rewriter.create( + loc, step, + rewriter.create( + loc, rewriter.getIntegerAttr(t, i)))); + predicate = rewriter.create(loc, arith::CmpIPredicate::slt, + iv, ub); + } + + // special handling for induction variable as the increment is implicit. + // iv = lb + i * step + Type t = lb.getType(); + Value iv = rewriter.create( + loc, lb, + rewriter.create( + loc, step, + rewriter.create(loc, + rewriter.getIntegerAttr(t, i)))); + setValueMapping(forOp.getInductionVar(), iv, i); + for (Operation *op : opOrder) { + if (stages[op] > i) + continue; + Operation *newOp = + cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { + auto it = valueMapping.find(newOperand->get()); + if (it != valueMapping.end()) { + Value replacement = it->second[i - stages[op]]; + newOperand->set(replacement); + } + }); + if (predicate) { + newOp = predicateFn(rewriter, newOp, predicate); + assert(newOp && "failed to predicate op."); + } + if (annotateFn) + annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Prologue, i); + for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { + setValueMapping(op->getResult(destId), newOp->getResult(destId), + i - stages[op]); + // If the value is a loop carried dependency update the loop argument + // mapping. + for (OpOperand &operand : yield->getOpOperands()) { + if (operand.get() != op->getResult(destId)) + continue; + setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], + newOp->getResult(destId), i - stages[op] + 1); + } + } + } + } +} + +std::pair +LoopPipelinerInternal::getDefiningOpAndDistance(Value value) { + int64_t distance = 0; + if (auto arg = dyn_cast(value)) { + if (arg.getOwner() != forOp.getBody()) + return {nullptr, 0}; + // Ignore induction variable. + if (arg.getArgNumber() == 0) + return {nullptr, 0}; + distance++; + value = + forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1); + } + Operation *def = value.getDefiningOp(); + if (!def) + return {nullptr, 0}; + return {def, distance}; +} + +llvm::MapVector +LoopPipelinerInternal::analyzeCrossStageValues() { + llvm::MapVector crossStageValues; + for (Operation *op : opOrder) { + unsigned stage = stages[op]; + + auto analyzeOperand = [&](OpOperand &operand) { + auto [def, distance] = getDefiningOpAndDistance(operand.get()); + if (!def) + return; + auto defStage = stages.find(def); + if (defStage == stages.end() || defStage->second == stage || + defStage->second == stage + distance) + return; + assert(stage > defStage->second); + LiverangeInfo &info = crossStageValues[operand.get()]; + info.defStage = defStage->second; + info.lastUseStage = std::max(info.lastUseStage, stage); + }; + + for (OpOperand &operand : op->getOpOperands()) + analyzeOperand(operand); + visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) { + analyzeOperand(*operand); + }); + } + return crossStageValues; +} + +scf::ForOp LoopPipelinerInternal::createKernelLoop( + const llvm::MapVector + &crossStageValues, + RewriterBase &rewriter, + llvm::DenseMap, unsigned> &loopArgMap) { + // Creates the list of initial values associated to values used across + // stages. The initial values come from the prologue created above. + // Keep track of the kernel argument associated to each version of the + // values passed to the kernel. + llvm::SmallVector newLoopArg; + // For existing loop argument initialize them with the right version from the + // prologue. + for (const auto &retVal : + llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { + Operation *def = retVal.value().getDefiningOp(); + assert(def && "Only support loop carried dependencies of distance 1"); + unsigned defStage = stages[def]; + Value valueVersion = valueMapping[forOp.getRegionIterArgs()[retVal.index()]] + [maxStage - defStage]; + assert(valueVersion); + newLoopArg.push_back(valueVersion); + } + for (auto escape : crossStageValues) { + LiverangeInfo &info = escape.second; + Value value = escape.first; + for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage; + stageIdx++) { + Value valueVersion = + valueMapping[value][maxStage - info.lastUseStage + stageIdx]; + assert(valueVersion); + newLoopArg.push_back(valueVersion); + loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage - + stageIdx)] = newLoopArg.size() - 1; + } + } + + // Create the new kernel loop. When we peel the epilgue we need to peel + // `numStages - 1` iterations. Then we adjust the upper bound to remove those + // iterations. + Value newUb = forOp.getUpperBound(); + if (peelEpilogue) { + Type t = ub.getType(); + Location loc = forOp.getLoc(); + // newUb = ub - maxStage * step + newUb = rewriter.create( + loc, ub, + rewriter.create( + loc, step, + rewriter.create( + loc, rewriter.getIntegerAttr(t, -maxStage)))); + } + auto newForOp = + rewriter.create(forOp.getLoc(), forOp.getLowerBound(), newUb, + forOp.getStep(), newLoopArg); + // When there are no iter args, the loop body terminator will be created. + // Since we always create it below, remove the terminator if it was created. + if (!newForOp.getBody()->empty()) + rewriter.eraseOp(newForOp.getBody()->getTerminator()); + return newForOp; +} + +LogicalResult LoopPipelinerInternal::createKernel( + scf::ForOp newForOp, + const llvm::MapVector + &crossStageValues, + const llvm::DenseMap, unsigned> &loopArgMap, + RewriterBase &rewriter) { + valueMapping.clear(); + + // Create the kernel, we clone instruction based on the order given by + // user and remap operands coming from a previous stages. + rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); + IRMapping mapping; + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) { + mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + } + SmallVector predicates(maxStage + 1, nullptr); + if (!peelEpilogue) { + // Create a predicate for each stage except the last stage. + Location loc = newForOp.getLoc(); + Type t = ub.getType(); + for (unsigned i = 0; i < maxStage; i++) { + // c = ub - (maxStage - i) * step + Value c = rewriter.create( + loc, ub, + rewriter.create( + loc, step, + rewriter.create( + loc, rewriter.getIntegerAttr(t, -int64_t(maxStage - i))))); + + Value pred = rewriter.create( + newForOp.getLoc(), arith::CmpIPredicate::slt, + newForOp.getInductionVar(), c); + predicates[i] = pred; + } + } + for (Operation *op : opOrder) { + int64_t useStage = stages[op]; + auto *newOp = rewriter.clone(*op, mapping); + SmallVector operands; + // Collect all the operands for the cloned op and its nested ops. + op->walk([&operands](Operation *nestedOp) { + for (OpOperand &operand : nestedOp->getOpOperands()) { + operands.push_back(&operand); + } + }); + for (OpOperand *operand : operands) { + Operation *nestedNewOp = mapping.lookup(operand->getOwner()); + // Special case for the induction variable uses. We replace it with a + // version incremented based on the stage where it is used. + if (operand->get() == forOp.getInductionVar()) { + rewriter.setInsertionPoint(newOp); + + // offset = (maxStage - stages[op]) * step + Type t = step.getType(); + Value offset = rewriter.create( + forOp.getLoc(), step, + rewriter.create( + forOp.getLoc(), + rewriter.getIntegerAttr(t, maxStage - stages[op]))); + Value iv = rewriter.create( + forOp.getLoc(), newForOp.getInductionVar(), offset); + nestedNewOp->setOperand(operand->getOperandNumber(), iv); + rewriter.setInsertionPointAfter(newOp); + continue; + } + Value source = operand->get(); + auto arg = dyn_cast(source); + if (arg && arg.getOwner() == forOp.getBody()) { + Value ret = forOp.getBody()->getTerminator()->getOperand( + arg.getArgNumber() - 1); + Operation *dep = ret.getDefiningOp(); + if (!dep) + continue; + auto stageDep = stages.find(dep); + if (stageDep == stages.end() || stageDep->second == useStage) + continue; + // If the value is a loop carried value coming from stage N + 1 remap, + // it will become a direct use. + if (stageDep->second == useStage + 1) { + nestedNewOp->setOperand(operand->getOperandNumber(), + mapping.lookupOrDefault(ret)); + continue; + } + source = ret; + } + // For operands defined in a previous stage we need to remap it to use + // the correct region argument. We look for the right version of the + // Value based on the stage where it is used. + Operation *def = source.getDefiningOp(); + if (!def) + continue; + auto stageDef = stages.find(def); + if (stageDef == stages.end() || stageDef->second == useStage) + continue; + auto remap = loopArgMap.find( + std::make_pair(operand->get(), useStage - stageDef->second)); + assert(remap != loopArgMap.end()); + nestedNewOp->setOperand(operand->getOperandNumber(), + newForOp.getRegionIterArgs()[remap->second]); + } + + if (predicates[useStage]) { + newOp = predicateFn(rewriter, newOp, predicates[useStage]); + if (!newOp) + return failure(); + // Remap the results to the new predicated one. + for (auto values : llvm::zip(op->getResults(), newOp->getResults())) + mapping.map(std::get<0>(values), std::get<1>(values)); + } + rewriter.setInsertionPointAfter(newOp); + if (annotateFn) + annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Kernel, 0); + } + + // Collect the Values that need to be returned by the forOp. For each + // value we need to have `LastUseStage - DefStage` number of versions + // returned. + // We create a mapping between original values and the associated loop + // returned values that will be needed by the epilogue. + llvm::SmallVector yieldOperands; + for (OpOperand &yielOperand : + forOp.getBody()->getTerminator()->getOpOperands()) { + Value source = mapping.lookupOrDefault(yielOperand.get()); + // When we don't peel the epilogue the yield value is used outside the loop + // we need to make sure we return the version from numStages - defStage. + if (!peelEpilogue && + !forOp.getResult(yielOperand.getOperandNumber()).use_empty()) { + auto [def, distance] = getDefiningOpAndDistance(yielOperand.get()); + if (def) { + auto defStage = stages.find(def); + if (defStage != stages.end()) { + Value pred = predicates[defStage->second]; + if (pred) { + source = rewriter.create( + pred.getLoc(), pred, source, + newForOp.getBody() + ->getArguments()[yielOperand.getOperandNumber() + 1]); + } + } + } + } + yieldOperands.push_back(source); + } + + for (auto &it : crossStageValues) { + int64_t version = maxStage - it.second.lastUseStage + 1; + unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage; + // add the original version to yield ops. + // If there is a live range spanning across more than 2 stages we need to + // add extra arg. + for (unsigned i = 1; i < numVersionReturned; i++) { + setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), + version++); + yieldOperands.push_back( + newForOp.getBody()->getArguments()[yieldOperands.size() + 1 + + newForOp.getNumInductionVars()]); + } + setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), + version++); + yieldOperands.push_back(mapping.lookupOrDefault(it.first)); + } + // Map the yield operand to the forOp returned value. + for (const auto &retVal : + llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { + Operation *def = retVal.value().getDefiningOp(); + assert(def && "Only support loop carried dependencies of distance 1"); + unsigned defStage = stages[def]; + if (defStage > 0) { + setValueMapping(forOp.getRegionIterArgs()[retVal.index()], + newForOp->getResult(retVal.index()), + maxStage - defStage + 1); + } + } + rewriter.create(forOp.getLoc(), yieldOperands); + return success(); +} + +llvm::SmallVector +LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter) { + llvm::SmallVector returnValues(forOp->getNumResults()); + // Emit different versions of the induction variable. They will be + // removed by dead code if not used. + for (int64_t i = 0; i < maxStage; i++) { + Location loc = forOp.getLoc(); + Type t = lb.getType(); + Value minusOne = + rewriter.create(loc, rewriter.getIntegerAttr(t, -1)); + // number of iterations = ((ub - 1) - lb) / step + Value totlaNumIteration = rewriter.create( + loc, + rewriter.create( + loc, rewriter.create(loc, ub, minusOne), lb), + step); + // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i) + Value minusI = + rewriter.create(loc, rewriter.getIntegerAttr(t, -i)); + Value newlastIter = rewriter.create( + loc, lb, + rewriter.create( + loc, step, + rewriter.create(loc, totlaNumIteration, minusI))); + setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i); + } + // Emit `maxStage - 1` epilogue part that includes operations from stages + // [i; maxStage]. + for (int64_t i = 1; i <= maxStage; i++) { + for (Operation *op : opOrder) { + if (stages[op] < i) + continue; + Operation *newOp = + cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { + auto it = valueMapping.find(newOperand->get()); + if (it != valueMapping.end()) { + Value replacement = it->second[maxStage - stages[op] + i]; + newOperand->set(replacement); + } + }); + if (annotateFn) + annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Epilogue, + i - 1); + for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { + setValueMapping(op->getResult(destId), newOp->getResult(destId), + maxStage - stages[op] + i); + // If the value is a loop carried dependency update the loop argument + // mapping and keep track of the last version to replace the original + // forOp uses. + for (OpOperand &operand : + forOp.getBody()->getTerminator()->getOpOperands()) { + if (operand.get() != op->getResult(destId)) + continue; + unsigned version = maxStage - stages[op] + i + 1; + // If the version is greater than maxStage it means it maps to the + // original forOp returned value. + if (version > maxStage) { + returnValues[operand.getOperandNumber()] = newOp->getResult(destId); + continue; + } + setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], + newOp->getResult(destId), version); + } + } + } + } + return returnValues; +} + +void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) { + auto it = valueMapping.find(key); + // If the value is not in the map yet add a vector big enough to store all + // versions. + if (it == valueMapping.end()) + it = + valueMapping + .insert(std::make_pair(key, llvm::SmallVector(maxStage + 1))) + .first; + it->second[idx] = el; +} + +} // namespace + +FailureOr +mlir::triton::pipelineForLoop(RewriterBase &rewriter, ForOp forOp, + const triton::PipeliningOption &options, + bool *modifiedIR) { + if (modifiedIR) + *modifiedIR = false; + LoopPipelinerInternal pipeliner; + if (!pipeliner.initializeLoopInfo(forOp, options)) + return failure(); + + if (modifiedIR) + *modifiedIR = true; + + // 1. Emit prologue. + pipeliner.emitPrologue(rewriter); + + // 2. Track values used across stages. When a value cross stages it will + // need to be passed as loop iteration arguments. + // We first collect the values that are used in a different stage than where + // they are defined. + llvm::MapVector + crossStageValues = pipeliner.analyzeCrossStageValues(); + + // Mapping between original loop values used cross stage and the block + // arguments associated after pipelining. A Value may map to several + // arguments if its liverange spans across more than 2 stages. + llvm::DenseMap, unsigned> loopArgMap; + // 3. Create the new kernel loop and return the block arguments mapping. + ForOp newForOp = + pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap); + // Create the kernel block, order ops based on user choice and remap + // operands. + if (failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap, + rewriter))) + return failure(); + + llvm::SmallVector returnValues = + newForOp.getResults().take_front(forOp->getNumResults()); + if (options.peelEpilogue) { + // 4. Emit the epilogue after the new forOp. + rewriter.setInsertionPointAfter(newForOp); + returnValues = pipeliner.emitEpilogue(rewriter); + } + // 5. Erase the original loop and replace the uses with the epilogue output. + if (forOp->getNumResults() > 0) + rewriter.replaceOp(forOp, returnValues); + else + rewriter.eraseOp(forOp); + + return newForOp; +} diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.h b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.h new file mode 100644 index 000000000..0a3d736c6 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.h @@ -0,0 +1,101 @@ +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_ + +// This is a fork of upstream pipeline transformation. This will be merged back +// upstream once we have a stable solution. + +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" + +namespace mlir { + +class RewriterBase; +class Operation; +class Value; + +namespace scf { +class ForOp; +} + +namespace triton { + +/// Options to dictate how loops should be pipelined. +struct PipeliningOption { + /// Lambda returning all the operation in the forOp, with their stage, in the + /// order picked for the pipelined loop. + using GetScheduleFnType = std::function> &)>; + GetScheduleFnType getScheduleFn = nullptr; + enum class PipelinerPart { + Prologue, + Kernel, + Epilogue, + }; + /// Lambda called by the pipeliner to allow the user to annotate the IR while + /// it is generated. + /// The callback passes the operation created along with the part of the + /// pipeline and the iteration index. The iteration index is always 0 for the + /// kernel. For the prologue and epilogue, it corresponds to the iteration + /// peeled out of the loop in the range [0, maxStage[. + using AnnotationlFnType = + std::function; + AnnotationlFnType annotateFn = nullptr; + + /// Control whether the epilogue should be peeled out of the loop or + /// operations should be predicated to skip the early stages in the last loop + /// iterations. If the epilogue is predicated; the user needs to provide a + /// lambda to generate the predicated version of operations. + bool peelEpilogue = true; + + /// Control whether the transformation checks that the number of iterations is + /// greater or equal to the number of stages and skip the transformation if + /// this is not the case. If the loop is dynamic and this is set to true the + /// pipeliner will have to predicate operations in the the prologue/epilogue. + bool supportDynamicLoops = false; + + // Callback to predicate operations when the prologue or epilogue are not + // peeled. This takes the original operation, an i1 predicate value and the + // pattern rewriter. It is expected to replace the given operation with + // the predicated equivalent and return it, or return nullptr if the + // predication is impossible. In the latter case, pipelining will fail and + // may leave IR in a partially transformed state. + using PredicateOpFnType = + std::function; + PredicateOpFnType predicateFn = nullptr; + + // TODO: add option to decide if the prologue should be peeled. +}; + +/// Generate a pipelined version of the scf.for loop based on the schedule given +/// as option. This applies the mechanical transformation of changing the loop +/// and generating the prologue/epilogue for the pipelining and doesn't make any +/// decision regarding the schedule. +/// Based on the options the loop is split into several stages. +/// The transformation assumes that the scheduling given by user is valid. +/// For example if we break a loop into 3 stages named S0, S1, S2 we would +/// generate the following code with the number in parenthesis as the iteration +/// index: +/// +/// S0(0) // Prologue +/// S0(1) S1(0) // Prologue +/// scf.for %I = %C0 to %N - 2 { +/// S0(I+2) S1(I+1) S2(I) // Pipelined kernel +/// } +/// S1(N) S2(N-1) // Epilogue +/// S2(N) // Epilogue +/// +/// If `modifiedIR` is provided, it will be set to a value that indicates +/// whether pipelining modified the IR before failing, signaling to the caller +/// whether they can proceed with different transformations. +FailureOr pipelineForLoop(RewriterBase &rewriter, scf::ForOp forOp, + const PipeliningOption &options, + bool *modifiedIR = nullptr); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_ diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.h b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.h new file mode 100644 index 000000000..67ee2ca83 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.h @@ -0,0 +1,27 @@ +#ifndef TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ +#define TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ + +#include "PipelineExpander.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" +#include + +namespace mlir { +namespace triton { + +/// This fill out the pipelining options including schedule and annotations for +/// wait ops. This also does pre-processing by converting some of the loads into +/// async loads so that the IR is ready to be pipelined. +bool preProcessLoopAndGetSchedule(scf::ForOp &forOp, int numStages, + mlir::triton::PipeliningOption &options); + +/// This does post-processing on the pipelined loop to try to pipeline wgmma +/// ops. +// TODO: this should be included as part of the pipeline but currently the wgmma +// wait modeling is problematic. +void asyncLaunchDots(scf::ForOp forOp); + +} // namespace triton +} // namespace mlir +#endif // TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp new file mode 100644 index 000000000..573151115 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp @@ -0,0 +1,88 @@ +#include "PipelineExpander.h" +#include "Schedule.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/Debug.h" + +//===----------------------------------------------------------------------===// +// This file will create a schedule that will be handed over to the pipeline +// expander. +// Software pipeliners are usually separated into two pieces, one that create a +// modulo schedule and an expander that rewrites the loop and emits a prologue +// and epilogue. This pass first calls a helper that will pre-process the IR +// to create async operations and create a modulo schedule. Then we call the +// expander to generate the prologue and new loop. +//===----------------------------------------------------------------------===// + +using namespace mlir; + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +static void pipelineLoop(scf::ForOp forOp, int numStages) { + mlir::triton::PipeliningOption options; + // Skip loop with distance > 1 for now. + // TODO: relax the constraint in the expander. + if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), + [](Value operand) { + Operation *def = operand.getDefiningOp(); + return !def; + })) + return; + + bool foundSchedule = false; + foundSchedule = preProcessLoopAndGetSchedule(forOp, numStages, options); + + // TODO: add more pipelines strategy. + if (!foundSchedule) + return; + + IRRewriter rewriter(forOp->getContext()); + rewriter.setInsertionPoint(forOp); + FailureOr newForOp = + mlir::triton::pipelineForLoop(rewriter, forOp, options); + + if (succeeded(newForOp)) + mlir::triton::asyncLaunchDots(newForOp.value()); +} + +namespace { +struct PipelinePass : public TritonGPUPipelineBase { + PipelinePass() = default; + PipelinePass(int numStages, int numWarps, int numCTAs, + int computeCapability) { + this->numStages = numStages; + this->numWarps = numWarps; + this->numCTAs = numCTAs; + this->computeCapability = computeCapability; + } + + void runOnOperation() override { + if (this->numStages <= 1) + return; + SmallVector loops; + getOperation()->walk([&](scf::ForOp forOp) { loops.push_back(forOp); }); + for (scf::ForOp forOp : loops) { + pipelineLoop(forOp, numStages); + } + } +}; +} // anonymous namespace + +std::unique_ptr mlir::createTritonGPUPipelinePass(int numStages, + int numWarps, + int numCTAs, + int computeCapability) { + return std::make_unique(numStages, numWarps, numCTAs, + computeCapability); +} diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 9af04cf59..ee6b2fca1 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -332,9 +332,6 @@ SmallVector LayoutPropagation::propagateToUsers(Value value, setEncoding({afterArg, result}, info, changed, user); continue; } - // Workaround: don't propagate through truncI - if (isa(user)) - continue; if (user->hasTrait() || user->hasTrait() || isa(loop.getInitArgs()); - operands.append(newIterOperands.begin(), newIterOperands.end()); - scf::ForOp newLoop = rewriter.create( - loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), - operands); - newLoop.getBody()->erase(); - - newLoop.getRegion().getBlocks().splice( - newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks()); - for (Value operand : newIterOperands) - newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); - - for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( - loop.getNumResults()))) - std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); - return newLoop; -} - static void rewriteSlice(SetVector &slice, DenseMap &layout, ConvertLayoutOp convertOp, IRMapping &mapping) { diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index c345f1b87..1693a3ef7 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -98,8 +98,6 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( // Some ops from SCF are illegal addIllegalOp(); - // We have custom versions of some arith operators - addIllegalOp(); addDynamicallyLegalDialect()) return "orangered"; - else - assert(0 && "Unrecognized layout"); + else { + llvm::report_fatal_error("Unrecognized layout"); + return "unknown"; + } } else { return "white"; } @@ -342,11 +344,39 @@ bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) { } return true; } + if (auto view = dyn_cast(op)) { + auto viewDstType = view.getType().cast(); + RankedTensorType newDstType = RankedTensorType::get( + viewDstType.getShape(), viewDstType.getElementType(), targetEncoding); + return !triton::gpu::isExpensiveView(view.getOperand().getType(), + newDstType); + } return isa(op); + triton::MakeRangeOp, triton::SplatOp>(op); } -// +scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, scf::ForOp loop, + ValueRange newIterOperands) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(loop); + + // Create a new loop before the existing one, with the extra operands. + auto operands = llvm::to_vector<4>(loop.getInitArgs()); + operands.append(newIterOperands.begin(), newIterOperands.end()); + scf::ForOp newLoop = rewriter.create( + loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), + operands); + newLoop.getBody()->erase(); + newLoop.getRegion().getBlocks().splice( + newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks()); + for (Value operand : newIterOperands) + newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); + + for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( + loop.getNumResults()))) + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + return newLoop; +} Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, IRMapping &mapping) { diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index f11f02781..fee287203 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -79,6 +79,17 @@ void CreateMutexOp::build(::mlir::OpBuilder &builder, build(builder, state, MutexType::get(builder.getContext())); } +///--- DotWaitOp --- +LogicalResult DotWaitOp::inferReturnTypes( + ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, + ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + for (Value operand : operands) + inferredReturnTypes.push_back(operand.getType()); + return mlir::success(); +} + } // namespace nvidia_gpu } // namespace triton } // namespace mlir diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp index af2799301..63bdf66e4 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp @@ -68,7 +68,8 @@ Attribute replaceCTALayout(Attribute layout, llvm::ArrayRef shape, replaceCTALayout(sliceLayout.getParent(), shape, newCTALayout)); } else { // Other layouts are generated by passes after PlanCTAPass - assert(0 && "replaceCTALayout not implemented"); + llvm::report_fatal_error("replaceCTALayout not implemented"); + return layout; } } @@ -393,7 +394,8 @@ bool CTAPlanner::propagateBackward(CastOp cast) { Value output = cast.getResult(0); unsigned numUsers = getNumUsers(input); if (numUsers == 0) { - assert(0 && "Unreachable branch"); + llvm::report_fatal_error("Unreachable branch"); + return false; } else if (numUsers == 1) { Type outTy = output.getType(); if (auto ptrTy = outTy.dyn_cast()) @@ -649,7 +651,7 @@ bool CTAPlanner::isElementwiseOp(Operation *op) const { return true; if (auto externElementwiseOp = dyn_cast(op)) return externElementwiseOp.getPure(); - if (llvm::isa(op)) + if (llvm::isa(op)) return true; return false; } @@ -711,7 +713,7 @@ bool CTAPlanner::processExpandDimsBackward(triton::ExpandDimsOp expandDims, bool CTAPlanner::processExpandDimsForward(triton::ExpandDimsOp expandDims, Attribute newSrcLayout) { - assert(0 && "processExpandDimsForward not implemented yet"); + llvm::report_fatal_error("processExpandDimsForward not implemented yet"); return true; } @@ -827,7 +829,7 @@ int findResultIndex(Operation *op, Value result) { for (int i = 0; i < op->getNumResults(); ++i) if (op->getResult(i) == result) return i; - assert(0 && "Invalid index of op result"); + llvm::report_fatal_error("Invalid index of op result"); return -1; } @@ -849,7 +851,7 @@ bool CTAPlanner::processBlockArgBackward(BlockArgument arg, CastOp cast) { auto newType = cast.getResult(0).getType(); return processForOp(forOp, index, newType); } else { - assert(0 && "Unexpected parent op of block argument"); + llvm::report_fatal_error("Unexpected parent op of block argument"); return true; } } @@ -869,7 +871,7 @@ bool CTAPlanner::processYieldOpForward(scf::YieldOp yieldOp, CastOp cast) { else if (auto forOp = llvm::dyn_cast(yieldOp->getParentOp())) return processForOp(forOp, index, newType); else - assert(0 && "Unexpected parent op of YieldOp"); + llvm::report_fatal_error("Unexpected parent op of YieldOp"); return true; } @@ -936,7 +938,8 @@ bool CTAPlanner::processMultiUsersBackward(Value input, CastOp cast) { Operation *clonedOp = builder.clone(*defOp); newInput = clonedOp->getResult(0); } else { - assert(0 && "Layout conflict for block arg"); // TODO + llvm::report_fatal_error("Layout conflict for block arg"); // TODO + return false; } } first = false; diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp index 5a2d3beaa..ba42896ee 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp @@ -55,7 +55,7 @@ bool isDivisible(Value v, unsigned divisor) { auto func = dyn_cast(parentOp); assert(func); if (auto attr = func.getArgAttrOfType(blockArg.getArgNumber(), - "tt.max_divisibility")) + "tt.divisibility")) return attr.getValue().getZExtValue() % divisor == 0; return false; } else if (v.getParentBlock()->isEntryBlock() && (!v.isa())) { @@ -98,13 +98,8 @@ bool shouldRemove(tt::MakeTensorPtrOp &op, int computeCapability) { return !(boxDimSwizzle && strideDivisible && enableTMA); } -// TODO: When encoding exists use triton::gpu::CmpIOp as arith::CmpIOp doesn't -// play well with encoding attributes. Move back to arith::CmpIOp when this pass -// moves back to triton IR level. Value createCmpOp(OpBuilder &builder, Location loc, RankedTensorType type, arith::CmpIPredicate pred, Value lhs, Value rhs) { - if (type.getEncoding()) - return builder.create(loc, type, pred, lhs, rhs); return builder.create(loc, type, pred, lhs, rhs); } @@ -358,12 +353,17 @@ class TritonGPURewriteTensorPointerPass : public TritonGPURewriteTensorPointerBase< TritonGPURewriteTensorPointerPass> { private: - int computeCapability; + // int computeCapability; DenseMap rewritedInfo; public: - explicit TritonGPURewriteTensorPointerPass(int computeCapability) - : computeCapability(computeCapability) {} + // explicit TritonGPURewriteTensorPointerPass(int computeCapability) + // : computeCapability(computeCapability) {} + + TritonGPURewriteTensorPointerPass() = default; + TritonGPURewriteTensorPointerPass(int computeCapability) { + this->computeCapability = computeCapability; + } static bool needRewrite(Operation *op, const DenseSet &valueToRemove) { if (auto ifOp = dyn_cast(op)) { @@ -763,17 +763,16 @@ public: ModuleOp mod = getOperation(); DenseSet valueToRemove; - mod.walk([&valueToRemove, - computeCapability = this->computeCapability](Operation *op) { + mod.walk([&valueToRemove, this](Operation *op) { if (auto makeTensorPtrOp = dyn_cast(op)) { - if (shouldRemove(makeTensorPtrOp, computeCapability)) + if (shouldRemove(makeTensorPtrOp, this->computeCapability)) valueToRemove.insert(op->getResult(0)); } if (llvm::isa(op)) { auto src = op->getOperand(0); if (tt::isTensorPointerType(src.getType())) { auto makeTensorPtrOp = getMakeTensorPtrOp(src); - if (shouldRemove(makeTensorPtrOp, computeCapability)) { + if (shouldRemove(makeTensorPtrOp, this->computeCapability)) { valueToRemove.insert(op->getResult(0)); } } @@ -782,7 +781,7 @@ public: auto src = op->getOperand(0); if (tt::isTensorPointerType(src.getType())) { auto makeTensorPtrOp = getMakeTensorPtrOp(src); - if (shouldRemove(makeTensorPtrOp, computeCapability)) + if (shouldRemove(makeTensorPtrOp, this->computeCapability)) valueToRemove.insert(src); } } @@ -791,7 +790,7 @@ public: for (unsigned i = 0, size = forOp.getInitArgs().size(); i < size; ++i) { if (tt::isTensorPointerType(iterOperands[i].getType())) { auto makeTensorPtrOp = getMakeTensorPtrOp(iterOperands[i]); - if (shouldRemove(makeTensorPtrOp, computeCapability)) + if (shouldRemove(makeTensorPtrOp, this->computeCapability)) valueToRemove.insert(iterOperands[i]); } } @@ -800,7 +799,7 @@ public: for (unsigned i = 0, size = yieldOp.getNumOperands(); i < size; ++i) { if (tt::isTensorPointerType(operands[i].getType())) { auto makeTensorPtrOp = getMakeTensorPtrOp(operands[i]); - if (shouldRemove(makeTensorPtrOp, computeCapability)) + if (shouldRemove(makeTensorPtrOp, this->computeCapability)) valueToRemove.insert(operands[i]); } } diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp index e21f6ddd3..150be31e4 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp @@ -892,7 +892,7 @@ void buildAsyncComm(const DenseMap> &map, Value result = forOp->getResult(resultIndex); auto dotWait = builder.createWithAgentIds( forOp.getLoc(), result, 0); - result.replaceAllUsesExcept(dotWait.getResult(), dotWait); + result.replaceAllUsesExcept(dotWait.getResult(0), dotWait); // 3. insert ConsumerReleaseOp for outstanding DotAsyncOps zero = builder.createWithAgentIds(loc, 0, 32); diff --git a/llvm-hash.txt b/llvm-hash.txt new file mode 100644 index 000000000..3eba2629b --- /dev/null +++ b/llvm-hash.txt @@ -0,0 +1 @@ +49af6502c6dcb4a7f7520178bd14df396f78240c diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..e1ecf8228 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,26 @@ +[build-system] +requires = ["setuptools>=40.8.0", "wheel", "cmake>=3.18", "ninja>=1.11.1"] + +[tool.yapf] +based_on_style = "pep8" +column_limit = 120 +disable_split_list_with_comment = true +each_dict_entry_on_separate_line=false +split_before_named_assigns = false +split_complex_comprehension = true + +[tool.yapfignore] +ignore_patterns = [ + # This exclusion is also specified in .pre-commit-config.yaml. + # - We put it here because if you run yapf directly, we want it to skip the + # file. + # - We also put it in .pre-commit-config because yapf raises an error if + # pre-commit runs it but all of the files it might touch are ignored! + "python/test/unit/language/test_line_info.py" +] + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +ignore = ["E501", "E701", "E731", "E741"] diff --git a/python/examples/copy_strided.py b/python/examples/copy_strided.py index 34cf12630..79457d6a4 100644 --- a/python/examples/copy_strided.py +++ b/python/examples/copy_strided.py @@ -4,8 +4,8 @@ import triton.language as tl # triton kernel @triton.jit -def kernel(X, stride_xm, - Z, stride_zn, +def kernel(X, stride_xm, # + Z, stride_zn, # BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): off_m = tl.arange(0, BLOCK_M) off_n = tl.arange(0, BLOCK_N) diff --git a/python/examples/empty.py b/python/examples/empty.py index df313fb85..bff6d1e94 100644 --- a/python/examples/empty.py +++ b/python/examples/empty.py @@ -10,4 +10,4 @@ def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr): X = torch.randn(1, device="cuda") -pgm = kernel[(1,)](X, 1, 1, BLOCK=1024) +pgm = kernel[(1, )](X, 1, 1, BLOCK=1024) diff --git a/python/pyproject.toml b/python/pyproject.toml index 8bd8093e7..315aa7da9 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -2,7 +2,14 @@ [build-system] requires = ["setuptools>=40.8.0", "wheel", "cmake>=3.18", "ninja>=1.11.1"] +# We're incrementally switching from autopep8 to ruff. [tool.autopep8] aggressive = 1 -ignore = "E501,E701,E731,W690" +ignore = "E501,E701,E731,W690,W503" max_line_length = 88 + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +ignore = ["E501", "E701", "E731", "E741"] diff --git a/python/setup.py b/python/setup.py index c5ddff6e0..a8e56a5c6 100644 --- a/python/setup.py +++ b/python/setup.py @@ -55,6 +55,7 @@ class Package(NamedTuple): lib_flag: str syspath_var_name: str + # pybind11 @@ -63,6 +64,7 @@ def get_pybind11_package_info(): url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.11.1.tar.gz" return Package("pybind11", name, url, "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH") + # llvm @@ -74,6 +76,8 @@ def get_llvm_package_info(): arch = 'arm64' if system == "Darwin": arch = platform.machine() + if arch == "x86_64": + arch = "x64" system_suffix = f"macos-{arch}" elif system == "Linux": # TODO: arm64 @@ -84,7 +88,7 @@ def get_llvm_package_info(): return Package("llvm", "LLVM-C.lib", "", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH") # use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False") # release_suffix = "assert" if use_assert_enabled_llvm else "release" - rev = "b1115f8c" + rev = "49af6502" name = f"llvm-{rev}-{system_suffix}" url = f"https://tritonlang.blob.core.windows.net/llvm-builds/{name}.tar.gz" return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH") @@ -119,10 +123,13 @@ def get_thirdparty_packages(triton_cache_path): thirdparty_cmake_args.append(f"-D{p.lib_flag}={package_dir}/lib") return thirdparty_cmake_args + # ---- package data --- -def download_and_copy(src_path, version, url_func): +def download_and_copy(src_path, variable, version, url_func): + if variable in os.environ: + return base_dir = os.path.dirname(__file__) arch = platform.machine() if arch == "x86_64": @@ -148,7 +155,7 @@ def download_and_copy(src_path, version, url_func): src_path = os.path.join(temp_dir, src_path) os.makedirs(os.path.split(dst_path)[0], exist_ok=True) shutil.copy(src_path, dst_path) - return dst_suffix + # ---- cmake extension ---- @@ -167,18 +174,21 @@ def get_cmake_dir(): class CMakeClean(clean): + def initialize_options(self): clean.initialize_options(self) self.build_temp = get_cmake_dir() class CMakeBuildPy(build_py): + def run(self) -> None: self.run_command('build_ext') return super().run() class CMakeExtension(Extension): + def __init__(self, name, path, sourcedir=""): Extension.__init__(self, name, sources=[]) self.sourcedir = os.path.abspath(sourcedir) @@ -201,7 +211,8 @@ class CMakeBuild(build_ext): try: out = subprocess.check_output(["cmake", "--version"]) except OSError: - raise RuntimeError("CMake must be installed to build the following extensions: " + ", ".join(e.name for e in self.extensions)) + raise RuntimeError("CMake must be installed to build the following extensions: " + + ", ".join(e.name for e in self.extensions)) match = re.search(r"version\s*(?P\d+)\.(?P\d+)([\d.]+)?", out.decode()) cmake_major, cmake_minor = int(match.group("major")), int(match.group("minor")) @@ -228,8 +239,10 @@ class CMakeBuild(build_ext): # python directories python_include_dir = sysconfig.get_path("platinclude") cmake_args = [ - "-G", "Ninja", # Ninja is much faster than make - "-DCMAKE_MAKE_PROGRAM=" + ninja_dir, # Pass explicit path to ninja otherwise cmake may cache a temporary path + "-G", + "Ninja", # Ninja is much faster than make + "-DCMAKE_MAKE_PROGRAM=" + + ninja_dir, # Pass explicit path to ninja otherwise cmake may cache a temporary path "-DCMAKE_EXPORT_COMPILE_COMMANDS=ON", "-DLLVM_ENABLE_WERROR=ON", "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, @@ -263,12 +276,28 @@ class CMakeBuild(build_ext): build_args += ['-j' + max_jobs] if check_env_flag("TRITON_BUILD_WITH_CLANG_LLD"): - cmake_args += ["-DCMAKE_C_COMPILER=clang", - "-DCMAKE_CXX_COMPILER=clang++", - "-DCMAKE_LINKER=lld", - "-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=lld", - "-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld", - "-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld"] + cmake_args += [ + "-DCMAKE_C_COMPILER=clang", + "-DCMAKE_CXX_COMPILER=clang++", + "-DCMAKE_LINKER=lld", + "-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=lld", + "-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld", + "-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld", + ] + + # Note that asan doesn't work with binaries that use the GPU, so this is + # only useful for tools like triton-opt that don't run code on the GPU. + # + # I tried and gave up getting msan to work. It seems that libstdc++'s + # std::string does not play nicely with clang's msan (I didn't try + # gcc's). I was unable to configure clang to ignore the error, and I + # also wasn't able to get libc++ to work, but that doesn't mean it's + # impossible. :) + if check_env_flag("TRITON_BUILD_WITH_ASAN"): + cmake_args += [ + "-DCMAKE_C_FLAGS=-fsanitize=address", + "-DCMAKE_CXX_FLAGS=-fsanitize=address", + ] if check_env_flag("TRITON_BUILD_WITH_CCACHE"): cmake_args += [ @@ -282,9 +311,27 @@ class CMakeBuild(build_ext): subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir) -download_and_copy(src_path='bin/ptxas', version='12.1.105', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2") -download_and_copy(src_path='bin/cuobjdump', version='12.1.111', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2") -download_and_copy(src_path='bin/nvdisasm', version='12.1.105', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2") +download_and_copy( + src_path="bin/ptxas", + variable="TRITON_PTXAS_PATH", + version="12.1.105", + url_func=lambda arch, version: + f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2", +) +download_and_copy( + src_path="bin/cuobjdump", + variable="TRITON_CUOBJDUMP_PATH", + version="12.1.111", + url_func=lambda arch, version: + f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2", +) +download_and_copy( + src_path="bin/nvdisasm", + variable="TRITON_NVDISASM_PATH", + version="12.1.105", + url_func=lambda arch, version: + f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2", +) setup( name="triton", @@ -307,10 +354,14 @@ setup( "triton/third_party", "triton/tools", ], +<<<<<<< HEAD long_description_content_type="text/markdown", install_requires=[ "filelock" ], +======= + install_requires=["filelock"], +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 include_package_data=True, ext_modules=[CMakeExtension("triton", "triton/_C/")], cmdclass={"build_ext": CMakeBuild, "build_py": CMakeBuildPy, "clean": CMakeClean}, diff --git a/python/src/triton.cc b/python/src/triton.cc index fe335d93a..e21d15f69 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -229,6 +229,12 @@ void init_triton_ir(py::module &&m) { .value("RELAXED", mlir::triton::MemSemantic::RELAXED) .export_values(); + py::enum_(m, "MEM_SYNC_SCOPE", py::module_local()) + .value("GPU", mlir::triton::MemSyncScope::GPU) + .value("CTA", mlir::triton::MemSyncScope::CTA) + .value("SYSTEM", mlir::triton::MemSyncScope::SYSTEM) + .export_values(); + py::enum_(m, "EVICTION_POLICY", py::module_local()) .value("NORMAL", mlir::triton::EvictionPolicy::NORMAL) @@ -1527,7 +1533,8 @@ void init_triton_ir(py::module &&m) { // // atomic .def("create_atomic_cas", [](TritonOpBuilder &self, mlir::Value &ptr, mlir::Value &cmp, - mlir::Value &val, mlir::triton::MemSemantic sem) -> mlir::Value { + mlir::Value &val, mlir::triton::MemSemantic sem, + mlir::triton::MemSyncScope scope) -> mlir::Value { mlir::Type dstType; if (auto srcTensorType = ptr.getType().dyn_cast()) { @@ -1542,12 +1549,13 @@ void init_triton_ir(py::module &&m) { dstType = ptrType.getPointeeType(); } return self.create(dstType, ptr, cmp, - val, sem); + val, sem, scope); }) .def("create_atomic_rmw", [](TritonOpBuilder &self, mlir::triton::RMWOp rmwOp, mlir::Value &ptr, mlir::Value &val, mlir::Value &mask, - mlir::triton::MemSemantic sem) -> mlir::Value { + mlir::triton::MemSemantic sem, + mlir::triton::MemSyncScope scope) -> mlir::Value { mlir::Type dstType; if (auto srcTensorType = ptr.getType().dyn_cast()) { @@ -1561,8 +1569,8 @@ void init_triton_ir(py::module &&m) { .cast(); dstType = ptrType.getPointeeType(); } - return self.create(dstType, rmwOp, ptr, - val, mask, sem); + return self.create( + dstType, rmwOp, ptr, val, mask, sem, scope); }) // External .def("create_extern_elementwise", @@ -1764,6 +1772,10 @@ void init_triton_ir(py::module &&m) { [](mlir::PassManager &self) { self.addPass(mlir::createTritonGPUCoalescePass()); }) + .def("add_tritongpu_optimize_thread_locality_pass", + [](mlir::PassManager &self) { + self.addPass(mlir::createTritonGPUOptimizeThreadLocalityPass()); + }) .def("add_symbol_dce_pass", [](mlir::PassManager &self) { self.addPass(mlir::createSymbolDCEPass()); diff --git a/python/test/backend/test_device_backend.py b/python/test/backend/test_device_backend.py index 238398cdf..bc73d8376 100644 --- a/python/test/backend/test_device_backend.py +++ b/python/test/backend/test_device_backend.py @@ -13,12 +13,11 @@ import torch import triton import triton.language as tl -from triton.common.backend import BaseBackend, register_backend +from triton.common.backend import (BaseBackend, compute_core_version_key, register_backend) from triton.common.build import quiet from triton.compiler.make_launcher import make_so_cache_key from triton.runtime.cache import get_cache_manager from triton.runtime.driver import DriverBase -from triton.runtime.jit import version_key def build_for_backend(name, src, srcdir): @@ -81,6 +80,7 @@ def build_for_backend(name, src, srcdir): class ExtensionUtils: + def __new__(cls): if not hasattr(cls, 'instance'): cls.instance = super(ExtensionUtils, cls).__new__(cls) @@ -110,6 +110,7 @@ class ExtensionUtils: class ExtensionDriver(DriverBase): + def __new__(cls): if not hasattr(cls, 'instance'): cls.instance = super(ExtensionDriver, cls).__new__(cls) @@ -125,6 +126,7 @@ class ExtensionBackend(BaseBackend): def __init__(self, device_type: str) -> None: super(ExtensionBackend, self).__init__(device_type) self.driver = ExtensionDriver() + self.version_key = None def add_stages(self, arch, extern_libs, stages): filter_in_stages = ["ast", "ttir", "ttgir"] @@ -163,9 +165,14 @@ class ExtensionBackend(BaseBackend): def get_architecture_descriptor(self, **kwargs): return "" + def get_version_key(self): + if self.version_key is None: + self.version_key = compute_core_version_key() + return self.version_key + def make_launcher_stub(self, name, signature, constants): # name of files that are cached - so_cache_key = make_so_cache_key(version_key(), signature, constants) + so_cache_key = make_so_cache_key(self.get_version_key(), signature, constants) so_cache_manager = get_cache_manager(so_cache_key) so_name = f"{name}.so" # retrieve stub from cache if it exists @@ -250,13 +257,13 @@ def test_dummy_backend(): inp = torch.randn(10) out = torch.randn(10) - kernel[(10,)](inp, out, 10, XBLOCK=16) + kernel[(10, )](inp, out, 10, XBLOCK=16) spec = importlib.util.spec_from_file_location("__triton_launcher", ExtensionBackend.stub_so_path) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) launch_counter = getattr(mod, "launch_counter") for _ in range(100): - kernel[(10,)](inp, out, 10, XBLOCK=16) + kernel[(10, )](inp, out, 10, XBLOCK=16) assert launch_counter() > 0 diff --git a/python/test/backend/third_party_backends/conftest.py b/python/test/backend/third_party_backends/conftest.py index 62ee6c689..d939bc001 100644 --- a/python/test/backend/third_party_backends/conftest.py +++ b/python/test/backend/third_party_backends/conftest.py @@ -4,9 +4,7 @@ import pytest def pytest_addoption(parser): - parser.addoption( - "--backend", action="store", default="", help="Codegen backend" - ) + parser.addoption("--backend", action="store", default="", help="Codegen backend") @pytest.fixture diff --git a/python/test/backend/third_party_backends/test_xpu_backend.py b/python/test/backend/third_party_backends/test_xpu_backend.py index f00261f19..e6850efdd 100644 --- a/python/test/backend/third_party_backends/test_xpu_backend.py +++ b/python/test/backend/third_party_backends/test_xpu_backend.py @@ -24,10 +24,10 @@ def test_xpu_backend(cmdopt): if has_ipex: for _ in range(1000): - x = torch.randn((65536,), device="xpu", dtype=torch.float32) - y = torch.randn((65536,), device="xpu", dtype=torch.float32) - z = torch.zeros((65536,), device="xpu", dtype=torch.float32) - kernel[(65536,)](x, y, z, num_warps=32) + x = torch.randn((65536, ), device="xpu", dtype=torch.float32) + y = torch.randn((65536, ), device="xpu", dtype=torch.float32) + z = torch.zeros((65536, ), device="xpu", dtype=torch.float32) + kernel[(65536, )](x, y, z, num_warps=32) assert torch.all(x + y == z) else: return diff --git a/python/test/regression/test_cast_matmul.py b/python/test/regression/test_cast_matmul.py new file mode 100644 index 000000000..1477bc5a4 --- /dev/null +++ b/python/test/regression/test_cast_matmul.py @@ -0,0 +1,100 @@ +""" +issue: https://github.com/openai/triton/issues/2523 +fused type convert and matmul, base on triton matmul, the different with matmul: +1. force C's dtype=dot_out_dtype to ["float16", "float32"] +2. accept A and B with dtype=["float32", "float64"] + +""" +import pytest +import torch + +import triton.language as tl +from triton import cdiv, jit + +input_dtypes = ["float32", "float64"] +out_dtypes = ["float16", "float32"] + + +@pytest.mark.parametrize("M, K, N, w_dtype, x_dtype, out_dtype", + [(M, K, N, w, x, o) # + for (M, K, N) in [(128, 128, 128), (1280, 768, 1024)] # + for w in input_dtypes + for x in input_dtypes # + for o in out_dtypes]) +def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype): + if x_dtype == w_dtype: + pytest.skip("skip same dtype") + device = torch.cuda.current_device() + x_dtype = getattr(torch, x_dtype) + w_dtype = getattr(torch, w_dtype) + a = torch.randn((M, K), device=device, dtype=x_dtype) + b = torch.randn((K, N), device=device, dtype=w_dtype) + torch_dtype = getattr(torch, out_dtype) + triton_dtype = getattr(tl, out_dtype) # <- here force dot_out_dtype + out_torch = torch.matmul(a.to(torch_dtype), b.to(torch_dtype)) + out_triton = torch.empty((M, N), device=device, dtype=torch_dtype) + + allow_tf32 = True + # launch kernel + BLOCK_M, BLOCK_N, BLOCK_K = 16, 16, 32 + grid = ((cdiv(M, BLOCK_M) * cdiv(N, BLOCK_N)), 1) + + @jit + def matmul_kernel(A, B, C, M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + dot_out_dtype: tl.constexpr, # + allow_tf32: tl.constexpr, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, # + BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr): + # matrix multiplication + pid = tl.program_id(0) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) + for k in range(0, tl.cdiv(K, BLOCK_K)): + k_remaining = K - k * BLOCK_K + _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) + a = a.to(C.dtype.element_ty) + b = b.to(C.dtype.element_ty) + acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + acc = acc.to(C.dtype.element_ty) + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + tl.store(C, acc, mask=mask) + + matmul_kernel[grid]( + a, b, out_triton, M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + out_triton.stride(0), out_triton.stride(1), dot_out_dtype=triton_dtype, # + allow_tf32=allow_tf32, # + GROUP_M=8, # + BLOCK_M=BLOCK_M, # + BLOCK_N=BLOCK_N, # + BLOCK_K=BLOCK_K) + + torch.testing.assert_close(out_torch, out_triton, atol=0.3, rtol=0.01) diff --git a/python/test/regression/test_functional_regressions.py b/python/test/regression/test_functional_regressions.py index b873db7a3..e0eb56601 100644 --- a/python/test/regression/test_functional_regressions.py +++ b/python/test/regression/test_functional_regressions.py @@ -14,18 +14,14 @@ def test_chained_matmul(): return torch.einsum('MN,NK->MK', intermediate, c) @triton.jit - def chained_matmul_kernel( - A, # shape: (m, k) - B, # shape: (n, k) - C, # shape: (n, k) - out, # shape: (m, k) - m, n, k: tl.constexpr, - block_m: tl.constexpr, - block_n: tl.constexpr, - block_k: tl.constexpr): + def chained_matmul_kernel(A, # shape: (m, k) + B, # shape: (n, k) + C, # shape: (n, k) + out, # shape: (m, k) + m, n, k: tl.constexpr, # + block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr): - tl.static_assert(block_k == k, - f"expected block_k == k but got {block_k} != {k}") + tl.static_assert(block_k == k, f"expected block_k == k but got {block_k} != {k}") block_ix = tl.program_id(0) a_tile = (block_ix * block_m + tl.arange(0, block_m))[:, None] * block_k \ @@ -55,35 +51,33 @@ def test_chained_matmul(): m, n, k = 32, 64, 128 block_m, block_n, block_k = 16, 32, k - grid = (triton.cdiv(m, block_m),) - a = torch.randint(low=0, high=2, size=(m, k), dtype=torch.float16, - device='cuda') - b = torch.randint(low=0, high=2, size=(n, k), dtype=torch.float16, - device='cuda') + grid = (triton.cdiv(m, block_m), ) + a = torch.randint(low=0, high=2, size=(m, k), dtype=torch.float16, device='cuda') + b = torch.randint(low=0, high=2, size=(n, k), dtype=torch.float16, device='cuda') c = torch.randint_like(b, low=0, high=2) triton_result = torch.zeros_like(a) torch_result = chained_matmul_reference(a, b, c) - chained_matmul_kernel[grid](a, b, c, triton_result, m, n, k, - block_m=block_m, block_n=block_n, - block_k=block_k) + chained_matmul_kernel[grid]( + a, b, c, triton_result, m, n, k, # + block_m=block_m, block_n=block_n, block_k=block_k) assert (torch_result == triton_result).all() def test_vecmat(): + @triton.jit def batched_vecmat( - # inputs - A, # shape: [dim_m, dim_k] - B, # shape: [dim_m, dim_n, dim_k] - # dimensions + # inputs + A, # shape: [dim_m, dim_k] + B, # shape: [dim_m, dim_n, dim_k] + # dimensions dim_m, dim_n, dim_k, - # outputs - output, - # block information - block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr - ): + # outputs + output, + # block information + block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr): m_index = tl.program_id(0) n_index = tl.program_id(1) # Output tile @@ -125,9 +119,10 @@ def test_vecmat(): grid = (M // block_m, N // block_n) - batched_vecmat[grid](A_tri, B_tri, M, N, K, C_tri, - block_m=block_m, block_n=block_n, block_k=block_k, - num_warps=4, num_stages=1) + batched_vecmat[grid]( + A_tri, B_tri, M, N, K, C_tri, # + block_m=block_m, block_n=block_n, block_k=block_k, # + num_warps=4, num_stages=1) A_expanded = A[:, np.newaxis, :] A_broadcasted = np.broadcast_to(A_expanded, (M, N, K)) @@ -137,18 +132,18 @@ def test_vecmat(): np.testing.assert_allclose(C_ref, C_tri.cpu().numpy(), rtol=0.01, atol=1e-3) -@pytest.mark.parametrize("type", ["pre_load", "post_load", "post_pre_mixed", "post_load_two_iters", "post_load_three_iters"]) +@pytest.mark.parametrize("type", + ["pre_load", "post_load", "post_pre_mixed", "post_load_two_iters", "post_load_three_iters"]) def test_iv_dependent_matmul(type): + @triton.jit - def kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - type: tl.constexpr - ): + def kernel(a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + type: tl.constexpr): pid = tl.program_id(axis=0) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) pid_m = pid // num_pid_n @@ -216,15 +211,16 @@ def test_iv_dependent_matmul(type): b = torch.rand((K, N), device='cuda') torch_output = torch.mm(a, b) - triton_output = torch.empty_like( - torch_output, device=torch_output.device) + triton_output = torch.empty_like(torch_output, device=torch_output.device) def grid(META): - return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) + return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) num_stages = 4 if type == "post_load_three_iters" else 3 - kernel[grid](a, b, triton_output, M, N, K, a.stride(0), a.stride(1), - b.stride(0), b.stride(1), triton_output.stride(0), triton_output.stride(1), - BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, - type=type, num_stages=num_stages) + kernel[grid]( + a, b, triton_output, M, N, K, # + a.stride(0), a.stride(1), b.stride(0), b.stride(1), # + triton_output.stride(0), triton_output.stride(1), # + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, type=type, # + num_stages=num_stages) torch.testing.assert_close(torch_output, triton_output, rtol=1e-2, atol=1e-2) diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index ec9966d29..c3a70a6a9 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -26,7 +26,6 @@ sm_clocks = {'v100': 1350, 'a100': 1350} mem_clocks = {'v100': 877, 'a100': 1215} matmul_data = { - # NOTE: 'a100': { # square (512, 512, 512): {'float16': 0.108, 'float32': 0.097, 'int8': 0.05}, @@ -49,10 +48,9 @@ matmul_data = { } -@pytest.mark.parametrize('M, N, K, dtype_str', - [(M, N, K, dtype_str) - for M, N, K in matmul_data[DEVICE_NAME].keys() - for dtype_str in ['float16']]) +@pytest.mark.parametrize('M, N, K, dtype_str', [(M, N, K, dtype_str) + for M, N, K in matmul_data[DEVICE_NAME].keys() + for dtype_str in ['float16']]) def test_matmul(M, N, K, dtype_str): stream = torch.cuda.Stream() torch.cuda.set_stream(stream) @@ -86,8 +84,7 @@ def test_matmul(M, N, K, dtype_str): @triton.jit -def _add(x_ptr, y_ptr, output_ptr, n_elements, - BLOCK_SIZE: tl.constexpr): +def _add(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) @@ -136,36 +133,36 @@ def test_elementwise(N, dtype_str): print_perf(ms, cur_gpu_util, ref_gpu_util) triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01) + ####################### # Flash-Attention ####################### - flash_attention_data = { "a100": { (4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.542, (4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.471, (4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.155, - (4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.203, - (4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.202, - (4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.108, + (4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.232, + (4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.231, + (4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.138, (4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.306, (4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.266, (4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.098, (4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.134, (4, 48, 4096, 64, True, False, 'backward', 'bfloat16'): 0.135, - (4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.066, + (4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.092, (4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.541, (4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.471, (4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.150, - (4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.263, + (4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.291, (4, 48, 4096, 64, False, True, 'backward', 'bfloat16'): 0.255, (4, 48, 1024, 16, False, True, 'backward', 'float32'): 0.144, (4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.306, (4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.266, (4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.098, (4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.159, - (4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.136, + (4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.159, (4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.088, } } @@ -221,8 +218,7 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str): @triton.jit -def _sum(x_ptr, y_ptr, output_ptr, n_elements, - BLOCK_SIZE: tl.constexpr): +def _sum(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) @@ -260,8 +256,8 @@ def test_reductions(N, dtype_str): y = torch.randn_like(z) else: info = torch.iinfo(dtype) - x = torch.randint(info.min, info.max, (N,), dtype=dtype, device='cuda') - y = torch.randint(info.min, info.max, (N,), dtype=dtype, device='cuda') + x = torch.randint(info.min, info.max, (N, ), dtype=dtype, device='cuda') + y = torch.randint(info.min, info.max, (N, ), dtype=dtype, device='cuda') grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), ) fn = lambda: _sum[grid](x, y, z, N, BLOCK_SIZE=1024) ms = triton.testing.do_bench_cudagraph(fn) diff --git a/python/test/tools/compare_files.py b/python/test/tools/compare_files.py index 1c8de084d..d74e1da02 100644 --- a/python/test/tools/compare_files.py +++ b/python/test/tools/compare_files.py @@ -9,6 +9,7 @@ import yaml class ComparisonResult: + def __init__(self, name: str, numComparisons: int, diffs: List[str] = None, errors: List[str] = None): self.name = name self.numComparisons = numComparisons @@ -142,7 +143,8 @@ def doFilesMatch(path1: str, path2: str) -> bool: return True -def compareMatchingFiles(name: str, nameToHashes1: Dict[str, List[str]], nameToHashes2: Dict[str, List[str]], args) -> ComparisonResult: +def compareMatchingFiles(name: str, nameToHashes1: Dict[str, List[str]], nameToHashes2: Dict[str, List[str]], + args) -> ComparisonResult: """ Compare files with the given name in all hashes in both paths Return the first mismatching files as a tuple (file1, file2), otherwise, return an empty tuple diff --git a/python/test/unit/hopper/test_flashattention.py b/python/test/unit/hopper/test_flashattention.py index 60006613b..fc8db664c 100644 --- a/python/test/unit/hopper/test_flashattention.py +++ b/python/test/unit/hopper/test_flashattention.py @@ -18,7 +18,6 @@ # CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - """ Fused Attention =============== @@ -35,18 +34,15 @@ import triton.language as tl @triton.jit -def _fwd_kernel( - Q, K, V, sm_scale, - L, M, - Out, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, - stride_oz, stride_oh, stride_om, stride_on, - Z, H, N_CTX, D0, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): +def _fwd_kernel(Q, K, V, sm_scale, # + L, M, # + Out, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vk, stride_vn, # + stride_oz, stride_oh, stride_om, stride_on, # + Z, H, N_CTX, D0, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr): start_m = tl.program_id(0) off_hz = tl.program_id(1) @@ -61,31 +57,38 @@ def _fwd_kernel( stride_qh_2d = stride_qh // stride_qm // stride_qk - q_tile_ptr = tl.make_block_ptr(base=Q, - shape=(D0, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=( - off_hz * stride_qh_2d + start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0)) - k_tile_ptr = tl.make_block_ptr(base=K, - shape=(D0, BLOCK_DMODEL), - strides=(stride_kn, stride_kk), - offsets=(off_hz * stride_qh_2d, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0)) - v_tile_ptr = tl.make_block_ptr(base=V, - shape=(D0, BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(off_hz * stride_qh_2d, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0)) - out_tile_ptr = tl.make_block_ptr(base=Out, - shape=(D0, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0)) + q_tile_ptr = tl.make_block_ptr( + base=Q, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + k_tile_ptr = tl.make_block_ptr( + base=K, + shape=(D0, BLOCK_DMODEL), + strides=(stride_kn, stride_kk), + offsets=(off_hz * stride_qh_2d, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + v_tile_ptr = tl.make_block_ptr( + base=V, + shape=(D0, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(off_hz * stride_qh_2d, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + out_tile_ptr = tl.make_block_ptr( + base=Out, + shape=(D0, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) # load q: it will stay in SRAM throughout q = tl.load(q_tile_ptr) @@ -96,8 +99,7 @@ def _fwd_kernel( qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, tl.trans(k)) qk *= sm_scale - qk = tl.where(offs_m[:, None] >= ( - start_n + offs_n[None, :]), qk, float("-inf")) + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) # compute new m m_curr = tl.maximum(tl.max(qk, 1), m_prev) # correct old l @@ -133,11 +135,9 @@ def _fwd_kernel( @triton.jit -def _bwd_preprocess( - Out, DO, L, - NewDO, Delta, - BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, -): +def _bwd_preprocess(Out, DO, L, # + NewDO, Delta, # + BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr): off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) off_n = tl.arange(0, D_HEAD) # load @@ -153,19 +153,14 @@ def _bwd_preprocess( @triton.jit -def _bwd_kernel( - Q, K, V, sm_scale, Out, DO, - DQ, DK, DV, - L, M, - D, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, - Z, H, N_CTX, D0, - num_block, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): +def _bwd_kernel(Q, K, V, sm_scale, Out, DO, # + DQ, DK, DV, # + L, M, # + D, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vk, stride_vn, # + Z, H, N_CTX, D0, # + num_block, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr): off_hz = tl.program_id(0) off_z = off_hz // H off_h = off_hz % H @@ -173,55 +168,62 @@ def _bwd_kernel( stride_qz_2d = stride_qz // stride_qm // stride_qk stride_qh_2d = stride_qh // stride_qm // stride_qk - q_tile_ptr = tl.make_block_ptr(base=Q, - shape=(D0, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=( - off_z * stride_qz_2d + off_h * stride_qh_2d, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0)) - k_tile_ptr = tl.make_block_ptr(base=K, - shape=(D0, BLOCK_DMODEL), - strides=(stride_kn, stride_kk), - offsets=( - off_z * stride_qz_2d + off_h * stride_qh_2d, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0)) - v_tile_ptr = tl.make_block_ptr(base=V, - shape=(D0, BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=( - off_z * stride_qz_2d + off_h * stride_qh_2d, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0)) - do_tile_ptr = tl.make_block_ptr(base=DO, - shape=(D0, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=( - off_z * stride_qz_2d + off_h * stride_qh_2d, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0)) - dq_tile_ptr = tl.make_block_ptr(base=DQ, - shape=(D0, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=( - off_z * stride_qz_2d + off_h * stride_qh_2d, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0)) - dk_tile_ptr = tl.make_block_ptr(base=DK, - shape=(D0, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=( - off_z * stride_qz_2d + off_h * stride_qh_2d, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0)) - dv_tile_ptr = tl.make_block_ptr(base=DV, - shape=(D0, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=( - off_z * stride_qz_2d + off_h * stride_qh_2d, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0)) + q_tile_ptr = tl.make_block_ptr( + base=Q, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + k_tile_ptr = tl.make_block_ptr( + base=K, + shape=(D0, BLOCK_DMODEL), + strides=(stride_kn, stride_kk), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + v_tile_ptr = tl.make_block_ptr( + base=V, + shape=(D0, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + do_tile_ptr = tl.make_block_ptr( + base=DO, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + dq_tile_ptr = tl.make_block_ptr( + base=DQ, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + dk_tile_ptr = tl.make_block_ptr( + base=DK, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + dv_tile_ptr = tl.make_block_ptr( + base=DV, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) # offset pointers for batch/head DQ += off_z * stride_qz + off_h * stride_qh for start_n in range(0, num_block): @@ -250,8 +252,7 @@ def _bwd_kernel( # recompute p = softmax(qk, dim=-1).T # NOTE: `do` is pre-divided by `l`; no normalization here qk = tl.dot(q, tl.trans(k)) - qk = tl.where(offs_m_curr[:, None] >= ( - offs_n[None, :]), qk, float("-inf")) + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) m = tl.load(m_ptrs + offs_m_curr) p = tl.exp(qk * sm_scale - m[:, None]) # compute dv @@ -301,29 +302,21 @@ class _attention(torch.autograd.Function): assert Lk in {16, 32, 64, 128} o = torch.empty_like(q) grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1) - L = torch.empty( - (q.shape[0] * q.shape[1], q.shape[2]), - device=q.device, - dtype=torch.float32) - m = torch.empty( - (q.shape[0] * q.shape[1], q.shape[2]), - device=q.device, - dtype=torch.float32) + L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) num_warps = 4 if Lk <= 64 else 8 D0 = q.shape[0] * q.shape[1] * q.shape[2] _fwd_kernel[grid]( - q, k, v, sm_scale, - L, m, - o, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - o.stride(0), o.stride(1), o.stride(2), o.stride(3), - q.shape[0], q.shape[1], q.shape[2], D0, - BLOCK_M=BLOCK, BLOCK_N=BLOCK, - BLOCK_DMODEL=Lk, num_warps=num_warps, - num_stages=2, - ) + q, k, v, sm_scale, # + L, m, # + o, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + o.stride(0), o.stride(1), o.stride(2), o.stride(3), # + q.shape[0], q.shape[1], q.shape[2], D0, # + BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=Lk, # + num_warps=num_warps, num_stages=2) ctx.save_for_backward(q, k, v, o, L, m) ctx.grid = grid @@ -343,25 +336,22 @@ class _attention(torch.autograd.Function): delta = torch.empty_like(l) D0 = q.shape[0] * q.shape[1] * q.shape[2] _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( - o, do, l, - do_scaled, delta, - BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL, - ) - _bwd_kernel[(ctx.grid[1],)]( - q, k, v, ctx.sm_scale, - o, do_scaled, - dq, dk, dv, - l, m, - delta, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - q.shape[0], q.shape[1], q.shape[2], D0, - ctx.grid[0], - BLOCK_M=BLOCK, BLOCK_N=BLOCK, - BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, - num_stages=1, - ) + o, do, l, # + do_scaled, delta, # + BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL) + _bwd_kernel[(ctx.grid[1], )]( + q, k, v, ctx.sm_scale, # + o, do_scaled, # + dq, dk, dv, # + l, m, # + delta, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + q.shape[0], q.shape[1], q.shape[2], D0, # + ctx.grid[0], # + BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=ctx.BLOCK_DMODEL, # + num_warps=8, num_stages=1) return dq, dk, dv, None @@ -380,15 +370,9 @@ attention = _attention.apply @pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires arch 9+") def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): torch.manual_seed(20) - q = torch.empty( - (Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_( - mean=0.1, std=0.2).requires_grad_() - k = torch.empty( - (Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_( - mean=0.4, std=0.2).requires_grad_() - v = torch.empty( - (Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_( - mean=0.3, std=0.2).requires_grad_() + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_() sm_scale = 0.2 dout = torch.randn_like(q) # reference implementation @@ -427,22 +411,25 @@ except BaseException: BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 # vary seq length for fixed head and batch=4 -configs = [triton.testing.Benchmark( - x_names=['N_CTX'], - x_vals=[2**i for i in range(10, 14)], - line_arg='provider', - line_vals=['triton'] + (['flash'] if HAS_FLASH else []), - line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), - styles=[('red', '-'), ('blue', '-')], - ylabel='ms', - plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}', - args={ - 'H': N_HEADS, - 'BATCH': BATCH, - 'D_HEAD': D_HEAD, - 'dtype': torch.float16, - 'mode': mode} -) for mode in ['fwd', 'bwd']] +configs = [ + triton.testing.Benchmark( + x_names=['N_CTX'], + x_vals=[2**i for i in range(10, 14)], + line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}', + args={ + 'H': N_HEADS, + 'BATCH': BATCH, + 'D_HEAD': D_HEAD, + 'dtype': torch.float16, + 'mode': mode, + }, + ) for mode in ['fwd', 'bwd'] +] @triton.testing.perf_report(configs) @@ -463,9 +450,8 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) return ms if provider == "flash": - lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) - cu_seqlens = torch.zeros( - (BATCH + 1,), device=device, dtype=torch.int32) + lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros((BATCH + 1, ), device=device, dtype=torch.int32) cu_seqlens[1:] = lengths.cumsum(0) qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True) diff --git a/python/test/unit/hopper/test_gemm.py b/python/test/unit/hopper/test_gemm.py index a3e3f80b9..cc6ebd0b0 100644 --- a/python/test/unit/hopper/test_gemm.py +++ b/python/test/unit/hopper/test_gemm.py @@ -32,19 +32,30 @@ import triton.language as tl @triton.jit -def matmul_no_scf_kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr -): - a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), - offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) - b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), - offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) +def matmul_no_scf_kernel(a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr # + ): + a_block_ptr = tl.make_block_ptr( + base=a_ptr, + shape=(M, K), + strides=(stride_am, stride_ak), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_K), + order=(1, 0), + ) + b_block_ptr = tl.make_block_ptr( + base=b_ptr, + shape=(K, N), + strides=(stride_bk, stride_bn), + offsets=(0, 0), + block_shape=(BLOCK_K, BLOCK_N), + order=(0, 1), + ) a = tl.load(a_block_ptr) b = tl.load(b_block_ptr) @@ -54,8 +65,8 @@ def matmul_no_scf_kernel( c = c.to(tl.float16) if USE_TMA_EPILOGUE: - c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), - offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) + c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) tl.store(c_block_ptr, c) else: offs_m = tl.arange(0, BLOCK_M) @@ -64,33 +75,30 @@ def matmul_no_scf_kernel( tl.store(c_ptrs, c) -@pytest.mark.parametrize('M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,ENABLE_WS', - itertools.chain( - *[ - [ - # numCTAs = 1, no TMA multicast: - [64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], - [64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], - [64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], - [64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], - [64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], - [64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], - [128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], - [128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], - # static mask, cluster 4x1 - [256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], - [256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], - # dynamic mask, cluster 2x2 - [128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], - [128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], - # small M, N - [16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], - [16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], - [32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], - [32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], - ] for USE_TMA_EPILOGUE in [True, False] - for ENABLE_WS in [False, True] - ])) +@pytest.mark.parametrize( + 'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,ENABLE_WS', + itertools.chain(*[[ + # numCTAs = 1, no TMA multicast: + [64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], + [64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], + [64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], + [64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + [64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + [64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + [128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], + [128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + # static mask, cluster 4x1 + [256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], + [256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + # dynamic mask, cluster 2x2 + [128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], + [128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + # small M, N + [16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + [16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + [32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + [32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + ] for USE_TMA_EPILOGUE in [True, False] for ENABLE_WS in [False, True]])) @pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE, ENABLE_WS): if (TRANS_A): @@ -107,46 +115,41 @@ def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE else: c = torch.empty((M, N), device=a.device, dtype=torch.float32) - matmul_no_scf_kernel[(1, 1)](a_ptr=a, b_ptr=b, c_ptr=c, - M=M, N=N, K=K, - stride_am=a.stride(0), stride_ak=a.stride(1), - stride_bk=b.stride(0), stride_bn=b.stride(1), - stride_cm=c.stride(0), stride_cn=c.stride(1), - BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, - num_warps=NUM_WARPS, - num_ctas=NUM_CTAS, - FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), - USE_TMA_EPILOGUE=USE_TMA_EPILOGUE, - enable_warp_specialization=ENABLE_WS) + matmul_no_scf_kernel[(1, 1)]( + a_ptr=a, b_ptr=b, c_ptr=c, # + M=M, N=N, K=K, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_cm=c.stride(0), stride_cn=c.stride(1), # + BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, # + num_warps=NUM_WARPS, # + num_ctas=NUM_CTAS, # + FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), # + USE_TMA_EPILOGUE=USE_TMA_EPILOGUE, # + enable_warp_specialization=ENABLE_WS) a_f32 = a.to(torch.float32) b_f32 = b.to(torch.float32) golden = torch.matmul(a_f32, b_f32) torch.set_printoptions(profile="full") - assert_close( - c, - golden, - rtol=1e-2, - atol=1e-3, - check_dtype=False) + assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False) @triton.jit -def matmul_kernel( - a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_wm, stride_wn, - stride_zm, stride_zn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, - out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr, - ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, - DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, - A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, - B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr, - W_ORDER_0: tl.constexpr, W_ORDER_1: tl.constexpr, - Z_ORDER_0: tl.constexpr, Z_ORDER_1: tl.constexpr -): +def matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_wm, stride_wn, # + stride_zm, stride_zn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, # + out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr, # + ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, # + DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, # + A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, # + B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr, # + W_ORDER_0: tl.constexpr, W_ORDER_1: tl.constexpr, # + Z_ORDER_0: tl.constexpr, Z_ORDER_1: tl.constexpr # + ): pid = tl.program_id(axis=0) num_pid_n = tl.cdiv(N, BLOCK_N) num_pid_m = tl.cdiv(M, BLOCK_M) @@ -159,13 +162,31 @@ def matmul_kernel( block_offset_m = pid_m * BLOCK_M block_offset_n = pid_n * BLOCK_N - a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), - offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(A_ORDER_0, A_ORDER_1)) - b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), - offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(B_ORDER_0, B_ORDER_1)) + a_tile_ptr = tl.make_block_ptr( + base=a_ptr, + shape=(M, K), + strides=(stride_am, stride_ak), + offsets=(block_offset_m, 0), + block_shape=(BLOCK_M, BLOCK_K), + order=(A_ORDER_0, A_ORDER_1), + ) + b_tile_ptr = tl.make_block_ptr( + base=b_ptr, + shape=(K, N), + strides=(stride_bk, stride_bn), + offsets=(0, block_offset_n), + block_shape=(BLOCK_K, BLOCK_N), + order=(B_ORDER_0, B_ORDER_1), + ) # for chain-dot, BLOCK_N must always be equal to N, and each program loads the whole W matrix - w_tile_ptr = tl.make_block_ptr(base=w_ptr, shape=(N, N), strides=(stride_wm, stride_wn), - offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_N), order=(W_ORDER_0, W_ORDER_1)) + w_tile_ptr = tl.make_block_ptr( + base=w_ptr, + shape=(N, N), + strides=(stride_wm, stride_wn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_N), + order=(W_ORDER_0, W_ORDER_1), + ) z = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) offs_m = block_offset_m + tl.arange(0, BLOCK_M) @@ -204,141 +225,151 @@ def matmul_kernel( if USE_TMA_STORE: z_block_ptr = tl.make_block_ptr(base=z_ptr, shape=(M, N), strides=(stride_zm, stride_zn), - offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(Z_ORDER_0, Z_ORDER_1)) + offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), + order=(Z_ORDER_0, Z_ORDER_1)) tl.store(z_block_ptr, z, boundary_check=(0, 1)) else: tl.store(z_ptrs, z, mask=mask) -@pytest.mark.parametrize('BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS', - [ - # corner shapes - (128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws) - for shape_w_c in [ - [4096, 1, 1024, False, False, True], - [2048, 204, 1000, True, False, True], - [4096, 1, 1024, False, False, False], - [2048, 204, 1000, True, False, False], - ] - for out_dtype in ['float16', 'float32'] - for use_tma_store in [False, True] - for enable_ws in [False, True] - ] + [ - # softmax epilogue - (*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws) - for shape_w_c in [ - [64, 64, 16, 4, 1, 64, 64, 64], - [128, 128, 64, 4, 1, None, None, None], - [16, 16, 64, 4, 1, 16, 16, 64], - [64, 64, 32, 8, 1, 64, 64, 64], - [128, 128, 64, 4, 1, 128, 128, 128], - ] - for epilogue in ['softmax'] - for out_dtype in ['float16', 'float32'] - for use_tma_store in [False, True] - for trans_a in [False,] - for trans_b in [True,] - for trans_output in [False,] - for num_stages in [3] - for enable_ws in [False, True] - ] + [ - # loop over epilogues besides of softmax - (*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws) - for shape_w_c in [ - [64, 64, 16, 4, 1, 128, 128, 64], - *[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], - # for chain-dot - [128, 128, 64, 4, 1, None, None, None], - [64, 64, 16, 4, 1, None, None, None], - # small BLOCK_M and BLOCK_K - [16, 16, 64, 4, 1, 128, 128, 64], - *[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], - # repeat - [64, 64, 32, 8, 1, 128, 256, 64], - [64, 64, 16, 8, 2, 128, 128, 64], - # irregular shape - [128, 128, 64, 4, 1, 500, 200, 128], - [128, 128, 64, 4, 2, 513, 193, 192], - ] - for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'] - for out_dtype in ['float16', 'float32'] - for use_tma_store in [False, True] - for trans_a in [False,] - for trans_b in [True,] - for trans_output in [False,] - for num_stages in [3] - for enable_ws in [False, True] - if not (epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) - ] + [ - # loop over tile shapes and transpose combinations - (*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws) - for shape_w_c in [ - [64, 64, 32, 4, 1, 128, 256, 64], - [128, 128, 16, 4, 4, 512, 256, 64], - [128, 256, 32, 4, 8, 256, 256, 192], - [512, 256, 32, 4, 8, 1024, 256, 192], - # BLOCK_K >= 128 - [64, 128, 128, 4, 1, 512, 256, 256], - [128, 128, 128, 4, 1, 256, 256, 192], - [128, 128, 128, 4, 2, 256, 256, 192], - # small BLOCK_M and BLOCK_K - [16, 32, 32, 4, 1, 128, 256, 64], - [32, 32, 16, 4, 1, 256, 256, 192], - [16, 32, 64, 4, 4, 512, 256, 64], - ] - for out_dtype in ['float32',] - for use_tma_store in [False,] - for trans_a in [False, True] - for trans_b in [False, True] - for trans_output in [False, True] - for num_stages in [3] - for enable_ws in [False, True] - ] + [ - # loop over instr shapes & pipeline stages - (64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws) - for n in [16, 32, 64, 128, 256] - for trans_output in [False,] - for out_dtype in ['float32',] - for use_tma_store in [False,] - for num_stages in [2, 4, 5, 7] - for enable_ws in [False, True] - ] + [ - # irregular shapes - (*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws) - for shape_w_c in [ - [128, 128, 64, 4, 1], - [256, 128, 64, 4, 2], - [128, 128, 128, 4, 2], - ] - for shape in [ - [512, 360, 1024], - [360, 4096, 512], - ] - for trans_output in [False,] - for out_dtype in ['float32',] - for use_tma_store in [False, True] - for num_stages in [3, 4] - for enable_ws in [False, True] - ]) -@pytest.mark.skipif(torch.cuda.get_device_capability() - [0] < 9, reason="Requires compute capability >= 9") -def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS): +@pytest.mark.parametrize( + 'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS', + [ + # corner shapes + (128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws) + for shape_w_c in [ + [4096, 1, 1024, False, False, True], + [2048, 204, 1000, True, False, True], + [4096, 1, 1024, False, False, False], + [2048, 204, 1000, True, False, False], + ] + for out_dtype in ['float16', 'float32'] # + for use_tma_store in [False, True] # + for enable_ws in [False, True] + ] + [ + # softmax epilogue + (*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws) + for shape_w_c in [ + [64, 64, 16, 4, 1, 64, 64, 64], + [128, 128, 64, 4, 1, None, None, None], + [16, 16, 64, 4, 1, 16, 16, 64], + [64, 64, 32, 8, 1, 64, 64, 64], + [128, 128, 64, 4, 1, 128, 128, 128], + ] + for epilogue in ['softmax'] + for out_dtype in ['float16', 'float32'] + for use_tma_store in [False, True] + for trans_a in [False] + for trans_b in [True] + for trans_output in [False] + for num_stages in [3] + for enable_ws in [False, True] + ] + [ + # loop over epilogues besides of softmax + (*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws) + for shape_w_c in [ + [64, 64, 16, 4, 1, 128, 128, 64], + *[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] + for num_warps in [4, 8] + for num_ctas in [1, 2, 4]], + # for chain-dot + [128, 128, 64, 4, 1, None, None, None], + [64, 64, 16, 4, 1, None, None, None], + # small BLOCK_M and BLOCK_K + [16, 16, 64, 4, 1, 128, 128, 64], + *[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] + for num_warps in [4, 8] + for num_ctas in [1, 2]], + # repeat + [64, 64, 32, 8, 1, 128, 256, 64], + [64, 64, 16, 8, 2, 128, 128, 64], + # irregular shape + [128, 128, 64, 4, 1, 500, 200, 128], + [128, 128, 64, 4, 2, 513, 193, 192], + ] + for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'] + for out_dtype in ['float16', 'float32'] + for use_tma_store in [False, True] + for trans_a in [False] + for trans_b in [True] + for trans_output in [False] + for num_stages in [3] + for enable_ws in [False, True] + if not (epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) + ] + [ + # loop over tile shapes and transpose combinations + (*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws) + for shape_w_c in [ + [64, 64, 32, 4, 1, 128, 256, 64], + [128, 128, 16, 4, 4, 512, 256, 64], + [128, 256, 32, 4, 8, 256, 256, 192], + [512, 256, 32, 4, 8, 1024, 256, 192], + # BLOCK_K >= 128 + [64, 128, 128, 4, 1, 512, 256, 256], + [128, 128, 128, 4, 1, 256, 256, 192], + [128, 128, 128, 4, 2, 256, 256, 192], + # small BLOCK_M and BLOCK_K + [16, 32, 32, 4, 1, 128, 256, 64], + [32, 32, 16, 4, 1, 256, 256, 192], + [16, 32, 64, 4, 4, 512, 256, 64], + ] + for out_dtype in ['float32'] + for use_tma_store in [False] + for trans_a in [False, True] + for trans_b in [False, True] + for trans_output in [False, True] + for num_stages in [3] + for enable_ws in [False, True] + ] + [ + # loop over instr shapes & pipeline stages + (64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, + enable_ws) + for n in [16, 32, 64, 128, 256] + for trans_output in [False] + for out_dtype in ['float32'] + for use_tma_store in [False] + for num_stages in [2, 4, 5, 7] + for enable_ws in [False, True] + ] + [ + # irregular shapes + (*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws) + for shape_w_c in [ + [128, 128, 64, 4, 1], + [256, 128, 64, 4, 2], + [128, 128, 128, 4, 2], + ] + for shape in [ + [512, 360, 1024], + [360, 4096, 512], + ] + for trans_output in [False] + for out_dtype in ['float32'] + for use_tma_store in [False, True] + for num_stages in [3, 4] + for enable_ws in [False, True] + ]) +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") +def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, + out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS): if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ - '16-32-64-4-4-512-256-64-True-False', - '16-32-64-4-4-512-256-64-True-True', - '16-32-64-4-4-512-256-64-False-False', - '16-32-64-4-4-512-256-64-False-True', + '16-32-64-4-4-512-256-64-True-False', + '16-32-64-4-4-512-256-64-True-True', + '16-32-64-4-4-512-256-64-False-False', + '16-32-64-4-4-512-256-64-False-True', ]: pytest.skip('shapePerCTA[1] < 16 not supported') if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ - '16-32-64-4-1-256-256-256-False', - '16-32-64-4-2-256-256-256-False', - '16-32-64-4-2-256-256-256-True', - '16-32-64-8-2-256-256-256-False', - '16-32-64-8-2-256-256-256-True', + '16-32-64-4-1-256-256-256-False', + '16-32-64-4-2-256-256-256-False', + '16-32-64-4-2-256-256-256-True', + '16-32-64-8-2-256-256-256-False', + '16-32-64-8-2-256-256-256-True', ]: pytest.skip('Known legacy issue, ldmatrix can only support x4') + enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower() + if NUM_CTAS > 1 and enable_tma in ["on", "true", "1"]: + pytest.skip('multi-CTA with TMA not supported in MaterializeLoadStore') M = BLOCK_M if M is None else M N = BLOCK_N if N is None else N @@ -410,38 +441,38 @@ def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, else: ref = d return ref + golden = process_epilogue(dot, bias, w, epilogue) def grid(META): - return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),) - pgm = matmul_kernel[grid](a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, - M=M, N=N, K=K, - stride_am=a.stride(0), stride_ak=a.stride(1), - stride_bk=b.stride(0), stride_bn=b.stride(1), - stride_wm=w.stride(0), stride_wn=w.stride(1), - stride_zm=z.stride(0), stride_zn=z.stride(1), - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, - out_dtype=out_dtype, - USE_TMA_STORE=USE_TMA_STORE, - ADD_MATRIX=epilogue == 'add-matrix', - ADD_ROWS=epilogue == 'add-rows', - ADD_COLS=epilogue == 'add-cols', - DO_SOFTMAX=epilogue == 'softmax', - CHAIN_DOT=epilogue == 'chain-dot', - A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], - B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], - W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], - Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], - num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES, - enable_warp_specialization=ENABLE_WS) + return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) + + pgm = matmul_kernel[grid]( + a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # + M=M, N=N, K=K, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_wm=w.stride(0), stride_wn=w.stride(1), # + stride_zm=z.stride(0), stride_zn=z.stride(1), # + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # + out_dtype=out_dtype, # + USE_TMA_STORE=USE_TMA_STORE, # + ADD_MATRIX=epilogue == 'add-matrix', # + ADD_ROWS=epilogue == 'add-rows', # + ADD_COLS=epilogue == 'add-cols', # + DO_SOFTMAX=epilogue == 'softmax', # + CHAIN_DOT=epilogue == 'chain-dot', # + A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # + B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # + W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # + Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # + num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES, # + enable_warp_specialization=ENABLE_WS) torch.set_printoptions(profile="full") golden = torch.nn.functional.normalize(golden) z = torch.nn.functional.normalize(z) - assert_close(z, golden, - rtol=1e-2, - atol=1e-3, - check_dtype=False) + assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: diff --git a/python/test/unit/hopper/test_gemm_fusion.py b/python/test/unit/hopper/test_gemm_fusion.py index 1fd53d5c4..4b439efa8 100644 --- a/python/test/unit/hopper/test_gemm_fusion.py +++ b/python/test/unit/hopper/test_gemm_fusion.py @@ -27,16 +27,20 @@ import triton.language as tl @triton.jit -def gemm_fusion_kernel(A, B, C, E, - M, N, K, - stride_am, stride_ak, stride_bn, stride_bk, stride_cn, stride_ck, stride_em, stride_ek, +def gemm_fusion_kernel(A, B, C, E, # + M, N, K, # + stride_am, stride_ak, stride_bn, stride_bk, stride_cn, stride_ck, stride_em, stride_ek, # BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr): pid = tl.program_id(0) - a_tile_ptr = tl.make_block_ptr(base=A, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) - b_tile_ptr = tl.make_block_ptr(base=B, shape=(N, K), strides=(stride_bn, stride_bk), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_K), order=(1, 0)) - c_tile_ptr = tl.make_block_ptr(base=C, shape=(N, K), strides=(stride_cn, stride_ck), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_K), order=(1, 0)) - e_tile_ptr = tl.make_block_ptr(base=E, shape=(M, K), strides=(stride_em, stride_ek), offsets=(pid * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + a_tile_ptr = tl.make_block_ptr(base=A, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_tile_ptr = tl.make_block_ptr(base=B, shape=(N, K), strides=(stride_bn, stride_bk), offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_K), order=(1, 0)) + c_tile_ptr = tl.make_block_ptr(base=C, shape=(N, K), strides=(stride_cn, stride_ck), offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_K), order=(1, 0)) + e_tile_ptr = tl.make_block_ptr(base=E, shape=(M, K), strides=(stride_em, stride_ek), offsets=(pid * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) acc_e = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32) a = tl.load(a_tile_ptr) @@ -57,66 +61,70 @@ def gemm_fusion_kernel(A, B, C, E, def test_gemm_fusion(): M, N, K = 4096, 4096, 64 BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 64 - A = torch.empty( - (M, K), dtype=torch.float16, device='cuda').normal_( - mean=0.1, std=0.2) - B = torch.empty( - (N, K), dtype=torch.float16, device='cuda').normal_( - mean=0.1, std=0.2) - C = torch.empty( - (N, K), dtype=torch.float16, device='cuda').normal_( - mean=0.1, std=0.2) + A = torch.empty((M, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2) + B = torch.empty((N, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2) + C = torch.empty((N, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2) E = torch.empty((M, K), dtype=torch.float16, device='cuda') ref_out = torch.matmul(torch.matmul(A, B.T), C) num_warps = 4 grid = (triton.cdiv(M, BLOCK_M), 1) - gemm_fusion_kernel[grid](A, B, C, E, M, N, K, - A.stride(0), A.stride(1), B.stride(0), B.stride( - 1), C.stride(0), C.stride(1), E.stride(0), E.stride(1), - BLOCK_M, BLOCK_N, BLOCK_K, num_warps=num_warps) + gemm_fusion_kernel[grid]( + A, B, C, E, M, N, K, # + A.stride(0), A.stride(1), # + B.stride(0), B.stride(1), # + C.stride(0), C.stride(1), # + E.stride(0), E.stride(1), # + BLOCK_M, BLOCK_N, BLOCK_K, # + num_warps=num_warps) torch.testing.assert_close(ref_out, E, atol=1e-2, rtol=0) @triton.jit -def batched_gemm_fusion( - Q, K, V, Out, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, - stride_oz, stride_oh, stride_om, stride_on, - Z, NH, N_CTX, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): +def batched_gemm_fusion(Q, K, V, Out, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vk, stride_vn, # + stride_oz, stride_oh, stride_om, stride_on, # + Z, NH, N_CTX, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr): start_m = tl.program_id(0) off_hz = tl.program_id(1) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - q_tile_ptr = tl.make_block_ptr(base=Q, - shape=(Z, NH, N_CTX, BLOCK_DMODEL), - strides=(stride_qz, stride_qh, stride_qm, stride_qk), - offsets=(off_hz // NH, off_hz % NH, start_m, 0), - block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL), - order=(3, 2, 1, 0)) - k_tile_ptr = tl.make_block_ptr(base=K, - shape=(Z, NH, N_CTX, BLOCK_DMODEL), - strides=(stride_kz, stride_kh, stride_kn, stride_kk), - offsets=(off_hz // NH, off_hz % NH, 0, 0), - block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL), - order=(3, 2, 1, 0)) - v_tile_ptr = tl.make_block_ptr(base=V, - shape=(Z, NH, N_CTX, BLOCK_DMODEL), - strides=(stride_vz, stride_vh, stride_vk, stride_vn), - offsets=(off_hz // NH, off_hz % NH, 0, 0), - block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL), - order=(3, 2, 1, 0)) - o_tile_ptr = tl.make_block_ptr(base=Out, - shape=(Z, NH, N_CTX, BLOCK_DMODEL), - strides=(stride_oz, stride_oh, stride_om, stride_on), - offsets=(off_hz // NH, off_hz % NH, start_m, 0), - block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL), - order=(3, 2, 1, 0)) + q_tile_ptr = tl.make_block_ptr( + base=Q, + shape=(Z, NH, N_CTX, BLOCK_DMODEL), + strides=(stride_qz, stride_qh, stride_qm, stride_qk), + offsets=(off_hz // NH, off_hz % NH, start_m, 0), + block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL), + order=(3, 2, 1, 0), + ) + k_tile_ptr = tl.make_block_ptr( + base=K, + shape=(Z, NH, N_CTX, BLOCK_DMODEL), + strides=(stride_kz, stride_kh, stride_kn, stride_kk), + offsets=(off_hz // NH, off_hz % NH, 0, 0), + block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL), + order=(3, 2, 1, 0), + ) + v_tile_ptr = tl.make_block_ptr( + base=V, + shape=(Z, NH, N_CTX, BLOCK_DMODEL), + strides=(stride_vz, stride_vh, stride_vk, stride_vn), + offsets=(off_hz // NH, off_hz % NH, 0, 0), + block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL), + order=(3, 2, 1, 0), + ) + o_tile_ptr = tl.make_block_ptr( + base=Out, + shape=(Z, NH, N_CTX, BLOCK_DMODEL), + strides=(stride_oz, stride_oh, stride_om, stride_on), + offsets=(off_hz // NH, off_hz % NH, start_m, 0), + block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL), + order=(3, 2, 1, 0), + ) q = tl.load(q_tile_ptr, boundary_check=(0, 1, 2, 3)) q = tl.view(q, (BLOCK_M, BLOCK_DMODEL)) @@ -155,12 +163,13 @@ def test_batched_gemm_fusion(): ref_out = torch.matmul(torch.matmul(A, BT), C) num_warps = 4 grid = (triton.cdiv(N_CTX, BLOCK_M), B * NH) - batched_gemm_fusion[grid](A, B, C, E, - A.stride(0), A.stride(1), A.stride(2), A.stride(3), - B.stride(0), B.stride(1), B.stride(2), B.stride(3), - C.stride(0), C.stride(1), C.stride(2), C.stride(3), - E.stride(0), E.stride(1), E.stride(2), E.stride(3), - Z, NH, N_CTX, - BLOCK_M, BLOCK_DMODEL, BLOCK_N, num_warps=num_warps) + batched_gemm_fusion[grid]( + A, B, C, E, # + A.stride(0), A.stride(1), A.stride(2), A.stride(3), # + B.stride(0), B.stride(1), B.stride(2), B.stride(3), # + C.stride(0), C.stride(1), C.stride(2), C.stride(3), # + E.stride(0), E.stride(1), E.stride(2), E.stride(3), # + Z, NH, N_CTX, # + BLOCK_M, BLOCK_DMODEL, BLOCK_N, num_warps=num_warps) torch.testing.assert_close(ref_out, E, atol=1e-2, rtol=0) diff --git a/python/test/unit/hopper/test_mixed_io.py b/python/test/unit/hopper/test_mixed_io.py index cecabbaa7..68ee474a4 100644 --- a/python/test/unit/hopper/test_mixed_io.py +++ b/python/test/unit/hopper/test_mixed_io.py @@ -24,10 +24,8 @@ def add_kernel( offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements - x_block_ptr = tl.make_block_ptr( - base=x_ptr, shape=(n_elements, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), - block_shape=(BLOCK_SIZE, ), order=(0, ) - ) + x_block_ptr = tl.make_block_ptr(base=x_ptr, shape=(n_elements, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) x = tl.load(x_block_ptr, boundary_check=(0, ), padding_option='zero') y = tl.load(y_ptr + offsets, mask=mask) @@ -36,9 +34,7 @@ def add_kernel( @pytest.mark.parametrize('SIZE,BLOCK_SIZE,dtype_str', - [(98432, 1024, dtype_str) - for dtype_str in ['float16', 'float32'] - ]) + [(98432, 1024, dtype_str) for dtype_str in ['float16', 'float32']]) def test_add(SIZE, BLOCK_SIZE, dtype_str): dtype = dtype_mapping[dtype_str] output = torch.empty(SIZE, device='cuda', dtype=dtype) @@ -46,7 +42,8 @@ def test_add(SIZE, BLOCK_SIZE, dtype_str): y = torch.randn(SIZE, device='cuda', dtype=dtype) def grid(meta): - return (triton.cdiv(SIZE, meta['BLOCK_SIZE']),) + return (triton.cdiv(SIZE, meta['BLOCK_SIZE']), ) + add_kernel[grid](x, y, output, SIZE, BLOCK_SIZE=BLOCK_SIZE) output_torch = x + y @@ -64,25 +61,20 @@ def load_reduce_kernel( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): - x_ptr = tl.make_block_ptr( - base=x_ptr, shape=(BLOCK_M, BLOCK_N), strides=(stride_xm, stride_xn), - offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0) - ) + x_ptr = tl.make_block_ptr(base=x_ptr, shape=(BLOCK_M, BLOCK_N), strides=(stride_xm, stride_xn), offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) x = tl.load(x_ptr) y = tl.max(x, axis=1) tl.store(y_ptr + tl.arange(0, BLOCK_M), y) -@pytest.mark.parametrize('BLOCK_M,BLOCK_N,dtype_str', - [(128, 64, dtype_str) - for dtype_str in ['float16'] - ]) +@pytest.mark.parametrize('BLOCK_M,BLOCK_N,dtype_str', [(128, 64, dtype_str) for dtype_str in ['float16']]) def test_load_reduce(BLOCK_M, BLOCK_N, dtype_str): dtype = dtype_mapping[dtype_str] x = torch.randn((BLOCK_M, BLOCK_N), device='cuda', dtype=dtype) y = torch.empty((BLOCK_M, ), device='cuda', dtype=dtype) - load_reduce_kernel[(1,)](x, y, x.stride(0), x.stride(1), y.stride(0), BLOCK_M, BLOCK_N) + load_reduce_kernel[(1, )](x, y, x.stride(0), x.stride(1), y.stride(0), BLOCK_M, BLOCK_N) golden = x.max(dim=1)[0] torch.set_printoptions(profile='full') diff --git a/python/test/unit/hopper/test_persistent_warp_specialized_fused-attention.py b/python/test/unit/hopper/test_persistent_warp_specialized_fused-attention.py index 868c052d6..ea1776998 100644 --- a/python/test/unit/hopper/test_persistent_warp_specialized_fused-attention.py +++ b/python/test/unit/hopper/test_persistent_warp_specialized_fused-attention.py @@ -18,7 +18,6 @@ # CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - """ Fused Attention =============== @@ -40,18 +39,17 @@ import triton.language as tl key=['Q', 'K', 'V'], ) @triton.jit -def _fwd_kernel( - Q, K, V, sm_scale, - L, M, - Out, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, - stride_oz, stride_oh, stride_om, stride_on, - Z, H, N_CTX, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): +def _fwd_kernel(Q, K, V, sm_scale, # + L, M, # + Out, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vk, stride_vn, # + stride_oz, stride_oh, stride_om, stride_on, # + Z, H, N_CTX, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr # + ): start_m = tl.program_id(0) off_hz = tl.program_id(1) # initialize offsets @@ -116,11 +114,10 @@ def _fwd_kernel( @triton.jit -def _bwd_preprocess( - Out, DO, L, - NewDO, Delta, - BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, -): +def _bwd_preprocess(Out, DO, L, # + NewDO, Delta, # + BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr # + ): off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) off_n = tl.arange(0, D_HEAD) # load @@ -136,19 +133,18 @@ def _bwd_preprocess( @triton.jit -def _bwd_kernel( - Q, K, V, sm_scale, Out, DO, - DQ, DK, DV, - L, M, - D, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, - Z, H, N_CTX, - num_block, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): +def _bwd_kernel(Q, K, V, sm_scale, Out, DO, # + DQ, DK, DV, # + L, M, # + D, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vk, stride_vn, # + Z, H, N_CTX, # + num_block, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr, # + ): off_hz = tl.program_id(0) off_z = off_hz // H off_h = off_hz % H @@ -240,16 +236,16 @@ class _attention(torch.autograd.Function): assert num_warps == 4 _fwd_kernel[grid]( - q, k, v, sm_scale, - L, m, - o, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - o.stride(0), o.stride(1), o.stride(2), o.stride(3), - q.shape[0], q.shape[1], q.shape[2], - BLOCK_M=BLOCK, BLOCK_N=BLOCK, - BLOCK_DMODEL=Lk, + q, k, v, sm_scale, # + L, m, # + o, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + o.stride(0), o.stride(1), o.stride(2), o.stride(3), # + q.shape[0], q.shape[1], q.shape[2], # + BLOCK_M=BLOCK, BLOCK_N=BLOCK, # + BLOCK_DMODEL=Lk # ) ctx.save_for_backward(q, k, v, o, L, m) @@ -269,24 +265,23 @@ class _attention(torch.autograd.Function): do_scaled = torch.empty_like(do) delta = torch.empty_like(l) _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( - o, do, l, - do_scaled, delta, - BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL, - ) - _bwd_kernel[(ctx.grid[1],)]( - q, k, v, ctx.sm_scale, - o, do_scaled, - dq, dk, dv, - l, m, - delta, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - q.shape[0], q.shape[1], q.shape[2], - ctx.grid[0], - BLOCK_M=BLOCK, BLOCK_N=BLOCK, - BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, - num_stages=1, + o, do, l, # + do_scaled, delta, # + BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL) + _bwd_kernel[(ctx.grid[1], )]( + q, k, v, ctx.sm_scale, # + o, do_scaled, # + dq, dk, dv, # + l, m, # + delta, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + q.shape[0], q.shape[1], q.shape[2], # + ctx.grid[0], # + BLOCK_M=BLOCK, BLOCK_N=BLOCK, # + BLOCK_DMODEL=ctx.BLOCK_DMODEL, # + num_warps=8, num_stages=1 # ) return dq, dk, dv, None @@ -339,19 +334,19 @@ except BaseException: BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 # vary seq length for fixed head and batch=4 -configs = [triton.testing.Benchmark( - x_names=['N_CTX'], - # x_vals=[2**i for i in range(10, 14)], - x_vals=[2**i for i in range(10, 11)], - line_arg='provider', - line_vals=['triton'] + (['flash'] if HAS_FLASH else []), - line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), - styles=[('red', '-'), ('blue', '-')], - ylabel='ms', - plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}', - args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode} - # ) for mode in ['fwd', 'bwd']] -) for mode in ['fwd']] +configs = [ + triton.testing.Benchmark( + x_names=['N_CTX'], + # x_vals=[2**i for i in range(10, 14)], + x_vals=[2**i + for i in range(10, 11)], line_arg='provider', line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), styles=[('red', '-'), ('blue', '-')], ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}', + args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode} + # ) for mode in ['fwd', 'bwd']] + ) + for mode in ['fwd'] +] @triton.testing.perf_report(configs) @@ -374,9 +369,8 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) return ms if provider == "flash": - lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) - cu_seqlens = torch.zeros( - (BATCH + 1,), device=device, dtype=torch.int32) + lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros((BATCH + 1, ), device=device, dtype=torch.int32) cu_seqlens[1:] = lengths.cumsum(0) qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True) diff --git a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py index 32c04c33b..340709a6a 100644 --- a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py +++ b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py @@ -29,14 +29,14 @@ import triton.language as tl @triton.jit -def static_persistent_matmul_kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - NUM_SM: tl.constexpr, +def static_persistent_matmul_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + NUM_SM: tl.constexpr # ): start_tile = tl.program_id(axis=0) m_tiles = tl.cdiv(M, BLOCK_M) @@ -68,14 +68,14 @@ def static_persistent_matmul_kernel( @triton.jit -def static_persistent_tma_matmul_kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - NUM_SM: tl.constexpr, +def static_persistent_tma_matmul_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + NUM_SM: tl.constexpr # ): start_tile = tl.program_id(axis=0) m_tiles = tl.cdiv(M, BLOCK_M) @@ -88,8 +88,10 @@ def static_persistent_tma_matmul_kernel( block_offset_m = pre_pid_m * BLOCK_M block_offset_n = pre_pid_n * BLOCK_N - a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) - b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) + a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), + offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), + offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) for tile_id in range(start_tile, num_tiles, NUM_SM): pid_m = tile_id // n_tiles pid_n = tile_id % n_tiles @@ -114,21 +116,23 @@ def static_persistent_tma_matmul_kernel( pre_pid_n = pid_n -@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA', - [(*shape, use_tma) - for shape in [ - [4096, 4096, 64, 64, 64, 16, 4, 1, False, True], - [4096, 4096, 64, 64, 64, 32, 4, 1, False, True], - [4096, 4096, 64, 256, 64, 16, 4, 1, False, True], - [4096, 4096, 64, 128, 128, 16, 4, 1, False, True], - # TODO: fix issue for 8-warp persistent kernel - # [4096, 4096, 64, 128, 128, 16, 8, 1, False, True], - # [4096, 4096, 64, 128, 256, 16, 8, 1, False, True], - ] - for use_tma in [False, True] - ]) +@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA', [( + *shape, use_tma +) for shape in [ + [4096, 4096, 64, 64, 64, 16, 4, 1, False, True], + [4096, 4096, 64, 64, 64, 32, 4, 1, False, True + ], + [4096, 4096, 64, 256, 64, 16, 4, 1, False, True + ], + [4096, 4096, 64, 128, 128, 16, 4, 1, False, True + ], + # TODO: fix issue for 8-warp persistent kernel + # [4096, 4096, 64, 128, 128, 16, 8, 1, False, True], + # [4096, 4096, 64, 128, 256, 16, 8, 1, False, True], +] for use_tma in [False, True]]) @pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") -def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, TRANS_A, TRANS_B, USE_TMA): +def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, + TRANS_A, TRANS_B, USE_TMA): if (TRANS_A): a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T else: @@ -141,25 +145,33 @@ def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLO c = torch.empty((M, N), device=a.device, dtype=torch.float32) num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count - grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),) + grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), ) if USE_TMA: - static_persistent_tma_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS) + static_persistent_tma_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), + stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), + stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, + num_ctas=NUM_CTAS) else: - static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS) + static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), + stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), + stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, + num_ctas=NUM_CTAS) th_c = torch.matmul(a, b) torch.testing.assert_close(th_c, c, atol=1e-2, rtol=0, check_dtype=False) @triton.jit -def warp_specialized_matmul_kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, +def warp_specialized_matmul_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # ): tid = tl.program_id(axis=0) n_tiles = tl.cdiv(N, BLOCK_N) @@ -193,13 +205,13 @@ def warp_specialized_matmul_kernel( @triton.jit -def tma_warp_specialized_matmul_kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, +def tma_warp_specialized_matmul_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # ): tid = tl.program_id(axis=0) n_tiles = tl.cdiv(N, BLOCK_N) @@ -232,8 +244,7 @@ def tma_warp_specialized_matmul_kernel( @pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA', - [(*shape, use_tma) - for shape in [ + [(*shape, use_tma) for shape in [ [2048, 2048, 64, 64, 64, 16, 1, False, True], [4096, 4096, 64, 64, 64, 16, 1, False, True], [128, 4096, 64, 64, 64, 16, 1, False, True], @@ -257,9 +268,7 @@ def tma_warp_specialized_matmul_kernel( [4096, 4096, 128, 256, 128, 64, 4, False, True], [4096, 4096, 256, 128, 256, 64, 4, False, True], [4096, 4096, 256, 256, 256, 64, 4, False, True], - ] - for use_tma in [False, True] - ]) + ] for use_tma in [False, True]]) @pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B, USE_TMA): if (TRANS_A): @@ -274,29 +283,29 @@ def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K c = torch.empty((M, N), device=a.device, dtype=torch.float32) - grid = lambda META: (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) + grid = lambda META: (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), ) if USE_TMA: tma_warp_specialized_matmul_kernel[grid]( - a, b, c, - M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - BLOCK_M, BLOCK_N, BLOCK_K, - num_warps=4, - num_ctas=NUM_CTAS, + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + BLOCK_M, BLOCK_N, BLOCK_K, # + num_warps=4, # + num_ctas=NUM_CTAS, # enable_warp_specialization=True) else: warp_specialized_matmul_kernel[grid]( - a, b, c, - M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - BLOCK_M, BLOCK_N, BLOCK_K, - num_warps=4, - num_ctas=NUM_CTAS, + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + BLOCK_M, BLOCK_N, BLOCK_K, # + num_warps=4, # + num_ctas=NUM_CTAS, # enable_warp_specialization=True) th_c = torch.matmul(a, b) @@ -304,14 +313,14 @@ def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K @triton.jit -def static_persistent_warp_specialized_matmul_kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - NUM_SM: tl.constexpr, +def static_persistent_warp_specialized_matmul_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + NUM_SM: tl.constexpr # ): start_tile = tl.program_id(axis=0) m_tiles = tl.cdiv(M, BLOCK_M) @@ -343,14 +352,14 @@ def static_persistent_warp_specialized_matmul_kernel( @triton.jit -def static_persistent_tma_warp_specialized_matmul_kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - NUM_SM: tl.constexpr, +def static_persistent_tma_warp_specialized_matmul_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + NUM_SM: tl.constexpr # ): start_tile = tl.program_id(axis=0) m_tiles = tl.cdiv(M, BLOCK_M) @@ -363,8 +372,10 @@ def static_persistent_tma_warp_specialized_matmul_kernel( block_offset_m = pre_pid_m * BLOCK_M block_offset_n = pre_pid_n * BLOCK_N - a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) - b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) + a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), + offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), + offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) for tile_id in range(start_tile, num_tiles, NUM_SM): pid_m = tile_id // n_tiles pid_n = tile_id % n_tiles @@ -390,8 +401,7 @@ def static_persistent_tma_warp_specialized_matmul_kernel( @pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA', - [(*shape, use_tma) - for shape in [ + [(*shape, use_tma) for shape in [ [2048, 2048, 64, 64, 64, 16, 1, False, True], [4096, 4096, 64, 64, 64, 16, 1, False, True], [128, 4096, 64, 64, 64, 16, 1, False, True], @@ -415,11 +425,10 @@ def static_persistent_tma_warp_specialized_matmul_kernel( [4096, 4096, 128, 256, 128, 64, 4, False, True], [4096, 4096, 256, 128, 256, 64, 4, False, True], [4096, 4096, 256, 256, 256, 64, 4, False, True], - ] - for use_tma in [False, True] - ]) + ] for use_tma in [False, True]]) @pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") -def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B, USE_TMA): +def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B, + USE_TMA): if (TRANS_A): a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T else: @@ -432,27 +441,22 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N c = torch.empty((M, N), device=a.device, dtype=torch.float32) num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count - grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),) + grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), ) if USE_TMA: static_persistent_tma_warp_specialized_matmul_kernel[grid]( - a, b, c, - M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - BLOCK_M, BLOCK_N, BLOCK_K, num_SMs, - num_warps=4, num_ctas=NUM_CTAS, + a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M, + BLOCK_N, BLOCK_K, num_SMs, num_warps=4, num_ctas=NUM_CTAS, # enable_warp_specialization=True) else: static_persistent_warp_specialized_matmul_kernel[grid]( - a, b, c, - M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - BLOCK_M, BLOCK_N, BLOCK_K, num_SMs, - num_warps=4, num_ctas=NUM_CTAS, + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + BLOCK_M, BLOCK_N, BLOCK_K, num_SMs, # + num_warps=4, num_ctas=NUM_CTAS, # enable_warp_specialization=True) th_c = torch.matmul(a, b) @@ -460,16 +464,15 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N @triton.jit -def static_persistent_matmul_no_scf_kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr, - NUM_SM: tl.constexpr, USE_TMA_LOAD: tl.constexpr, -): +def static_persistent_matmul_no_scf_kernel(a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr, # + NUM_SM: tl.constexpr, USE_TMA_LOAD: tl.constexpr # + ): start_tile = tl.program_id(axis=0) m_tiles = tl.cdiv(M, BLOCK_M) n_tiles = tl.cdiv(N, BLOCK_N) @@ -487,7 +490,8 @@ def static_persistent_matmul_no_scf_kernel( offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) if USE_TMA_EPILOGUE: c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), - offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) + offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0)) for tile_id in range(start_tile, num_tiles, NUM_SM): pid_m = tile_id // n_tiles @@ -524,29 +528,27 @@ def static_persistent_matmul_no_scf_kernel( pre_pid_n = pid_n -@pytest.mark.parametrize('M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,USE_TMA_LOAD', - itertools.chain( - *[ - [ - # numCTAs = 1, no TMA multicast: - [64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD], - [64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD], - [64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD], - [64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], - [64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], - [64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], - [128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD], - [128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], - # small M, N - [16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], - [16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], - [32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], - [32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], - ] for USE_TMA_EPILOGUE in [True, False] - for USE_TMA_LOAD in [True, False] - ])) +@pytest.mark.parametrize( + 'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,USE_TMA_LOAD', + itertools.chain(*[[ + # numCTAs = 1, no TMA multicast: + [64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], + # small M, N + [16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], + ] for USE_TMA_EPILOGUE in [True, False] for USE_TMA_LOAD in [True, False]])) @pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") -def test_static_persistent_matmul_no_scf_kernel(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE, USE_TMA_LOAD): +def test_static_persistent_matmul_no_scf_kernel(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, + USE_TMA_EPILOGUE, USE_TMA_LOAD): if (TRANS_A): a = torch.randn((K, M), device='cuda', dtype=torch.float16).T else: @@ -564,46 +566,42 @@ def test_static_persistent_matmul_no_scf_kernel(M, N, K, NUM_CTAS, NUM_WARPS, TR num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count # TODO: set `enable_warp_specialization=False` will lead to compilation error. - static_persistent_matmul_no_scf_kernel[(num_SMs,)](a_ptr=a, b_ptr=b, c_ptr=c, - M=M, N=N, K=K, - stride_am=a.stride(0), stride_ak=a.stride(1), - stride_bk=b.stride(0), stride_bn=b.stride(1), - stride_cm=c.stride(0), stride_cn=c.stride(1), - BLOCK_M=M if M < 128 else M // 2, BLOCK_N=N if N < 128 else N // 2, BLOCK_K=K, NUM_SM=num_SMs, - num_warps=NUM_WARPS, - num_ctas=NUM_CTAS, - FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), - USE_TMA_EPILOGUE=USE_TMA_EPILOGUE, - USE_TMA_LOAD=USE_TMA_LOAD, - enable_warp_specialization=True) + static_persistent_matmul_no_scf_kernel[(num_SMs, )]( + a_ptr=a, b_ptr=b, c_ptr=c, # + M=M, N=N, K=K, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_cm=c.stride(0), stride_cn=c.stride(1), # + BLOCK_M=M if M < 128 else M // 2, BLOCK_N=N if N < 128 else N // 2, BLOCK_K=K, NUM_SM=num_SMs, # + num_warps=NUM_WARPS, # + num_ctas=NUM_CTAS, # + FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), # + USE_TMA_EPILOGUE=USE_TMA_EPILOGUE, # + USE_TMA_LOAD=USE_TMA_LOAD, # + enable_warp_specialization=True) a_f32 = a.to(torch.float32) b_f32 = b.to(torch.float32) golden = torch.matmul(a_f32, b_f32) torch.set_printoptions(profile="full") - assert_close( - c, - golden, - rtol=1e-2, - atol=1e-3, - check_dtype=False) + assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False) @triton.jit -def full_static_persistent_matmul_kernel( - a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_wm, stride_wn, - stride_zm, stride_zn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, - out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr, - ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, - DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, - A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, - B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr, - NUM_SM: tl.constexpr -): +def full_static_persistent_matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_wm, stride_wn, # + stride_zm, stride_zn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, # + out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr, # + ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, # + DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, # + A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, # + B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr, # + NUM_SM: tl.constexpr # + ): start_pid = tl.program_id(axis=0) num_pid_n = tl.cdiv(N, BLOCK_N) num_pid_m = tl.cdiv(M, BLOCK_M) @@ -618,15 +616,18 @@ def full_static_persistent_matmul_kernel( pre_block_offset_m = pre_pid_m * BLOCK_M pre_block_offset_n = pre_pid_n * BLOCK_N a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), - offsets=(pre_block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(A_ORDER_0, A_ORDER_1)) + offsets=(pre_block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), + order=(A_ORDER_0, A_ORDER_1)) b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), - offsets=(0, pre_block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(B_ORDER_0, B_ORDER_1)) + offsets=(0, pre_block_offset_n), block_shape=(BLOCK_K, BLOCK_N), + order=(B_ORDER_0, B_ORDER_1)) w_tile_ptr = tl.make_block_ptr(base=w_ptr, shape=(N, N), strides=(stride_wm, stride_wn), offsets=(0, pre_block_offset_n), block_shape=(BLOCK_N, BLOCK_N), order=(0, 1)) if USE_TMA_STORE: z_block_ptr = tl.make_block_ptr(base=z_ptr, shape=(M, N), strides=(stride_zm, stride_zn), - offsets=(pre_block_offset_m, pre_block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) + offsets=(pre_block_offset_m, pre_block_offset_n), + block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) for tile_id in range(start_pid, num_tiles, NUM_SM): group_id = tile_id // num_pid_in_group @@ -694,136 +695,120 @@ def full_static_persistent_matmul_kernel( pre_pid_n = pid_n -@pytest.mark.parametrize('BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS', - [ - # corner shapes - (128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws) - for shape_w_c in [ - [4096, 1, 1024, False, False], - [2048, 204, 1000, True, False], - [16, 524288, 32, False, True], - ] - for out_dtype in ['float16', 'float32'] - for use_tma_store in [False, True] - for enable_ws in [True] - ] + [ - # softmax epilogue - (*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws) - # softmax works for one CTA - for shape_w_c in [ - [64, 64, 16, 4, 1, 64, 64, 64], - [128, 128, 64, 4, 1, None, None, None], - [16, 16, 64, 4, 1, 16, 16, 64], - # TODO: enable when num_warps != 4 is supported. - # [64, 64, 32, 8, 1, 64, 64, 64], - [128, 128, 64, 4, 1, 128, 128, 128], - ] - for epilogue in ['softmax'] - for out_dtype in ['float16', 'float32'] - for use_tma_store in [False, True] - for trans_a in [False,] - for trans_b in [True,] - for num_stages in [3] - for enable_ws in [True] - ] + [ - # loop over tile shapes and transpose combinations - (*shape_w_c, trans_a, trans_b, 'none', out_dtype, use_tma_store, num_stages, enable_ws) - for shape_w_c in [ - [64, 64, 32, 4, 1, 128, 256, 64], - [128, 128, 16, 4, 4, 512, 256, 64], - [128, 256, 32, 4, 8, 256, 256, 192], - [512, 256, 32, 4, 8, 1024, 256, 192], - # BLOCK_K >= 128 - [64, 128, 128, 4, 1, 512, 256, 256], - [128, 128, 128, 4, 1, 256, 256, 192], - [128, 128, 128, 4, 2, 256, 256, 192], - # small BLOCK_M and BLOCK_K - [16, 32, 32, 4, 1, 128, 256, 64], - [32, 32, 16, 4, 1, 256, 256, 192], - [16, 32, 64, 4, 4, 512, 256, 64], - ] - for out_dtype in ['float32',] - for use_tma_store in [False,] - for trans_a in [False, True] - for trans_b in [False, True] - for num_stages in [3] - for enable_ws in [True] - ] + [ - # loop over epilogues besides of softmax - (*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws) - for shape_w_c in [ - [64, 64, 16, 4, 1, 128, 128, 64], - *[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4] for num_ctas in [1, 2, 4]], - # for chain-dot - [128, 128, 64, 4, 1, None, None, None], - [64, 64, 16, 4, 1, None, None, None], - # small BLOCK_M and BLOCK_K - [16, 16, 64, 4, 1, 128, 128, 64], - *[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4] for num_ctas in [1, 2]], - # # TODO: enable when num_warps != 4 is supported. - # # repeat - # # [64, 64, 32, 8, 1, 128, 256, 64], - # # [64, 64, 16, 8, 2, 128, 128, 64], - # irregular shape - [128, 128, 64, 4, 1, 500, 200, 128], - [128, 128, 64, 4, 1, 513, 193, 192], - ] - for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'] - for out_dtype in ['float16', 'float32'] - for use_tma_store in [False, True] - for trans_a in [False,] - for trans_b in [True,] - for num_stages in [3] - for enable_ws in [True] - if not (epilogue == 'chain-dot' and (shape_w_c[5] is not None or shape_w_c[0] != shape_w_c[1])) - ] + [ - # loop over instr shapes & pipeline stages - (64, n, 16, 4, 1, 512, 256, 256, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws) - for n in [16, 32, 64, 128, 256] - for out_dtype in ['float32'] - for use_tma_store in [False,] - for num_stages in [2, 4, 5, 7] - for enable_ws in [True] - ] + [ - # irregular shapes - (*shape_w_c, *shape, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws) - for shape_w_c in [ - [128, 128, 64, 4, 1], - [256, 128, 64, 4, 2], - [128, 128, 128, 4, 2] - ] - for shape in [ - [512, 360, 1024], - [360, 4096, 512], - ] - for out_dtype in ['float32'] - for use_tma_store in [False, True] - for num_stages in [3, 4] - for enable_ws in [True] - ] - ) -@pytest.mark.skipif(torch.cuda.get_device_capability() - [0] < 9, reason="Requires compute capability >= 9") -def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS): - if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS])) in [ - '128-128-128-4-1-256-256-192-none-float32-True-3-True', - ]: +@pytest.mark.parametrize( + 'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS', + [ + # corner shapes + (128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws) for shape_w_c in [ + [4096, 1, 1024, False, False], + [2048, 204, 1000, True, False], + [16, 524288, 32, False, True], + ] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for enable_ws in [True] + ] + [ + # softmax epilogue + (*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws) + # softmax works for one CTA + for shape_w_c in [ + [64, 64, 16, 4, 1, 64, 64, 64], + [128, 128, 64, 4, 1, None, None, None], + [16, 16, 64, 4, 1, 16, 16, 64], + # TODO: enable when num_warps != 4 is supported. + # [64, 64, 32, 8, 1, 64, 64, 64], + [128, 128, 64, 4, 1, 128, 128, 128], + ] + for epilogue in ['softmax'] + for out_dtype in ['float16', 'float32'] + for use_tma_store in [False, True] + for trans_a in [False] + for trans_b in [True] + for num_stages in [3] + for enable_ws in [True] + ] + [ + # loop over tile shapes and transpose combinations + (*shape_w_c, trans_a, trans_b, 'none', out_dtype, use_tma_store, num_stages, enable_ws) for shape_w_c in [ + [64, 64, 32, 4, 1, 128, 256, 64], + [128, 128, 16, 4, 4, 512, 256, 64], + [128, 256, 32, 4, 8, 256, 256, 192], + [512, 256, 32, 4, 8, 1024, 256, 192], + # BLOCK_K >= 128 + [64, 128, 128, 4, 1, 512, 256, 256], + [128, 128, 128, 4, 1, 256, 256, 192], + [128, 128, 128, 4, 2, 256, 256, 192], + # small BLOCK_M and BLOCK_K + [16, 32, 32, 4, 1, 128, 256, 64], + [32, 32, 16, 4, 1, 256, 256, 192], + [16, 32, 64, 4, 4, 512, 256, 64], + ] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in + [False, True] for num_stages in [3] for enable_ws in [True] + ] + [ + # loop over epilogues besides of softmax + (*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws) for shape_w_c in [ + [64, 64, 16, 4, 1, 128, 128, 64], + *[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4] for num_ctas in [1, 2, 4]], + # for chain-dot + [128, 128, 64, 4, 1, None, None, None], + [64, 64, 16, 4, 1, None, None, None], + # small BLOCK_M and BLOCK_K + [16, 16, 64, 4, 1, 128, 128, 64], + *[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4] for num_ctas in [1, 2]], + # # TODO: enable when num_warps != 4 is supported. + # # repeat + # # [64, 64, 32, 8, 1, 128, 256, 64], + # # [64, 64, 16, 8, 2, 128, 128, 64], + # irregular shape + [128, 128, 64, 4, 1, 500, 200, 128], + [128, 128, 64, 4, 1, 513, 193, 192], + ] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'] for out_dtype in + ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in [False] for trans_b in [True] for + num_stages in [3] for enable_ws in [True] if not (epilogue == 'chain-dot' and + (shape_w_c[5] is not None or shape_w_c[0] != shape_w_c[1])) + ] + [ + # loop over instr shapes & pipeline stages + (64, n, 16, 4, 1, 512, 256, 256, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws) + for n in [16, 32, 64, 128, 256] + for out_dtype in ['float32'] + for use_tma_store in [False] + for num_stages in [2, 4, 5, 7] + for enable_ws in [True] + ] + [ + # irregular shapes + (*shape_w_c, *shape, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws) + for shape_w_c in [[128, 128, 64, 4, 1], [256, 128, 64, 4, 2], [128, 128, 128, 4, 2]] + for shape in [ + [512, 360, 1024], + [360, 4096, 512], + ] + for out_dtype in ['float32'] + for use_tma_store in [False, True] + for num_stages in [3, 4] + for enable_ws in [True] + ]) +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") +def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, + epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS): + if '-'.join( + map(str, [ + BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, + ENABLE_WS + ])) in [ + '128-128-128-4-1-256-256-192-none-float32-True-3-True', + ]: pytest.skip('out of resource: shared memory, Required: 263168') if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ - '16-32-64-4-4-512-256-64-True-False', - '16-32-64-4-4-512-256-64-True-True', - '16-32-64-4-4-512-256-64-False-False', - '16-32-64-4-4-512-256-64-False-True', + '16-32-64-4-4-512-256-64-True-False', + '16-32-64-4-4-512-256-64-True-True', + '16-32-64-4-4-512-256-64-False-False', + '16-32-64-4-4-512-256-64-False-True', ]: pytest.skip('shapePerCTA[1] < 16 not supported') if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ - '16-32-64-4-1-256-256-256-False', - '16-32-64-4-2-256-256-256-False', - '16-32-64-4-2-256-256-256-True', - '16-32-64-8-2-256-256-256-False', - '16-32-64-8-2-256-256-256-True', + '16-32-64-4-1-256-256-256-False', + '16-32-64-4-2-256-256-256-False', + '16-32-64-4-2-256-256-256-True', + '16-32-64-8-2-256-256-256-False', + '16-32-64-8-2-256-256-256-True', ]: pytest.skip('Known legacy issue, ldmatrix can only support x4') @@ -893,37 +878,36 @@ def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WAR else: ref = d return ref + golden = process_epilogue(dot, bias, w, epilogue) num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count def grid(META): - return (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),) + return (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), ) + full_static_persistent_matmul_kernel[grid]( - a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, - M=M, N=N, K=K, - stride_am=a.stride(0), stride_ak=a.stride(1), - stride_bk=b.stride(0), stride_bn=b.stride(1), - stride_wm=w.stride(0), stride_wn=w.stride(1), - stride_zm=z.stride(0), stride_zn=z.stride(1), - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, - out_dtype=out_dtype, - USE_TMA_STORE=USE_TMA_STORE, - ADD_MATRIX=epilogue == 'add-matrix', - ADD_ROWS=epilogue == 'add-rows', - ADD_COLS=epilogue == 'add-cols', - DO_SOFTMAX=epilogue == 'softmax', - CHAIN_DOT=epilogue == 'chain-dot', - A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], - B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], - num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES, - enable_warp_specialization=ENABLE_WS, + a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # + M=M, N=N, K=K, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_wm=w.stride(0), stride_wn=w.stride(1), # + stride_zm=z.stride(0), stride_zn=z.stride(1), # + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # + out_dtype=out_dtype, # + USE_TMA_STORE=USE_TMA_STORE, # + ADD_MATRIX=epilogue == 'add-matrix', # + ADD_ROWS=epilogue == 'add-rows', # + ADD_COLS=epilogue == 'add-cols', # + DO_SOFTMAX=epilogue == 'softmax', # + CHAIN_DOT=epilogue == 'chain-dot', # + A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # + B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # + num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES, # + enable_warp_specialization=ENABLE_WS, # NUM_SM=num_SMs) torch.set_printoptions(profile="full") golden = torch.nn.functional.normalize(golden) z = torch.nn.functional.normalize(z) - assert_close(z, golden, - rtol=1e-2, - atol=1e-3, - check_dtype=False) + assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) diff --git a/python/test/unit/hopper/test_tma_store_gemm.py b/python/test/unit/hopper/test_tma_store_gemm.py index 6d912d89c..b2fc3e874 100644 --- a/python/test/unit/hopper/test_tma_store_gemm.py +++ b/python/test/unit/hopper/test_tma_store_gemm.py @@ -19,7 +19,6 @@ # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - import pytest import torch from torch.testing import assert_close @@ -29,21 +28,21 @@ import triton.language as tl @triton.jit -def matmul_tma_load_store( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - OUTPUT_F16: tl.constexpr +def matmul_tma_load_store( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + OUTPUT_F16: tl.constexpr # ): - a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), - offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) - b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), - offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) - c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), - offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0), + block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) + c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) a = tl.load(a_block_ptr) b = tl.load(b_block_ptr) @@ -78,15 +77,15 @@ def test_tma_load_store(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_F if OUTPUT_F16: c = torch.empty((M, N), device=a.device, dtype=torch.float16) - matmul_tma_load_store[(1, 1)](a_ptr=a, b_ptr=b, c_ptr=c, - M=M, N=N, K=K, - stride_am=a.stride(0), stride_ak=a.stride(1), - stride_bk=b.stride(0), stride_bn=b.stride(1), - stride_cm=c.stride(0), stride_cn=c.stride(1), - BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, - num_warps=NUM_WARPS, - num_ctas=NUM_CTAS, - OUTPUT_F16=OUTPUT_F16) + matmul_tma_load_store[(1, 1)]( + a_ptr=a, b_ptr=b, c_ptr=c, # + M=M, N=N, K=K, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_cm=c.stride(0), stride_cn=c.stride(1), # + BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, # + num_warps=NUM_WARPS, num_ctas=NUM_CTAS, # + OUTPUT_F16=OUTPUT_F16) golden = torch.matmul(a, b) torch.set_printoptions(profile="full") assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False) diff --git a/python/test/unit/hopper/ttgir_tests/test_tma.py b/python/test/unit/hopper/ttgir_tests/test_tma.py index d48d2aa42..0ee725b4b 100644 --- a/python/test/unit/hopper/ttgir_tests/test_tma.py +++ b/python/test/unit/hopper/ttgir_tests/test_tma.py @@ -54,17 +54,13 @@ def test_tma_wgmma_64_64_16_f16(TTGIR, TRANS_A, TRANS_B): ttgir_path = os.path.dirname(__file__) + "/" + TTGIR kernel = triton.compile(ttgir_path) - kernel[(1, 1, 1)](a.data_ptr(), b.data_ptr(), c.data_ptr(), - SIZE_M, SIZE_N, SIZE_K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0)) + kernel[(1, 1, 1)]( # + a.data_ptr(), b.data_ptr(), c.data_ptr(), # + SIZE_M, SIZE_N, SIZE_K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0)) golden = torch.matmul(a, b) torch.set_printoptions(profile="full", sci_mode=False) - assert_close( - c, - golden, - rtol=1e-2, - atol=1e-3, - check_dtype=False) + assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False) diff --git a/python/test/unit/language/assert_helper.py b/python/test/unit/language/assert_helper.py index 4bc9a367e..de802ec3d 100644 --- a/python/test/unit/language/assert_helper.py +++ b/python/test/unit/language/assert_helper.py @@ -15,9 +15,9 @@ def kernel_device_assert(X, Y, BLOCK: tl.constexpr): @triton.jit -def kernel_device_assert_scalar(X, Y, BLOCK: tl.constexpr): +def kernel_assert_passes(X, Y, BLOCK: tl.constexpr): x = tl.load(X + tl.arange(0, BLOCK)) - # Trivial assert + # Trivial assert, should not be an error. tl.device_assert(0 == 0, "x != 0") tl.store(Y + tl.arange(0, BLOCK), x) @@ -48,6 +48,7 @@ def test_assert(func: str): x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda') y = torch.zeros(shape, dtype=x.dtype, device="cuda") if func == "device_assert": +<<<<<<< HEAD kernel_device_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0]) kernel_device_assert_scalar[(1,)](x, y, num_warps=2, BLOCK=shape[0]) elif func == "no_debug": @@ -55,8 +56,32 @@ def test_assert(func: str): kernel_device_assert_no_debug[(1,)](x, y, num_warps=2, BLOCK=shape[0]) elif func == "assert": kernel_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0]) +======= + kernel_device_assert[(1, )](x, y, BLOCK=shape[0]) + if func == "device_assert_passes": + # Assert passes; no error. + kernel_assert_passes[(1, )](x, y, BLOCK=shape[0]) + elif func == "no_debug": + # TRITON_DEBUG=1 can override the debug flag + kernel_device_assert_no_debug[(1, )](x, y, BLOCK=shape[0]) + elif func == "assert": + kernel_assert[(1, )](x, y, BLOCK=shape[0]) +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 elif func == "static_assert": - kernel_static_assert[(1,)](x, y, BLOCK=shape[0]) + kernel_static_assert[(1, )](x, y, BLOCK=shape[0]) + elif func == "double_assert": + # Launching a different kernel after the first one asserted used to + # segfault. What seems to have happened is: + # - The first kernel is enqueued but doesn't run yet. + # - We go to launch the second kernel. Because this is the first time + # we're running it, we have to load the kernel into the GPU. + # - Loading the kernel takes some time, during which the first launch + # completes. + # - Now the GPU is in an error state. We need to detect this inside + # the kernel-launch/loading code and bail out properly. If we don't, + # we segfault. + kernel_device_assert[(1, )](x, y, BLOCK=shape[0]) + kernel_assert_passes[(1, )](x, y, BLOCK=shape[0]) assert_close(y, x) @@ -116,11 +141,19 @@ def test_assert_nested(caller: str, callee: str): x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda') y = torch.zeros(shape, dtype=x.dtype, device="cuda") if caller == "none": +<<<<<<< HEAD kernel_device_assert_nested[(1,)](x, y, num_warps=2, BLOCK=shape[0], jit_debug=callee) elif caller == "true": kernel_device_assert_nested_true[(1,)](x, y, num_warps=2, BLOCK=shape[0], jit_debug=callee) elif caller == "false": kernel_device_assert_nested_false[(1,)](x, y, num_warps=2, BLOCK=shape[0], jit_debug=callee) +======= + kernel_device_assert_nested[(1, )](x, y, BLOCK=shape[0], jit_debug=callee) + elif caller == "true": + kernel_device_assert_nested_true[(1, )](x, y, BLOCK=shape[0], jit_debug=callee) + elif caller == "false": + kernel_device_assert_nested_false[(1, )](x, y, BLOCK=shape[0], jit_debug=callee) +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 assert_close(y, x) diff --git a/python/test/unit/language/conftest.py b/python/test/unit/language/conftest.py index f9e96688b..7a02d322b 100644 --- a/python/test/unit/language/conftest.py +++ b/python/test/unit/language/conftest.py @@ -4,9 +4,7 @@ import pytest def pytest_addoption(parser): - parser.addoption( - "--device", action="store", default='cuda' - ) + parser.addoption("--device", action="store", default='cuda') @pytest.fixture diff --git a/python/test/unit/language/print_helper.py b/python/test/unit/language/print_helper.py index feb0d219d..6776f09c1 100644 --- a/python/test/unit/language/print_helper.py +++ b/python/test/unit/language/print_helper.py @@ -1,4 +1,5 @@ import sys +import uuid import torch from torch.testing import assert_close @@ -10,21 +11,49 @@ import triton.language as tl @triton.jit def kernel_device_print(X, Y, BLOCK: tl.constexpr): x = tl.load(X + tl.arange(0, BLOCK)) - tl.device_print("", x) + tl.device_print("x: ", x) tl.store(Y + tl.arange(0, BLOCK), x) @triton.jit def kernel_print(X, Y, BLOCK: tl.constexpr): x = tl.load(X + tl.arange(0, BLOCK)) - print("", x) + # Triton should add a space after this prefix. + print("x:", x) tl.store(Y + tl.arange(0, BLOCK), x) @triton.jit -def kernel_static_print(X, Y, BLOCK: tl.constexpr): +def kernel_device_print_large( + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + x = tl.full([BLOCK_M, BLOCK_N], 1, tl.int32) + # Triton should change this prefix to "x: ". + tl.device_print("x ", x) + + +@triton.jit +def kernel_print_multiple_args(X, Y, BLOCK: tl.constexpr): x = tl.load(X + tl.arange(0, BLOCK)) - tl.static_print(x) + y = tl.full((BLOCK, ), 1, tl.int32) + print("", x, y) + + +@triton.jit +def kernel_device_print_multiple_args(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.full((BLOCK, ), 1, tl.int32) + tl.device_print("", x, y) + tl.store(Y + tl.arange(0, BLOCK), y) + + +@triton.jit +def kernel_static_print(X, Y, BLOCK: tl.constexpr, PLACEHOLDER: tl.constexpr): + # This function takes an extra value as a tl.constexpr so this kernel is not + # cached. This way the static print is run every time. + x = tl.load(X + tl.arange(0, BLOCK)) + tl.static_print("", x) tl.store(Y + tl.arange(0, BLOCK), x) @@ -33,21 +62,36 @@ def kernel_no_arg_print(): print("", tl.program_id(0)) +@triton.jit +def kernel_print_no_arg(): + print("no arg") + + def test_print(func: str, data_type: str): shape = (128, ) - # limit the range of integers so that the sum does not overflow x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda').to(getattr(torch, data_type)) y = torch.zeros(shape, dtype=x.dtype, device="cuda") if func == "device_print": - kernel_device_print[(1,)](x, y, BLOCK=shape[0]) + kernel_device_print[(1, )](x, y, BLOCK=shape[0]) elif func == "print": - kernel_print[(1,)](x, y, BLOCK=shape[0]) + kernel_print[(1, )](x, y, BLOCK=shape[0]) + elif func == "device_print_large": + kernel_device_print_large[(1, 2)](BLOCK_M=64, BLOCK_N=128) + elif func == "print_multiple_args": + kernel_print_multiple_args[(1, )](x, y, BLOCK=shape[0]) + elif func == "device_print_multiple_args": + kernel_device_print_multiple_args[(1, )](x, y, BLOCK=shape[0]) elif func == "static_print": - kernel_static_print[(1,)](x, y, BLOCK=shape[0]) + kernel_static_print[(1, )](x, y, BLOCK=shape[0], PLACEHOLDER=uuid.uuid4()) elif func == "no_arg_print": - kernel_no_arg_print[(1,)](num_warps=4) + kernel_no_arg_print[(1, )](num_warps=4) + elif func == "print_no_arg": + kernel_print_no_arg[(1, )](num_warps=4) + else: + assert f"Unknown kernel: {func}" - if func != "no_arg_print": + if func != "print_no_arg" and func != "no_arg_print" and func != "device_print_large" and \ + func != "print_multiple_args" and func != "device_print_multiple_args": assert_close(y, x) diff --git a/python/test/unit/language/test_annotations.py b/python/test/unit/language/test_annotations.py index 0e18c950c..26bb40664 100644 --- a/python/test/unit/language/test_annotations.py +++ b/python/test/unit/language/test_annotations.py @@ -1,4 +1,3 @@ - from __future__ import annotations import torch @@ -14,8 +13,8 @@ def test_annotations(device): pass x = torch.empty(1, device=device) - _kernel[(1,)](x, x.shape[0], 32) + _kernel[(1, )](x, x.shape[0], 32) try: - _kernel[(1,)](x.shape[0], x.shape[0], 32) + _kernel[(1, )](x.shape[0], x.shape[0], 32) except AttributeError: pass diff --git a/python/test/unit/language/test_block_pointer.py b/python/test/unit/language/test_block_pointer.py index b1802b908..78b80cde9 100644 --- a/python/test/unit/language/test_block_pointer.py +++ b/python/test/unit/language/test_block_pointer.py @@ -17,10 +17,12 @@ def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, padding_option: tl.store(b_block_ptr, a, boundary_check=(0, )) -@pytest.mark.parametrize("dtype_str, n, padding_option", - [(dtype_str, n, padding) for dtype_str in ("bool", "int16", "float16") - for n in (64, 128, 256, 512, 1024) - for padding in ("zero", "nan")]) +@pytest.mark.parametrize("dtype_str, n, padding_option", [ # + (dtype_str, n, padding) + for dtype_str in ("bool", "int16", "float16") + for n in (64, 128, 256, 512, 1024) + for padding in ("zero", "nan") # +]) def test_block_copy(dtype_str, n, padding_option): capability = torch.cuda.get_device_capability() if torch.version.hip is None and capability[0] >= 9: @@ -35,31 +37,31 @@ def test_block_copy(dtype_str, n, padding_option): a = torch.randn((n, ), device="cuda", dtype=dtype) b = torch.zeros((n, ), device="cuda", dtype=dtype) - grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),) + grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]), ) block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, padding_option=padding_option) - assert torch.all(a[0: n // 2] == b[0: n // 2]) + assert torch.all(a[0:n // 2] == b[0:n // 2]) if padding_option == "zero": - assert torch.all(b[n // 2: n] == 0) + assert torch.all(b[n // 2:n] == 0) else: - assert torch.all(torch.isnan(b[n // 2: n])) + assert torch.all(torch.isnan(b[n // 2:n])) @triton.jit -def matmul_no_scf_with_advance_kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr +def matmul_no_scf_with_advance_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr # ): offs_m = tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) - a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), - offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) - b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), - offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(1, 0)) + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0), + block_shape=(BLOCK_K, BLOCK_N), order=(1, 0)) # Below two lines are just for testing negative offsets for the `advance` API, which could be removed a_block_ptr = tl.advance(a_block_ptr, (BLOCK_M, -BLOCK_K)) a_block_ptr = tl.advance(a_block_ptr, (-BLOCK_M, BLOCK_K)) @@ -71,14 +73,12 @@ def matmul_no_scf_with_advance_kernel( tl.store(c_ptrs, c) -@pytest.mark.parametrize("shape, num_warps", [ - (shape, num_warps) - for shape in [ +@pytest.mark.parametrize("shape, num_warps", [ # + (shape, num_warps) for shape in [ [64, 64, 16], [64, 64, 32], [64, 64, 64], - ] - for num_warps in [4, 8] + ] for num_warps in [4, 8] ]) def test_block_ptr_matmul_no_scf(shape, num_warps): capability = torch.cuda.get_device_capability() @@ -91,12 +91,13 @@ def test_block_ptr_matmul_no_scf(shape, num_warps): c = torch.empty((m, n), device="cuda", dtype=torch.float32) grid = lambda META: (1, ) - matmul_no_scf_with_advance_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, - M=m, N=n, K=k, - stride_am=a.stride(0), stride_ak=a.stride(1), - stride_bk=b.stride(0), stride_bn=b.stride(1), - stride_cm=c.stride(0), stride_cn=c.stride(1), - BLOCK_M=m, BLOCK_N=n, BLOCK_K=k, - num_warps=num_warps) + matmul_no_scf_with_advance_kernel[grid]( + a_ptr=a, b_ptr=b, c_ptr=c, # + M=m, N=n, K=k, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_cm=c.stride(0), stride_cn=c.stride(1), # + BLOCK_M=m, BLOCK_N=n, BLOCK_K=k, # + num_warps=num_warps) golden = torch.matmul(a, b) torch.testing.assert_close(c, golden, check_dtype=False) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index c8cab4a99..ece5a9caf 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1,6 +1,5 @@ # flake8: noqa: F821,F841 import itertools -import os import re from typing import Optional, Union @@ -10,7 +9,6 @@ import torch from numpy.random import RandomState import triton -import triton._C.libtriton.triton as _triton import triton.language as tl from triton.common.build import is_hip from triton.runtime.jit import JITFunction, TensorWrapper, reinterpret @@ -63,8 +61,7 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, h elif dtype_str in float_dtypes: return rs.normal(0, 1, shape).astype(dtype_str) elif dtype_str == 'bfloat16': - return (rs.normal(0, 1, shape).astype('float32').view('uint32') - & np.uint32(0xffff0000)).view('float32') + return (rs.normal(0, 1, shape).astype('float32').view('uint32') & np.uint32(0xffff0000)).view('float32') elif dtype_str in ['bool', 'int1', 'bool_']: return rs.normal(0, 1, shape) > 0.0 else: @@ -137,6 +134,7 @@ def check_type_supported(dtype, device): class MmaLayout: + def __init__(self, version, warps_per_cta, ctas_per_cga, cta_split_num, cta_order, instr_shape): self.version = version self.warps_per_cta = str(warps_per_cta) @@ -150,6 +148,7 @@ class MmaLayout: class BlockedLayout: + def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order): self.sz_per_thread = str(size_per_thread) self.threads_per_warp = str(threads_per_warp) @@ -164,6 +163,7 @@ class BlockedLayout: class SharedLayout: + def __init__(self, vec, per_phase, max_phase, order, ctas_per_cga, cta_split_num, cta_order): self.vec = str(vec) self.per_phase = str(per_phase) @@ -184,6 +184,7 @@ def test_empty_kernel(dtype_x, device): @triton.jit def kernel(X, SIZE: tl.constexpr): pass + check_type_supported(dtype_x, device) x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device, dst_type=dtype_x) kernel[(1, )](x, SIZE=SIZE, num_warps=4) @@ -248,7 +249,8 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]: return overrides.get(key) -def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', num_ctas=1, y_low=None, y_high=None): +def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', num_ctas=1, + y_low=None, y_high=None): check_type_supported(dtype_x, device) # early return if dtype_x is not supported check_type_supported(dtype_y, device) SIZE = 128 @@ -280,8 +282,7 @@ def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y= x_tri = to_triton(x, device=device, dst_type=dtype_x) y_tri = to_triton(y, device=device, dst_type=dtype_y) z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device) - kernel[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, - num_warps=4, num_ctas=num_ctas) + kernel[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas) np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=expr, rtol=0.01) @@ -310,12 +311,13 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: ('uint64', 'float64'), ] + # --------------- # test binary ops # --------------- -@pytest.mark.parametrize("dtype_x, dtype_y, op", [ +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # (dtype_x, dtype_y, op) for op in ['+', '-', '*', '/', '%'] for dtype_x in dtypes_with_bfloat16 @@ -327,7 +329,8 @@ def test_bin_op(dtype_x, dtype_y, op, num_ctas, device): if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes: # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. numpy_expr = 'np.fmod(x, y)' - elif op in ('/', '%') and dtype_x in ('int16', 'float16', 'bfloat16') and dtype_y in ('int16', 'float16', 'bfloat16'): + elif op in ('/', '%') and dtype_x in ('int16', 'float16', 'bfloat16') and dtype_y in ('int16', 'float16', + 'bfloat16'): # Triton promotes 16-bit floating-point / and % to 32-bit because there # are no native div or FRem operations on float16. Since we have to # convert anyway, we may as well take the accuracy bump. @@ -340,27 +343,14 @@ def test_bin_op(dtype_x, dtype_y, op, num_ctas, device): numpy_expr = None if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y): with pytest.raises(AssertionError, match='Not equal to tolerance'): - _test_binary( - dtype_x, - dtype_y, - expr, - numpy_expr, - device=device, - num_ctas=num_ctas) - elif (op in ('%', '/') and - ((dtype_x in int_dtypes and dtype_y in uint_dtypes) or - (dtype_x in uint_dtypes and dtype_y in int_dtypes))): + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + elif (op in ('%', '/') and ((dtype_x in int_dtypes and dtype_y in uint_dtypes) or + (dtype_x in uint_dtypes and dtype_y in int_dtypes))): with pytest.raises(triton.CompilationError) as exc_info: _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) assert re.match('Cannot use .* because they have different signedness', str(exc_info.value.__cause__)) else: - _test_binary( - dtype_x, - dtype_y, - expr, - numpy_expr, - device=device, - num_ctas=num_ctas) + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) @pytest.mark.parametrize("dtype, order", [(dtype, order) for dtype in dtypes_with_bfloat16 for order in [0, 1]]) @@ -382,14 +372,15 @@ def test_addptr(dtype, order, device): x_tri = to_triton(x, dst_type=dtype, device=device) y_tri = to_triton(y, dst_type=dtype, device=device) y = x - kernel[1,](x_tri, y_tri, order, SIZE) + kernel[ + 1, + ](x_tri, y_tri, order, SIZE) np.testing.assert_allclose(y, to_numpy(y_tri)) -@pytest.mark.parametrize("dtype_x, dtype_y", - [(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] + - [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes] - ) +@pytest.mark.parametrize("dtype_x, dtype_y", [ # + (dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes +] + [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_floordiv(dtype_x, dtype_y, num_ctas, device): # Triton has IEEE, not numpy/torch, semantics for %, and those carry @@ -397,13 +388,7 @@ def test_floordiv(dtype_x, dtype_y, num_ctas, device): # reference result for //. expr = 'x // y' numpy_expr = '((x - np.fmod(x, y)) / y)' - _test_binary( - dtype_x, - dtype_y, - expr, - numpy_expr, - device=device, - num_ctas=num_ctas) + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) def test_unsigned_name_mangling(device='cuda'): @@ -432,10 +417,7 @@ def test_unsigned_name_mangling(device='cuda'): # triton result x_tri = to_triton(x, device=device, dst_type=dtype_x) y_tri = to_triton(y, device=device, dst_type=dtype_y) - actual = tuple( - to_triton(np.empty_like(e), device=device) - for e in expect - ) + actual = tuple(to_triton(np.empty_like(e), device=device) for e in expect) kernel[(1, )](actual[0], actual[1], x_tri, y_tri, SIZE=SIZE, num_warps=4) # Bitwise op, so expect exact equality @@ -443,47 +425,9 @@ def test_unsigned_name_mangling(device='cuda'): assert (expect[1] == to_numpy(actual[1])).all() -def test_unsigned_name_mangling(device): - # Test that uint32 and int32 are mangled differently by the compiler - SIZE = 128 - # define the kernel / launch-grid - - @triton.jit - def kernel(O1, O2, X, Y, SIZE: tl.constexpr): - off = tl.arange(0, SIZE) - x = tl.load(X + off) - y = tl.load(Y + off) - out1 = tl.abs(x) # uint32 -> nop - out2 = tl.abs(-y) # int32 -> should have an effect - tl.store(O1 + off, out1) - tl.store(O2 + off, out2) - - dtype_x = 'uint32' - dtype_y = 'int32' - # inputs - rs = RandomState(17) - x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs) - y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs) - # reference result - expect = (np.abs(x), np.abs(-y)) - # triton result - x_tri = to_triton(x, device=device, dst_type=dtype_x) - y_tri = to_triton(y, device=device, dst_type=dtype_y) - actual = tuple( - to_triton(np.empty_like(e), device=device) - for e in expect - ) - kernel[(1, )](actual[0], actual[1], x_tri, y_tri, SIZE=SIZE, num_warps=4) - - # Bitwise op, so expect exact equality - assert (expect[0] == to_numpy(actual[0])).all() - assert (expect[1] == to_numpy(actual[1])).all() - - -# --------------- # test bitwise ops # --------------- -@pytest.mark.parametrize("dtype_x, dtype_y, op", [ +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # (dtype_x, dtype_y, op) for op in ['&', '|', '^'] for dtype_x in dtypes + dtypes_with_bfloat16 @@ -504,16 +448,10 @@ def test_bitwise_op(dtype_x, dtype_y, op, num_ctas, device): # The CompilationError must have been caused by a C++ exception with this text. assert re.match('invalid operands of type', str(exc_info.value.__cause__)) else: - _test_binary( - dtype_x, - dtype_y, - expr, - numpy_expr, - device=device, - num_ctas=num_ctas) + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) -@pytest.mark.parametrize("dtype_x, dtype_y, op", [ +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # (dtype_x, dtype_y, op) for op in ['<<', '>>'] for dtype_x in int_dtypes + uint_dtypes @@ -537,22 +475,14 @@ def test_shift_op(dtype_x, dtype_y, op, num_ctas, device): ops = ['==', '!=', '>', '<', '>=', '<='] -@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y", - # real - [ - (dtype_x, dtype_y, op, 'real', 'real') - for op in ops - for dtype_x in dtypes - for dtype_y in dtypes - ] + - # NaNs - [('float32', 'float32', op, mode_x, mode_y) - for op in ops - for mode_x, mode_y in [('nan', 'real'), - ('real', 'nan'), - ('nan', 'nan')] - - ]) +@pytest.mark.parametrize( + "dtype_x, dtype_y, op, mode_x, mode_y", + # real + [(dtype_x, dtype_y, op, 'real', 'real') for op in ops for dtype_x in dtypes for dtype_y in dtypes] + # NaNs + + [('float32', 'float32', op, mode_x, mode_y) + for op in ops + for mode_x, mode_y in [('nan', 'real'), ('real', 'nan'), ('nan', 'nan')]]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device): expr = f'x {op} y' @@ -570,6 +500,7 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device): # --------------- @pytest.mark.parametrize("dtype", dtypes_with_bfloat16) def test_broadcast(dtype, device): + @triton.jit def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr): offset1 = tl.arange(0, M) @@ -590,9 +521,10 @@ def test_broadcast(dtype, device): y_tri = to_triton(y, device=device, dst_type=dtype) y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device=device, dst_type=dtype) - broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N) + broadcast_kernel[(1, )](x_tri, y_tri, y_broadcasted_tri, M=M, N=N) assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all() + # ---------- # test slice # ---------- @@ -620,7 +552,7 @@ def test_slice(device): t = scalar[None, None] tl.static_assert(t.shape == [1, 1]) - slice_kernel[(1,)](XBLOCK=32) + slice_kernel[(1, )](XBLOCK=32) # ------------------ @@ -636,13 +568,14 @@ def test_invalid_slice(device): dst[10:] with pytest.raises(triton.CompilationError, match='unsupported tensor index'): - _kernel[(1,)](dst=dst) + _kernel[(1, )](dst=dst) # ---------------- # test expand_dims # ---------------- def test_expand_dims(device): + @triton.jit def expand_dims_kernel(dummy, N: tl.constexpr): offset1 = tl.arange(0, N) @@ -681,10 +614,11 @@ def test_expand_dims(device): N = 32 dummy_tensor = torch.empty((), device=device) - expand_dims_kernel[(1,)](dummy_tensor, N) + expand_dims_kernel[(1, )](dummy_tensor, N) def test_expand_dims_error_cases(device): + @triton.jit def dim_out_of_range1(dummy, N: tl.constexpr): offset1 = tl.arange(0, N) @@ -722,19 +656,19 @@ def test_expand_dims_error_cases(device): dummy_tensor = torch.empty((), device=device) with pytest.raises(triton.CompilationError, match="invalid axis -3"): - dim_out_of_range1[(1,)](dummy_tensor, N) + dim_out_of_range1[(1, )](dummy_tensor, N) with pytest.raises(triton.CompilationError, match="invalid axis 2"): - dim_out_of_range2[(1,)](dummy_tensor, N) + dim_out_of_range2[(1, )](dummy_tensor, N) with pytest.raises(triton.CompilationError, match="invalid axis 1"): - dim_out_of_range3[(1,)](dummy_tensor, N) + dim_out_of_range3[(1, )](dummy_tensor, N) with pytest.raises(triton.CompilationError, match=r"duplicate axes, normalized axes = \[0, 0\]"): - duplicate_dim1[(1,)](dummy_tensor, N) + duplicate_dim1[(1, )](dummy_tensor, N) with pytest.raises(triton.CompilationError, match=r"duplicate axes, normalized axes = \[0, 0\]"): - duplicate_dim2[(1,)](dummy_tensor, N) + duplicate_dim2[(1, )](dummy_tensor, N) # ---------------------------- @@ -748,7 +682,7 @@ def test_invalid_pid_axis(device): pid = tl.program_id(20) with pytest.raises(triton.CompilationError, match=r"program_id axis must be 0, 1, or 2 but got 20"): - _kernel[(1,)](dst) + _kernel[(1, )](dst) # --------------- @@ -764,10 +698,8 @@ def test_where(dtype, num_ctas, device): check_type_supported(dtype, device) @triton.jit - def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, - BLOCK_SIZE: tl.constexpr, - TEST_POINTERS: tl.constexpr, - TEST_SCALAR_POINTERS: tl.constexpr): + def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, + TEST_POINTERS: tl.constexpr, TEST_SCALAR_POINTERS: tl.constexpr): offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements decide = tl.load(cond_ptr + offsets, mask=mask) @@ -796,17 +728,20 @@ def test_where(dtype, num_ctas, device): y_tri = to_triton(y, device=device, dst_type=dtype) z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device=device, dst_type=dtype) - grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),) - where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=False, num_ctas=num_ctas) + grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']), ) + where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, + TEST_SCALAR_POINTERS=False, num_ctas=num_ctas) assert (z == to_numpy(z_tri)).all() if select_ptrs: - where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=True) + where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, + TEST_SCALAR_POINTERS=True) z = np.where(cond[0], x, y) assert (z == to_numpy(z_tri)).all() @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_where_broadcast(num_ctas, device): + @triton.jit def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] @@ -835,44 +770,45 @@ def test_where_broadcast(num_ctas, device): cond_tri = to_triton(mask, device=device) x_tri = to_triton(x, device=device, dst_type=dtype) z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device=device, dst_type=dtype) - where_kernel[(1,)](cond_tri, x_tri, z_tri, SIZE) + where_kernel[(1, )](cond_tri, x_tri, z_tri, SIZE) assert (z == to_numpy(z_tri)).all() - where_scalar_condition[(1,)](x_tri, z_tri, SIZE, num_ctas=num_ctas) + where_scalar_condition[(1, )](x_tri, z_tri, SIZE, num_ctas=num_ctas) z = np.where(0, x, 0) assert (z == to_numpy(z_tri)).all() + # --------------- # test unary ops # --------------- -@pytest.mark.parametrize("dtype_x, expr", [ - (dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16 -] + [ - (dtype_x, ' ~x') for dtype_x in int_dtypes -]) +@pytest.mark.parametrize("dtype_x, expr", + [(dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16] + [(dtype_x, ' ~x') + for dtype_x in int_dtypes]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_unary_op(dtype_x, expr, num_ctas, device): _test_unary(dtype_x, expr, device=device, num_ctas=num_ctas) + # ---------------- # test math ops # ---------------- -@pytest.mark.parametrize("dtype_x, expr, x", [(dtype_x, expr, x) for dtype_x in ["float32", "float64"] for expr in ['exp', 'log', 'cos', 'sin'] for x in ['x', '3.0']]) +@pytest.mark.parametrize("dtype_x, expr, x", [(dtype_x, expr, x) + for dtype_x in ["float32", "float64"] + for expr in ['exp', 'log', 'cos', 'sin'] + for x in ['x', '3.0']]) def test_math_op(dtype_x, expr, device, x): _test_unary(dtype_x, f'tl.{expr}({x})', f'np.{expr}({x}) ', device=device) + # ---------------- # test abs # ---------------- -@pytest.mark.parametrize("dtype_x", [ - (dtype_x) - for dtype_x in dtypes_with_bfloat16 -]) +@pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16]) def test_abs(dtype_x, device): _test_unary(dtype_x, 'tl.abs(x)', 'np.abs(x) ', device=device) @@ -896,7 +832,7 @@ def test_abs_fp8(in_dtype, device): f8 = triton.reinterpret(f8_tensor, in_dtype) n_elements = f8_tensor.numel() out_f8 = torch.empty_like(f8_tensor) - abs_kernel[(1,)](f8, triton.reinterpret(out_f8, in_dtype), n_elements) + abs_kernel[(1, )](f8, triton.reinterpret(out_f8, in_dtype), n_elements) f32_tensor = convert_float_to_float32(f8_tensor, in_dtype) expect = f32_tensor.abs() @@ -921,13 +857,9 @@ def make_ptr_str(name, shape): # TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>`` -@pytest.mark.parametrize("expr, dtype_str", [ - (f'x[{s}]', d) - for s in ['None, :', ':, None', - 'None, :, :', - ':, :, None'] - for d in ['int32', 'uint32', 'uint16'] -]) +@pytest.mark.parametrize("expr, dtype_str", [(f'x[{s}]', d) + for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] + for d in ['int32', 'uint32', 'uint16']]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_index1d(expr, dtype_str, num_ctas, device): rank_x = expr.count(':') @@ -971,8 +903,7 @@ def test_index1d(expr, dtype_str, num_ctas, device): def catch_compilation_error(kernel): try: - kernel[(1, )](z_tri, x_tri, num_warps=1, - SIZE=shape_x[0], num_ctas=num_ctas) + kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0], num_ctas=num_ctas) except triton.CompilationError as e: np.testing.assert_(True) except BaseException: @@ -995,6 +926,7 @@ def tuples_fn(a, b): def test_tuples(device): + @triton.jit def with_fn(X, Y, A, B, C): x = tl.load(X) @@ -1103,7 +1035,7 @@ def test_noinline(mode, device): z = torch.ones((16, 16), device=device, dtype=torch.float32) else: z = torch.tensor([0.0], device=device, dtype=torch.float32) - kernel[(1,)](x, y, z, num_warps=1) + kernel[(1, )](x, y, z, num_warps=1) if mode == "simple": assert torch.equal(z, x + y) elif mode == "call_graph" or mode == "dynamic" or mode == "multi_values": @@ -1116,18 +1048,31 @@ def test_noinline(mode, device): # --------------- # test atomics # --------------- -@pytest.mark.parametrize("op, dtype_x_str, mode, sem", itertools.chain.from_iterable([ - [ +@pytest.mark.parametrize( + "op, dtype_x_str, mode, sem", + itertools.chain.from_iterable([[ ('add', 'float16', mode, sem), - ('add', 'uint32', mode, sem), ('add', 'int32', mode, sem), ('add', 'float32', mode, sem), - ('add', 'uint64', mode, sem), ('add', 'int64', mode, sem), ('add', 'float64', mode, sem), - ('max', 'uint32', mode, sem), ('max', 'int32', mode, sem), ('max', 'float32', mode, sem), - ('max', 'uint64', mode, sem), ('max', 'int64', mode, sem), ('max', 'float64', mode, sem), - ('min', 'uint32', mode, sem), ('min', 'int32', mode, sem), ('min', 'float32', mode, sem), - ('min', 'uint64', mode, sem), ('min', 'int64', mode, sem), ('min', 'float64', mode, sem), + ('add', 'uint32', mode, sem), + ('add', 'int32', mode, sem), + ('add', 'float32', mode, sem), + ('add', 'uint64', mode, sem), + ('add', 'int64', mode, sem), + ('add', 'float64', mode, sem), + ('max', 'uint32', mode, sem), + ('max', 'int32', mode, sem), + ('max', 'float32', mode, sem), + ('max', 'uint64', mode, sem), + ('max', 'int64', mode, sem), + ('max', 'float64', mode, sem), + ('min', 'uint32', mode, sem), + ('min', 'int32', mode, sem), + ('min', 'float32', mode, sem), + ('min', 'uint64', mode, sem), + ('min', 'int64', mode, sem), + ('min', 'float64', mode, sem), ] - for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos'] - for sem in [None, 'acquire', 'release', 'acq_rel', 'relaxed']])) + for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos'] + for sem in [None, 'acquire', 'release', 'acq_rel', 'relaxed']])) def test_atomic_rmw(op, dtype_x_str, mode, sem, device): check_cuda_only(device) @@ -1186,18 +1131,22 @@ def test_atomic_rmw(op, dtype_x_str, mode, sem, device): @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_atomic_rmw_predicate(num_ctas, device): + @triton.jit def kernel(X): val = tl.program_id(0) if val < 64: tl.atomic_max(X, val) - x = torch.zeros((1,), device=device, dtype=torch.int32) - kernel[(4096,)](x, num_ctas=num_ctas) + + x = torch.zeros((1, ), device=device, dtype=torch.int32) + kernel[(4096, )](x, num_ctas=num_ctas) assert x.item() == 63 -@pytest.mark.parametrize("shape, axis, num_ctas", - [(shape, axis, num_ctas) for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)] for axis in [0, 1] for num_ctas in num_ctas_list]) +@pytest.mark.parametrize("shape, axis, num_ctas", [(shape, axis, num_ctas) + for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)] + for axis in [0, 1] + for num_ctas in num_ctas_list]) def test_tensor_atomic_rmw(shape, axis, num_ctas, device): shape0, shape1 = shape # triton kernel @@ -1212,6 +1161,7 @@ def test_tensor_atomic_rmw(shape, axis, num_ctas, device): tl.atomic_add(Z + off0, z) else: tl.atomic_add(Z + off1, z) + rs = RandomState(17) x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs) # reference result @@ -1220,7 +1170,7 @@ def test_tensor_atomic_rmw(shape, axis, num_ctas, device): x_tri = to_triton(x, device=device) z_shape = (shape0, ) if axis == 1 else (shape1, ) z_tri = to_triton(np.zeros(z_shape, dtype="float32"), device=device) - kernel[(1,)](z_tri, x_tri, axis, shape0, shape1, num_ctas=num_ctas) + kernel[(1, )](z_tri, x_tri, axis, shape0, shape1, num_ctas=num_ctas) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) @@ -1236,8 +1186,9 @@ def test_tensor_atomic_rmw_block(num_ctas, device): val = offs.to(tl.float32) x = X + offs tl.atomic_min(x, val) + x = torch.ones((8, 8), device=device, dtype=torch.float32) - kernel[(2,)](x, shape[0], shape[1], num_ctas=num_ctas) + kernel[(2, )](x, shape[0], shape[1], num_ctas=num_ctas) assert torch.min(x).item() == 0.0 @@ -1249,8 +1200,8 @@ def test_atomic_cas(sem, num_ctas, device): def change_value(Lock): tl.atomic_cas(Lock, 0, 1) - Lock = torch.zeros((1,), device=device, dtype=torch.int32) - change_value[(1,)](Lock) + Lock = torch.zeros((1, ), device=device, dtype=torch.int32) + change_value[(1, )](Lock) assert (Lock[0] == 1) @@ -1266,10 +1217,10 @@ def test_atomic_cas(sem, num_ctas, device): # release lock tl.atomic_xchg(Lock, 0) - Lock = torch.zeros((1,), device=device, dtype=torch.int32) - data = torch.zeros((128,), device=device, dtype=torch.float32) - ref = torch.full((128,), 64.0) - h = serialized_add[(64,)](data, Lock, SEM=sem, num_ctas=num_ctas) + Lock = torch.zeros((1, ), device=device, dtype=torch.int32) + data = torch.zeros((128, ), device=device, dtype=torch.float32) + ref = torch.full((128, ), 64.0) + h = serialized_add[(64, )](data, Lock, SEM=sem, num_ctas=num_ctas) sem_str = "acq_rel" if sem is None else sem np.testing.assert_allclose(to_numpy(data), to_numpy(ref)) if is_hip(): @@ -1277,36 +1228,49 @@ def test_atomic_cas(sem, num_ctas, device): assert f"atom.global.{sem_str}" in h.asm["ptx"] +@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed']) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_tensor_atomic_cas(sem, num_ctas, device): + + @triton.jit + def change_value(X, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + t1 = tl.full((BLOCK_SIZE, ), 0, dtype=tl.int64) + t2 = tl.full((BLOCK_SIZE, ), 2, dtype=tl.int64) + tl.atomic_cas(X + offsets, t1, t2) + + X = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1], device=device, dtype=torch.int64) + Y = torch.tensor([2, 1, 2, 1, 2, 1, 2, 1], device=device, dtype=torch.int64) + + change_value[(2, )](X, 4) + assert (torch.equal(X, Y)) + + # --------------- # test cast # --------------- -@pytest.mark.parametrize("dtype_x, dtype_z, bitcast, size", [ - (dtype_x, dtype_z, False, 1024) - for dtype_x in dtypes - for dtype_z in dtypes -] + [ - ('float32', 'bfloat16', False, 1024), - ('bfloat16', 'float32', False, 1024), - ('float32', 'int32', True, 1024), - ('float32', 'int1', False, 1024), - ('int8', 'bfloat16', False, 1024), -] + [ - (f'uint{x}', f'int{x}', True, 1024) for x in [8, 16, 32, 64] -] + [ - (f'int{x}', f'uint{x}', True, 1024) for x in [8, 16, 32, 64] -] + (([ - (dtype_x, dtype_z, False, size) - for dtype_x in torch_float8_dtypes - for dtype_z in ["float16", "float32", "bfloat16"] - for size in [1024, 32] -] + [ - (dtype_x, dtype_z, False, size) - for dtype_z in torch_float8_dtypes - for dtype_x in ["float16", "float32", "bfloat16"] - for size in [1024, 32] -]) if torch.__version__ >= "2.1" else [])) +@pytest.mark.parametrize("dtype_x, dtype_z, bitcast, size", + [(dtype_x, dtype_z, False, 1024) for dtype_x in dtypes for dtype_z in dtypes] + [ + ('float32', 'bfloat16', False, 1024), + ('bfloat16', 'float32', False, 1024), + ('float32', 'int32', True, 1024), + ('float32', 'int1', False, 1024), + ('int8', 'bfloat16', False, 1024), + ] + [(f'uint{x}', f'int{x}', True, 1024) + for x in [8, 16, 32, 64]] + [(f'int{x}', f'uint{x}', True, 1024) + for x in [8, 16, 32, 64]] + + (([(dtype_x, dtype_z, False, size) + for dtype_x in torch_float8_dtypes + for dtype_z in ["float16", "float32", "bfloat16"] + for size in [1024, 32]] + # + [(dtype_x, dtype_z, False, size) + for dtype_z in torch_float8_dtypes + for dtype_x in ["float16", "float32", "bfloat16"] + for size in [1024, 32]]) if torch.__version__ >= "2.1" else [])) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device): # bfloat16 on cc < 80 will not be tested @@ -1330,8 +1294,11 @@ def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device): if dtype_z in uint_dtypes: x = np.absolute(x) x_tri = to_triton(x, device=device) - + if 'float' in dtype_z and 'float' in dtype_x: + # make sure we use values that can be represented in both types + x_tri = x_tri.to(getattr(torch, dtype_z)).to(getattr(torch, dtype_x)) # triton kernel + @triton.jit def kernel(X, Z, BITCAST: tl.constexpr, SIZE: tl.constexpr): x_ptr = X + tl.arange(0, SIZE) @@ -1343,14 +1310,15 @@ def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device): dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_' # triton result if dtype_z.startswith('bfloat'): - z_tri = torch.empty((size,), dtype=getattr(torch, dtype_z), device=device) + z_tri = torch.empty((size, ), dtype=getattr(torch, dtype_z), device=device) elif dtype_z.startswith('float8'): - z_tri = torch.empty((size,), dtype=torch.float, device=device) + z_tri = torch.empty((size, ), dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_z)) else: z_tri = to_triton(np.empty((size, ), dtype=getattr(np, dtype_z_np)), device=device) kernel[(1, )](x_tri, z_tri, BITCAST=bitcast, SIZE=size, num_warps=1, num_ctas=num_ctas) # torch result - if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat') or dtype_z.startswith('float8') or dtype_x.startswith('float8'): + if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat') or dtype_z.startswith( + 'float8') or dtype_x.startswith('float8'): assert bitcast is False z_ref = x_tri.to(z_tri.dtype) torch.testing.assert_close(z_ref, z_tri, rtol=0, atol=0) @@ -1362,7 +1330,8 @@ def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device): np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0, atol=0) -@pytest.mark.parametrize("dtype_str, num_warps", [(dtype_str, num_warps) for dtype_str in int_dtypes + float_dtypes for num_warps in [4, 8]]) +@pytest.mark.parametrize("dtype_str, num_warps", + [(dtype_str, num_warps) for dtype_str in int_dtypes + float_dtypes for num_warps in [4, 8]]) def test_cat(dtype_str, num_warps, device): check_type_supported(dtype_str, device) @@ -1377,7 +1346,7 @@ def test_cat(dtype_str, num_warps, device): x = torch.arange(0, 128, device=device).to(getattr(torch, dtype_str)) y = torch.arange(-128, 0, device=device).to(getattr(torch, dtype_str)) z_ref = torch.cat([x, y], dim=0).sum() - z = torch.zeros((256,), dtype=getattr(torch, dtype_str), device=device) + z = torch.zeros((256, ), dtype=getattr(torch, dtype_str), device=device) kernel[(1, )](x, y, z, N=128, num_warps=num_warps) assert z.sum() == z_ref # check if there's no duplicate value in z @@ -1388,8 +1357,8 @@ def test_cat(dtype_str, num_warps, device): @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_store_constant(dtype_str, num_ctas, device): check_type_supported(dtype_str, device) - """Tests that boolean True is stored as 1""" + @triton.jit def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) @@ -1402,12 +1371,13 @@ def test_store_constant(dtype_str, num_ctas, device): block_size = 128 ref = torch.ones([block_size], dtype=getattr(torch, dtype_str), device=device) output = torch.zeros([block_size], dtype=getattr(torch, dtype_str), device=device) - kernel[(1,)](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas) + kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas) assert torch.all(output == ref) def test_load_store_same_ptr(device): + @triton.jit() def kernel(in_out_ptr): pid = tl.program_id(axis=0) @@ -1416,11 +1386,11 @@ def test_load_store_same_ptr(device): tl.store(in_out_ptr + pid, out) for _ in range(1000): - x = torch.ones((65536,), device=device, dtype=torch.float32) + x = torch.ones((65536, ), device=device, dtype=torch.float32) if is_hip(): - kernel[(65536,)](x, num_warps=16) # threads per Warp for ROCM is 64 + kernel[(65536, )](x, num_warps=16) # threads per Warp for ROCM is 64 else: - kernel[(65536,)](x, num_warps=32) + kernel[(65536, )](x, num_warps=32) assert torch.all(x == 2) @@ -1435,13 +1405,15 @@ def convert_float_to_float32(fp: torch.tensor, dtype=None): exp = ((fp >> dtype.fp_mantissa_width) & ((1 << exp_width) - 1)).int() frac = (fp & ((1 << dtype.fp_mantissa_width) - 1)).int() - output = torch.where(exp == 0, - # subnormal - ((-1.0) ** sign) * (2.0 ** (1 - exp_bias)) * (frac / (2.0 ** dtype.fp_mantissa_width)), - # normal - ((-1.0) ** sign) * (2.0 ** (exp - exp_bias)) * (1.0 + frac / (2.0 ** dtype.fp_mantissa_width))).float() + output = torch.where( + exp == 0, + # subnormal + ((-1.0)**sign) * (2.0**(1 - exp_bias)) * (frac / (2.0**dtype.fp_mantissa_width)), + # normal + ((-1.0)**sign) * (2.0**(exp - exp_bias)) * (1.0 + frac / (2.0**dtype.fp_mantissa_width))).float() - extended_exp = ((1 << (tl.float32.primitive_bitwidth - tl.float32.fp_mantissa_width - 1)) - 1) << tl.float32.fp_mantissa_width + extended_exp = ( + (1 << (tl.float32.primitive_bitwidth - tl.float32.fp_mantissa_width - 1)) - 1) << tl.float32.fp_mantissa_width # special cases, exp is 0b11..1 if dtype in [tl.float8e4nv, tl.float8e4b15]: # float8e4m3nv does not have infinities @@ -1449,8 +1421,9 @@ def convert_float_to_float32(fp: torch.tensor, dtype=None): output[fp == 0b11111111] = torch.nan else: output = torch.where(exp == (1 << exp_width) - 1, - ((sign << (tl.float32.primitive_bitwidth - 1)) | extended_exp | (frac << (tl.float32.fp_mantissa_width - dtype.fp_mantissa_width))).view(torch.float32), - output) + ((sign << (tl.float32.primitive_bitwidth - 1)) | extended_exp | + (frac << (tl.float32.fp_mantissa_width - dtype.fp_mantissa_width))) # + .view(torch.float32), output) return output @@ -1459,7 +1432,7 @@ def test_convert_float16_to_float32(in_dtype, device): """Tests that check convert_float_to_float32 function""" check_type_supported(in_dtype, device) - f16_input = torch.tensor(range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=torch.int16).view(in_dtype) + f16_input = torch.tensor(range(-int(2**(16 - 1)), int(2**(16 - 1))), dtype=torch.int16).view(in_dtype) f32_output = convert_float_to_float32(f16_input) nan = f16_input.isnan() @@ -1488,6 +1461,7 @@ def serialize_fp8(np_data, in_dtype): else: return np_data + # inverse of `serialize_fp8` @@ -1533,13 +1507,13 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device): tri_fp8 = torch.from_numpy(serialize_fp8(ref_fp8, in_dtype)).cuda() # check that non-subnormal fp8 are correctly converted to fp16 tri_fp16 = torch.empty(256, dtype=out_dtype, device="cuda") - copy_kernel[(1,)](triton.reinterpret(tri_fp8, in_dtype), tri_fp16, tri_fp16.shape[0], BLOCK_SIZE=1024) + copy_kernel[(1, )](triton.reinterpret(tri_fp8, in_dtype), tri_fp16, tri_fp16.shape[0], BLOCK_SIZE=1024) ref_fp8 = torch.from_numpy(ref_fp8).cuda() ref_fp16 = convert_float_to_float32(ref_fp8, in_dtype) assert torch.all(tri_fp16[~is_subnormal] == ref_fp16[~is_subnormal]) # check that values are properly converted back to float8 ref_fp8 = torch.empty_like(tri_fp16, dtype=torch.int8) - copy_kernel[(1,)](tri_fp16, triton.reinterpret(ref_fp8, in_dtype), tri_fp16.shape[0], BLOCK_SIZE=1024) + copy_kernel[(1, )](tri_fp16, triton.reinterpret(ref_fp8, in_dtype), tri_fp16.shape[0], BLOCK_SIZE=1024) if in_dtype == tl.float8e4b15: assert torch.all(tri_fp8[:127] == ref_fp8[:127]) assert torch.all(tri_fp8[128:255] == ref_fp8[128:255]) @@ -1548,6 +1522,7 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device): else: assert torch.all(tri_fp8[~is_subnormal] == ref_fp8[~is_subnormal]) + # --------------- # test reduce # --------------- @@ -1561,22 +1536,21 @@ def get_reduced_dtype(dtype_str, op): return dtype_str -@pytest.mark.parametrize("op, dtype_str, shape", - [(op, dtype, shape) - for op in ['min', 'max', - 'min-with-indices', - 'max-with-indices', - 'argmin-tie-break-left', - 'argmax-tie-break-left', - 'sum'] - for dtype in dtypes_with_bfloat16 - for shape in [32, 64, 128, 512]]) +@pytest.mark.parametrize("op, dtype_str, shape", [(op, dtype, shape) for op in [ + 'min', + 'max', + 'min-with-indices', + 'max-with-indices', + 'argmin-tie-break-left', + 'argmax-tie-break-left', + 'sum', +] for dtype in dtypes_with_bfloat16 for shape in [32, 64, 128, 512]]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_reduce1d(op, dtype_str, shape, num_ctas, device): check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested if is_hip(): - pytest.skip(f"test_reduce1d not supported on HIP") + pytest.skip("test_reduce1d not supported on HIP") # triton kernel @triton.jit @@ -1596,14 +1570,18 @@ def test_reduce1d(op, dtype_str, shape, num_ctas, device): # input rs = RandomState(17) # limit the range of integers so that the sum does not overflow - x = numpy_random((shape,), dtype_str=dtype_str, rs=rs) - numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min, - 'max-with-indices': np.max, - 'min-with-indices': np.min, - 'argmin-tie-break-fast': np.argmin, - 'argmin-tie-break-left': np.argmin, - 'argmax-tie-break-fast': np.argmax, - 'argmax-tie-break-left': np.argmax}[op] + x = numpy_random((shape, ), dtype_str=dtype_str, rs=rs) + numpy_op = { + 'sum': np.sum, + 'max': np.max, + 'min': np.min, + 'max-with-indices': np.max, + 'min-with-indices': np.min, + 'argmin-tie-break-fast': np.argmin, + 'argmin-tie-break-left': np.argmin, + 'argmax-tie-break-fast': np.argmax, + 'argmax-tie-break-left': np.argmax, + }[op] if 'tie-break-left' in op: x[3:10] = numpy_op(x) x_tri = to_triton(x, device=device) @@ -1619,9 +1597,8 @@ def test_reduce1d(op, dtype_str, shape, num_ctas, device): else: z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) # triton result - z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs), - device=device, dst_type=z_tri_dtype_str) - kernel[(1,)](x_tri, z_tri, BLOCK=shape, num_ctas=num_ctas) + z_tri = to_triton(numpy_random((1, ), dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str) + kernel[(1, )](x_tri, z_tri, BLOCK=shape, num_ctas=num_ctas) z_tri = to_numpy(z_tri) # compare if op == 'sum': @@ -1636,12 +1613,10 @@ def test_reduce1d(op, dtype_str, shape, num_ctas, device): # TODO: [Qingyi] Fix argmin / argmax -reduce_configs1 = [ - (op, dtype, (1, 1024), axis) for dtype in dtypes_with_bfloat16 - for op in ['min', 'max', 'sum', 'argmin', 'argmax'] - for axis in [1] -] - +reduce_configs1 = [(op, dtype, (1, 1024), axis) + for dtype in dtypes_with_bfloat16 + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for axis in [1]] # shape (128, 256) and (32, 1024) are not enabled on sm86 because the required shared memory # exceeds the limit of 99KB @@ -1651,24 +1626,16 @@ reduce2d_shapes = [(2, 32), (4, 32), (4, 128)] if torch.cuda.is_available() and 'V100' in torch.cuda.get_device_name(0): reduce2d_shapes += [(128, 256) and (32, 1024)] - -reduce_configs2 = [ - (op, 'float32', shape, axis) - for op in ['min', 'max', 'sum', 'argmin', 'argmax'] - for shape in reduce2d_shapes - for axis in [0, 1] -] + [ - (op, 'float32', [16, 32], None) - for op in ['min', 'max', 'sum'] -] +reduce_configs2 = [(op, 'float32', shape, axis) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for shape in reduce2d_shapes + for axis in [0, 1]] + [(op, 'float32', [16, 32], None) for op in ['min', 'max', 'sum']] reduce3d_shapes = [(2, 32, 16), (32, 2, 16), (32, 16, 2)] -reduce_configs3 = [ - (op, 'float32', shape, axis) - for op in ['min', 'max', 'sum', 'argmin', 'argmax'] - for shape in reduce3d_shapes - for axis in [0, 1, 2] -] +reduce_configs3 = [(op, 'float32', shape, axis) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for shape in reduce3d_shapes + for axis in [0, 1, 2]] @pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2 + reduce_configs3) @@ -1677,16 +1644,18 @@ def test_reduce(op, dtype_str, shape, axis, num_ctas, device): check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested if is_hip(): - pytest.skip(f"test_reduce2d not supported on HIP") + pytest.skip("test_reduce2d not supported on HIP") # triton kernel @triton.jit - def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr, AXIS: tl.constexpr): + def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr, + AXIS: tl.constexpr): range_m = tl.arange(0, BLOCK_M) range_n = tl.arange(0, BLOCK_N) range_k = tl.arange(0, BLOCK_K) if IS_3D: - x = tl.load(X + range_m[:, None, None] * BLOCK_N * BLOCK_K + range_n[None, :, None] * BLOCK_K + range_k[None, None, :]) + x = tl.load(X + range_m[:, None, None] * BLOCK_N * BLOCK_K + range_n[None, :, None] * BLOCK_K + + range_k[None, None, :]) else: x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) z = GENERATE_TEST_HERE @@ -1713,8 +1682,7 @@ def test_reduce(op, dtype_str, shape, axis, num_ctas, device): # limit the range of integers so that the sum does not overflow x = numpy_random(shape, dtype_str=dtype_str, rs=rs) x_tri = to_triton(x, device=device) - numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min, - 'argmin': np.argmin, 'argmax': np.argmax}[op] + numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min, 'argmin': np.argmin, 'argmax': np.argmax}[op] z_dtype_str = get_reduced_dtype(dtype_str, op) z_tri_dtype_str = z_dtype_str # numpy result @@ -1728,13 +1696,12 @@ def test_reduce(op, dtype_str, shape, axis, num_ctas, device): z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str)) # triton result ret_numel = 1 if axis is None else shape[1 - axis] - z_shape = (1,) if axis is None else tuple(shape_i for i, shape_i in enumerate(shape) if i != axis) - z_tri = to_triton(numpy_random(z_shape, dtype_str=z_dtype_str, rs=rs), - device=device, dst_type=z_tri_dtype_str) + z_shape = (1, ) if axis is None else tuple(shape_i for i, shape_i in enumerate(shape) if i != axis) + z_tri = to_triton(numpy_random(z_shape, dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str) BLOCK_K = 1 if len(shape) == 2 else shape[2] IS_3D = bool(len(shape) == 3) - kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], - BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, num_ctas=num_ctas) + kernel[(1, )](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, + num_ctas=num_ctas) z_tri = to_numpy(z_tri) # compare if op == 'sum': @@ -1754,14 +1721,12 @@ def test_reduce(op, dtype_str, shape, axis, num_ctas, device): scan2d_shapes = [(8, 32), (16, 32), (32, 16), (2, 1024), (1024, 2), (32, 32), (1, 1024)] -scan_configs = [ - (op, type, shape, axis, num_warps) - for num_warps in [4, 16] - for type in ['int32', 'float32'] - for axis in [1, 0] - for shape in scan2d_shapes - for op in ['cumsum', 'cumprod', 'get_first_element'] -] +scan_configs = [(op, type, shape, axis, num_warps) + for num_warps in [4, 16] + for type in ['int32', 'float32'] + for axis in [1, 0] + for shape in scan2d_shapes + for op in ['cumsum', 'cumprod', 'get_first_element']] @triton.jit @@ -1807,7 +1772,7 @@ def test_scan2d(op, dtype_str, shape, axis, num_warps, device): z_ref[:, 1:] = x[:, 0:1] # triton result z_tri = to_triton(z, device=device) - kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps) + kernel[(1, )](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps) z_tri = to_numpy(z_tri) # compare if dtype_str == 'float32': @@ -1825,7 +1790,6 @@ scan_layouts = [ BlockedLayout([4, 1], [4, 8], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]), BlockedLayout([2, 2], [4, 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), BlockedLayout([2, 2], [8, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 4], [4, 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 4], [8, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([4, 1], [4, 8], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), @@ -1834,6 +1798,54 @@ scan_layouts = [ ] +@pytest.mark.parametrize("op", ['sum', 'max', 'min']) +@pytest.mark.parametrize("BLOCK_N", [32, 64, 128]) +@pytest.mark.parametrize("N", [512, 1024, 2048]) +@pytest.mark.parametrize("num_pid_n", [2, 4]) +def test_locality(op, BLOCK_N, N, num_pid_n): + + @triton.jit + def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + start_m = tl.program_id(0) + pid_n = tl.program_id(1) + num_pid_n = tl.num_programs(1) + local = INITIALIZE_PATCH + off_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + for start_n in range(pid_n, tl.cdiv(N, BLOCK_N), num_pid_n): + off_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + Xs = X + off_m[:, None] * N + off_n[None, :] + x = tl.load(Xs) + local = ACCUMULATE_PATCH + tl.store(Y + off_m * num_pid_n + pid_n, local) + + initialize_patch = { + 'sum': 'tl.zeros([BLOCK_M], dtype=tl.float32)', + 'max': 'tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)', + 'min': 'tl.full([BLOCK_M], float("inf"), dtype=tl.float32)', + }[op] + reduce_patch = { + 'sum': 'local + tl.sum(x, axis=1)', + 'max': 'tl.maximum(local, tl.max(x, axis=1))', + 'min': 'tl.minimum(local, tl.min(x, axis=1))', + }[op] + numpy_op = { + 'sum': np.sum, + 'max': np.max, + 'min': np.min, + }[op] + kernel = patch_kernel(kernel, {'ACCUMULATE_PATCH': reduce_patch, 'INITIALIZE_PATCH': initialize_patch}) + torch.manual_seed(0) + BLOCK_M = 32 + x = torch.randn((BLOCK_M, N), dtype=torch.float32, device="cuda") + y = torch.randn((BLOCK_M, num_pid_n), dtype=torch.float32, device="cuda") + h = kernel[(1, num_pid_n, 1)](x, y, N, BLOCK_M, BLOCK_N) + assert h.asm['ttgir'].count( + '"tt.reduce"') == 2, "tt.reduce should be called twice, otherwise the optimization didn't work" + y_ref = numpy_op(x.cpu().numpy(), axis=1, keepdims=True) + y_tri = numpy_op(y.cpu().numpy(), axis=1, keepdims=True) + np.testing.assert_allclose(y_tri, y_ref, rtol=0.01, atol=1e-3) + + @pytest.mark.parametrize("M, N", [[32, 16], [32, 32], [32, 64], [64, 32]]) @pytest.mark.parametrize("src_layout", scan_layouts) @pytest.mark.parametrize("axis", [0, 1]) @@ -1895,9 +1907,12 @@ layouts = [ BlockedLayout([1, 4], [8, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 4], [8, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), BlockedLayout([4, 4], [2, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]), - MmaLayout(version=(2, 0), warps_per_cta=[2, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]), - MmaLayout(version=(3, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0], instr_shape=[16, 16, 16]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]), + MmaLayout(version=(2, 0), warps_per_cta=[2, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]), + MmaLayout(version=(3, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0], + instr_shape=[16, 16, 16]), ] @@ -1915,13 +1930,10 @@ def test_reduce_layouts(M, N, src_layout, axis, reduce2d, dtype_str, reduce_op, ty = {"int32": "i32", "float32": "f32", "float16": "f16"}[dtype_str] arith_op = { - "max": {"int32": "arith.maxsi", "float32": "arith.maximumf", "float16": "arith.maximumf"}, + "max": {"int32": "arith.maxsi", "float32": "arith.maximumf", "float16": "arith.maximumf"}, # "sum": {"int32": "arith.addi", "float32": "arith.addf", "float16": "arith.addf"} }[reduce_op][dtype_str] - numpy_op = { - "max": np.max, - "sum": np.sum - }[reduce_op] + numpy_op = {"max": np.max, "sum": np.sum}[reduce_op] rdims_1d = f"{N}" if axis == 0 else f"{M}" rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1" store_range = "%7" if axis == 0 else "%1" @@ -1998,7 +2010,8 @@ def test_reduce_layouts(M, N, src_layout, axis, reduce2d, dtype_str, reduce_op, layouts = [ BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), - MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]) + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]) ] @@ -2047,7 +2060,8 @@ def test_store_op(M, src_layout, device): layouts = [ BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), - MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]) + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]) ] @@ -2124,12 +2138,12 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis): op_str = "" if op == "sum": - op_str = f""" + op_str = """ %13 = arith.addi %arg2, %arg3 : i32 tt.reduce.return %13 : i32""" elif op == "max": - op_str = f""" - %13 = "{GPU_DIALECT}.cmpi"(%arg2, %arg3) <{{predicate = 4 : i64}}> : (i32, i32) -> i1 + op_str = """ + %13 = arith.cmpi "sgt", %arg2, %arg3 : i32 %14 = arith.select %13, %arg2, %arg3 : i32 tt.reduce.return %14 : i32""" ir = f""" @@ -2170,7 +2184,7 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis): rs = RandomState(17) x = rs.randint(0, 4, (M, N)).astype('int32') - z = np.zeros((1,)).astype('int32') + z = np.zeros((1, )).astype('int32') x_tri = torch.tensor(x, device=device) z_tri = torch.tensor(z, device=device) @@ -2202,7 +2216,7 @@ def test_generic_reduction(device): out_mean = torch.empty((), device=device) out_var = torch.empty((), device=device) - var_mean_kernel[(1,)](x, out_mean, out_var, BLOCK=SIZE) + var_mean_kernel[(1, )](x, out_mean, out_var, BLOCK=SIZE) expect_var, expect_mean = torch.var_mean(x, dim=0, correction=0) torch.testing.assert_close(out_mean, expect_mean) @@ -2214,42 +2228,37 @@ def test_generic_reduction(device): # --------------- -@pytest.mark.parametrize("dtype_str, shape, perm", - [(dtype, shape, perm) - # TODO: bfloat16 - for dtype in ['float8e4b15', 'float16', 'float32'] - for shape in [(64, 64), (128, 128)] - for perm in [(1, 0)]]) +@pytest.mark.parametrize("dtype_str, shape, perm", [(dtype, shape, perm) + # TODO: bfloat16 + for dtype in ['float8e4b15', 'float16', 'float32'] + for shape in [(64, 64), (128, 128)] + for perm in [(1, 0)]]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_permute(dtype_str, shape, perm, num_ctas, device): check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested if is_hip(): - pytest.skip(f"test_permute is not supported in HIP") + pytest.skip("test_permute is not supported in HIP") # triton kernel @triton.jit - def kernel(X, stride_xm, stride_xn, - Z, stride_zm, stride_zn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + def kernel(X, stride_xm, stride_xn, Z, stride_zm, stride_zn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): off_m = tl.arange(0, BLOCK_M) off_n = tl.arange(0, BLOCK_N) Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn tl.store(Zs, tl.load(Xs)) + # input x = numpy_random(shape, dtype_str=dtype_str) # triton result z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) z_tri_contiguous = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) x_tri = to_triton(x, device=device, dst_type=dtype_str) - pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), - z_tri, z_tri.stride(1), z_tri.stride(0), - BLOCK_M=shape[0], BLOCK_N=shape[1], - num_ctas=num_ctas) - pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), x_tri.stride(0), - z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1), - BLOCK_M=shape[0], BLOCK_N=shape[1], - num_ctas=num_ctas) + pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), z_tri, z_tri.stride(1), z_tri.stride(0), + BLOCK_M=shape[0], BLOCK_N=shape[1], num_ctas=num_ctas) + pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), + x_tri.stride(0), z_tri_contiguous, z_tri_contiguous.stride(0), + z_tri_contiguous.stride(1), BLOCK_M=shape[0], BLOCK_N=shape[1], num_ctas=num_ctas) # numpy result if dtype_str == 'float8e4b15': ty = tl.float8e4b15 @@ -2279,37 +2288,25 @@ def test_permute(dtype_str, shape, perm, num_ctas, device): # --------------- -@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype", - [(*shape, 4, False, False, epilogue, allow_tf32, in_dtype, out_dtype) - for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)] - for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] - for allow_tf32 in [True, False] - for in_dtype, out_dtype in [('float16', 'float16'), - ('float16', 'float32'), - ('float32', 'float32')] - if not (allow_tf32 and (in_dtype in ['float16']))] + - - [(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype) - for shape_nw in [[128, 256, 32, 8], - [128, 16, 32, 4], - [32, 128, 64, 4], - [128, 128, 64, 4], - [64, 128, 128, 4], - [32, 128, 64, 2], - [64, 64, 32, 4], - [32, 32, 128, 16], - [128, 128, 64, 2], - [64, 128, 128, 2]] - for allow_tf32 in [True] - for col_a in [True, False] - for col_b in [True, False] - for in_dtype, out_dtype in [('int8', 'int8'), - ('float16', 'float16'), - ('float16', 'float32'), - ('float32', 'float32')]] + - - [(64, 64, 64, 4, col_a, col_b, 'none', False, 'float32', 'float32') - for col_a in [True, False] for col_b in [True, False]]) +@pytest.mark.parametrize( + "M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype", + [(*shape, 4, False, False, epilogue, allow_tf32, in_dtype, out_dtype) + for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)] + for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] + for allow_tf32 in [True, False] + for in_dtype, out_dtype in [('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32')] + if not (allow_tf32 and (in_dtype in ['float16']))] + + [(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype) + for shape_nw in [[128, 256, 32, 8], [128, 16, 32, 4], [32, 128, 64, 4], [128, 128, 64, 4], [64, 128, 128, 4], + [32, 128, 64, 2], [64, 64, 32, 4], [32, 32, 128, 16], [128, 128, 64, 2], [64, 128, 128, 2]] + for allow_tf32 in [True] + for col_a in [True, False] + for col_b in [True, False] + for in_dtype, out_dtype in [('int8', 'int8'), ('float16', 'float16'), ('float16', 'float32'), ('float32', + 'float32')]] + + [(64, 64, 64, 4, col_a, col_b, 'none', False, 'float32', 'float32') + for col_a in [True, False] + for col_b in [True, False]]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, num_ctas, device): check_cuda_only(device) @@ -2358,16 +2355,10 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o # triton kernel @triton.jit - def kernel(X, stride_xm, stride_xk, - Y, stride_yk, stride_yn, - W, stride_wn, stride_wl, - Z, stride_zm, stride_zn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, - ALLOW_TF32: tl.constexpr, - DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, - COL_A: tl.constexpr, COL_B: tl.constexpr, - out_dtype: tl.constexpr = tl.float32): + def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, stride_wl, Z, stride_zm, stride_zn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ADD_MATRIX: tl.constexpr, + ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, ALLOW_TF32: tl.constexpr, DO_SOFTMAX: tl.constexpr, + CHAIN_DOT: tl.constexpr, COL_A: tl.constexpr, COL_B: tl.constexpr, out_dtype: tl.constexpr = tl.float32): off_m = tl.arange(0, BLOCK_M) off_n = tl.arange(0, BLOCK_N) off_l = tl.arange(0, BLOCK_N) @@ -2397,6 +2388,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o w = tl.load(Ws) z = tl.dot(z.to(w.dtype), w, allow_tf32=ALLOW_TF32, out_dtype=out_dtype) tl.store(Zs, z) + # input rs = RandomState(17) if col_a: @@ -2438,20 +2430,12 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o else: out_dtype = tl.float32 - pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), - y_tri, y_tri.stride(0), y_tri.stride(1), - w_tri, w_tri.stride(0), w_tri.stride(1), - z_tri, z_tri.stride(0), z_tri.stride(1), - COL_A=col_a, COL_B=col_b, - BLOCK_M=M, BLOCK_K=K, BLOCK_N=N, - ADD_MATRIX=epilogue == 'add-matrix', - ADD_ROWS=epilogue == 'add-rows', - ADD_COLS=epilogue == 'add-cols', - DO_SOFTMAX=epilogue == 'softmax', - CHAIN_DOT=epilogue == 'chain-dot', - ALLOW_TF32=allow_tf32, - num_warps=num_warps, num_ctas=num_ctas, - out_dtype=out_dtype) + pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), y_tri, y_tri.stride(0), y_tri.stride(1), w_tri, + w_tri.stride(0), w_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1), COL_A=col_a, + COL_B=col_b, BLOCK_M=M, BLOCK_K=K, BLOCK_N=N, ADD_MATRIX=epilogue == 'add-matrix', + ADD_ROWS=epilogue == 'add-rows', ADD_COLS=epilogue == 'add-cols', + DO_SOFTMAX=epilogue == 'softmax', CHAIN_DOT=epilogue == 'chain-dot', ALLOW_TF32=allow_tf32, + num_warps=num_warps, num_ctas=num_ctas, out_dtype=out_dtype) if epilogue == 'softmax' and (in_dtype != 'float32' or allow_tf32): if is_hip(): @@ -2462,7 +2446,6 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o end = ptx.find("cvt.rn.f16.f32") red_code = ptx[start:end] assert len(red_code) > 0 - import os # skip this check on hopper because there are some functions whose name contain "shared" in ptx. # TODO: we should eliminate these unused functions in ptx code. @@ -2471,8 +2454,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o assert "bar.sync" not in red_code # torch result if in_dtype == 'int8': - z_ref = np.matmul(x.astype(np.float32), - y.astype(np.float32())).astype(np.int32) + z_ref = np.matmul(x.astype(np.float32), y.astype(np.float32())).astype(np.int32) else: z_ref = np.matmul(x, y) @@ -2532,9 +2514,8 @@ def test_dot_mulbroadcastred(in_dtype, device): pytest.skip("Requires sm >= 80 to run") @triton.jit - def kernel(Z, X, Y, - M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, - BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr): + def kernel(Z, X, Y, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, BM: tl.constexpr, BN: tl.constexpr, + BK: tl.constexpr): pidn = tl.program_id(1) pidm = tl.program_id(0) offm = tl.arange(0, BM)[:, None] @@ -2550,6 +2531,7 @@ def test_dot_mulbroadcastred(in_dtype, device): t = tl.sum(x * y, axis=1) acc = t + acc tl.store(Z + ((pidm * BM * N) + (pidn * BN) + (offm * N) + offn), acc) + M, N, K = 256, 192, 160 BM, BN, BK = 128, 32, 32 rs = RandomState(17) @@ -2573,7 +2555,6 @@ def test_dot_mulbroadcastred(in_dtype, device): # as the loaded value is in rowmajor. But MMAv3 requires it's second # operand is in colmajor because transpose is not supported for MMAv3 # with float32 input. - import os if capability[0] >= 9: assert "triton_gpu.async_wait {num = 1 : i32}" in h.asm['ttgir'] else: @@ -2581,7 +2562,7 @@ def test_dot_mulbroadcastred(in_dtype, device): @pytest.mark.parametrize("dtype_str", int_dtypes + uint_dtypes + float_dtypes + ['bfloat16']) -@pytest.mark.parametrize("shape", [(), (1,), (128,)]) +@pytest.mark.parametrize("shape", [(), (1, ), (128, )]) def test_full(dtype_str, shape, device): if dtype_str in uint_dtypes and not hasattr(torch, dtype_str): # PyTorch only has unsigned 8, but not 16, 32, or 64 @@ -2609,30 +2590,28 @@ def test_full(dtype_str, shape, device): 'SHAPE': str(list(shape)), }) out_static = torch.zeros((128), dtype=dtype, device=device) - kernel_static_patched[(1,)](out_static) + kernel_static_patched[(1, )](out_static) assert torch.all(out_static == 2) kernel_dynamic_patched = patch_kernel(kernel_dynamic, {'SHAPE': str(list(shape))}) out_dynamic = torch.zeros((128), dtype=dtype, device=device) - kernel_dynamic_patched[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str)) + kernel_dynamic_patched[(1, )](out_dynamic, 2, getattr(triton.language, dtype_str)) assert torch.all(out_dynamic == 2) -@pytest.mark.parametrize("literal, dtype_str", - [(1e+50, "f64"), (1e+10, "f32"), (1.0, "f32"), - ('float("inf")', "f32"), ('float("-inf")', "f32"), - ('float("nan")', "f32"), ('float("-nan")', "f32"), - (0., "f32"), - (5, "i32"), (2**40, "i64"),]) +@pytest.mark.parametrize("literal, dtype_str", [(1e+50, "f64"), (1e+10, "f32"), (1.0, "f32"), ('float("inf")', "f32"), + ('float("-inf")', "f32"), ('float("nan")', "f32"), + ('float("-nan")', "f32"), (0., "f32"), (5, "i32"), (2**40, "i64")]) def test_constexpr(literal, dtype_str, device): + @triton.jit def kernel(out_ptr): val = GENERATE_TEST_HERE tl.store(out_ptr.to(tl.pointer_type(val.dtype)), val) kernel_patched = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{literal}"}) - out = torch.zeros((1,), dtype=torch.float32, device=device) - h = kernel_patched[(1,)](out) + out = torch.zeros((1, ), dtype=torch.float32, device=device) + h = kernel_patched[(1, )](out) assert re.search(r"arith.constant .* : " + dtype_str, h.asm["ttir"]) is not None @@ -2651,14 +2630,16 @@ def test_dot_without_load(dtype_str, device): c = tl.dot(a, b, allow_tf32=ALLOW_TF32) out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :] tl.store(out_ptr, c) + kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"}) a = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device) b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device) out_ref = torch.matmul(a, b) out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device=device) - kernel[(1,)](out, ALLOW_TF32=allow_tf32) + kernel[(1, )](out, ALLOW_TF32=allow_tf32) assert torch.all(out == out_ref) + # --------------- # test arange # --------------- @@ -2671,21 +2652,25 @@ def test_arange(start, num_ctas, device): z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) @triton.jit - def _kernel(z, BLOCK: tl.constexpr, - START: tl.constexpr, END: tl.constexpr): + def _kernel(z, BLOCK: tl.constexpr, START: tl.constexpr, END: tl.constexpr): off = tl.arange(0, BLOCK) val = tl.arange(START, END) tl.store(z + off, val) - _kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK, num_ctas=num_ctas) + + _kernel[(1, )](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK, num_ctas=num_ctas) z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device) np.testing.assert_allclose(to_numpy(z_tri), to_numpy(z_ref)) + # --------------- # test load # --------------- -@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [0, 1, 2, 3, 4]]) +@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) + for dtype_str in torch_dtypes + for size in [128, 512] + for size_diff in [0, 1, 2, 3, 4]]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_masked_load(dtype_str, size, size_diff, num_ctas, device): dtype = getattr(torch, dtype_str) @@ -2694,12 +2679,12 @@ def test_masked_load(dtype_str, size, size_diff, num_ctas, device): input_size = size - size_diff output_size = size if dtype_str == 'bool': - input = torch.randint(0, 2, (input_size,), dtype=dtype, device=device) + input = torch.randint(0, 2, (input_size, ), dtype=dtype, device=device) elif dtype_str in int_dtypes or dtype_str in uint_dtypes: - input = torch.randint(0, 127, (input_size,), dtype=dtype, device=device) + input = torch.randint(0, 127, (input_size, ), dtype=dtype, device=device) else: input = torch.rand(input_size, dtype=dtype, device=device) - output = torch.zeros((output_size,), dtype=dtype, device=device) + output = torch.zeros((output_size, ), dtype=dtype, device=device) @triton.jit def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr): @@ -2712,12 +2697,13 @@ def test_masked_load(dtype_str, size, size_diff, num_ctas, device): mask_str = "mask=in_offsets < in_size, other=1" if size_diff > 0 else "None" kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.load(in_ptr + in_offsets, {mask_str})"}) - kernel[(1,)](input, output, input_size, output_size, num_ctas=num_ctas) + kernel[(1, )](input, output, input_size, output_size, num_ctas=num_ctas) - reference_out = torch.cat((input, torch.ones((size_diff,), dtype=dtype, device=device))) + reference_out = torch.cat((input, torch.ones((size_diff, ), dtype=dtype, device=device))) # print((output - reference_out).nonzero()) torch.testing.assert_close(output, reference_out) + # Testing masked loads with an intermate copy to shared memory run. @@ -2738,9 +2724,7 @@ def test_masked_load_shared_memory(dtype, device): out = torch.zeros((M, N), dtype=dtype, device=device) @triton.jit - def _kernel(in1_ptr, in2_ptr, output_ptr, - in_stride, in2_stride, out_stride, - in_numel, in2_numel, out_numel, + def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, in_numel, in2_numel, out_numel, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr): M_offsets = tl.arange(0, M) @@ -2761,14 +2745,8 @@ def test_masked_load_shared_memory(dtype, device): output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :] tl.store(output_ptr + output_offsets, o, mask=output_offsets < M * N) - pgm = _kernel[(1,)](in1, in2, out, - in1.stride()[0], - in2.stride()[0], - out.stride()[0], - in1.numel(), - in2.numel(), - out.numel(), - M=M, N=N, K=K) + pgm = _kernel[(1, )](in1, in2, out, in1.stride()[0], in2.stride()[0], out.stride()[0], in1.numel(), in2.numel(), + out.numel(), M=M, N=N, K=K) reference_out = torch.matmul(in1, in2) torch.testing.assert_close(out, reference_out, atol=1e-2, rtol=0) @@ -2785,7 +2763,7 @@ def test_load_cache_modifier(cache, device): x = tl.load(src + offsets, cache_modifier=CACHE) tl.store(dst + offsets, x) - pgm = _kernel[(1,)](dst, src, CACHE=cache) + pgm = _kernel[(1, )](dst, src, CACHE=cache) if is_hip(): return @@ -2813,8 +2791,8 @@ def test_vectorization(N, num_ctas, device): offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x = tl.load(src + offsets, mask=offsets < N) tl.store(dst + offsets, x, mask=offsets < N) - pgm = _kernel[(1,)]( - dst, src, N=N, BLOCK_SIZE=block_size) + + pgm = _kernel[(1, )](dst, src, N=N, BLOCK_SIZE=block_size) if is_hip(): return @@ -2841,7 +2819,8 @@ def test_vectorization_hints(has_hints, device): tl.max_contiguous(tl.multiple_of(offsets, 1024), 1024) x = tl.load(src + offsets, mask=offsets < N) tl.store(dst + offsets, x, mask=offsets < N) - pgm = _kernel[(1,)](dst, src, off, N=1024, BLOCK_SIZE=src.shape[0], HINT=has_hints) + + pgm = _kernel[(1, )](dst, src, off, N=1024, BLOCK_SIZE=src.shape[0], HINT=has_hints) if is_hip(): return @@ -2851,6 +2830,7 @@ def test_vectorization_hints(has_hints, device): else: assert "ld.global.v4.b32" not in ptx + # --------------- # test store # --------------- @@ -2869,7 +2849,7 @@ def test_store_cache_modifier(cache): if is_hip(): return - pgm = _kernel[(1,)](dst, src, CACHE=cache) + pgm = _kernel[(1, )](dst, src, CACHE=cache) ptx = pgm.asm['ptx'] if cache == '': assert 'st.global.wb' not in ptx @@ -2897,6 +2877,7 @@ def test_store_cache_modifier(cache): assert 'st.global.cs' not in ptx assert 'st.global.wt' in ptx + # --------------- # test if # --------------- @@ -2930,52 +2911,56 @@ def test_default(device): tl.store(ret0, _impl()) tl.store(ret1, _impl(value)) - _kernel[(1,)](ret0, ret1, value) + _kernel[(1, )](ret0, ret1, value) assert ret0.item() == 10 assert ret1.item() == value - _kernel[(1,)](ret0, ret1) + _kernel[(1, )](ret0, ret1) assert ret0.item() == 10 assert ret1.item() == 3 + # --------------- # test noop # ---------------- def test_noop(device): + @triton.jit def kernel(x): pass - x = to_triton(numpy_random((1,), dtype_str='int32'), device=device) + + x = to_triton(numpy_random((1, ), dtype_str='int32'), device=device) kernel[(1, )](x) @pytest.mark.parametrize("device", ['cuda', 'cpu', 'cpu_pinned']) def test_pointer_arguments(device): + @triton.jit def kernel(x): pass + pin_memory = 'pinned' in device x = torch.empty(1024, device=device.split('_')[0], pin_memory=pin_memory) if device == "cpu": with pytest.raises(ValueError): - kernel[(1,)](x) + kernel[(1, )](x) else: kernel[(1, )](x) -@pytest.mark.parametrize("value, value_type", [ - (-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'), - (2**31, 'i64'), (2**32 - 1, 'i64'), (2**32, 'i64'), (2**63 - 1, 'i64'), - (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64') -]) +@pytest.mark.parametrize("value, value_type", [(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'), + (2**31, 'i64'), (2**32 - 1, 'i64'), (2**32, 'i64'), (2**63 - 1, 'i64'), + (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')]) def test_value_specialization(value: int, value_type: str, device) -> None: spec_type = None def cache_hook(*args, **kwargs): nonlocal spec_type spec_type = kwargs["compile"]["signature"][0] + JITFunction.cache_hook = cache_hook @triton.jit @@ -2988,15 +2973,13 @@ def test_value_specialization(value: int, value_type: str, device) -> None: JITFunction.cache_hook = None assert spec_type == value_type + # -------------------- # value specialization # -------------------- -@pytest.mark.parametrize( - "value, overflow", - [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)] -) +@pytest.mark.parametrize("value, overflow", [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]) def test_value_specialization_overflow(value: int, overflow: bool, device) -> None: @triton.jit @@ -3016,12 +2999,14 @@ def test_value_specialization_overflow(value: int, overflow: bool, device) -> No # test constexpr # ---------------- + @pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>', '<<', '>>', '&', '^', '|']) @pytest.mark.parametrize("is_lhs_constexpr", [False, True]) @pytest.mark.parametrize("is_rhs_constexpr", [True, False]) def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr, device): if is_hip(): - if (is_rhs_constexpr, is_lhs_constexpr, op) in [(False, False, "<<"), (False, False, ">>"), (False, True, "<<")]: + if (is_rhs_constexpr, is_lhs_constexpr, op) in [(False, False, "<<"), (False, False, ">>"), + (False, True, "<<")]: pytest.skip(f"test_bin_op_constexpr[{is_lhs_constexpr}-{is_rhs_constexpr}-{op}] is not supported in HIP") @triton.jit @@ -3034,19 +3019,19 @@ def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr, device): if op in ['<<', '>>', '&', '^', '|']: # int op x_str = "3" if is_lhs_constexpr else "x" y_str = "4" if is_rhs_constexpr else "y" - x = numpy_random((1,), dtype_str="int32") - y = numpy_random((1,), dtype_str="int32") + x = numpy_random((1, ), dtype_str="int32") + y = numpy_random((1, ), dtype_str="int32") else: x_str = "3.14" if is_lhs_constexpr else "x" y_str = "4.13" if is_rhs_constexpr else "y" - x = numpy_random((1,), dtype_str="float32") - y = numpy_random((1,), dtype_str="float32") + x = numpy_random((1, ), dtype_str="float32") + y = numpy_random((1, ), dtype_str="float32") kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"}) z = np.array(eval(f"{x_str} {op} {y_str}")) x_tri = to_triton(x, device=device) y_tri = to_triton(y, device=device) - z_tri = to_triton(np.empty((1,), dtype=z.dtype), device=device) - kernel[(1,)](z_tri, x_tri, y_tri) + z_tri = to_triton(np.empty((1, ), dtype=z.dtype), device=device) + kernel[(1, )](z_tri, x_tri, y_tri) np.testing.assert_allclose(z, to_numpy(z_tri)) @@ -3058,7 +3043,7 @@ def test_constexpr_shape(device): tl.store(X + off, off) x_tri = to_triton(np.empty((256, ), dtype=np.int32), device=device) - kernel[(1,)](x_tri) + kernel[(1, )](x_tri) np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) @@ -3071,22 +3056,25 @@ def test_constexpr_scalar_shape(device): tl.store(X + off, val) x_tri = to_triton(np.empty((256, ), dtype=np.int32), device=device) - kernel[(1,)](x_tri, 32) + kernel[(1, )](x_tri, 32) np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8) @triton.jit def static_assert_func(): - tl.static_assert(tl.constexpr(False), f"Assert is firing because the constexpr progation did not work properly") + tl.static_assert(tl.constexpr(False), "Assert is firing because the constexpr progation did not work properly") def test_constexpr_propagation(): + @triton.jit def _kernel(COND: tl.constexpr): NEW_COND = COND if NEW_COND: static_assert_func() - _kernel[(1,)](False) + + _kernel[(1, )](False) + # ------------- # test call @@ -3127,11 +3115,11 @@ def test_call(type, num_ctas, device): vecmul_kernel(ptr, n_elements, num2, type) size = 1024 - rand_val = numpy_random((size,), dtype_str="float32") + rand_val = numpy_random((size, ), dtype_str="float32") rand_val_tri = to_triton(rand_val, device=device) err_msg = "" try: - kernel[(size // 128,)](rand_val_tri, size, 3, 5, type, num_ctas=num_ctas) + kernel[(size // 128, )](rand_val_tri, size, 3, 5, type, num_ctas=num_ctas) except Exception as e: err_msg = str(e) @@ -3141,14 +3129,16 @@ def test_call(type, num_ctas, device): ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4 np.testing.assert_equal(to_numpy(rand_val_tri), ans) + # ------------- # test if # ------------- -# TODO(Keren): if_exp_dynamic - -@pytest.mark.parametrize("if_type", ["if", "if_and_dynamic", "if_exp_static", "if_and_static"]) +@pytest.mark.parametrize("if_type", [ + "if", "if_and_dynamic", "if_exp_static", "if_exp_dynamic", "if_exp_dynamic_constexpr", "if_exp_dynamic_void", + "if_and_static" +]) def test_if(if_type, device): @triton.jit @@ -3156,16 +3146,22 @@ def test_if(if_type, device): pid = tl.program_id(0) cond = tl.load(Cond) if IfType == "if": - if pid % 2 == 0: + if pid % 2 == 0: # eq tl.store(Ret, tl.load(XTrue)) - else: + elif 1 == pid % 2: # req tl.store(Ret, tl.load(XFalse)) elif IfType == "if_exp_dynamic": + val = tl.load(XTrue) if pid % 2 == 0 else tl.load(XFalse) + tl.store(Ret, val) + elif IfType == "if_exp_dynamic_constexpr": + val = 3.14 if pid % 2 == 0 else tl.load(XFalse) + tl.store(Ret, val) + elif IfType == "if_exp_dynamic_void": tl.store(Ret, tl.load(XTrue)) if pid % 2 == 0 else tl.store(Ret, tl.load(XFalse)) elif IfType == "if_exp_static": tl.store(Ret, tl.load(XTrue)) if BoolVar else tl.store(Ret, tl.load(XFalse)) elif IfType == "if_and_dynamic": - if BoolVar and pid % 2 == 0: + if BoolVar and (1 != pid % 2 and pid % 2 != 1): # rne and ne tl.store(Ret, tl.load(XTrue)) else: tl.store(Ret, tl.load(XFalse)) @@ -3180,7 +3176,7 @@ def test_if(if_type, device): x_false = torch.tensor([1.51], dtype=torch.float32, device=device) ret = torch.zeros(1, dtype=torch.float32, device=device) - kernel[(1,)](cond, x_true, x_false, ret, if_type, True, 1) + kernel[(1, )](cond, x_true, x_false, ret, if_type, True, 1) assert torch.equal(ret, x_true) @@ -3192,23 +3188,22 @@ def test_num_warps_pow2(device): pass with pytest.raises(AssertionError, match='must be a power of 2'): - _kernel[(1,)](dst=dst, num_warps=3) - _kernel[(1,)](dst=dst, num_warps=1) - _kernel[(1,)](dst=dst, num_warps=2) - _kernel[(1,)](dst=dst, num_warps=4) + _kernel[(1, )](dst=dst, num_warps=3) + _kernel[(1, )](dst=dst, num_warps=1) + _kernel[(1, )](dst=dst, num_warps=2) + _kernel[(1, )](dst=dst, num_warps=4) + # ------------- # test extern # ------------- -@pytest.mark.parametrize("dtype_str, expr, lib_path", - [('int32', 'math.ffs', ''), - ('float32', 'math.log2', ''), - ('float32', 'math.scalbn', ''), - ('float32', 'math.pow', tl.math.libdevice_path()), - ('float64', 'math.pow_dtype', tl.math.libdevice_path()), - ('float64', 'math.norm4d', '')]) +@pytest.mark.parametrize("dtype_str, expr, lib_path", [('int32', 'math.ffs', ''), ('float32', 'math.log2', ''), + ('float32', 'math.scalbn', ''), + ('float32', 'math.pow', tl.math.libdevice_path()), + ('float64', 'math.pow_dtype', tl.math.libdevice_path()), + ('float64', 'math.norm4d', '')]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_math_tensor(dtype_str, expr, lib_path, num_ctas, device): @@ -3239,7 +3234,7 @@ def test_math_tensor(dtype_str, expr, lib_path, num_ctas, device): y_ref = x * pow(2, 2) elif expr == 'math.pow_dtype': x = np.abs(x) - kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.math.pow(x, 0.5)'}) + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.math.pow(x, 0.5)'}) y_ref = np.power(x, 0.5) elif expr == 'math.pow': # numpy does not allow negative factors in power, so we use abs() @@ -3256,8 +3251,8 @@ def test_math_tensor(dtype_str, expr, lib_path, num_ctas, device): x_tri = to_triton(x, device=device) # triton result - y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device=device) - kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path}, num_ctas=num_ctas) + y_tri = to_triton(numpy_random((shape[0], ), dtype_str=dtype_str, rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path}, num_ctas=num_ctas) # compare if expr == 'math.ffs': np.testing.assert_equal(y_ref, to_numpy(y_tri)) @@ -3265,10 +3260,8 @@ def test_math_tensor(dtype_str, expr, lib_path, num_ctas, device): np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01) -@pytest.mark.parametrize("dtype_str, expr, lib_path", - [('float32', 'math.pow', ''), - ('float64', 'math.pow_dtype', ''), - ('float64', 'math.pow', tl.math.libdevice_path())]) +@pytest.mark.parametrize("dtype_str, expr, lib_path", [('float32', 'math.pow', ''), ('float64', 'math.pow_dtype', ''), + ('float64', 'math.pow', tl.math.libdevice_path())]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_math_scalar(dtype_str, expr, lib_path, num_ctas, device): @@ -3281,7 +3274,7 @@ def test_math_scalar(dtype_str, expr, lib_path, num_ctas, device): shape = (128, ) rs = RandomState(17) # limit the range of integers so that the sum does not overflow - x = numpy_random((1,), dtype_str=dtype_str, rs=rs) + x = numpy_random((1, ), dtype_str=dtype_str, rs=rs) y_ref = np.zeros(shape, dtype=x.dtype) # numpy does not allow negative factors in power, so we use abs() @@ -3296,8 +3289,8 @@ def test_math_scalar(dtype_str, expr, lib_path, num_ctas, device): # triton result x_tri = to_triton(x, device=device)[0].item() - y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device=device) - kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path}, num_ctas=num_ctas) + y_tri = to_triton(numpy_random((shape[0], ), dtype_str=dtype_str, rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path}, num_ctas=num_ctas) # compare np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01) @@ -3306,6 +3299,7 @@ def test_math_scalar(dtype_str, expr, lib_path, num_ctas, device): # test inline asm # ----------------------- + @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_inline_asm(num_ctas, device): check_cuda_only(device) @@ -3318,7 +3312,8 @@ def test_inline_asm(num_ctas, device): x = tl.load(X + tl.arange(0, BLOCK)) y = tl.load(Y + tl.arange(0, BLOCK)) s = tl.full([BLOCK], n, tl.int32) - z = tl.inline_asm_elementwise("shf.l.wrap.b32 $0, $1, $2, $3;", "=r,r, r, r", [x, y, s], dtype=tl.int32, is_pure=True, pack=1) + z = tl.inline_asm_elementwise("shf.l.wrap.b32 $0, $1, $2, $3;", "=r,r, r, r", [x, y, s], dtype=tl.int32, + is_pure=True, pack=1) tl.store(Z + tl.arange(0, BLOCK), z) shape = (128, ) @@ -3329,7 +3324,7 @@ def test_inline_asm(num_ctas, device): y_tri = to_triton(y, device=device) n = 17 z_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) - kernel[(1,)](x_tri, y_tri, z_tri, n, BLOCK=shape[0], num_ctas=num_ctas) + kernel[(1, )](x_tri, y_tri, z_tri, n, BLOCK=shape[0], num_ctas=num_ctas) y_ref = (y << n) | (x >> (32 - n)) # compare np.testing.assert_equal(y_ref, to_numpy(z_tri)) @@ -3346,9 +3341,11 @@ def test_inline_asm_packed(num_ctas, device): def kernel(X, Y, BLOCK: tl.constexpr): x = tl.load(X + tl.arange(0, BLOCK)) # shift 4x8bits values together. - y = tl.inline_asm_elementwise("and.b32 $0, $1, 0x1F1F1F1F; \ - shl.b32 $0, $0, 3;", - "=r,r", [x,], dtype=tl.int8, is_pure=True, pack=4) + y = tl.inline_asm_elementwise( + "and.b32 $0, $1, 0x1F1F1F1F; \ + shl.b32 $0, $0, 3;", "=r,r", [ + x, + ], dtype=tl.int8, is_pure=True, pack=4) tl.store(Y + tl.arange(0, BLOCK), y) shape = (512, ) @@ -3356,19 +3353,19 @@ def test_inline_asm_packed(num_ctas, device): x = numpy_random(shape, dtype_str='uint8', rs=rs) x_tri = to_triton(x, device=device) y_tri = to_triton(numpy_random(shape, dtype_str='uint8', rs=rs), device=device) - kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], num_ctas=num_ctas) + kernel[(1, )](x_tri, y_tri, BLOCK=shape[0], num_ctas=num_ctas) y_ref = x << 3 # compare np.testing.assert_equal(y_ref, to_numpy(y_tri)) + # ----------------------- # test control flow # ----------------------- @pytest.mark.parametrize("lo, hi, iv", [(2**35, 2**35 + 20, 1), (2**35, 2**35 + 20, 2), (2**35, 2**35 + 20, 3), - (15, -16, -1), (15, -16, -2), (15, -16, -3), - (-18, -22, -1), (22, 18, -1)]) + (15, -16, -1), (15, -16, -2), (15, -16, -3), (-18, -22, -1), (22, 18, -1)]) def test_for_iv(lo, hi, iv, device): @triton.jit @@ -3381,8 +3378,8 @@ def test_for_iv(lo, hi, iv, device): lo = 2**35 hi = 2**35 + 20 - out = to_triton(np.zeros((1,), dtype=np.int64), device=device) - kernel[(1,)](out, lo, hi, iv) + out = to_triton(np.zeros((1, ), dtype=np.int64), device=device) + kernel[(1, )](out, lo, hi, iv) assert out[0] == sum(range(lo, hi, iv)) @@ -3396,17 +3393,17 @@ def test_if_else(device): val = tl.load(FalseVal) tl.store(Out, val) - out = to_triton(np.zeros((1,), dtype=np.int32), device=device) - true_val = to_triton(np.full((1,), 1, dtype=np.int32), device=device) - false_val = to_triton(np.full((1,), 2, dtype=np.int32), device=device) - cond = to_triton(np.zeros((1,), dtype=np.int32), device=device) + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + true_val = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + false_val = to_triton(np.full((1, ), 2, dtype=np.int32), device=device) + cond = to_triton(np.zeros((1, ), dtype=np.int32), device=device) # True cond[0] = True - kernel[(1,)](cond, true_val, false_val, out) + kernel[(1, )](cond, true_val, false_val, out) assert to_numpy(out)[0] == true_val[0] # False cond[0] = False - kernel[(1,)](cond, true_val, false_val, out) + kernel[(1, )](cond, true_val, false_val, out) assert to_numpy(out)[0] == false_val[0] @@ -3425,15 +3422,15 @@ def test_if_return(mode, device): return tl.store(Out, 1) - out = to_triton(np.zeros((1,), dtype=np.int32), device=device) - exit_early = to_triton(np.zeros((1,), dtype=np.int32), device=device) + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + exit_early = to_triton(np.zeros((1, ), dtype=np.int32), device=device) # exit early path taken exit_early[0] = 1 - kernel[(1,)](exit_early, out, True, mode) + kernel[(1, )](exit_early, out, True, mode) assert to_numpy(out)[0] == 0 # exit early path not taken exit_early[0] = 0 - kernel[(1,)](exit_early, out, False, mode) + kernel[(1, )](exit_early, out, False, mode) assert to_numpy(out)[0] == 1 @@ -3469,10 +3466,11 @@ def add_fn_static_cond(x, cond: tl.constexpr): # TODO(Keren): if_exp -@pytest.mark.parametrize("call_type", ["attribute", "attribute_jit", - "jit", "jit_if", "jit_expr", - "jit_static_cond", "jit_noinline", "jit_extern"]) +@pytest.mark.parametrize( + "call_type", + ["attribute", "attribute_jit", "jit", "jit_if", "jit_expr", "jit_static_cond", "jit_noinline", "jit_extern"]) def test_if_call(call_type, device): + @triton.jit def kernel(Out, call_type: tl.constexpr): pid = tl.program_id(0) @@ -3531,8 +3529,8 @@ def test_if_call(call_type, device): tl.store(Out, o) - out = to_triton(np.zeros((1,), dtype=np.int32), device=device) - kernel[(1,)](out, call_type) + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + kernel[(1, )](out, call_type) assert to_numpy(out)[0] == 1 @@ -3556,14 +3554,14 @@ def test_nested_if_else_return(_cond1, _cond2, _cond3, device): val = tl.load(Val3) tl.store(Out, val) - out = to_triton(np.full((1,), -1, dtype=np.int32), device=device) - cond1 = to_triton(np.full((1,), _cond1, dtype=np.int32), device=device) - cond2 = to_triton(np.full((1,), _cond2, dtype=np.int32), device=device) - cond3 = to_triton(np.full((1,), _cond3, dtype=np.int32), device=device) - val1 = to_triton(np.full((1,), 1, dtype=np.int32), device=device) - val2 = to_triton(np.full((1,), 2, dtype=np.int32), device=device) - val3 = to_triton(np.full((1,), 3, dtype=np.int32), device=device) - kernel[(1,)](cond1, cond2, cond3, val1, val2, val3, out) + out = to_triton(np.full((1, ), -1, dtype=np.int32), device=device) + cond1 = to_triton(np.full((1, ), _cond1, dtype=np.int32), device=device) + cond2 = to_triton(np.full((1, ), _cond2, dtype=np.int32), device=device) + cond3 = to_triton(np.full((1, ), _cond3, dtype=np.int32), device=device) + val1 = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + val2 = to_triton(np.full((1, ), 2, dtype=np.int32), device=device) + val3 = to_triton(np.full((1, ), 3, dtype=np.int32), device=device) + kernel[(1, )](cond1, cond2, cond3, val1, val2, val3, out) targets = { (True, True, True): val1[0], (True, True, False): val1[0], @@ -3592,19 +3590,20 @@ def test_while(device): tl.store(OutI, curr_i) tl.store(OutJ, j) - out_i = to_triton(np.zeros((1,), dtype=np.int32), device=device) - out_j = to_triton(np.zeros((1,), dtype=np.int32), device=device) - init_i = to_triton(np.full((1,), 1, dtype=np.int32), device=device) - out_init_i = to_triton(np.full((1,), 0, dtype=np.int32), device=device) - bound = to_triton(np.full((1,), 10, dtype=np.int32), device=device) - cut_off = to_triton(np.full((1,), 5, dtype=np.int32), device=device) - kernel[(1,)](init_i, bound, cut_off, out_i, out_init_i, out_j) + out_i = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + out_j = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + init_i = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + out_init_i = to_triton(np.full((1, ), 0, dtype=np.int32), device=device) + bound = to_triton(np.full((1, ), 10, dtype=np.int32), device=device) + cut_off = to_triton(np.full((1, ), 5, dtype=np.int32), device=device) + kernel[(1, )](init_i, bound, cut_off, out_i, out_init_i, out_j) assert out_init_i[0] == init_i[0] assert out_i[0] == init_i[0] + 1 assert out_j[0] == bound[0] -def test_while(device): +def test_while2(device): + @triton.jit def nested_while(data, countPtr): for i in range(10): @@ -3614,8 +3613,8 @@ def test_while(device): count = count - 2 counter = torch.tensor([8], dtype=torch.int32, device=device) - data = torch.zeros((1,), device=device, dtype=torch.float32) - nested_while[(1,)](data, counter) + data = torch.zeros((1, ), device=device, dtype=torch.float32) + nested_while[(1, )](data, counter) assert data[0] == 40 @@ -3643,6 +3642,7 @@ def test_while(device): # test extra # ----------------------- + def test_num_threads(device): if is_hip(): pytest.skip("test_num_threads is not supported in HIP") @@ -3655,8 +3655,8 @@ def test_num_threads(device): tl.store(Out + offs, 1) num_threads = 256 - out = to_triton(np.zeros((num_threads,), dtype=np.int32), device=device) - kernel[(1,)](out, num_warps=num_threads // 32) + out = to_triton(np.zeros((num_threads, ), dtype=np.int32), device=device) + kernel[(1, )](out, num_warps=num_threads // 32) assert torch.sum(out) == 256 @@ -3674,9 +3674,9 @@ def test_globaltimer(device): end = tl.extra.cuda.globaltimer() tl.store(Out2, end - start) - out1 = to_triton(np.zeros((128,), dtype=np.int64), device=device) - out2 = to_triton(np.zeros((1,), dtype=np.int64), device=device) - h = kernel[(1,)](out1, out2) + out1 = to_triton(np.zeros((128, ), dtype=np.int64), device=device) + out2 = to_triton(np.zeros((1, ), dtype=np.int64), device=device) + h = kernel[(1, )](out1, out2) assert out2[0] > 0 assert h.asm["ptx"].count("%globaltimer") == 2 @@ -3690,17 +3690,17 @@ def test_smid(device): def kernel(Out): tl.store(Out + tl.program_id(0), tl.extra.cuda.smid()) - out = to_triton(np.zeros((1024,), dtype=np.int32), device=device) - h = kernel[(out.shape[0],)](out) + out = to_triton(np.zeros((1024, ), dtype=np.int32), device=device) + h = kernel[(out.shape[0], )](out) assert out.sort()[0].unique().shape[0] > 0 assert h.asm["ptx"].count("%smid") == 1 + # ----------------------- # test layout conversions # ----------------------- # TODO: backend should be tested separately - layouts = [ # MmaLayout(1, [1, 4], [1, 1], [0, 1]), # MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]), @@ -3800,15 +3800,17 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): def test_load_scalar_with_mask(device): + @triton.jit def kernel(Input, Index, Out, N: int): index = tl.load(Index) scalar = tl.load(Input + index, mask=index < N, other=0) tl.store(Out, scalar, mask=index < N) + Index = torch.tensor([0], dtype=torch.int32, device=device) Input = torch.tensor([0], dtype=torch.int32, device=device) Out = torch.empty_like(Index, device=device) - kernel[(1,)](Input, Index, Out, Index.numel()) + kernel[(1, )](Input, Index, Out, Index.numel()) assert Out.data[0] == 0 @@ -3816,6 +3818,7 @@ def test_load_scalar_with_mask(device): # maybe delete it later after ptxas has been fixed @pytest.mark.parametrize("dtype_str", ['float16', 'int16']) def test_ptx_cast(dtype_str, device): + @triton.jit def kernel(in_ptr0, out_ptr2, xnumel, rnumel, dtype: tl.constexpr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): xoffset = tl.program_id(0) * XBLOCK @@ -3847,15 +3850,17 @@ def test_ptx_cast(dtype_str, device): s0 = 4 buf11 = -torch.ones((6 * s0, 197, 197), device=device, dtype=torch_dtype) buf14 = -torch.ones((s0, 6, 197, 197), device=device, dtype=torch_dtype) - kernel[(4728,)](buf11, buf14, 1182 * s0, 197, triton_dtype, 1, 256, num_warps=2) + kernel[(4728, )](buf11, buf14, 1182 * s0, 197, triton_dtype, 1, 256, num_warps=2) assert buf14.to(torch.float32).mean() == -2.0 + # ----------------------- # test fp8 -> fp32 dot # ----------------------- def f8_to_f16(x, dtype): + @triton.jit def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(0) @@ -3865,21 +3870,21 @@ def f8_to_f16(x, dtype): tl.store(Y + offs, x, mask=mask) ret = torch.empty(x.shape, dtype=torch.float16, device=x.device) - grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']),) + grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']), ) dtype = getattr(tl, dtype) kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024) return ret @triton.jit -def matmul_kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - low_precision_acc: tl.constexpr, +def matmul_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + low_precision_acc: tl.constexpr # ): pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) @@ -3917,9 +3922,7 @@ def test_fp8_dot_acc(in_type_str, low_precision_acc, device): a = to_triton(A, device='cuda', dst_type=in_type_str) b = to_triton(B, device='cuda', dst_type=in_type_str) grid = (triton.cdiv(M, BLOCK_M), 1) - matmul_kernel[grid](a, b, C, M, N, K, - a.stride(0), a.stride(1), b.stride(0), b.stride( - 1), C.stride(0), C.stride(1), + matmul_kernel[grid](a, b, C, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), C.stride(0), C.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, low_precision_acc, num_warps=num_warps) torch_a = torch.from_numpy(A) th_a = f8_to_f16(torch_a.cuda(), in_type_str) @@ -3933,6 +3936,7 @@ def test_fp8_dot_acc(in_type_str, low_precision_acc, device): else: torch.testing.assert_close(ref_out, C) + # ----------------------- # test enable_fp_fusion # ----------------------- @@ -3946,8 +3950,8 @@ def test_enable_fp_fusion(enable_fp_fusion): ptrs = data + tl.arange(0, 128) tl.store(ptrs, tl.load(ptrs) * 1.5 + 1.0) - data = torch.randn((128,), device='cuda', dtype=torch.float32) - h = mul_add[(1,)](data, enable_fp_fusion=enable_fp_fusion) + data = torch.randn((128, ), device='cuda', dtype=torch.float32) + h = mul_add[(1, )](data, enable_fp_fusion=enable_fp_fusion) found_fma = re.search(r'(mad|fma)\.r[nzmp]\.(ftz\.)?f32', h.asm["ptx"]) is not None assert found_fma == enable_fp_fusion diff --git a/python/test/unit/language/test_random.py b/python/test/unit/language/test_random.py index e12adff1e..7f6784d0b 100644 --- a/python/test/unit/language/test_random.py +++ b/python/test/unit/language/test_random.py @@ -12,6 +12,7 @@ import triton.language as tl class PhiloxConfig: + def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE): self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE) self.PHILOX_ROUND_B = np.array(PHILOX_ROUND_B, dtype=DTYPE) @@ -40,6 +41,7 @@ PHILOX_64 = PhiloxConfig( class CustomPhilox4x: + def __init__(self, seed, config): self._config = config seed = self._into_pieces(seed) @@ -92,6 +94,7 @@ class CustomPhilox4x: class CustomPhilox(CustomPhilox4x): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.buffer = [] @@ -111,10 +114,9 @@ BLOCK = 1024 # test generation of random uint32 -@pytest.mark.parametrize('size, seed', - [(size, seed) for size in ['10', '4,53', '10000'] - for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]] - ) +@pytest.mark.parametrize('size, seed', [(size, seed) + for size in ['10', '4,53', '10000'] + for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]) def test_randint(size, seed, device): size = list(map(int, size.split(','))) @@ -123,10 +125,11 @@ def test_randint(size, seed, device): offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) rand = tl.randint(seed, offset) tl.store(X + offset, rand, mask=offset < N) + # triton result x = torch.empty(size, dtype=torch.int32, device=device) N = x.numel() - grid = (triton.cdiv(N, BLOCK),) + grid = (triton.cdiv(N, BLOCK), ) kernel[grid](x, N, seed) out_tri = x.cpu().numpy().astype(np.uint32).flatten().tolist() # reference result @@ -134,44 +137,44 @@ def test_randint(size, seed, device): out_ref = [gen.random_raw()[0] for _ in out_tri] assert out_tri == out_ref + # test uniform PRNG -@pytest.mark.parametrize('size, seed', - [(size, seed) for size in [1000000] - for seed in [0, 42, 124, 54]] - ) +@pytest.mark.parametrize('size, seed', [(size, seed) for size in [1000000] for seed in [0, 42, 124, 54]]) def test_rand(size, seed, device): + @triton.jit def kernel(X, N, seed): offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) rand = tl.rand(seed, offset) tl.store(X + offset, rand, mask=offset < N) + # triton result x = torch.empty(size, dtype=torch.float32, device=device) N = x.numel() - grid = (triton.cdiv(N, BLOCK),) + grid = (triton.cdiv(N, BLOCK), ) kernel[grid](x, N, seed) assert all((x >= 0) & (x <= 1)) assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01 + # test normal PRNG -@pytest.mark.parametrize('size, seed', - [(size, seed) for size in [1000000] - for seed in [0, 42, 124, 54]] - ) +@pytest.mark.parametrize('size, seed', [(size, seed) for size in [1000000] for seed in [0, 42, 124, 54]]) def test_randn(size, seed, device): + @triton.jit def kernel(X, N, seed): offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) rand = tl.randn(seed, offset) tl.store(X + offset, rand, mask=offset < N) + # triton result x = torch.empty(size, dtype=torch.float32, device=device) N = x.numel() - grid = (triton.cdiv(N, BLOCK),) + grid = (triton.cdiv(N, BLOCK), ) kernel[grid](x, N, seed) assert abs(x.mean()) < 1e-2 assert abs(x.std() - 1) < 1e-2 @@ -179,7 +182,9 @@ def test_randn(size, seed, device): # tl.rand() should never produce >=1.0 + def test_rand_limits(device): + @triton.jit def kernel(input, output, n: tl.constexpr): idx = tl.arange(0, n) @@ -192,7 +197,7 @@ def test_rand_limits(device): torch.iinfo(torch.int32).max, ], dtype=torch.int32, device=device) output = torch.empty(2, dtype=torch.float32, device=device) - kernel[(1,)](min_max_int32, output, 2) + kernel[(1, )](min_max_int32, output, 2) assert output[0] == output[1] assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0 diff --git a/python/test/unit/language/test_subprocess.py b/python/test/unit/language/test_subprocess.py index 78b8d09fb..08bc63a3e 100644 --- a/python/test/unit/language/test_subprocess.py +++ b/python/test/unit/language/test_subprocess.py @@ -1,6 +1,8 @@ +import itertools import os import subprocess import sys +from collections import Counter import pytest @@ -9,53 +11,88 @@ print_path = os.path.join(dir_path, "print_helper.py") assert_path = os.path.join(dir_path, "assert_helper.py") # TODO: bfloat16 after LLVM-15 -assert_types = ["device_assert", "assert", "static_assert", "no_debug"] +assert_types = ["device_assert", "device_assert_passes", "assert", "static_assert", "no_debug", "double_assert"] nested_types = [(caller, callee) for caller in ["true", "false", "none"] for callee in ["true", "false", "none"]] torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"] -@pytest.mark.parametrize("func_type, data_type", - [("device_print", data_type) for data_type in torch_types] + [("print", "int32"), ("static_print", "int32"), ("no_arg_print", "int32")]) +# TODO: Print with multiple operands +@pytest.mark.parametrize("func_type, data_type", [("device_print", data_type) for data_type in torch_types] + [ + ("print", "int32"), + ("static_print", "int32"), + ("no_arg_print", "int32"), + ("print_no_arg", "int32"), + ("device_print_large", "int32"), + ("print_multiple_args", "int32"), + ("device_print_multiple_args", "int32"), +]) def test_print(func_type: str, data_type: str): proc = subprocess.Popen([sys.executable, print_path, func_type, data_type], stdout=subprocess.PIPE, shell=False) outs, _ = proc.communicate() - outs = outs.split() - new_lines = set() - for line in outs: - try: - value = line - if func_type != "static_print": - value = int(float(line)) - new_lines.add(value) - except Exception as e: - print(e) - if func_type != "static_print" and func_type != "no_arg_print": + outs = [line for line in outs.decode("UTF-8").split("\n") if line] + + # Format is + # pid (, , ) idx (, , ...) (operand ) + expected_lines = Counter() + if func_type == "print" or func_type == "device_print": for i in range(128): - assert i in new_lines - else: - assert len(new_lines) == 1 + line = f"pid (0, 0, 0) idx ({i:3}) x: {i}" + if data_type.startswith("float"): + line += ".000000" + expected_lines[line] = 1 + elif func_type == "static_print": + expected_lines[" int32[constexpr[128]]"] = 1 + elif func_type == "no_arg_print": + expected_lines["pid (0, 0, 0) idx (): 0"] = 128 + elif func_type == "print_no_arg": + expected_lines["pid (0, 0, 0) no arg"] = 128 + elif func_type == "device_print_large": + for i, j, k in itertools.product(range(2), range(64), range(128)): + expected_lines[f"pid (0, {i}, 0) idx ({j:2}, {k:3}) x: 1"] = 1 + elif func_type == "print_multiple_args" or func_type == "device_print_multiple_args": + for i in range(128): + expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 0) {i}"] = 1 + expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 1) 1"] = 1 + + actual_lines = Counter() + for line in outs: + actual_lines[line] += 1 + + diff = Counter(actual_lines) + diff.subtract(expected_lines) + for line, delta in diff.items(): + if delta == 0: + continue + print(f'Expected line "{line}" {expected_lines[line]} time(s), but saw {actual_lines[line]} time(s)') + assert all(delta == 0 for delta in diff.values()) @pytest.mark.parametrize("func_type", assert_types) def test_assert(func_type: str): os.environ["TRITON_DEBUG"] = "1" - proc = subprocess.Popen([sys.executable, assert_path, func_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False) + proc = subprocess.Popen([sys.executable, assert_path, func_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE, + shell=False) _, errs = proc.communicate() errs = errs.splitlines() num_errs = 0 for err in errs: if "x != 0" in err.decode("utf-8"): num_errs += 1 + + # Check for segfaults. + assert all("segmentation fault" not in line.decode("utf-8").lower() for line in errs) + os.environ["TRITON_DEBUG"] = "0" - if func_type != "static_assert": - assert num_errs == 127 - else: + if func_type == "static_assert" or func_type == "device_assert_passes": assert num_errs == 0 + else: + assert num_errs == 127 @pytest.mark.parametrize("caller_type, callee_type", nested_types) def test_assert_nested(caller_type, callee_type): - proc = subprocess.Popen([sys.executable, assert_path, caller_type, callee_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False) + proc = subprocess.Popen([sys.executable, assert_path, caller_type, callee_type], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, shell=False) _, errs = proc.communicate() errs = errs.splitlines() num_errs = 0 diff --git a/python/test/unit/operators/test_blocksparse.py b/python/test/unit/operators/test_blocksparse.py index 7e6f820a3..acc5e30c6 100644 --- a/python/test/unit/operators/test_blocksparse.py +++ b/python/test/unit/operators/test_blocksparse.py @@ -68,8 +68,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K= b_ref = do_mask(b_ref) if is_dds else b_ref a_ref.retain_grad() b_ref.retain_grad() - c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref, - b_ref.transpose(2, 3) if TRANS_B else b_ref) + c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref, b_ref.transpose(2, 3) if TRANS_B else b_ref) c_ref.backward(dc_ref) c_ref = do_sparsify(c_ref) if is_sdd else c_ref da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad @@ -172,7 +171,7 @@ def test_attention_fwd_bwd( value.retain_grad() attn_out = triton_attention(layout, block, query=query, key=key, value=value, scale=scale) # ad hoc loss - loss = (attn_out ** 2).mean() + loss = (attn_out**2).mean() loss.backward() grads = [query.grad, key.grad, value.grad] @@ -189,7 +188,7 @@ def test_attention_fwd_bwd( probs = torch.softmax(scores, dim=-1) torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v) # ad hoc loss - torch_loss = (torch_attn_out ** 2).mean() + torch_loss = (torch_attn_out**2).mean() torch_loss.backward() torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad] @@ -209,8 +208,10 @@ def triton_attention( value: torch.Tensor, scale: float, ): - sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device) - sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device) + sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, + device=value.device) + sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, + device=value.device) sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device) w = sparse_dot_sdd_nt(query, key) diff --git a/python/test/unit/operators/test_cross_entropy.py b/python/test/unit/operators/test_cross_entropy.py index f6ae42ac3..5bffd2ad8 100644 --- a/python/test/unit/operators/test_cross_entropy.py +++ b/python/test/unit/operators/test_cross_entropy.py @@ -5,14 +5,13 @@ import triton import triton.ops -@pytest.mark.parametrize("M, N, dtype, mode", - [ - (M, N, dtype, mode) for M in [1024, 821] - for N in [512, 857, 1871, 2089, 8573, 31000] - for dtype in ['float16', 'float32'] - for mode in ['forward', 'backward'] - ] - ) +@pytest.mark.parametrize("M, N, dtype, mode", [ # + (M, N, dtype, mode) + for M in [1024, 821] + for N in [512, 857, 1871, 2089, 8573, 31000] + for dtype in ['float16', 'float32'] + for mode in ['forward', 'backward'] +]) def test_op(M, N, dtype, mode): capability = torch.cuda.get_device_capability() if capability[0] < 8 and dtype == "bfloat16": diff --git a/python/test/unit/operators/test_flash_attention.py b/python/test/unit/operators/test_flash_attention.py index 09f739bcb..5783141da 100644 --- a/python/test/unit/operators/test_flash_attention.py +++ b/python/test/unit/operators/test_flash_attention.py @@ -5,10 +5,12 @@ import triton import triton.ops -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(2, 4, 512, 16), - (2, 4, 512, 32), - (2, 4, 512, 64), - (2, 4, 512, 128)]) +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ # + (2, 4, 512, 16), + (2, 4, 512, 32), + (2, 4, 512, 64), + (2, 4, 512, 128), +]) @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) @pytest.mark.parametrize('causal', [True, False]) @pytest.mark.parametrize('seq_par', [True, False]) @@ -56,6 +58,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par): # # triton implementation tri_out = triton.ops.attention(q, k, v, causal, sm_scale, seq_par) +<<<<<<< HEAD # print(ref_out) # print(tri_out) if torch.version.hip is None: @@ -70,3 +73,74 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par): torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0) torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0) torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0) +======= + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # compare + atol = 1e-1 if dtype == torch.bfloat16 else 1e-2 + torch.testing.assert_close(ref_out, tri_out, atol=atol, rtol=0) + torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0) + torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0) + torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0) + + +try: + from flash_attn.flash_attn_interface import flash_attn_func + HAS_FLASH = True +except BaseException: + HAS_FLASH = False + +BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 +# vary seq length for fixed head and batch=4 +configs = [ + triton.testing.Benchmark( + x_names=['N_CTX'], x_vals=[2**i for i in range(10, 14)], line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), styles=[('red', '-'), ('blue', '-')], ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-{casual}-{seq_par}', args={ + 'H': N_HEADS, + 'BATCH': BATCH, + 'D_HEAD': D_HEAD, + 'dtype': torch.float16, + 'mode': mode, + 'casual': casual, + 'seq_par': seq_par, + }) for mode in ['fwd', 'bwd'] for casual in [True, False] for seq_par in [True, False] +] + + +@triton.testing.perf_report(configs) +def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, casual, seq_par, provider, dtype=torch.float16, device="cuda"): + assert mode in ['fwd', 'bwd'] + warmup = 25 + rep = 100 + sm_scale = 1.3 + q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + if provider == "triton": + fn = lambda: triton.ops.attention(q, k, v, casual, sm_scale, seq_par) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + if provider == "flash": + lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros((BATCH + 1, ), device=device, dtype=torch.int32) + cu_seqlens[1:] = lengths.cumsum(0) + fn = lambda: flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=sm_scale, causal=casual) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + + +# only works on post-Ampere GPUs right now +# bench_flash_attention.run(save_path='.', print_data=True) +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 diff --git a/python/test/unit/operators/test_inductor.py b/python/test/unit/operators/test_inductor.py index 579d0ad93..2fdfe235e 100644 --- a/python/test/unit/operators/test_inductor.py +++ b/python/test/unit/operators/test_inductor.py @@ -8,7 +8,8 @@ import triton.language as tl def test_normalization_with_remat(): @triton.jit - def triton_(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): + def triton_(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, xnumel, rnumel, XBLOCK: tl.constexpr, + RBLOCK: tl.constexpr): xnumel = 512 rnumel = 4096 xoffset = tl.program_id(0) * XBLOCK @@ -52,7 +53,7 @@ def test_normalization_with_remat(): arg115_1 = torch.rand(64, device="cuda") arg8_1 = torch.rand(64, device="cuda") arg9_1 = torch.rand(64, device="cuda") - triton_[(512,)](buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048) + triton_[(512, )](buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048) torch.testing.assert_close(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0) @@ -148,7 +149,7 @@ def test_avg_pool_bw(): inp = torch.ones(8, 2048, 8, 8, device="cuda", dtype=torch.half) out = torch.ones_like(inp) * 3 numel = inp.numel() - triton_[(numel // 1024,)](inp, out, 1024) + triton_[(numel // 1024, )](inp, out, 1024) out_ref = torch.ones_like(inp) out_ref[:, :, 1:7, 0::7] = 2 / 3 out_ref[:, :, 0::7, 1:7] = 2 / 3 @@ -159,6 +160,7 @@ def test_avg_pool_bw(): @pytest.mark.parametrize("RBLOCK", [1, 16, 32, 64, 128]) @pytest.mark.parametrize("num_warps", [1, 4]) def test_scan2d_broadcast(RBLOCK, num_warps): + @triton.jit(debug=True) def fn(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): rindex = tl.arange(0, RBLOCK)[None, :] @@ -172,12 +174,13 @@ def test_scan2d_broadcast(RBLOCK, num_warps): XBLOCK = 4 input = torch.randint(0, 10, (1, RBLOCK), dtype=torch.int64, device='cuda') output = torch.empty((XBLOCK, RBLOCK), dtype=torch.int64, device='cuda') - fn[(1,)](input, output, XBLOCK, RBLOCK, num_warps=num_warps) + fn[(1, )](input, output, XBLOCK, RBLOCK, num_warps=num_warps) ref = input.cumsum(1).broadcast_to((XBLOCK, RBLOCK)) torch.testing.assert_close(output, ref) def test_scan2d_for(): + @triton.jit def fn(out_ptr0, rnumel, RBLOCK: tl.constexpr): rbase = tl.arange(0, RBLOCK)[None, :] @@ -190,6 +193,6 @@ def test_scan2d_for(): RBLOCK = 8 out0 = torch.empty(RBLOCK, device="cuda", dtype=torch.int64) - fn[(1,)](out0, RBLOCK, RBLOCK) + fn[(1, )](out0, RBLOCK, RBLOCK) ref = torch.arange(RBLOCK, device="cuda", dtype=torch.int64) + 1 torch.testing.assert_close(out0, ref) diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py index 642b0982b..801cb8a41 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/python/test/unit/operators/test_matmul.py @@ -19,7 +19,7 @@ def f8_to_f16(x, dtype): tl.store(Y + offs, x, mask=mask) ret = torch.empty(x.shape, dtype=torch.float16, device=x.device) - grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']),) + grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']), ) dtype = getattr(tl, dtype) kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024) return ret @@ -28,87 +28,88 @@ def f8_to_f16(x, dtype): @pytest.mark.parametrize( "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM", itertools.chain( - *[ - [ - # 1 warp - (16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - # 2 warp - (64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - # 4 warp - (128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - # 8 warp - (128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - # variable input - (128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True, True), - (128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True, True), - (128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True, True), - (128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True, True), - ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] - ], + *[[ + # 1 warp + (16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + # 2 warp + (64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + # 4 warp + (128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + # 8 warp + (128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + # variable input + (128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True, True), + (128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True, True), + (128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True, True), + (128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True, True), + ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]], # n-stage - *[ - [ - (16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True, True), - (64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True, True), - (128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True, True), - (256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True, True), - (128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True, True), - ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [4] - ], + *[[ + (16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True, True), + (64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True, True), + (128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True, True), + (256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True, True), + (128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True, True), + ] + for DTYPE in ["float16", "bfloat16", "float32"] + for AT in [False, True] + for BT in [False, True] + for STAGES in [4]], # mixed-precision - *[ - [ - (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), - (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), - (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), - ] for ADTYPE, BDTYPE in [("float8e4nv", "float8e5"), - ("float8e4nv", "float8e4nv"), - ("float8e5", "float8e4nv"), - ("float8e5", "float8e5"), - ("float8e4b15", "float8e4b15"), - ("float8e4nv", "float16"), - ("float16", "float8e5"), - ("float16", "float32"), - ("float32", "float16"), - ("bfloat16", "float32"), - ("float32", "bfloat16")] for AT in [False, True] for BT in [False, True] for FASTACCUM in [True, False] - ], + *[[ + (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), + (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), + (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), + ] for ADTYPE, BDTYPE in [ + ("float8e4nv", "float8e5"), + ("float8e4nv", "float8e4nv"), + ("float8e5", "float8e4nv"), + ("float8e5", "float8e5"), + ("float8e4b15", "float8e4b15"), + ("float8e4nv", "float16"), + ("float16", "float8e5"), + ("float16", "float32"), + ("float32", "float16"), + ("bfloat16", "float32"), + ("float32", "bfloat16"), + ] for AT in [False, True] for BT in [False, True] for FASTACCUM in [True, False]], # mixed-precision block layout - *[ - [ - (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True), - (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True), - (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False, True), - ] for ADTYPE, BDTYPE in [("float8e4nv", "float16"), - ("float16", "float8e5"), - ("float16", "float32"), - ("float32", "float16"), - ("bfloat16", "float32"), - ("float32", "bfloat16")] for AT in [False, True] for BT in [False, True] - ], + *[[ + (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True), + (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True), + (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False, True), + ] for ADTYPE, BDTYPE in [ + ("float8e4nv", "float16"), + ("float16", "float8e5"), + ("float16", "float32"), + ("float32", "float16"), + ("bfloat16", "float32"), + ("float32", "bfloat16"), + ] for AT in [False, True] for BT in [False, True]], ), ) -def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM): +def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, + F8_FASTACCUM): capability = torch.cuda.get_device_capability() if capability[0] < 7: pytest.skip("Only test tl.dot() on devices with sm >= 70") @@ -147,7 +148,7 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, return torch.randint(-128, 127, (m, n), device="cuda", dtype=torch.int8) dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype] exponents = torch.randint(-10, 0, size=(m, n)) - ret = (2. ** exponents).to(dtype).to("cuda") + ret = (2.**exponents).to(dtype).to("cuda") return ret # allocate/transpose inputs diff --git a/python/test/unit/runtime/test_autotuner.py b/python/test/unit/runtime/test_autotuner.py index c425a3669..198717a32 100644 --- a/python/test/unit/runtime/test_autotuner.py +++ b/python/test/unit/runtime/test_autotuner.py @@ -17,6 +17,25 @@ def test_kwargs(): offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x = tl.load(src + offsets, mask=offsets < N) tl.store(dst + offsets, x, mask=offsets < N) - grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']),) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) _kernel[grid](dst, src, N) _kernel[grid](dst=dst, src=src, N=N) + + +def test_restore(): + N = 1024 + src = torch.zeros(N, device='cuda') + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] + + @triton.autotune(configs=configs, key=['N'], restore_value=['src']) + @triton.jit + def _kernel(src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + 1 + tl.store(src + offsets, x, mask=offsets < N) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](src, N) + triton.testing.assert_close(src, torch.ones_like(src)) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index f75fa7c32..c8e2e91fa 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -80,11 +80,12 @@ def test_reuse(): def inc_counter(*args, **kwargs): nonlocal counter counter += 1 + JITFunction.cache_hook = inc_counter reset_tmp_dir() x = torch.empty(1, dtype=torch.int32, device='cuda') for i in range(10): - kernel[(1,)](x, 1, BLOCK=1024) + kernel[(1, )](x, 1, BLOCK=1024) assert counter == 1 @@ -95,17 +96,19 @@ def test_specialize(mode): def inc_counter(*args, **kwargs): nonlocal counter counter += 1 + JITFunction.cache_hook = inc_counter reset_tmp_dir() x = torch.empty(1, dtype=torch.int32, device='cuda') function = {'enable': kernel, 'disable': kernel_nospec}[mode] target = {'enable': 4, 'disable': 1}[mode] for i in [1, 2, 4, 8, 16, 32]: - function[(1,)](x, i, BLOCK=512) + function[(1, )](x, i, BLOCK=512) assert counter == target def test_annotation(): + @triton.jit def kernel(X, i: tl.int32): tl.store(X, i) @@ -113,14 +116,15 @@ def test_annotation(): x = torch.empty(1, dtype=torch.int32, device='cuda') device = torch.cuda.current_device() - kernel[(1,)](x, 1) - kernel[(1,)](x, 8) - kernel[(1,)](x, 16) - kernel[(1,)](x, 17) + kernel[(1, )](x, 1) + kernel[(1, )](x, 8) + kernel[(1, )](x, 16) + kernel[(1, )](x, 17) assert len(kernel.cache[device]) == 4 def test_constexpr_not_callable() -> None: + @triton.jit def kernel(X, c: tl.constexpr): tl.store(X, 2) @@ -141,11 +145,11 @@ def test_constexpr_not_callable() -> None: def test_jit_warmup_cache() -> None: + @triton.jit def kernel_add(a, b, o, N: tl.constexpr): idx = tl.arange(0, N) - tl.store(o + idx, - tl.load(a + idx) + tl.load(b + idx)) + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) args = [ torch.randn(32, dtype=torch.float32, device="cuda"), @@ -155,31 +159,31 @@ def test_jit_warmup_cache() -> None: ] device = torch.cuda.current_device() assert len(kernel_add.cache[device]) == 0 - kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,)) + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) assert len(kernel_add.cache[device]) == 1 - kernel_add.warmup(*args, grid=(1,)) + kernel_add.warmup(*args, grid=(1, )) assert len(kernel_add.cache[device]) == 1 - kernel_add.warmup(*args, grid=(1,)) + kernel_add.warmup(*args, grid=(1, )) assert len(kernel_add.cache[device]) == 1 def test_jit_debug() -> None: + @triton.jit def kernel_add(a, b, o, N: tl.constexpr): idx = tl.arange(0, N) tl.device_assert(idx < 32, "idx < 32") - tl.store(o + idx, - tl.load(a + idx) + tl.load(b + idx)) + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) device = torch.cuda.current_device() assert len(kernel_add.cache[device]) == 0 - kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,)) + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) assert len(kernel_add.cache[device]) == 1 kernel_add.debug = False - kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,)) + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) assert len(kernel_add.cache[device]) == 2 kernel_add.debug = True - kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,)) + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) assert len(kernel_add.cache[device]) == 3 bins = list(kernel_add.cache[device].values()) assert bins[2].asm['ttir'] != bins[1].asm['ttir'] @@ -192,13 +196,14 @@ def add_fn(a, b, o, N: tl.constexpr): def test_jit_noinline() -> None: + @triton.jit def kernel_add_device(a, b, o, N: tl.constexpr): add_fn(a, b, o, N) device = torch.cuda.current_device() assert len(kernel_add_device.cache[device]) == 0 - kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,)) + kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) assert len(kernel_add_device.cache[device]) == 1 bins = list(kernel_add_device.cache[device].values()) inline_ttir = bins[0].asm['ttir'] @@ -206,7 +211,7 @@ def test_jit_noinline() -> None: add_fn.hash = None kernel_add_device.hash = None kernel_add_device.cache[device].clear() - kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,)) + kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) assert len(kernel_add_device.cache[device]) == 1 bins = list(kernel_add_device.cache[device].values()) noinline_ttir = bins[0].asm['ttir'] @@ -214,6 +219,7 @@ def test_jit_noinline() -> None: def test_memory_leak() -> None: + @triton.jit def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): xnumel = 10 diff --git a/python/test/unit/runtime/test_launch.py b/python/test/unit/runtime/test_launch.py index d3f9fd01b..00009f230 100644 --- a/python/test/unit/runtime/test_launch.py +++ b/python/test/unit/runtime/test_launch.py @@ -31,11 +31,11 @@ def test_memory_leak() -> None: try: inp = torch.randn(10, device='cuda') out = torch.randn(10, device='cuda') - kernel[(10,)](inp, out, 10, XBLOCK=16) + kernel[(10, )](inp, out, 10, XBLOCK=16) gc.collect() begin, _ = tracemalloc.get_traced_memory() for _ in range(100): - kernel[(10,)](inp, out, 10, XBLOCK=16) + kernel[(10, )](inp, out, 10, XBLOCK=16) gc.collect() end, _ = tracemalloc.get_traced_memory() assert end - begin < 30000 diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index f1958ffe2..cbddc0d0c 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -17,9 +17,11 @@ def reset_tmp_dir(): shutil.rmtree(tmpdir, ignore_errors=True) -instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"]) +instance_descriptor = namedtuple("instance_descriptor", + ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"]) +<<<<<<< HEAD def get_device_type(): try: import torch @@ -36,10 +38,15 @@ def get_device_type(): def compile_fn(config, device_type, cc): +======= +def compile_fn(config, cc): + +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 @triton.jit def kernel_sub(a, b, o, N: tl.constexpr): idx = tl.arange(0, N) tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777) + triton.compile( fn=kernel_sub, signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, @@ -57,15 +64,24 @@ def test_compile_in_subproc() -> None: config = instance_descriptor(tuple(range(4)), (), (), ()) multiprocessing.set_start_method('fork') +<<<<<<< HEAD proc = multiprocessing.Process( target=compile_fn, args=(config, device_type, cc)) +======= + proc = multiprocessing.Process(target=compile_fn, args=(config, cc)) +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 proc.start() proc.join() assert proc.exitcode == 0 +<<<<<<< HEAD def compile_fn_dot(config, device_type, cc): +======= +def compile_fn_dot(config, cc): + +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 @triton.jit def kernel_dot(Z): offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :] @@ -90,9 +106,13 @@ def test_compile_in_forked_subproc() -> None: config = instance_descriptor(tuple(range(1)), (), (), ()) assert multiprocessing.get_start_method() == 'fork' +<<<<<<< HEAD proc = multiprocessing.Process( target=compile_fn_dot, args=(config, device_type, cc)) +======= + proc = multiprocessing.Process(target=compile_fn_dot, args=(config, cc)) +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 proc.start() proc.join() assert proc.exitcode == 0 diff --git a/python/test/unit/tools/test_aot.py b/python/test/unit/tools/test_aot.py index 06a7ed2b1..92b5562e9 100644 --- a/python/test/unit/tools/test_aot.py +++ b/python/test/unit/tools/test_aot.py @@ -59,7 +59,7 @@ def kernel(C, A, B, M, N, K, tl.store(c_ptrs, c) """ -test_utils_src = ''' +test_utils_src = """ #include #include #include @@ -93,23 +93,26 @@ static void read_csv_to_buffer(char *filename, int16_t *buffer, int size) { index++; } fclose(file); -}''' +}""" def gen_kernel_library(dir, libname): c_files = glob.glob(os.path.join(dir, "*.c")) - subprocess.run(["gcc"] + c_files + ["-I", cuda_include_dir(), - "-c", "-fPIC"], - check=True, cwd=dir) + subprocess.run( + ["gcc"] + c_files + ["-I", cuda_include_dir(), "-c", "-fPIC"], + check=True, + cwd=dir, + ) o_files = glob.glob(os.path.join(dir, "*.o")) - subprocess.run(["gcc"] + o_files + ["-shared", - "-o", libname, - "-L", libcuda_dirs()[0]], - check=True, cwd=dir) + subprocess.run( + ["gcc"] + o_files + ["-shared", "-o", libname, "-L", libcuda_dirs()[0]], + check=True, + cwd=dir, + ) def gen_test_bin(dir, M, N, K, exe="test", algo_id=0): - test_src = f''' + test_src = f""" int main(int argc, char **argv) {{ int M = {M}, N = {N}, K = {K}; @@ -165,17 +168,29 @@ int main(int argc, char **argv) {{ cuMemFree(C); cuCtxDestroy(ctx); }} -''' +""" src = test_utils_src + test_src with open(os.path.join(dir, "test.c"), "w") as file: file.write(src) - subprocess.run(["gcc"] + ["test.c", - "-I", cuda_include_dir(), - "-L", libcuda_dirs()[0], - "-l", "cuda", - "-L", dir, - "-l", "kernel", - "-o", exe], check=True, cwd=dir) + subprocess.run( + ["gcc"] + [ + "test.c", + "-I", + cuda_include_dir(), + "-L", + libcuda_dirs()[0], + "-l", + "cuda", + "-L", + dir, + "-l", + "kernel", + "-o", + exe, + ], + check=True, + cwd=dir, + ) def write_triton_kernels(dir, src, util_src): @@ -190,16 +205,67 @@ def write_triton_kernels(dir, src, util_src): return kernel_path -def compile_aot_kernels(dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints): +def _compile_kernel(dir, signature, kernel_name, out_name, out_path, num_warps, grid, kernel_path): compiler_path = os.path.join(triton.tools.__path__[0], "compile.py") + subprocess.run( + [ + sys.executable, + compiler_path, + "-n", + kernel_name, + "--signature", + signature, + "--out-name", + out_name, + "-o", + out_path, + "-w", + str(num_warps), + "-g", + grid, + kernel_path, + ], + check=True, + cwd=dir, + ) + + +# Edge case kernel with no specialization +def compile_aot_kernel_no_specialization(dir, kernel_path, dtype, BM, BN, BK): + # compile all desired configs + sig = f"*fp32, *{dtype}, *{dtype}, i32, i32, i32, i32, i32, i32, i32, i32, i32, {BM}, {BN}, {BK}" + name = f"matmul_{dtype}" + grid = f"M/{BM}, N/{BN}, 1" + _compile_kernel( + dir=dir, + signature=sig, + kernel_name="kernel", + out_name=name, + out_path=name, + num_warps=1, + grid=grid, + kernel_path=kernel_path, + ) + + +def compile_aot_kernels(dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints): # compile all desired configs for ha in ha_hb_hints: for hb in ha_hb_hints: - sig = f'*fp32:16, *{dtype}:16, *{dtype}:16, i32, i32, i32, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}' + sig = f"*fp32:16, *{dtype}:16, *{dtype}:16, i32, i32, i32, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}" name = f"matmul_{dtype}" - grid = f'M/{BM}, N/{BN}, 1' - subprocess.run([sys.executable, compiler_path, "-n", "kernel", "--signature", sig, "--out-name", name, "-o", name, "-w", "1", "-g", grid, kernel_path], check=True, cwd=dir) + grid = f"M/{BM}, N/{BN}, 1" + _compile_kernel( + dir=dir, + signature=sig, + kernel_name="kernel", + out_name=name, + out_path=name, + num_warps=1, + grid=grid, + kernel_path=kernel_path, + ) def link_aot_kernels(dir): @@ -221,11 +287,42 @@ def generate_matmul_test_data(dir, M, N, K): return a, b, a_path, b_path, c_path +# Test edge case where the provided kernel signature has no specializations +def test_compile_link_matmul_no_specialization(): + np.random.seed(3) + + with tempfile.TemporaryDirectory() as tmp_dir: + dtype = "fp16" + BM, BN, BK = 16, 16, 16 + + kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src) + compile_aot_kernel_no_specialization(tmp_dir, kernel_path, dtype, BM, BN, BK) + link_aot_kernels(tmp_dir) + + # compile test case + M, N, K = 16, 16, 16 + gen_kernel_library(tmp_dir, "libkernel.so") + gen_test_bin(tmp_dir, M, N, K) + + # initialize test data + a, b, a_path, b_path, c_path = generate_matmul_test_data(tmp_dir, M, N, K) + + # run test case + env = os.environ.copy() + env["LD_LIBRARY_PATH"] = tmp_dir + subprocess.run(["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir) + + # read data and compare against reference + c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32) + c_tri = c.reshape((M, N)).view(np.float32) + c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32)) + np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.0) + + def test_compile_link_matmul(): np.random.seed(3) with tempfile.TemporaryDirectory() as tmp_dir: - dtype = "fp16" BM, BN, BK = 16, 16, 16 @@ -250,7 +347,7 @@ def test_compile_link_matmul(): c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32) c_tri = c.reshape((M, N)).view(np.float32) c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32)) - np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.) + np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.0) def test_launcher_has_no_available_kernel(): @@ -275,7 +372,13 @@ def test_launcher_has_no_available_kernel(): # run test case env = os.environ.copy() env["LD_LIBRARY_PATH"] = tmp_dir - result = subprocess.run(["./test", a_path, b_path, c_path], env=env, cwd=tmp_dir, capture_output=True, text=True) + result = subprocess.run( + ["./test", a_path, b_path, c_path], + env=env, + cwd=tmp_dir, + capture_output=True, + text=True, + ) # It should fail since the launcher requires all the strides be 1 while they are not. assert result.returncode == -6 @@ -286,7 +389,6 @@ def test_compile_link_autotune_matmul(): np.random.seed(3) with tempfile.TemporaryDirectory() as tmp_dir: - dtype = "fp16" kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src) @@ -319,7 +421,12 @@ def test_compile_link_autotune_matmul(): env = os.environ.copy() env["LD_LIBRARY_PATH"] = tmp_dir - subprocess.run([f"./{test_name}", a_path, b_path, c_path], check=True, cwd=tmp_dir, env=env) + subprocess.run( + [f"./{test_name}", a_path, b_path, c_path], + check=True, + cwd=tmp_dir, + env=env, + ) # read data and compare against reference c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32) diff --git a/python/triton/__init__.py b/python/triton/__init__.py index aca365dad..55484acd5 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -45,12 +45,12 @@ __all__ = [ "tools", ] - # ------------------------------------- # misc. utilities that don't fit well # into any specific module # ------------------------------------- + def cdiv(x: int, y: int): return (x + y - 1) // y diff --git a/python/triton/common/backend.py b/python/triton/common/backend.py index b76953f5d..486788483 100644 --- a/python/triton/common/backend.py +++ b/python/triton/common/backend.py @@ -1,5 +1,5 @@ - import functools +import hashlib import importlib import importlib.util import os @@ -10,8 +10,12 @@ from typing import Dict from ..runtime.driver import DriverBase +TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +TRITON_VERSION = "2.1.0" + class BaseBackend: + def __init__(self, device_type: str) -> None: self.device_type = device_type @@ -104,7 +108,7 @@ def get_backend(device_type: str): def _path_to_binary(binary: str): base_dir = os.path.join(os.path.dirname(__file__), os.pardir) paths = [ - os.environ.get("TRITON_PTXAS_PATH", ""), + os.environ.get(f"TRITON_{binary.upper()}_PATH", ""), os.path.join(base_dir, "third_party", "cuda", "bin", binary) ] @@ -132,3 +136,48 @@ def path_to_cuobjdump(): @functools.lru_cache() def path_to_nvdisasm(): return _path_to_binary("nvdisasm") + + +@functools.lru_cache() +def compute_core_version_key(): + import pkgutil + contents = [] + # frontend + with open(__file__, "rb") as f: + contents += [hashlib.sha1(f.read()).hexdigest()] + # compiler + compiler_path = os.path.join(TRITON_PATH, 'compiler') + for lib in pkgutil.iter_modules([compiler_path]): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha1(f.read()).hexdigest()] + # backend + libtriton_hash = hashlib.sha1() + with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f: + while True: + chunk = f.read(1024**2) + if not chunk: + break + libtriton_hash.update(chunk) + contents.append(libtriton_hash.hexdigest()) + # language + language_path = os.path.join(TRITON_PATH, 'language') + for lib in pkgutil.iter_modules([language_path]): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha1(f.read()).hexdigest()] + return '-'.join(TRITON_VERSION) + '-'.join(contents) + + +_cached_cuda_version_key = None + + +def get_cuda_version_key(): + global _cached_cuda_version_key + if _cached_cuda_version_key is None: + key = compute_core_version_key() + try: + ptxas = path_to_ptxas()[0] + ptxas_version = subprocess.check_output([ptxas, "--version"]) + except RuntimeError: + ptxas_version = b"NO_PTXAS" + _cached_cuda_version_key = key + '-' + hashlib.sha1(ptxas_version).hexdigest() + return _cached_cuda_version_key diff --git a/python/triton/common/build.py b/python/triton/common/build.py index e7948e7b2..536586ddd 100644 --- a/python/triton/common/build.py +++ b/python/triton/common/build.py @@ -92,9 +92,15 @@ def _build(name, src, srcdir): py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] if is_hip(): - ret = subprocess.check_call([cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{hip_lib_dir}", "-lamdhip64", "-o", so]) + ret = subprocess.check_call([ + cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", + f"-L{hip_lib_dir}", "-lamdhip64", "-o", so + ]) else: - cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so] + cc_cmd = [ + cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", + "-o", so + ] cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs] ret = subprocess.check_call(cc_cmd) diff --git a/python/triton/compiler/__init__.py b/python/triton/compiler/__init__.py index 89f46a1fe..fd0665e1e 100644 --- a/python/triton/compiler/__init__.py +++ b/python/triton/compiler/__init__.py @@ -1,5 +1,8 @@ -from .compiler import (CompiledKernel, compile, get_arch_default_num_stages, - get_arch_default_num_warps, instance_descriptor) +from .compiler import (CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps, + instance_descriptor) from .errors import CompilationError -__all__ = ["compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps", "get_arch_default_num_stages"] +__all__ = [ + "compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps", + "get_arch_default_num_stages" +] diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 1bba2d079..8dbc1587a 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -10,8 +10,7 @@ from .._C.libtriton.triton import ir from ..language import constexpr, tensor # ideally we wouldn't need any runtime component from ..runtime import JITFunction -from .errors import (CompilationError, CompileTimeAssertionFailure, - UnsupportedLanguageConstruct) +from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) def mangle_ty(ty): @@ -68,7 +67,10 @@ def _check_fn_args(node, fn, args): if fn.noinline: for idx, arg in enumerate(args): if not _is_constexpr(arg) and not _is_triton_scalar(arg): - raise UnsupportedLanguageConstruct(fn.src, node, f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}') + raise UnsupportedLanguageConstruct( + fn.src, node, + f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}' + ) def _get_fn_file_line(fn): @@ -89,6 +91,7 @@ _condition_types = {bool, int, type(None)} # Python types accepted for conditio class enter_sub_region: + def __init__(self, generator): self.generator = generator @@ -109,6 +112,7 @@ class enter_sub_region: # Check if the given syntax node has an "early" return class ContainsReturnChecker(ast.NodeVisitor): + def __init__(self, gscope): self.gscope = gscope @@ -199,9 +203,10 @@ class ContainsReturnChecker(ast.NodeVisitor): class CodeGenerator(ast.NodeVisitor): - def __init__(self, context, prototype, gscope, attributes, constants, function_name, target, - module=None, is_kernel=False, function_types: Optional[Dict] = None, - debug=False, noinline=False, file_name: Optional[str] = None, begin_line=0): + + def __init__(self, context, prototype, gscope, attributes, constants, function_name, target, module=None, + is_kernel=False, function_types: Optional[Dict] = None, debug=False, noinline=False, + file_name: Optional[str] = None, begin_line=0): self.context = context self.builder = ir.builder(context) self.file_name = file_name @@ -237,8 +242,10 @@ class CodeGenerator(ast.NodeVisitor): )) def _define_name_lookup(self): + def local_lookup(name: str, absent): - value = self.lscope.get(name, absent) # this needs to be re-fetched from `self` every time, because it gets switched occasionally + # this needs to be re-fetched from `self` every time, because it gets switched occasionally + value = self.lscope.get(name, absent) if value is not absent and name not in self.local_defs: self.global_uses[name] = value return value @@ -255,8 +262,7 @@ class CodeGenerator(ast.NodeVisitor): return name_lookup - def set_value(self, name: str, - value: Union[tensor, constexpr]) -> None: + def set_value(self, name: str, value: Union[tensor, constexpr]) -> None: ''' This function: called by visit_Assign() & visit_FunctionDef() to store left value (lvalue) 1. record local defined name (FIXME: should consider control flow) @@ -338,7 +344,8 @@ class CodeGenerator(ast.NodeVisitor): self.visit(init_node) # initialize function visibility = "public" if self.is_kernel else "private" - self.fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility, self.noinline) + self.fn = self.builder.get_or_insert_function(self.module, self.function_name, + self.prototype.to_ir(self.builder), visibility, self.noinline) self.module.push_back(self.fn) entry = self.fn.add_entry_block() arg_values = [] @@ -469,12 +476,23 @@ class CodeGenerator(ast.NodeVisitor): rhs = self.visit(node.right) method_name = self._method_name_for_bin_op.get(type(node.op)) if method_name is None: - raise UnsupportedLanguageConstruct(None, node, "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__)) + raise UnsupportedLanguageConstruct( + None, node, "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__)) return self._apply_binary_method(method_name, lhs, rhs) + _method_name_for_bin_op: Dict[Type[ast.operator], str] = { - ast.Add: '__add__', ast.Sub: '__sub__', ast.Mult: '__mul__', ast.Div: '__truediv__', - ast.FloorDiv: '__floordiv__', ast.Mod: '__mod__', ast.Pow: '__pow__', - ast.LShift: '__lshift__', ast.RShift: '__rshift__', ast.BitAnd: '__and__', ast.BitOr: '__or__', ast.BitXor: '__xor__', + ast.Add: '__add__', + ast.Sub: '__sub__', + ast.Mult: '__mul__', + ast.Div: '__truediv__', + ast.FloorDiv: '__floordiv__', + ast.Mod: '__mod__', + ast.Pow: '__pow__', + ast.LShift: '__lshift__', + ast.RShift: '__rshift__', + ast.BitAnd: '__and__', + ast.BitOr: '__or__', + ast.BitXor: '__xor__', } def visit_then_else_blocks(self, node, liveins, then_block, else_block): @@ -508,7 +526,8 @@ class CodeGenerator(ast.NodeVisitor): if name in then_defs or name in else_defs: names.append(name) ret_types.append(then_defs[name].type if name in then_defs else else_defs[name].type) - ir_ret_types.append(then_defs[name].handle.get_type() if name in then_defs else else_defs[name].handle.get_type()) + ir_ret_types.append(then_defs[name].handle.get_type() if name in + then_defs else else_defs[name].handle.get_type()) # variable defined in then but not in else if name in then_defs and name not in else_defs: else_defs[name] = liveins[name] @@ -602,8 +621,7 @@ class CodeGenerator(ast.NodeVisitor): contains_return = ContainsReturnChecker(self.gscope).visit(node) if self.scf_stack and contains_return: raise UnsupportedLanguageConstruct( - None, node, - "Cannot have `return` statements inside `while` or `for` statements in triton " + None, node, "Cannot have `return` statements inside `while` or `for` statements in triton " "(note that this also applies to `return` statements that are inside functions " "transitively called from within `while`/`for` statements)") elif self.scf_stack or not contains_return: @@ -612,10 +630,13 @@ class CodeGenerator(ast.NodeVisitor): self.visit_if_top_level(cond, node) else: cond = _unwrap_if_constexpr(cond) - if type(cond) not in _condition_types: # not isinstance - we insist the real thing, no subclasses and no ducks + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: raise UnsupportedLanguageConstruct( - None, node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( - ', '.join(_.__name__ for _ in _condition_types), type(cond).__name__)) + None, node, + "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) if cond: self.visit_compound_statement(node.body) else: @@ -624,15 +645,52 @@ class CodeGenerator(ast.NodeVisitor): def visit_IfExp(self, node): cond = self.visit(node.test) if _is_triton_tensor(cond): - raise UnsupportedLanguageConstruct( - None, node, - "Triton does not support `if` expressions (ternary operators) with dynamic conditions, use `if` statements instead") + cond = cond.to(language.int1, _builder=self.builder) + # TODO: Deal w/ more complicated return types (e.g tuple) + with enter_sub_region(self): + ip, last_loc = self._get_insertion_point_and_loc() + + then_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(then_block) + then_val = language.core._to_tensor(self.visit(node.body), self.builder) + then_block = self.builder.get_insertion_block() + + else_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(else_block) + # do not need to reset lscope since + # ternary expressions cannot define new variables + else_val = language.core._to_tensor(self.visit(node.orelse), self.builder) + else_block = self.builder.get_insertion_block() + + self._set_insertion_point_and_loc(ip, last_loc) + + assert then_val.type == else_val.type, \ + f'ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}' + ret_type = then_val.type + + ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else [] + if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + if ret_type_ir: + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + self.builder.create_yield_op([then_val.handle]) + + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + else_block.merge_block_before(if_op.get_else_block()) + if ret_type_ir: + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + self.builder.create_yield_op([else_val.handle]) + return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None else: cond = _unwrap_if_constexpr(cond) - if type(cond) not in _condition_types: # not isinstance - we insist the real thing, no subclasses and no ducks + + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: raise UnsupportedLanguageConstruct( - None, node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( - ', '.join(_.__name__ for _ in _condition_types), type(cond).__name__)) + None, node, + "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) if cond: return self.visit(node.body) else: @@ -654,8 +712,10 @@ class CodeGenerator(ast.NodeVisitor): return constexpr(lhs_value is not rhs_value) method_name = self._method_name_for_comp_op.get(type(node.ops[0])) if method_name is None: - raise UnsupportedLanguageConstruct(None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__)) + raise UnsupportedLanguageConstruct( + None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__)) return self._apply_binary_method(method_name, lhs, rhs) + _method_name_for_comp_op: Dict[Type[ast.cmpop], str] = { ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__' } @@ -664,11 +724,15 @@ class CodeGenerator(ast.NodeVisitor): op = self.visit(node.operand) fn = self._method_name_for_unary_op.get(type(node.op)) if fn is None: - raise UnsupportedLanguageConstruct(None, node, "AST unary operator '{}' is not (currently) implemented.".format(node.op.__name__)) + raise UnsupportedLanguageConstruct( + None, node, "AST unary operator '{}' is not (currently) implemented.".format(node.op.__name__)) if _is_triton_tensor(op): return getattr(op, fn)(_builder=self.builder) return getattr(op, fn)() - _method_name_for_unary_op: Dict[Type[ast.unaryop], str] = {ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'} + + _method_name_for_unary_op: Dict[Type[ast.unaryop], str] = { + ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__' + } def visit_While(self, node): with enter_sub_region(self) as sr: @@ -763,9 +827,7 @@ class CodeGenerator(ast.NodeVisitor): iter_args = [self.visit(arg) for arg in node.iter.args] if IteratorClass == language.static_range: iterator = IteratorClass(*iter_args) - static_range = range(iterator.start.value, - iterator.end.value, - iterator.step.value) + static_range = range(iterator.start.value, iterator.end.value, iterator.step.value) for i in static_range: self.lscope[node.target.id] = constexpr(i) self.visit_compound_statement(node.body) @@ -902,8 +964,7 @@ class CodeGenerator(ast.NodeVisitor): def call_JitFunction(self, fn: JITFunction, args, kwargs): args = inspect.getcallargs(fn.fn, *args, **kwargs) args = [args[name] for name in fn.arg_names] - args = [arg if _is_triton_tensor(arg) - else constexpr(arg) for arg in args] + args = [arg if _is_triton_tensor(arg) else constexpr(arg) for arg in args] # generate function def attributes = dict() constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)] @@ -921,8 +982,9 @@ class CodeGenerator(ast.NodeVisitor): debug = self.debug if fn.debug is None else fn.debug file_name, begin_line = _get_fn_file_line(fn) generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, - function_name=fn_name, function_types=self.function_ret_types, debug=debug, noinline=fn.noinline, - file_name=file_name, begin_line=begin_line, target=self.builder.target) + function_name=fn_name, function_types=self.function_ret_types, debug=debug, + noinline=fn.noinline, file_name=file_name, begin_line=begin_line, + target=self.builder.target) generator.visit(fn.parse()) callee_ret_type = generator.last_ret_type self.function_ret_types[fn_name] = callee_ret_type @@ -950,7 +1012,7 @@ class CodeGenerator(ast.NodeVisitor): kws = dict(self.visit(keyword) for keyword in node.keywords) args = [self.visit(arg) for arg in node.args] - if fn is language.core.device_assert: # TODO: this should not be so hardcoded + if fn is language.core.device_assert: # TODO: this should not be so hardcoded if not self.debug: return if isinstance(fn, JITFunction): @@ -971,16 +1033,21 @@ class CodeGenerator(ast.NodeVisitor): def visit_BoolOp(self, node: ast.BoolOp): if len(node.values) != 2: - raise UnsupportedLanguageConstruct(None, node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.") + raise UnsupportedLanguageConstruct( + None, node, + "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.") lhs = self.visit(node.values[0]) rhs = self.visit(node.values[1]) method_name = self._method_name_for_bool_op.get(type(node.op)) if method_name is None: - raise UnsupportedLanguageConstruct(None, node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__)) + raise UnsupportedLanguageConstruct( + None, node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__)) return self._apply_binary_method(method_name, lhs, rhs) + _method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'} if sys.version_info < (3, 8): + def visit_NameConstant(self, node): return constexpr(node.value) @@ -1013,7 +1080,9 @@ class CodeGenerator(ast.NodeVisitor): evaluated = self.visit(value.value) if not _is_constexpr(evaluated): raise UnsupportedLanguageConstruct( - None, node, "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " + str(type(evaluated))) + None, node, + "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " + + str(type(evaluated))) values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value) else: raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value))) @@ -1055,7 +1124,9 @@ class CodeGenerator(ast.NodeVisitor): passed = _unwrap_if_constexpr(self.visit(node.args[0])) if not isinstance(passed, bool): - raise NotImplementedError("Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values") + raise NotImplementedError( + "Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values" + ) if not passed: if arg_count == 1: message = "" @@ -1144,10 +1215,9 @@ def ast_to_ttir(fn, signature, specialization, constants, debug, target): file_name, begin_line = _get_fn_file_line(fn) prototype = language.function_type([], arg_types) - generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, - function_name=function_name, attributes=new_attrs, - is_kernel=True, debug=debug, file_name=file_name, begin_line=begin_line, - target=target) + generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, + attributes=new_attrs, is_kernel=True, debug=debug, file_name=file_name, + begin_line=begin_line, target=target) try: generator.visit(fn.parse()) except CompilationError as e: diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index b849ad265..3bd9c07f3 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -11,25 +11,21 @@ from typing import Any from dataclasses import dataclass -from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs, - compile_ptx_to_cubin, get_env_vars, get_num_warps, - get_shared_memory_size, ir, runtime, - translate_llvmir_to_ptx, +from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs, compile_ptx_to_cubin, get_env_vars, + get_num_warps, get_shared_memory_size, ir, runtime, translate_llvmir_to_ptx, translate_triton_gpu_to_llvmir) -from ..common.backend import get_backend, path_to_ptxas +from ..common.backend import get_backend, get_cuda_version_key, path_to_ptxas from ..common.build import is_hip # from ..runtime import driver, jit, JITFunction # TODO: runtime.errors from ..runtime.autotuner import OutOfResources from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager from ..runtime.driver import driver -from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device, - get_device_capability, version_key) +from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device, get_device_capability) from ..tools.disasm import get_sass from .code_generator import ast_to_ttir from .make_launcher import make_stub -from .utils import (InfoFromBackendForTensorMap, TensorMapManager, - get_ids_of_tensormaps, parse_tma_info) +from .utils import (InfoFromBackendForTensorMap, TensorMapManager, get_ids_of_tensormaps, parse_tma_info) CUDA_DEFAULT_WARP_SIZE = 32 @@ -45,6 +41,7 @@ def _is_cuda(target): class LazyDict(dict): + def __getitem__(self, key): val = dict.__getitem__(self, key) if callable(val): @@ -103,8 +100,13 @@ def ttir_to_ttgir(mod, num_warps, warpsize, num_ctas, target): return mod +<<<<<<< HEAD def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_inst_type): +======= +def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization, + enable_persistent, optimize_epilogue): +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 is_cuda = _is_cuda(target) if is_cuda: capability = target.capability @@ -128,9 +130,13 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, if optimize_epilogue: pm.add_tritongpu_optimize_epilogue_pass() pm.add_tritongpu_optimize_dot_operands_pass() +<<<<<<< HEAD if num_stages == 0 and is_hip() and gpu_matrix_core_version() != 0: pm.add_tritongpu_stream_pipeline_pass() pm.add_canonicalizer_pass() +======= + pm.add_cse_pass() +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 ws_enabled = False # `num_warps` does not mean the total number of warps of a CTA when # warp specialization is enabled. @@ -174,6 +180,8 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, if is_cuda and capability // 10 >= 9: pm.add_tritongpu_fence_insertion_pass() pm.add_tritongpu_ws_fixup_missing_attrs_pass() + pm.add_tritongpu_optimize_thread_locality_pass() + pm.add_canonicalizer_pass() pm.run(mod) return mod @@ -197,6 +205,7 @@ def ttgir_to_llir(mod, extern_libs, target, tma_infos, waves_per_eu=0): # PTX translation + @functools.lru_cache() def ptx_get_version(cuda_version) -> int: ''' @@ -261,7 +270,11 @@ def convert_type_repr(x): return x -def make_hash(fn, target, env_vars, **kwargs): +def make_hash(fn, target, env_vars, device_backend, **kwargs): + if device_backend is None: + version_key = get_cuda_version_key() + else: + version_key = device_backend.get_version_key() if isinstance(fn, JITFunction): configs = kwargs["configs"] signature = kwargs["signature"] @@ -275,16 +288,21 @@ def make_hash(fn, target, env_vars, **kwargs): enable_persistent = kwargs.get("enable_persistent", False) debug = kwargs.get("debug", False) # Get unique key for the compiled code - get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8)) + get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), + sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8)) configs_key = [get_conf_key(conf) for conf in configs] env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())] +<<<<<<< HEAD key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{matrix_instr_nonkdim}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}" +======= + key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}" +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 return hashlib.md5(key.encode("utf-8")).hexdigest() assert isinstance(fn, str) ignore_version = kwargs.get('ignore_version', False) if (ignore_version): return hashlib.md5((Path(fn).read_text()).encode("utf-8")).hexdigest() - return hashlib.md5((Path(fn).read_text() + version_key()).encode("utf-8")).hexdigest() + return hashlib.md5((Path(fn).read_text() + version_key).encode("utf-8")).hexdigest() # - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func, @@ -321,12 +339,14 @@ else: def _get_jsonable_constants(constants): + def _is_jsonable(x): try: json.dumps(x) return True except (TypeError, OverflowError): return False + serialized_constants = {} for constant in constants: if _is_jsonable(constants[constant]): @@ -341,7 +361,9 @@ def parse_mlir_module(path, context): return module -instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], defaults=[set(), set(), set(), set()]) +instance_descriptor = namedtuple("instance_descriptor", + ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], + defaults=[set(), set(), set(), set()]) def is_hip(): @@ -385,10 +407,16 @@ def get_arch_default_num_stages(device_type, capability=None): def add_cuda_stages(target, extern_libs, stages): +<<<<<<< HEAD stages["ptx"] = (lambda path: Path(path).read_text(), lambda src: llir_to_ptx(src, target)) stages["cubin"] = (lambda path: Path(path).read_bytes(), lambda src: ptx_to_cubin(src, target)) +======= + + stages["ptx"] = (lambda path: Path(path).read_text(), lambda src: llir_to_ptx(src, target)) + stages["cubin"] = (lambda path: Path(path).read_bytes(), lambda src: ptx_to_cubin(src, target)) +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 def compile(fn, **kwargs): @@ -434,7 +462,8 @@ def compile(fn, **kwargs): # build architecture descriptor if device_type == "cuda": _device_backend = get_backend(device_type) - target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps, enable_fp_fusion=enable_fp_fusion) + target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps, + enable_fp_fusion=enable_fp_fusion) else: _device_backend = get_backend(device_type) assert _device_backend @@ -443,11 +472,12 @@ def compile(fn, **kwargs): # build compilation stages stages = dict() stages["ast"] = (lambda path: fn, None) - stages["ttir"] = (lambda path: parse_mlir_module(path, context), - lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target)) + stages["ttir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttir( + ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target)) if is_cuda: - stages["ttgir"] = (lambda path: parse_mlir_module(path, context), - lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue)) + stages["ttgir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttgir( + ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info, + enable_warp_specialization, enable_persistent, optimize_epilogue)) stages["llir"] = (lambda path: Path(path).read_text(), lambda src: ttgir_to_llir(src, extern_libs, target, tma_infos)) add_cuda_stages(target, extern_libs, stages) @@ -507,18 +537,21 @@ def compile(fn, **kwargs): if ir_name == 'ttgir': num_warps_matches = re.findall(ttgir_num_warps_pattern, src) assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps" - assert "num_warps" not in kwargs or int(num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile" + assert "num_warps" not in kwargs or int( + num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile" num_warps = int(num_warps_matches[0]) param_tys = [convert_type_repr(ty) for ty in types] signature = {k: v for k, v in enumerate(param_tys)} first_stage = list(stages.keys()).index(ir_name) # create cache manager - fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), **kwargs)) + fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), _device_backend, **kwargs)) # managers used to dump and override IR for debugging enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1" - fn_override_manager = get_override_manager(make_hash(fn, target, get_env_vars(), **kwargs, ignore_version=True)) - fn_dump_manager = get_dump_manager(make_hash(fn, target, get_env_vars(), **kwargs, ignore_version=True)) + fn_override_manager = get_override_manager( + make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True)) + fn_dump_manager = get_dump_manager( + make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True)) # determine name and extension type of provided function if isinstance(fn, JITFunction): @@ -531,9 +564,7 @@ def compile(fn, **kwargs): metadata_filename = f"{name}.json" # The group is addressed by the metadata - metadata_group = fn_cache_manager.get_group( - metadata_filename - ) or {} + metadata_group = fn_cache_manager.get_group(metadata_filename) or {} metadata_path = metadata_group.get(metadata_filename) @@ -541,9 +572,9 @@ def compile(fn, **kwargs): with open(metadata_path) as f: metadata = json.load(f) if 'tensormaps_info' in metadata: - metadata['tensormaps_info'] = [ - InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']] + metadata['tensormaps_info'] = [InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']] else: +<<<<<<< HEAD metadata = {"num_warps": num_warps, "warp_size": warp_size, "num_ctas": num_ctas, @@ -555,6 +586,18 @@ def compile(fn, **kwargs): "constants": _get_jsonable_constants(constants), "debug": debug, "target": target, } +======= + metadata = { + "num_warps": num_warps, + "num_ctas": num_ctas, + "num_stages": num_stages, + "enable_warp_specialization": enable_warp_specialization, + "enable_persistent": enable_persistent, + "constants": _get_jsonable_constants(constants), + "debug": debug, + "target": target, + } +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 metadata.update(get_env_vars()) if ext == "ptx": assert "shared" in kwargs, "ptx compilation must provide shared memory size" @@ -626,10 +669,7 @@ def compile(fn, **kwargs): ids_of_folded_args = tuple([int(k) for k in configs[0].ids_of_folded_args]) if isinstance(fn, JITFunction) else () if "clusterDims" not in metadata: - metadata["clusterDims"] = [ - cluster_info.clusterDimX, - cluster_info.clusterDimY, - cluster_info.clusterDimZ] + metadata["clusterDims"] = [cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ] if len(tma_infos) > 0: metadata["tensormaps_info"] = parse_tma_info(tma_infos, ids_of_folded_args) @@ -643,7 +683,10 @@ def compile(fn, **kwargs): fn.tensormaps_info = metadata["tensormaps_info"] ids_of_const_exprs = tuple(fn.constexprs) if isinstance(fn, JITFunction) else () - ids = {"ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs": ids_of_const_exprs} + ids = { + "ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs": + ids_of_const_exprs + } # cache manager if is_cuda: so_path = make_stub(name, signature, constants, ids, enable_warp_specialization=enable_warp_specialization) @@ -651,7 +694,8 @@ def compile(fn, **kwargs): so_path = _device_backend.make_launcher_stub(name, signature, constants, ids) # write-back metadata, if it didn't come from the cache if metadata_path is None: - metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, binary=False) + metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, + binary=False) fn_cache_manager.put_group(metadata_filename, metadata_group) # return handle to compiled kernel @@ -701,10 +745,7 @@ class CompiledKernel: if self.device_type in ["cuda"]: device = get_current_device() - bin_path = { - driver.HIP: "hsaco_path", - driver.CUDA: "cubin" - }[driver.backend] + bin_path = {driver.HIP: "hsaco_path", driver.CUDA: "cubin"}[driver.backend] max_shared = driver.utils.get_device_properties(device)["max_shared_mem"] fn_load_binary = driver.utils.load_binary else: @@ -752,4 +793,5 @@ class CompiledKernel: self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.clusterDims[0], self.clusterDims[1], self.clusterDims[2], self.shared, stream, self.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand) + return runner diff --git a/python/triton/compiler/make_launcher.py b/python/triton/compiler/make_launcher.py index c7dd75ec7..52a8f74a1 100644 --- a/python/triton/compiler/make_launcher.py +++ b/python/triton/compiler/make_launcher.py @@ -3,9 +3,9 @@ import os import tempfile from ..common import _build +from ..common.backend import get_cuda_version_key from ..common.build import is_hip from ..runtime.cache import get_cache_manager -from ..runtime.jit import version_key from .utils import generate_cu_signature # ----- stub -------- @@ -23,7 +23,7 @@ def make_so_cache_key(version_hash, signature, constants, ids, **kwargs): def make_stub(name, signature, constants, ids, **kwargs): # name of files that are cached - so_cache_key = make_so_cache_key(version_key(), signature, constants, ids, **kwargs) + so_cache_key = make_so_cache_key(get_cuda_version_key(), signature, constants, ids, **kwargs) so_cache_manager = get_cache_manager(so_cache_key) so_name = f"{name}.so" # retrieve stub from cache if it exists @@ -40,6 +40,7 @@ def make_stub(name, signature, constants, ids, **kwargs): else: return cache_path + # ----- source code generation -------- @@ -100,7 +101,10 @@ def generate_launcher(constants, signature, ids): # generate glue code folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']] - params = [i for i in signature.keys() if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs)] + params = [ + i for i in signature.keys() + if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs) + ] src = f""" #include \"cuda.h\" #include diff --git a/python/triton/compiler/utils.py b/python/triton/compiler/utils.py index d4b24a93e..ef629c75a 100644 --- a/python/triton/compiler/utils.py +++ b/python/triton/compiler/utils.py @@ -158,19 +158,21 @@ class InfoFromBackendForTensorMap: # dtype:cuda.CUtensorMapDataType | int def bytes_from_type(self, dtype): - return {driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT8"]: 1, - driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT16"]: 2, - driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT32"]: 4, - driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT32"]: 4, - driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT64"]: 8, - driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT64"]: 8, - driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"]: 2, - driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32"]: 4, - driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT64"]: 8, - driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_BFLOAT16"]: 2, - driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ"]: 4, - driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32"]: 4, - driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ"]: 4}[dtype] + return { + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT8"]: 1, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT16"]: 2, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT32"]: 4, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT32"]: 4, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT64"]: 8, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT64"]: 8, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"]: 2, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32"]: 4, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT64"]: 8, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_BFLOAT16"]: 2, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ"]: 4, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32"]: 4, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ"]: 4 + }[dtype] def getTensorMapDataType(self): return self.tensorDataType @@ -259,22 +261,29 @@ class InfoFromBackendForTensorMap: self.getInterleave(), self.getSwizzle(), self.getL2Promotion(), - self.getOobFill() + self.getOobFill(), ) # make hashable to use as partial key in cache def __hash__(self): - return hash((self.ids_of_folded_args, self.globalAddressArgIdx, tuple(self.globalDimsArgIdx), tuple(self.globalStridesArgIdx), self.tensorDataType, - self.tensorRank, tuple(self.boxDims), tuple(self.elementStrides), self.interleave, self.swizzle, self.l2Promotion, self.oobFill)) + return hash((self.ids_of_folded_args, self.globalAddressArgIdx, tuple(self.globalDimsArgIdx), + tuple(self.globalStridesArgIdx), self.tensorDataType, self.tensorRank, tuple(self.boxDims), + tuple(self.elementStrides), self.interleave, self.swizzle, self.l2Promotion, self.oobFill)) def __eq__(self, other): if not isinstance(other, self.__class__): return False - return (self.ids_of_folded_args, self.globalAddressArgIdx, self.globalDimsArgIdx, self.globalStridesArgIdx, self.tensorDataType, self.tensorRank, self.boxDims, self.elementStrides, self.interleave, self.swizzle, self.l2Promotion, self.oobFill) == ( - other.ids_of_folded_args, other.globalAddressArgIdx, other.globalDimsArgIdx, other.globalStridesArgIdx, other.tensorDataType, other.tensorRank, other.boxDims, other.elementStrides, other.interleave, other.swizzle, other.l2Promotion, other.oobFill) + return (self.ids_of_folded_args, self.globalAddressArgIdx, self.globalDimsArgIdx, self.globalStridesArgIdx, + self.tensorDataType, self.tensorRank, self.boxDims, self.elementStrides, self.interleave, self.swizzle, + self.l2Promotion, + self.oobFill) == (other.ids_of_folded_args, other.globalAddressArgIdx, other.globalDimsArgIdx, + other.globalStridesArgIdx, other.tensorDataType, other.tensorRank, other.boxDims, + other.elementStrides, other.interleave, other.swizzle, other.l2Promotion, + other.oobFill) class TensorMapManager: + def __init__(self): self.tensormaps_device = {} @@ -286,8 +295,7 @@ class TensorMapManager: t_tensormap = e.tensormap(args) TENSORMAP_SIZE_IN_BYTES = 128 t_tensormap_device = driver.utils.cuMemAlloc(TENSORMAP_SIZE_IN_BYTES) - driver.utils.cuMemcpyHtoD( - t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES) + driver.utils.cuMemcpyHtoD(t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES) self.tensormaps_device[key] = t_tensormap_device return int(self.tensormaps_device[key]) diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index dedb1b919..5c06ac951 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -111,7 +111,6 @@ from .random import ( uint32_to_uniform_float, ) - __all__ = [ "TRITON_MAX_TENSOR_NUMEL", "abs", diff --git a/python/triton/language/core.py b/python/triton/language/core.py index a262e7d66..c81c372ce 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -22,10 +22,8 @@ def builtin(fn: T) -> T: @wraps(fn) def wrapper(*args, **kwargs): if "_builder" not in kwargs or kwargs["_builder"] is None: - raise ValueError( - "Did you forget to add @triton.jit ? " - "(`_builder` argument must be provided outside of JIT functions.)" - ) + raise ValueError("Did you forget to add @triton.jit ? " + "(`_builder` argument must be provided outside of JIT functions.)") return fn(*args, **kwargs) setattr(wrapper, TRITON_BUILTIN, True) @@ -54,7 +52,7 @@ def _to_tensor(x, builder): else: raise RuntimeError(f'Nonrepresentable integer {x}.') elif isinstance(x, float): - min_float32 = 2 ** -126 + min_float32 = 2**-126 max_float32 = (2 - 2**-23) * 2**127 abs_x = __builtins__['abs'](x) if abs_x == float("inf") or\ @@ -243,7 +241,7 @@ class dtype: return not self.__eq__(other) def __hash__(self): - return hash((self.name,)) + return hash((self.name, )) @property def scalar(self): @@ -297,6 +295,7 @@ class dtype: class pointer_type(dtype): + def __init__(self, element_ty: dtype, address_space: int = 1): if not isinstance(element_ty, dtype): raise TypeError('element_ty is a {type(element_ty).__name__}.') @@ -331,6 +330,7 @@ class pointer_type(dtype): class block_type(dtype): + def __init__(self, element_ty: dtype, shape: List): self.element_ty = element_ty @@ -381,6 +381,7 @@ class block_type(dtype): class function_type(dtype): + def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None: self.ret_types = ret_types self.param_types = param_types @@ -531,7 +532,7 @@ class constexpr: return constexpr(~self.value) def __pow__(self, other): - return constexpr(self.value ** other.value) + return constexpr(self.value**other.value) def __rshift__(self, other): return constexpr(self.value >> other.value) @@ -547,6 +548,7 @@ class constexpr: class tensor: + def __init__(self, handle, type: dtype): # IR handle self.handle = handle @@ -740,11 +742,21 @@ class tensor: other = _to_tensor(other, _builder) return semantic.equal(self, other, _builder) + @builtin + def __req__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.equal(other, self, _builder) + @builtin def __ne__(self, other, _builder=None): other = _to_tensor(other, _builder) return semantic.not_equal(self, other, _builder) + @builtin + def __rne__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.not_equal(other, self, _builder) + @builtin def logical_and(self, other, _builder=None): other = _to_tensor(other, _builder) @@ -1023,6 +1035,7 @@ def expand_dims(input, axis, _builder=None): ret = semantic.expand_dims(ret, a, _builder) return ret + # ----------------------- # Linear Algebra # ----------------------- @@ -1171,6 +1184,7 @@ def advance(base: tensor, offsets, _builder=None): """ return semantic.advance(base, offsets, _builder) + # ----------------------- # Atomic Memory Operations # ----------------------- @@ -1196,6 +1210,9 @@ def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]: :param sem: Memory semantics to use ("ACQUIRE_RELEASE" (default), "ACQUIRE", "RELEASE", or "RELAXED") :type sem: str + :param scope: Scope of threads that observe synchronizing effect of the + atomic operation ("GPU" (default), "CTA", or "SYSTEM") + :type scope: str """ func.__doc__ = docstr return func @@ -1205,73 +1222,82 @@ def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]: @builtin @_add_atomic_docstr("compare-and-swap", has_cmp=True) -def atomic_cas(pointer, cmp, val, sem=None, _builder=None): +def atomic_cas(pointer, cmp, val, sem=None, scope=None, _builder=None): cmp = _to_tensor(cmp, _builder) val = _to_tensor(val, _builder) sem = _constexpr_to_value(sem) - return semantic.atomic_cas(pointer, cmp, val, sem, _builder) + scope = _constexpr_to_value(scope) + return semantic.atomic_cas(pointer, cmp, val, sem, scope, _builder) @builtin @_add_atomic_docstr("exchange") -def atomic_xchg(pointer, val, mask=None, sem=None, _builder=None): +def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _builder=None): val = _to_tensor(val, _builder) sem = _constexpr_to_value(sem) - return semantic.atomic_xchg(pointer, val, mask, sem, _builder) + scope = _constexpr_to_value(scope) + return semantic.atomic_xchg(pointer, val, mask, sem, scope, _builder) @builtin @_add_atomic_docstr("add") -def atomic_add(pointer, val, mask=None, sem=None, _builder=None): +def atomic_add(pointer, val, mask=None, sem=None, scope=None, _builder=None): val = _to_tensor(val, _builder) sem = _constexpr_to_value(sem) - return semantic.atomic_add(pointer, val, mask, sem, _builder) + scope = _constexpr_to_value(scope) + return semantic.atomic_add(pointer, val, mask, sem, scope, _builder) @builtin @_add_atomic_docstr("max") -def atomic_max(pointer, val, mask=None, sem=None, _builder=None): +def atomic_max(pointer, val, mask=None, sem=None, scope=None, _builder=None): val = _to_tensor(val, _builder) sem = _constexpr_to_value(sem) - return semantic.atomic_max(pointer, val, mask, sem, _builder) + scope = _constexpr_to_value(scope) + return semantic.atomic_max(pointer, val, mask, sem, scope, _builder) @builtin @_add_atomic_docstr("min") -def atomic_min(pointer, val, mask=None, sem=None, _builder=None): +def atomic_min(pointer, val, mask=None, sem=None, scope=None, _builder=None): val = _to_tensor(val, _builder) sem = _constexpr_to_value(sem) - return semantic.atomic_min(pointer, val, mask, sem, _builder) + scope = _constexpr_to_value(scope) + return semantic.atomic_min(pointer, val, mask, sem, scope, _builder) @builtin @_add_atomic_docstr("logical and") -def atomic_and(pointer, val, mask=None, sem=None, _builder=None): +def atomic_and(pointer, val, mask=None, sem=None, scope=None, _builder=None): val = _to_tensor(val, _builder) sem = _constexpr_to_value(sem) - return semantic.atomic_and(pointer, val, mask, sem, _builder) + scope = _constexpr_to_value(scope) + return semantic.atomic_and(pointer, val, mask, sem, scope, _builder) @builtin @_add_atomic_docstr("logical or") -def atomic_or(pointer, val, mask=None, sem=None, _builder=None): +def atomic_or(pointer, val, mask=None, sem=None, scope=None, _builder=None): val = _to_tensor(val, _builder) sem = _constexpr_to_value(sem) - return semantic.atomic_or(pointer, val, mask, sem, _builder) + scope = _constexpr_to_value(scope) + return semantic.atomic_or(pointer, val, mask, sem, scope, _builder) @builtin @_add_atomic_docstr("logical xor") -def atomic_xor(pointer, val, mask=None, sem=None, _builder=None): +def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None): val = _to_tensor(val, _builder) sem = _constexpr_to_value(sem) - return semantic.atomic_xor(pointer, val, mask, sem, _builder) + scope = _constexpr_to_value(scope) + return semantic.atomic_xor(pointer, val, mask, sem, scope, _builder) # ----------------------- # Conditioning # ----------------------- + @builtin def where(condition, x, y, _builder=None): """ @@ -1299,6 +1325,7 @@ def where(condition, x, y, _builder=None): # Math # ----------------------- + @builtin def umulhi(x, y, _builder=None): """ @@ -1392,6 +1419,7 @@ def abs(x, _builder=None): # Reductions # ----------------------- + def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]: def _decorator(func: T) -> T: @@ -1430,8 +1458,7 @@ def reduce(input, axis, combine_fn, _builder=None, _generator=None): """ if isinstance(input, tensor): - return reduce((input,), axis, combine_fn, - _builder=_builder, _generator=_generator)[0] + return reduce((input, ), axis, combine_fn, _builder=_builder, _generator=_generator)[0] def make_combine_region(reduce_op): in_scalar_tys = [t.type.scalar for t in input] @@ -1441,14 +1468,14 @@ def reduce(input, axis, combine_fn, _builder=None, _generator=None): with _insertion_guard(_builder): param_types = [ty.to_ir(_builder) for ty in prototype.param_types] block = _builder.create_block_with_parent(region, param_types) - args = [tensor(block.arg(i), ty) - for i, ty in enumerate(prototype.param_types)] + args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] results = _generator.call_JitFunction(combine_fn, args, kwargs={}) if isinstance(results, tensor): handles = [results.handle] else: handles = [r.handle for r in results] _builder.create_reduce_ret(*handles) + if axis is not None: axis = _constexpr_to_value(axis) return semantic.reduction(input, axis, make_combine_region, _builder) @@ -1483,8 +1510,7 @@ def _reduce_with_indices(input, axis, combine_fn, _builder=None, _generator=None index = expand_dims(index, axes_to_expand, _builder=_builder) index = broadcast_to(index, input.shape, _builder=_builder) - rvalue, rindices = reduce((input, index), axis, combine_fn, - _builder=_builder, _generator=_generator) + rvalue, rindices = reduce((input, index), axis, combine_fn, _builder=_builder, _generator=_generator) return rvalue, rindices @@ -1492,6 +1518,7 @@ def _reduce_with_indices(input, axis, combine_fn, _builder=None, _generator=None # Scans # ----------------------- + def _add_scan_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]: def _decorator(func: T) -> T: @@ -1516,8 +1543,7 @@ def associative_scan(input, axis, combine_fn, _builder=None, _generator=None): """ if isinstance(input, tensor): - return associative_scan((input,), axis, combine_fn, - _builder=_builder, _generator=_generator)[0] + return associative_scan((input, ), axis, combine_fn, _builder=_builder, _generator=_generator)[0] def make_combine_region(scan_op): in_scalar_tys = [t.type.scalar for t in input] @@ -1527,17 +1553,18 @@ def associative_scan(input, axis, combine_fn, _builder=None, _generator=None): with _insertion_guard(_builder): param_types = [ty.to_ir(_builder) for ty in prototype.param_types] block = _builder.create_block_with_parent(region, param_types) - args = [tensor(block.arg(i), ty) - for i, ty in enumerate(prototype.param_types)] + args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] results = _generator.call_JitFunction(combine_fn, args, kwargs={}) if isinstance(results, tensor): handles = [results.handle] else: handles = [r.handle for r in results] _builder.create_scan_ret(*handles) + axis = _constexpr_to_value(axis) return semantic.associative_scan(input, axis, make_combine_region, _builder) + # ----------------------- # Compiler Hint Ops # ----------------------- @@ -1600,6 +1627,8 @@ def max_constancy(input, values, _builder=None): raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") values = [x.value for x in values] return semantic.max_constancy(input, values) + + # ----------------------- # Debugging functions # ----------------------- @@ -1739,12 +1768,12 @@ def inline_asm_elementwise(asm: str, constraints: str, args: list, dtype, is_pur broadcast_arg = dispatch_args[0] # Get the broadcast shape over all the arguments for i, item in enumerate(dispatch_args): - _, broadcast_arg = semantic.binary_op_type_checking_impl( - item, broadcast_arg, _builder, arithmetic_check=False) + _, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder, + arithmetic_check=False) # Change the shape of each argument based on the broadcast shape for i in range(len(dispatch_args)): - dispatch_args[i], _ = semantic.binary_op_type_checking_impl( - dispatch_args[i], broadcast_arg, _builder, arithmetic_check=False) + dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder, + arithmetic_check=False) ret_shape = broadcast_arg.shape res_ty = block_type(dtype, ret_shape) call = _builder.create_inline_asm(asm, constraints, [t.handle for t in args], res_ty.to_ir(_builder), is_pure, pack) @@ -1757,7 +1786,6 @@ def inline_asm_elementwise(asm: str, constraints: str, args: list, dtype, is_pur class static_range: - """ Iterator that counts upward forever. @@ -1801,7 +1829,9 @@ class static_range: # Extern functions # ----------------------- -def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, is_pure: bool, _builder=None): + +def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, + is_pure: bool, _builder=None): ''' Dispatch a function to a library :param func: the function to dispatch @@ -1843,7 +1873,8 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder), is_pure), ret_type) -def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, _builder=None): +def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, + _builder=None): ''' Dispatch an elementwise function to a library :param lib_name: the name of the library @@ -1872,12 +1903,12 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol broadcast_arg = dispatch_args[0] # Get the broadcast shape over all the arguments for i, item in enumerate(dispatch_args): - _, broadcast_arg = semantic.binary_op_type_checking_impl( - item, broadcast_arg, _builder, arithmetic_check=arithmetic_check) + _, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder, + arithmetic_check=arithmetic_check) # Change the shape of each argument based on the broadcast shape for i in range(len(dispatch_args)): - dispatch_args[i], _ = semantic.binary_op_type_checking_impl( - dispatch_args[i], broadcast_arg, _builder, arithmetic_check=arithmetic_check) + dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder, + arithmetic_check=arithmetic_check) if not all_scalar: ret_shape = broadcast_arg.shape func = getattr(_builder, "create_extern_elementwise") diff --git a/python/triton/language/extra/cuda.py b/python/triton/language/extra/cuda.py index 8c4114739..9400ae797 100644 --- a/python/triton/language/extra/cuda.py +++ b/python/triton/language/extra/cuda.py @@ -3,16 +3,14 @@ from .. import core @core.extern def globaltimer(_builder=None): - return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], - dtype=core.int64, is_pure=False, - pack=1, _builder=_builder) + return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], dtype=core.int64, is_pure=False, pack=1, + _builder=_builder) @core.extern def smid(_builder=None): - return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], - dtype=core.int32, is_pure=True, - pack=1, _builder=_builder) + return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1, + _builder=_builder) @core.builtin diff --git a/python/triton/language/math.py b/python/triton/language/math.py index 6f8b0aced..1cbad660d 100644 --- a/python/triton/language/math.py +++ b/python/triton/language/math.py @@ -18,25 +18,27 @@ def libdevice_path(): @core.extern def clz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int32"),): ("__nv_clz", core.dtype("int32")), - (core.dtype("int64"),): ("__nv_clzll", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("int32"), ): ("__nv_clz", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_clzll", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def popc(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int32"),): ("__nv_popc", core.dtype("int32")), - (core.dtype("int64"),): ("__nv_popcll", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("int32"), ): ("__nv_popc", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_popcll", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def byte_perm(arg0, arg1, arg2, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], - {(core.dtype("int32"), core.dtype("int32"), core.dtype("int32"),): ("__nv_byte_perm", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2], { + (core.dtype("int32"), core.dtype("int32"), core.dtype("int32")): ("__nv_byte_perm", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern @@ -73,1471 +75,1602 @@ def max(arg0, arg1, _builder=None): @core.extern def mulhi(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("int32"), core.dtype("int32"),): ("__nv_mulhi", core.dtype("int32")), - (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umulhi", core.dtype("uint32")), - (core.dtype("int64"), core.dtype("int64"),): ("__nv_mul64hi", core.dtype("int64")), - (core.dtype("uint64"), core.dtype("uint64"),): ("__nv_umul64hi", core.dtype("uint64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_mulhi", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_umulhi", core.dtype("uint32")), + (core.dtype("int64"), core.dtype("int64")): ("__nv_mul64hi", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64")): ("__nv_umul64hi", core.dtype("uint64")), + }, is_pure=True, _builder=_builder) @core.extern def mul24(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("int32"), core.dtype("int32"),): ("__nv_mul24", core.dtype("int32")), - (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umul24", core.dtype("uint32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_mul24", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_umul24", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) @core.extern def brev(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int32"),): ("__nv_brev", core.dtype("int32")), - (core.dtype("int64"),): ("__nv_brevll", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("int32"), ): ("__nv_brev", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_brevll", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def sad(arg0, arg1, arg2, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], - {(core.dtype("int32"), core.dtype("int32"), core.dtype("uint32"),): ("__nv_sad", core.dtype("int32")), - (core.dtype("uint32"), core.dtype("uint32"), core.dtype("uint32"),): ("__nv_usad", core.dtype("uint32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1, arg2], { + (core.dtype("int32"), core.dtype("int32"), core.dtype("uint32")): ("__nv_sad", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"), core.dtype("uint32")): ("__nv_usad", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) @core.extern def abs(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int32"),): ("__nv_abs", core.dtype("int32")), - (core.dtype("int64"),): ("__nv_llabs", core.dtype("int64")), - (core.dtype("fp32"),): ("__nv_fabsf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_fabs", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("int32"), ): ("__nv_abs", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_llabs", core.dtype("int64")), + (core.dtype("fp32"), ): ("__nv_fabsf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_fabs", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def floor(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_floorf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_floor", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_floorf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_floor", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def rcp64h(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_rcp64h", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_rcp64h", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def rsqrt(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_rsqrtf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_rsqrt", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_rsqrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_rsqrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def ceil(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_ceil", core.dtype("fp64")), - (core.dtype("fp32"),): ("__nv_ceilf", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_ceil", core.dtype("fp64")), + (core.dtype("fp32"), ): ("__nv_ceilf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def trunc(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_trunc", core.dtype("fp64")), - (core.dtype("fp32"),): ("__nv_truncf", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_trunc", core.dtype("fp64")), + (core.dtype("fp32"), ): ("__nv_truncf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def exp2(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_exp2f", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_exp2", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_exp2f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_exp2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def saturatef(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_saturatef", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_saturatef", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def fma_rn(arg0, arg1, arg2, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rn", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rn", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def fma_rz(arg0, arg1, arg2, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rz", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rz", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def fma_rd(arg0, arg1, arg2, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rd", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rd", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def fma_ru(arg0, arg1, arg2, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ru", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_ru", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def fast_dividef(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fast_fdividef", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fast_fdividef", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def div_rn(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rn", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rn", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def div_rz(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rz", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rz", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def div_rd(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rd", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rd", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def div_ru(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_ru", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_ru", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def rcp_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_frcp_rn", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_drcp_rn", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_rn", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def rcp_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_frcp_rz", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_drcp_rz", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_rz", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def rcp_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_frcp_rd", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_drcp_rd", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_rd", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def rcp_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_frcp_ru", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_drcp_ru", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_ru", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def sqrt_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_fsqrt_rn", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_dsqrt_rn", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_rn", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def sqrt_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_fsqrt_rz", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_dsqrt_rz", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_rz", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def sqrt_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_fsqrt_rd", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_dsqrt_rd", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_rd", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def sqrt_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_fsqrt_ru", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_dsqrt_ru", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_ru", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def sqrt(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_sqrtf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_sqrt", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_sqrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sqrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def add_rn(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rn", core.dtype("fp64")), - (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rn", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rn", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def add_rz(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rz", core.dtype("fp64")), - (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rz", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rz", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def add_rd(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rd", core.dtype("fp64")), - (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rd", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rd", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def add_ru(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_ru", core.dtype("fp64")), - (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_ru", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_ru", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def mul_rn(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rn", core.dtype("fp64")), - (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rn", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rn", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def mul_rz(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rz", core.dtype("fp64")), - (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rz", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rz", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def mul_rd(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rd", core.dtype("fp64")), - (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rd", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rd", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def mul_ru(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_ru", core.dtype("fp64")), - (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_ru", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__nv_dmul_ru", core.dtype("fp64")), + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__nv_fmul_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def double2float_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2float_rn", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def double2float_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2float_rz", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def double2float_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2float_rd", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def double2float_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2float_ru", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def double2int_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2int_rn", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def double2int_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2int_rz", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def double2int_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2int_rd", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def double2int_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2int_ru", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def double2uint_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2uint_rn", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def double2uint_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2uint_rz", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def double2uint_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2uint_rd", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def double2uint_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2uint_ru", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def int2double_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int32"),): ("__nv_int2double_rn", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int32"), ): ("__nv_int2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def uint2double_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint32"),): ("__nv_uint2double_rn", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def float2int_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2int_rn", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def float2int_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2int_rz", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def float2int_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2int_rd", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def float2int_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2int_ru", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def float2uint_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2uint_rn", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def float2uint_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2uint_rz", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def float2uint_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2uint_rd", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def float2uint_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2uint_ru", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def int2float_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int32"),): ("__nv_int2float_rn", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def int2float_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int32"),): ("__nv_int2float_rz", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def int2float_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int32"),): ("__nv_int2float_rd", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def int2float_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int32"),): ("__nv_int2float_ru", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def uint2float_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint32"),): ("__nv_uint2float_rn", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def uint2float_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint32"),): ("__nv_uint2float_rz", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def uint2float_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint32"),): ("__nv_uint2float_rd", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def uint2float_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint32"),): ("__nv_uint2float_ru", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def hiloint2double(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("int32"), core.dtype("int32"),): ("__nv_hiloint2double", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_hiloint2double", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def double2loint(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2loint", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2loint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def double2hiint(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2hiint", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2hiint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def float2ll_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2ll_rn", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def float2ll_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2ll_rz", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def float2ll_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2ll_rd", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def float2ll_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2ll_ru", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def float2ull_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2ull_rn", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def float2ull_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2ull_rz", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def float2ull_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2ull_rd", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def float2ull_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2ull_ru", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def double2ll_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2ll_rn", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def double2ll_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2ll_rz", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def double2ll_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2ll_rd", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def double2ll_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2ll_ru", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def double2ull_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2ull_rn", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def double2ull_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2ull_rz", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def double2ull_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2ull_rd", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def double2ull_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2ull_ru", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def ll2float_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int64"),): ("__nv_ll2float_rn", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def ll2float_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int64"),): ("__nv_ll2float_rz", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def ll2float_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int64"),): ("__nv_ll2float_rd", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def ll2float_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int64"),): ("__nv_ll2float_ru", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def ull2float_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint64"),): ("__nv_ull2float_rn", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def ull2float_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint64"),): ("__nv_ull2float_rz", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def ull2float_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint64"),): ("__nv_ull2float_rd", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def ull2float_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint64"),): ("__nv_ull2float_ru", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def ll2double_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int64"),): ("__nv_ll2double_rn", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def ll2double_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int64"),): ("__nv_ll2double_rz", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def ll2double_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int64"),): ("__nv_ll2double_rd", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def ll2double_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int64"),): ("__nv_ll2double_ru", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def ull2double_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint64"),): ("__nv_ull2double_rn", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def ull2double_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint64"),): ("__nv_ull2double_rz", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def ull2double_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint64"),): ("__nv_ull2double_rd", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def ull2double_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint64"),): ("__nv_ull2double_ru", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def int_as_float(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int32"),): ("__nv_int_as_float", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int32"), ): ("__nv_int_as_float", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def float_as_int(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float_as_int", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float_as_int", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def uint_as_float(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint32"),): ("__nv_uint_as_float", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint32"), ): ("__nv_uint_as_float", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def float_as_uint(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float_as_uint", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float_as_uint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def longlong_as_double(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int64"),): ("__nv_longlong_as_double", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int64"), ): ("__nv_longlong_as_double", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def double_as_longlong(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double_as_longlong", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double_as_longlong", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def fast_sinf(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_fast_sinf", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_sinf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def fast_cosf(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_fast_cosf", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_cosf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def fast_log2f(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_fast_log2f", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_log2f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def fast_logf(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_fast_logf", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_logf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def fast_expf(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_fast_expf", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_expf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def fast_tanf(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_fast_tanf", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_tanf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def fast_exp10f(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_fast_exp10f", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_exp10f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def fast_log10f(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_fast_log10f", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_log10f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def fast_powf(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fast_powf", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fast_powf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def hadd(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("int32"), core.dtype("int32"),): ("__nv_hadd", core.dtype("int32")), - (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_uhadd", core.dtype("uint32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_hadd", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_uhadd", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) @core.extern def rhadd(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("int32"), core.dtype("int32"),): ("__nv_rhadd", core.dtype("int32")), - (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_urhadd", core.dtype("uint32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_rhadd", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_urhadd", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) @core.extern def sub_rn(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rn", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rn", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def sub_rz(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rz", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rz", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def sub_rd(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rd", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rd", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def sub_ru(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_ru", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_ru", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def rsqrt_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_frsqrt_rn", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_frsqrt_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def ffs(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int32"),): ("__nv_ffs", core.dtype("int32")), - (core.dtype("int64"),): ("__nv_ffsll", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [ + arg0, + ], { + (core.dtype("int32"), ): ("__nv_ffs", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_ffsll", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def rint(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_rintf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_rint", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_rintf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_rint", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def llrint(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_llrintf", core.dtype("int64")), - (core.dtype("fp64"),): ("__nv_llrint", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_llrintf", core.dtype("int64")), + (core.dtype("fp64"), ): ("__nv_llrint", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def nearbyint(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_nearbyintf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_nearbyint", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_nearbyintf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_nearbyint", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def isnan(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_isnanf", core.dtype("int32")), - (core.dtype("fp64"),): ("__nv_isnand", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_isnanf", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_isnand", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def signbit(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_signbitf", core.dtype("int32")), - (core.dtype("fp64"),): ("__nv_signbitd", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_signbitf", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_signbitd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def copysign(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_copysignf", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_copysign", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_copysignf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_copysign", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def finitef(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_finitef", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_finitef", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def isinf(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_isinff", core.dtype("int32")), - (core.dtype("fp64"),): ("__nv_isinfd", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_isinff", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_isinfd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def nextafter(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_nextafterf", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_nextafter", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_nextafterf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_nextafter", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def sin(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_sinf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_sin", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_sinf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sin", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def cos(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_cosf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_cos", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_cosf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cos", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def sinpi(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_sinpif", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_sinpi", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_sinpif", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sinpi", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def cospi(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_cospif", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_cospi", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_cospif", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cospi", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def tan(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_tanf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_tan", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_tanf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_tan", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def log2(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_log2f", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_log2", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_log2f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def exp(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_expf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_exp", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_expf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_exp", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def exp10(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_exp10f", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_exp10", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_exp10f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_exp10", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def cosh(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_coshf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_cosh", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_coshf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cosh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def sinh(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_sinhf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_sinh", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_sinhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sinh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def tanh(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_tanhf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_tanh", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_tanhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_tanh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def atan2(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_atan2f", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_atan2", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_atan2f", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_atan2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def atan(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_atanf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_atan", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_atanf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_atan", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def asin(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_asinf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_asin", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_asinf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_asin", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def acos(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_acosf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_acos", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_acosf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_acos", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def log(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_logf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_log", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_logf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def log10(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_log10f", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_log10", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_log10f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log10", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def log1p(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_log1pf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_log1p", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_log1pf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log1p", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def acosh(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_acoshf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_acosh", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_acoshf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_acosh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def asinh(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_asinhf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_asinh", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_asinhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_asinh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def atanh(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_atanhf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_atanh", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_atanhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_atanh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def expm1(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_expm1f", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_expm1", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_expm1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_expm1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def hypot(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_hypotf", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_hypot", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_hypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_hypot", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def rhypot(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_rhypotf", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_rhypot", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_rhypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_rhypot", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def norm3d(arg0, arg1, arg2, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_norm3df", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_norm3d", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_norm3df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_norm3d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def rnorm3d(arg0, arg1, arg2, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_rnorm3df", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_rnorm3d", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_rnorm3df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_rnorm3d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def norm4d(arg0, arg1, arg2, arg3, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, arg3, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_norm4df", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_norm4d", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1, arg2, arg3], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): + ("__nv_norm4df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): + ("__nv_norm4d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, arg3, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_rnorm4df", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_rnorm4d", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1, arg2, arg3], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): + ("__nv_rnorm4df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): + ("__nv_rnorm4d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def cbrt(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_cbrtf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_cbrt", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_cbrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cbrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def rcbrt(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_rcbrtf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_rcbrt", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_rcbrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_rcbrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def j0(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_j0f", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_j0", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_j0f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_j0", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def j1(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_j1f", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_j1", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_j1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_j1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def y0(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_y0f", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_y0", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_y0f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_y0", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def y1(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_y1f", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_y1", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_y1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_y1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def yn(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("int32"), core.dtype("fp32"),): ("__nv_ynf", core.dtype("fp32")), - (core.dtype("int32"), core.dtype("fp64"),): ("__nv_yn", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("int32"), core.dtype("fp32")): ("__nv_ynf", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64")): ("__nv_yn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def jn(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("int32"), core.dtype("fp32"),): ("__nv_jnf", core.dtype("fp32")), - (core.dtype("int32"), core.dtype("fp64"),): ("__nv_jn", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("int32"), core.dtype("fp32")): ("__nv_jnf", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64")): ("__nv_jn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def cyl_bessel_i0(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_cyl_bessel_i0f", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_cyl_bessel_i0", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_cyl_bessel_i0f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cyl_bessel_i0", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def cyl_bessel_i1(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_cyl_bessel_i1f", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_cyl_bessel_i1", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_cyl_bessel_i1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cyl_bessel_i1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def erf(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_erff", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_erf", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_erff", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erf", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def erfinv(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_erfinvf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_erfinv", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_erfinvf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfinv", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def erfc(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_erfcf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_erfc", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_erfcf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfc", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def erfcx(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_erfcxf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_erfcx", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_erfcxf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfcx", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def erfcinv(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_erfcinvf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_erfcinv", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_erfcinvf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfcinv", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def normcdfinv(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_normcdfinvf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_normcdfinv", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_normcdfinvf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_normcdfinv", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def normcdf(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_normcdff", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_normcdf", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_normcdff", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_normcdf", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def lgamma(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_lgammaf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_lgamma", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_lgammaf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_lgamma", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def ldexp(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("int32"),): ("__nv_ldexpf", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("int32"),): ("__nv_ldexp", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__nv_ldexpf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__nv_ldexp", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def scalbn(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("int32"),): ("__nv_scalbnf", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("int32"),): ("__nv_scalbn", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__nv_scalbnf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__nv_scalbn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def fmod(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmodf", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fmod", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmodf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_fmod", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def remainder(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_remainderf", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_remainder", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_remainderf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_remainder", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def fma(arg0, arg1, arg2, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def pow(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("int32"),): ("__nv_powif", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("int32"),): ("__nv_powi", core.dtype("fp64")), - (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_powf", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_pow", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__nv_powif", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__nv_powi", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_powf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_pow", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def tgamma(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_tgammaf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_tgamma", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_tgammaf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_tgamma", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def round(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_roundf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_round", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_roundf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_round", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def llround(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_llroundf", core.dtype("int64")), - (core.dtype("fp64"),): ("__nv_llround", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_llroundf", core.dtype("int64")), + (core.dtype("fp64"), ): ("__nv_llround", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def fdim(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdimf", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fdim", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdimf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_fdim", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def ilogb(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_ilogbf", core.dtype("int32")), - (core.dtype("fp64"),): ("__nv_ilogb", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_ilogbf", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_ilogb", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def logb(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_logbf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_logb", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_logbf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_logb", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def isfinited(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_isfinited", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_isfinited", core.dtype("int32")), + }, is_pure=True, _builder=_builder) diff --git a/python/triton/language/random.py b/python/triton/language/random.py index 7af60855b..a8fb520e0 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -91,6 +91,7 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): # two_to_the_minus_32: tl.constexpr = 2.328306e-10 # return x * two_to_the_minus_32 + @jit def uint32_to_uniform_float(x): """ @@ -134,6 +135,7 @@ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): u4 = uint32_to_uniform_float(i4) return u1, u2, u3, u4 + # ------------------- # randn # ------------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index dea078713..468061361 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -19,10 +19,12 @@ def _is_cuda(target): from ..compiler.compiler import CudaTargetDescriptor return isinstance(target, CudaTargetDescriptor) + # Create custom exception that prints message "hello" class IncompatibleTypeErrorImpl(Exception): + def __init__(self, type_a, type_b): self.type_a = type_a self.type_b = type_b @@ -34,6 +36,7 @@ class IncompatibleTypeErrorImpl(Exception): # Programming Model # ===----------------------------------------------------------------------===## + def program_id(axis: int, builder: ir.builder) -> tl.tensor: if axis not in (0, 1, 2): raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}") @@ -45,6 +48,7 @@ def num_programs(axis: int, builder: ir.builder) -> tl.tensor: raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}") return tl.tensor(builder.create_get_num_programs(axis), tl.int32) + # ===----------------------------------------------------------------------===// # Implicit Casting Utilities # ===----------------------------------------------------------------------===// @@ -95,10 +99,12 @@ def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> t # 5 ) both operands are integer and undergo # integer promotion if div_or_mod and a_ty.int_signedness != b_ty.int_signedness: - raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + " because they have different signedness;" + raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + + " because they have different signedness;" "this is unlikely to result in a useful answer. Cast them to the same signedness.") return integer_promote_impl(a_ty, b_ty) + # ===----------------------------------------------------------------------===// # Binary Operators # ===----------------------------------------------------------------------===// @@ -116,12 +122,9 @@ def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) - raise IncompatibleTypeErrorImpl(type_a, type_b) -def binary_op_type_checking_impl(lhs: tl.tensor, - rhs: tl.tensor, - builder: ir.builder, - allow_lhs_ptr=False, allow_rhs_ptr=False, - arithmetic_check=True, div_or_mod=False - ) -> Tuple[tl.tensor, tl.tensor]: +def binary_op_type_checking_impl(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder, allow_lhs_ptr=False, + allow_rhs_ptr=False, arithmetic_check=True, + div_or_mod=False) -> Tuple[tl.tensor, tl.tensor]: # implicit broadcasting lhs, rhs = broadcast_impl_value(lhs, rhs, builder) # implicit typecasting @@ -136,9 +139,7 @@ def binary_op_type_checking_impl(lhs: tl.tensor, return lhs, rhs -def add(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def add(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder, True, True) input_scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar @@ -162,15 +163,12 @@ def add(input: tl.tensor, assert False -def sub(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def sub(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder, True, False) scalar_ty = input.type.scalar # ptr - offset if scalar_ty.is_ptr(): - return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle), - input.type) + return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle), input.type) # float - float if scalar_ty.is_floating(): return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type) @@ -180,9 +178,7 @@ def sub(input: tl.tensor, assert False -def mul(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def mul(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder) scalar_ty = input.type.scalar # float * float @@ -194,9 +190,7 @@ def mul(input: tl.tensor, assert False -def truediv(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def truediv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) input_scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar @@ -222,9 +216,7 @@ def truediv(input: tl.tensor, return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type) -def floordiv(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def floordiv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) input_scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar @@ -239,10 +231,7 @@ def floordiv(input: tl.tensor, assert False -def fdiv(input: tl.tensor, - other: tl.tensor, - ieee_rounding: bool, - builder: ir.builder) -> tl.tensor: +def fdiv(input: tl.tensor, other: tl.tensor, ieee_rounding: bool, builder: ir.builder) -> tl.tensor: input_scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating(): @@ -252,18 +241,14 @@ def fdiv(input: tl.tensor, return tl.tensor(ret, input.type) -def mod(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def mod(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar # float % float if scalar_ty.is_floating(): # input - input.div(other, rounding_mode="floor") * other - ret = sub(input, mul(floor(fdiv(input, other, False, builder), builder), - other, builder), - builder) + ret = sub(input, mul(floor(fdiv(input, other, False, builder), builder), other, builder), builder) return ret # % int elif scalar_ty.is_int(): @@ -277,13 +262,13 @@ def mod(input: tl.tensor, return tl.tensor(builder.create_urem(input.handle, other.handle), input.type) assert False + ############## # bitwise ops ############## -def bitwise_op_type_checking_impl(input: tl.tensor, - other: tl.tensor, +def bitwise_op_type_checking_impl(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]: input, other = binary_op_type_checking_impl(input, other, builder, False, False, False) input_sca_ty = input.type.scalar @@ -298,23 +283,17 @@ def bitwise_op_type_checking_impl(input: tl.tensor, return input, other -def and_(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def and_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = bitwise_op_type_checking_impl(input, other, builder) return tl.tensor(builder.create_and(input.handle, other.handle), input.type) -def or_(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def or_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = bitwise_op_type_checking_impl(input, other, builder) return tl.tensor(builder.create_or(input.handle, other.handle), input.type) -def xor_(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def xor_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = bitwise_op_type_checking_impl(input, other, builder) return tl.tensor(builder.create_xor(input.handle, other.handle), input.type) @@ -341,26 +320,21 @@ def not_(input: tl.tensor, builder: ir.builder): return invert(input, builder) -def lshr(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def lshr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = bitwise_op_type_checking_impl(input, other, builder) return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type) -def ashr(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def ashr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = bitwise_op_type_checking_impl(input, other, builder) return tl.tensor(builder.create_ashr(input.handle, other.handle), input.type) -def shl(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def shl(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = bitwise_op_type_checking_impl(input, other, builder) return tl.tensor(builder.create_shl(input.handle, other.handle), input.type) + # ===----------------------------------------------------------------------===// # Unary Operators # ===----------------------------------------------------------------------===// @@ -370,8 +344,7 @@ def plus(input: tl.tensor) -> tl.tensor: return input -def minus(input: tl.tensor, - builder: ir.builder) -> tl.tensor: +def minus(input: tl.tensor, builder: ir.builder) -> tl.tensor: input_sca_ty = input.type.scalar if input_sca_ty.is_ptr(): raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")") @@ -379,8 +352,7 @@ def minus(input: tl.tensor, return sub(_0, input, builder) -def invert(input: tl.tensor, - builder: tl.tensor) -> tl.tensor: +def invert(input: tl.tensor, builder: tl.tensor) -> tl.tensor: input_sca_ty = input.type.scalar if input_sca_ty.is_ptr() or input_sca_ty.is_floating(): raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")") @@ -398,9 +370,7 @@ def _bool_like(v: tl.tensor) -> tl.block_type: return tl.block_type(tl.int1, shape) -def greater_than(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def greater_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder) scalar_ty = input.type.scalar # float > float @@ -415,9 +385,7 @@ def greater_than(input: tl.tensor, assert False -def greater_equal(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def greater_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder) scalar_ty = input.type.scalar # float >= float @@ -432,9 +400,7 @@ def greater_equal(input: tl.tensor, assert False -def less_than(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def less_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder) scalar_ty = input.type.scalar # float < float @@ -449,9 +415,7 @@ def less_than(input: tl.tensor, assert False -def less_equal(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def less_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder) scalar_ty = input.type.scalar # float < float @@ -466,9 +430,7 @@ def less_equal(input: tl.tensor, assert False -def equal(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder) scalar_ty = input.type.scalar # float == float @@ -480,9 +442,7 @@ def equal(input: tl.tensor, assert False -def not_equal(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def not_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder) scalar_ty = input.type.scalar # float == float @@ -493,6 +453,7 @@ def not_equal(input: tl.tensor, return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input)) assert False + # ===----------------------------------------------------------------------===// # Block Creation # ===----------------------------------------------------------------------===// @@ -541,6 +502,7 @@ def ones(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor: # Shape Manipulation # ===----------------------------------------------------------------------===// + def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor: assert not value.type.is_block(), "Cannot splat a block tensor" if len(shape) == 0: @@ -549,9 +511,7 @@ def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor: return tl.tensor(builder.create_splat(value.handle, shape), ret_ty) -def view(input: tl.tensor, - dst_shape: List[int], - builder: ir.builder) -> tl.tensor: +def view(input: tl.tensor, dst_shape: List[int], builder: ir.builder) -> tl.tensor: # TODO: disable when TritonToTritonGPU handles views properly # assert len(input.shape) == len(dst_shape) @@ -564,9 +524,7 @@ def view(input: tl.tensor, return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty) -def reshape(input: tl.tensor, - dst_shape: List[int], - builder: ir.builder) -> tl.tensor: +def reshape(input: tl.tensor, dst_shape: List[int], builder: ir.builder) -> tl.tensor: raise ValueError("`reshape` is not supported yet. Please use `view` instead if applicable. " "Note that view may reorder elements in an implementation- and context- dependent way.") @@ -596,9 +554,7 @@ def trans(input: tl.tensor, builder: ir.builder) -> tl.tensor: return tl.tensor(builder.create_trans(input.handle), ret_type) -def broadcast_impl_shape(input: tl.tensor, - shape: List[int], - builder: ir.builder) -> tl.tensor: +def broadcast_impl_shape(input: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor: if not input.type.is_block(): ret_ty = tl.block_type(input.type, shape) return tl.tensor(builder.create_splat(input.handle, shape), ret_ty) @@ -616,9 +572,7 @@ def broadcast_impl_shape(input: tl.tensor, return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty) -def broadcast_impl_value(lhs: tl.tensor, - rhs: tl.tensor, - builder: ir.builder) -> tl.tensor: +def broadcast_impl_value(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor: lhs_ty = lhs.type rhs_ty = rhs.type @@ -638,13 +592,15 @@ def broadcast_impl_value(lhs: tl.tensor, if len(lhs_shape) < len(rhs_shape): # Add new axes to lhs for dim in range(len(lhs_shape), len(rhs_shape)): - lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0), tl.block_type(lhs_ty.scalar, [1] + lhs_shape)) + lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0), + tl.block_type(lhs_ty.scalar, [1] + lhs_shape)) lhs_ty = lhs.type lhs_shape = lhs_ty.get_block_shapes() elif len(rhs_shape) < len(lhs_shape): # Add new axes to rhs for dim in range(len(rhs_shape), len(lhs_shape)): - rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0), tl.block_type(rhs_ty.scalar, [1] + rhs_shape)) + rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0), + tl.block_type(rhs_ty.scalar, [1] + rhs_shape)) rhs_ty = rhs.type rhs_shape = rhs_ty.get_block_shapes() assert len(rhs_shape) == len(lhs_shape) @@ -670,14 +626,13 @@ def broadcast_impl_value(lhs: tl.tensor, # (scalar, scalar) => returns original blocks return lhs, rhs + ####### # cast ####### -def bitcast(input: tl.tensor, - dst_ty: tl.dtype, - builder: ir.builder) -> tl.tensor: +def bitcast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor: src_ty = input.type if src_ty.is_block(): dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) @@ -693,13 +648,10 @@ def bitcast(input: tl.tensor, if src_bits != dst_bits: raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to " "data-type of size " + str(dst_bits)) - return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), - dst_ty) + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) -def cast(input: tl.tensor, - dst_ty: tl.dtype, - builder: ir.builder) -> tl.tensor: +def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor: src_ty = input.type if isinstance(dst_ty, tl.constexpr): dst_ty = dst_ty.value @@ -718,8 +670,7 @@ def cast(input: tl.tensor, # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \ (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()): - return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder)), - dst_ty) + return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) # bf16 <=> (not fp32) if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \ @@ -733,9 +684,7 @@ def cast(input: tl.tensor, dst_sca_ty.is_floating() and \ src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth if truncate_fp: - return tl.tensor(builder.create_fp_trunc(input.handle, - dst_ty.to_ir(builder)), - dst_ty) + return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty) # Standard floating types' casting: extension # fp32 => fp64 @@ -745,9 +694,7 @@ def cast(input: tl.tensor, dst_sca_ty.is_floating() and \ src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth if ext_fp: - return tl.tensor(builder.create_fp_ext(input.handle, - dst_ty.to_ir(builder)), - dst_ty) + return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty) # Casting between integer types if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ @@ -758,9 +705,7 @@ def cast(input: tl.tensor, _0 = tl.tensor(builder.get_null_value(ty), input.dtype) return not_equal(input, _0, builder) else: - return tl.tensor(builder.create_int_cast(input.handle, - dst_ty.to_ir(builder), sign_extend), - dst_ty) + return tl.tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty) # Casting standard floating types to integer types if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int(): @@ -769,35 +714,24 @@ def cast(input: tl.tensor, _0 = tl.tensor(builder.get_null_value(ty), input.dtype) return not_equal(input, _0, builder) elif dst_sca_ty.is_int_signed(): - return tl.tensor(builder.create_fp_to_si(input.handle, - dst_ty.to_ir(builder)), - dst_ty) + return tl.tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), dst_ty) else: - return tl.tensor(builder.create_fp_to_ui(input.handle, - dst_ty.to_ir(builder)), - dst_ty) + return tl.tensor(builder.create_fp_to_ui(input.handle, dst_ty.to_ir(builder)), dst_ty) # Casting integer types to standard floating types if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating(): if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): - return tl.tensor(builder.create_ui_to_fp(input.handle, - dst_ty.to_ir(builder)), - dst_ty) + return tl.tensor(builder.create_ui_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) else: - return tl.tensor(builder.create_si_to_fp(input.handle, - dst_ty.to_ir(builder)), - dst_ty) + return tl.tensor(builder.create_si_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) # Casting pointer types to integer types if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): bitwidth = dst_sca_ty.int_bitwidth if bitwidth == 64: - return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), - dst_ty) + return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty) if bitwidth == 1: - return not_equal(cast(input, tl.int64, builder), - tl.tensor(builder.get_int64(0), tl.int64), - builder) + return not_equal(cast(input, tl.int64, builder), tl.tensor(builder.get_int64(0), tl.int64), builder) # Casting integer types to pointer types if src_sca_ty.is_int() and dst_sca_ty.is_ptr(): @@ -809,6 +743,7 @@ def cast(input: tl.tensor, assert False, f'cannot cast {input} to {dst_ty}' + # ===----------------------------------------------------------------------===// # Memory Operators # ===----------------------------------------------------------------------===// @@ -882,6 +817,20 @@ def _str_to_sem(sem_option): return sem +def _str_to_scope(scope_option): + scope = ir.MEM_SYNC_SCOPE.GPU + if scope_option: + if scope_option == "gpu": + scope = ir.MEM_SYNC_SCOPE.GPU + elif scope_option == "cta": + scope = ir.MEM_SYNC_SCOPE.CTA + elif scope_option == "sys": + scope = ir.MEM_SYNC_SCOPE.SYSTEM + else: + raise ValueError(f"Memory semantic {scope_option} not supported") + return scope + + def _canonicalize_boundary_check(boundary_check, block_shape): if boundary_check: if not hasattr(boundary_check, "__iter__"): @@ -913,8 +862,8 @@ def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, evicti boundary_check = _canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes()) # Build IR - return tl.tensor(builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, - is_volatile), dst_ty) + return tl.tensor( + builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile), dst_ty) def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder): @@ -970,19 +919,13 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_ if not mask: return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty) else: - return tl.tensor(builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, - eviction, is_volatile), dst_ty) + return tl.tensor( + builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction, + is_volatile), dst_ty) -def load(ptr: tl.tensor, - mask: Optional[tl.tensor], - other: Optional[tl.tensor], - boundary_check, - padding_option: str, - cache_modifier: str, - eviction_policy: str, - is_volatile: bool, - builder: ir.builder) -> tl.tensor: +def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], boundary_check, padding_option: str, + cache_modifier: str, eviction_policy: str, is_volatile: bool, builder: ir.builder) -> tl.tensor: # Cache, eviction and padding options cache = _str_to_load_cache_modifier(cache_modifier) eviction = _str_to_eviction_policy(eviction_policy) @@ -1007,7 +950,8 @@ def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builde if not val.type.is_block(): val = broadcast_impl_shape(val, block_shape, builder) assert val.type.is_block(), "Value argument must be block type or a scalar" - assert block_shape == val.type.get_block_shapes(), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch" + assert block_shape == val.type.get_block_shapes( + ), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch" assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch" elt_ty = ptr.type.element_ty.element_ty @@ -1065,13 +1009,8 @@ def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder): return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), tl.void) -def store(ptr: tl.tensor, - val: tl.tensor, - mask: Optional[tl.tensor], - boundary_check, - cache_modifier: str, - eviction_policy: str, - builder: ir.builder) -> tl.tensor: +def store(ptr: tl.tensor, val: tl.tensor, mask: Optional[tl.tensor], boundary_check, cache_modifier: str, + eviction_policy: str, builder: ir.builder) -> tl.tensor: # Cache and eviction options cache = _str_to_store_cache_modifier(cache_modifier) eviction = _str_to_eviction_policy(eviction_policy) @@ -1089,22 +1028,16 @@ def store(ptr: tl.tensor, ######### -def atomic_cas(ptr: tl.tensor, - cmp: tl.tensor, - val: tl.tensor, - sem: str, - builder: ir.builder) -> tl.tensor: +def atomic_cas(ptr: tl.tensor, cmp: tl.tensor, val: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: sem = _str_to_sem(sem) + scope = _str_to_scope(scope) element_ty = ptr.type.scalar.element_ty if element_ty.primitive_bitwidth not in [16, 32, 64]: raise ValueError("atomic_cas only supports elements with width {16, 32, 64}") - return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem), val.type) + return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type) -def atom_red_typechecking_impl(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - op: str, +def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, op: str, builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]: if not ptr.type.scalar.is_ptr(): raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) @@ -1129,24 +1062,18 @@ def atom_red_typechecking_impl(ptr: tl.tensor, return ptr, val, mask -def atomic_max(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - sem: str, - builder: ir.builder) -> tl.tensor: +def atomic_max(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder) sem = _str_to_sem(sem) + scope = _str_to_scope(scope) sca_ty = val.type.scalar # direct call to atomic_max for integers if sca_ty.is_int(): if sca_ty.is_int_signed(): - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, - ptr.handle, - val.handle, - mask.handle, - sem), - val.type) + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) else: +<<<<<<< HEAD return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, @@ -1154,6 +1081,10 @@ def atomic_max(ptr: tl.tensor, sem), val.type) # ROCM TODO: implement atomic_max/min for f32 as they are supported by MI cards. +======= + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 # for float # return atomic_smax(i_ptr, i_val) if val >= 0 # return atomic_umin(i_ptr, i_val) if val < 0 @@ -1167,36 +1098,29 @@ def atomic_max(ptr: tl.tensor, i_ptr = bitcast(ptr, tl.pointer_type(itype, 1), builder) pos = greater_equal(val, zero, builder) neg = less_than(val, zero, builder) - pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle, sem), i_val.type) - neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle, sem), i_val.type) + pos_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, + and_(mask, pos, builder).handle, sem, scope), i_val.type) + neg_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, + and_(mask, neg, builder).handle, sem, scope), i_val.type) ret = where(pos, pos_ret, neg_ret, builder) return bitcast(ret, sca_ty, builder) -def atomic_min(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - sem: str, - builder: ir.builder) -> tl.tensor: +def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder) sem = _str_to_sem(sem) + scope = _str_to_scope(scope) sca_ty = val.type.scalar # direct call to atomic_min for integers if sca_ty.is_int(): if sca_ty.is_int_signed(): - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, - ptr.handle, - val.handle, - mask.handle, - sem), - val.type) + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) else: - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, - ptr.handle, - val.handle, - mask.handle, - sem), - val.type) + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) # for float # return atomic_smin(i_ptr, i_val) if val >= 0 # return atomic_umax(i_ptr, i_val) if val < 0 @@ -1210,72 +1134,57 @@ def atomic_min(ptr: tl.tensor, i_ptr = bitcast(ptr, tl.pointer_type(itype, 1), builder) pos = greater_equal(val, zero, builder) neg = less_than(val, zero, builder) - pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, - i_ptr.handle, - i_val.handle, - and_(mask, pos, builder).handle, - sem), - i_val.type) - neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, - i_ptr.handle, - i_val.handle, - and_(mask, neg, builder).handle, - sem), - i_val.type) + pos_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle, + and_(mask, pos, builder).handle, sem, scope), i_val.type) + neg_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, i_ptr.handle, i_val.handle, + and_(mask, neg, builder).handle, sem, scope), i_val.type) ret = where(pos, pos_ret, neg_ret, builder) return bitcast(ret, sca_ty, builder) -def atomic_add(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - sem: str, - builder: ir.builder) -> tl.tensor: +def atomic_add(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder) sem = _str_to_sem(sem) + scope = _str_to_scope(scope) sca_ty = val.type.scalar op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD - return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem), val.type) + return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope), val.type) -def atomic_and(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - sem: str, - builder: ir.builder) -> tl.tensor: +def atomic_and(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder) sem = _str_to_sem(sem) - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem), val.type) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) -def atomic_or(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - sem: str, - builder: ir.builder) -> tl.tensor: +def atomic_or(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder) sem = _str_to_sem(sem) - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem), val.type) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) -def atomic_xor(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - sem: str, - builder: ir.builder) -> tl.tensor: +def atomic_xor(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder) sem = _str_to_sem(sem) - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem), val.type) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) -def atomic_xchg(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - sem: str, +def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder) sem = _str_to_sem(sem) - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem), val.type) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + # ===----------------------------------------------------------------------===// # Linear Algebra @@ -1334,6 +1243,7 @@ def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool: return False return True +<<<<<<< HEAD def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, @@ -1341,6 +1251,12 @@ def dot(lhs: tl.tensor, max_num_imprecise_acc: int, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: +======= + +def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_num_imprecise_acc: int, + out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): # Checks for non-cuda archs if not _is_cuda(target): @@ -1351,22 +1267,30 @@ def dot(lhs: tl.tensor, # Checks for cuda archs if target.capability < 90: - assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(), "Dot op does not support fp8e4nv on CUDA arch < 90" + assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv( + ), "Dot op does not support fp8e4nv on CUDA arch < 90" if lhs_dtype.is_fp8() and rhs_dtype.is_fp8(): return assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" else: - assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15(), "Dot op does not support fp8e4b15 on CUDA arch >= 90" - assert not lhs_dtype.is_fp8e4b15x4() and not rhs_dtype.is_fp8e4b15x4(), "Dot op does not support fp8e4b15x4 on CUDA arch >= 90" + assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15( + ), "Dot op does not support fp8e4b15 on CUDA arch >= 90" + assert not lhs_dtype.is_fp8e4b15x4() and not rhs_dtype.is_fp8e4b15x4( + ), "Dot op does not support fp8e4b15x4 on CUDA arch >= 90" if lhs_dtype.is_int() or rhs_dtype.is_int(): assert lhs_dtype == rhs_dtype, f"Both operands must be same type. First operand ({lhs_dtype}) and second operand ({rhs_dtype})" - assert lhs_dtype.is_int8() or lhs_dtype.is_uint8(), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})" + assert lhs_dtype.is_int8() or lhs_dtype.is_uint8( + ), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})" elif lhs_dtype.is_fp8() or rhs_dtype.is_fp8(): - assert lhs_dtype.is_fp8e4nv() or lhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. First operand ({lhs_dtype})" - assert rhs_dtype.is_fp8e4nv() or rhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. Second operand ({rhs_dtype})" + assert lhs_dtype.is_fp8e4nv() or lhs_dtype.is_fp8e5( + ), f"Only supports fp8e4nv or fp8e5. First operand ({lhs_dtype})" + assert rhs_dtype.is_fp8e4nv() or rhs_dtype.is_fp8e5( + ), f"Only supports fp8e4nv or fp8e5. Second operand ({rhs_dtype})" else: - assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1(), f"Unsupported dtype {lhs_dtype}" - assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1(), f"Unsupported dtype {rhs_dtype}" + assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1( + ), f"Unsupported dtype {lhs_dtype}" + assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1( + ), f"Unsupported dtype {rhs_dtype}" assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" assert lhs.type.is_block() and rhs.type.is_block() @@ -1374,7 +1298,8 @@ def dot(lhs: tl.tensor, assert len(lhs.shape) == 2, f"First input shape ({lhs.shape}) is not two dimensional!" assert len(rhs.shape) == 2, f"Second input shape ({rhs.shape}) is not two dimensional!" - assert lhs.shape[1].value == rhs.shape[0].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[1].value}) must be equal to first index of second shape ({rhs.shape[0].value})" + assert lhs.shape[1].value == rhs.shape[ + 0].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[1].value}) must be equal to first index of second shape ({rhs.shape[0].value})" assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \ and rhs.shape[1].value >= 16, \ f"All values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!" @@ -1397,7 +1322,8 @@ def dot(lhs: tl.tensor, _0 = builder.get_int32(0) ret_scalar_ty = tl.int32 elif out_dtype.is_bf16(): - raise ValueError("out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`") + raise ValueError( + "out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`") elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16(): _0 = builder.get_fp32(0) ret_scalar_ty = tl.float32 @@ -1423,6 +1349,7 @@ def dot(lhs: tl.tensor, else: _0 = builder.create_splat(builder.get_fp32(0), [M, N]) ret_ty = tl.block_type(ret_cast_scalar_ty, [M, N]) +<<<<<<< HEAD ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32, max_num_imprecise_acc), ret_ty) return cast(ret, ret_scalar_ty, builder) @@ -1433,6 +1360,12 @@ def dot(lhs: tl.tensor, if max_num_imprecise_acc is None: max_num_imprecise_acc = 2**30 +======= + ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), ret_ty) + return cast(ret, ret_scalar_ty, builder) + if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32, + ret_scalar_ty) and ret_scalar_ty.primitive_bitwidth < 32: +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 if lhs.type.scalar.is_int(): ret_dot_scalar_ty = tl.int32 _0 = builder.create_splat(builder.get_int32(0), [M, N]) @@ -1440,8 +1373,12 @@ def dot(lhs: tl.tensor, ret_dot_scalar_ty = tl.float32 _0 = builder.create_splat(builder.get_fp32(0), [M, N]) ret_ty = tl.block_type(ret_dot_scalar_ty, [M, N]) +<<<<<<< HEAD ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32, max_num_imprecise_acc), ret_ty) +======= + ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), ret_ty) +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 return cast(ret, ret_scalar_ty, builder) _0 = builder.create_splat(_0, [M, N]) @@ -1453,23 +1390,21 @@ def dot(lhs: tl.tensor, assert acc.type == ret_ty # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90 - if not (_is_cuda(builder.target) and builder.target.capability == 90 and lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and ret_scalar_ty.is_fp32()): + if not (_is_cuda(builder.target) and builder.target.capability == 90 and lhs.dtype.is_fp8() and rhs.dtype.is_fp8() + and ret_scalar_ty.is_fp32()): max_num_imprecise_acc = 0 if max_num_imprecise_acc is None: max_num_imprecise_acc = 2**30 - return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc), - ret_ty) + return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc), ret_ty) # ===----------------------------------------------------------------------===// # Indexing # ===----------------------------------------------------------------------===// -def where(condition: tl.tensor, - x: tl.tensor, - y: tl.tensor, - builder: ir.builder) -> tl.tensor: + +def where(condition: tl.tensor, x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor: condition = cast(condition, tl.int1, builder) if condition.type.is_block(): condition, x = broadcast_impl_value(condition, x, builder) @@ -1482,14 +1417,13 @@ def where(condition: tl.tensor, ret_ty = x.type return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty) + # ===----------------------------------------------------------------------===// # Reduction # ===----------------------------------------------------------------------=== -def reduction( - inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder -) -> Tuple[tl.tensor, ...]: +def reduction(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder) -> Tuple[tl.tensor, ...]: if axis is None: new_inputs = [] for i in range(len(inputs)): @@ -1515,10 +1449,7 @@ def reduction( region_builder_fn(reduce_op) reduce_op.verify() - return tuple( - wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar) - for i in range(len(inputs)) - ) + return tuple(wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar) for i in range(len(inputs))) # ===----------------------------------------------------------------------=== @@ -1526,9 +1457,8 @@ def reduction( # ===----------------------------------------------------------------------=== -def associative_scan( - inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder -) -> Tuple[tl.tensor, ...]: +def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, + builder: ir.builder) -> Tuple[tl.tensor, ...]: if len(inputs) != 1: raise ValueError("Current implementation only support single tensor input") shape = inputs[0].type.shape @@ -1541,16 +1471,14 @@ def associative_scan( region_builder_fn(scan_op) scan_op.verify() - return tuple( - wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar) - for i in range(len(inputs)) - ) + return tuple(wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar) for i in range(len(inputs))) # ===----------------------------------------------------------------------=== # Math # ===----------------------------------------------------------------------=== + def _check_dtype(dtypes: List[str]) -> T: """ We're following libdevice's convention to check accepted data types for math functions. @@ -1559,7 +1487,9 @@ def _check_dtype(dtypes: List[str]) -> T: We should let the users know that they are using and invoke explicit cast to convert the data type to the supported one. """ + def wrapper(fn): + @wraps(fn) def check(*args, **kwargs): # concatenate args and kwargs @@ -1568,6 +1498,7 @@ def _check_dtype(dtypes: List[str]) -> T: if arg.type.scalar.name not in dtypes: raise ValueError(f"Expected dtype {dtypes} but got {arg.type.scalar.name}") return fn(*args, **kwargs) + return check return wrapper @@ -1653,6 +1584,15 @@ def debug_barrier(builder: ir.builder) -> tl.tensor: def device_print(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor: + # It makes sense visually for prefix to end in ": "; make it so. Also, + # non-empty prefixes should start with " ". + if not prefix.endswith(" ") and args: + prefix += " " + if not prefix.endswith(": ") and args: + prefix = prefix[:-1] + ": " + if len(prefix) > 2 and not prefix.startswith(" "): + prefix = " " + prefix + new_args = [] for arg in args: new_args.append(arg.handle) @@ -1662,8 +1602,8 @@ def device_print(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl. def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor: cond_ty = cond.type if not cond_ty.is_block(): - cond_ty = tl.block_type(cond_ty.scalar, (1,)) - cond = tl.tensor(builder.create_splat(cond.handle, (1,)), cond_ty) + cond_ty = tl.block_type(cond_ty.scalar, (1, )) + cond = tl.tensor(builder.create_splat(cond.handle, (1, )), cond_ty) return tl.tensor(builder.create_assert(cond.handle, msg, file_name, func_name, lineno), tl.void) diff --git a/python/triton/language/standard.py b/python/triton/language/standard.py index 8ef52cb9c..acda0ca7a 100644 --- a/python/triton/language/standard.py +++ b/python/triton/language/standard.py @@ -123,6 +123,7 @@ def maximum(x, y): """ return math.max(x, y) + # max and argmax @@ -149,8 +150,7 @@ def _argmax_combine_tie_break_fast(value1, index1, value2, index2): @jit -@core._add_reduction_docstr("maximum", - return_indices_arg="return_indices", +@core._add_reduction_docstr("maximum", return_indices_arg="return_indices", tie_break_arg="return_indices_tie_break_left") def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True): input = core._promote_reduction_input(input) @@ -175,6 +175,7 @@ def argmax(input, axis, tie_break_left=True): (_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left) return ret + # min and argmin @@ -201,8 +202,7 @@ def _argmin_combine_tie_break_fast(value1, index1, value2, index2): @jit -@core._add_reduction_docstr("minimum", - return_indices_arg="return_indices", +@core._add_reduction_docstr("minimum", return_indices_arg="return_indices", tie_break_arg="return_indices_tie_break_left") def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True): input = core._promote_reduction_input(input) @@ -222,8 +222,7 @@ def min(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr @jit -@core._add_reduction_docstr("minimum index", - tie_break_arg="tie_break_left") +@core._add_reduction_docstr("minimum index", tie_break_arg="tie_break_left") def argmin(input, axis, tie_break_left=True): _, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left) return ret @@ -233,6 +232,7 @@ def argmin(input, axis, tie_break_left=True): def _sum_combine(a, b): return a + b + # sum @@ -247,6 +247,7 @@ def sum(input, axis=None): def _xor_combine(a, b): return a ^ b + # xor sum @@ -258,8 +259,8 @@ def xor_sum(input, axis=None, _builder=None, _generator=None): raise ValueError("xor_sum only supported for integers") input = core._promote_reduction_input(input, _builder=_builder) - return core.reduce(input, axis, _xor_combine, - _builder=_builder, _generator=_generator) + return core.reduce(input, axis, _xor_combine, _builder=_builder, _generator=_generator) + # cumsum @@ -271,6 +272,7 @@ def cumsum(input, axis=0): input = core._promote_reduction_input(input) return core.associative_scan(input, axis, _sum_combine) + # cumprod diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index eaf4f2f40..098e15438 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -17,15 +17,14 @@ from ... import language as tl 'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0, }) @jit -def _sdd_kernel( - A, B, C, - stride_za, stride_ha, stride_ma, stride_ak, - stride_zb, stride_hb, stride_bk, stride_nb, - stride_zc, stride_hc, stride_mc, stride_nc, - K, grid_offset, lut, - TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, - BLOCK: tl.constexpr, EVEN_K: tl.constexpr -): +def _sdd_kernel(A, B, C, # + stride_za, stride_ha, stride_ma, stride_ak, # + stride_zb, stride_hb, stride_bk, stride_nb, # + stride_zc, stride_hc, stride_mc, stride_nc, # + K, grid_offset, lut, # + TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, # + BLOCK: tl.constexpr, EVEN_K: tl.constexpr # + ): # ------------ # # - Prologue - # # ------------ # @@ -104,13 +103,13 @@ def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out= c = out grid = [c.shape[1], 1, c.shape[0]] _sdd_kernel[grid]( - a, b, c, - a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), - b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), - c.stride(0), c.stride(1), c.stride(2), c.stride(3), - Ka, 0, lut, - TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4, - num_warps=4, + a, b, c, # + a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), # + b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), # + c.stride(0), c.stride(1), c.stride(2), c.stride(3), # + Ka, 0, lut, # + TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4, # + num_warps=4 # ) return c @@ -120,6 +119,7 @@ def sdd_lut(layout, block, device): lut = lut.contiguous() return lut, None + # ----------------------------- # Dense = Sparse x Dense (DSD) # This operation uses a look-up table that contains pre-computed pointer increments @@ -128,15 +128,14 @@ def sdd_lut(layout, block, device): @jit -def _dsd_kernel( - A, B, C, - stride_az, stride_ha, stride_am, stride_ak, - stride_zb, stride_hb, stride_bk, stride_bn, - stride_zc, stride_hc, stride_cm, stride_cn, - DS0, DS1, lut, - TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr -): +def _dsd_kernel(A, B, C, # + stride_az, stride_ha, stride_am, stride_ak, # + stride_zb, stride_hb, stride_bk, stride_bn, # + stride_zc, stride_hc, stride_cm, stride_cn, # + DS0, DS1, lut, # + TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr # + ): # ------------ # # - Prologue - # # ------------ # @@ -229,13 +228,13 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=N # compute output grid = lambda meta: [cdiv(BS3, meta['TILE_N']), width, BS0] _dsd_kernel[grid]( - a, b, c, - a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), - b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), - c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), - BS3, AS1, lut, - TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4, - num_warps=4, GROUP_SIZE_M=4, + a, b, c, # + a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), # + b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), # + c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), # + BS3, AS1, lut, # + TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4, # + num_warps=4, GROUP_SIZE_M=4 # ) # exit() return c @@ -337,6 +336,7 @@ def dsd_lut(layout, block, step, trans, device): # create locks return lut, width + # ----------------------------- # Dense = Dense x Sparse (DDS) # ----------------------------- @@ -346,6 +346,7 @@ def dsd_lut(layout, block, step, trans, device): def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None): return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out) + ############## # MAIN API # ############## @@ -356,10 +357,8 @@ class _matmul(torch.autograd.Function): fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul} @staticmethod - def forward( - ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, - c_lut, c_width, da_lut, da_width, db_lut, db_width, out - ): + def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_width, da_lut, da_width, db_lut, + db_width, out): c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out) # save for backward ctx.save_for_backward(a, b) @@ -385,15 +384,13 @@ class _matmul(torch.autograd.Function): # gradients w.r.t. a if ctx.needs_input_grad[0]: mode_da = mode[1] + mode[0] + mode[2] - da = _matmul.fn[mode_da]( - dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_width, - ) + da = _matmul.fn[mode_da](dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, + ctx.da_lut, ctx.da_width) # gradients w.r.t. b if ctx.needs_input_grad[1]: mode_db = mode[2] + mode[1] + mode[0] - db = _matmul.fn[mode_db]( - a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_width, - ) + db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, + ctx.db_lut, ctx.db_width) dout = dc if ctx.has_out else None return da, db, None, None, None, \ None, None, None, None, \ @@ -427,11 +424,9 @@ class matmul: self.db_lut, self.db_width = sdd_lut(layout, block, device) def __call__(self, a, b, out=None): - c = _matmul.apply( - a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, - self.c_lut, self.c_width, - self.da_lut, self.da_width, - self.db_lut, self.db_width, - out - ) + c = _matmul.apply(a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, # + self.c_lut, self.c_width, # + self.da_lut, self.da_width, # + self.db_lut, self.db_width, # + out) return c diff --git a/python/triton/ops/blocksparse/softmax.py b/python/triton/ops/blocksparse/softmax.py index c045b11a5..bcffff26b 100644 --- a/python/triton/ops/blocksparse/softmax.py +++ b/python/triton/ops/blocksparse/softmax.py @@ -18,14 +18,13 @@ def num_warps(n): @jit -def _blocksparse_softmax_fwd( - Out, A, stride_xz, LUT, - R, extent, stride_zr, stride_hr, # relative attention - scale, is_causal, - ROW_SIZE: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - IS_DENSE: tl.constexpr, -): +def _blocksparse_softmax_fwd(Out, A, stride_xz, LUT, # + R, extent, stride_zr, stride_hr, # relative attention + scale, is_causal, # + ROW_SIZE: tl.constexpr, # + BLOCK_SIZE: tl.constexpr, # + IS_DENSE: tl.constexpr # + ): h = tl.program_id(0) m = tl.program_id(1) z = tl.program_id(2) @@ -73,18 +72,16 @@ def _blocksparse_softmax_fwd( @jit -def _blocksparse_softmax_bwd( - DA, stride_zdx, - DOut, stride_zdout, - Out, stride_zout, - scale, - LUT, - DR, extent, stride_zr, stride_hr, stride_er, - is_causal, - ROW_SIZE: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - IS_DENSE: tl.constexpr, -): +def _blocksparse_softmax_bwd(DA, stride_zdx, # + DOut, stride_zdout, # + Out, stride_zout, # + scale, # + LUT, # + DR, extent, stride_zr, stride_hr, stride_er, # + is_causal, # + ROW_SIZE: tl.constexpr, # + BLOCK_SIZE: tl.constexpr, # + IS_DENSE: tl.constexpr): h = tl.program_id(0) m = tl.program_id(1) z = tl.program_id(2) @@ -133,6 +130,7 @@ def _blocksparse_softmax_bwd( class _softmax(torch.autograd.Function): + @staticmethod def make_lut(layout, block, device): _empty = torch.tensor([], dtype=torch.int64, device=layout.device) @@ -151,10 +149,7 @@ class _softmax(torch.autograd.Function): return lut, int(total_sizes.max()) @staticmethod - def forward( - ctx, a, scale, rel_logits, is_causal, - spdims, block, lut, maxlut, is_dense - ): + def forward(ctx, a, scale, rel_logits, is_causal, spdims, block, lut, maxlut, is_dense): if scale is not None and isinstance(scale, torch.Tensor): assert scale.device.type == "cpu" scale = scale.item() @@ -165,14 +160,14 @@ class _softmax(torch.autograd.Function): # enqueue kernel out = torch.empty_like(a) _blocksparse_softmax_fwd[grid]( - out, a, a.stride(0), lut, - rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn - scale, - is_causal, - BLOCK_SIZE=block, - ROW_SIZE=next_power_of_2(maxlut), - IS_DENSE=is_dense, - num_warps=num_warps(maxlut) + out, a, a.stride(0), lut, # + rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn# + scale, # + is_causal, # + BLOCK_SIZE=block, # + ROW_SIZE=next_power_of_2(maxlut), # + IS_DENSE=is_dense, # + num_warps=num_warps(maxlut) # ) # save to context # ctx.mark_dirty(x) @@ -201,28 +196,23 @@ class _softmax(torch.autograd.Function): grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M) da = torch.empty_like(dout) _blocksparse_softmax_bwd[grid]( - da, da.stride(0), - dout, dout.stride(0), - out, out.stride(0), - ctx.scale, - lut, - dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2], - ctx.is_causal, - BLOCK_SIZE=ctx.block, - ROW_SIZE=next_power_of_2(ctx.maxlut), - IS_DENSE=ctx.is_dense, - num_warps=num_warps(ctx.maxlut) + da, da.stride(0), # + dout, dout.stride(0), # + out, out.stride(0), # + ctx.scale, # + lut, # + dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2], # + ctx.is_causal, # + BLOCK_SIZE=ctx.block, # + ROW_SIZE=next_power_of_2(ctx.maxlut), # + IS_DENSE=ctx.is_dense, # + num_warps=num_warps(ctx.maxlut) # ) - return (da, None, None, dr, None, - None, None, None, None, None, - None, - None, None, None, - None, - None, None, None - ) + return (da, None, None, dr, None, None, None, None, None, None, None, None, None, None, None, None, None, None) class softmax: + def __init__(self, layout, block, device, is_dense=False): self.spdims = layout.shape self.layout = layout @@ -233,8 +223,6 @@ class softmax: def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False): if rel_logits is not None and rel_logits.dtype != a.dtype: raise ValueError(f"relative position embedding must be {a.dtype}") - a = _softmax.apply( - a, scale, rel_logits, is_causal, - self.spdims, self.block, self.lut, self.maxlut, self.is_dense, - ) + a = _softmax.apply(a, scale, rel_logits, is_causal, self.spdims, self.block, self.lut, self.maxlut, + self.is_dense) return a diff --git a/python/triton/ops/cross_entropy.py b/python/triton/ops/cross_entropy.py index 5b0bcf836..88e8dae50 100644 --- a/python/triton/ops/cross_entropy.py +++ b/python/triton/ops/cross_entropy.py @@ -59,6 +59,7 @@ def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr): class _cross_entropy(torch.autograd.Function): + @classmethod def forward(cls, ctx, logits, indices): # make sure we can use triton diff --git a/python/triton/ops/flash_attention.py b/python/triton/ops/flash_attention.py index 74bc240b1..185606efb 100644 --- a/python/triton/ops/flash_attention.py +++ b/python/triton/ops/flash_attention.py @@ -15,20 +15,19 @@ from .. import language as tl @jit -def _fwd_kernel( - Q, K, V, sm_scale, - L, - Out, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vn, stride_vk, - stride_oz, stride_oh, stride_om, stride_on, - Z, H, N_CTX, - Z_H_N_CTX, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - IS_CAUSAL: tl.constexpr, -): +def _fwd_kernel(Q, K, V, sm_scale, # + L, # + Out, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + stride_oz, stride_oh, stride_om, stride_on, # + Z, H, N_CTX, # + Z_H_N_CTX, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr, # + IS_CAUSAL: tl.constexpr # + ): start_m = tl.program_id(0) off_hz = tl.program_id(1) qvk_offset = off_hz * stride_qh @@ -40,7 +39,7 @@ def _fwd_kernel( strides=(stride_kk, stride_kn), offsets=(0, vk_offset), block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1) + order=(0, 1), ) V_block_ptr = tl.make_block_ptr( base=V, @@ -48,7 +47,7 @@ def _fwd_kernel( strides=(stride_vn, stride_vk), offsets=(vk_offset, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0) + order=(1, 0), ) # initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) @@ -104,7 +103,7 @@ def _fwd_kernel( strides=(stride_om, stride_on), offsets=(vk_offset + start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) + order=(1, 0), ) # O_ptrs = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk tl.store(O_block_ptr, acc.to(K.dtype.element_ty)) @@ -112,9 +111,11 @@ def _fwd_kernel( @jit def _bwd_preprocess( - Out, DO, + Out, + DO, Delta, - BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, + BLOCK_M: tl.constexpr, + D_HEAD: tl.constexpr, ): off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) off_n = tl.arange(0, D_HEAD) @@ -128,40 +129,48 @@ def _bwd_preprocess( @jit -def _bwd_kernel_one_col_block( - Q, K, V, sm_scale, qk_scale, - Out, DO, - DQ, DK, DV, - L, - D, - stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vn, stride_vk, - Z, H, N_CTX, - off_hz, start_n, num_block, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - SEQUENCE_PARALLEL: tl.constexpr, - CAUSAL: tl.constexpr, - MMA_V3: tl.constexpr -): - if SEQUENCE_PARALLEL: - DQ += stride_dqa.to(tl.int64) * start_n +def _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, # + Out, DO, # + DQ, DK, DV, # + L, # + D, # + Q_block_ptr, K_block_ptr, V_block_ptr, # + DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, # + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + Z, H, N_CTX, # + off_h, off_z, off_hz, start_n, num_block, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr, # + SEQUENCE_PARALLEL: tl.constexpr, # + CAUSAL: tl.constexpr, # + MMA_V3: tl.constexpr # + ): if CAUSAL: lo = start_n * BLOCK_M else: lo = 0 + + Q_offset = (off_z * stride_qz + off_h * stride_qh) // stride_qm + DQ_offset = off_z * stride_qz + off_h * stride_qh + K_offset = (off_z * stride_kz + off_h * stride_kh) // stride_kn + V_offset = (off_z * stride_vz + off_h * stride_vh) // stride_vn + if SEQUENCE_PARALLEL: + DQ_offset += stride_dqa.to(tl.int64) * start_n + DQ_offset = DQ_offset // stride_qm + + Q_block_ptr = tl.advance(Q_block_ptr, (lo + Q_offset, 0)) + K_block_ptr = tl.advance(K_block_ptr, (start_n * BLOCK_M + K_offset, 0)) + V_block_ptr = tl.advance(V_block_ptr, (start_n * BLOCK_M + V_offset, 0)) + DO_block_ptr = tl.advance(DO_block_ptr, (lo + Q_offset, 0)) + DQ_block_ptr = tl.advance(DQ_block_ptr, (lo + DQ_offset, 0)) + DK_block_ptr = tl.advance(DK_block_ptr, (start_n * BLOCK_M + K_offset, 0)) + DV_block_ptr = tl.advance(DV_block_ptr, (start_n * BLOCK_M + V_offset, 0)) + # initialize row/col offsets - offs_qm = lo + tl.arange(0, BLOCK_M) offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) offs_m = tl.arange(0, BLOCK_N) - offs_k = tl.arange(0, BLOCK_DMODEL) - # initialize pointers to value-like data - q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) - v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) - do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) # pointer to row-wise quantities in value-like data D_ptrs = D + off_hz * N_CTX l_ptrs = L + off_hz * N_CTX @@ -169,17 +178,17 @@ def _bwd_kernel_one_col_block( dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # k and v stay in SRAM throughout - k = tl.load(k_ptrs) - v = tl.load(v_ptrs) + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) # loop over rows for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): offs_m_curr = start_m + offs_m # load q, k, v, do on-chip - q = tl.load(q_ptrs) + q = tl.load(Q_block_ptr) # recompute p = softmax(qk, dim=-1).T # NOTE: `do` is pre-divided by `l`; no normalization here if CAUSAL: - qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf")) + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.0), float("-inf")) else: qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, tl.trans(k)) @@ -187,7 +196,7 @@ def _bwd_kernel_one_col_block( l_i = tl.load(l_ptrs + offs_m_curr) p = tl.math.exp2(qk - l_i[:, None]) # compute dv - do = tl.load(do_ptrs) + do = tl.load(DO_block_ptr) dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do, allow_tf32=True) # compute dp = dot(v, do) Di = tl.load(D_ptrs + offs_m_curr) @@ -199,97 +208,156 @@ def _bwd_kernel_one_col_block( dk += tl.dot(tl.trans(ds), q, allow_tf32=True) # compute dq if not SEQUENCE_PARALLEL: - dq = tl.load(dq_ptrs) + dq = tl.load(DQ_block_ptr) dq += tl.dot(ds, k, allow_tf32=True) - tl.store(dq_ptrs, dq) + tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty)) elif SEQUENCE_PARALLEL: if MMA_V3: dq = tl.dot(ds, k, allow_tf32=True) else: # not work with mma v3, becuase M % 64 != 0 dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds), allow_tf32=True)) - tl.store(dq_ptrs, dq) + tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty)) # increment pointers - dq_ptrs += BLOCK_M * stride_qm - q_ptrs += BLOCK_M * stride_qm - do_ptrs += BLOCK_M * stride_qm + DQ_block_ptr = tl.advance(DQ_block_ptr, (BLOCK_M, 0)) + Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M, 0)) + DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_M, 0)) # write-back - dv_ptrs = DV + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) - dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) - tl.store(dv_ptrs, dv) - tl.store(dk_ptrs, dk) + tl.store(DV_block_ptr, dv.to(V.dtype.element_ty)) + tl.store(DK_block_ptr, dk.to(K.dtype.element_ty)) @jit -def _bwd_kernel( - # fmt: off - Q, K, V, sm_scale, - Out, DO, - DQ, DK, DV, - L, - D, - stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vn, stride_vk, - Z, H, N_CTX, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - SEQUENCE_PARALLEL: tl.constexpr, - CAUSAL: tl.constexpr, - MMA_V3: tl.constexpr - # fmt: on -): +def _bwd_kernel(Q, K, V, sm_scale, # + Out, DO, # + DQ, DK, DV, # + L, # + D, # + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + Z, H, N_CTX, # + Z_H_N_CTX, # + SQ_Z_H_N_CTX, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr, # + SEQUENCE_PARALLEL: tl.constexpr, # + CAUSAL: tl.constexpr, # + MMA_V3: tl.constexpr # + ): qk_scale = sm_scale * 1.44269504 off_hz = tl.program_id(0) off_z = off_hz // H off_h = off_hz % H - # offset pointers for batch/head - Q += off_z * stride_qz + off_h * stride_qh - K += off_z * stride_kz + off_h * stride_kh - V += off_z * stride_vz + off_h * stride_vh - DO += off_z * stride_qz + off_h * stride_qh - DQ += off_z * stride_qz + off_h * stride_qh - DK += off_z * stride_kz + off_h * stride_kh - DV += off_z * stride_vz + off_h * stride_vh + + Q_block_ptr = tl.make_block_ptr( + base=Q, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_kn, stride_kk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_vn, stride_vk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + DO_block_ptr = tl.make_block_ptr( + base=DO, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + if SEQUENCE_PARALLEL: + DQ_block_ptr = tl.make_block_ptr( + base=DQ, + shape=(SQ_Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + else: + DQ_block_ptr = tl.make_block_ptr( + base=DQ, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + + DK_block_ptr = tl.make_block_ptr( + base=DK, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_kn, stride_kk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + DV_block_ptr = tl.make_block_ptr( + base=DV, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_vn, stride_vk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) num_block_n = tl.cdiv(N_CTX, BLOCK_N) if not SEQUENCE_PARALLEL: for start_n in range(0, num_block_n): - _bwd_kernel_one_col_block( - Q, K, V, sm_scale, qk_scale, Out, DO, - DQ, DK, DV, - L, - D, - stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vn, stride_vk, - Z, H, N_CTX, - off_hz, start_n, num_block_n, - BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_N=BLOCK_N, - SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, - CAUSAL=CAUSAL, - MMA_V3=MMA_V3 - ) + _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, # + DQ, DK, DV, # + L, # + D, # + Q_block_ptr, K_block_ptr, V_block_ptr, # + DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, # + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + Z, H, N_CTX, # + off_h, off_z, off_hz, start_n, num_block_n, # + BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, # + BLOCK_N=BLOCK_N, # + SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, # + CAUSAL=CAUSAL, # + MMA_V3=MMA_V3 # + ) else: start_n = tl.program_id(1) - _bwd_kernel_one_col_block( - Q, K, V, sm_scale, qk_scale, Out, DO, - DQ, DK, DV, - L, - D, - stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vn, stride_vk, - Z, H, N_CTX, - off_hz, start_n, num_block_n, - BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_N=BLOCK_N, - SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, - CAUSAL=CAUSAL, - MMA_V3=MMA_V3 - ) + _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, # + DQ, DK, DV, # + L, # + D, # + Q_block_ptr, K_block_ptr, V_block_ptr, # + DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, # + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + Z, H, N_CTX, # + off_h, off_z, off_hz, start_n, num_block_n, # + BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, # + BLOCK_N=BLOCK_N, # + SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, # + CAUSAL=CAUSAL, # + MMA_V3=MMA_V3 # + ) class _attention(torch.autograd.Function): @@ -315,19 +383,20 @@ class _attention(torch.autograd.Function): L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) num_warps = 4 if Lk <= 64 else 8 _fwd_kernel[grid]( - q, k, v, sm_scale, - L, - o, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - o.stride(0), o.stride(1), o.stride(2), o.stride(3), - q.shape[0], q.shape[1], q.shape[2], - q.shape[0] * q.shape[1] * q.shape[2], - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, - IS_CAUSAL=causal, - num_warps=num_warps, - num_stages=4) + q, k, v, sm_scale, # + L, # + o, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + o.stride(0), o.stride(1), o.stride(2), o.stride(3), # + q.shape[0], q.shape[1], q.shape[2], # + q.shape[0] * q.shape[1] * q.shape[2], # + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, # + IS_CAUSAL=causal, # + num_warps=num_warps, # + num_stages=4 # + ) ctx.save_for_backward(q, k, v, o, L) ctx.grid = grid @@ -348,35 +417,39 @@ class _attention(torch.autograd.Function): do = do.contiguous() if sequence_parallel: replicas = cdiv(seq_len_kv, BLOCK) - new_dq_shape = (replicas,) + q.shape + new_dq_shape = (replicas, ) + q.shape dq = torch.zeros(new_dq_shape, device=q.device, dtype=q.dtype) else: - dq = torch.zeros_like(q, dtype=torch.float32) + dq = torch.zeros_like(q, dtype=q.dtype) dk = torch.empty_like(k) dv = torch.empty_like(v) delta = torch.empty_like(L) _bwd_preprocess[(cdiv(q.shape[2], BLOCK) * ctx.grid[1], )]( - o, do, + o, + do, delta, - BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL, + BLOCK_M=BLOCK, + D_HEAD=ctx.BLOCK_DMODEL, ) _bwd_kernel[(ctx.grid[1], cdiv(seq_len_kv, BLOCK) if sequence_parallel else 1)]( - q, k, v, ctx.sm_scale, - o, do, - dq, dk, dv, - L, - delta, - o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - q.shape[0], q.shape[1], q.shape[2], - BLOCK_M=BLOCK, BLOCK_N=BLOCK, - BLOCK_DMODEL=ctx.BLOCK_DMODEL, - SEQUENCE_PARALLEL=sequence_parallel, - CAUSAL=ctx.causal, - MMA_V3=MMA_V3, - num_warps=8, - num_stages=1, + q, k, v, ctx.sm_scale, # + o, do, # + dq, dk, dv, # + L, # + delta, # + o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + q.shape[0], q.shape[1], q.shape[2], # + q.shape[0] * q.shape[1] * q.shape[2], # + cdiv(seq_len_kv, BLOCK) * q.shape[0] * q.shape[1] * q.shape[2], # + BLOCK_M=BLOCK, BLOCK_N=BLOCK, # + BLOCK_DMODEL=ctx.BLOCK_DMODEL, # + SEQUENCE_PARALLEL=sequence_parallel, # + CAUSAL=ctx.causal, # + MMA_V3=MMA_V3, # + num_warps=8, # + num_stages=1 # ) if len(dq.shape) == 5: diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 63490504d..c8f681a4a 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -37,8 +37,9 @@ def get_configs_io_bound(): num_stages=num_stages, num_warps=num_warps)) # split_k for split_k in [2, 4, 8, 16]: - configs.append(Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, - num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + configs.append( + Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) return configs @@ -69,22 +70,22 @@ def get_configs_io_bound(): prune_configs_by={ 'early_config_prune': early_config_prune, 'perf_model': estimate_matmul_time, - 'top_k': 10 + 'top_k': 10, }, ) @heuristics({ 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, }) @jit -def _kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - dot_out_dtype: tl.constexpr, - allow_tf32: tl.constexpr, - fp8_fast_accum: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr +def _kernel(A, B, C, M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + dot_out_dtype: tl.constexpr, # + allow_tf32: tl.constexpr, # + fp8_fast_accum: tl.constexpr, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr # ): # matrix multiplication pid = tl.program_id(0) @@ -184,14 +185,15 @@ class _matmul(torch.autograd.Function): ab_dtype = False # launch kernel grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K']) - _kernel[grid](a, b, c, M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - dot_out_dtype=dot_out_dtype, - allow_tf32=allow_tf32, - fp8_fast_accum=fp8_fast_accum, - GROUP_M=8, AB_DTYPE=ab_dtype) + _kernel[grid]( + a, b, c, M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + dot_out_dtype=dot_out_dtype, # + allow_tf32=allow_tf32, # + fp8_fast_accum=fp8_fast_accum, # + GROUP_M=8, AB_DTYPE=ab_dtype) return c @staticmethod diff --git a/python/triton/ops/matmul_perf_model.py b/python/triton/ops/matmul_perf_model.py index abe5325ee..1e07b0a02 100644 --- a/python/triton/ops/matmul_perf_model.py +++ b/python/triton/ops/matmul_perf_model.py @@ -5,8 +5,7 @@ import torch from .. import cdiv from .._C.libtriton.triton import runtime from ..runtime import driver -from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops, - nvsmi) +from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops, nvsmi) def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype): @@ -14,7 +13,8 @@ def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype): total_warps = num_ctas * min(num_warps, 4) num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs cur_sm_clock = nvsmi(['clocks.current.sm'])[0] - tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, cur_sm_clock, backend, device) + tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops( + dtype, cur_sm_clock, backend, device) return tflops @@ -35,12 +35,12 @@ def get_tflops(backend, device, num_ctas, num_warps, dtype): def estimate_matmul_time( - # backend, device, - num_warps, num_stages, - A, B, C, - M, N, K, - BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, - debug=False, **kwargs + # backend, device, + num_warps, num_stages, # + A, B, C, # + M, N, K, # + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, # + debug=False, **kwargs # ): ''' return estimated running time in ms = max(compute, loading) + store ''' @@ -149,8 +149,9 @@ def early_config_prune(configs, named_args): optimal_num_stages = ldgsts_latency / mma_cycles # nearest stages, prefer large #stages - nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) - if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages) + nearest = heapq.nsmallest( + 2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) + if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages) for n in nearest: pruned_configs.append(n[0]) diff --git a/python/triton/runtime/__init__.py b/python/triton/runtime/__init__.py index a4291ab31..e785018e0 100644 --- a/python/triton/runtime/__init__.py +++ b/python/triton/runtime/__init__.py @@ -1,8 +1,6 @@ -from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune, - heuristics) +from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune, heuristics) from .driver import driver -from .jit import (JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret, - version_key) +from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret __all__ = [ "driver", @@ -12,7 +10,6 @@ __all__ = [ "heuristics", "JITFunction", "KernelInterface", - "version_key", "reinterpret", "TensorWrapper", "OutOfResources", diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 07803c9c3..0ba6da07e 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -9,11 +9,10 @@ from .jit import KernelInterface class OutOfResources(Exception): + def __init__(self, required, limit, name): - self.message = f'out of resource: {name}, '\ - f'Required: {required}, '\ - f'Hardware limit: {limit}' - self.message += '. Reducing block sizes or `num_stages` may help.' + self.message = (f"out of resource: {name}, Required: {required}, Hardware limit: {limit}. " + + "Reducing block sizes or `num_stages` may help.") self.required = required self.limit = limit self.name = name @@ -25,38 +24,77 @@ class OutOfResources(Exception): class Autotuner(KernelInterface): +<<<<<<< HEAD def __init__(self, fn, arg_names, configs, key, verbose, reset_to_zero, prune_configs_by: Dict = None, warmup=25, rep=100): ''' +======= + + def __init__( + self, + fn, + arg_names, + configs, + key, + reset_to_zero, + restore_value, + prune_configs_by: Dict = None, + warmup=25, + rep=100, + ): + """ +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 :param prune_configs_by: a dict of functions that are used to prune configs, fields: 'perf_model': performance model used to predicate running time with different configs, returns running time 'top_k': number of configs to bench 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. - ''' + """ if not configs: self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)] else: self.configs = configs self.key_idx = [arg_names.index(k) for k in key] self.cache = {} - # hook to reset all required tensor to zeros before relaunching a kernel - self.hook = lambda args: 0 + self.arg_names = arg_names + + # Reset to zero or restore values + self.reset_idx = [] if reset_to_zero is not None: self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + self.restore_idx = [] + if restore_value is not None: + self.restore_idx = [arg_names.index(k) for k in restore_value] - def _hook(args): + # Hook to reset or restore for required tensors + self.pre_hook = lambda args, reset_only=False: 0 + self.post_hook = lambda args: 0 + if len(self.reset_idx) > 0 or len(self.restore_idx) > 0: + + def _pre_hook(args, reset_only=False): for i in self.reset_idx: args[i].zero_() - self.hook = _hook - self.arg_names = arg_names - # prune configs + if not reset_only: + self.restore_copies = [args[i].clone() for i in self.restore_idx] + + self.pre_hook = _pre_hook + if len(self.restore_idx) > 0: + + def _post_hook(args): + for i, j in enumerate(self.restore_idx): + args[j].copy_(self.restore_copies[i]) + self.restore_copies = [] + + self.post_hook = _post_hook + + # Prune configs if prune_configs_by: - perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] - if 'early_config_prune' in prune_configs_by: - early_config_prune = prune_configs_by['early_config_prune'] + perf_model, top_k = prune_configs_by["perf_model"], prune_configs_by["top_k"] + if "early_config_prune" in prune_configs_by: + early_config_prune = prune_configs_by["early_config_prune"] else: perf_model, top_k, early_config_prune = None, None, None self.perf_model, self.configs_top_k = perf_model, top_k self.early_config_prune = early_config_prune + self.fn = fn self.warmup = warmup self.rep = rep @@ -67,10 +105,8 @@ class Autotuner(KernelInterface): # as kwargs and by the autotuner conflicts = meta.keys() & config.kwargs.keys() if conflicts: - raise ValueError( - f"Conflicting meta-parameters: {', '.join(conflicts)}." - " Make sure that you don't re-define auto-tuned symbols." - ) + raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols.") # augment meta-parameters with tunable ones current = dict(meta, **config.kwargs) full_nargs = {**self.nargs, **current} @@ -78,16 +114,22 @@ class Autotuner(KernelInterface): def kernel_call(): if config.pre_hook: config.pre_hook(full_nargs) - self.hook(args) - self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, - num_ctas=config.num_ctas, - enable_warp_specialization=config.enable_warp_specialization, - # enable_persistent=False, - **current) + self.pre_hook(args) + self.fn.run( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + num_ctas=config.num_ctas, + enable_warp_specialization=config.enable_warp_specialization, + # enable_persistent=False, + **current, + ) + self.post_hook(args) + try: return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8)) except OutOfResources: - return [float('inf'), float('inf'), float('inf')] + return [float("inf"), float("inf"), float("inf")] def get_best_config(self): return self.best_config @@ -110,12 +152,11 @@ class Autotuner(KernelInterface): # prune configs pruned_configs = self.prune_configs(kwargs) bench_start = time.time() - timings = {config: self._bench(*args, config=config, **kwargs) - for config in pruned_configs} + timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} bench_end = time.time() self.bench_time = bench_end - bench_start self.cache[key] = builtins.min(timings, key=timings.get) - self.hook(args) + self.pre_hook(args, reset_only=True) self.configs_timings = timings if self.verbose: print(str(key) + ": " + str(self.cache[key])) @@ -126,9 +167,15 @@ class Autotuner(KernelInterface): full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs} if config.pre_hook is not None: config.pre_hook(full_nargs) - ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, - num_ctas=config.num_ctas, - enable_warp_specialization=config.enable_warp_specialization, **kwargs, **config.kwargs) + ret = self.fn.run( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + num_ctas=config.num_ctas, + enable_warp_specialization=config.enable_warp_specialization, + **kwargs, + **config.kwargs, + ) self.nargs = None return ret @@ -142,17 +189,20 @@ class Autotuner(KernelInterface): top_k = int(len(self.configs) * top_k) if len(pruned_configs) > top_k: est_timing = { - config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, - num_warps=config.num_warps, - num_ctas=config.num_ctas, - enable_warp_specialization=config.enable_warp_specialization, - enable_persistent=config.enable_persistent) + config: + self.perf_model( + **self.nargs, + **kwargs, + **config.kwargs, + num_stages=config.num_stages, + num_warps=config.num_warps, + num_ctas=config.num_ctas, + enable_warp_specialization=config.enable_warp_specialization, + enable_persistent=config.enable_persistent, + ) for config in pruned_configs } - pruned_configs = sorted( - est_timing.keys(), - key=lambda x: est_timing[x])[ - :top_k] + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] return pruned_configs def warmup(self, *args, **kwargs): @@ -195,13 +245,14 @@ class Config: self.num_ctas = num_ctas self.num_stages = num_stages self.enable_warp_specialization = enable_warp_specialization - # TODO[shuhaoj]: May make enable_persistent configurable in future if necessay. + # TODO[shuhaoj]: May make enable_persistent configurable in future if necessary. self.enable_persistent = False self.pre_hook = pre_hook def __str__(self): res = [] for k, v in self.kwargs.items(): +<<<<<<< HEAD res.append(f'{k}: {v}') res.append(f'num_warps: {self.num_warps}') ## Comment out Hopper specific parameters @@ -214,6 +265,18 @@ class Config: def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=False, warmup=25, rep=100): +======= + res.append(f"{k}: {v}") + res.append(f"num_warps: {self.num_warps}") + res.append(f"num_ctas: {self.num_ctas}") + res.append(f"num_stages: {self.num_stages}") + res.append(f"enable_warp_specialization: {self.enable_warp_specialization}") + res.append(f"enable_persistent: {self.enable_persistent}") + return ", ".join(res) + + +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, warmup=25, rep=100): +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 """ Decorator for auto-tuning a :code:`triton.jit`'d function. @@ -244,6 +307,8 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=Fa 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs. :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. :type reset_to_zero: list[str] + :param restore_value: a list of argument names whose value will be restored after evaluating any configs. + :type restore_value: list[str] :param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25. :type warmup: int :param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100. @@ -251,8 +316,13 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, verbose=Fa :param verbose: a boolean that controls whether the best_config for each key is printed :type verbose: bool """ + def decorator(fn): +<<<<<<< HEAD return Autotuner(fn, fn.arg_names, configs, key, verbose, reset_to_zero, prune_configs_by, warmup, rep) +======= + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, prune_configs_by, warmup, rep) +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 return decorator @@ -286,6 +356,7 @@ def heuristics(values): each such function takes a list of positional arguments as input. :type values: dict[str, Callable[[list[Any]], Any]] """ + def decorator(fn): return Heuristics(fn, fn.arg_names, values) diff --git a/python/triton/runtime/backends/cuda.c b/python/triton/runtime/backends/cuda.c index 278310473..089c9ffa6 100644 --- a/python/triton/runtime/backends/cuda.c +++ b/python/triton/runtime/backends/cuda.c @@ -1,27 +1,42 @@ #include "cuda.h" #include +#include #define PY_SSIZE_T_CLEAN #include -static inline void gpuAssert(CUresult code, const char *file, int line) { - if (code != CUDA_SUCCESS) { - const char *prefix = "Triton Error [CUDA]: "; - const char *str; - cuGetErrorString(code, &str); - char err[1024] = {0}; - strcat(err, prefix); - strcat(err, str); - PyGILState_STATE gil_state; - gil_state = PyGILState_Ensure(); - PyErr_SetString(PyExc_RuntimeError, err); - PyGILState_Release(gil_state); - } +// Raises a Python exception and returns false if code is not CUDA_SUCCESS. +static bool gpuAssert(CUresult code, const char *file, int line) { + if (code == CUDA_SUCCESS) + return true; + + const char *prefix = "Triton Error [CUDA]: "; + const char *str; + cuGetErrorString(code, &str); + char err[1024] = {0}; + strcat(err, prefix); + strcat(err, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); + PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); + return false; } -#define CUDA_CHECK(ans) \ - { \ - { gpuAssert((ans), __FILE__, __LINE__); } \ - } +// To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block. +#define CUDA_CHECK_AND_RETURN_NULL(ans) \ + do { \ + if (!gpuAssert((ans), __FILE__, __LINE__)) \ + return NULL; \ + } while (0) + +// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block. +#define CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \ + do { \ + if (!gpuAssert((ans), __FILE__, __LINE__)) { \ + PyEval_RestoreThread(_save); \ + return NULL; \ + } \ + } while (0) #define ADD_ENUM_ITEM(value) \ do { \ @@ -200,16 +215,16 @@ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { int sm_clock_rate; int mem_clock_rate; int mem_bus_width; - CUDA_CHECK(cuDeviceGetAttribute( + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( &max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device)); - CUDA_CHECK(cuDeviceGetAttribute( + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( &multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device)); - CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate, - CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device)); - CUDA_CHECK(cuDeviceGetAttribute( + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( &mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device)); - CUDA_CHECK(cuDeviceGetAttribute( + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( &mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device)); return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", @@ -237,33 +252,37 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) { CUcontext pctx = 0; Py_BEGIN_ALLOW_THREADS; - CUDA_CHECK(cuCtxGetCurrent(&pctx)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&pctx)); if (!pctx) { - CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device)); - CUDA_CHECK(cuCtxSetCurrent(pctx)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuDevicePrimaryCtxRetain(&pctx, device)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(pctx)); } - CUDA_CHECK(cuModuleLoadData(&mod, data)); - CUDA_CHECK(cuModuleGetFunction(&fun, mod, name)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuModuleLoadData(&mod, data)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuModuleGetFunction(&fun, mod, name)); // get allocated registers and spilled registers from the function - CUDA_CHECK(cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun)); - CUDA_CHECK( + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun)); n_spills /= 4; // set dynamic shared memory if necessary int shared_optin; - CUDA_CHECK(cuDeviceGetAttribute( + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute( &shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device)); if (shared > 49152 && shared_optin > 49152) { - CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED)); int shared_total, shared_static; - CUDA_CHECK(cuDeviceGetAttribute( + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute( &shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device)); - CUDA_CHECK(cuFuncGetAttribute(&shared_static, - CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun)); - CUDA_CHECK( + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute( + &shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static)); } @@ -286,7 +305,7 @@ static PyObject *memAlloc(PyObject *self, PyObject *args) { } Py_BEGIN_ALLOW_THREADS; - CUDA_CHECK(cuMemAlloc(&dptr, bytesize)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuMemAlloc(&dptr, bytesize)); Py_END_ALLOW_THREADS; return PyLong_FromUnsignedLongLong((unsigned long long)dptr); @@ -307,7 +326,8 @@ static PyObject *memcpyHtoD(PyObject *self, PyObject *args) { srcHost = (const void *)srcHostPtr; Py_BEGIN_ALLOW_THREADS; - CUDA_CHECK(cuMemcpyHtoD(dstDevice, srcHost, byteCount)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuMemcpyHtoD(dstDevice, srcHost, byteCount)); Py_END_ALLOW_THREADS; Py_RETURN_NONE; @@ -321,7 +341,7 @@ static PyObject *memFree(PyObject *self, PyObject *args) { } Py_BEGIN_ALLOW_THREADS; - CUDA_CHECK(cuMemFree(dptr)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuMemFree(dptr)); Py_END_ALLOW_THREADS; Py_RETURN_NONE; @@ -411,7 +431,7 @@ static PyObject *tensorMapEncodeTiled(PyObject *self, PyObject *args) { } // Call the function Py_BEGIN_ALLOW_THREADS; - CUDA_CHECK(cuTensorMapEncodeTiledHandle( + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuTensorMapEncodeTiledHandle( tensorMap, tensorDataType, tensorRank, globalAddress, globalDim, globalStrides, boxDim, elementStrides, interleave, swizzle, l2Promotion, oobFill)); diff --git a/python/triton/runtime/cache.py b/python/triton/runtime/cache.py index e4721cbe3..3799eaf08 100644 --- a/python/triton/runtime/cache.py +++ b/python/triton/runtime/cache.py @@ -19,6 +19,7 @@ def default_dump_dir(): class CacheManager(ABC): + def __init__(self, key): pass @@ -44,20 +45,21 @@ class CacheManager(ABC): class FileCacheManager(CacheManager): + def __init__(self, key, override=False, dump=False): self.key = key self.lock_path = None - if (dump): + if dump: self.cache_dir = default_dump_dir() self.cache_dir = os.path.join(self.cache_dir, self.key) self.lock_path = os.path.join(self.cache_dir, "lock") os.makedirs(self.cache_dir, exist_ok=True) - elif (override): + elif override: self.cache_dir = default_override_dir() self.cache_dir = os.path.join(self.cache_dir, self.key) else: # create cache directory if it doesn't exist - self.cache_dir = os.getenv('TRITON_CACHE_DIR', "").strip() or default_cache_dir() + self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir() if self.cache_dir: self.cache_dir = os.path.join(self.cache_dir, self.key) self.lock_path = os.path.join(self.cache_dir, "lock") @@ -93,9 +95,8 @@ class FileCacheManager(CacheManager): result = {} for c in child_paths: p = self._make_path(c) - if not os.path.exists(p): - raise Exception(f"Group file {p} does not exist from group {grp_filename} ") - result[c] = p + if os.path.exists(p): + result[c] = p return result # Note a group of pushed files as being part of a group @@ -142,6 +143,7 @@ def get_cache_manager(key) -> CacheManager: if user_cache_manager is not None and user_cache_manager != __cache_cls_nme: import importlib + module_path, clz_nme = user_cache_manager.split(":") module = importlib.import_module(module_path) __cache_cls = getattr(module, clz_nme) diff --git a/python/triton/runtime/driver.py b/python/triton/runtime/driver.py index 5b778be4d..767a567c4 100644 --- a/python/triton/runtime/driver.py +++ b/python/triton/runtime/driver.py @@ -9,7 +9,6 @@ from .cache import get_cache_manager class DriverBase(metaclass=abc.ABCMeta): - CUDA = 0 HIP = 1 @@ -19,6 +18,8 @@ class DriverBase(metaclass=abc.ABCMeta): def __init__(self) -> None: pass + + # ----------------------------- # CUDA # ----------------------------- @@ -27,7 +28,7 @@ class DriverBase(metaclass=abc.ABCMeta): class CudaUtils(object): def __new__(cls): - if not hasattr(cls, 'instance'): + if not hasattr(cls, "instance"): cls.instance = super(CudaUtils, cls).__new__(cls) return cls.instance @@ -47,6 +48,7 @@ class CudaUtils(object): with open(so, "rb") as f: cache_path = cache.put(f.read(), fname, binary=True) import importlib.util + spec = importlib.util.spec_from_file_location("cuda_utils", cache_path) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) @@ -66,7 +68,7 @@ class CudaUtils(object): class CudaDriver(DriverBase): def __new__(cls): - if not hasattr(cls, 'instance'): + if not hasattr(cls, "instance"): cls.instance = super(CudaDriver, cls).__new__(cls) return cls.instance @@ -74,14 +76,16 @@ class CudaDriver(DriverBase): self.utils = CudaUtils() self.backend = self.CUDA + # ----------------------------- # HIP # ----------------------------- class HIPUtils(object): + def __new__(cls): - if not hasattr(cls, 'instance'): + if not hasattr(cls, "instance"): cls.instance = super(HIPUtils, cls).__new__(cls) return cls.instance @@ -101,6 +105,7 @@ class HIPUtils(object): with open(so, "rb") as f: cache_path = cache.put(f.read(), fname, binary=True) import importlib.util + spec = importlib.util.spec_from_file_location("hip_utils", cache_path) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) @@ -111,7 +116,7 @@ class HIPUtils(object): class HIPDriver(DriverBase): def __new__(cls): - if not hasattr(cls, 'instance'): + if not hasattr(cls, "instance"): cls.instance = super(HIPDriver, cls).__new__(cls) return cls.instance @@ -123,7 +128,7 @@ class HIPDriver(DriverBase): class UnsupportedDriver(DriverBase): def __new__(cls): - if not hasattr(cls, 'instance'): + if not hasattr(cls, "instance"): cls.instance = super(UnsupportedDriver, cls).__new__(cls) return cls.instance @@ -131,12 +136,14 @@ class UnsupportedDriver(DriverBase): self.utils = None self.backend = None + # ----------------------------- # Driver # ----------------------------- class LazyProxy: + def __init__(self, init_fn): self._init_fn = init_fn self._obj = None @@ -150,7 +157,7 @@ class LazyProxy: return getattr(self._obj, name) def __setattr__(self, name, value): - if name in ['_init_fn', '_obj']: + if name in ["_init_fn", "_obj"]: super().__setattr__(name, value) else: self._initialize_obj() @@ -172,6 +179,7 @@ class LazyProxy: def initialize_driver(): import torch + if torch.version.hip is not None: return HIPDriver() elif torch.cuda.is_available(): diff --git a/python/triton/runtime/errors.py b/python/triton/runtime/errors.py index 4ff900574..a5d69aba6 100644 --- a/python/triton/runtime/errors.py +++ b/python/triton/runtime/errors.py @@ -1,10 +1,8 @@ - class OutOfResources(Exception): + def __init__(self, required, limit, name): - self.message = f'out of resource: {name}, '\ - f'Required: {required}, '\ - f'Hardware limit: {limit}' - self.message += '. Reducing block sizes or `num_stages` may help.' + self.message = f"out of resource: {name}, " f"Required: {required}, " f"Hardware limit: {limit}" + self.message += ". Reducing block sizes or `num_stages` may help." self.required = required self.limit = limit self.name = name diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index ac364ffa7..c8b70bfb0 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -74,11 +74,15 @@ class BlockPointerHandle: def wrap_ret(compute_ret_ty): + def wrapper(fn): + def wrapped(*args, **kwargs): ret = fn(*args, **kwargs) return TensorHandle(ret.data, compute_ret_ty(*args, **kwargs)) + return wrapped + return wrapper @@ -249,11 +253,13 @@ class Builder: # ternary functions def ternary_op(self, lhs, rhs, other, op): return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype) + create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where) # unary functions def unary_op(self, arg, op): return TensorHandle(op(arg.data), arg.dtype) + create_exp = lambda self, arg: self.unary_op(arg, np.exp) create_cos = lambda self, arg: self.unary_op(arg, np.cos) create_sin = lambda self, arg: self.unary_op(arg, np.sin) @@ -279,7 +285,8 @@ class Builder: dtype_tt = ptr.dtype.element_ty return TensorHandle(ptr.data + (dtype_tt.primitive_bitwidth // 8) * offset.data.astype(np.uint64), ptr.dtype) - def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, is_volatile): + def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, + is_volatile): ptrs, masks = ptr.materialize_pointers(boundary_check) assert padding_option is None other = None @@ -297,6 +304,7 @@ class Builder: def create_int_to_ptr(self, val, dst_ty): return TensorHandle(val.data.astype(np.uint64), dst_ty) + # def create_cat(self, lhs, rhs): # pass @@ -360,7 +368,10 @@ class Builder: def patch_attr(obj, name, member, builder): - new_member = lambda *args, member=member, **kwargs: (member(*args, **{k: v for k, v in kwargs.items() if k != '_builder'}, _builder=builder)) + new_member = lambda *args, member=member, **kwargs: (member(*args, ** + {k: v + for k, v in kwargs.items() + if k != "_builder"}, _builder=builder)) setattr(obj, name, new_member) @@ -384,8 +395,8 @@ def _patch_lang_core(lang, builder): def _new_reduce(input, axis, combine_fn): fn = combine_fn.fn.__name__ mapping = { - 'maximum': np.max, - '_sum_combine': np.sum, + "maximum": np.max, + "_sum_combine": np.sum, } ret = mapping[fn](input.handle.data, axis=axis) ret_type = tl.block_type(input.dtype, ret.shape) @@ -397,15 +408,16 @@ def _patch_lang_core(lang, builder): def _patch_lang_math(lang, builder): math = lang.math mapping = { - 'abs': 'abs', - 'acos': 'arccos', - 'asin': 'arcsin', - 'exp2': 'exp2', - 'log2': 'log2', - 'max': 'maximum', + "abs": "abs", + "acos": "arccos", + "asin": "arcsin", + "exp2": "exp2", + "log2": "log2", + "max": "maximum", } def make_numpy(name): + def impl(*args, **kwargs): ret_type = args[0].type # TODO: incorrect ret_dtype = args[0].dtype # TODO: incorrect @@ -414,15 +426,18 @@ def _patch_lang_math(lang, builder): ret = getattr(np, mapping[name])(*args, **kwargs) ret = tl.core.tensor(TensorHandle(ret, ret_dtype), ret_type) return ret + return impl def make_fallback(name): + def fallback(*args, **kwargs): raise NotImplementedError(f""" {name} not supported in interpreter mode: no known numpy implementation. If you think that {name} in fact does have a numpy implementation, please add it to the mapping in python/triton/interpreter/new_interpreter.py:_patch_lang_math. """) + return fallback for name, member in inspect.getmembers(math): @@ -438,7 +453,7 @@ def _implicit_cvt(arg): ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg))) handle = TensorHandle(np.array([arg], dtype=np.int32), ty) return tl.tensor(handle, ty) - if hasattr(arg, 'data_ptr'): + if hasattr(arg, "data_ptr"): ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg))) handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty) return tl.tensor(handle, ty) @@ -453,28 +468,29 @@ def _unwrap(tensor): builder = Builder() -RESERVED_KWS = ['num_warps', 'num_stages', 'num_ctas', 'enable_warp_specialization', 'enable_fp_fusion'] +RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_warp_specialization", "enable_fp_fusion"] class GridExecutor: def __init__(self, fn, arg_names, grid): from .jit import _normalize_ty # TODO: modularize + self.fn = fn self.arg_names = arg_names self.grid = grid __annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()} - self.constexprs = [name for name in arg_names if __annotations__.get(name) == 'constexpr'] + self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"] def _patch_lang(self, builder): lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]] assert len(lang) == 1, "triton.language must be visible from within jit'd function" - _patch_lang_tensor(getattr(lang[0], 'tensor'), builder) + _patch_lang_tensor(getattr(lang[0], "tensor"), builder) _patch_lang_core(lang[0], builder) _patch_lang_math(lang[0], builder) def __call__(self, *args_dev, **kwargs): - args_hst = [_unwrap(arg).cpu() if hasattr(arg, 'data_ptr') else arg for arg in args_dev] + args_hst = [_unwrap(arg).cpu() if hasattr(arg, "data_ptr") else arg for arg in args_dev] # removes reserved keywords from kwargs kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS} # remaps core language functions to interpreted ones @@ -486,7 +502,7 @@ class GridExecutor: # iterate through grid grid = self.grid(args) if callable(self.grid) else self.grid assert len(grid) <= 3 - grid = grid + (1,) * (3 - len(grid)) + grid = grid + (1, ) * (3 - len(grid)) builder.set_grid_dim(*grid) for x in range(grid[0]): for y in range(grid[1]): @@ -495,7 +511,7 @@ class GridExecutor: self.fn(**args) # copy arguments back to propagate side-effects for arg_dev, arg_hst in zip(args_dev, args_hst): - if hasattr(arg_dev, 'data_ptr'): + if hasattr(arg_dev, "data_ptr"): _unwrap(arg_dev).copy_(arg_hst.to(arg_dev.device)) @@ -504,17 +520,18 @@ class InterpretedFunction: def _patch_lang(self, builder): lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]] assert len(lang) == 1, "triton.language must be visible from within jit'd function" - _patch_lang_tensor(getattr(lang[0], 'tensor'), builder) + _patch_lang_tensor(getattr(lang[0], "tensor"), builder) _patch_lang_core(lang[0], builder) def __init__(self, fn) -> None: self.fn = fn def run(*args, **kwargs): - grid = kwargs['grid'] - kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS + ['grid']} + grid = kwargs["grid"] + kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS + ["grid"]} return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs) + self.run = run signature = inspect.signature(fn) self.arg_names = [v.name for v in signature.parameters.values()] diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 3120e1b7a..92a14682b 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -5,48 +5,48 @@ import functools import hashlib import inspect import os -import subprocess import textwrap from collections import defaultdict, namedtuple -from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast, - overload) +from functools import cached_property +from typing import Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast, overload from .._C.libtriton.triton import TMAInfos -from ..common.backend import get_backend, path_to_ptxas -from ..language.core import dtype +from ..common.backend import get_backend, get_cuda_version_key from .interpreter import InterpretedFunction -TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -TRITON_VERSION = "2.1.0" - def get_cuda_stream(idx=None): if idx is None: idx = get_current_device() try: from torch._C import _cuda_getCurrentRawStream + return _cuda_getCurrentRawStream(idx) except ImportError: import torch + return torch.cuda.current_stream(idx).cuda_stream def get_current_device(): import torch + return torch.cuda.current_device() def set_current_device(idx): import torch + torch.cuda.set_device(idx) def get_device_capability(idx): import torch + return torch.cuda.get_device_capability(idx) -T = TypeVar('T') +T = TypeVar("T") # ----------------------------------------------------------------------------- # Dependencies Finder @@ -72,7 +72,8 @@ class DependenciesFinder(ast.NodeVisitor): lhs = self.visit(node.value) while isinstance(lhs, ast.Attribute): lhs = self.visit(lhs.value) - if lhs is None or (getattr(lhs, "__name__", "") == "triton" or getattr(lhs, "__name__", "").endswith(".triton")): + if lhs is None or (getattr(lhs, "__name__", "") == "triton" + or getattr(lhs, "__name__", "").endswith(".triton")): return None return getattr(lhs, node.attr) @@ -82,55 +83,26 @@ class DependenciesFinder(ast.NodeVisitor): return if inspect.isbuiltin(func): return - if func.__module__ and (func.__module__.startswith('triton.') or '.triton.' in func.__module__): + if func.__module__ and (func.__module__.startswith("triton.") or ".triton." in func.__module__): return - assert isinstance(func, JITFunction), f"Function \"{func.__name__}\" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this" + assert isinstance( + func, JITFunction + ), f'Function "{func.__name__}" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this' if func.hash is None: tree = ast.parse(func.src) finder = DependenciesFinder(func.__globals__, func.src) finder.visit(tree) func.hash = finder.ret - noinline = str(getattr(func, 'noinline', False)) + noinline = str(getattr(func, "noinline", False)) self.ret = (self.ret + func.hash + noinline).encode("utf-8") self.ret = hashlib.sha1(self.ret).hexdigest() + # ----------------------------------------------------------------------------- # JITFunction # ----------------------------------------------------------------------------- -@functools.lru_cache() -def version_key(): - import pkgutil - contents = [] - # frontend - with open(__file__, "rb") as f: - contents += [hashlib.sha1(f.read()).hexdigest()] - # compiler - compiler_path = os.path.join(TRITON_PATH, 'compiler') - for lib in pkgutil.iter_modules([compiler_path]): - with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: - contents += [hashlib.sha1(f.read()).hexdigest()] - # backend - libtriton_hash = hashlib.sha1() - with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f: - while True: - chunk = f.read(1024 ** 2) - if not chunk: - break - libtriton_hash.update(chunk) - contents.append(libtriton_hash.hexdigest()) - # language - language_path = os.path.join(TRITON_PATH, 'language') - for lib in pkgutil.iter_modules([language_path]): - with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: - contents += [hashlib.sha1(f.read()).hexdigest()] - # ptxas version - ptxas = path_to_ptxas()[0] - ptxas_version = hashlib.sha1(subprocess.check_output([ptxas, "--version"])).hexdigest() - return '-'.join(TRITON_VERSION) + '-' + ptxas_version + '-' + '-'.join(contents) - - def _normalize_ty(ty) -> str: if isinstance(ty, type): return ty.__name__ @@ -139,6 +111,85 @@ def _normalize_ty(ty) -> str: return repr(ty) +class KernelParam: + """Represents a parameter to a @jit'ed function. + + A parameter is just the name plus metadata; a parameter plus a value is a + KernelArg. + """ + + def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool): + self.num = num + self._param = param + self.do_not_specialize = do_not_specialize + + @cached_property + def name(self): + return self._param.name + + @cached_property + def annotation(self): + if not self._param.annotation or self._param.annotation == inspect.Parameter.empty: + return "" + return _normalize_ty(self._param.annotation) + + @cached_property + def is_constexpr(self): + return "constexpr" in self.annotation + + @property + def default(self): + return self._param.default + + @property + def has_default(self): + return self._param.default != inspect.Parameter.empty + + +class KernelArg: + """Represents an argument to a @jit'ed function. + + An argument is a parameter plus a value. + """ + + def __init__(self, value, param): + self.value = value + self.param = param + + @property + def name(self): + return self.param.name + + def signature_key(self): + annotation = self.param.annotation + if "Tensor" in annotation: + return self.value.dtype + elif annotation == "bool": + return "i1" + elif annotation == "float": + return "fp32" + else: + return JITFunction._key_of(self.value) + + def specialization_key(self): + assert not self.param.do_not_specialize + + try: + return (self.value.data_ptr() % JITFunction.divisibility == 0, ) + except AttributeError: + pass + + if isinstance(self.value, int): + # bool is a subclass of int, so we don't check explicitly above. + return ( + self.value % JITFunction.divisibility == 0, + self.value % JITFunction.divisibility_8 == 0, + self.value == 1, + ) + + return (False, ) + + class KernelInterface(Generic[T]): run: T @@ -152,7 +203,6 @@ class KernelInterface(Generic[T]): class JITFunction(KernelInterface[T]): - # Hook for inspecting compiled functions and modules cache_hook = None divisibility = 16 @@ -169,44 +219,44 @@ class JITFunction(KernelInterface[T]): elif isinstance(arg, bool): return "i1" elif isinstance(arg, int): - if -2**31 <= arg and arg <= 2**31 - 1: + if -(2**31) <= arg and arg <= 2**31 - 1: return "i32" elif 2**63 <= arg and arg <= 2**64 - 1: return "u64" else: return "i64" elif isinstance(arg, float): - return 'fp32' + return "fp32" elif arg is None: return None else: - raise TypeError(f'Unsupported type {type(arg)} for {arg}') + raise TypeError(f"Unsupported type {type(arg)} for {arg}") @staticmethod def _device_of(arg): - if hasattr(arg, "device"): - if hasattr(arg.device, 'type'): - return arg.device.type - - return '' + try: + return arg.device.type + except AttributeError: + return "" @staticmethod def _pinned_memory_of(arg): - if hasattr(arg, "is_pinned"): - if isinstance(arg.is_pinned, Callable): - return arg.is_pinned() - - return False + try: + return arg.is_pinned() + except (AttributeError, TypeError): + return False @staticmethod def _spec_of(arg): if hasattr(arg, "data_ptr"): - return (arg.data_ptr() % JITFunction.divisibility == 0) + return arg.data_ptr() % JITFunction.divisibility == 0 elif isinstance(arg, int): return (arg % 16 == 0, arg == 1) return (arg is None, ) + # TODO(jlebar): Fold this into the KernelArg class. def _get_config(self, *args): + def is_divisible_by_16(x): if hasattr(x, "data_ptr"): return x.data_ptr() % JITFunction.divisibility == 0 @@ -222,28 +272,38 @@ class JITFunction(KernelInterface[T]): if x is None: return True return False - divisible_by_16 = {i for i, arg in enumerate( - args) if is_divisible_by_16(arg) and i not in self.do_not_specialize} - divisible_by_8 = {i for i, arg in enumerate( - args) if is_divisible_by_8(arg) and i not in self.do_not_specialize} + + divisible_by_16 = { + param.num + for param, arg in zip(self.params, args) + if is_divisible_by_16(arg) and not param.do_not_specialize + } + divisible_by_8 = { + param.num + for param, arg in zip(self.params, args) + if is_divisible_by_8(arg) and not param.do_not_specialize + } equal_to_1 = { - i for i, arg in enumerate(args) if isinstance( - arg, int) and not isinstance( - arg, bool) and arg == 1 and i not in self.do_not_specialize} + param.num + for param, arg in zip(self.params, args) + if isinstance(arg, int) and not isinstance(arg, bool) and arg == 1 and not param.do_not_specialize + } # folded equal_to_1 and None # TODO: method to collect all folded args - none_args = {i for i, arg in enumerate(args) if arg is None and i not in self.do_not_specialize} + none_args = {param.num for param, arg in zip(self.params, args) if arg is None and not param.do_not_specialize} ids_of_folded_args = equal_to_1 | none_args - return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])( - tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), tuple(divisible_by_8)) + return namedtuple("instance_descriptor", + ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])( # + tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), + tuple(divisible_by_8)) # return _triton.code_gen.instance_descriptor(divisible_by_16, # equal_to_1) @staticmethod def _type_of(key): - # None are nullptr -- implicitly converted to *i8 + # `None` is nullptr. Implicitly convert to *i8. if key is None: - return '*i8' + return "*i8" dtype_str = str(key).split(".")[-1] tys = { "bool": "i1", @@ -281,21 +341,46 @@ class JITFunction(KernelInterface[T]): constants = dict(zip(self.constexprs, constexpr_key)) return constants +<<<<<<< HEAD def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization,enable_fp_fusion, extern_libs, configs): +======= + def _call_hook( + self, + key, + signature, + device, + constants, + num_warps, + num_ctas, + num_stages, + enable_warp_specialization, + enable_fp_fusion, + extern_libs, + configs, + ): +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 if JITFunction.cache_hook is None: return False + name = self.fn.__name__ module = self.fn.__module__ +<<<<<<< HEAD arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])]) repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, matrix_instr_nonkdim={matrix_instr_nonkdim}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs}), enable_fp_fusion={enable_fp_fusion}]({arg_reprs})" +======= + arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])]) + repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, enable_warp_specialization={enable_warp_specialization}, enable_fp_fusion={enable_fp_fusion}]({arg_reprs})" +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 key = str(key) class LegacyCompiler: + def __init__(self, module, name): self.module = module self.name = name pass +<<<<<<< HEAD kwargs = dict(signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs, configs=configs) @@ -326,18 +411,43 @@ class JITFunction(KernelInterface[T]): return 'fp32' else: return self._key_of(arg) +======= + kwargs = dict( + signature=signature, + device=device, + constants=constants, + num_warps=num_warps, + num_ctas=num_ctas, + num_stages=num_stages, + enable_warp_specialization=enable_warp_specialization, + enable_fp_fusion=enable_fp_fusion, + extern_libs=extern_libs, + configs=configs, + ) + + return JITFunction.cache_hook( + key=key, + repr=repr, + fn=LegacyCompiler(module, name), + compile={"key": key, **kwargs}, + is_manual_warmup=False, + already_compiled=False, + ) +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: List[bool]) -> str: - device_types = [device_type for device_type in device_types if device_type != ''] + device_types = [device_type for device_type in device_types if device_type != ""] # Return cuda if one of the input tensors is cuda - if 'cuda' in device_types: + if "cuda" in device_types: import torch - return 'hip' if torch.version.hip else 'cuda' - is_cpu = all(device_type == 'cpu' for device_type in device_types) + return "hip" if torch.version.hip else "cuda" + + is_cpu = all(device_type == "cpu" for device_type in device_types) is_pinned_memory = any(pinned_memory_flag for pinned_memory_flag in pinned_memory_flags) # Return cuda if all the input tensors are cpu while the memory is pinned if is_cpu and is_pinned_memory: +<<<<<<< HEAD return 'cuda' return device_types[0] if len(device_types) > 0 else 'cuda' @@ -452,16 +562,193 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu scope = {"launcher_body": launcher_body} exec(src, scope) return scope[self.fn.__name__] +======= + return "cuda" + + return device_types[0] if len(device_types) > 0 else "cuda" + + def run(self, *args, **kwargs): + from ..compiler import CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps + + # Get a compiler-flags arg like `num_warps` and remove it from kwargs. + def get_special_arg(name: str, default=None): + if name not in kwargs: + return default + ret = kwargs[name] + del kwargs[name] + return ret + + grid = get_special_arg("grid") + num_warps = get_special_arg("num_warps") + num_ctas = get_special_arg("num_ctas", 1) + num_stages = get_special_arg("num_stages") + enable_warp_specialization = get_special_arg("enable_warp_specialization", False) + enable_fp_fusion = get_special_arg("enable_fp_fusion", True) + extern_libs = get_special_arg("extern_libs") + stream = get_special_arg("stream") + warmup = get_special_arg("warmup", False) + device = get_special_arg("device") + device_type = get_special_arg("device_type") + + # Bind the remaining arguments to `fn`. + bound_args = self.signature.bind(*args, **kwargs) + bound_args.apply_defaults() + + assert len(bound_args.arguments) == len(self.params) + args = [KernelArg(arg_value, param) for (_, arg_value), param in zip(bound_args.arguments.items(), self.params)] + + non_constexpr_arg_values = [arg.value for arg in args if not arg.param.is_constexpr] + + sig_key = tuple(arg.signature_key() for arg in args if not arg.param.is_constexpr) + spec_key = tuple(arg.specialization_key() for arg in args if not arg.param.do_not_specialize) + constexpr_key = tuple(arg.value for arg in args if arg.param.is_constexpr) + + assert num_ctas > 0 + assert grid is not None + if callable(grid): + # Arguments are passed as a dict to `grid`, by contract. + # TODO(jlebar): In the new launch API, pass the compiler flags as a + # second parameter to `grid`. + grid = grid(dict(bound_args.arguments)) + grid_size = len(grid) + grid_0 = grid[0] + grid_1 = grid[1] if grid_size > 1 else 1 + grid_2 = grid[2] if grid_size > 2 else 1 + if device_type is None: + device_types = [self._device_of(arg) for arg in non_constexpr_arg_values] + device_types = [_device_type for _device_type in device_types if _device_type != ""] + device_type = self._conclude_device_type(device_types, + [self._pinned_memory_of(arg) for arg in non_constexpr_arg_values]) + + device_backend = None + if device_type not in ["cuda"]: + device_backend = get_backend(device_type) + if device_backend is None: + raise ValueError("Cannot find backend for " + device_type) + + if device is None: + if device_type in ["cuda"]: + device = get_current_device() + set_current_device(device) + else: + device = device_backend.get_current_device() + device_backend.set_current_device(device) + if stream is None and not warmup: + if device_type in ["cuda"]: + stream = get_cuda_stream(device) + else: + stream = device_backend.get_stream() + + if num_warps is None: + num_warps = get_arch_default_num_warps(device_type) + if num_stages is None: + num_stages = get_arch_default_num_stages(device_type) + + if device_type in ["cuda"]: + version_key = get_cuda_version_key() + else: + version_key = device_backend.get_version_key() + key = ( + version_key, + sig_key, + constexpr_key, + spec_key, + num_warps, + num_ctas, + num_stages, + enable_warp_specialization, + enable_fp_fusion, + self.debug, + ) + if extern_libs is not None: + key = (key, tuple(extern_libs.items())) + + # Kernel is not cached; we have to compile. + if key not in self.cache[device]: + configs = (self._get_config(*[arg.value for arg in args]), ) + constants = { + arg.param.num: arg.value + for arg in args + if arg.param.is_constexpr or arg.param.num in configs[0].equal_to_1 or arg.value is None + } + for i, arg in constants.items(): + if callable(arg): + raise TypeError(f"Callable constexpr at index {i} is not supported") + + # Build kernel signature -- doesn't include constexpr arguments. + signature = { + arg.param.num: self._type_of(self._key_of(arg.value)) + for arg in args + if not arg.param.is_constexpr + } + + if self._call_hook( + key, + signature, + device, + constants, + num_warps, + num_ctas, + num_stages, + enable_warp_specialization, + enable_fp_fusion, + extern_libs, + configs, + ): + return None + + self.cache[device][key] = compile( + self, + signature=signature, + device=device, + constants=constants, + num_warps=num_warps, + num_ctas=num_ctas, + num_stages=num_stages, + enable_warp_specialization=enable_warp_specialization, + enable_fp_fusion=enable_fp_fusion, + extern_libs=extern_libs, + configs=configs, + debug=self.debug, + device_type=device_type, + ) + + bin = self.cache[device][key] + if not warmup: + bin.c_wrapper( + grid_0, + grid_1, + grid_2, + bin.num_warps, + bin.num_ctas, + bin.clusterDims[0], + bin.clusterDims[1], + bin.clusterDims[2], + bin.shared, + stream, + bin.cu_function, + CompiledKernel.launch_enter_hook, + CompiledKernel.launch_exit_hook, + bin, + *bin.assemble_tensormap_to_arg(non_constexpr_arg_values), + ) + return bin +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None): + do_not_specialize = do_not_specialize if do_not_specialize else [] + self.fn = fn self.module = fn.__module__ self.version = version - # function signature information - signature = inspect.signature(fn) - self.arg_names = [v.name for v in signature.parameters.values()] - self.arg_defaults = [v.default for v in signature.parameters.values()] - self.has_defaults = any(v != inspect._empty for v in self.arg_defaults) + self.signature = inspect.signature(fn) + self.do_not_specialize = do_not_specialize + + self.params = [] + for i, param in enumerate(self.signature.parameters.values()): + dns = do_not_specialize and (i in do_not_specialize or param.name in do_not_specialize) + self.params.append(KernelParam(i, param, dns)) + # function source code (without decorators) self.src = textwrap.dedent(inspect.getsource(fn)) self.src = self.src[self.src.find("def"):] @@ -470,22 +757,18 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu self.hash = None # JITFunction can be instantiated as kernel # when called with a grid using __getitem__ - self.kernel_decorators = [] self.kernel = None self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug self.noinline = noinline - # annotations - self.__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()} - # index of constexprs - self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty] - # specialization hints - regular_args = [arg for i, arg in enumerate(self.arg_names) if i not in self.constexprs] - self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize - self.do_not_specialize = {regular_args.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize} + # tma info self.tensormaps_info = TMAInfos() - # launcher - self.run = self._make_launcher() + + # TODO(jlebar): Remove uses of these fields outside this file, then + # remove the fields here. + self.arg_names = [p.name for p in self.params] + self.constexprs = [p.num for p in self.params if p.is_constexpr] + # re-use docs of wrapped function self.__doc__ = fn.__doc__ self.__name__ = fn.__name__ @@ -498,7 +781,7 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu if self.hash is None: dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src) dependencies_finder.visit(self.parse()) - self.hash = dependencies_finder.ret + version_key() + self.hash = dependencies_finder.ret return self.hash def warmup(self, *args, **kwargs): @@ -518,14 +801,10 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel") def __setattr__(self, name, value): - # - when kernel decorators change, cached kernel - # needs to be cleared - if name == 'kernel_decorators': - self.kernel = None super(JITFunction, self).__setattr__(name, value) # - when `.src` attribute is set, cache path needs # to be reinitialized - if name == 'src': + if name == "src": self.hash = None def __repr__(self): @@ -591,12 +870,14 @@ def jit( debug=debug, noinline=noinline, ) + if fn is not None: return decorator(fn) else: return decorator + # ----------------------------------------------------------------------------- # Utilities for mocking tensors # ----------------------------------------------------------------------------- @@ -607,10 +888,10 @@ class MockTensor: Can be used in place of real tensors when calling: kernel.warmup(MockTensor(torch.float32), ...) """ + @staticmethod def wrap_dtype(arg): - if arg.__class__.__name__ == "dtype" and\ - arg.__module__ == "torch": + if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch": return MockTensor(arg) return arg @@ -623,6 +904,7 @@ class MockTensor: class TensorWrapper: + def __init__(self, base, dtype): self.dtype = dtype self.base = base @@ -637,7 +919,7 @@ class TensorWrapper: return self.base.stride(i) def __str__(self) -> str: - return f'TensorWrapper[{self.dtype}]({self.base})' + return f"TensorWrapper[{self.dtype}]({self.base})" def element_size(self): return self.base.element_size() @@ -655,4 +937,4 @@ def reinterpret(tensor, dtype): # A new wrapper is needed around an unwrapped tensor. return TensorWrapper(tensor, dtype) else: - raise TypeError(f'Cannot reinterpret a {type(tensor)}.') + raise TypeError(f"Cannot reinterpret a {type(tensor)}.") diff --git a/python/triton/testing.py b/python/triton/testing.py index f01d4f8e3..848db7c1b 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -78,10 +78,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None): return torch.mean(torch.tensor(ret)).item() -def do_bench(fn, warmup=25, rep=100, grad_to_none=None, - quantiles=None, - fast_flush=True, - return_mode="mean"): +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"): assert return_mode in ["min", "max", "mean", "median"] import torch """ @@ -261,11 +258,12 @@ class Benchmark: class Mark: + def __init__(self, fn, benchmarks): self.fn = fn self.benchmarks = benchmarks - def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, **kwrags): + def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False, **kwrags): import os import matplotlib.pyplot as plt @@ -321,24 +319,36 @@ class Mark: if save_path: plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png")) df = df[x_names + bench.line_names] + if diff_col and df.shape[1] == 2: + col0, col1 = df.columns.tolist() + df['Diff'] = df[col1] - df[col0] + if print_data: print(bench.plot_name + ':') print(df) if save_path: df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False) + return df - def run(self, show_plots=False, print_data=False, save_path='', **kwargs): + def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs): has_single_bench = isinstance(self.benchmarks, Benchmark) benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks + result_dfs = [] if save_path: html = open(os.path.join(save_path, "results.html"), "w") html.write("\n") for bench in benchmarks: - self._run(bench, save_path, show_plots, print_data, **kwargs) + result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs)) if save_path: html.write(f"\n") if save_path: html.write("\n") + if return_df: + if has_single_bench: + return result_dfs[0] + else: + return result_dfs + return None def perf_report(benchmarks): @@ -393,12 +403,15 @@ def get_max_tensorcore_tflops(dtype, clock_rate, backend=None, device=None): tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 return tflops + # create decorator that wraps test function into # a cuda-memcheck system call def cuda_memcheck(**target_kwargs): + def decorator(test_fn): + @functools.wraps(test_fn) def wrapper(*args, **kwargs): import psutil @@ -416,7 +429,9 @@ def cuda_memcheck(**target_kwargs): assert "ERROR SUMMARY: 0 errors" in str(out.stdout) else: test_fn(*args, **kwargs) + return wrapper + return decorator @@ -424,22 +439,18 @@ def cuda_memcheck(**target_kwargs): def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215): try: subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"]) - subprocess.check_output( - [ - "nvidia-smi", - "-i", - "0", - f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}", - ] - ) - subprocess.check_output( - [ - "nvidia-smi", - "-i", - "0", - f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}", - ] - ) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}", + ]) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}", + ]) cur_sm_clock = nvsmi(["clocks.current.sm"])[0] cur_mem_clock = nvsmi(["clocks.current.memory"])[0] assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz" diff --git a/python/triton/tools/build_extern.py b/python/triton/tools/build_extern.py index f19fbd561..6f00e8192 100644 --- a/python/triton/tools/build_extern.py +++ b/python/triton/tools/build_extern.py @@ -141,8 +141,7 @@ class ExternLibrary(ABC): f.write(file_str) f.close() if self._format: - subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], - stdout=subprocess.PIPE).communicate() + subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], stdout=subprocess.PIPE).communicate() subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate() @@ -208,56 +207,36 @@ class Libdevice(ExternLibrary): # Group functions together by renaming. renaming = { - 'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh', - 'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': 'add_rn', - 'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru', - 'dadd_rz': 'add_rz', 'fadd_rz': 'add_rz', 'asinf': 'asin', - 'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2', - 'atanhf': 'atanh', 'brevll': 'brev', 'cbrtf': 'cbrt', - 'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign', - 'cosf': 'cos', 'coshf': 'cosh', 'cospif': 'cospi', - 'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1', - 'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn', - 'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', 'ddiv_ru': 'div_ru', - 'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf', - 'erfcf': 'erfc', 'erfcinvf': 'erfcinv', 'erfcxf': 'erfcx', - 'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10', - 'exp2f': 'exp2', 'expm1f': 'expm1', 'fabsf': 'abs', - 'fabs': 'abs', 'fast_fdividef': 'fast_dividef', - 'fdimf': 'fdim', 'ffsll': 'ffs', 'floorf': 'floor', - 'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn', - 'fmaf_ru': 'fma_ru', 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod', - 'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb', - 'isinff': 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan', - 'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn', - 'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint', - 'llroundf': 'llround', 'logf': 'log', 'log10f': 'log10', - 'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb', - 'umax': 'max', 'llmax': 'max', 'ullmax': 'max', 'fmaxf': 'max', - 'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min', - 'fminf': 'min', 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd', - 'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', 'dmul_ru': 'mul_ru', - 'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz', - 'umul24': 'mul24', 'umulhi': 'mulhi', 'mul64hi': 'mulhi', - 'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': 'nextafter', - 'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf', - 'normcdfinvf': 'normcdfinv', 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow', - 'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', 'drcp_rd': 'rcp_rd', - 'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru', - 'drcp_ru': 'rcp_ru', 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz', - 'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot', - 'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d', - 'roundf': 'round', 'rsqrtf': 'rsqrt', 'frsqrt_rn': 'rsqrt_rn', - 'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit', - 'signbitd': 'signbit', 'sinf': 'sin', 'sinhf': 'sinh', - 'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd', - 'dsqrt_rd': 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn', - 'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', 'fsqrt_rz': 'sqrt_rz', - 'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd', - 'fsub_rn': 'sub_rn', 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru', - 'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz', - 'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc', - 'y0f': 'y0', 'y1f': 'y1', 'ynf': 'yn' + 'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh', 'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': + 'add_rn', 'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru', 'dadd_rz': 'add_rz', 'fadd_rz': + 'add_rz', 'asinf': 'asin', 'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2', 'atanhf': 'atanh', + 'brevll': 'brev', 'cbrtf': 'cbrt', 'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign', 'cosf': 'cos', + 'coshf': 'cosh', 'cospif': 'cospi', 'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1', + 'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn', 'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', + 'ddiv_ru': 'div_ru', 'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf', 'erfcf': 'erfc', 'erfcinvf': + 'erfcinv', 'erfcxf': 'erfcx', 'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10', 'exp2f': 'exp2', + 'expm1f': 'expm1', 'fabsf': 'abs', 'fabs': 'abs', 'fast_fdividef': 'fast_dividef', 'fdimf': 'fdim', 'ffsll': + 'ffs', 'floorf': 'floor', 'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn', 'fmaf_ru': 'fma_ru', + 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod', 'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb', 'isinff': + 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan', 'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn', + 'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint', 'llroundf': 'llround', 'logf': 'log', 'log10f': + 'log10', 'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb', 'umax': 'max', 'llmax': 'max', 'ullmax': + 'max', 'fmaxf': 'max', 'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min', 'fminf': 'min', + 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd', 'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', + 'dmul_ru': 'mul_ru', 'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz', 'umul24': 'mul24', + 'umulhi': 'mulhi', 'mul64hi': 'mulhi', 'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': + 'nextafter', 'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf', 'normcdfinvf': 'normcdfinv', + 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow', 'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', + 'drcp_rd': 'rcp_rd', 'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru', 'drcp_ru': 'rcp_ru', + 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz', 'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot', + 'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d', 'roundf': 'round', 'rsqrtf': 'rsqrt', + 'frsqrt_rn': 'rsqrt_rn', 'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit', 'signbitd': 'signbit', + 'sinf': 'sin', 'sinhf': 'sinh', 'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd', 'dsqrt_rd': + 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn', 'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', + 'fsqrt_rz': 'sqrt_rz', 'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd', 'fsub_rn': 'sub_rn', + 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru', 'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz', + 'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc', 'y0f': 'y0', 'y1f': 'y1', 'ynf': + 'yn' } for symbol in self._symbols.values(): @@ -347,8 +326,7 @@ class LLVMDisassembler: self._ll_file = "/tmp/extern_lib.ll" def disasm(self, lib_path: str) -> None: - subprocess.Popen([self._path, lib_path, "-o", self.ll_file], - stdout=subprocess.PIPE).communicate() + subprocess.Popen([self._path, lib_path, "-o", self.ll_file], stdout=subprocess.PIPE).communicate() @property def ll_file(self) -> str: diff --git a/python/triton/tools/compile.py b/python/triton/tools/compile.py index d80f15e8a..a69c7100d 100644 --- a/python/triton/tools/compile.py +++ b/python/triton/tools/compile.py @@ -40,10 +40,13 @@ if __name__ == "__main__": # command-line arguments parser = ArgumentParser(description=desc) - parser.add_argument("path", help="Path to Python source containing desired kernel in its scope. File will be executed.") - parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile", required=True) + parser.add_argument("path", + help="Path to Python source containing desired kernel in its scope. File will be executed.") + parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile", + required=True) parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel") - parser.add_argument("--num-stages", "-ns", type=int, default=3, help="Number of stages (meta-parameter of the kernel)") + parser.add_argument("--num-stages", "-ns", type=int, default=3, + help="Number of stages (meta-parameter of the kernel)") parser.add_argument("--out-name", "-on", type=str, default=None, help="Out name for the compiled kernel") parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename") parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True) @@ -104,7 +107,8 @@ if __name__ == "__main__": config = triton.compiler.instance_descriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1) for i in equal_to_1: constexprs.update({i: 1}) - ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config], num_warps=args.num_warps, num_stages=args.num_stages) + ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config], + num_warps=args.num_warps, num_stages=args.num_stages) arg_names = [] arg_types = [] for i in signature.keys(): diff --git a/python/triton/tools/link.py b/python/triton/tools/link.py index 836c89c5f..eb39b4bda 100644 --- a/python/triton/tools/link.py +++ b/python/triton/tools/link.py @@ -27,6 +27,7 @@ class KernelLinkerMeta: class HeaderParser: + def __init__(self) -> None: import re @@ -42,7 +43,6 @@ class HeaderParser: self.kernels = defaultdict(list) def extract_linker_meta(self, header: str): - for ln in header.splitlines(): if ln.startswith("//"): m = self.linker_directives.match(ln) @@ -76,7 +76,7 @@ class HeaderParser: m = self.c_sig.findall(c_sig) if len(m): tys, args = [], [] - for (ty, arg_name) in m: + for ty, arg_name in m: tys.append(ty) args.append(arg_name) return tys, args @@ -84,7 +84,7 @@ class HeaderParser: raise LinkerError(f"{c_sig} is not a valid argument signature") def _match_suffix(self, suffix: str, c_sig: str): - args = c_sig.split(',') + args = c_sig.split(",") s2i = {"c": 1, "d": 16} num_specs = 0 sizes = [] @@ -110,7 +110,7 @@ class HeaderParser: if name in self.kernels: last: KernelLinkerMeta = self.kernels[name][-1] - for (cur, new_) in zip(last.arg_ctypes, ker.arg_ctypes): + for cur, new_ in zip(last.arg_ctypes, ker.arg_ctypes): if cur != new_: raise LinkerError( f"Mismatched signature for kernel {name}: \n\texisting sig is: {','.join(last.arg_ctypes)}\n\tcurrent is: {','.join(ker.arg_ctypes)}" @@ -152,7 +152,7 @@ void unload_{meta.orig_kernel_name}(); # generate dispatcher function for kernels with different meta-parameter and constant values def make_default_algo_kernel(meta: KernelLinkerMeta) -> str: src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n" - src += f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n" + src += (f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n") src += "}\n" return src @@ -164,12 +164,22 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) - src += f"CUresult {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, {gen_signature(meta)});\n" src += "\n" - src += f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{" + src += (f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{") src += "\n" for meta in sorted(metas, key=lambda m: -m.num_specs): - cond_fn = lambda val, hint: f"({val} % {hint} == 0)" if hint == 16 else f"({val} == {hint})" if hint == 1 else None - conds = " && ".join([cond_fn(val, hint) for val, hint in zip(meta.arg_names, meta.sizes) if hint is not None]) - src += f" if ({conds})\n" + cond_fn = ( # + lambda val, hint: f"({val} % {hint} == 0)" # + if hint == 16 # + else f"({val} == {hint})" # + if hint == 1 # + else None) + conds = " && ".join([ # + cond_fn(val, hint) # + for val, hint in zip(meta.arg_names, meta.sizes) # + if hint is not None + ]) + src += (f" if ({conds})\n" if any(meta.sizes) else "if (1)\n" + ) # Edge case where no specializations hence no dispatching required arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1] src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n" src += "\n" @@ -183,7 +193,7 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) - src += f"void {mode}_{name}() {{" src += "\n" for meta in sorted(metas, key=lambda m: -m.num_specs): - src += f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n" + src += (f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n") src += "}\n" return src @@ -252,7 +262,12 @@ if __name__ == "__main__": help="Paths to header files to link. Must include linker directive annotations (autogenerated by ttc)", ) parser.add_argument("--out", "-o", type=Path, help="Out filename") - parser.add_argument("--prefix", type=str, default="", help="String to prefix kernel dispatcher names") + parser.add_argument( + "--prefix", + type=str, + default="", + help="String to prefix kernel dispatcher names", + ) args = parser.parse_args() # metadata diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 3463ddf1c..1c1900a07 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -25,14 +25,13 @@ import triton.language as tl @triton.jit -def add_kernel( - x_ptr, # *Pointer* to first input vector. - y_ptr, # *Pointer* to second input vector. - output_ptr, # *Pointer* to output vector. - n_elements, # Size of the vector. - BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. - # NOTE: `constexpr` so it can be used as a shape value. -): +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): # There are multiple 'programs' processing different data. We identify which program # we are here: pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. @@ -66,7 +65,7 @@ def add(x: torch.Tensor, y: torch.Tensor): # The SPMD launch grid denotes the number of kernel instances that run in parallel. # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. # In this case, we use a 1D grid where the size is the number of blocks: - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) # NOTE: # - Each torch.tensor object is implicitly converted into a pointer to its first element. # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel. @@ -88,10 +87,8 @@ output_torch = x + y output_triton = add(x, y) print(output_torch) print(output_triton) -print( - f'The maximum difference between torch and triton is ' - f'{torch.max(torch.abs(output_torch - output_triton))}' -) +print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}') # %% # Seems like we're good to go! @@ -108,9 +105,7 @@ print( @triton.testing.perf_report( triton.testing.Benchmark( x_names=['size'], # Argument names to use as an x-axis for the plot. - x_vals=[ - 2 ** i for i in range(12, 28, 1) - ], # Different possible values for `x_name`. + x_vals=[2**i for i in range(12, 28, 1)], # Different possible values for `x_name`. x_log=True, # x axis is logarithmic. line_arg='provider', # Argument name whose value corresponds to a different line in the plot. line_vals=['triton', 'torch'], # Possible values for `line_arg`. @@ -119,8 +114,7 @@ print( ylabel='GB/s', # Label name for the y-axis. plot_name='vector-add-performance', # Name for the plot. Used also as a file name for saving the plot. args={}, # Values for function arguments not in `x_names` and `y_name`. - ) -) + )) def benchmark(size, provider): x = torch.rand(size, device='cuda', dtype=torch.float32) y = torch.rand(size, device='cuda', dtype=torch.float32) diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index 31b5b0a01..6d7d39b33 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -71,10 +71,7 @@ def naive_softmax(x): @triton.jit -def softmax_kernel( - output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, - BLOCK_SIZE: tl.constexpr -): +def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr): # The rows of the softmax are independent, so we parallelize across those row_idx = tl.program_id(0) # The stride represents how much we need to increase the pointer to advance 1 row @@ -118,7 +115,7 @@ def softmax(x): y = torch.empty_like(x) # Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o # f the input matrix - softmax_kernel[(n_rows,)]( + softmax_kernel[(n_rows, )]( y, x, x.stride(0), @@ -158,9 +155,7 @@ assert torch.allclose(y_triton, y_torch), (y_triton, y_torch) @triton.testing.perf_report( triton.testing.Benchmark( x_names=['N'], # argument names to use as an x-axis for the plot - x_vals=[ - 128 * i for i in range(2, 100) - ], # different possible values for `x_name` + x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name` line_arg='provider', # argument name whose value corresponds to a different line in the plot line_vals=[ 'triton', @@ -176,8 +171,7 @@ assert torch.allclose(y_triton, y_torch), (y_triton, y_torch) ylabel="GB/s", # label name for the y-axis plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. args={'M': 4096}, # values for function arguments not in `x_names` and `y_name` - ) -) + )) def benchmark(M, N, provider): x = torch.randn(M, N, device='cuda', dtype=torch.float32) quantiles = [0.5, 0.2, 0.8] diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index acb942fd8..da5696025 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -165,6 +165,7 @@ import pytest # provided configs @triton.autotune( configs=[ +<<<<<<< HEAD triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), @@ -179,6 +180,24 @@ import pytest triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, num_warps=8, num_stages=0), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3}, num_warps=4, num_stages=0), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8}, num_warps=4, num_stages=0), +======= + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, + num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + num_warps=2), +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 ], key=['M', 'N', 'K'], ) @@ -187,6 +206,7 @@ import pytest }) @triton.jit def matmul_kernel( +<<<<<<< HEAD # Pointers to matrices a_ptr, b_ptr, c_ptr, # Matrix dimensions @@ -202,6 +222,22 @@ def matmul_kernel( EVEN_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, ACTIVATION: tl.constexpr, +======= + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M, N, K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + ACTIVATION: tl.constexpr # +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 ): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) @@ -300,16 +336,14 @@ def matmul(a, b, activation=""): # Allocates output. c = torch.empty((M, N), device=a.device, dtype=a.dtype) # 1D launch kernel where each block gets its own program. - grid = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - ) + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) matmul_kernel[grid]( - a, b, c, - M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - ACTIVATION=activation + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + ACTIVATION=activation # ) return c @@ -363,6 +397,7 @@ verbose = False @triton.testing.perf_report( triton.testing.Benchmark( x_names=['M', 'N', 'K'], # Argument names to use as an x-axis for the plot +<<<<<<< HEAD x_vals=[ (1024, 1024, 1024), (2048, 2048, 2048), @@ -370,6 +405,9 @@ verbose = False (8192, 8192, 8192), (9728, 8192, 65536) ], # Different possible values for `x_name` +======= + x_vals=[128 * i for i in range(2, 33)], # Different possible values for `x_name` +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 line_arg='provider', # Argument name whose value corresponds to a different line in the plot # Possible values for `line_arg` line_vals=['rocblas', 'triton'], @@ -380,8 +418,7 @@ verbose = False ylabel="TFLOPS", # Label name for the y-axis plot_name="matmul-performance", # Name for the plot, used also as a file name for saving the plot. args={}, - ) -) + )) def benchmark(M, N, K, provider): a = torch.randn((M, K), device='cuda', dtype=torch.float16) b = torch.randn((K, N), device='cuda', dtype=torch.float16) diff --git a/python/tutorials/04-low-memory-dropout.py b/python/tutorials/04-low-memory-dropout.py index 3c4d217e2..fe52f0d8e 100644 --- a/python/tutorials/04-low-memory-dropout.py +++ b/python/tutorials/04-low-memory-dropout.py @@ -32,7 +32,6 @@ In doing so, you will learn about: # # Let's first take a look at the baseline implementation. - import tabulate import torch @@ -66,22 +65,22 @@ def dropout(x, x_keep, p): output = torch.empty_like(x) assert x.is_contiguous() n_elements = x.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024) return output # Input tensor -x = torch.randn(size=(10,)).cuda() +x = torch.randn(size=(10, )).cuda() # Dropout mask p = 0.5 -x_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda() +x_keep = (torch.rand(size=(10, )) > p).to(torch.int32).cuda() # output = dropout(x, x_keep=x_keep, p=p) print(tabulate.tabulate([ ["input"] + x.tolist(), ["keep mask"] + x_keep.tolist(), - ["output"] + output.tolist() + ["output"] + output.tolist(), ])) # %% @@ -134,23 +133,24 @@ def seeded_dropout(x, p, seed): output = torch.empty_like(x) assert x.is_contiguous() n_elements = x.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024) return output -x = torch.randn(size=(10,)).cuda() +x = torch.randn(size=(10, )).cuda() # Compare this to the baseline - dropout mask is never instantiated! output = seeded_dropout(x, p=0.5, seed=123) output2 = seeded_dropout(x, p=0.5, seed=123) output3 = seeded_dropout(x, p=0.5, seed=512) -print(tabulate.tabulate([ - ["input"] + x.tolist(), - ["output (seed = 123)"] + output.tolist(), - ["output (seed = 123)"] + output2.tolist(), - ["output (seed = 512)"] + output3.tolist() -])) +print( + tabulate.tabulate([ + ["input"] + x.tolist(), + ["output (seed = 123)"] + output.tolist(), + ["output (seed = 123)"] + output2.tolist(), + ["output (seed = 512)"] + output3.tolist(), + ])) # %% # Et Voilà! We have a triton kernel that applies the same dropout mask provided the seed is the same! diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index 6f131c963..95cd22d1a 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -126,24 +126,22 @@ def _layer_norm_fwd_fused( # In Stage 2, the buffers are further reduced to compute the final :math:`\nabla_{w}` and :math:`\nabla_{b}`. # In the following implementation, Stage 1 is implemented by the function :code:`_layer_norm_bwd_dx_fused` and Stage 2 is implemented by the function :code:`_layer_norm_bwd_dwdb`. + @triton.jit -def _layer_norm_bwd_dx_fused( - DX, # pointer to the input gradient - DY, # pointer to the output gradient - DW, # pointer to the partial sum of weights gradient - DB, # pointer to the partial sum of biases gradient - X, # pointer to the input - W, # pointer to the weights - B, # pointer to the biases - Mean, # pointer to the mean - Rstd, # pointer to the 1/std - Lock, # pointer to the lock - stride, # how much to increase the pointer when moving by 1 row - N, # number of columns in X - eps, # epsilon to avoid division by zero - GROUP_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr -): +def _layer_norm_bwd_dx_fused(DX, # pointer to the input gradient + DY, # pointer to the output gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + Lock, # pointer to the lock + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): # Map the program id to the elements of X, DX, and DY it should compute. row = tl.program_id(0) cols = tl.arange(0, BLOCK_SIZE_N) @@ -192,16 +190,13 @@ def _layer_norm_bwd_dx_fused( @triton.jit -def _layer_norm_bwd_dwdb( - DW, # pointer to the partial sum of weights gradient - DB, # pointer to the partial sum of biases gradient - FINAL_DW, # pointer to the weights gradient - FINAL_DB, # pointer to the biases gradient - M, # GROUP_SIZE_M - N, # number of columns - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr -): +def _layer_norm_bwd_dwdb(DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + FINAL_DW, # pointer to the weights gradient + FINAL_DB, # pointer to the biases gradient + M, # GROUP_SIZE_M + N, # number of columns + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): # Map the program id to the elements of DW and DB it should compute. pid = tl.program_id(0) cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -258,9 +253,10 @@ class LayerNorm(torch.autograd.Function): else: num_warps = min(max(BLOCK_SIZE // 256, 1), 8) # enqueue kernel - _layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd, - x_arg.stride(0), N, eps, - BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1) + _layer_norm_fwd_fused[(M, )]( # + x_arg, y, weight, bias, mean, rstd, # + x_arg.stride(0), N, eps, # + BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1) ctx.save_for_backward(x, weight, bias, mean, rstd) ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps @@ -280,23 +276,25 @@ class LayerNorm(torch.autograd.Function): locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda') _dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device) _db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device) - dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) - db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) + dw = torch.empty((w.shape[0], ), dtype=w.dtype, device=w.device) + db = torch.empty((w.shape[0], ), dtype=w.dtype, device=w.device) dx = torch.empty_like(dy) # enqueue kernel using forward pass heuristics # also compute partial sums for DW and DB x_arg = x.reshape(-1, x.shape[-1]) M, N = x_arg.shape - _layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks, - x_arg.stride(0), N, ctx.eps, - BLOCK_SIZE_N=ctx.BLOCK_SIZE, - GROUP_SIZE_M=GROUP_SIZE_M, - num_warps=ctx.num_warps) + _layer_norm_bwd_dx_fused[(M, )]( # + dx, dy, _dw, _db, x, w, b, m, v, locks, # + x_arg.stride(0), N, ctx.eps, # + BLOCK_SIZE_N=ctx.BLOCK_SIZE, # + GROUP_SIZE_M=GROUP_SIZE_M, # + num_warps=ctx.num_warps) grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] # accumulate partial sums in separate kernel - _layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N, - BLOCK_SIZE_M=32, - BLOCK_SIZE_N=128, num_ctas=1) + _layer_norm_bwd_dwdb[grid]( + _dw, _db, dw, db, GROUP_SIZE_M, N, # + BLOCK_SIZE_M=32, # + BLOCK_SIZE_N=128, num_ctas=1) return dx, None, dw, db, None @@ -340,10 +338,16 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []), styles=[('blue', '-'), ('green', '-'), ('orange', '-')], ylabel='GB/s', +<<<<<<< HEAD plot_name='layer-norm-forward', args={'M': 4096, 'dtype': torch.float16, 'mode': 'forward'} ) ) +======= + plot_name='layer-norm-backward', + args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}, + )) +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'): # create data x_shape = (M, N) @@ -356,24 +360,34 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c quantiles = [0.5, 0.2, 0.8] # utility functions if provider == 'triton': - def y_fwd(): return layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704 - if provider == 'torch': - def y_fwd(): return torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704 - if provider == 'apex': - apex_layer_norm = apex.normalization.FusedLayerNorm( - w_shape).to(x.device).to(x.dtype) - def y_fwd(): return apex_layer_norm(x) # noqa: F811, E704 + def y_fwd(): + return layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704 + + if provider == 'torch': + + def y_fwd(): + return torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704 + + if provider == 'apex': + apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype) + + def y_fwd(): + return apex_layer_norm(x) # noqa: F811, E704 + # forward pass if mode == 'forward': gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6 ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500) # backward pass if mode == 'backward': - def gbps(ms): return 3 * x.numel() * x.element_size() / ms * 1e-6 # noqa: F811, E704 + + def gbps(ms): + return 3 * x.numel() * x.element_size() / ms * 1e-6 # noqa: F811, E704 + y = y_fwd() - ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), - quantiles=quantiles, grad_to_none=[x], rep=500) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), quantiles=quantiles, + grad_to_none=[x], rep=500) return gbps(ms), gbps(max_ms), gbps(min_ms) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 2ec90c1f3..5878cdcad 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -29,6 +29,7 @@ if TORCH_HAS_FP8E5FNUZ: TORCH_HAS_FP8 = True @triton.jit +<<<<<<< HEAD def _attn_fwd_inner( acc, l_i, m_i, q, K_block_ptr, V_block_ptr, @@ -42,6 +43,14 @@ def _attn_fwd_inner( N_CTX, pre_load_v: tl.constexpr, ): +======= +def _attn_fwd_inner(acc, l_i, m_i, q, # + K_block_ptr, V_block_ptr, # + start_m, qk_scale, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # + N_CTX: tl.constexpr): +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 # range of values handled by this stage if STAGE == 1: lo, hi = 0, start_m * BLOCK_M @@ -83,6 +92,7 @@ def _attn_fwd_inner( K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) return acc, l_i, m_i + # We don't run auto-tuning everytime to keep the tutorial fast. Uncommenting # the code below and commenting out the equivalent parameters is convenient for # re-tuning. @@ -99,6 +109,7 @@ def _attn_fwd_inner( @triton.jit +<<<<<<< HEAD def _attn_fwd( Q, K, V, sm_scale, M, Out, stride_qz, stride_qh, stride_qm, stride_qk, @@ -113,6 +124,20 @@ def _attn_fwd( BLOCK_N: tl.constexpr, pre_load_v: tl.constexpr, ): +======= +def _attn_fwd(Q, K, V, sm_scale, M, Out, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vk, stride_vn, # + stride_oz, stride_oh, stride_om, stride_on, # + Z, H, # + N_CTX: tl.constexpr, # + BLOCK_M: tl.constexpr, # + BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr # + ): +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 start_m = tl.program_id(0) off_hz = tl.program_id(1) qvk_offset = off_hz * stride_qh @@ -168,6 +193,7 @@ def _attn_fwd( # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE if STAGE & 1: +<<<<<<< HEAD acc, l_i, m_i = _attn_fwd_inner( acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, @@ -175,11 +201,19 @@ def _attn_fwd( 4 - STAGE, offs_m, offs_n, N_CTX, pre_load_v, ) +======= + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # + start_m, qk_scale, # + BLOCK_M, BLOCK_DMODEL, BLOCK_N, # + 4 - STAGE, offs_m, offs_n, N_CTX # + ) +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 # stage 2: on-band if STAGE & 2: # barrier makes it easier for compielr to schedule the # two loops independently tl.debug_barrier() +<<<<<<< HEAD acc, l_i, m_i = _attn_fwd_inner( acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, @@ -187,6 +221,13 @@ def _attn_fwd( 2, offs_m, offs_n, N_CTX, pre_load_v, ) +======= + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # + start_m, qk_scale, # + BLOCK_M, BLOCK_DMODEL, BLOCK_N, # + 2, offs_m, offs_n, N_CTX # + ) +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 # epilogue # write back m acc = acc / l_i[:, None] @@ -197,8 +238,14 @@ def _attn_fwd( @triton.jit def _attn_bwd_preprocess(O, DO, # +<<<<<<< HEAD NewDO, Delta, # BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, # +======= + Delta, # + Z, H, N_CTX, # + BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr # +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 ): off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) off_n = tl.arange(0, D_HEAD) @@ -212,6 +259,7 @@ def _attn_bwd_preprocess(O, DO, # @triton.jit +<<<<<<< HEAD def _bwd_kernel_dk_dv( Q, K, V, sm_scale, Out, DO, DK, DV, @@ -422,11 +470,242 @@ def _bwd_kernel_dq( order=(1, 0) ) tl.store(DQ_block_ptr, (dq * sm_scale).to(tl.float16)) +======= +def _attn_bwd_dkdv(dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + # shared by Q/K/V/DO. + stride_tok, stride_d, # + H, N_CTX, BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + BLOCK_DMODEL: tl.constexpr, # + # Filled in by the wrapper. + start_n, start_m, num_steps, # + MASK: tl.constexpr): + offs_m = start_m + tl.arange(0, BLOCK_M1) + offs_n = start_n + tl.arange(0, BLOCK_N1) + offs_k = tl.arange(0, BLOCK_DMODEL) + qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d + do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + curr_m = start_m + step_m = BLOCK_M1 + for blk_idx in range(num_steps): + qT = tl.load(qT_ptrs) + # Load m before computing qk to reduce pipeline stall. + offs_m = curr_m + tl.arange(0, BLOCK_M1) + m = tl.load(M + offs_m) + qkT = tl.dot(k, qT) + pT = tl.math.exp2(qkT - m[None, :]) + # Autoregressive masking. + if MASK: + mask = (offs_m[None, :] >= offs_n[:, None]) + pT = tl.where(mask, pT, 0.0) + do = tl.load(do_ptrs) + # Compute dV. + ppT = pT + ppT = ppT.to(tl.float16) + dv += tl.dot(ppT, do) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do)).to(tl.float32) + dsT = pT * (dpT - Di[None, :]) + dsT = dsT.to(tl.float16) + dk += tl.dot(dsT, tl.trans(qT)) + # Increment pointers. + curr_m += step_m + qT_ptrs += step_m * stride_tok + do_ptrs += step_m * stride_tok + return dk, dv + + +# the main inner-loop logic for computing dQ +@triton.jit +def _attn_bwd_dq(dq, q, K, V, # + do, m, D, + # shared by Q/K/V/DO. + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + BLOCK_DMODEL: tl.constexpr, + # Filled in by the wrapper. + start_m, start_n, num_steps, # + MASK: tl.constexpr): + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + offs_k = tl.arange(0, BLOCK_DMODEL) + kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + for blk_idx in range(num_steps): + kT = tl.load(kT_ptrs) + vT = tl.load(vT_ptrs) + qk = tl.dot(q, kT) + p = tl.math.exp2(qk - m) + # Autoregressive masking. + if MASK: + offs_n = curr_n + tl.arange(0, BLOCK_N2) + mask = (offs_m[:, None] >= offs_n[None, :]) + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + dp = tl.dot(do, vT).to(tl.float32) + ds = p * (dp - Di[:, None]) + ds = ds.to(tl.float16) + # Compute dQ. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + dq += tl.dot(ds, tl.trans(kT)) + # Increment pointers. + curr_n += step_n + kT_ptrs += step_n * stride_tok + vT_ptrs += step_n * stride_tok + return dq + + +@triton.jit +def _attn_bwd(Q, K, V, sm_scale, # + DO, # + DQ, DK, DV, # + M, D, + # shared by Q/K/V/DO. + stride_z, stride_h, stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + BLK_SLICE_FACTOR: tl.constexpr, # + BLOCK_DMODEL: tl.constexpr): + LN2: tl.constexpr = 0.6931471824645996 # = ln(2) + + bhid = tl.program_id(2) + off_chz = (bhid * N_CTX).to(tl.int64) + adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) + pid = tl.program_id(0) + + # offset pointers for batch/head + Q += adj + K += adj + V += adj + DO += adj + DQ += adj + DK += adj + DV += adj + M += off_chz + D += off_chz + + # load scales + offs_k = tl.arange(0, BLOCK_DMODEL) + + start_n = pid * BLOCK_N1 + start_m = start_n + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + offs_n = start_n + tl.arange(0, BLOCK_N1) + + dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + + num_steps = BLOCK_N1 // MASK_BLOCK_M1 + + dk, dv = _attn_bwd_dkdv(dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + stride_tok, stride_d, # + H, N_CTX, # + MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, # + start_n, start_m, num_steps, # + MASK=True # + ) + + start_m += num_steps * MASK_BLOCK_M1 + num_steps = (N_CTX - start_m) // BLOCK_M1 + + # Compute dK and dV for non-masked blocks. + dk, dv = _attn_bwd_dkdv( # + dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, # + start_n, start_m, num_steps, # + MASK=False # + ) + + dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dv_ptrs, dv) + + # Write back dK. + dk *= sm_scale + dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dk_ptrs, dk) + + # THIS BLOCK DOES DQ: + start_m = pid * BLOCK_M2 + end_n = start_m + BLOCK_M2 + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + offs_m = start_m + tl.arange(0, BLOCK_M2) + + q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) + do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + + m = tl.load(M + offs_m) + m = m[:, None] + + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _attn_bwd_dq, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + num_steps = BLOCK_M2 // MASK_BLOCK_N2 + dq = _attn_bwd_dq(dq, q, K, V, # + do, m, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL, # + start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, # + MASK=True # + ) + end_n -= num_steps * MASK_BLOCK_N2 + # stage 2 + num_steps = end_n // BLOCK_N2 + dq = _attn_bwd_dq(dq, q, K, V, # + do, m, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2, BLOCK_N2, BLOCK_DMODEL, # + start_m, end_n - num_steps * BLOCK_N2, num_steps, # + MASK=False # + ) + # Write back dQ. + dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + dq *= LN2 + tl.store(dq_ptrs, dq) + +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 empty = torch.empty(128, device="cuda") class _attention(torch.autograd.Function): + @staticmethod def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False): # shape constraints @@ -453,6 +732,7 @@ class _attention(torch.autograd.Function): M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) _attn_fwd[grid]( +<<<<<<< HEAD q, k, v, sm_scale, M, o, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), @@ -462,6 +742,21 @@ class _attention(torch.autograd.Function): N_CTX=q.shape[2], BLOCK_DMODEL=Lk, STAGE=stage, +======= + q, k, v, sm_scale, M, o, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + o.stride(0), o.stride(1), o.stride(2), o.stride(3), # + q.shape[0], q.shape[1], # + N_CTX=q.shape[2], # + BLOCK_M=BLOCK_M, # + BLOCK_N=BLOCK_N, # + BLOCK_DMODEL=Lk, # + STAGE=stage, # + num_warps=num_warps, # + num_stages=num_stages # +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 ) ## restore the grid for bwd kernel @@ -493,6 +788,7 @@ class _attention(torch.autograd.Function): dk = torch.empty_like(k) dv = torch.empty_like(v) BATCH, N_HEAD, N_CTX = q.shape[:3] +<<<<<<< HEAD delta = torch.empty_like(L) do_scaled = torch.empty_like(do) # Figure out what BLOCK size fwd used and adjust num_blocks accordingly. @@ -506,6 +802,39 @@ class _attention(torch.autograd.Function): o, do, # do_scaled, delta, # BLOCK_M=block_scale * BLOCK, D_HEAD=ctx.BLOCK_DMODEL, # +======= + PRE_BLOCK = 128 + NUM_WARPS, NUM_STAGES = 4, 1 + if torch.cuda.get_device_capability()[0] == 9: + NUM_STAGES = 5 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + BLK_SLICE_FACTOR = 2 + RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) + arg_k = k + arg_k = arg_k * (ctx.sm_scale * RCP_LN2) + PRE_BLOCK = 128 + assert N_CTX % PRE_BLOCK == 0 + pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) + delta = torch.empty_like(M) + _attn_bwd_preprocess[pre_grid]( + o, do, # + delta, # + BATCH, N_HEAD, N_CTX, # + BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL # + ) + grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) + _attn_bwd[grid]( + q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, # + M, delta, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + N_HEAD, N_CTX, # + BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, # + BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, # + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # + BLOCK_DMODEL=ctx.BLOCK_DMODEL, # + num_warps=NUM_WARPS, # + num_stages=NUM_STAGES # +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 ) if not ctx.split_kernel: _bwd_kernel[(ctx.grid[1],)]( @@ -599,11 +928,17 @@ def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): ]) def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16): torch.manual_seed(20) +<<<<<<< HEAD causal = True q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() +======= + q = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + k = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + v = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 sm_scale = 0.5 split_kernel = True dout = torch.randn_like(q) @@ -672,17 +1007,26 @@ for mode in ['fwd', 'bwd']: ylabel='ms', plot_name=f'fused-attention-{mode}-d{D_HEAD}-causal={causal}', args={ +<<<<<<< HEAD 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode, 'causal': causal}) ) +======= + "H": N_HEADS, + "BATCH": BATCH, + "D_HEAD": D_HEAD, + "dtype": torch.float16, + "mode": mode, + "causal": causal, + }, + )) +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 @triton.testing.perf_report(configs) -def bench_flash_attention( - BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda" -): +def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"): assert mode in ["fwd", "bwd"] warmup = 25 rep = 100 @@ -706,9 +1050,7 @@ def bench_flash_attention( fn = lambda: o.backward(do, retain_graph=True) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) if provider == "flash": - qkv = torch.randn( - (BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True - ) + qkv = torch.randn((BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) fn = lambda: flash_attn_func(qkv, causal=causal) if mode == "bwd": o = fn() diff --git a/python/tutorials/07-math-functions.py b/python/tutorials/07-math-functions.py index 1ded3aa98..f60f07efb 100644 --- a/python/tutorials/07-math-functions.py +++ b/python/tutorials/07-math-functions.py @@ -22,10 +22,10 @@ import triton.language as tl @triton.jit def asin_kernel( - x_ptr, - y_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, + x_ptr, + y_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE @@ -35,12 +35,12 @@ def asin_kernel( x = tl.math.asin(x) tl.store(y_ptr + offsets, x, mask=mask) + # %% # Using the default libdevice library path # ----------------------------------------- # We can use the default libdevice library path encoded in `triton/language/math.py` - torch.manual_seed(0) size = 98432 x = torch.rand(size, device='cuda') @@ -48,14 +48,12 @@ output_triton = torch.zeros(size, device='cuda') output_torch = torch.asin(x) assert x.is_cuda and output_triton.is_cuda n_elements = output_torch.numel() -grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) +grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024) print(output_torch) print(output_triton) -print( - f'The maximum difference between torch and triton is ' - f'{torch.max(torch.abs(output_torch - output_triton))}' -) +print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}') # %% # Customize the libdevice library path @@ -67,7 +65,5 @@ asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024, extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'}) print(output_torch) print(output_triton) -print( - f'The maximum difference between torch and triton is ' - f'{torch.max(torch.abs(output_torch - output_triton))}' -) +print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}') diff --git a/python/tutorials/08-experimental-block-pointer.py b/python/tutorials/08-experimental-block-pointer.py index 7147b69de..4486349fb 100644 --- a/python/tutorials/08-experimental-block-pointer.py +++ b/python/tutorials/08-experimental-block-pointer.py @@ -98,14 +98,22 @@ import triton.language as tl @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + num_warps=2), ], key=['M', 'N', 'K'], ) @@ -118,13 +126,11 @@ def matmul_kernel_with_block_pointers( # The stride variables represent how much to increase the ptr by when moving by 1 # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` # by to get the element one row down (A has M rows). - stride_am, stride_ak, - stride_bk, stride_bn, + stride_am, stride_ak, # + stride_bk, stride_bn, # stride_cm, stride_cn, # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr -): + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) """ @@ -196,16 +202,13 @@ def matmul(a, b): # Allocates output. c = torch.empty((M, N), device=a.device, dtype=a.dtype) # 1D launch kernel where each block gets its own program. - grid = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - ) + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) matmul_kernel_with_block_pointers[grid]( - a, b, c, - M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - ) + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1)) return c diff --git a/python/tutorials/09-experimental-tma-matrix-multiplication.py b/python/tutorials/09-experimental-tma-matrix-multiplication.py index 8a79720c7..8cf81ef69 100644 --- a/python/tutorials/09-experimental-tma-matrix-multiplication.py +++ b/python/tutorials/09-experimental-tma-matrix-multiplication.py @@ -40,23 +40,24 @@ if torch.cuda.get_device_capability()[0] < 9: @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, + num_warps=4), # triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4, num_ctas=2), # triton.Config({'BLOCK_SIZE_M': 512, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4, num_ctas=4), ], key=['M', 'N', 'K'], ) @triton.jit -def matmul_kernel( - a_ptr, b_ptr, z_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_zm, stride_zn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, - A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, - B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr -): +def matmul_kernel(a_ptr, b_ptr, z_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_zm, stride_zn, # + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, # + A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, # + B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr # + ): pid = tl.program_id(axis=0) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) @@ -70,9 +71,11 @@ def matmul_kernel( block_offset_n = pid_n * BLOCK_SIZE_N a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), - offsets=(block_offset_m, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), order=(A_ORDER_0, A_ORDER_1)) + offsets=(block_offset_m, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), + order=(A_ORDER_0, A_ORDER_1)) b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), - offsets=(0, block_offset_n), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), order=(B_ORDER_0, B_ORDER_1)) + offsets=(0, block_offset_n), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), + order=(B_ORDER_0, B_ORDER_1)) z = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) offs_m = block_offset_m + tl.arange(0, BLOCK_SIZE_M) @@ -101,15 +104,17 @@ def matmul(a, b, a_order, b_order): z = torch.empty((M, N), device=a.device, dtype=torch.float16) def grid(META): - return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) - matmul_kernel[grid](a_ptr=a, b_ptr=b, z_ptr=z, - M=M, N=N, K=K, - stride_am=a.stride(0), stride_ak=a.stride(1), - stride_bk=b.stride(0), stride_bn=b.stride(1), - stride_zm=z.stride(0), stride_zn=z.stride(1), - A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], - B_ORDER_0=b_order[0], B_ORDER_1=b_order[1] - ) + return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + + matmul_kernel[grid]( + a_ptr=a, b_ptr=b, z_ptr=z, # + M=M, N=N, K=K, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_zm=z.stride(0), stride_zn=z.stride(1), # + A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # + B_ORDER_0=b_order[0], B_ORDER_1=b_order[1] # + ) return z @@ -160,14 +165,12 @@ def test_matmul(): # label name for the lines line_names=["cuBLAS", "Triton"], # line styles - styles=[('green', '-'), ('green', '--'), - ('blue', '-'), ('blue', '--')], + styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], ylabel="TFLOPS", # label name for the y-axis plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot. args={}, - ) -) + )) def benchmark(M, N, K, TRANS_A, TRANS_B, provider): if (TRANS_A): a = torch.randn((K, M), device='cuda', dtype=torch.float16).T @@ -185,14 +188,15 @@ def benchmark(M, N, K, TRANS_A, TRANS_B, provider): quantiles = [0.5, 0.2, 0.8] if provider == 'cublas': - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: torch.matmul(a, b), rep=100, quantiles=quantiles, fast_flush=False) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), rep=100, quantiles=quantiles, + fast_flush=False) if provider == 'triton': - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: matmul(a, b, a_order, b_order), rep=100, quantiles=quantiles, fast_flush=False) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, a_order, b_order), rep=100, + quantiles=quantiles, fast_flush=False) def perf(ms): return 2 * M * N * K * 1e-12 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) diff --git a/python/tutorials/10-experimental-tma-store-matrix-multiplication.py b/python/tutorials/10-experimental-tma-store-matrix-multiplication.py index 37d58863d..966e8d1e2 100644 --- a/python/tutorials/10-experimental-tma-store-matrix-multiplication.py +++ b/python/tutorials/10-experimental-tma-store-matrix-multiplication.py @@ -40,21 +40,21 @@ if torch.cuda.get_device_capability()[0] < 9: @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, + num_warps=4), # triton.Config({'BLOCK_SIZE_M': 512, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4, num_ctas=4), ], key=['M', 'N', 'K'], ) @triton.jit -def matmul_kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): +def matmul_kernel(a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr # + ): pid = tl.program_id(axis=0) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) @@ -67,20 +67,10 @@ def matmul_kernel( block_offset_m = pid_m * BLOCK_SIZE_M block_offset_n = pid_n * BLOCK_SIZE_N - a_tile_ptr = tl.make_block_ptr( - base=a_ptr, shape=( - M, K), strides=( - stride_am, stride_ak), offsets=( - block_offset_m, 0), block_shape=( - BLOCK_SIZE_M, BLOCK_SIZE_K), order=( - 1, 0)) - b_tile_ptr = tl.make_block_ptr( - base=b_ptr, shape=( - K, N), strides=( - stride_bk, stride_bn), offsets=( - 0, block_offset_n), block_shape=( - BLOCK_SIZE_K, BLOCK_SIZE_N), order=( - 0, 1)) + a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), + offsets=(block_offset_m, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), order=(1, 0)) + b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), + offsets=(0, block_offset_n), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), order=(0, 1)) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, K, BLOCK_SIZE_K): @@ -91,7 +81,8 @@ def matmul_kernel( b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_SIZE_K, 0]) c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), - offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0)) + offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), + order=(1, 0)) tl.store(c_block_ptr, accumulator) @@ -101,20 +92,19 @@ def matmul(a, b): assert a.shape[1] == b.shape[0], "incompatible dimensions" M, K = a.shape K, N = b.shape - assert ( - K % 32 == 0 - ), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K" + assert (K % 32 == 0), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K" c = torch.empty((M, N), device=a.device, dtype=torch.float32) def grid(META): - return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) + return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) - matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, - M=M, N=N, K=K, - stride_am=a.stride(0), stride_ak=a.stride(1), - stride_bk=b.stride(0), stride_bn=b.stride(1), - stride_cm=c.stride(0), stride_cn=c.stride(1)) + matmul_kernel[grid]( + a_ptr=a, b_ptr=b, c_ptr=c, # + M=M, N=N, K=K, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_cm=c.stride(0), stride_cn=c.stride(1)) return c @@ -126,12 +116,7 @@ c = torch.nn.functional.normalize(c) golden = torch.nn.functional.normalize(torch.matmul(a, b)) torch.set_printoptions(profile="full") -assert_close( - c, - golden, - rtol=1e-2, - atol=1e-3, - check_dtype=False) +assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False) @triton.testing.perf_report( @@ -143,7 +128,7 @@ assert_close( [2048, 1024, 1024], [2048, 2048, 2048], [2048, 4096, 4096], - [2048, 8192, 8192] + [2048, 8192, 8192], ], # different possible values for `x_name` line_arg='provider', # argument name whose value corresponds to a different line in the plot @@ -152,27 +137,26 @@ assert_close( # label name for the lines line_names=["cuBLAS", "Triton"], # line styles - styles=[('green', '-'), ('green', '--'), - ('blue', '-'), ('blue', '--')], + styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], ylabel="TFLOPS", # label name for the y-axis plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot. args={}, - ) -) + )) def benchmark(M, N, K, provider): a = torch.randn((M, K), device='cuda', dtype=torch.float16) b = torch.randn((N, K), device='cuda', dtype=torch.float16).T quantiles = [0.5, 0.2, 0.8] if provider == 'cublas': - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: torch.matmul(a, b), rep=100, quantiles=quantiles, fast_flush=False) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), rep=100, quantiles=quantiles, + fast_flush=False) if provider == 'triton': - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: matmul(a, b), rep=100, quantiles=quantiles, fast_flush=False) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), rep=100, quantiles=quantiles, + fast_flush=False) def perf(ms): return 2 * M * N * K * 1e-12 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) diff --git a/python/tutorials/11-grouped-gemm.py b/python/tutorials/11-grouped-gemm.py index ee1328cd8..d4756535c 100644 --- a/python/tutorials/11-grouped-gemm.py +++ b/python/tutorials/11-grouped-gemm.py @@ -1,3 +1,13 @@ +<<<<<<< HEAD +======= +""" +Group GEMM +============================ +This group gemm kernel launches a fixed number of CTA to compute a group +of gemms. The scheduling is static and we do it on device. +""" + +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 # Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining @@ -28,6 +38,7 @@ import triton.language as tl # of gemms. The scheduling is static and we do it on device @triton.autotune( configs=[ +<<<<<<< HEAD triton.Config( { 'BLOCK_SIZE_M': 128, @@ -111,6 +122,32 @@ import triton.language as tl num_stages = 0, num_warps = 2, ), +======= + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': 84, + }), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': 128, + }), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': 84, + }), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': 128, + }), +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 ], key=['SUM_M', 'SUM_N', 'SUM_K'], ) @@ -149,9 +186,7 @@ def grouped_matmul_kernel( num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N) num_tiles = num_m_tiles * num_n_tiles # iterate through the tiles in the current gemm problem - while ( - tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles - ): + while (tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles): # pick up a tile from the current gemm problem k = gk lda = tl.load(g_lds + g * 3) @@ -171,9 +206,7 @@ def grouped_matmul_kernel( offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :] b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :] - accumulator = tl.zeros( - (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32 - ) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)): # hint to Triton compiler to do proper loop pipelining tl.multiple_of(a_ptrs, [16, 16]) @@ -224,7 +257,7 @@ def group_gemm_fn(group_A, group_B): group_C.append(C) A_addrs.append(A.data_ptr()) B_addrs.append(B.data_ptr()) - C_addrs .append(C.data_ptr()) + C_addrs.append(C.data_ptr()) g_sizes += [M, N, K] SUM_M += M SUM_N += N @@ -235,14 +268,10 @@ def group_gemm_fn(group_A, group_B): d_a_ptrs = torch.tensor(A_addrs, device=device) d_b_ptrs = torch.tensor(B_addrs, device=device) d_c_ptrs = torch.tensor(C_addrs, device=device) - d_g_sizes = torch.tensor( - g_sizes, dtype=torch.int32, device=device - ) - d_g_lds = torch.tensor( - g_lds, dtype=torch.int32, device=device - ) + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=device) + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=device) # we use a fixed number of CTA, and it's auto-tunable - grid = lambda META: (META['NUM_SM'],) + grid = lambda META: (META['NUM_SM'], ) grouped_matmul_kernel[grid]( d_a_ptrs, d_b_ptrs, @@ -283,8 +312,13 @@ for i in range(group_size): # only launch the kernel, no tensor preparation here to remove all overhead +<<<<<<< HEAD def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size, sum_m, sum_n, sum_k): grid = lambda META: (META['NUM_SM'],) +======= +def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size): + grid = lambda META: (META['NUM_SM'], ) +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 grouped_matmul_kernel[grid]( a_ptrs, b_ptrs, @@ -307,7 +341,7 @@ def torch_perf_fn(group_A, group_B): triton.testing.Benchmark( # argument names to use as an x-axis for the plot x_names=['N'], - x_vals=[2 ** i for i in range(7, 11)], # different possible values for `x_name` + x_vals=[2**i for i in range(7, 11)], # different possible values for `x_name` line_arg='provider', # argument name whose value corresponds to a different line in the plot # possible values for `line_arg`` @@ -320,8 +354,7 @@ def torch_perf_fn(group_A, group_B): plot_name="group-gemm-performance", # name for the plot. Used also as a file name for saving the plot. args={}, - ) -) + )) def benchmark(N, provider): group_size = 4 group_A = [] @@ -341,7 +374,7 @@ def benchmark(N, provider): group_C.append(C) A_addrs.append(A.data_ptr()) B_addrs.append(B.data_ptr()) - C_addrs .append(C.data_ptr()) + C_addrs.append(C.data_ptr()) g_sizes += [N, N, N] g_lds += [N, N, N] @@ -355,7 +388,12 @@ def benchmark(N, provider): if provider == 'cublas': ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles) if provider == 'triton': +<<<<<<< HEAD ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size, group_size*N, group_size*N, group_size*N), quantiles=quantiles) +======= + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size), quantiles=quantiles) +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 return ms, max_ms, min_ms diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir index 77eae5f22..ab938f3c8 100644 --- a/test/Analysis/test-alignment.mlir +++ b/test/Analysis/test-alignment.mlir @@ -292,7 +292,7 @@ tt.func @select() { // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0 %5 = arith.select %4, %3, %7 : tensor<128xi1> // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = - %8 = "triton_gpu.select"(%7, %3, %2) : (tensor<128xi1>, tensor<128xi1>, tensor<128xi1>) -> tensor<128xi1> + %8 = arith.select %7, %3, %2 : tensor<128xi1>, tensor<128xi1> tt.return } diff --git a/test/Conversion/dedup-by-constancy.mlir b/test/Conversion/dedup-by-constancy.mlir new file mode 100644 index 000000000..455a71548 --- /dev/null +++ b/test/Conversion/dedup-by-constancy.mlir @@ -0,0 +1,72 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm="target=nvvm" --llvm-optimize-for-nvvm-target | FileCheck %s + +// CHECK-LABEL: dedup_by_constancy_full +// CHECK-COUNT-5: llvm.add +// CHECK-NOT: llvm.add +// CHECK: llvm.icmp "slt" +// CHECK-NOT: llvm.icmp "slt" +// CHECK: llvm.sdiv +// CHECK-NOT: llvm.sdiv +// CHECK: llvm.getelementptr %arg0[[[REGISTER:%[0-9]+]]] +// CHECK-COUNT-7: llvm.getelementptr %arg0[[[REGISTER]]] +// CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER]]] +#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @dedup_by_constancy_full(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<256> : tensor<1024xi32, #blocked> + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg2 : (i32) -> tensor<1024xi32, #blocked> + %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked> + %7 = arith.divsi %4, %cst : tensor<1024xi32, #blocked> + %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> + %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %10 = tt.load %9, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf16, #blocked> + %11 = tt.splat %arg1 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> + %12 = tt.addptr %11, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + tt.store %12, %10, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf16, #blocked> + tt.return + } +} + +// ----- + +// CHECK-LABEL: dedup_by_constancy_partial +// CHECK-COUNT-8: llvm.add +// CHECK-NOT: llvm.add +// CHECK: llvm.icmp "slt" +// CHECK-NOT: llvm.icmp "slt" +// CHECK-COUNT-2: llvm.sdiv +// CHECK-NOT: llvm.sdiv +// CHECK: llvm.getelementptr %arg0[[[REGISTER1:%[0-9]+]]] +// CHECK-COUNT-3: llvm.getelementptr %arg0[[[REGISTER1]]] +// CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER1]]] +// CHECK: llvm.getelementptr %arg0[[[REGISTER2:%[0-9]+]]] +// CHECK-COUNT-3: llvm.getelementptr %arg0[[[REGISTER2]]] +// CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER2]]] +#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @dedup_by_constancy_partial(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<4> : tensor<1024xi32, #blocked> + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg2 : (i32) -> tensor<1024xi32, #blocked> + %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked> + %7 = arith.divsi %4, %cst : tensor<1024xi32, #blocked> + %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> + %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %10 = tt.load %9, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf16, #blocked> + %11 = tt.splat %arg1 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> + %12 = tt.addptr %11, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + tt.store %12, %10, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf16, #blocked> + tt.return + } +} diff --git a/test/Conversion/triton_to_tritongpu.mlir b/test/Conversion/triton_to_tritongpu.mlir index 9d53e092c..92d11e758 100644 --- a/test/Conversion/triton_to_tritongpu.mlir +++ b/test/Conversion/triton_to_tritongpu.mlir @@ -79,8 +79,7 @@ tt.func public @select_op(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %2 = tt.addptr %1, %0 : tensor<128x!tt.ptr>, tensor<128xi32> %3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32> - // CHECK: %[[splat:.*]] = tt.splat %arg2 : (i1) -> tensor<128xi1, #blocked> - // CHECK-NEXT: %{{.*}} = "triton_gpu.select"(%[[splat]], %{{.*}}, %{{.*}}) : (tensor<128xi1, #blocked>, tensor<128xf32, #blocked>, tensor<128xf32, #blocked>) -> tensor<128xf32, #blocked> + // CHECK: %{{.*}} = arith.select %arg2, %{{.*}}, %{{.*}} : tensor<128xf32, #blocked> %4 = arith.select %arg2, %cst, %3 : tensor<128xf32> %5 = tt.splat %arg1 : (!tt.ptr) -> tensor<128x!tt.ptr> diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 827a6c895..3949cdfb8 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -331,7 +331,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<64x!tt.ptr, #blocked> %8 = tt.addptr %7, %4 : tensor<64x!tt.ptr, #blocked>, tensor<64xi32, #blocked> %9 = tt.splat %n_elements : (i32) -> tensor<64xi32, #blocked> - %10 = "triton_gpu.cmpi"(%4, %9) {predicate = 2 : i64} : (tensor<64xi32, #blocked>, tensor<64xi32, #blocked>) -> tensor<64xi1, #blocked> + %10 = arith.cmpi "slt", %4, %9 : tensor<64xi32, #blocked> // load op has a vector width = 1 due to the %mask's alignment // GCN-NOT: llvm.inline_asm // GCN: llvm.addrspacecast {{.*}} : !llvm.ptr to !llvm.ptr @@ -1648,6 +1648,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: atomic_add_f32 tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { +<<<<<<< HEAD // GCN-NOT: llvm.inline_asm // GCN: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr, f32 // PTX: llvm.inline_asm @@ -1655,6 +1656,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // PTX: llvm.inline_asm // PTX-SAME: @$3 atom.global.gpu.add.f32 %0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32, sem = 1 : i32} : (tensor<256x!tt.ptr, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0> +======= + // CHECK: llvm.inline_asm + // CHECK-SAME: @$3 atom.global.gpu.relaxed.add.f32 + // CHECK: llvm.inline_asm + // CHECK-SAME: @$3 atom.global.gpu.relaxed.add.f32 + %0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32, sem = 1 : i32, scope = 1 : i32} : (tensor<256x!tt.ptr, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0> +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 tt.return } } @@ -1664,12 +1672,34 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: atomic_add_f32_scalar tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr, %arg1 : i1, %arg2 : f32) { +<<<<<<< HEAD // GCN-NOT: llvm.inline_asm // GCN: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr, f32 // PTX: llvm.icmp "eq" // PTX: llvm.inline_asm // PTX-SAME: @$3 atom.global.gpu.relaxed.add.f32 %0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32, sem = 1: i32} : (!tt.ptr, f32, i1) -> f32 +======= + // CHECK: llvm.icmp "eq" + // CHECK: llvm.inline_asm + // CHECK-SAME: @$3 atom.global.gpu.relaxed.add.f32 + %0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32, sem = 1 : i32, scope = 1 : i32} : (!tt.ptr, f32, i1) -> f32 + tt.return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: atomic_add_f32 + tt.func @atomic_add_f32_sys_scope(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { + // CHECK: llvm.inline_asm + // CHECK-SAME: @$3 atom.global.sys.relaxed.add.f32 + // CHECK: llvm.inline_asm + // CHECK-SAME: @$3 atom.global.sys.relaxed.add.f32 + %0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32, sem = 1 : i32, scope = 3 : i32} : (tensor<256x!tt.ptr, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0> +>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52 tt.return } } diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index e078661b6..6a280f84e 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -190,6 +190,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- #blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: test_fp8_to_f16_conversion tt.func @test_fp8_to_f16_conversion( @@ -199,16 +200,18 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %out0 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xf16, #blocked> // CHECK-COUNT-2: cvt.rn.f16x2.e4m3x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16> %out1 = tt.fp_to_fp %in1 : tensor<128xf8E4M3FNUZ, #blocked> -> tensor<128xf16, #blocked> + // CHECK-COUNT-2: mul.rn.bf16x2 + %out2 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xbf16, #blocked> // CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8> - %out2 = tt.fp_to_fp %in2 : tensor<128xf16, #blocked> -> tensor<128xf8E5M2, #blocked> + %out3 = tt.fp_to_fp %in2 : tensor<128xf16, #blocked> -> tensor<128xf8E5M2, #blocked> // CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8> - %out3 = tt.fp_to_fp %in2 : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked> + %out4 = tt.fp_to_fp %in2 : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked> // CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8> - %out4 = tt.fp_to_fp %in3 : tensor<128xf32, #blocked> -> tensor<128xf8E5M2, #blocked> + %out5 = tt.fp_to_fp %in3 : tensor<128xf32, #blocked> -> tensor<128xf8E5M2, #blocked> // CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8> - %out5 = tt.fp_to_fp %in3 : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked> + %out6 = tt.fp_to_fp %in3 : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked> tt.return } } diff --git a/test/Triton/print.mlir b/test/Triton/print.mlir index f164e4684..c94e79350 100644 --- a/test/Triton/print.mlir +++ b/test/Triton/print.mlir @@ -15,7 +15,7 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c %3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked> %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> %5 = tt.splat %arg3 : (i32) -> tensor<1024xi32, #blocked> - %6 = "triton_gpu.cmpi"(%4, %5) <{predicate = 2 : i64}> : (tensor<1024xi32, #blocked>, tensor<1024xi32, #blocked>) -> tensor<1024xi1, #blocked> + %6 = arith.cmpi "slt", %4, %5 : tensor<1024xi32, #blocked> %7 = tt.splat %arg0 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> %9 = tt.load %8, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked> diff --git a/test/Triton/vecadd.mlir b/test/Triton/vecadd.mlir index 8a1c9123e..e421230e6 100644 --- a/test/Triton/vecadd.mlir +++ b/test/Triton/vecadd.mlir @@ -55,7 +55,7 @@ module { // %3 = tt.broadcast %1 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> // %4 = arith.addi %3, %2 : tensor<256xi32, #triton_gpu<"coalesced encoding">> // %5 = tt.broadcast %arg3 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %6 = "triton_gpu.cmpi"(%4, %5) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu<"coalesced encoding">>, tensor<256xi32, #triton_gpu<"coalesced encoding">>) -> tensor<256xi1, #triton_gpu<"coalesced encoding">> +// %6 = arith.cmpi "slt", %4, %5 : (tensor<256xi32, #triton_gpu<"coalesced encoding">>, tensor<256xi32, #triton_gpu<"coalesced encoding">>) -> tensor<256xi1, #triton_gpu<"coalesced encoding">> // %7 = tt.broadcast %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> // %8 = tt.addptr %7, %4, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xi32> // %9 = tt.broadcast %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> diff --git a/test/TritonGPU/atomic-cas.mlir b/test/TritonGPU/atomic-cas.mlir new file mode 100644 index 000000000..e8c4239ed --- /dev/null +++ b/test/TritonGPU/atomic-cas.mlir @@ -0,0 +1,27 @@ +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu 2>&1 | FileCheck %s --check-prefix=GPU +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu -convert-triton-gpu-to-llvm 2>&1 | FileCheck %s --check-prefix=LLVM + +// GPU: %9 = "tt.atomic_cas"(%8, %cst_0, %cst) <{scope = 2 : i32, sem = 4 : i32}> : (tensor<2x!tt.ptr, #blocked>, tensor<2xi64, #blocked>, tensor<2xi64, #blocked>) -> tensor<2xi64, #blocked> +// LLVM: llvm.inline_asm {{.*}} "mov.u64 $0, 0x0;\0A\09@$4 atom.global.acq_rel.cta.cas.b64 $0, [ $1 + 0 ], $2, $3;", "=l,l,l,l,b" + +module { + tt.func public @atomic_cas_kernel_0d1d2e(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<2> : tensor<2xi64> + %cst_0 = arith.constant dense<1> : tensor<2xi64> + %c2_i32 = arith.constant 2 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c2_i32 : i32 + %2 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> + %3 = tt.splat %1 : (i32) -> tensor<2xi32> + %4 = arith.addi %3, %2 : tensor<2xi32> + %5 = tt.splat %arg2 : (i32) -> tensor<2xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<2xi32> + %7 = tt.splat %arg0 : (!tt.ptr) -> tensor<2x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<2x!tt.ptr>, tensor<2xi32> + %9 = "tt.atomic_cas"(%8, %cst_0, %cst) <{sem = 4 : i32, scope = 2 : i32}> : (tensor<2x!tt.ptr>, tensor<2xi64>, tensor<2xi64>) -> tensor<2xi64> + %10 = tt.splat %arg1 : (!tt.ptr) -> tensor<2x!tt.ptr> + %11 = tt.addptr %10, %4 : tensor<2x!tt.ptr>, tensor<2xi32> + tt.store %11, %9, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<2xi64> + tt.return + } +} diff --git a/test/TritonGPU/canonicalize.mlir b/test/TritonGPU/canonicalize.mlir new file mode 100644 index 000000000..cb2bba970 --- /dev/null +++ b/test/TritonGPU/canonicalize.mlir @@ -0,0 +1,34 @@ +// RUN: triton-opt %s -split-input-file -canonicalize | FileCheck %s + + +// CHECK-LABEL: @test_canonicalize_convert_view +// CHECK-SAME: (%[[ARG:.+]]: tensor<64x64xf32 +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: %[[V:.+]] = tt.view %[[ARG]] +// CHECK: tt.return %[[V]] +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> tensor<4096xf32, #blocked1> { + %c = triton_gpu.convert_layout %arg0 : (tensor<64x64xf32, #blocked0>) -> tensor<64x64xf32, #blocked2> + %r = tt.view %c : (tensor<64x64xf32, #blocked2>) -> tensor<4096xf32, #blocked1> + tt.return %r : tensor<4096xf32, #blocked1> +} + +// ----- + +// test that the convert doesn't get combined with view if the resulting operations +// is an expensive view which would require moving data across threads. +// CHECK-LABEL: @test_canonicalize_convert_expensive_view +// CHECK-SAME: (%[[ARG:.+]]: tensor<256x16xf32 +// CHECK: %[[C:.+]] = triton_gpu.convert_layout %[[ARG]] +// CHECK: %[[V:.+]] = tt.view %[[C]] +// CHECK: tt.return %[[V]] +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +tt.func @test_canonicalize_convert_expensive_view(%arg0: tensor<256x16xf32, #blocked0>) -> tensor<4096xf32, #blocked1> { + %c = triton_gpu.convert_layout %arg0 : (tensor<256x16xf32, #blocked0>) -> tensor<256x16xf32, #blocked2> + %r = tt.view %c : (tensor<256x16xf32, #blocked2>) -> tensor<4096xf32, #blocked1> + tt.return %r : tensor<4096xf32, #blocked1> +} diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir index 4979c77bd..cffe3a902 100644 --- a/test/TritonGPU/coalesce.mlir +++ b/test/TritonGPU/coalesce.mlir @@ -86,7 +86,7 @@ tt.func public @load_tensors_two_types(%arg0: !tt.ptr {tt.divisibility = %3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked> %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> %5 = tt.splat %arg3 : (i32) -> tensor<1024xi32, #blocked> - %6 = "triton_gpu.cmpi"(%4, %5) <{predicate = 2 : i64}> : (tensor<1024xi32, #blocked>, tensor<1024xi32, #blocked>) -> tensor<1024xi1, #blocked> + %6 = arith.cmpi "slt", %4, %5 : tensor<1024xi32, #blocked> %7 = tt.splat %arg0 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> %9 = tt.load %8, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked> @@ -120,7 +120,7 @@ tt.func public @load_tensors_two_types(%arg0: !tt.ptr {tt.divisibility = %3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked> %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> %5 = tt.splat %arg3 : (i32) -> tensor<1024xi32, #blocked> - %6 = "triton_gpu.cmpi"(%4, %5) <{predicate = 2 : i64}> : (tensor<1024xi32, #blocked>, tensor<1024xi32, #blocked>) -> tensor<1024xi1, #blocked> + %6 = arith.cmpi "slt", %4, %5 : tensor<1024xi32, #blocked> %7 = tt.splat %arg0 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> %9 = tt.load %8, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked> diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 869f39146..bee7edd70 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -360,7 +360,7 @@ tt.func @loop_if(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i3 %9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1> %10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1>, tensor<64x64xi32, #blocked1> %11:2 = scf.for %arg5 = %c0 to %c32 step %c1 iter_args(%arg6 = %cst_1, %arg7 = %10) -> (tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr, #blocked1>) { - %33 = "triton_gpu.cmpi"(%i0, %i0) {predicate = 4 : i64} : (i32, i32) -> i1 + %33 = arith.cmpi "sgt", %i0, %i0 : i32 %34 = scf.if %33 -> (tensor<64x64xf32, #blocked1>) { %23 = triton_gpu.convert_layout %arg7 : (tensor<64x64x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked3> %24 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3> @@ -443,7 +443,7 @@ tt.func @select(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr %4 = triton_gpu.convert_layout %3 : (tensor<1x1xi32, #blocked1>) -> tensor<1x1xi32, #blocked2> %5 = tt.splat %0 : (i32) -> tensor<1x1xi32, #blocked2> %6 = arith.addi %5, %4 : tensor<1x1xi32, #blocked2> - %7 = "triton_gpu.cmpi"(%6, %cst_1) {predicate = 2 : i64} : (tensor<1x1xi32, #blocked2>, tensor<1x1xi32, #blocked2>) -> tensor<1x1xi1, #blocked2> + %7 = arith.cmpi "slt", %6, %cst_1 : tensor<1x1xi32, #blocked2> %8 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32, #blocked0> %9 = triton_gpu.convert_layout %8 : (tensor<512xi32, #blocked0>) -> tensor<512xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> %10 = tt.expand_dims %9 {axis = 0 : i32} : (tensor<512xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x512xi32, #blocked2> @@ -455,7 +455,7 @@ tt.func @select(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr %16 = arith.index_cast %arg3 : index to i32 %17 = tt.splat %16 : (i32) -> tensor<1x512xi32, #blocked2> %18 = arith.addi %17, %10 : tensor<1x512xi32, #blocked2> - %19 = "triton_gpu.cmpi"(%18, %cst_0) {predicate = 2 : i64} : (tensor<1x512xi32, #blocked2>, tensor<1x512xi32, #blocked2>) -> tensor<1x512xi1, #blocked2> + %19 = arith.cmpi "slt", %18, %cst_0 : tensor<1x512xi32, #blocked2> %20 = arith.addi %18, %12 : tensor<1x512xi32, #blocked2> %21 = tt.addptr %13, %20 : tensor<1x512x!tt.ptr, #blocked2>, tensor<1x512xi32, #blocked2> %22 = arith.andi %19, %14 : tensor<1x512xi1, #blocked2> @@ -464,9 +464,9 @@ tt.func @select(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr %25 = tt.load %23, %24 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1x512xf64, #blocked3> %26 = triton_gpu.convert_layout %25 : (tensor<1x512xf64, #blocked3>) -> tensor<1x512xf64, #blocked2> %27 = arith.andi %14, %19 : tensor<1x512xi1, #blocked2> - %28 = "triton_gpu.cmpf"(%arg4, %26) {predicate = 4 : i64} : (tensor<1x512xf64, #blocked2>, tensor<1x512xf64, #blocked2>) -> tensor<1x512xi1, #blocked2> + %28 = arith.cmpf "olt", %arg4, %26 : tensor<1x512xf64, #blocked2> %29 = arith.andi %27, %28 : tensor<1x512xi1, #blocked2> - %30 = "triton_gpu.select"(%29, %26, %arg4) : (tensor<1x512xi1, #blocked2>, tensor<1x512xf64, #blocked2>, tensor<1x512xf64, #blocked2>) -> tensor<1x512xf64, #blocked2> + %30 = arith.select %29, %26, %arg4 : tensor<1x512xi1, #blocked2>, tensor<1x512xf64, #blocked2> %31 = triton_gpu.convert_layout %21 : (tensor<1x512x!tt.ptr, #blocked2>) -> tensor<1x512x!tt.ptr, #blocked3> %32 = triton_gpu.convert_layout %30 : (tensor<1x512xf64, #blocked2>) -> tensor<1x512xf64, #blocked3> %33 = triton_gpu.convert_layout %27 : (tensor<1x512xi1, #blocked2>) -> tensor<1x512xi1, #blocked3> @@ -504,7 +504,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked0> %3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked0> %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked0> - %5 = "triton_gpu.cmpi"(%4, %cst_11) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> + %5 = arith.cmpi "slt", %4, %cst_11 : tensor<1024xi32, #blocked0> %6 = tt.splat %arg5 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked0> %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> %8 = triton_gpu.convert_layout %7 : (tensor<1024x!tt.ptr, #blocked0>) -> tensor<1024x!tt.ptr, #blocked1> @@ -536,282 +536,282 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %34 = arith.addf %33, %26 : tensor<1024xf32, #blocked0> %35 = arith.divf %26, %34 : tensor<1024xf32, #blocked0> %36 = tt.splat %30 : (f32) -> tensor<1024xf32, #blocked0> - %37 = "triton_gpu.cmpf"(%36, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0> - %38 = "triton_gpu.select"(%37, %cst_14, %cst_9) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> - %39 = "triton_gpu.select"(%37, %cst_8, %cst_7) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %37 = arith.cmpf "oge", %36, %35 : tensor<1024xf32, #blocked0> + %38 = arith.select %37, %cst_14, %cst_9 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %39 = arith.select %37, %cst_8, %cst_7 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %40 = arith.subi %39, %38 : tensor<1024xi32, #blocked0> - %41 = "triton_gpu.cmpi"(%40, %cst_14) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> - %42 = "triton_gpu.cmpi"(%41, %cst_5) {predicate = 1 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0> + %41 = arith.cmpi "slt", %40, %cst_14 : tensor<1024xi32, #blocked0> + %42 = arith.cmpi "ne", %41, %cst_5 : tensor<1024xi1, #blocked0> %43 = arith.remsi %40, %cst_6 : tensor<1024xi32, #blocked0> - %44 = "triton_gpu.cmpi"(%43, %cst_14) {predicate = 1 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> + %44 = arith.cmpi "ne", %43, %cst_14 : tensor<1024xi32, #blocked0> %45 = arith.divsi %40, %cst_6 : tensor<1024xi32, #blocked0> %46 = arith.subi %45, %cst_12 : tensor<1024xi32, #blocked0> - %47 = "triton_gpu.select"(%44, %46, %45) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> - %48 = "triton_gpu.select"(%42, %47, %45) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %47 = arith.select %44, %46, %45 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %48 = arith.select %42, %47, %45 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %49 = arith.addi %38, %48 : tensor<1024xi32, #blocked0> - %50 = "triton_gpu.cmpi"(%38, %39) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> - %51 = "triton_gpu.select"(%50, %49, %cst_14) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %50 = arith.cmpi "slt", %38, %39 : tensor<1024xi32, #blocked0> + %51 = arith.select %50, %49, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %52 = tt.splat %arg6 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked0> %53 = tt.addptr %52, %51 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> %54 = triton_gpu.convert_layout %53 : (tensor<1024x!tt.ptr, #blocked0>) -> tensor<1024x!tt.ptr, #blocked0> %55 = tt.load %54 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked0> - %56 = "triton_gpu.cmpf"(%55, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0> - %57 = "triton_gpu.cmpi"(%56, %cst_5) {predicate = 0 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0> + %56 = arith.cmpf "oge", %55, %35 :tensor<1024xf32, #blocked0> + %57 = arith.cmpi "eq", %56, %cst_5 : tensor<1024xi1, #blocked0> %58 = arith.andi %57, %50 : tensor<1024xi1, #blocked0> %59 = arith.addi %51, %cst_12 : tensor<1024xi32, #blocked0> - %60 = "triton_gpu.select"(%58, %59, %38) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %60 = arith.select %58, %59, %38 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %61 = arith.andi %56, %50 : tensor<1024xi1, #blocked0> - %62 = "triton_gpu.select"(%61, %51, %39) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> - %63 = "triton_gpu.cmpi"(%60, %62) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> + %62 = arith.select %61, %51, %39 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %63 = arith.cmpi "slt", %60, %62 : tensor<1024xi32, #blocked0> %64 = arith.subi %62, %60 : tensor<1024xi32, #blocked0> - %65 = "triton_gpu.cmpi"(%64, %cst_14) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> - %66 = "triton_gpu.cmpi"(%65, %cst_5) {predicate = 1 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0> + %65 = arith.cmpi "slt", %64, %cst_14 : tensor<1024xi32, #blocked0> + %66 = arith.cmpi "ne", %65, %cst_5 : tensor<1024xi1, #blocked0> %67 = arith.remsi %64, %cst_6 : tensor<1024xi32, #blocked0> - %68 = "triton_gpu.cmpi"(%67, %cst_14) {predicate = 1 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> + %68 = arith.cmpi "ne", %67, %cst_14 : tensor<1024xi32, #blocked0> %69 = arith.divsi %64, %cst_6 : tensor<1024xi32, #blocked0> %70 = arith.subi %69, %cst_12 : tensor<1024xi32, #blocked0> - %71 = "triton_gpu.select"(%68, %70, %69) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> - %72 = "triton_gpu.select"(%66, %71, %69) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %71 = arith.select %68, %70, %69 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %72 = arith.select %66, %71, %69 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %73 = arith.addi %60, %72 : tensor<1024xi32, #blocked0> - %74 = "triton_gpu.select"(%63, %73, %cst_14) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %74 = arith.select %63, %73, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %75 = tt.addptr %52, %74 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> %76 = triton_gpu.convert_layout %75 : (tensor<1024x!tt.ptr, #blocked0>) -> tensor<1024x!tt.ptr, #blocked0> %77 = tt.load %76 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked0> - %78 = "triton_gpu.cmpf"(%77, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0> - %79 = "triton_gpu.cmpi"(%78, %cst_5) {predicate = 0 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0> + %78 = arith.cmpf "oge", %77, %35 :tensor<1024xf32, #blocked0> + %79 = arith.cmpi "eq", %78, %cst_5 : tensor<1024xi1, #blocked0> %80 = arith.andi %79, %63 : tensor<1024xi1, #blocked0> %81 = arith.addi %74, %cst_12 : tensor<1024xi32, #blocked0> - %82 = "triton_gpu.select"(%80, %81, %60) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %82 = arith.select %80, %81, %60 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %83 = arith.andi %78, %63 : tensor<1024xi1, #blocked0> - %84 = "triton_gpu.select"(%83, %74, %62) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> - %85 = "triton_gpu.cmpi"(%82, %84) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> + %84 = arith.select %83, %74, %62 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %85 = arith.cmpi "slt", %82, %84 : tensor<1024xi32, #blocked0> %86 = arith.subi %84, %82 : tensor<1024xi32, #blocked0> - %87 = "triton_gpu.cmpi"(%86, %cst_14) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> - %88 = "triton_gpu.cmpi"(%87, %cst_5) {predicate = 1 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0> + %87 = arith.cmpi "slt", %86, %cst_14 : tensor<1024xi32, #blocked0> + %88 = arith.cmpi "ne", %87, %cst_5 : tensor<1024xi1, #blocked0> %89 = arith.remsi %86, %cst_6 : tensor<1024xi32, #blocked0> - %90 = "triton_gpu.cmpi"(%89, %cst_14) {predicate = 1 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> + %90 = arith.cmpi "ne", %89, %cst_14 : tensor<1024xi32, #blocked0> %91 = arith.divsi %86, %cst_6 : tensor<1024xi32, #blocked0> %92 = arith.subi %91, %cst_12 : tensor<1024xi32, #blocked0> - %93 = "triton_gpu.select"(%90, %92, %91) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> - %94 = "triton_gpu.select"(%88, %93, %91) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %93 = arith.select %90, %92, %91 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %94 = arith.select %88, %93, %91 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %95 = arith.addi %82, %94 : tensor<1024xi32, #blocked0> - %96 = "triton_gpu.select"(%85, %95, %cst_14) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %96 = arith.select %85, %95, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %97 = tt.addptr %52, %96 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> %98 = triton_gpu.convert_layout %97 : (tensor<1024x!tt.ptr, #blocked0>) -> tensor<1024x!tt.ptr, #blocked0> %99 = tt.load %98 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked0> - %100 = "triton_gpu.cmpf"(%99, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0> - %101 = "triton_gpu.cmpi"(%100, %cst_5) {predicate = 0 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0> + %100 = arith.cmpf "oge", %99, %35 : tensor<1024xf32, #blocked0> + %101 = arith.cmpi "eq", %100, %cst_5 : tensor<1024xi1, #blocked0> %102 = arith.andi %101, %85 : tensor<1024xi1, #blocked0> %103 = arith.addi %96, %cst_12 : tensor<1024xi32, #blocked0> - %104 = "triton_gpu.select"(%102, %103, %82) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %104 = arith.select %102, %103, %82 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %105 = arith.andi %100, %85 : tensor<1024xi1, #blocked0> - %106 = "triton_gpu.select"(%105, %96, %84) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> - %107 = "triton_gpu.cmpi"(%104, %106) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> + %106 = arith.select %105, %96, %84 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %107 = arith.cmpi "slt", %104, %106 : tensor<1024xi32, #blocked0> %108 = arith.subi %106, %104 : tensor<1024xi32, #blocked0> - %109 = "triton_gpu.cmpi"(%108, %cst_14) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> - %110 = "triton_gpu.cmpi"(%109, %cst_5) {predicate = 1 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0> + %109 = arith.cmpi "slt", %108, %cst_14 : tensor<1024xi32, #blocked0> + %110 = arith.cmpi "ne", %109, %cst_5 : tensor<1024xi1, #blocked0> %111 = arith.remsi %108, %cst_6 : tensor<1024xi32, #blocked0> - %112 = "triton_gpu.cmpi"(%111, %cst_14) {predicate = 1 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> + %112 = arith.cmpi "ne", %111, %cst_14 : tensor<1024xi32, #blocked0> %113 = arith.divsi %108, %cst_6 : tensor<1024xi32, #blocked0> %114 = arith.subi %113, %cst_12 : tensor<1024xi32, #blocked0> - %115 = "triton_gpu.select"(%112, %114, %113) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> - %116 = "triton_gpu.select"(%110, %115, %113) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %115 = arith.select %112, %114, %113 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %116 = arith.select %110, %115, %113 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %117 = arith.addi %104, %116 : tensor<1024xi32, #blocked0> - %118 = "triton_gpu.select"(%107, %117, %cst_14) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %118 = arith.select %107, %117, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %119 = tt.addptr %52, %118 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> %120 = triton_gpu.convert_layout %119 : (tensor<1024x!tt.ptr, #blocked0>) -> tensor<1024x!tt.ptr, #blocked0> %121 = tt.load %120 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked0> - %122 = "triton_gpu.cmpf"(%121, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0> - %123 = "triton_gpu.cmpi"(%122, %cst_5) {predicate = 0 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0> + %122 = arith.cmpf "oge", %121, %35 : tensor<1024xf32, #blocked0> + %123 = arith.cmpi "eq", %122, %cst_5 : tensor<1024xi1, #blocked0> %124 = arith.andi %123, %107 : tensor<1024xi1, #blocked0> %125 = arith.addi %118, %cst_12 : tensor<1024xi32, #blocked0> - %126 = "triton_gpu.select"(%124, %125, %104) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %126 = arith.select %124, %125, %104 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %127 = arith.andi %122, %107 : tensor<1024xi1, #blocked0> - %128 = "triton_gpu.select"(%127, %118, %106) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> - %129 = "triton_gpu.cmpi"(%126, %128) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> + %128 = arith.select %127, %118, %106 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %129 = arith.cmpi "slt", %126, %128 : tensor<1024xi32, #blocked0> %130 = arith.subi %128, %126 : tensor<1024xi32, #blocked0> - %131 = "triton_gpu.cmpi"(%130, %cst_14) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> - %132 = "triton_gpu.cmpi"(%131, %cst_5) {predicate = 1 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0> + %131 = arith.cmpi "slt", %130, %cst_14 : tensor<1024xi32, #blocked0> + %132 = arith.cmpi "ne", %131, %cst_5 : tensor<1024xi1, #blocked0> %133 = arith.remsi %130, %cst_6 : tensor<1024xi32, #blocked0> - %134 = "triton_gpu.cmpi"(%133, %cst_14) {predicate = 1 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> + %134 = arith.cmpi "ne", %133, %cst_14 : tensor<1024xi32, #blocked0> %135 = arith.divsi %130, %cst_6 : tensor<1024xi32, #blocked0> %136 = arith.subi %135, %cst_12 : tensor<1024xi32, #blocked0> - %137 = "triton_gpu.select"(%134, %136, %135) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> - %138 = "triton_gpu.select"(%132, %137, %135) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %137 = arith.select %134, %136, %135 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %138 = arith.select %132, %137, %135 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %139 = arith.addi %126, %138 : tensor<1024xi32, #blocked0> - %140 = "triton_gpu.select"(%129, %139, %cst_14) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %140 = arith.select %129, %139, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %141 = tt.addptr %52, %140 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> %142 = triton_gpu.convert_layout %141 : (tensor<1024x!tt.ptr, #blocked0>) -> tensor<1024x!tt.ptr, #blocked0> %143 = tt.load %142 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked0> - %144 = "triton_gpu.cmpf"(%143, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0> - %145 = "triton_gpu.cmpi"(%144, %cst_5) {predicate = 0 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0> + %144 = arith.cmpf "oge", %143, %35 : tensor<1024xf32, #blocked0> + %145 = arith.cmpi "eq", %144, %cst_5 : tensor<1024xi1, #blocked0> %146 = arith.andi %145, %129 : tensor<1024xi1, #blocked0> %147 = arith.addi %140, %cst_12 : tensor<1024xi32, #blocked0> - %148 = "triton_gpu.select"(%146, %147, %126) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %148 = arith.select %146, %147, %126 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %149 = arith.andi %144, %129 : tensor<1024xi1, #blocked0> - %150 = "triton_gpu.select"(%149, %140, %128) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> - %151 = "triton_gpu.cmpi"(%148, %150) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> + %150 = arith.select %149, %140, %128 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %151 = arith.cmpi "slt", %148, %150 : tensor<1024xi32, #blocked0> %152 = arith.subi %150, %148 : tensor<1024xi32, #blocked0> - %153 = "triton_gpu.cmpi"(%152, %cst_14) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> - %154 = "triton_gpu.cmpi"(%153, %cst_5) {predicate = 1 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0> + %153 = arith.cmpi "slt", %152, %cst_14 : tensor<1024xi32, #blocked0> + %154 = arith.cmpi "ne", %153, %cst_5 : tensor<1024xi1, #blocked0> %155 = arith.remsi %152, %cst_6 : tensor<1024xi32, #blocked0> - %156 = "triton_gpu.cmpi"(%155, %cst_14) {predicate = 1 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> + %156 = arith.cmpi "ne", %155, %cst_14 : tensor<1024xi32, #blocked0> %157 = arith.divsi %152, %cst_6 : tensor<1024xi32, #blocked0> %158 = arith.subi %157, %cst_12 : tensor<1024xi32, #blocked0> - %159 = "triton_gpu.select"(%156, %158, %157) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> - %160 = "triton_gpu.select"(%154, %159, %157) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %159 = arith.select %156, %158, %157 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %160 = arith.select %154, %159, %157 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %161 = arith.addi %148, %160 : tensor<1024xi32, #blocked0> - %162 = "triton_gpu.select"(%151, %161, %cst_14) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %162 = arith.select %151, %161, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %163 = tt.addptr %52, %162 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> %164 = triton_gpu.convert_layout %163 : (tensor<1024x!tt.ptr, #blocked0>) -> tensor<1024x!tt.ptr, #blocked0> %165 = tt.load %164 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked0> - %166 = "triton_gpu.cmpf"(%165, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0> - %167 = "triton_gpu.cmpi"(%166, %cst_5) {predicate = 0 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0> + %166 = arith.cmpf "oge", %165, %35 : tensor<1024xf32, #blocked0> + %167 = arith.cmpi "eq", %166, %cst_5 : tensor<1024xi1, #blocked0> %168 = arith.andi %167, %151 : tensor<1024xi1, #blocked0> %169 = arith.addi %162, %cst_12 : tensor<1024xi32, #blocked0> - %170 = "triton_gpu.select"(%168, %169, %148) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %170 = arith.select %168, %169, %148 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %171 = arith.andi %166, %151 : tensor<1024xi1, #blocked0> - %172 = "triton_gpu.select"(%171, %162, %150) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> - %173 = "triton_gpu.cmpi"(%170, %172) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> + %172 = arith.select %171, %162, %150 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %173 = arith.cmpi "slt", %170, %172 : tensor<1024xi32, #blocked0> %174 = arith.subi %172, %170 : tensor<1024xi32, #blocked0> - %175 = "triton_gpu.cmpi"(%174, %cst_14) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> - %176 = "triton_gpu.cmpi"(%175, %cst_5) {predicate = 1 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0> + %175 = arith.cmpi "slt", %174, %cst_14 : tensor<1024xi32, #blocked0> + %176 = arith.cmpi "ne", %175, %cst_5 : tensor<1024xi1, #blocked0> %177 = arith.remsi %174, %cst_6 : tensor<1024xi32, #blocked0> - %178 = "triton_gpu.cmpi"(%177, %cst_14) {predicate = 1 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> + %178 = arith.cmpi "ne", %177, %cst_14 : tensor<1024xi32, #blocked0> %179 = arith.divsi %174, %cst_6 : tensor<1024xi32, #blocked0> %180 = arith.subi %179, %cst_12 : tensor<1024xi32, #blocked0> - %181 = "triton_gpu.select"(%178, %180, %179) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> - %182 = "triton_gpu.select"(%176, %181, %179) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %181 = arith.select %178, %180, %179 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %182 = arith.select %176, %181, %179 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %183 = arith.addi %170, %182 : tensor<1024xi32, #blocked0> - %184 = "triton_gpu.select"(%173, %183, %cst_14) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %184 = arith.select %173, %183, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %185 = tt.addptr %52, %184 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> %186 = triton_gpu.convert_layout %185 : (tensor<1024x!tt.ptr, #blocked0>) -> tensor<1024x!tt.ptr, #blocked0> %187 = tt.load %186 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked0> - %188 = "triton_gpu.cmpf"(%187, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0> - %189 = "triton_gpu.cmpi"(%188, %cst_5) {predicate = 0 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0> + %188 = arith.cmpf "oge", %187, %35 : tensor<1024xf32, #blocked0> + %189 = arith.cmpi "eq", %188, %cst_5 : tensor<1024xi1, #blocked0> %190 = arith.andi %189, %173 : tensor<1024xi1, #blocked0> %191 = arith.addi %184, %cst_12 : tensor<1024xi32, #blocked0> - %192 = "triton_gpu.select"(%190, %191, %170) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %192 = arith.select %190, %191, %170 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %193 = arith.andi %188, %173 : tensor<1024xi1, #blocked0> - %194 = "triton_gpu.select"(%193, %184, %172) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> - %195 = "triton_gpu.cmpi"(%192, %194) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> + %194 = arith.select %193, %184, %172 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %195 = arith.cmpi "slt", %192, %194 : tensor<1024xi32, #blocked0> %196 = arith.subi %194, %192 : tensor<1024xi32, #blocked0> - %197 = "triton_gpu.cmpi"(%196, %cst_14) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> - %198 = "triton_gpu.cmpi"(%197, %cst_5) {predicate = 1 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0> + %197 = arith.cmpi "slt", %196, %cst_14 : tensor<1024xi32, #blocked0> + %198 = arith.cmpi "ne", %197, %cst_5 : tensor<1024xi1, #blocked0> %199 = arith.remsi %196, %cst_6 : tensor<1024xi32, #blocked0> - %200 = "triton_gpu.cmpi"(%199, %cst_14) {predicate = 1 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> + %200 = arith.cmpi "ne", %199, %cst_14 : tensor<1024xi32, #blocked0> %201 = arith.divsi %196, %cst_6 : tensor<1024xi32, #blocked0> %202 = arith.subi %201, %cst_12 : tensor<1024xi32, #blocked0> - %203 = "triton_gpu.select"(%200, %202, %201) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> - %204 = "triton_gpu.select"(%198, %203, %201) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %203 = arith.select %200, %202, %201 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %204 = arith.select %198, %203, %201 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %205 = arith.addi %192, %204 : tensor<1024xi32, #blocked0> - %206 = "triton_gpu.select"(%195, %205, %cst_14) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %206 = arith.select %195, %205, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %207 = tt.addptr %52, %206 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> %208 = triton_gpu.convert_layout %207 : (tensor<1024x!tt.ptr, #blocked0>) -> tensor<1024x!tt.ptr, #blocked0> %209 = tt.load %208 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked0> - %210 = "triton_gpu.cmpf"(%209, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0> - %211 = "triton_gpu.cmpi"(%210, %cst_5) {predicate = 0 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0> + %210 = arith.cmpf "oge", %209, %35 :tensor<1024xf32, #blocked0> + %211 = arith.cmpi "eq", %210, %cst_5 : tensor<1024xi1, #blocked0> %212 = arith.andi %211, %195 : tensor<1024xi1, #blocked0> %213 = arith.addi %206, %cst_12 : tensor<1024xi32, #blocked0> - %214 = "triton_gpu.select"(%212, %213, %192) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %214 = arith.select %212, %213, %192 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %215 = arith.andi %210, %195 : tensor<1024xi1, #blocked0> - %216 = "triton_gpu.select"(%215, %206, %194) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> - %217 = "triton_gpu.cmpi"(%214, %216) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> + %216 = arith.select %215, %206, %194 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %217 = arith.cmpi "slt", %214, %216 : tensor<1024xi32, #blocked0> %218 = arith.subi %216, %214 : tensor<1024xi32, #blocked0> - %219 = "triton_gpu.cmpi"(%218, %cst_14) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> - %220 = "triton_gpu.cmpi"(%219, %cst_5) {predicate = 1 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0> + %219 = arith.cmpi "slt", %218, %cst_14 : tensor<1024xi32, #blocked0> + %220 = arith.cmpi "ne", %219, %cst_5 : tensor<1024xi1, #blocked0> %221 = arith.remsi %218, %cst_6 : tensor<1024xi32, #blocked0> - %222 = "triton_gpu.cmpi"(%221, %cst_14) {predicate = 1 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> + %222 = arith.cmpi "ne", %221, %cst_14 : tensor<1024xi32, #blocked0> %223 = arith.divsi %218, %cst_6 : tensor<1024xi32, #blocked0> %224 = arith.subi %223, %cst_12 : tensor<1024xi32, #blocked0> - %225 = "triton_gpu.select"(%222, %224, %223) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> - %226 = "triton_gpu.select"(%220, %225, %223) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %225 = arith.select %222, %224, %223 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %226 = arith.select %220, %225, %223 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %227 = arith.addi %214, %226 : tensor<1024xi32, #blocked0> - %228 = "triton_gpu.select"(%217, %227, %cst_14) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %228 = arith.select %217, %227, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %229 = tt.addptr %52, %228 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> %230 = triton_gpu.convert_layout %229 : (tensor<1024x!tt.ptr, #blocked0>) -> tensor<1024x!tt.ptr, #blocked0> %231 = tt.load %230 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked0> - %232 = "triton_gpu.cmpf"(%231, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0> - %233 = "triton_gpu.cmpi"(%232, %cst_5) {predicate = 0 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0> + %232 = arith.cmpf "oge", %231, %35 : tensor<1024xf32, #blocked0> + %233 = arith.cmpi "eq", %232, %cst_5 : tensor<1024xi1, #blocked0> %234 = arith.andi %233, %217 : tensor<1024xi1, #blocked0> %235 = arith.addi %228, %cst_12 : tensor<1024xi32, #blocked0> - %236 = "triton_gpu.select"(%234, %235, %214) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %236 = arith.select %234, %235, %214 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %237 = arith.andi %232, %217 : tensor<1024xi1, #blocked0> - %238 = "triton_gpu.select"(%237, %228, %216) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> - %239 = "triton_gpu.cmpi"(%236, %238) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> + %238 = arith.select %237, %228, %216 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %239 = arith.cmpi "slt", %236, %238 : tensor<1024xi32, #blocked0> %240 = arith.subi %238, %236 : tensor<1024xi32, #blocked0> - %241 = "triton_gpu.cmpi"(%240, %cst_14) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> - %242 = "triton_gpu.cmpi"(%241, %cst_5) {predicate = 1 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0> + %241 = arith.cmpi "slt", %240, %cst_14 : tensor<1024xi32, #blocked0> + %242 = arith.cmpi "ne", %241, %cst_5 : tensor<1024xi1, #blocked0> %243 = arith.remsi %240, %cst_6 : tensor<1024xi32, #blocked0> - %244 = "triton_gpu.cmpi"(%243, %cst_14) {predicate = 1 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> + %244 = arith.cmpi "ne", %243, %cst_14 : tensor<1024xi32, #blocked0> %245 = arith.divsi %240, %cst_6 : tensor<1024xi32, #blocked0> %246 = arith.subi %245, %cst_12 : tensor<1024xi32, #blocked0> - %247 = "triton_gpu.select"(%244, %246, %245) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> - %248 = "triton_gpu.select"(%242, %247, %245) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %247 = arith.select %244, %246, %245 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %248 = arith.select %242, %247, %245 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %249 = arith.addi %236, %248 : tensor<1024xi32, #blocked0> - %250 = "triton_gpu.select"(%239, %249, %cst_14) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %250 = arith.select %239, %249, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %251 = tt.addptr %52, %250 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> %252 = triton_gpu.convert_layout %251 : (tensor<1024x!tt.ptr, #blocked0>) -> tensor<1024x!tt.ptr, #blocked0> %253 = tt.load %252 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked0> - %254 = "triton_gpu.cmpf"(%253, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0> - %255 = "triton_gpu.cmpi"(%254, %cst_5) {predicate = 0 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0> + %254 = arith.cmpf "oge", %253, %35 : tensor<1024xf32, #blocked0> + %255 = arith.cmpi "eq", %254, %cst_5 : tensor<1024xi1, #blocked0> %256 = arith.andi %255, %239 : tensor<1024xi1, #blocked0> %257 = arith.addi %250, %cst_12 : tensor<1024xi32, #blocked0> - %258 = "triton_gpu.select"(%256, %257, %236) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %258 = arith.select %256, %257, %236 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %259 = arith.andi %254, %239 : tensor<1024xi1, #blocked0> - %260 = "triton_gpu.select"(%259, %250, %238) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> - %261 = "triton_gpu.cmpi"(%258, %260) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> + %260 = arith.select %259, %250, %238 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %261 = arith.cmpi "slt", %258, %260 : tensor<1024xi32, #blocked0> %262 = arith.subi %260, %258 : tensor<1024xi32, #blocked0> - %263 = "triton_gpu.cmpi"(%262, %cst_14) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> - %264 = "triton_gpu.cmpi"(%263, %cst_5) {predicate = 1 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0> + %263 = arith.cmpi "slt", %262, %cst_14 : tensor<1024xi32, #blocked0> + %264 = arith.cmpi "ne", %263, %cst_5 : tensor<1024xi1, #blocked0> %265 = arith.remsi %262, %cst_6 : tensor<1024xi32, #blocked0> - %266 = "triton_gpu.cmpi"(%265, %cst_14) {predicate = 1 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> + %266 = arith.cmpi "ne", %265, %cst_14 : tensor<1024xi32, #blocked0> %267 = arith.divsi %262, %cst_6 : tensor<1024xi32, #blocked0> %268 = arith.subi %267, %cst_12 : tensor<1024xi32, #blocked0> - %269 = "triton_gpu.select"(%266, %268, %267) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> - %270 = "triton_gpu.select"(%264, %269, %267) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %269 = arith.select %266, %268, %267 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %270 = arith.select %264, %269, %267 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %271 = arith.addi %258, %270 : tensor<1024xi32, #blocked0> - %272 = "triton_gpu.select"(%261, %271, %cst_14) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %272 = arith.select %261, %271, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %273 = tt.addptr %52, %272 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> %274 = triton_gpu.convert_layout %273 : (tensor<1024x!tt.ptr, #blocked0>) -> tensor<1024x!tt.ptr, #blocked0> %275 = tt.load %274 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked0> - %276 = "triton_gpu.cmpf"(%275, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0> - %277 = "triton_gpu.cmpi"(%276, %cst_5) {predicate = 0 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0> + %276 = arith.cmpf "oge", %275, %35 : tensor<1024xf32, #blocked0> + %277 = arith.cmpi "eq", %276, %cst_5 : tensor<1024xi1, #blocked0> %278 = arith.andi %277, %261 : tensor<1024xi1, #blocked0> %279 = arith.addi %272, %cst_12 : tensor<1024xi32, #blocked0> - %280 = "triton_gpu.select"(%278, %279, %258) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %280 = arith.select %278, %279, %258 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %281 = arith.andi %276, %261 : tensor<1024xi1, #blocked0> - %282 = "triton_gpu.select"(%281, %272, %260) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> - %283 = "triton_gpu.cmpi"(%280, %282) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> + %282 = arith.select %281, %272, %260 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %283 = arith.cmpi "slt", %280, %282 : tensor<1024xi32, #blocked0> %284 = arith.subi %282, %280 : tensor<1024xi32, #blocked0> - %285 = "triton_gpu.cmpi"(%284, %cst_14) {predicate = 2 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> - %286 = "triton_gpu.cmpi"(%285, %cst_5) {predicate = 1 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0> + %285 = arith.cmpi "slt", %284, %cst_14 : tensor<1024xi32, #blocked0> + %286 = arith.cmpi "ne", %285, %cst_5 : tensor<1024xi1, #blocked0> %287 = arith.remsi %284, %cst_6 : tensor<1024xi32, #blocked0> - %288 = "triton_gpu.cmpi"(%287, %cst_14) {predicate = 1 : i64} : (tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi1, #blocked0> + %288 = arith.cmpi "ne", %287, %cst_14 : tensor<1024xi32, #blocked0> %289 = arith.divsi %284, %cst_6 : tensor<1024xi32, #blocked0> %290 = arith.subi %289, %cst_12 : tensor<1024xi32, #blocked0> - %291 = "triton_gpu.select"(%288, %290, %289) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> - %292 = "triton_gpu.select"(%286, %291, %289) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %291 = arith.select %288, %290, %289 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> + %292 = arith.select %286, %291, %289 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %293 = arith.addi %280, %292 : tensor<1024xi32, #blocked0> - %294 = "triton_gpu.select"(%283, %293, %cst_14) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %294 = arith.select %283, %293, %cst_14 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %295 = tt.addptr %52, %294 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> %296 = triton_gpu.convert_layout %295 : (tensor<1024x!tt.ptr, #blocked0>) -> tensor<1024x!tt.ptr, #blocked0> %297 = tt.load %296 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked0> - %298 = "triton_gpu.cmpf"(%297, %35) {predicate = 3 : i64} : (tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xi1, #blocked0> - %299 = "triton_gpu.cmpi"(%298, %cst_5) {predicate = 0 : i64} : (tensor<1024xi1, #blocked0>, tensor<1024xi1, #blocked0>) -> tensor<1024xi1, #blocked0> + %298 = arith.cmpf "oge", %297, %35 :tensor<1024xf32, #blocked0> + %299 = arith.cmpi "eq", %298, %cst_5 : tensor<1024xi1, #blocked0> %300 = arith.andi %299, %283 : tensor<1024xi1, #blocked0> %301 = arith.addi %294, %cst_12 : tensor<1024xi32, #blocked0> - %302 = "triton_gpu.select"(%300, %301, %280) : (tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0>, tensor<1024xi32, #blocked0>) -> tensor<1024xi32, #blocked0> + %302 = arith.select %300, %301, %280 : tensor<1024xi1, #blocked0>, tensor<1024xi32, #blocked0> %303 = arith.extsi %cst_12 : tensor<1024xi32, #blocked0> to tensor<1024xi64, #blocked0> - %304 = "triton_gpu.cmpi"(%17, %303) {predicate = 0 : i64} : (tensor<1024xi64, #blocked0>, tensor<1024xi64, #blocked0>) -> tensor<1024xi1, #blocked0> + %304 = arith.cmpi "eq", %17, %303 : tensor<1024xi64, #blocked0> %305 = arith.fptosi %23 : tensor<1024xf32, #blocked0> to tensor<1024xi64, #blocked0> %306 = arith.extsi %cst_14 : tensor<1024xi32, #blocked0> to tensor<1024xi64, #blocked0> - %307 = "triton_gpu.cmpi"(%306, %305) {predicate = 4 : i64} : (tensor<1024xi64, #blocked0>, tensor<1024xi64, #blocked0>) -> tensor<1024xi1, #blocked0> + %307 = arith.cmpi "sgt", %306, %305 : tensor<1024xi64, #blocked0> %308 = arith.extsi %cst_4 : tensor<1024xi32, #blocked0> to tensor<1024xi64, #blocked0> - %309 = "triton_gpu.cmpi"(%305, %308) {predicate = 4 : i64} : (tensor<1024xi64, #blocked0>, tensor<1024xi64, #blocked0>) -> tensor<1024xi1, #blocked0> - %310 = "triton_gpu.select"(%309, %306, %305) : (tensor<1024xi1, #blocked0>, tensor<1024xi64, #blocked0>, tensor<1024xi64, #blocked0>) -> tensor<1024xi64, #blocked0> - %311 = "triton_gpu.select"(%307, %306, %310) : (tensor<1024xi1, #blocked0>, tensor<1024xi64, #blocked0>, tensor<1024xi64, #blocked0>) -> tensor<1024xi64, #blocked0> - %312 = "triton_gpu.select"(%304, %311, %306) : (tensor<1024xi1, #blocked0>, tensor<1024xi64, #blocked0>, tensor<1024xi64, #blocked0>) -> tensor<1024xi64, #blocked0> + %309 = arith.cmpi "sgt", %305, %308 : tensor<1024xi64, #blocked0> + %310 = arith.select %309, %306, %305 : tensor<1024xi1, #blocked0>, tensor<1024xi64, #blocked0> + %311 = arith.select %307, %306, %310 : tensor<1024xi1, #blocked0>, tensor<1024xi64, #blocked0> + %312 = arith.select %304, %311, %306 : tensor<1024xi1, #blocked0>, tensor<1024xi64, #blocked0> %313 = arith.extsi %cst_3 : tensor<1024xi32, #blocked0> to tensor<1024xi64, #blocked0> %314 = arith.muli %312, %313 : tensor<1024xi64, #blocked0> %315 = arith.extsi %302 : tensor<1024xi32, #blocked0> to tensor<1024xi64, #blocked0> @@ -823,7 +823,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %321 = triton_gpu.convert_layout %320 : (tensor<1024x!tt.ptr, #blocked0>) -> tensor<1024x!tt.ptr, #blocked0> %322 = tt.load %321 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf64, #blocked0> %323 = arith.extf %cst_2 : tensor<1024xf32, #blocked0> to tensor<1024xf64, #blocked0> - %324 = "triton_gpu.cmpf"(%322, %323) {predicate = 2 : i64} : (tensor<1024xf64, #blocked0>, tensor<1024xf64, #blocked0>) -> tensor<1024xi1, #blocked0> + %324 = arith.cmpf "ogt", %322, %323 : tensor<1024xf64, #blocked0> %325 = tt.splat %arg10 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked0> %326 = tt.addptr %325, %318 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi64, #blocked0> %327 = triton_gpu.convert_layout %326 : (tensor<1024x!tt.ptr, #blocked0>) -> tensor<1024x!tt.ptr, #blocked0> @@ -833,7 +833,7 @@ tt.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %331 = arith.mulf %330, %cst_1 : tensor<1024xf32, #blocked0> %332 = arith.mulf %35, %cst_0 : tensor<1024xf32, #blocked0> %333 = arith.addf %331, %332 : tensor<1024xf32, #blocked0> - %334 = "triton_gpu.select"(%324, %333, %35) : (tensor<1024xi1, #blocked0>, tensor<1024xf32, #blocked0>, tensor<1024xf32, #blocked0>) -> tensor<1024xf32, #blocked0> + %334 = arith.select %324, %333, %35 : tensor<1024xi1, #blocked0>, tensor<1024xf32, #blocked0> %335 = tt.addptr %319, %317 : tensor<1024x!tt.ptr, #blocked0>, tensor<1024xi32, #blocked0> %336 = triton_gpu.convert_layout %335 : (tensor<1024x!tt.ptr, #blocked0>) -> tensor<1024x!tt.ptr, #blocked0> %337 = tt.load %336 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf64, #blocked0> @@ -896,10 +896,10 @@ tt.func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: ! %5 = triton_gpu.convert_layout %4 : (tensor<16x1xi32, #blocked1>) -> tensor<16x1xi32, #blocked2> %6 = tt.splat %1 : (i32) -> tensor<16x1xi32, #blocked2> %7 = arith.addi %6, %5 : tensor<16x1xi32, #blocked2> - %8 = "triton_gpu.cmpi"(%7, %cst_1) {predicate = 2 : i64} : (tensor<16x1xi32, #blocked2>, tensor<16x1xi32, #blocked2>) -> tensor<16x1xi1, #blocked2> + %8 = arith.cmpi "slt", %7, %cst_1 : tensor<16x1xi32, #blocked2> %9 = triton_gpu.convert_layout %2 : (tensor<16xi32, #blocked0>) -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> %10 = tt.expand_dims %9 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>>) -> tensor<1x16xi32, #blocked3> - %11 = "triton_gpu.cmpi"(%10, %cst_0) {predicate = 2 : i64} : (tensor<1x16xi32, #blocked3>, tensor<1x16xi32, #blocked3>) -> tensor<1x16xi1, #blocked3> + %11 = arith.cmpi "slt", %10, %cst_0 : tensor<1x16xi32, #blocked3> %12 = arith.muli %7, %cst : tensor<16x1xi32, #blocked2> %13 = tt.broadcast %10 : (tensor<1x16xi32, #blocked3>) -> tensor<16x16xi32, #blocked3> %14 = triton_gpu.convert_layout %13 : (tensor<16x16xi32, #blocked3>) -> tensor<16x16xi32, #blocked2> @@ -915,9 +915,9 @@ tt.func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: ! %24 = triton_gpu.convert_layout %22 : (tensor<16x16xi1, #blocked2>) -> tensor<16x16xi1, #blocked4> %25 = tt.load %23, %24 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<16x16xf32, #blocked4> %26 = triton_gpu.convert_layout %25 : (tensor<16x16xf32, #blocked4>) -> tensor<16x16xf32, #blocked2> - %27 = "triton_gpu.cmpf"(%cst_2, %26) {predicate = 4 : i64} : (tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xi1, #blocked2> + %27 = arith.cmpf "olt", %cst_2, %26 : tensor<16x16xf32, #blocked2> %28 = arith.andi %22, %27 : tensor<16x16xi1, #blocked2> - %29 = "triton_gpu.select"(%28, %26, %cst_2) : (tensor<16x16xi1, #blocked2>, tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xf32, #blocked2> + %29 = arith.select %28, %26, %cst_2 : tensor<16x16xi1, #blocked2>, tensor<16x16xf32, #blocked2> %30 = "tt.reduce" (%29) ({ ^bb0(%arg4: f32, %arg5: f32): %max = arith.maximumf %arg4, %arg5 : f32 @@ -937,7 +937,7 @@ tt.func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: ! %42 = arith.subf %40, %41 : tensor<16x16xf32, #blocked2> %43 = math.exp %42 : tensor<16x16xf32, #blocked2> %44 = arith.addf %36, %43 : tensor<16x16xf32, #blocked2> - %45 = "triton_gpu.select"(%22, %44, %36) : (tensor<16x16xi1, #blocked2>, tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xf32, #blocked2> + %45 = arith.select %22, %44, %36 : tensor<16x16xi1, #blocked2>, tensor<16x16xf32, #blocked2> %46 = "tt.reduce" (%45) ({ ^bb0(%arg4: f32, %arg5: f32): %add = arith.addf %arg4, %arg5 : f32 @@ -997,7 +997,7 @@ tt.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt %5 = triton_gpu.convert_layout %4 : (tensor<64x1xi32, #blocked1>) -> tensor<64x1xi32, #blocked2> %6 = tt.splat %1 : (i32) -> tensor<64x1xi32, #blocked2> %7 = arith.addi %6, %5 : tensor<64x1xi32, #blocked2> - %8 = "triton_gpu.cmpi"(%7, %cst_5) {predicate = 2 : i64} : (tensor<64x1xi32, #blocked2>, tensor<64x1xi32, #blocked2>) -> tensor<64x1xi1, #blocked2> + %8 = arith.cmpi "slt", %7, %cst_5 : tensor<64x1xi32, #blocked2> %9 = triton_gpu.convert_layout %2 : (tensor<64xi32, #blocked0>) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> %10 = tt.expand_dims %9 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>>) -> tensor<1x64xi32, #blocked3> %11 = arith.remsi %7, %cst_4 : tensor<64x1xi32, #blocked2> @@ -1018,7 +1018,7 @@ tt.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt %44 = arith.index_cast %arg6 : index to i32 %45 = tt.splat %44 : (i32) -> tensor<1x64xi32, #blocked3> %46 = arith.addi %45, %10 : tensor<1x64xi32, #blocked3> - %47 = "triton_gpu.cmpi"(%46, %cst_2) {predicate = 2 : i64} : (tensor<1x64xi32, #blocked3>, tensor<1x64xi32, #blocked3>) -> tensor<1x64xi1, #blocked3> + %47 = arith.cmpi "slt", %46, %cst_2 : tensor<1x64xi32, #blocked3> %48 = tt.broadcast %46 : (tensor<1x64xi32, #blocked3>) -> tensor<64x64xi32, #blocked3> %49 = triton_gpu.convert_layout %48 : (tensor<64x64xi32, #blocked3>) -> tensor<64x64xi32, #blocked2> %50 = arith.addi %49, %16 : tensor<64x64xi32, #blocked2> @@ -1039,13 +1039,13 @@ tt.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt %65 = tt.load %63, %64 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<64x64xf32, #blocked5> %66 = triton_gpu.convert_layout %65 : (tensor<64x64xf32, #blocked5>) -> tensor<64x64xf32, #blocked2> %67 = arith.addf %59, %66 : tensor<64x64xf32, #blocked2> - %68 = "triton_gpu.cmpf"(%67, %67) {predicate = 13 : i64} : (tensor<64x64xf32, #blocked2>, tensor<64x64xf32, #blocked2>) -> tensor<64x64xi1, #blocked2> - %69 = "triton_gpu.cmpf"(%67, %cst) {predicate = 2 : i64} : (tensor<64x64xf32, #blocked2>, tensor<64x64xf32, #blocked2>) -> tensor<64x64xi1, #blocked2> - %70 = "triton_gpu.select"(%69, %67, %cst) : (tensor<64x64xi1, #blocked2>, tensor<64x64xf32, #blocked2>, tensor<64x64xf32, #blocked2>) -> tensor<64x64xf32, #blocked2> - %71 = "triton_gpu.select"(%68, %67, %70) : (tensor<64x64xi1, #blocked2>, tensor<64x64xf32, #blocked2>, tensor<64x64xf32, #blocked2>) -> tensor<64x64xf32, #blocked2> + %68 = arith.cmpf "une", %67, %67 : tensor<64x64xf32, #blocked2> + %69 = arith.cmpf "ogt", %67, %cst : tensor<64x64xf32, #blocked2> + %70 = arith.select %69, %67, %cst : tensor<64x64xi1, #blocked2>, tensor<64x64xf32, #blocked2> + %71 = arith.select %68, %67, %70 : tensor<64x64xi1, #blocked2>, tensor<64x64xf32, #blocked2> %72 = math.exp %71 : tensor<64x64xf32, #blocked2> %73 = arith.addf %arg7, %72 : tensor<64x64xf32, #blocked2> - %74 = "triton_gpu.select"(%54, %73, %arg7) : (tensor<64x64xi1, #blocked2>, tensor<64x64xf32, #blocked2>, tensor<64x64xf32, #blocked2>) -> tensor<64x64xf32, #blocked2> + %74 = arith.select %54, %73, %arg7 : tensor<64x64xi1, #blocked2>, tensor<64x64xf32, #blocked2> scf.yield %74 : tensor<64x64xf32, #blocked2> } %26 = "tt.reduce" (%25) ({ @@ -1074,7 +1074,7 @@ tt.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt %44 = arith.index_cast %arg6 : index to i32 %45 = tt.splat %44 : (i32) -> tensor<1x64xi32, #blocked3> %46 = arith.addi %45, %10 : tensor<1x64xi32, #blocked3> - %47 = "triton_gpu.cmpi"(%46, %cst_2) {predicate = 2 : i64} : (tensor<1x64xi32, #blocked3>, tensor<1x64xi32, #blocked3>) -> tensor<1x64xi1, #blocked3> + %47 = arith.cmpi "slt", %46, %cst_2 : tensor<1x64xi32, #blocked3> %48 = tt.broadcast %46 : (tensor<1x64xi32, #blocked3>) -> tensor<64x64xi32, #blocked3> %49 = triton_gpu.convert_layout %48 : (tensor<64x64xi32, #blocked3>) -> tensor<64x64xi32, #blocked2> %50 = arith.addi %49, %32 : tensor<64x64xi32, #blocked2> @@ -1095,10 +1095,10 @@ tt.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt %65 = tt.load %63, %64 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<64x64xf32, #blocked5> %66 = triton_gpu.convert_layout %65 : (tensor<64x64xf32, #blocked5>) -> tensor<64x64xf32, #blocked2> %67 = arith.addf %59, %66 : tensor<64x64xf32, #blocked2> - %68 = "triton_gpu.cmpf"(%67, %67) {predicate = 13 : i64} : (tensor<64x64xf32, #blocked2>, tensor<64x64xf32, #blocked2>) -> tensor<64x64xi1, #blocked2> - %69 = "triton_gpu.cmpf"(%67, %cst) {predicate = 2 : i64} : (tensor<64x64xf32, #blocked2>, tensor<64x64xf32, #blocked2>) -> tensor<64x64xi1, #blocked2> - %70 = "triton_gpu.select"(%69, %67, %cst) : (tensor<64x64xi1, #blocked2>, tensor<64x64xf32, #blocked2>, tensor<64x64xf32, #blocked2>) -> tensor<64x64xf32, #blocked2> - %71 = "triton_gpu.select"(%68, %67, %70) : (tensor<64x64xi1, #blocked2>, tensor<64x64xf32, #blocked2>, tensor<64x64xf32, #blocked2>) -> tensor<64x64xf32, #blocked2> + %68 = arith.cmpf "une", %67, %67 : tensor<64x64xf32, #blocked2> + %69 = arith.cmpf "ogt", %67, %cst : tensor<64x64xf32, #blocked2> + %70 = arith.select %69, %67, %cst : tensor<64x64xi1, #blocked2>, tensor<64x64xf32, #blocked2> + %71 = arith.select %68, %67, %70 : tensor<64x64xi1, #blocked2>, tensor<64x64xf32, #blocked2> %72 = math.exp %71 : tensor<64x64xf32, #blocked2> %73 = arith.divf %72, %41 : tensor<64x64xf32, #blocked2> %74 = tt.addptr %42, %50 : tensor<64x64x!tt.ptr, #blocked2>, tensor<64x64xi32, #blocked2> @@ -1163,7 +1163,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #blocked1> %1 = triton_gpu.convert_layout %0 : (tensor<2xi32, #blocked1>) -> tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> %2 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x2xi32, #blocked> - %3 = "triton_gpu.cmpi"(%2, %cst_0) {predicate = 2 : i64} : (tensor<1x2xi32, #blocked>, tensor<1x2xi32, #blocked>) -> tensor<1x2xi1, #blocked> + %3 = arith.cmpi "slt", %2, %cst_0 : tensor<1x2xi32, #blocked> %4 = "tt.reduce" (%cst) ({ ^bb0(%arg3: i32, %arg4: i32): %add = arith.addi %arg3, %arg4 : i32 @@ -1220,7 +1220,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { %4 = triton_gpu.convert_layout %3 : (tensor<1x1xi32, #blocked2>) -> tensor<1x1xi32, #blocked> %5 = tt.splat %0 : (i32) -> tensor<1x1xi32, #blocked> %6 = arith.addi %5, %4 : tensor<1x1xi32, #blocked> - %7 = "triton_gpu.cmpi"(%6, %cst_5) {predicate = 2 : i64} : (tensor<1x1xi32, #blocked>, tensor<1x1xi32, #blocked>) -> tensor<1x1xi1, #blocked> + %7 = arith.cmpi "slt", %6, %cst_5 : tensor<1x1xi32, #blocked> %8 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1> %9 = triton_gpu.convert_layout %8 : (tensor<256xi32, #blocked1>) -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> %10 = tt.expand_dims %9 {axis = 0 : i32} : (tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x256xi32, #blocked> @@ -1232,7 +1232,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { %42 = arith.index_cast %arg5 : index to i32 %43 = tt.splat %42 : (i32) -> tensor<1x256xi32, #blocked> %44 = arith.addi %43, %10 : tensor<1x256xi32, #blocked> - %45 = "triton_gpu.cmpi"(%44, %cst_4) {predicate = 2 : i64} : (tensor<1x256xi32, #blocked>, tensor<1x256xi32, #blocked>) -> tensor<1x256xi1, #blocked> + %45 = arith.cmpi "slt", %44, %cst_4 : tensor<1x256xi32, #blocked> %46 = arith.remsi %44, %cst_3 : tensor<1x256xi32, #blocked> %47 = arith.divsi %44, %cst_3 : tensor<1x256xi32, #blocked> %48 = arith.addi %46, %12 : tensor<1x256xi32, #blocked> @@ -1246,7 +1246,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { %56 = tt.load %53, %54, %55 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1x256xf32, #blocked3> %57 = triton_gpu.convert_layout %56 : (tensor<1x256xf32, #blocked3>) -> tensor<1x256xf32, #blocked> %58 = arith.addf %arg6, %57 : tensor<1x256xf32, #blocked> - %59 = "triton_gpu.select"(%52, %58, %arg6) : (tensor<1x256xi1, #blocked>, tensor<1x256xf32, #blocked>, tensor<1x256xf32, #blocked>) -> tensor<1x256xf32, #blocked> + %59 = arith.select %52, %58, %arg6 : tensor<1x256xi1, #blocked>, tensor<1x256xf32, #blocked> scf.yield %59 : tensor<1x256xf32, #blocked> } %16 = "tt.reduce" (%15) ({ @@ -1312,14 +1312,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %20 = triton_gpu.convert_layout %19 : (tensor<16x16xi32, #blocked3>) -> tensor<16x16xi32, #blocked1> %21 = arith.addi %18, %20 : tensor<16x16xi32, #blocked1> %22 = tt.splat %arg2 : (!tt.ptr) -> tensor<16x16x!tt.ptr, #blocked1> - %23 = "triton_gpu.cmpi"(%13, %cst_3) <{predicate = 2 : i64}> : (tensor<16x1xi32, #blocked1>, tensor<16x1xi32, #blocked1>) -> tensor<16x1xi1, #blocked1> + %23 = arith.cmpi "slt", %13, %cst_3 : tensor<16x1xi32, #blocked1> %24 = tt.broadcast %23 : (tensor<16x1xi1, #blocked1>) -> tensor<16x16xi1, #blocked1> %25 = arith.truncf %cst_2 : tensor<16x16xf32, #blocked1> to tensor<16x16xf16, #blocked1> %26 = arith.muli %2, %arg11 : i32 %27 = arith.muli %3, %arg12 : i32 %28 = arith.addi %26, %27 : i32 %29 = tt.splat %arg10 : (!tt.ptr) -> tensor<16x!tt.ptr, #blocked> - %30 = "triton_gpu.cmpi"(%7, %cst_1) <{predicate = 2 : i64}> : (tensor<16xi32, #blocked>, tensor<16xi32, #blocked>) -> tensor<16xi1, #blocked> + %30 = arith.cmpi "slt", %7, %cst_1 : tensor<16xi32, #blocked> %31 = arith.muli %2, %arg8 : i32 %32 = arith.muli %3, %arg9 : i32 %33 = arith.addi %31, %32 : i32 @@ -1351,8 +1351,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %83 = triton_gpu.convert_layout %30 : (tensor<16xi1, #blocked>) -> tensor<16xi1, #blocked> %84 = triton_gpu.convert_layout %cst_0 : (tensor<16xf32, #blocked>) -> tensor<16xf32, #blocked> %85 = tt.load %82, %83, %84 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16xf32, #blocked> - %86 = "triton_gpu.cmpf"(%arg20, %85) <{predicate = 2 : i64}> : (tensor<16xf32, #blocked>, tensor<16xf32, #blocked>) -> tensor<16xi1, #blocked> - %87 = "triton_gpu.select"(%86, %arg20, %85) : (tensor<16xi1, #blocked>, tensor<16xf32, #blocked>, tensor<16xf32, #blocked>) -> tensor<16xf32, #blocked> + %86 = arith.cmpf "ogt", %arg20, %85 : tensor<16xf32, #blocked> + %87 = arith.select %86, %arg20, %85 : tensor<16xi1, #blocked>, tensor<16xf32, #blocked> %88 = arith.subf %arg20, %87 : tensor<16xf32, #blocked> %89 = math.exp %88 : tensor<16xf32, #blocked> %90 = arith.subf %85, %87 : tensor<16xf32, #blocked> @@ -1395,7 +1395,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %51 = arith.addi %50, %49 : tensor<16x16xi32, #blocked1> %52 = tt.splat %arg13 : (!tt.ptr) -> tensor<16x16x!tt.ptr, #blocked1> %53 = tt.addptr %52, %51 : tensor<16x16x!tt.ptr, #blocked1>, tensor<16x16xi32, #blocked1> - %54 = "triton_gpu.cmpi"(%41, %cst_3) <{predicate = 2 : i64}> : (tensor<16x1xi32, #blocked1>, tensor<16x1xi32, #blocked1>) -> tensor<16x1xi1, #blocked1> + %54 = arith.cmpi "slt", %41, %cst_3 : tensor<16x1xi32, #blocked1> %55 = tt.broadcast %54 : (tensor<16x1xi1, #blocked1>) -> tensor<16x16xi1, #blocked1> %56 = arith.truncf %35#0 : tensor<16x16xf32, #blocked1> to tensor<16x16xf16, #blocked1> %57 = triton_gpu.convert_layout %53 : (tensor<16x16x!tt.ptr, #blocked1>) -> tensor<16x16x!tt.ptr, #blocked4> @@ -1549,10 +1549,10 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %30 = triton_gpu.convert_layout %29 : (tensor<32x32xf32, #blocked5>) -> tensor<32x32xf32, #blocked> %31:2 = "tt.reduce"(%30, %11) <{axis = 1 : i32}> ({ ^bb0(%arg3: f32, %arg4: i32, %arg5: f32, %arg6: i32): - %37 = "triton_gpu.cmpf"(%arg3, %arg5) <{predicate = 1 : i64}> : (f32, f32) -> i1 - %38 = "triton_gpu.cmpi"(%arg4, %arg6) <{predicate = 2 : i64}> : (i32, i32) -> i1 + %37 = arith.cmpf "oeq", %arg3, %arg5 : f32 + %38 = arith.cmpi "slt", %arg4, %arg6 : i32 %39 = arith.andi %37, %38 : i1 - %40 = "triton_gpu.cmpf"(%arg3, %arg5) <{predicate = 2 : i64}> : (f32, f32) -> i1 + %40 = arith.cmpf "ogt", %arg3, %arg5 : f32 %41 = arith.ori %40, %39 : i1 %42 = arith.select %41, %arg3, %arg5 : f32 %43 = arith.select %41, %arg4, %arg6 : i32 @@ -1810,10 +1810,10 @@ tt.func @reduce_to_scalar(%ptr: tensor<1024x!tt.ptr, #blocked>) -> (f32, i3 %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked1> %3:2 = "tt.reduce"(%1, %2) <{axis = 0 : i32}> ({ ^bb0(%arg7: f32, %arg8: i32, %arg9: f32, %arg10: i32): - %51 = "triton_gpu.cmpf"(%arg7, %arg9) <{predicate = 1 : i64}> : (f32, f32) -> i1 - %52 = "triton_gpu.cmpi"(%arg8, %arg10) <{predicate = 2 : i64}> : (i32, i32) -> i1 + %51 = arith.cmpf "oeq", %arg7, %arg9 : f32 + %52 = arith.cmpi "slt", %arg8, %arg10 : i32 %53 = arith.andi %51, %52 : i1 - %54 = "triton_gpu.cmpf"(%arg7, %arg9) <{predicate = 2 : i64}> : (f32, f32) -> i1 + %54 = arith.cmpf "ogt", %arg7, %arg9 : f32 %55 = arith.ori %54, %53 : i1 %56 = arith.select %55, %arg7, %arg9 : f32 %57 = arith.select %55, %arg8, %arg10 : i32 diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index 8bfb7b576..649b9e6a4 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -10,16 +10,15 @@ #A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> #B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> -// CHECK: tt.func @matmul_loop +// CHECK-LABEL: tt.func @matmul_loop // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32 -// CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] // CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] // CHECK-DAG: %[[LOOP_COND_0_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_0]] // CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[LOOP_COND_0_SPLAT_A]] -// CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor // CHECK-DAG: %[[LOOP_COND_0_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_0]] // CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[LOOP_COND_0_SPLAT_B]] // CHECK-DAG: %[[IV_1:.*]] = arith.addi %[[LB]], %[[STEP:.*]] @@ -29,18 +28,24 @@ // CHECK-DAG: %[[LOOP_COND_1_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_1]] // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]], %[[LOOP_COND_1_SPLAT_B]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0] -// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] +// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]] // CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] -// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] -// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}} -// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, {{.*}} -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, {{.*}} +// CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.convert_layout %[[arg_b0]] +// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_0]], {{.*}} +// CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] +// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][{{.*}}, 0, 0] -// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][{{.*}}, 0, 0] -// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} +// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: scf.yield {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, @@ -93,31 +98,37 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, #C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> #B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> -// CHECK: tt.func @matmul_loop_nested + +// CHECK-LABEL: tt.func @matmul_loop_nested // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32 // CHECK: scf.for // CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor -// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0] -// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] -// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] -// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] -// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}} -// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, {{.*}} -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, {{.*}} -// CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][{{.*}}, 0, 0] -// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][{{.*}}, 0, 0] -// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} +// CHECK-DAG: %[[A0:.*]] = triton_gpu.extract_slice %[[A0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK-DAG: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]] +// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] +// CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.convert_layout %[[arg_b0]] +// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_0]], {{.*}} +// CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] +// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] +// CHECK: triton_gpu.async_wait {num = 2 : i32} +// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: scf.yield {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, @@ -171,23 +182,28 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, #C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> #B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> -// CHECK: tt.func @matmul_loop_single_pipeline +// CHECK-LABEL: tt.func @matmul_loop_single_pipeline // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32 // CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor // CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: triton_gpu.async_wait {num = 1 : i32} -// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] +// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_b0:.*]] = %[[B0]] // CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] // CHECK: tt.dot {{.*}}, %[[arg_b0_dot_op]], {{.*}} -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, {{.*}} +// CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] // CHECK: triton_gpu.async_wait {num = 1 : i32} -// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][{{.*}}, 0, 0] -// CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, {{.*}} +// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: scf.yield {{.*}}, %[[NEXT_B_BUFFER]], %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_B]] module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, @@ -303,3 +319,124 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // tt.return %res#0, %res#1, %res#2 : !tt.ptr, 1>, !tt.ptr, 1>, tensor<128x128xf32, #C> // } //} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 16, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: dot_chained_single_load + tt.func @dot_chained_single_load(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x64xf32, #mma> { + %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %c0_i64 = arith.constant 0 : i64 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 + %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 + %2 = tt.splat %1 : (!tt.ptr) -> tensor<128x1x!tt.ptr, #blocked1> + %3 = tt.addptr %2, %cst_1 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1> + %6 = tt.broadcast %3 : (tensor<128x1x!tt.ptr, #blocked1>) -> tensor<128x64x!tt.ptr, #blocked1> + %7 = tt.broadcast %5 : (tensor<1x64xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> + %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %9 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> + %10 = tt.splat %0 : (!tt.ptr) -> tensor<1x16x!tt.ptr, #blocked> + %11 = tt.addptr %10, %cst_0 : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64x1xi32, #blocked> + %14 = tt.broadcast %11 : (tensor<1x16x!tt.ptr, #blocked>) -> tensor<64x16x!tt.ptr, #blocked> + %15 = tt.broadcast %13 : (tensor<64x1xi32, #blocked>) -> tensor<64x16xi32, #blocked> + %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + // CHECK: scf.for + // CHECK: triton_gpu.async_wait {num = 1 : i32} + // CHECK: tt.dot + // CHECK: triton_nvidia_gpu.dot_async + // CHECK: triton_gpu.insert_slice_async + // CHECK: triton_gpu.async_commit_group + // CHECK: scf.yield + %17:2 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_3, %arg5 = %16) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>) : i32 { + %18 = tt.load %arg5 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x16xf16, #blocked> + %19 = triton_gpu.convert_layout %9 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared> + %20 = triton_gpu.convert_layout %18 : (tensor<64x16xf16, #blocked>) -> tensor<64x16xf16, #shared1> + %21 = tt.dot %19, %20, %cst_2 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #shared> * tensor<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> + %22 = arith.truncf %21 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> + %23 = tt.trans %20 : (tensor<64x16xf16, #shared1>) -> tensor<16x64xf16, #shared> + %24 = triton_gpu.convert_layout %22 : (tensor<128x16xf16, #mma1>) -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %25 = tt.dot %24, %23, %arg4 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<16x64xf16, #shared> -> tensor<128x64xf32, #mma> + %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + scf.yield %25, %26 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked> + } + tt.return %17#0 : tensor<128x64xf32, #mma> + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 16, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: two_accumulator_escape + tt.func @two_accumulator_escape(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x64xf32, #mma>, tensor<128x16xf32, #mma1>) { + %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %c0_i64 = arith.constant 0 : i64 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 + %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 + %2 = tt.splat %1 : (!tt.ptr) -> tensor<128x1x!tt.ptr, #blocked1> + %3 = tt.addptr %2, %cst_1 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1> + %6 = tt.broadcast %3 : (tensor<128x1x!tt.ptr, #blocked1>) -> tensor<128x64x!tt.ptr, #blocked1> + %7 = tt.broadcast %5 : (tensor<1x64xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> + %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %9 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> + %10 = tt.splat %0 : (!tt.ptr) -> tensor<1x16x!tt.ptr, #blocked> + %11 = tt.addptr %10, %cst_0 : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64x1xi32, #blocked> + %14 = tt.broadcast %11 : (tensor<1x16x!tt.ptr, #blocked>) -> tensor<64x16x!tt.ptr, #blocked> + %15 = tt.broadcast %13 : (tensor<64x1xi32, #blocked>) -> tensor<64x16xi32, #blocked> + %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + // CHECK: %[[R:.+]]:{{.+}} = scf.for + // CHECK: triton_gpu.async_wait {num = 2 : i32} + // CHECK: triton_nvidia_gpu.dot_async + // CHECK: triton_nvidia_gpu.dot_async + // CHECK: triton_nvidia_gpu.dot_wait %35 {pendings = 2 : i32} + // CHECK: scf.yield + // CHECK: %{{.*}}:2 = triton_nvidia_gpu.dot_wait %[[R]]#{{.+}}, %[[R]]#{{.+}} {pendings = 0 : i32} : tensor<128x16xf32, #{{.*}}>, tensor<128x64xf32, #{{.*}}> + %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_3, %arg5 = %16, %arg6 = %cst_2) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>) : i32 { + %18 = tt.load %arg5 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x16xf16, #blocked> + %19 = triton_gpu.convert_layout %9 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared> + %20 = triton_gpu.convert_layout %18 : (tensor<64x16xf16, #blocked>) -> tensor<64x16xf16, #shared1> + %21 = tt.dot %19, %20, %arg6 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #shared> * tensor<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> + %l = tt.load %arg5 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x16xf16, #blocked> + %c = triton_gpu.convert_layout %l : (tensor<64x16xf16, #blocked>) -> tensor<64x16xf16, #shared1> + %23 = tt.trans %c : (tensor<64x16xf16, #shared1>) -> tensor<16x64xf16, #shared> + %25 = tt.dot %cst_4, %23, %arg4 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<16x64xf16, #shared> -> tensor<128x64xf32, #mma> + %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + scf.yield %25, %26, %21 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1> + } + tt.return %17#0, %17#2 : tensor<128x64xf32, #mma>, tensor<128x16xf32, #mma1> + } +} diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index d56b27291..f3a9b01aa 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -11,15 +11,15 @@ #A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> #B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> -// CHECK: tt.func @matmul_loop +// CHECK-LABEL: tt.func @matmul_loop // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] // CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] // CHECK-DAG: %[[LOOP_COND_0_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_0]] // CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[LOOP_COND_0_SPLAT_A]] -// CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor // CHECK-DAG: %[[LOOP_COND_0_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_0]] // CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[LOOP_COND_0_SPLAT_B]] // CHECK-DAG: %[[IV_1:.*]] = arith.addi %[[LB]], %[[STEP:.*]] @@ -29,25 +29,25 @@ // CHECK-DAG: %[[LOOP_COND_1_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_1]] // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]], %[[LOOP_COND_1_SPLAT_B]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0] -// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_0]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] +// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]] // CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] // CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.convert_layout %[[arg_b0]] // CHECK: %[[arg_b0_dot_op_1:.*]] = arith.mulf %[[arg_b0_dot_op_0]] // CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_1]], {{.*}} -// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] : i32 -// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_2]] -// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[NEXT_LOOP_IDX]] -// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]] -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]] +// CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] +// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0] -// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0] -// CHECK-DAG: %[[PIPELINE_IDX_PLUS_ONE:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] -// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_2]] -// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX_PLUS_ONE]] -// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[EXTRACT_IDX]] +// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: scf.yield {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { @@ -92,36 +92,36 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, tt.return %loop#2: tensor<128x128xf32, #C> } -// CHECK: tt.func @matmul_loop_nested +// CHECK-LABEL: tt.func @matmul_loop_nested // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 // CHECK: scf.for // CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor -// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0] -// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_0]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] -// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] -// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] -// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}} -// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] : i32 -// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_2]] -// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[NEXT_LOOP_IDX]] -// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]] -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]] -// CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0] -// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0] -// CHECK-DAG: %[[PIPELINE_IDX_PLUS_ONE:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] -// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_2]] -// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX_PLUS_ONE]] -// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[EXTRACT_IDX]] +// CHECK-DAG: %[[A0:.*]] = triton_gpu.extract_slice %[[A0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK-DAG: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]] +// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] +// CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.convert_layout %[[arg_b0]] +// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_0]], {{.*}} +// CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] +// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] +// CHECK: triton_gpu.async_wait {num = 2 : i32} +// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: scf.yield {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C>{ @@ -168,7 +168,7 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, } -// CHECK: tt.func @matmul_loop_single_pipeline +// CHECK-LABEL: tt.func @matmul_loop_single_pipeline // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 @@ -176,20 +176,20 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, // CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: triton_gpu.async_wait {num = 1 : i32} -// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_0]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] +// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_b0:.*]] = %[[B0]] // CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] // CHECK: tt.dot {{.*}}, %[[arg_b0_dot_op]], {{.*}} -// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] : i32 -// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_2]] -// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[NEXT_LOOP_IDX]] -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]] +// CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] // CHECK: triton_gpu.async_wait {num = 1 : i32} -// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0] -// CHECK-DAG: %[[PIPELINE_IDX_PLUS_ONE:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] -// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_2]] -// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX_PLUS_ONE]] -// CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[EXTRACT_IDX]] +// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: scf.yield {{.*}}, %[[NEXT_B_BUFFER]], %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_B]] tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { @@ -228,18 +228,18 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, tt.return %loop#1 : tensor<128x128xf32, #C> } -// CHECK: tt.func @lut_bmm_scalar +// CHECK-LABEL: tt.func @lut_bmm_scalar // CHECK: triton_gpu.insert_slice_async // CHECK: triton_gpu.insert_slice_async // CHECK: triton_gpu.insert_slice_async // CHECK: triton_gpu.insert_slice_async // CHECK: triton_gpu.async_commit_group -// CHECK: %[[LUT_BUFFER_0:.*]] = tt.load %arg15, {{.*}} +// CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} +// CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_1]] +// CHECK: %[[LUT_BUFFER_0:.*]] = tt.load %{{.*}}, {{.*}} // CHECK: %[[LUT_BUFFER_1:.*]] = arith.muli {{.*}}, %[[LUT_BUFFER_0]] // CHECK: %[[LUT_BUFFER_2:.*]] = tt.splat %[[LUT_BUFFER_1]] // CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[LUT_BUFFER_2]] -// CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %arg14, {{.*}} -// CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_1]] // CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_0]] // CHECK: triton_gpu.async_wait {num = 2 : i32} tt.func @lut_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, @@ -271,19 +271,19 @@ tt.func @lut_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, tt.return %79#0 : tensor<16x16xf32, #C> } -// CHECK: tt.func @lut_bmm_vector +// CHECK-LABEL: tt.func @lut_bmm_vector // CHECK: triton_gpu.insert_slice_async // CHECK: triton_gpu.insert_slice_async // CHECK: triton_gpu.insert_slice_async // CHECK: triton_gpu.insert_slice_async // CHECK: triton_gpu.async_commit_group -// CHECK: %[[LUT_BUFFER_0:.*]] = tt.load %arg15, {{.*}} +// CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} +// CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_1]] +// CHECK: %[[LUT_BUFFER_0:.*]] = tt.load %{{.*}}, {{.*}} // CHECK: %[[LUT_BUFFER_1:.*]] = tt.expand_dims %[[LUT_BUFFER_0]] {axis = 1 : i32} // CHECK: %[[LUT_BUFFER_2:.*]] = tt.broadcast %[[LUT_BUFFER_1]] // CHECK: %[[LUT_BUFFER_3:.*]] = arith.muli {{.*}}, %[[LUT_BUFFER_2]] // CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[LUT_BUFFER_3]] -// CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %arg14, {{.*}} -// CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_1]] // CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_0]] // CHECK: triton_gpu.async_wait {num = 2 : i32} tt.func @lut_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, @@ -317,11 +317,11 @@ tt.func @lut_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt tt.return %79#0 : tensor<16x16xf32, #C> } -// CHECK: tt.func @post_load_inv +// CHECK-LABEL: tt.func @post_load_inv // CHECK: scf.for -// CHECK: arith.index_cast // CHECK-DAG: %[[IV:.*]] = arith.index_cast // CHECK: %[[NEXT_IV:.*]] = arith.addi %[[IV]], %c1_i32 : i32 +// CHECK: arith.index_cast // CHECK-NOT: arith.addi %[[NEXT_IV]] tt.func @post_load_inv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, @@ -350,11 +350,11 @@ tt.func @post_load_inv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %107 = arith.muli %130, %c32_i32 : i32 %108 = arith.subi %arg5, %107 : i32 %109 = tt.splat %108 : (i32) -> tensor<1x32xi32, #AL> - %110 = "triton_gpu.cmpi"(%50, %109) <{predicate = 2 : i64}> : (tensor<1x32xi32, #AL>, tensor<1x32xi32, #AL>) -> tensor<1x32xi1, #AL> + %110 = arith.cmpi "slt", %50, %109 : tensor<1x32xi32, #AL> %111 = tt.broadcast %110 : (tensor<1x32xi1, #AL>) -> tensor<32x32xi1, #AL> %112 = tt.load %arg11, %111, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL> %113 = tt.splat %108 : (i32) -> tensor<32x1xi32, #AL> - %114 = "triton_gpu.cmpi"(%66, %113) <{predicate = 2 : i64}> : (tensor<32x1xi32, #AL>, tensor<32x1xi32, #AL>) -> tensor<32x1xi1, #AL> + %114 = arith.cmpi "slt", %66, %113 : tensor<32x1xi32, #AL> %115 = tt.broadcast %114 : (tensor<32x1xi1, #AL>) -> tensor<32x32xi1, #AL> %116 = tt.load %arg12, %115, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL> %117 = triton_gpu.convert_layout %112 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> @@ -373,17 +373,11 @@ tt.func @post_load_inv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, tt.return %85#0 : tensor<32x32xf32, #C> } -// CHECK: tt.func @cross_iter_dep -// CHECK: triton_gpu.async_commit_group -// CHECK: triton_gpu.async_commit_group -// CHECK: triton_gpu.async_commit_group -// CHECK: triton_gpu.async_commit_group -// CHECK: %[[PTR0:.*]] = tt.addptr -// CHECK: %[[PTR1:.*]] = tt.addptr -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[BUF0:.*]] = %[[PTR0]], {{.*}}, %[[BUF1:.*]] = %[[PTR1]] +// CHECK-LABEL: tt.func @cross_iter_dep +// TODO: enable pipelining with distance of 2 +// CHECK-NOT: triton_gpu.async_commit_group +// CHECK: scf.for // CHECK: scf.yield -// CHECK-SAME: %[[BUF0]] -// CHECK-SAME: %[[BUF1]] tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, @@ -413,11 +407,11 @@ tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %141 = arith.muli %161, %c32_i32 : i32 %142 = arith.subi %arg5, %141 : i32 %143 = tt.splat %142 : (i32) -> tensor<1x32xi32, #AL> - %144 = "triton_gpu.cmpi"(%65, %143) <{predicate = 2 : i64}> : (tensor<1x32xi32, #AL>, tensor<1x32xi32, #AL>) -> tensor<1x32xi1, #AL> + %144 = arith.cmpi "slt", %65, %143 : tensor<1x32xi32, #AL> %145 = tt.broadcast %144 : (tensor<1x32xi1, #AL>) -> tensor<32x32xi1, #AL> %146 = tt.load %arg11, %145, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL> %147 = tt.splat %142 : (i32) -> tensor<32x1xi32, #AL> - %148 = "triton_gpu.cmpi"(%88, %147) <{predicate = 2 : i64}> : (tensor<32x1xi32, #AL>, tensor<32x1xi32, #AL>) -> tensor<32x1xi1, #AL> + %148 = arith.cmpi "slt", %88, %147 : tensor<32x1xi32, #AL> %149 = tt.broadcast %148 : (tensor<32x1xi1, #AL>) -> tensor<32x32xi1, #AL> %150 = tt.load %arg12, %149, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL> %151 = triton_gpu.convert_layout %146 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> @@ -436,7 +430,7 @@ tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, tt.return %119#0 : tensor<32x32xf32, #C> } -// CHECK: tt.func @dep_arg_two_uses +// CHECK-LABEL: tt.func @dep_arg_two_uses // CHECK: tt.expand_dims // CHECK: tt.expand_dims // CHECK: tt.expand_dims %arg5 @@ -470,8 +464,8 @@ tt.func @dep_arg_two_uses(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %175 = arith.index_cast %1750 : index to i32 %176 = tt.splat %175 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> %177 = tt.splat %175 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #BL}>> - %178 = "triton_gpu.cmpi"(%57, %176) <{predicate = 2 : i64}> : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>>) -> tensor<32xi1, #triton_gpu.slice<{dim = 0, parent = #AL}>> - %179 = "triton_gpu.cmpi"(%58, %177) <{predicate = 2 : i64}> : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #BL}>>, tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #BL}>>) -> tensor<32xi1, #triton_gpu.slice<{dim = 1, parent = #BL}>> + %178 = arith.cmpi "slt", %57, %176 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>> + %179 = arith.cmpi "slt", %58, %177 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #BL}>> %180 = tt.expand_dims %178 {axis = 0 : i32} : (tensor<32xi1, #triton_gpu.slice<{dim = 0, parent = #AL}>>) -> tensor<1x32xi1, #AL> %181 = tt.expand_dims %179 {axis = 1 : i32} : (tensor<32xi1, #triton_gpu.slice<{dim = 1, parent = #BL}>>) -> tensor<32x1xi1, #BL> %182 = tt.expand_dims %arg21 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>>) -> tensor<1x32xi32, #AL> diff --git a/test/TritonGPU/optimize-locality.mlir b/test/TritonGPU/optimize-locality.mlir new file mode 100644 index 000000000..56b56d03c --- /dev/null +++ b/test/TritonGPU/optimize-locality.mlir @@ -0,0 +1,574 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-optimize-thread-locality -canonicalize | FileCheck %s + +// CHECK-LABEL: negative_zero_accumulator +// CHECK: %[[INIT_ARG:.*]] = arith.constant dense<0.000000e+00> +// CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[INIT_ARG]]) -> {{.*}} +// CHECK: %[[LOAD:.*]] = tt.load +// CHECK: tt.view %[[LOAD]] : {{.*}} -> tensor<{{32x32x4xf32.*}} +// CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> +// CHECK: arith.addf +// CHECK: arith.addf %[[FOR_ARG]], %[[REDUCE]] +// CHECK-NEXT: scf.yield +// CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> +// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[FINAL_REDUCE]] +// CHECK: tt.store {{%.*}}, %[[CVT_OUTPUT]] +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @negative_zero_accumulator( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, + %18: tensor<32x128x!tt.ptr, #blocked> {tt.divisibility = 16 : i32}, + %11: i32 {tt.divisibility = 16 : i32}, + %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} + ) attributes {noinline = false} { + %cst = arith.constant dense<-0.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %c128_i32 = arith.constant 128 : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.get_num_programs {axis = 1 : i32} : i32 + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %27 = arith.muli %arg3, %c128_i32 : i32 + %28 = tt.splat %27 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi32, #blocked> + %31 = tt.broadcast %30 : (tensor<1x128xi32, #blocked>) -> tensor<32x128xi32, #blocked> + %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %33 = tt.load %32 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf32, #blocked> + %34 = "tt.reduce"(%33) <{axis = 1 : i32}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %36 = arith.addf %arg5, %arg6 : f32 + tt.reduce.return %36 : f32 + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %35 = arith.addf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + } + %26 = triton_gpu.convert_layout %19 : (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32xf32, #blocked1> + tt.store %25, %26 {cache = 1 : i32, evict = 1 : i32} : tensor<32xf32, #blocked1> + tt.return + } +} + +// ----- + +// CHECK-LABEL: positive_zero_accumulator +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> +// CHECK-NEXT: %[[CST1:.*]] = arith.constant dense<0.000000e+00> +// CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[CST1]]) -> {{.*}} +// CHECK: tt.load +// CHECK: tt.view +// CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> +// CHECK: arith.addf +// CHECK: arith.addf %[[FOR_ARG]], %[[REDUCE]] +// CHECK-NEXT: scf.yield +// CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> +// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[FINAL_REDUCE]] +// CHECK: arith.addf %[[CVT_OUTPUT]], %[[CST]] +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @positive_zero_accumulator( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, + %18: tensor<32x128x!tt.ptr, #blocked> {tt.divisibility = 16 : i32}, + %11: i32 {tt.divisibility = 16 : i32}, + %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} + ) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %c128_i32 = arith.constant 128 : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.get_num_programs {axis = 1 : i32} : i32 + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %27 = arith.muli %arg3, %c128_i32 : i32 + %28 = tt.splat %27 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi32, #blocked> + %31 = tt.broadcast %30 : (tensor<1x128xi32, #blocked>) -> tensor<32x128xi32, #blocked> + %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %33 = tt.load %32 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf32, #blocked> + %34 = "tt.reduce"(%33) <{axis = 1 : i32}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %36 = arith.addf %arg5, %arg6 : f32 + tt.reduce.return %36 : f32 + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %35 = arith.addf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + } + %26 = triton_gpu.convert_layout %19 : (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32xf32, #blocked1> + tt.store %25, %26 {cache = 1 : i32, evict = 1 : i32} : tensor<32xf32, #blocked1> + tt.return + } +} + +// ----- + +// CHECK-LABEL: slice_layout +// CHECK: %[[LOOP_OUTPUT:.*]] = scf.for +// CHECK: %[[LOAD:.*]] = tt.load +// CHECK-NEXT: "tt.reduce"(%[[LOAD]]) <{axis = 1 : i32}> +// CHECK: arith.addf +// CHECK: arith.addf +// CHECK-NEXT: scf.yield +// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[LOOP_OUTPUT]] +#blocked3d = #triton_gpu.blocked<{sizePerThread = [1, 4, 1], threadsPerWarp = [1, 32, 1], warpsPerCTA = [4, 1, 1], order = [2, 1, 0], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [0, 1, 2]}> +#slice2d = #triton_gpu.slice<{dim = 2, parent = #blocked3d}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @slice_layout( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, + %18: tensor<32x128x!tt.ptr, #slice2d> {tt.divisibility = 16 : i32}, + %11: i32 {tt.divisibility = 16 : i32}, + %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} + ) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #slice2d}>> + %c128_i32 = arith.constant 128 : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.get_num_programs {axis = 1 : i32} : i32 + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #slice2d}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #slice2d}>>) : i32 { + %27 = arith.muli %arg3, %c128_i32 : i32 + %28 = tt.splat %27 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #slice2d}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #slice2d}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #slice2d}>>) -> tensor<1x128xi32, #slice2d> + %31 = tt.broadcast %30 : (tensor<1x128xi32, #slice2d>) -> tensor<32x128xi32, #slice2d> + %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #slice2d>, tensor<32x128xi32, #slice2d> + %33 = tt.load %32 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf32, #slice2d> + %34 = "tt.reduce"(%33) <{axis = 1 : i32}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %36 = arith.addf %arg5, %arg6 : f32 + tt.reduce.return %36 : f32 + }) : (tensor<32x128xf32, #slice2d>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #slice2d}>> + %35 = arith.addf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #slice2d}>> + scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #slice2d}>> + } + %26 = triton_gpu.convert_layout %19 : (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #slice2d}>>) -> tensor<32xf32, #blocked1> + tt.store %25, %26 {cache = 1 : i32, evict = 1 : i32} : tensor<32xf32, #blocked1> + tt.return + } +} + +// ----- + +// CHECK-LABEL: mma_layout +// CHECK: %[[LOOP_OUTPUT:.*]] = scf.for +// CHECK: %[[LOAD:.*]] = tt.load +// CHECK-NEXT: "tt.reduce"(%[[LOAD]]) <{axis = 1 : i32}> +// CHECK: arith.addf +// CHECK: arith.addf +// CHECK-NEXT: scf.yield +// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[LOOP_OUTPUT]] +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @mma_layout( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, + %18: tensor<32x128x!tt.ptr, #mma> {tt.divisibility = 16 : i32}, + %11: i32 {tt.divisibility = 16 : i32}, + %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} + ) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %c128_i32 = arith.constant 128 : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.get_num_programs {axis = 1 : i32} : i32 + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) : i32 { + %27 = arith.muli %arg3, %c128_i32 : i32 + %28 = tt.splat %27 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>) -> tensor<1x128xi32, #mma> + %31 = tt.broadcast %30 : (tensor<1x128xi32, #mma>) -> tensor<32x128xi32, #mma> + %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #mma>, tensor<32x128xi32, #mma> + %33 = tt.load %32 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf32, #mma> + %34 = "tt.reduce"(%33) <{axis = 1 : i32}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %36 = arith.addf %arg5, %arg6 : f32 + tt.reduce.return %36 : f32 + }) : (tensor<32x128xf32, #mma>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %35 = arith.addf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + } + %26 = triton_gpu.convert_layout %19 : (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<32xf32, #blocked1> + tt.store %25, %26 {cache = 1 : i32, evict = 1 : i32} : tensor<32xf32, #blocked1> + tt.return + } +} + +// ----- + +// CHECK-LABEL: max_reduce +// CHECK: %[[INIT_ARG:.*]] = arith.constant dense<0xFF800000> +// CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[INIT_ARG]]) -> {{.*}} +// CHECK: %[[LOAD:.*]] = tt.load +// CHECK: tt.view %[[LOAD]] : {{.*}} -> tensor<{{32x32x4xf32.*}} +// CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> +// CHECK: arith.maximumf +// CHECK: arith.maximumf %[[FOR_ARG]], %[[REDUCE]] +// CHECK-NEXT: scf.yield +// CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> +// CHECK: arith.maximumf +// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[FINAL_REDUCE]] +// CHECK: tt.store {{%.*}}, %[[CVT_OUTPUT]] +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @max_reduce( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, + %18: tensor<32x128x!tt.ptr, #blocked> {tt.divisibility = 16 : i32}, + %11: i32 {tt.divisibility = 16 : i32}, + %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} + ) attributes {noinline = false} { + %cst = arith.constant dense<0xFF800000> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %c128_i32 = arith.constant 128 : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.get_num_programs {axis = 1 : i32} : i32 + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %27 = arith.muli %arg3, %c128_i32 : i32 + %28 = tt.splat %27 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi32, #blocked> + %31 = tt.broadcast %30 : (tensor<1x128xi32, #blocked>) -> tensor<32x128xi32, #blocked> + %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %33 = tt.load %32 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf32, #blocked> + %34 = "tt.reduce"(%33) <{axis = 1 : i32}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %36 = arith.maximumf %arg5, %arg6 : f32 + tt.reduce.return %36 : f32 + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %35 = arith.maximumf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + } + %26 = triton_gpu.convert_layout %19 : (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32xf32, #blocked1> + tt.store %25, %26 {cache = 1 : i32, evict = 1 : i32} : tensor<32xf32, #blocked1> + tt.return + } +} + +// ----- + +// CHECK-LABEL: max_reduce_zero_int_accumulator +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> +// CHECK-NEXT: %[[CST1:.*]] = arith.constant dense<0xFF800000> +// CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[CST1]]) -> {{.*}} +// CHECK: tt.load +// CHECK: tt.view +// CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> +// CHECK: arith.maximumf +// CHECK: arith.maximumf %[[FOR_ARG]], %[[REDUCE]] +// CHECK-NEXT: scf.yield +// CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> +// CHECK: arith.maximumf +// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[FINAL_REDUCE]] +// CHECK: arith.maximumf %[[CVT_OUTPUT]], %[[CST]] +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @max_reduce_zero_int_accumulator( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, + %18: tensor<32x128x!tt.ptr, #blocked> {tt.divisibility = 16 : i32}, + %11: i32 {tt.divisibility = 16 : i32}, + %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} + ) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %c128_i32 = arith.constant 128 : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.get_num_programs {axis = 1 : i32} : i32 + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %27 = arith.muli %arg3, %c128_i32 : i32 + %28 = tt.splat %27 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi32, #blocked> + %31 = tt.broadcast %30 : (tensor<1x128xi32, #blocked>) -> tensor<32x128xi32, #blocked> + %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %33 = tt.load %32 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf32, #blocked> + %34 = "tt.reduce"(%33) <{axis = 1 : i32}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %36 = arith.maximumf %arg5, %arg6 : f32 + tt.reduce.return %36 : f32 + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %35 = arith.maximumf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + } + %26 = triton_gpu.convert_layout %19 : (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32xf32, #blocked1> + tt.store %25, %26 {cache = 1 : i32, evict = 1 : i32} : tensor<32xf32, #blocked1> + tt.return + } +} + +// ----- + +// CHECK-LABEL: min_reduce +// CHECK: %[[CST:.*]] = arith.constant dense<0x7F800000> +// CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[CST]]) -> {{.*}} +// CHECK: %[[LOAD:.*]] = tt.load +// CHECK: tt.view %[[LOAD]] : {{.*}} -> tensor<{{32x32x4xf32.*}} +// CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> +// CHECK: arith.minimumf +// CHECK: arith.minimumf %[[FOR_ARG]], %[[REDUCE]] +// CHECK-NEXT: scf.yield +// CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> +// CHECK: arith.minimumf +// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[FINAL_REDUCE]] +// CHECK: tt.store {{%.*}}, %[[CVT_OUTPUT]] +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @min_reduce( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, + %18: tensor<32x128x!tt.ptr, #blocked> {tt.divisibility = 16 : i32}, + %11: i32 {tt.divisibility = 16 : i32}, + %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} + ) attributes {noinline = false} { + %cst = arith.constant dense<0x7F800000> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %c128_i32 = arith.constant 128 : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.get_num_programs {axis = 1 : i32} : i32 + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %27 = arith.muli %arg3, %c128_i32 : i32 + %28 = tt.splat %27 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi32, #blocked> + %31 = tt.broadcast %30 : (tensor<1x128xi32, #blocked>) -> tensor<32x128xi32, #blocked> + %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %33 = tt.load %32 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf32, #blocked> + %34 = "tt.reduce"(%33) <{axis = 1 : i32}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %36 = arith.minimumf %arg5, %arg6 : f32 + tt.reduce.return %36 : f32 + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %35 = arith.minimumf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + } + %26 = triton_gpu.convert_layout %19 : (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32xf32, #blocked1> + tt.store %25, %26 {cache = 1 : i32, evict = 1 : i32} : tensor<32xf32, #blocked1> + tt.return + } +} + +// ----- + +// CHECK-LABEL: min_reduce_zero_int_accumulator +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> +// CHECK-NEXT: %[[CST1:.*]] = arith.constant dense<0x7F800000> +// CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[CST1]]) -> {{.*}} +// CHECK: tt.load +// CHECK: tt.view +// CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> +// CHECK: arith.minimumf +// CHECK: arith.minimumf %[[FOR_ARG]], %[[REDUCE]] +// CHECK-NEXT: scf.yield +// CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> +// CHECK: arith.minimumf +// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[FINAL_REDUCE]] +// CHECK: arith.minimumf %[[CVT_OUTPUT]], %[[CST]] +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @min_reduce_zero_int_accumulator( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, + %18: tensor<32x128x!tt.ptr, #blocked> {tt.divisibility = 16 : i32}, + %11: i32 {tt.divisibility = 16 : i32}, + %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} + ) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %c128_i32 = arith.constant 128 : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.get_num_programs {axis = 1 : i32} : i32 + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %27 = arith.muli %arg3, %c128_i32 : i32 + %28 = tt.splat %27 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi32, #blocked> + %31 = tt.broadcast %30 : (tensor<1x128xi32, #blocked>) -> tensor<32x128xi32, #blocked> + %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %33 = tt.load %32 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf32, #blocked> + %34 = "tt.reduce"(%33) <{axis = 1 : i32}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %36 = arith.minimumf %arg5, %arg6 : f32 + tt.reduce.return %36 : f32 + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %35 = arith.minimumf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + } + %26 = triton_gpu.convert_layout %19 : (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32xf32, #blocked1> + tt.store %25, %26 {cache = 1 : i32, evict = 1 : i32} : tensor<32xf32, #blocked1> + tt.return + } +} + +// ----- + +// CHECK-LABEL: mul_reduce +// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> +// CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[CST]]) -> {{.*}} +// CHECK: %[[LOAD:.*]] = tt.load +// CHECK: tt.view %[[LOAD]] : {{.*}} -> tensor<{{32x32x4xf32.*}} +// CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> +// CHECK: arith.mulf +// CHECK: arith.mulf %[[FOR_ARG]], %[[REDUCE]] +// CHECK-NEXT: scf.yield +// CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> +// CHECK: arith.mulf +// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[FINAL_REDUCE]] +// CHECK: tt.store {{%.*}}, %[[CVT_OUTPUT]] +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @mul_reduce( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, + %18: tensor<32x128x!tt.ptr, #blocked> {tt.divisibility = 16 : i32}, + %11: i32 {tt.divisibility = 16 : i32}, + %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} + ) attributes {noinline = false} { + %cst = arith.constant dense<1.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %c128_i32 = arith.constant 128 : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.get_num_programs {axis = 1 : i32} : i32 + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %27 = arith.muli %arg3, %c128_i32 : i32 + %28 = tt.splat %27 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi32, #blocked> + %31 = tt.broadcast %30 : (tensor<1x128xi32, #blocked>) -> tensor<32x128xi32, #blocked> + %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %33 = tt.load %32 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf32, #blocked> + %34 = "tt.reduce"(%33) <{axis = 1 : i32}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %36 = arith.mulf %arg5, %arg6 : f32 + tt.reduce.return %36 : f32 + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %35 = arith.mulf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + } + %26 = triton_gpu.convert_layout %19 : (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32xf32, #blocked1> + tt.store %25, %26 {cache = 1 : i32, evict = 1 : i32} : tensor<32xf32, #blocked1> + tt.return + } +} + +// ----- + +// CHECK-LABEL: mul_reduce_zero_int_accumulator +// CHECK: %[[CST:.*]] = arith.constant dense +// CHECK-NEXT: %[[CST1:.*]] = arith.constant dense<1.000000e+00> +// CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[CST1]]) -> {{.*}} +// CHECK: tt.load +// CHECK: tt.view +// CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> +// CHECK: arith.mulf +// CHECK: arith.mulf %[[FOR_ARG]], %[[REDUCE]] +// CHECK-NEXT: scf.yield +// CHECK: %[[FINAL_REDUCE:.*]] = "tt.reduce"(%[[LOOP_OUTPUT]]) <{axis = 1 : i32}> +// CHECK: arith.mulf +// CHECK: %[[CVT_OUTPUT:.*]] = triton_gpu.convert_layout %[[FINAL_REDUCE]] +// CHECK: arith.mulf %[[CVT_OUTPUT]], %[[CST]] +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @mul_reduce_zero_int_accumulator( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, + %18: tensor<32x128x!tt.ptr, #blocked> {tt.divisibility = 16 : i32}, + %11: i32 {tt.divisibility = 16 : i32}, + %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} + ) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %c128_i32 = arith.constant 128 : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.get_num_programs {axis = 1 : i32} : i32 + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %27 = arith.muli %arg3, %c128_i32 : i32 + %28 = tt.splat %27 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi32, #blocked> + %31 = tt.broadcast %30 : (tensor<1x128xi32, #blocked>) -> tensor<32x128xi32, #blocked> + %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %33 = tt.load %32 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf32, #blocked> + %34 = "tt.reduce"(%33) <{axis = 1 : i32}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %36 = arith.mulf %arg5, %arg6 : f32 + tt.reduce.return %36 : f32 + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %35 = arith.mulf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + } + %26 = triton_gpu.convert_layout %19 : (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32xf32, #blocked1> + tt.store %25, %26 {cache = 1 : i32, evict = 1 : i32} : tensor<32xf32, #blocked1> + tt.return + } +} + + +// ----- + +// CHECK-LABEL: remains_unchanged +// CHECK: %[[CST:.*]] = arith.constant dense +// CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[CST]]) -> {{.*}} +// CHECK: %[[LOAD:.*]] = tt.load +// CHECK: %[[MULF:.*]] = arith.mulf %[[LOAD]], %[[LOAD]] +// CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"(%[[MULF]]) <{axis = 1 : i32}> +// CHECK: arith.maximumf +// CHECK: arith.maximumf %[[FOR_ARG]], %[[REDUCE]] +// CHECK-NEXT: scf.yield +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @remains_unchanged( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, + %18: tensor<32x128x!tt.ptr, #blocked> {tt.divisibility = 16 : i32}, + %11: i32 {tt.divisibility = 16 : i32}, + %25: tensor<32x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32} + ) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %c128_i32 = arith.constant 128 : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.get_num_programs {axis = 1 : i32} : i32 + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %19 = scf.for %arg3 = %1 to %11 step %2 iter_args(%arg4 = %cst) -> (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) : i32 { + %27 = arith.muli %arg3, %c128_i32 : i32 + %28 = tt.splat %27 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %12 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %30 = tt.expand_dims %29 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi32, #blocked> + %31 = tt.broadcast %30 : (tensor<1x128xi32, #blocked>) -> tensor<32x128xi32, #blocked> + %32 = tt.addptr %18, %31 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %33 = tt.load %32 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf32, #blocked> + %333 = arith.mulf %33, %33: tensor<32x128xf32, #blocked> + %34 = "tt.reduce"(%333) <{axis = 1 : i32}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %36 = arith.maximumf %arg5, %arg6 : f32 + tt.reduce.return %36 : f32 + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %35 = arith.maximumf %arg4, %34 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + scf.yield %35 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + } + %26 = triton_gpu.convert_layout %19 : (tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32xf32, #blocked1> + tt.store %25, %26 {cache = 1 : i32, evict = 1 : i32} : tensor<32xf32, #blocked1> + tt.return + } +} diff --git a/test/TritonGPU/pipeline-hopper-remove-wait.mlir b/test/TritonGPU/pipeline-hopper-remove-wait.mlir new file mode 100644 index 000000000..56c55816d --- /dev/null +++ b/test/TritonGPU/pipeline-hopper-remove-wait.mlir @@ -0,0 +1,94 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-rewrite-tensor-pointer -canonicalize -tritongpu-pipeline=compute-capability=90 -canonicalize | FileCheck %s + + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked4 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: two_dependent_dot + tt.func public @two_dependent_dot(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} , %arg2: !tt.ptr {tt.divisibility = 16 : i32} , %arg3: f32 , %arg4: !tt.ptr {tt.divisibility = 16 : i32} , %arg5: !tt.ptr {tt.divisibility = 16 : i32} , %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg10: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg11: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg12: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg13: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg14: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg15: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg16: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg17: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg18: i32 , %arg19: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg20: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg21: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} ) attributes {noinline = false} { + %cst = arith.constant dense<0xFF800000> : tensor<128x64xf32, #mma> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %cst_1 = arith.constant dense<0xFF800000> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %cst_4 = arith.constant 1.44269502 : f32 + %c128_i32 = arith.constant 128 : i32 + %c1_i64 = arith.constant 1 : i64 + %c128_i64 = arith.constant 128 : i64 + %0 = tt.get_program_id x : i32 + %1 = tt.get_program_id y : i32 + %2 = arith.muli %1, %arg7 : i32 + %3 = arith.divsi %2, %arg8 : i32 + %4 = arith.extsi %arg21 : i32 to i64 + %5 = arith.extsi %arg11 : i32 to i64 + %6 = tt.make_tensor_ptr %arg1, [%c128_i64, %4], [%c1_i64, %5], [%c0_i32, %3] {order = array} : , 1> + %7 = arith.extsi %arg14 : i32 to i64 + %8 = tt.make_tensor_ptr %arg2, [%4, %c128_i64], [%7, %c1_i64], [%3, %c0_i32] {order = array} : , 1> + %9 = arith.muli %0, %c128_i32 : i32 + %10 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %11 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked3> + %13 = tt.splat %9 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %14 = tt.splat %9 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %15 = tt.splat %9 : (i32) -> tensor<128xi32, #blocked3> + %16 = arith.addi %13, %10 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %17 = arith.addi %14, %11 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %18 = arith.addi %15, %12 : tensor<128xi32, #blocked3> + %19 = arith.mulf %arg3, %cst_4 : f32 + %20 = tt.addptr %arg0, %2 : !tt.ptr, i32 + %21 = tt.expand_dims %16 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2> + %22 = tt.expand_dims %17 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128x1xi32, #mma> + %23 = tt.splat %arg8 : (i32) -> tensor<128x1xi32, #blocked2> + %24 = arith.muli %21, %23 : tensor<128x1xi32, #blocked2> + %25 = tt.splat %20 : (!tt.ptr) -> tensor<128x1x!tt.ptr, #blocked2> + %26 = tt.addptr %25, %24 : tensor<128x1x!tt.ptr, #blocked2>, tensor<128x1xi32, #blocked2> + %27 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %28 = tt.expand_dims %27 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi32, #blocked2> + %29 = tt.broadcast %26 : (tensor<128x1x!tt.ptr, #blocked2>) -> tensor<128x128x!tt.ptr, #blocked2> + %30 = tt.broadcast %28 : (tensor<1x128xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> + %31 = tt.addptr %29, %30 : tensor<128x128x!tt.ptr, #blocked2>, tensor<128x128xi32, #blocked2> + %32 = tt.load %31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #blocked2> + %33 = tt.splat %19 : (f32) -> tensor<128x128xf32, #blocked2> + %34 = arith.extf %32 : tensor<128x128xf16, #blocked2> to tensor<128x128xf32, #blocked2> + %35 = arith.mulf %34, %33 : tensor<128x128xf32, #blocked2> + %36 = arith.truncf %35 : tensor<128x128xf32, #blocked2> to tensor<128x128xf16, #blocked2> + %37 = arith.addi %0, %c1_i32 : i32 + %38 = arith.muli %37, %c128_i32 : i32 + %42:5 = scf.for %arg22 = %c0_i32 to %38 step %c64_i32 iter_args(%arg23 = %cst_3, %arg24 = %cst_2, %arg25 = %cst_1, %arg26 = %6, %arg27 = %8) -> (tensor<128x128xf32, #mma1>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr, 1>, !tt.ptr, 1>) : i32 { + %59 = tt.load %arg26 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<128x64xf16, #blocked4> + %60 = tt.load %arg27 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<64x128xf16, #blocked2> + %66 = triton_gpu.convert_layout %36 : (tensor<128x128xf16, #blocked2>) -> tensor<128x128xf16, #shared> + %67 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked4>) -> tensor<128x64xf16, #shared1> + %68 = tt.dot %66, %67, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf16, #shared> * tensor<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> + %81 = arith.truncf %68 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> + %82 = triton_gpu.convert_layout %60 : (tensor<64x128xf16, #blocked2>) -> tensor<64x128xf16, #shared> + %83 = triton_gpu.convert_layout %81 : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + // CHECK: triton_nvidia_gpu.dot_async + // CHECK-NOT: triton_nvidia_gpu.dot_wait + // CHECK: scf.yield + %84 = tt.dot %83, %82, %arg23 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<64x128xf16, #shared> -> tensor<128x128xf32, #mma1> + %85 = arith.mulf %arg24, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %87 = arith.addf %85, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %88 = tt.advance %arg26, [%c0_i32, %c64_i32] : , 1> + %89 = tt.advance %arg27, [%c64_i32, %c0_i32] : , 1> + scf.yield %84, %87, %arg25, %88, %89 : tensor<128x128xf32, #mma1>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr, 1>, !tt.ptr, 1> + } + %54 = arith.addi %3, %9 : i32 + %55 = arith.extsi %arg17 : i32 to i64 + %56 = tt.make_tensor_ptr %arg5, [%4, %c128_i64], [%55, %c1_i64], [%54, %c0_i32] {order = array} : , 1> + %57 = arith.truncf %42 : tensor<128x128xf32, #mma1> to tensor<128x128xf16, #mma1> + %58 = triton_gpu.convert_layout %57 : (tensor<128x128xf16, #mma1>) -> tensor<128x128xf16, #blocked2> + tt.store %56, %58 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1>, tensor<128x128xf16, #blocked2> + tt.return + } +} diff --git a/test/TritonGPU/rewrite-tensor-pointer-tma.mlir b/test/TritonGPU/rewrite-tensor-pointer-tma.mlir new file mode 100644 index 000000000..f5bf34177 --- /dev/null +++ b/test/TritonGPU/rewrite-tensor-pointer-tma.mlir @@ -0,0 +1,40 @@ +// RUN: ENABLE_TMA=1 triton-opt %s -split-input-file -tritongpu-rewrite-tensor-pointer=compute-capability=90 | FileCheck %s +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_0d1d2d3de4de5de6de7c8c9de10de11c(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg4: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %c127_i32 = arith.constant 127 : i32 + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #blocked> + %c1_i64 = arith.constant 1 : i64 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg4, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg3, %c127_i32 : i32 + %4 = arith.divsi %3, %c128_i32 : i32 + %5 = arith.muli %2, %c8_i32 : i32 + %6 = arith.divsi %0, %5 : i32 + %7 = arith.muli %6, %c8_i32 : i32 + %8 = arith.subi %4, %7 : i32 + %9 = arith.minsi %8, %c8_i32 : i32 + %10 = arith.remsi %0, %9 : i32 + %11 = arith.addi %7, %10 : i32 + %12 = arith.remsi %0, %5 : i32 + %13 = arith.divsi %12, %9 : i32 + %14 = arith.muli %11, %c128_i32 : i32 + %15 = arith.muli %13, %c128_i32 : i32 + %16 = arith.extsi %arg3 : i32 to i64 + %17 = arith.extsi %arg5 : i32 to i64 + %18 = arith.extsi %arg6 : i32 to i64 + // CHECK: tt.make_tensor_ptr + %19 = tt.make_tensor_ptr %arg0, [%16, %17], [%18, %c1_i64], [%14, %c0_i32] {order = array} : , 1> + %20 = arith.extsi %arg4 : i32 to i64 + %21 = arith.extsi %arg7 : i32 to i64 + // CHECK: tt.make_tensor_ptr + %22 = tt.make_tensor_ptr %arg1, [%17, %20], [%c1_i64, %21], [%c0_i32, %15] {order = array} : , 1> + tt.return + } +} diff --git a/test/TritonGPU/rewrite-tensor-pointer.mlir b/test/TritonGPU/rewrite-tensor-pointer.mlir index cfe46c787..b953714bc 100644 --- a/test/TritonGPU/rewrite-tensor-pointer.mlir +++ b/test/TritonGPU/rewrite-tensor-pointer.mlir @@ -21,7 +21,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %6 = arith.divsi %0, %5 : i32 %7 = arith.muli %6, %c8_i32 : i32 %8 = arith.subi %4, %7 : i32 - %9 = "triton_gpu.cmpi"(%8, %c8_i32) {predicate = 2 : i64} : (i32, i32) -> i1 + %9 = arith.cmpi "slt", %8, %c8_i32 : i32 %10 = arith.select %9, %8, %c8_i32 : i32 %11 = arith.remsi %0, %10 : i32 %12 = arith.addi %7, %11 : i32 @@ -86,13 +86,13 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %6 = arith.extsi %arg5 : i32 to i64 // CHECK-NOT: tt.make_tensor_ptr %7 = tt.make_tensor_ptr %arg0, [%4, %5], [%6, %c1_i64], [%3, %c0_i32] {order = array} : , 1> - %8 = "triton_gpu.cmpi"(%2, %c132_i32) <{predicate = 5 : i64}> : (i32, i32) -> i1 + %8 = arith.cmpi "sge", %2, %c132_i32 : i32 scf.if %8 { %9 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked1> %10 = tt.splat %arg7 : (i32) -> tensor<64x1xi32, #blocked> %11 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked> %12 = scf.for %arg8 = %0 to %2 step %c132_i32 iter_args(%arg9 = %7) -> (!tt.ptr, 1>) : i32 { - %13 = "triton_gpu.cmpi"(%arg8, %c132_i32) <{predicate = 5 : i64}> : (i32, i32) -> i1 + %13 = arith.cmpi "sge", %arg8, %c132_i32 : i32 %14 = scf.if %13 -> (!tt.ptr, 1>) { %25 = arith.subi %arg8, %0 : i32 %26 = arith.muli %25, %c64_i32 : i32 diff --git a/test/TritonGPU/wsdecomposing.mlir b/test/TritonGPU/wsdecomposing.mlir index 7d89baa8d..a7fda209a 100644 --- a/test/TritonGPU/wsdecomposing.mlir +++ b/test/TritonGPU/wsdecomposing.mlir @@ -112,9 +112,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %78 = arith.addi %76, %77 : tensor<128x128xi32, #blocked2> %79 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #blocked2> %80 = tt.addptr %79, %78 : tensor<128x128x!tt.ptr, #blocked2>, tensor<128x128xi32, #blocked2> - %81 = "triton_gpu.cmpi"(%28, %37) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %81 = arith.cmpi "slt", %28, %37 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %82 = tt.expand_dims %81 {axis = 1 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi1, #blocked2> - %83 = "triton_gpu.cmpi"(%35, %40) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %83 = arith.cmpi "slt", %35, %40 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> %84 = tt.expand_dims %83 {axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi1, #blocked2> %85 = tt.broadcast %82 : (tensor<128x1xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> %86 = tt.broadcast %84 : (tensor<1x128xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> @@ -223,9 +223,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: %78 = arith.addi %76, %77 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128xi32, #blocked2> // CHECK-NEXT: %79 = tt.splat %arg2 {async_agent = dense<1> : vector<1xi32>} : (!tt.ptr) -> tensor<128x128x!tt.ptr, #blocked2> // CHECK-NEXT: %80 = tt.addptr %79, %78 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr, #blocked2>, tensor<128x128xi32, #blocked2> - // CHECK-NEXT: %81 = "triton_gpu.cmpi"(%28, %37) {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %81 = arith.cmpi "slt", %28, %37 {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> // CHECK-NEXT: %82 = tt.expand_dims %81 {async_agent = dense<1> : vector<1xi32>, axis = 1 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi1, #blocked2> - // CHECK-NEXT: %83 = "triton_gpu.cmpi"(%35, %40) {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %83 = arith.cmpi "slt", %35, %40 {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> // CHECK-NEXT: %84 = tt.expand_dims %83 {async_agent = dense<1> : vector<1xi32>, axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi1, #blocked2> // CHECK-NEXT: %85 = tt.broadcast %82 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x1xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> // CHECK-NEXT: %86 = tt.broadcast %84 {async_agent = dense<1> : vector<1xi32>} : (tensor<1x128xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> @@ -351,9 +351,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %78 = arith.addi %76, %77 : tensor<128x128xi32, #blocked2> %79 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #blocked2> %80 = tt.addptr %79, %78 : tensor<128x128x!tt.ptr, #blocked2>, tensor<128x128xi32, #blocked2> - %81 = "triton_gpu.cmpi"(%28, %37) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %81 = arith.cmpi "slt", %28, %37 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %82 = tt.expand_dims %81 {axis = 1 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi1, #blocked2> - %83 = "triton_gpu.cmpi"(%35, %40) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %83 = arith.cmpi "slt", %35, %40 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> %84 = tt.expand_dims %83 {axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi1, #blocked2> %85 = tt.broadcast %82 : (tensor<128x1xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> %86 = tt.broadcast %84 : (tensor<1x128xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> @@ -467,9 +467,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: %78 = arith.addi %76, %77 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128xi32, #blocked2> // CHECK-NEXT: %79 = tt.splat %arg2 {async_agent = dense<1> : vector<1xi32>} : (!tt.ptr) -> tensor<128x128x!tt.ptr, #blocked2> // CHECK-NEXT: %80 = tt.addptr %79, %78 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr, #blocked2>, tensor<128x128xi32, #blocked2> - // CHECK-NEXT: %81 = "triton_gpu.cmpi"(%28, %37) {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %81 = arith.cmpi "slt", %28, %37 {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> // CHECK-NEXT: %82 = tt.expand_dims %81 {async_agent = dense<1> : vector<1xi32>, axis = 1 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi1, #blocked2> - // CHECK-NEXT: %83 = "triton_gpu.cmpi"(%35, %40) {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %83 = arith.cmpi "slt", %35, %40 {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> // CHECK-NEXT: %84 = tt.expand_dims %83 {async_agent = dense<1> : vector<1xi32>, axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi1, #blocked2> // CHECK-NEXT: %85 = tt.broadcast %82 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x1xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> // CHECK-NEXT: %86 = tt.broadcast %84 {async_agent = dense<1> : vector<1xi32>} : (tensor<1x128xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> @@ -611,9 +611,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %78 = arith.addi %76, %77 : tensor<128x128xi32, #blocked2> %79 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #blocked2> %80 = tt.addptr %79, %78 : tensor<128x128x!tt.ptr, #blocked2>, tensor<128x128xi32, #blocked2> - %81 = "triton_gpu.cmpi"(%28, %37) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %81 = arith.cmpi "slt", %28, %37 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %82 = tt.expand_dims %81 {axis = 1 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi1, #blocked2> - %83 = "triton_gpu.cmpi"(%35, %40) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %83 = arith.cmpi "slt", %35, %40 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> %84 = tt.expand_dims %83 {axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi1, #blocked2> %85 = tt.broadcast %82 : (tensor<128x1xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> %86 = tt.broadcast %84 : (tensor<1x128xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> @@ -741,9 +741,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: %78 = arith.addi %76, %77 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128xi32, #blocked2> // CHECK-NEXT: %79 = tt.splat %arg2 {async_agent = dense<1> : vector<1xi32>} : (!tt.ptr) -> tensor<128x128x!tt.ptr, #blocked2> // CHECK-NEXT: %80 = tt.addptr %79, %78 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr, #blocked2>, tensor<128x128xi32, #blocked2> - // CHECK-NEXT: %81 = "triton_gpu.cmpi"(%28, %37) {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + // CHECK-NEXT: %81 = arith.cmpi "slt", %28, %37 {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> // CHECK-NEXT: %82 = tt.expand_dims %81 {async_agent = dense<1> : vector<1xi32>, axis = 1 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi1, #blocked2> - // CHECK-NEXT: %83 = "triton_gpu.cmpi"(%35, %40) {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + // CHECK-NEXT: %83 = arith.cmpi "slt", %35, %40 {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> // CHECK-NEXT: %84 = tt.expand_dims %83 {async_agent = dense<1> : vector<1xi32>, axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi1, #blocked2> // CHECK-NEXT: %85 = tt.broadcast %82 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x1xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> // CHECK-NEXT: %86 = tt.broadcast %84 {async_agent = dense<1> : vector<1xi32>} : (tensor<1x128xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> diff --git a/test/TritonGPU/wspipeline.mlir b/test/TritonGPU/wspipeline.mlir index 5356002b2..11552a6e9 100644 --- a/test/TritonGPU/wspipeline.mlir +++ b/test/TritonGPU/wspipeline.mlir @@ -136,9 +136,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %78 = arith.addi %76, %77 : tensor<128x128xi32, #blocked2> %79 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #blocked2> %80 = tt.addptr %79, %78 : tensor<128x128x!tt.ptr, #blocked2>, tensor<128x128xi32, #blocked2> - %81 = "triton_gpu.cmpi"(%28, %37) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %81 = arith.cmpi "slt", %28, %37 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %82 = tt.expand_dims %81 {axis = 1 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi1, #blocked2> - %83 = "triton_gpu.cmpi"(%35, %40) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %83 = arith.cmpi "slt", %35, %40 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> %84 = tt.expand_dims %83 {axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi1, #blocked2> %85 = tt.broadcast %82 : (tensor<128x1xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> %86 = tt.broadcast %84 : (tensor<1x128xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> diff --git a/test/TritonNvidiaGPU/ws-feasibility-checking.mlir b/test/TritonNvidiaGPU/ws-feasibility-checking.mlir index 981d4748d..e1383c32d 100644 --- a/test/TritonNvidiaGPU/ws-feasibility-checking.mlir +++ b/test/TritonNvidiaGPU/ws-feasibility-checking.mlir @@ -111,9 +111,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %78 = arith.addi %76, %77 : tensor<128x128xi32, #blocked2> %79 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #blocked2> %80 = tt.addptr %79, %78 : tensor<128x128x!tt.ptr, #blocked2>, tensor<128x128xi32, #blocked2> - %81 = "triton_gpu.cmpi"(%28, %37) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %81 = arith.cmpi "slt", %28, %37 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %82 = tt.expand_dims %81 {axis = 1 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi1, #blocked2> - %83 = "triton_gpu.cmpi"(%35, %40) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %83 = arith.cmpi "slt", %35, %40 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> %84 = tt.expand_dims %83 {axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi1, #blocked2> %85 = tt.broadcast %82 : (tensor<128x1xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> %86 = tt.broadcast %84 : (tensor<1x128xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> @@ -241,9 +241,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %78 = arith.addi %76, %77 : tensor<128x128xi32, #blocked2> %79 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #blocked2> %80 = tt.addptr %79, %78 : tensor<128x128x!tt.ptr, #blocked2>, tensor<128x128xi32, #blocked2> - %81 = "triton_gpu.cmpi"(%28, %37) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %81 = arith.cmpi "slt", %28, %37 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %82 = tt.expand_dims %81 {axis = 1 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi1, #blocked2> - %83 = "triton_gpu.cmpi"(%35, %40) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %83 = arith.cmpi "slt", %35, %40 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> %84 = tt.expand_dims %83 {axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi1, #blocked2> %85 = tt.broadcast %82 : (tensor<128x1xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> %86 = tt.broadcast %84 : (tensor<1x128xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> @@ -386,9 +386,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %78 = arith.addi %76, %77 : tensor<128x128xi32, #blocked2> %79 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #blocked2> %80 = tt.addptr %79, %78 : tensor<128x128x!tt.ptr, #blocked2>, tensor<128x128xi32, #blocked2> - %81 = "triton_gpu.cmpi"(%28, %37) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %81 = arith.cmpi "slt", %28, %37 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %82 = tt.expand_dims %81 {axis = 1 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi1, #blocked2> - %83 = "triton_gpu.cmpi"(%35, %40) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %83 = arith.cmpi "slt", %35, %40 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> %84 = tt.expand_dims %83 {axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi1, #blocked2> %85 = tt.broadcast %82 : (tensor<128x1xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> %86 = tt.broadcast %84 : (tensor<1x128xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> @@ -465,10 +465,10 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %34 = tt.broadcast %32 : (tensor<1x64xi32, #blocked4>) -> tensor<64x64xi32, #blocked4> %35 = tt.addptr %33, %34 : tensor<64x64x!tt.ptr, #blocked4>, tensor<64x64xi32, #blocked4> %36 = tt.splat %arg3 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> - %37 = "triton_gpu.cmpi"(%22, %36) {predicate = 2 : i64} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>, tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<64xi1, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %37 = arith.cmpi "slt", %22, %36 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> %38 = tt.expand_dims %37 {axis = 1 : i32} : (tensor<64xi1, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<64x1xi1, #blocked4> %39 = tt.splat %arg4 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> - %40 = "triton_gpu.cmpi"(%26, %39) {predicate = 2 : i64} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>, tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>) -> tensor<64xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %40 = arith.cmpi "slt", %26, %39 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> %41 = tt.expand_dims %40 {axis = 0 : i32} : (tensor<64xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>) -> tensor<1x64xi1, #blocked4> %42 = tt.broadcast %38 : (tensor<64x1xi1, #blocked4>) -> tensor<64x64xi1, #blocked4> %43 = tt.broadcast %41 : (tensor<1x64xi1, #blocked4>) -> tensor<64x64xi1, #blocked4> @@ -545,10 +545,10 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %34 = tt.broadcast %32 : (tensor<1x64xi32, #blocked4>) -> tensor<64x64xi32, #blocked4> %35 = tt.addptr %33, %34 : tensor<64x64x!tt.ptr, #blocked4>, tensor<64x64xi32, #blocked4> %36 = tt.splat %arg3 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> - %37 = "triton_gpu.cmpi"(%22, %36) {predicate = 2 : i64} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>, tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<64xi1, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %37 = arith.cmpi "slt", %22, %36 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> %38 = tt.expand_dims %37 {axis = 1 : i32} : (tensor<64xi1, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<64x1xi1, #blocked4> %39 = tt.splat %arg4 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> - %40 = "triton_gpu.cmpi"(%26, %39) {predicate = 2 : i64} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>, tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>) -> tensor<64xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %40 = arith.cmpi "slt", %26, %39 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> %41 = tt.expand_dims %40 {axis = 0 : i32} : (tensor<64xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>) -> tensor<1x64xi1, #blocked4> %42 = tt.broadcast %38 : (tensor<64x1xi1, #blocked4>) -> tensor<64x64xi1, #blocked4> %43 = tt.broadcast %41 : (tensor<1x64xi1, #blocked4>) -> tensor<64x64xi1, #blocked4> @@ -627,10 +627,10 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %34 = tt.broadcast %32 : (tensor<1x64xi32, #blocked4>) -> tensor<64x64xi32, #blocked4> %35 = tt.addptr %33, %34 : tensor<64x64x!tt.ptr, #blocked4>, tensor<64x64xi32, #blocked4> %36 = tt.splat %arg3 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> - %37 = "triton_gpu.cmpi"(%22, %36) {predicate = 2 : i64} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>, tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<64xi1, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %37 = arith.cmpi "slt", %22, %36 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> %38 = tt.expand_dims %37 {axis = 1 : i32} : (tensor<64xi1, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<64x1xi1, #blocked4> %39 = tt.splat %arg4 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> - %40 = "triton_gpu.cmpi"(%26, %39) {predicate = 2 : i64} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>, tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>) -> tensor<64xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %40 = arith.cmpi "slt", %26, %39 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> %41 = tt.expand_dims %40 {axis = 0 : i32} : (tensor<64xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>) -> tensor<1x64xi1, #blocked4> %42 = tt.broadcast %38 : (tensor<64x1xi1, #blocked4>) -> tensor<64x64xi1, #blocked4> %43 = tt.broadcast %41 : (tensor<1x64xi1, #blocked4>) -> tensor<64x64xi1, #blocked4> @@ -713,16 +713,16 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %34 = tt.broadcast %32 : (tensor<1x64xi32, #blocked4>) -> tensor<64x64xi32, #blocked4> %35 = tt.addptr %33, %34 : tensor<64x64x!tt.ptr, #blocked4>, tensor<64x64xi32, #blocked4> %36 = tt.splat %arg3 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> - %37 = "triton_gpu.cmpi"(%22, %36) {predicate = 2 : i64} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>, tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<64xi1, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> + %37 = arith.cmpi "slt", %22, %36 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> %38 = tt.expand_dims %37 {axis = 1 : i32} : (tensor<64xi1, #triton_gpu.slice<{dim = 1, parent = #blocked4}>>) -> tensor<64x1xi1, #blocked4> %39 = tt.splat %arg4 : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> - %40 = "triton_gpu.cmpi"(%26, %39) {predicate = 2 : i64} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>, tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>) -> tensor<64xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> + %40 = arith.cmpi "slt", %26, %39 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> %41 = tt.expand_dims %40 {axis = 0 : i32} : (tensor<64xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>) -> tensor<1x64xi1, #blocked4> %42 = tt.broadcast %38 : (tensor<64x1xi1, #blocked4>) -> tensor<64x64xi1, #blocked4> %43 = tt.broadcast %41 : (tensor<1x64xi1, #blocked4>) -> tensor<64x64xi1, #blocked4> %44 = arith.andi %42, %43 : tensor<64x64xi1, #blocked4> %45 = triton_gpu.convert_layout %14#0 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked4> - "tt.atomic_cas"(%arg100, %c0_i32, %c1_i32) {sem = 1 : i32}: (!tt.ptr, i32, i32) -> i32 + "tt.atomic_cas"(%arg100, %c0_i32, %c1_i32) {sem = 1 : i32, scope = 1 : i32}: (!tt.ptr, i32, i32) -> i32 %46 = tt.load %35, %44 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked4> %47 = arith.addf %45, %46 : tensor<64x64xf32, #blocked4> tt.store %35, %47, %44 {cache = 1 : i32, evict = 1 : i32} : tensor<64x64xf32, #blocked4> @@ -780,7 +780,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %23:4 = scf.for %arg9 = %0 to %7 step %c114_i32 iter_args(%arg10 = %15, %arg11 = %18, %arg12 = %8, %arg13 = %9) -> (!tt.ptr, 1>, !tt.ptr, 1>, i32, i32) : i32 { %24 = arith.divsi %arg9, %4 : i32 %25 = arith.remsi %arg9, %4 : i32 - %26 = "triton_gpu.cmpi"(%arg9, %c114_i32) {predicate = 5 : i64} : (i32, i32) -> i1 + %26 = arith.cmpi "sge", %arg9, %c114_i32 : i32 %27:2 = scf.if %26 -> (!tt.ptr, 1>, !tt.ptr, 1>) { %43 = arith.subi %24, %arg12 : i32 %44 = arith.muli %43, %c256_i32 : i32 @@ -928,7 +928,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %7 = arith.divsi %0, %6 : i32 %8 = arith.muli %7, %c8_i32 : i32 %9 = arith.subi %4, %8 : i32 - %10 = "triton_gpu.cmpi"(%9, %c8_i32) {predicate = 2 : i64} : (i32, i32) -> i1 + %10 = arith.cmpi "slt", %9, %c8_i32 : i32 %11 = arith.select %10, %9, %c8_i32 : i32 %12 = arith.remsi %0, %6 : i32 %13 = arith.remsi %12, %11 : i32 @@ -959,7 +959,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %38 = arith.divsi %arg11, %6 : i32 %39 = arith.muli %38, %c8_i32 : i32 %40 = arith.subi %4, %39 : i32 - %41 = "triton_gpu.cmpi"(%40, %c8_i32) {predicate = 2 : i64} : (i32, i32) -> i1 + %41 = arith.cmpi "slt", %40, %c8_i32 : i32 %42 = arith.select %41, %40, %c8_i32 : i32 %43 = arith.remsi %arg11, %6 : i32 %44 = arith.remsi %43, %42 : i32 @@ -982,9 +982,9 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %61 = tt.broadcast %59 : (tensor<64x1x!tt.ptr, #blocked2>) -> tensor<64x64x!tt.ptr, #blocked2> %62 = tt.broadcast %60 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> %63 = tt.addptr %61, %62 : tensor<64x64x!tt.ptr, #blocked2>, tensor<64x64xi32, #blocked2> - %64 = "triton_gpu.cmpi"(%52, %31) {predicate = 2 : i64} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<64xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %64 = arith.cmpi "slt", %52, %31 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %65 = tt.expand_dims %64 {axis = 1 : i32} : (tensor<64xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<64x1xi1, #blocked2> - %66 = "triton_gpu.cmpi"(%56, %32) {predicate = 2 : i64} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<64xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %66 = arith.cmpi "slt", %56, %32 : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> %67 = tt.expand_dims %66 {axis = 0 : i32} : (tensor<64xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi1, #blocked2> %68 = tt.broadcast %65 : (tensor<64x1xi1, #blocked2>) -> tensor<64x64xi1, #blocked2> %69 = tt.broadcast %67 : (tensor<1x64xi1, #blocked2>) -> tensor<64x64xi1, #blocked2> @@ -1011,7 +1011,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %81 = tt.advance %77#2, [%36, %c0_i32] : , 1> %82 = "tt.reduce"(%78) ({ ^bb0(%arg16: f32, %arg17: f32): - %91 = "triton_gpu.cmpf"(%arg16, %arg17) {predicate = 2 : i64} : (f32, f32) -> i1 + %91 = arith.cmpf "ogt", %arg16, %arg17 : f32 %92 = arith.select %91, %arg16, %arg17 : f32 tt.reduce.return %92 : f32 }) {axis = 1 : i32} : (tensor<64x64xf32, #blocked2>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> diff --git a/test/lit.cfg.py b/test/lit.cfg.py index 5ea9c458d..b4d1528cd 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -30,20 +30,14 @@ config.test_exec_root = os.path.join(config.triton_obj_root, 'test') config.substitutions.append(('%PATH%', config.environment['PATH'])) config.substitutions.append(('%shlibext', config.llvm_shlib_ext)) -llvm_config.with_system_environment( - ['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP']) +llvm_config.with_system_environment(['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP']) # llvm_config.use_default_substitutions() # excludes: A list of directories to exclude from the testsuite. The 'Inputs' # subdirectories contain auxiliary inputs for various tests in their parent # directories. -config.excludes = [ - 'Inputs', - 'Examples', - 'CMakeLists.txt', - 'README.txt', - 'LICENSE.txt'] +config.excludes = ['Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt'] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) @@ -52,10 +46,7 @@ config.test_source_root = os.path.dirname(__file__) config.test_exec_root = os.path.join(config.triton_obj_root, 'test') config.triton_tools_dir = os.path.join(config.triton_obj_root, 'bin') config.filecheck_dir = os.path.join(config.triton_obj_root, 'bin', 'FileCheck') -tool_dirs = [ - config.triton_tools_dir, - config.llvm_tools_dir, - config.filecheck_dir] +tool_dirs = [config.triton_tools_dir, config.llvm_tools_dir, config.filecheck_dir] # Tweak the PATH to include the tools dir. for d in tool_dirs: diff --git a/third_party/triton_shared b/third_party/triton_shared index 07ea84207..450e6be65 160000 --- a/third_party/triton_shared +++ b/third_party/triton_shared @@ -1 +1 @@ -Subproject commit 07ea84207ac7763af16206e6d790c7bc37c6c2d9 +Subproject commit 450e6be65f99a0b15fd130892594b85e0897574c diff --git a/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp b/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp index a0669baa9..06195ddd0 100644 --- a/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp +++ b/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp @@ -114,7 +114,7 @@ int evalThreadIdOp(mlir::gpu::ThreadIdOp threadIdOp, int ctaid, int tid) { else if (dim == mlir::gpu::Dimension::z) return 0; else - assert(0 && "Invalid thread dim"); + llvm::report_fatal_error("Invalid thread dim"); return 0; } @@ -127,7 +127,7 @@ int evalInlineAsmOp(mlir::LLVM::InlineAsmOp asmOp, int ctaid, int tid) { else if (asmStr.find("%cluster_ctaid.z") != std::string::npos) return 0; else if (asmStr.find("%cluster_nctaid.x") != std::string::npos) - assert(0 && "%cluster_nctaid.x not supported"); + llvm::report_fatal_error("%cluster_nctaid.x not supported"); else if (asmStr.find("%cluster_nctaid.y") != std::string::npos) return 1; else if (asmStr.find("%cluster_nctaid.z") != std::string::npos) @@ -135,7 +135,7 @@ int evalInlineAsmOp(mlir::LLVM::InlineAsmOp asmOp, int ctaid, int tid) { else if (asmStr.find("%mock_smem_base") != std::string::npos) return 0; else - assert(0 && "Unrecognized ASM string"); + llvm::report_fatal_error("Unrecognized ASM string"); return 0; } @@ -177,7 +177,7 @@ int eval(Value value, int ctaid, int tid) { } else if (auto gepOp = llvm::dyn_cast(op)) { return evalGEPOp(gepOp, ctaid, tid); } else { - assert(0 && "Unrecognized op type in the index expression"); + llvm::report_fatal_error("Unrecognized op type in the index expression"); return 0; } } diff --git a/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp b/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp index 20603cd2e..5aa4f473e 100644 --- a/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp +++ b/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp @@ -209,7 +209,8 @@ private: if (elemTyStr == "F16") return FloatType::getF16(&context); else - assert(0 && "getElemTy not implemented"); + llvm::report_fatal_error("getElemTy not implemented"); + return nullptr; } protected: