Merge remote-tracking branch 'oai/main' into ifu230620

Conflicts:
	include/triton/Conversion/TritonToTritonGPU/Passes.td
	include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td
	lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
	python/test/unit/language/assert_helper.py
	python/triton/compiler/compiler.py
	python/triton/runtime/jit.py
	python/triton/tools/aot.py
	test/Conversion/triton_to_tritongpu.mlir
	test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
Jason Furmanek
2023-06-29 21:47:27 +00:00
80 changed files with 2508 additions and 959 deletions

37
.github/workflows/Dockerfile vendored Normal file
View File

@@ -0,0 +1,37 @@
FROM centos:7
ARG llvm_dir=llvm-project
# Add the cache artifacts and the LLVM source tree to the container
ADD sccache /sccache
ADD "${llvm_dir}" /source/llvm-project
ENV SCCACHE_DIR="/sccache"
ENV SCCACHE_CACHE_SIZE="2G"
# Install build dependencies
RUN yum install --assumeyes centos-release-scl
RUN yum install --assumeyes devtoolset-9-gcc* python3-devel python3-pip
SHELL [ "/usr/bin/scl", "enable", "devtoolset-9" ]
RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --upgrade cmake ninja sccache
# Install MLIR's Python Dependencies
RUN python3 -m pip install -r /source/llvm-project/mlir/python/requirements.txt
# Configure, Build, Test, and Install LLVM
RUN cmake -GNinja -Bbuild \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_C_COMPILER=gcc \
-DCMAKE_CXX_COMPILER=g++ \
-DCMAKE_C_COMPILER_LAUNCHER=sccache \
-DCMAKE_CXX_COMPILER_LAUNCHER=sccache \
-DCMAKE_INSTALL_PREFIX="/install" \
-DLLVM_BUILD_UTILS=ON \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DLLVM_ENABLE_PROJECTS=mlir \
-DLLVM_INSTALL_UTILS=ON \
-DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU" \
/source/llvm-project/llvm
RUN ninja -C build check-mlir install

View File

@@ -25,7 +25,7 @@ jobs:
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
echo '::set-output name=matrix::[["self-hosted", "A100"], ["self-hosted", "V100"], ["self-hosted", "gfx908"]]'
echo '::set-output name=matrix::[["self-hosted", "A100"], ["self-hosted", "V100"], ["self-hosted", "H100"], ["self-hosted", "gfx908"], ["self-hosted", "arc770"]]'
else
echo '::set-output name=matrix::["ubuntu-latest"]'
fi
@@ -44,7 +44,7 @@ jobs:
uses: actions/checkout@v2
- name: Set CUDA ENV
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100')}}
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}}
run: |
echo "BACKEND=CUDA" >> "${GITHUB_ENV}"
@@ -53,6 +53,11 @@ jobs:
run: |
echo "BACKEND=ROCM" >> "${GITHUB_ENV}"
- name: Set XPU ENV
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'arc770')}}
run: |
echo "BACKEND=XPU" >> "${GITHUB_ENV}"
- name: Clear cache
run: |
rm -rf ~/.triton
@@ -62,13 +67,22 @@ jobs:
echo "PATH=${HOME}/.local/bin:${PATH}" >> "${GITHUB_ENV}"
- name: Check pre-commit
if: ${{ matrix.runner != 'macos-10.15' }}
if: ${{ matrix.runner != 'macos-10.15' && (matrix.runner[1] != 'arc770') }}
run: |
python3 -m pip install --upgrade pre-commit
python3 -m pre_commit run --all-files
- name: Check pre-commit arc770
if: ${{ matrix.runner != 'macos-10.15' && (matrix.runner[1] == 'arc770') }}
run: |
source ${HOME}/triton_vars.sh
source ${HOME}/miniconda3/bin/activate
conda activate triton-xpu-ci
python3 -m pip install --upgrade pre-commit
python3 -m pre_commit run --all-files
- name: Install Triton
if: ${{ env.BACKEND != 'ROCM'}}
if: ${{ env.BACKEND == 'CUDA'}}
run: |
cd python
python3 -m pip install --upgrade pip
@@ -84,8 +98,23 @@ jobs:
python3 -m pip install torch==1.13.1 --index-url https://download.pytorch.org/whl/rocm5.2
python3 -m pip install --no-build-isolation -vvv '.[tests]'
- name: Install Triton on XPU
if: ${{ env.BACKEND == 'XPU'}}
run: |
source ${HOME}/triton_vars.sh
source ${HOME}/miniconda3/bin/activate
conda activate triton-xpu-ci
git submodule update --init --recursive
cd python
python3 -m pip install --upgrade pip
python3 -m pip install cmake==3.24
export TRITON_CODEGEN_INTEL_XPU_BACKEND=1
python3 -m pip uninstall -y triton
python3 setup.py build
python3 -m pip install --no-build-isolation -vvv '.[tests]'
- name: Run lit tests
if: ${{ env.BACKEND != 'ROCM'}}
if: ${{ env.BACKEND == 'CUDA'}}
run: |
python3 -m pip install lit
cd python
@@ -102,19 +131,20 @@ jobs:
python3 -m pytest
- name: Create artifacts archive
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100')}}
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}}
run: |
tar -czvf artifacts.tar.gz ~/.triton/cache
cd ~/.triton
tar -czvf artifacts.tar.gz cache
- name: Upload artifacts archive
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100')}}
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}}
uses: actions/upload-artifact@v2
with:
name: artifacts
path: artifacts.tar.gz
path: ~/.triton/artifacts.tar.gz
- name: Run CXX unittests
if: ${{ env.BACKEND != 'ROCM'}}
if: ${{ env.BACKEND == 'CUDA'}}
run: |
cd python
cd "build/$(ls build | grep -i cmake)"
@@ -126,6 +156,15 @@ jobs:
cd python/test/unit/language
python3 -m pytest --capture=tee-sys -rfs --verbose "test_core.py::test_empty_kernel"
- name: Run python tests on XPU
if: ${{ env.BACKEND == 'XPU'}}
run: |
source ${HOME}/triton_vars.sh
source ${HOME}/miniconda3/bin/activate
conda activate triton-xpu-ci
cd python/test/backend/third_party_backends
python3 -m pytest --capture=tee-sys -rfs --verbose --backend xpu
- name: Regression tests
if: ${{ contains(matrix.runner, 'A100') }}
run: |

195
.github/workflows/llvm-build.yml vendored Normal file
View File

@@ -0,0 +1,195 @@
name: LLVM Build
on:
push:
branches:
- llvm-head
paths:
- llvm-hash.txt
workflow_dispatch:
env:
SCCACHE_DIR: ${{ github.workspace }}/sccache
permissions:
contents: read
id-token: write
jobs:
build:
strategy:
fail-fast: true
matrix:
platform: [
ubuntu-20.04-x64,
ubuntu-22.04-x64,
centos-7-x64,
macos-x64,
macos-arm64
]
include:
# Specify OS versions
- platform: ubuntu-20.04-x64
host-os: ubuntu-20.04
target-os: ubuntu
arch: x64
- platform: ubuntu-22.04-x64
host-os: ubuntu-22.04
target-os: ubuntu
arch: x64
- platform: centos-7-x64
host-os: ubuntu-22.04
target-os: centos
arch: x64
- platform: macos-x64
host-os: macos-12
target-os: macos
arch: x64
- platform: macos-arm64
host-os: macos-12
target-os: macos
arch: arm64
runs-on: ${{ matrix.host-os }}
steps:
- name: Checkout Repo
uses: actions/checkout@v3
with:
path: llvm-build
- name: Fetch LLVM Commit Hash
run: |
LLVM_COMMIT_HASH="$(cat llvm-build/llvm-hash.txt)"
echo "Found LLVM commit hash: ${LLVM_COMMIT_HASH}"
echo "llvm_commit_hash=${LLVM_COMMIT_HASH}" >> ${GITHUB_ENV}
SHORT_LLVM_COMMIT_HASH="${LLVM_COMMIT_HASH:0:8}"
echo "Short LLVM commit hash: ${SHORT_LLVM_COMMIT_HASH}"
echo "short_llvm_commit_hash=${SHORT_LLVM_COMMIT_HASH}" >> ${GITHUB_ENV}
INSTALL_DIR="llvm-${SHORT_LLVM_COMMIT_HASH}-${{ matrix.platform }}"
echo "LLVM installation directory name: ${INSTALL_DIR}"
echo "llvm_install_dir=${INSTALL_DIR}" >> ${GITHUB_ENV}
- name: Checkout LLVM
uses: actions/checkout@v3
with:
repository: llvm/llvm-project
path: llvm-project
ref: ${{ env.llvm_commit_hash }}
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.11
- name: Install Prerequisites
run: |
python3 -m pip install cmake ninja sccache
mkdir -p ${{ env.SCCACHE_DIR }}
rm -rf ${{ env.SCCACHE_DIR }}/*
- name: Enable Cache
uses: actions/cache@v3
with:
path: ${{ env.SCCACHE_DIR }}
key: ${{ matrix.platform }}-${{ env.short_llvm_commit_hash }}
restore-keys: ${{ matrix.platform }}-
- name: Configure, Build, Test, and Install LLVM (Ubuntu and macOS x64)
if: matrix.arch == 'x64' && contains(fromJSON('["ubuntu", "macos"]'), matrix.target-os)
run: >
python3 -m pip install -r llvm-project/mlir/python/requirements.txt
cmake -GNinja -Bllvm-project/build
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++
-DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache
-DCMAKE_INSTALL_PREFIX="${{ env.llvm_install_dir }}"
-DCMAKE_LINKER=lld
-DLLVM_BUILD_UTILS=ON
-DLLVM_ENABLE_ASSERTIONS=ON
-DMLIR_ENABLE_BINDINGS_PYTHON=ON
-DLLVM_ENABLE_PROJECTS=mlir
-DLLVM_INSTALL_UTILS=ON
-DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU"
llvm-project/llvm
ninja -C llvm-project/build check-mlir install
tar czf "${{ env.llvm_install_dir }}.tar.gz" "${{ env.llvm_install_dir }}"
- name: Configure, Build, and Install LLVM (macOS arm64)
if: matrix.arch == 'arm64' && matrix.target-os == 'macos'
run: >
python3 -m pip install -r llvm-project/mlir/python/requirements.txt
cmake -GNinja -Bllvm-project/build
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++
-DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache
-DCMAKE_INSTALL_PREFIX="${{ env.llvm_install_dir }}"
-DCMAKE_LINKER=lld
-DCMAKE_OSX_ARCHITECTURES=arm64
-DLLVM_BUILD_UTILS=ON
-DLLVM_ENABLE_ASSERTIONS=ON
-DMLIR_ENABLE_BINDINGS_PYTHON=ON
-DLLVM_ENABLE_PROJECTS=mlir
-DLLVM_ENABLE_ZSTD=OFF
-DLLVM_INSTALL_UTILS=ON
-DLLVM_TARGETS_TO_BUILD="AArch64"
-DLLVM_USE_HOST_TOOLS=ON
llvm-project/llvm
ninja -C llvm-project/build install
tar czf "${{ env.llvm_install_dir }}.tar.gz" "${{ env.llvm_install_dir }}"
- name: Configure, Build, Test, and Install LLVM (CentOS)
if: matrix.target-os == 'centos'
run: |
docker build --tag llvm-build --build-arg llvm_dir=llvm-project \
-f llvm-build/.github/workflows/Dockerfile .
# Create temporary container to copy cache and installed artifacts.
CONTAINER_ID=$(docker create llvm-build)
docker cp "${CONTAINER_ID}:/install" "${{ env.llvm_install_dir }}"
tar czf "${{ env.llvm_install_dir }}.tar.gz" "${{ env.llvm_install_dir }}"
# We remove the existing directory, otherwise docker will
# create a subdirectory inside the existing directory.
rm -rf "${{ env.SCCACHE_DIR }}"
docker cp "${CONTAINER_ID}:/sccache" "${{ env.SCCACHE_DIR }}"
sudo chown -R "$(id -u -n):$(id -g -n)" "${{ env.SCCACHE_DIR }}"
docker rm "${CONTAINER_ID}"
- name: Azure Login
uses: azure/login@v1
with:
client-id: ${{ secrets.AZURE_CLIENT_ID }}
tenant-id: ${{ secrets.AZURE_TENANT_ID }}
subscription-id: ${{ secrets.AZURE_SUBSCRIPTION_ID }}
- name: Upload LLVM Artifacts to Azure
run: |
az storage blob upload --account-name tritonlang --auth-mode login --container-name llvm-builds --file "${{ env.llvm_install_dir }}.tar.gz" --name "${{ env.llvm_install_dir }}.tar.gz" --overwrite
URL=$(az storage blob url --account-name tritonlang --auth-mode login --container-name llvm-builds --name "${{ env.llvm_install_dir }}.tar.gz")
echo "Blob URL: ${URL}"
- name: Azure Logout
run: |
az logout
az cache purge
az account clear
if: always()
- name: Dump Sccache Statistics
run: sccache --show-stats

View File

@@ -1,24 +1,45 @@
name: Wheels
on:
workflow_dispatch:
#schedule:
# - cron: "0 0 * * *"
schedule:
- cron: "0 2 * * *"
jobs:
Build-Wheels:
runs-on: [self-hosted, V100]
permissions:
id-token: write
contents: read
steps:
- name: Checkout
uses: actions/checkout@v2
uses: actions/checkout@v3
- name: Install Azure CLI
run: |
curl -sL https://aka.ms/InstallAzureCLIDeb | sudo bash
- name: Azure login
uses: azure/login@v1
with:
client-id: ${{ secrets.AZURE_CLIENT_ID }}
tenant-id: ${{ secrets.AZURE_TENANT_ID }}
subscription-id: ${{ secrets.AZURE_SUBSCRIPTION_ID }}
- id: generate-token
name: Generate token
run: |
AZ_TOKEN=$(az account get-access-token --query accessToken)
echo "::add-mask::$AZ_TOKEN"
echo "access_token=$AZ_TOKEN" >> "$GITHUB_OUTPUT"
- name: Patch setup.py
run: |
#sed -i 's/name\=\"triton\"/name="triton-nightly"/g' python/setup.py
export LATEST_DATE=$(TZ=UTC0 git show --quiet --date='format-local:%Y%m%d' --format="%cd")
sed -i 's/name\=\"triton\"/name="triton-nightly"/g' python/setup.py
export LATEST_DATE=$(TZ=UTC0 git show --quiet --date='format-local:%Y%m%d%H%M%S' --format="%cd")
sed -i -r "s/version\=\"(.*)\"/version=\"\1-dev"$LATEST_DATE"\"/g" python/setup.py
echo "" >> python/setup.cfg
echo "[build_ext]" >> python/setup.cfg
@@ -29,11 +50,10 @@ jobs:
export CIBW_MANYLINUX_X86_64_IMAGE="quay.io/pypa/manylinux2014_x86_64:latest"
#export CIBW_MANYLINUX_PYPY_X86_64_IMAGE="quay.io/pypa/manylinux2014_x86_64:latest"
export CIBW_BEFORE_BUILD="pip install cmake;"
export CIBW_SKIP="{cp,pp}35-*"
export CIBW_BUILD="{cp,pp}3*-manylinux_x86_64 cp3*-musllinux_x86_64"
export CIBW_SKIP="{cp,pp}{35,36}-*"
export CIBW_BUILD="{cp,pp}3*-manylinux_x86_64"
python3 -m cibuildwheel python --output-dir wheelhouse
- name: Upload wheels to PyPI
- name: Publish wheels to Azure DevOps
run: |
python3 -m twine upload wheelhouse/* -u __token__ -p ${{ secrets.PYPY_API_TOKEN }}
python3 -m twine upload -r Triton-Nightly -u TritonArtifactsSP -p ${{ steps.generate-token.outputs.access_token }} --config-file utils/nightly.pypirc --non-interactive --verbose wheelhouse/*

3
.gitmodules vendored Normal file
View File

@@ -0,0 +1,3 @@
[submodule "third_party/intel_xpu_backend"]
path = third_party/intel_xpu_backend
url = http://github.com/intel/intel-xpu-backend-for-triton

View File

@@ -22,6 +22,7 @@ endif()
# Options
option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends")
# Force TRITON_USE_ROCM for ROCm support
set(TRITON_USE_ROCM ON)
@@ -276,6 +277,14 @@ if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32)
target_link_libraries(triton ${CUTLASS_LIBRARIES} ${PYTHON_LDFLAGS})
endif()
list(LENGTH TRITON_CODEGEN_BACKENDS CODEGEN_BACKENDS_LEN)
if (${CODEGEN_BACKENDS_LEN} GREATER 0)
set(PYTHON_THIRD_PARTY_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/triton/third_party)
foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS})
add_subdirectory(third_party/${CODEGEN_BACKEND})
endforeach()
endif()
add_subdirectory(test)
add_subdirectory(unittest)

View File

@@ -25,12 +25,12 @@ You can install the latest stable release of Triton from pip:
```bash
pip install triton
```
Binary wheels are available for CPython 3.6-3.11 and PyPy 3.7-3.9.
Binary wheels are available for CPython 3.8-3.11 and PyPy 3.8-3.9.
And the latest nightly release:
```bash
pip install -U --pre triton
pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
```
# Install from source
```

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 54 KiB

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 49 KiB

View File

@@ -12,13 +12,13 @@ You can install the latest stable release of Triton from pip:
pip install triton
Binary wheels are available for CPython 3.6-3.9 and PyPy 3.6-3.7.
Binary wheels are available for CPython 3.8-3.11 and PyPy 3.8-3.9.
And the latest nightly release:
.. code-block:: bash
pip install -U --pre triton
pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
-----------

View File

@@ -1,7 +1,7 @@
Welcome to Triton's documentation!
==================================
Triton is a language and compiler for parallel programming. It aims to provide a Python-based programming environment for productively writing custom DNN compute kernels capable of running at maximal throughput on modern GPU hardware.
Triton_ is a language and compiler for parallel programming. It aims to provide a Python-based programming environment for productively writing custom DNN compute kernels capable of running at maximal throughput on modern GPU hardware.
Getting Started
@@ -52,3 +52,5 @@ Check out the following documents to learn more about Triton and how it compares
programming-guide/chapter-1/introduction
programming-guide/chapter-2/related-work
.. _Triton: https://github.com/openai/triton

View File

@@ -125,7 +125,7 @@ public:
}
}
/// Returns the size of the given buffer is a virtual buffer.
/// Returns if the given buffer is a virtual buffer.
bool isVirtualBuffer(BufferId bufferId) const {
return bufferSet.at(bufferId).kind == BufferT::BufferKind::Virtual;
}

View File

@@ -73,9 +73,11 @@ bool supportMMA(triton::DotOp op, int version);
bool supportMMA(Value value, int version);
Type getElementType(Value value);
bool isSingleValue(Value value);
std::string getValueOperandName(Value value, AsmState &state);
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
Type getElementType(Value value);
template <typename T_OUT, typename T_IN>
inline SmallVector<T_OUT> convertType(ArrayRef<T_IN> in) {
@@ -91,7 +93,7 @@ template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; }
// output[i] = input[order[i]]
/// output[i] = input[order[i]]
template <typename T, typename RES_T = T>
SmallVector<RES_T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
size_t rank = order.size();
@@ -103,6 +105,7 @@ SmallVector<RES_T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
return result;
}
/// Get the highest power of 2 divisor of an integer.
template <typename T> T highestPowOf2Divisor(T n) {
if (n == 0) {
return (static_cast<T>(1) << (sizeof(T) * 8 - 2));
@@ -110,9 +113,18 @@ template <typename T> T highestPowOf2Divisor(T n) {
return (n & (~(n - 1)));
}
bool isSingleValue(Value value);
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
/// Get the next power of 2 for an integer (or the integer itself if it is a
/// power of 2).
template <typename T> T nextPowOf2(T n) {
if (n == 0) {
return 1;
}
n--;
for (unsigned i = 1; i < sizeof(T) * 8; i <<= 1) {
n |= n >> i;
}
return n + 1;
}
/// Multi-root DAG topological sort.
/// Performs a topological sort of the Operation in the `toSort` SetVector.

View File

@@ -23,7 +23,11 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
"number of warps">,
Option<"threadsPerWarp", "threads-per-warp",
<<<<<<< HEAD
"int32_t", /*default*/"TRITONGPU_DEFAULT_WARPSIZE",
=======
"int32_t", /*default*/"32",
>>>>>>> oai/main
"number of threads per warp">,
];
}

View File

@@ -14,6 +14,17 @@ def TT_CacheModifierAttr : I32EnumAttr<
let cppNamespace = "::mlir::triton";
}
def TT_MemSemanticAttr : I32EnumAttr<
"MemSemantic", "",
[
I32EnumAttrCase<"RELAXED", 1, "relaxed">,
I32EnumAttrCase<"ACQUIRE", 2, "acquire">,
I32EnumAttrCase<"RELEASE", 3, "release">,
I32EnumAttrCase<"ACQUIRE_RELEASE", 4, "acq_rel">,
]> {
let cppNamespace = "::mlir::triton";
}
def TT_EvictionPolicyAttr : I32EnumAttr<
"EvictionPolicy", "",
[

View File

@@ -234,7 +234,8 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape,
}];
let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrLike:$ptr,
TT_Type:$val, Optional<TT_BoolLike>:$mask);
TT_Type:$val, Optional<TT_BoolLike>:$mask,
TT_MemSemanticAttr:$sem);
let results = (outs TT_Type:$result);
}
@@ -255,7 +256,8 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>,
return $old
}];
let arguments = (ins TT_PtrLike:$ptr, TT_Type:$cmp, TT_Type:$val);
let arguments = (ins TT_PtrLike:$ptr, TT_Type:$cmp, TT_Type:$val,
TT_MemSemanticAttr:$sem);
let results = (outs TT_Type:$result);
}
@@ -444,7 +446,6 @@ class TT_ExternElementwiseOpBase<string mnemonic, list<Trait> traits = []> :
let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)";
}
def TT_PureExternElementwiseOp : TT_ExternElementwiseOpBase<"pure_extern_elementwise", [Pure, Elementwise]> {
let summary = "FFI for pure element-wise extern LLVM bitcode functions";
}

View File

@@ -23,6 +23,9 @@ namespace gpu {
unsigned getTotalElemsPerThread(Type type);
unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape,
Type eltTy);
SmallVector<unsigned> getElemsPerThread(Type type);
SmallVector<unsigned> getThreadsPerWarp(Attribute layout);
@@ -74,11 +77,11 @@ SmallVector<unsigned> getOrder(Attribute layout);
bool isaDistributedLayout(Attribute layout);
bool isSharedEncoding(Value value);
} // namespace gpu
} // namespace triton
bool isSharedEncoding(Value value);
} // namespace mlir
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_

View File

@@ -34,7 +34,11 @@ def TritonGPU_Dialect : Dialect {
static int getThreadsPerWarp(ModuleOp mod) {
Attribute threadsPerWarp = mod->getDiscardableAttr("triton_gpu.threads-per-warp");
if(!threadsPerWarp) {
<<<<<<< HEAD
return TRITONGPU_DEFAULT_WARPSIZE;
=======
return 32;
>>>>>>> oai/main
}
return threadsPerWarp.cast<IntegerAttr>().getInt();
}

View File

@@ -1,9 +1,13 @@
#ifndef TRITON_LIB_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
#define TRITON_LIB_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/MapVector.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
namespace mlir {
LogicalResult fixupLoops(ModuleOp mod);
@@ -12,9 +16,11 @@ LogicalResult fixupLoops(ModuleOp mod);
LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
Attribute &ret);
bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding);
bool isExpensiveLoadOrStore(Operation *op, Attribute &targetEncoding);
bool expensiveToRemat(Operation *op, Attribute &targetEncoding);
bool isExpensiveCat(triton::CatOp cat, Attribute &targetEncoding);
bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding);
// skipInit is True when we only consider the operands of the initOp but
// not the initOp itself.
@@ -36,4 +42,4 @@ LogicalResult canMoveOutOfLoop(BlockArgument arg,
} // namespace mlir
#endif // TRITON_LIB_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

View File

@@ -1,19 +0,0 @@
#ifndef TRITON_TARGET_AMDGCNTRANSLATION_H
#define TRITON_TARGET_AMDGCNTRANSLATION_H
#include <string>
#include <tuple>
namespace llvm {
class Module;
} // namespace llvm
namespace triton {
// Translate LLVM IR to AMDGCN code.
std::tuple<std::string, std::string>
translateLLVMIRToAMDGCN(llvm::Module &module, std::string cc);
} // namespace triton
#endif

View File

@@ -38,7 +38,7 @@ void SharedMemoryAliasAnalysis::visitOperation(
// insert_slice %src into %dst[%offsets]
aliasInfo = AliasInfo(operands[1]->getValue());
pessimistic = false;
} else if (isSharedEncoding(result)) {
} else if (triton::gpu::isSharedEncoding(result)) {
aliasInfo.insert(result);
pessimistic = false;
}

View File

@@ -151,7 +151,7 @@ private:
}
for (Value result : op->getResults()) {
if (isSharedEncoding(result)) {
if (triton::gpu::isSharedEncoding(result)) {
// Bytes could be a different value once we support padding or other
// allocation policies.
auto tensorType = result.getType().dyn_cast<RankedTensorType>();

View File

@@ -60,9 +60,9 @@ SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
auto argLayout = getSrcLayout();
auto argLayoutMma = argLayout.dyn_cast<triton::gpu::MmaEncodingAttr>();
if (argLayoutMma && argLayoutMma.getVersionMajor() == 2 &&
triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
return {{1, 1}, {1, 1}};
// if (argLayoutMma && argLayoutMma.getVersionMajor() == 2 &&
// triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
// return {{1, 1}, {1, 1}};
/// shared memory block0
smemShapes[0] = convertType<unsigned>(getSrcShape());
@@ -181,13 +181,6 @@ Type getElementType(Value value) {
return type;
}
std::string getValueOperandName(Value value, AsmState &state) {
std::string opName;
llvm::raw_string_ostream ss(opName);
value.printAsOperand(ss, state);
return opName;
}
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
// dot_op<opIdx=0, parent=#mma> = #mma
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>

View File

@@ -107,7 +107,7 @@ static Value loadA(Value tensor, const SharedMemoryObject &smemObj,
auto order = sharedLayout.getOrder();
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
bool isARow = order[0] != 0;
auto resultEncoding = resultTy.cast<RankedTensorType>()
@@ -232,7 +232,7 @@ static Value loadB(Value tensor, const SharedMemoryObject &smemObj,
auto shape = tensorTy.getShape();
auto order = sharedLayout.getOrder();
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
Value smem = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
bool isBRow = order[0] != 0; // is row-major in shared memory layout
// isBRow_ indicates whether B is row-major in DotOperand layout
auto resultEncoding = resultTy.cast<RankedTensorType>()

View File

@@ -18,7 +18,8 @@ using ::mlir::triton::gpu::SharedEncodingAttr;
// Data loader for mma.16816 instruction.
class MMA16816SmemLoader {
public:
MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, uint32_t kOrder,
MMA16816SmemLoader(int warpsPerTile, ArrayRef<uint32_t> order,
ArrayRef<uint32_t> warpsPerCTA, uint32_t kOrder,
int kWidth, ArrayRef<Value> smemStrides,
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
ArrayRef<int> matShape, int perPhase, int maxPhase,
@@ -27,7 +28,7 @@ public:
const Location &loc);
// lane = thread % 32
// warpOff = (thread/32) % wpt(0)
// warpOff = (thread/32) % warpsPerTile(0)
llvm::SmallVector<Value> computeOffsets(Value warpOff, Value lane,
Value cSwizzleOffset) {
if (canUseLdmatrix)
@@ -48,12 +49,14 @@ public:
Value cSwizzleOffset);
// Load 4 matrices and returns 4 vec<2> elements.
std::tuple<Value, Value, Value, Value>
loadX4(int mat0, int mat1, ArrayRef<Value> offs, ArrayRef<Value> ptrs,
Type matTy, Type shemPtrTy) const;
std::tuple<Value, Value, Value, Value> loadX4(int mat0, int mat1,
ArrayRef<Value> ptrs,
Type matTy,
Type shemPtrTy) const;
private:
SmallVector<uint32_t> order;
SmallVector<uint32_t> warpsPerCTA;
int kOrder;
int kWidth;
SmallVector<int64_t> tileShape;
@@ -66,42 +69,50 @@ private:
const Location &loc;
MLIRContext *ctx{};
int cMatShape;
int sMatShape;
// ldmatrix loads a matrix of size stridedMatShape x contiguousMatShape
int contiguousMatShape;
int stridedMatShape;
Value sStride;
// Offset in shared memory to increment on the strided axis
// This would be different than the tile shape in the case of a sliced tensor
Value stridedSmemOffset;
bool needTrans;
bool canUseLdmatrix;
int numPtrs;
int pLoadStrideInMat;
int sMatStride;
// Load operations offset in number of Matrices on contiguous and strided axes
int contiguousLoadMatOffset;
int stridedLoadMatOffset;
int matArrStride;
int warpOffStride;
// Offset in number of matrices to increment on non-k dim within a warp's 2x2
// matrices
int inWarpMatOffset;
// Offset in number of matrices to increment on non-k dim across warps
int warpMatOffset;
};
SmallVector<Value>
MMA16816SmemLoader::computeLdmatrixMatOffs(Value warpId, Value lane,
Value cSwizzleOffset) {
// 4x4 matrices
Value c = urem(lane, i32_val(8));
Value s = udiv(lane, i32_val(8)); // sub-warp-id
Value rowInMat = urem(lane, i32_val(8)); // row in the 8x8 matrix
Value matIndex =
udiv(lane, i32_val(8)); // linear index of the matrix in the 2x2 matrices
// Decompose s => s_0, s_1, that is the coordinate in 2x2 matrices in a
// Decompose matIndex => s_0, s_1, that is the coordinate in 2x2 matrices in a
// warp
Value s0 = urem(s, i32_val(2));
Value s1 = udiv(s, i32_val(2));
Value s0 = urem(matIndex, i32_val(2));
Value s1 = udiv(matIndex, i32_val(2));
// We use different orders for a and b for better performance.
Value kMatArr = kOrder == 1 ? s1 : s0;
Value nkMatArr = kOrder == 1 ? s0 : s1;
Value kMatArr = kOrder == 1 ? s1 : s0; // index of matrix on the k dim
Value nkMatArr = kOrder == 1 ? s0 : s1; // index of matrix on the non-k dim
// Matrix coordinates inside a CTA,
// the matrix layout is [2wpt[0], 2] for A and [2, 2wpt[1]] for B.
// e.g., Setting wpt=4, the data layout for A(kOrder=1) is
// the matrix layout is [2warpsPerTile[0], 2] for A and [2, 2warpsPerTile[1]]
// for B. e.g., Setting warpsPerTile=4, the data layout for A(kOrder=1) is
// |0 0| -> 0,1,2,3 are the warpids
// |0 0|
// |1 1|
@@ -118,34 +129,42 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value warpId, Value lane,
// address (s0,s1) annotates.
Value matOff[2];
matOff[kOrder ^ 1] =
add(mul(warpId, i32_val(warpOffStride)), // warp offset (kOrder=1)
mul(nkMatArr,
i32_val(matArrStride))); // matrix offset inside a warp (kOrder=1)
matOff[kOrder ^ 1] = add(
mul(warpId, i32_val(warpMatOffset)), // warp offset (kOrder=1)
mul(nkMatArr,
i32_val(inWarpMatOffset))); // matrix offset inside a warp (kOrder=1)
matOff[kOrder] = kMatArr;
// Physical offset (before swizzling)
Value cMatOff = matOff[order[0]];
Value sMatOff = matOff[order[1]];
Value cSwizzleMatOff = udiv(cSwizzleOffset, i32_val(cMatShape));
cMatOff = add(cMatOff, cSwizzleMatOff);
// row offset inside a matrix, each matrix has 8 rows.
Value sOffInMat = c;
Value contiguousMatIndex = matOff[order[0]];
Value stridedMatIndex = matOff[order[1]];
// Add the offset of the slice
Value contiguousSliceMatOffset =
udiv(cSwizzleOffset, i32_val(contiguousMatShape));
SmallVector<Value> offs(numPtrs);
Value phase = urem(udiv(sOffInMat, i32_val(perPhase)), i32_val(maxPhase));
// To prevent out-of-bound access of B when wpt * 16 > tile_size.
Value phase = urem(udiv(rowInMat, i32_val(perPhase)), i32_val(maxPhase));
// To prevent out-of-bound access of B when warpsPerTile * 16 > tile_size.
// In such a case, we need to wrap around the offset of B.
// |0 1 2 3 0 1 2 3| -> | 0(0) 1(1) 2(2) 3(3) |
// |0 1 2 3 0 1 2 3| | 0(0) 1(1) 2(2) 3(3) |
// ~~~~~~~ out-of-bound access
Value sOff = urem(add(sOffInMat, mul(sMatOff, i32_val(sMatShape))),
i32_val(tileShape[order[1]]));
Value rowOffset =
urem(add(rowInMat, mul(stridedMatIndex, i32_val(stridedMatShape))),
i32_val(tileShape[order[1]]));
auto contiguousTileNumMats = tileShape[order[0]] / matShape[order[0]];
for (int i = 0; i < numPtrs; ++i) {
Value cMatOffI = add(cMatOff, i32_val(i * pLoadStrideInMat));
cMatOffI = xor_(cMatOffI, phase);
offs[i] = add(mul(cMatOffI, i32_val(cMatShape)), mul(sOff, sStride));
Value contiguousIndex =
add(contiguousMatIndex, i32_val(i * contiguousLoadMatOffset));
if (warpsPerCTA[order[0]] > contiguousTileNumMats ||
contiguousTileNumMats % warpsPerCTA[order[0]] != 0)
contiguousIndex = urem(contiguousIndex, i32_val(contiguousTileNumMats));
contiguousIndex = add(contiguousIndex, contiguousSliceMatOffset);
Value contiguousIndexSwizzled = xor_(contiguousIndex, phase);
offs[i] = add(mul(contiguousIndexSwizzled, i32_val(contiguousMatShape)),
mul(rowOffset, stridedSmemOffset));
}
return offs;
@@ -208,14 +227,14 @@ SmallVector<Value> MMA16816SmemLoader::computeLdsMatOffs(Value warpOff,
Value jOff = i32_val(0);
if (!needTrans) {
jOff = add(jOff, i32_val(quadId));
jOff = add(jOff, i32_val(rep * pLoadStrideInMat));
jOff = add(jOff, i32_val(rep * contiguousLoadMatOffset));
}
// outer index offset
Value iOff = mul(warpOff, i32_val(warpOffStride));
Value iOff = mul(warpOff, i32_val(warpMatOffset));
if (needTrans) {
int pStride = kOrder == 1 ? 1 : 2;
iOff = add(iOff, i32_val(quadId * matArrStride));
iOff = add(iOff, i32_val(rep * pLoadStrideInMat * pStride));
iOff = add(iOff, i32_val(quadId * inWarpMatOffset));
iOff = add(iOff, i32_val(rep * contiguousLoadMatOffset * pStride));
}
// swizzle
if (!needTrans) {
@@ -234,9 +253,9 @@ SmallVector<Value> MMA16816SmemLoader::computeLdsMatOffs(Value warpOff,
// i = urem(i, i32_val(cTileShape));
// j = urem(j, i32_val(sTileShape));
if (needTrans) {
offs[idx] = add(i, mul(j, sStride));
offs[idx] = add(i, mul(j, stridedSmemOffset));
} else {
offs[idx] = add(mul(i, sStride), j);
offs[idx] = add(mul(i, stridedSmemOffset), j);
}
}
@@ -244,8 +263,7 @@ SmallVector<Value> MMA16816SmemLoader::computeLdsMatOffs(Value warpOff,
}
std::tuple<Value, Value, Value, Value>
MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> offs,
ArrayRef<Value> ptrs, Type matTy,
MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> ptrs, Type matTy,
Type shemPtrTy) const {
assert(mat0 % 2 == 0 && mat1 % 2 == 0 && "smem matrix load must be aligned");
int matIdx[2] = {mat0, mat1};
@@ -282,15 +300,16 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> offs,
}
if (canUseLdmatrix) {
Value sOffset =
mul(i32_val(matIdx[order[1]] * sMatStride * sMatShape), sStride);
Value sOffsetPtr = gep(shemPtrTy, ptr, sOffset);
Value stridedOffset =
mul(i32_val(matIdx[order[1]] * stridedLoadMatOffset * stridedMatShape),
stridedSmemOffset);
Value readPtr = gep(shemPtrTy, ptr, stridedOffset);
PTXBuilder builder;
// ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a
// thread.
auto resArgs = builder.newListOperand(4, "=r");
auto addrArg = builder.newAddrOperand(sOffsetPtr, "r");
auto addrArg = builder.newAddrOperand(readPtr, "r");
auto ldmatrix = builder.create("ldmatrix.sync.aligned.m8n8.x4")
->o("trans", needTrans /*predicate*/)
@@ -313,14 +332,14 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> offs,
for (int i = 0; i < vecWidth; i++)
ptrs[1][i] = getPtr(ptrIdx + i + vecWidth);
// static offsets along outer dimension
int _i0 = matIdx[order[1]] * (sMatStride * sMatShape);
int _i0 = matIdx[order[1]] * (stridedLoadMatOffset * stridedMatShape);
int _i1 = _i0;
if (needTrans)
_i1 += sMatStride * sMatShape;
_i1 += stridedLoadMatOffset * stridedMatShape;
else
_i1 += (kOrder == 1 ? 1 : sMatStride) * sMatShape;
Value i0 = mul(i32_val(_i0), sStride);
Value i1 = mul(i32_val(_i1), sStride);
_i1 += (kOrder == 1 ? 1 : stridedLoadMatOffset) * stridedMatShape;
Value i0 = mul(i32_val(_i0), stridedSmemOffset);
Value i1 = mul(i32_val(_i1), stridedSmemOffset);
std::array<Value, 2> ii = {i0, i1};
// load 4 32-bit values from shared memory
// (equivalent to ldmatrix.x4)
@@ -363,21 +382,23 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> offs,
}
MMA16816SmemLoader::MMA16816SmemLoader(
int wpt, ArrayRef<uint32_t> order, uint32_t kOrder, int kWidth,
ArrayRef<Value> smemStrides, ArrayRef<int64_t> tileShape,
ArrayRef<int> instrShape, ArrayRef<int> matShape, int perPhase,
int maxPhase, int elemBytes, ConversionPatternRewriter &rewriter,
int warpsPerTile, ArrayRef<uint32_t> order, ArrayRef<uint32_t> warpsPerCTA,
uint32_t kOrder, int kWidth, ArrayRef<Value> smemStrides,
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
ArrayRef<int> matShape, int perPhase, int maxPhase, int elemBytes,
ConversionPatternRewriter &rewriter,
TritonGPUToLLVMTypeConverter *typeConverter, const Location &loc)
: order(order.begin(), order.end()), kOrder(kOrder), kWidth(kWidth),
tileShape(tileShape.begin(), tileShape.end()),
: order(order.begin(), order.end()),
warpsPerCTA(warpsPerCTA.begin(), warpsPerCTA.end()), kOrder(kOrder),
kWidth(kWidth), tileShape(tileShape.begin(), tileShape.end()),
instrShape(instrShape.begin(), instrShape.end()),
matShape(matShape.begin(), matShape.end()), perPhase(perPhase),
maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter), loc(loc),
ctx(rewriter.getContext()) {
cMatShape = matShape[order[0]];
sMatShape = matShape[order[1]];
contiguousMatShape = matShape[order[0]];
stridedMatShape = matShape[order[1]];
sStride = smemStrides[order[1]];
stridedSmemOffset = smemStrides[order[1]];
// rule: k must be the fast-changing axis.
needTrans = kOrder != order[0];
@@ -385,12 +406,13 @@ MMA16816SmemLoader::MMA16816SmemLoader(
canUseLdmatrix = canUseLdmatrix && (kWidth == 4 / elemBytes);
if (canUseLdmatrix) {
// Each CTA, the warps is arranged as [1xwpt] if not transposed,
// otherwise [wptx1], and each warp will perform a mma.
numPtrs =
tileShape[order[0]] / (needTrans ? wpt : 1) / instrShape[order[0]];
// Each CTA, the warps is arranged as [1xwarpsPerTile] if not transposed,
// otherwise [warpsPerTilex1], and each warp will perform a mma.
numPtrs = tileShape[order[0]] / (needTrans ? warpsPerTile : 1) /
instrShape[order[0]];
} else {
numPtrs = tileShape[order[0]] / (needTrans ? wpt : 1) / matShape[order[0]];
numPtrs = tileShape[order[0]] / (needTrans ? warpsPerTile : 1) /
matShape[order[0]];
numPtrs *= 4 / elemBytes;
}
numPtrs = std::max<int>(numPtrs, 2);
@@ -398,23 +420,24 @@ MMA16816SmemLoader::MMA16816SmemLoader(
// Special rule for i8/u8, 4 ptrs for each matrix
// if (!canUseLdmatrix && elemBytes == 1)
int loadStrideInMat[2];
loadStrideInMat[kOrder] =
int loadOffsetInMat[2];
loadOffsetInMat[kOrder] =
2; // instrShape[kOrder] / matShape[kOrder], always 2
loadStrideInMat[kOrder ^ 1] =
wpt * (instrShape[kOrder ^ 1] / matShape[kOrder ^ 1]);
loadOffsetInMat[kOrder ^ 1] =
warpsPerTile * (instrShape[kOrder ^ 1] / matShape[kOrder ^ 1]);
pLoadStrideInMat = loadStrideInMat[order[0]];
contiguousLoadMatOffset = loadOffsetInMat[order[0]];
sMatStride =
loadStrideInMat[order[1]] / (instrShape[order[1]] / matShape[order[1]]);
stridedLoadMatOffset =
loadOffsetInMat[order[1]] / (instrShape[order[1]] / matShape[order[1]]);
// Each matArr contains warpOffStride matrices.
matArrStride = kOrder == 1 ? 1 : wpt;
warpOffStride = instrShape[kOrder ^ 1] / matShape[kOrder ^ 1];
// The stride (in number of matrices) within warp
inWarpMatOffset = kOrder == 1 ? 1 : warpsPerTile;
// The stride (in number of matrices) of each warp
warpMatOffset = instrShape[kOrder ^ 1] / matShape[kOrder ^ 1];
}
Type getShemPtrTy(Type argType) {
Type getSharedMemPtrTy(Type argType) {
MLIRContext *ctx = argType.getContext();
if (argType.isF16())
return ptr_ty(type::f16Ty(ctx), 3);
@@ -451,13 +474,12 @@ Value composeValuesToDotOperandLayoutStruct(
return result;
}
std::function<void(int, int)>
getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj,
MmaEncodingAttr mmaLayout, int wpt, uint32_t kOrder, int kWidth,
SmallVector<int> instrShape, SmallVector<int> matShape,
Value warpId, Value lane, ValueTable &vals, bool isA,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc) {
std::function<void(int, int)> getLoadMatrixFn(
Value tensor, const SharedMemoryObject &smemObj, MmaEncodingAttr mmaLayout,
int warpsPerTile, uint32_t kOrder, int kWidth, SmallVector<int> instrShape,
SmallVector<int> matShape, Value warpId, Value lane, ValueTable &vals,
bool isA, TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc) {
auto tensorTy = tensor.getType().cast<RankedTensorType>();
Type eltTy = tensorTy.getElementType();
// We assumes that the input operand of Dot should be from shared layout.
@@ -471,25 +493,27 @@ getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj,
// (a, b) is the coordinate.
auto load = [=, &rewriter, &vals](int a, int b) {
MMA16816SmemLoader loader(
wpt, sharedLayout.getOrder(), kOrder, kWidth, smemObj.strides,
tensorTy.getShape() /*tileShape*/, instrShape, matShape, perPhase,
maxPhase, elemBytes, rewriter, typeConverter, loc);
warpsPerTile, sharedLayout.getOrder(), mmaLayout.getWarpsPerCTA(),
kOrder, kWidth, smemObj.strides, tensorTy.getShape() /*tileShape*/,
instrShape, matShape, perPhase, maxPhase, elemBytes, rewriter,
typeConverter, loc);
// Offset of a slice within the original tensor in shared memory
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
SmallVector<Value> offs =
loader.computeOffsets(warpId, lane, cSwizzleOffset);
// initialize pointers
const int numPtrs = loader.getNumPtrs();
SmallVector<Value> ptrs(numPtrs);
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
Type smemPtrTy = getShemPtrTy(eltTy);
Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
Type smemPtrTy = getSharedMemPtrTy(eltTy);
for (int i = 0; i < numPtrs; ++i)
ptrs[i] = bitcast(gep(smemPtrTy, smemBase, offs[i]), smemPtrTy);
// actually load from shared memory
auto matTy = LLVM::LLVMStructType::getLiteral(eltTy.getContext(),
SmallVector<Type>(4, i32_ty));
auto [ha0, ha1, ha2, ha3] = loader.loadX4(
(kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs,
ptrs, matTy, getShemPtrTy(eltTy));
(kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, ptrs,
matTy, getSharedMemPtrTy(eltTy));
if (!isA)
std::swap(ha1, ha2);
// the following is incorrect
@@ -527,32 +551,34 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, Value tensor,
auto numRep = encoding.getMMAv2Rep(tensorTy.getShape(), bitwidth);
int kWidth = encoding.getMMAv2kWidth();
int wpt0 = mmaLayout.getWarpsPerCTA()[0];
int wpt1 = mmaLayout.getWarpsPerCTA()[1];
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
Value warp = udiv(thread, i32_val(32));
Value lane = urem(thread, i32_val(32));
Value warpM = urem(urem(warp, i32_val(wpt0)), i32_val(shape[0] / 16));
Value warpMN = udiv(warp, i32_val(wpt0));
Value warpN = urem(urem(warpMN, i32_val(wpt1)), i32_val(shape[1] / 8));
// Note: warps are currently column major in MMA layout
Value warpRowIndex = urem(warp, i32_val(warpsPerCTA[0]));
Value warpColIndex =
urem(udiv(warp, i32_val(warpsPerCTA[0])), i32_val(warpsPerCTA[1]));
Value warpM = urem(warpRowIndex, i32_val(shape[0] / 16));
Value warpN = urem(warpColIndex, i32_val(shape[1] / 8));
int wpt;
int warpsPerTile;
if (isA)
wpt = std::min<int>(wpt0, shape[0] / 16);
warpsPerTile = std::min<int>(warpsPerCTA[0], shape[0] / 16);
else
wpt = std::min<int>(wpt1, shape[1] / 16);
warpsPerTile = std::min<int>(warpsPerCTA[1], shape[1] / 16);
std::function<void(int, int)> loadFn;
if (isA)
loadFn = getLoadMatrixFn(
tensor, smemObj, mmaLayout, wpt /*wpt*/, 1 /*kOrder*/, kWidth,
{mmaInstrM, mmaInstrK} /*instrShape*/,
tensor, smemObj, mmaLayout, warpsPerTile /*warpsPerTile*/, 1 /*kOrder*/,
kWidth, {mmaInstrM, mmaInstrK} /*instrShape*/,
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, lane /*laneId*/,
vals /*vals*/, isA /*isA*/, typeConverter /* typeConverter */,
rewriter /*rewriter*/, loc /*loc*/);
else
loadFn = getLoadMatrixFn(
tensor, smemObj, mmaLayout, wpt /*wpt*/, 0 /*kOrder*/, kWidth,
{mmaInstrK, mmaInstrN} /*instrShape*/,
tensor, smemObj, mmaLayout, warpsPerTile /*warpsPerTile*/, 0 /*kOrder*/,
kWidth, {mmaInstrK, mmaInstrN} /*instrShape*/,
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, lane /*laneId*/,
vals /*vals*/, isA /*isA*/, typeConverter /* typeConverter */,
rewriter /*rewriter*/, loc /*loc*/);

View File

@@ -402,7 +402,14 @@ struct StoreOpConversion
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
auto &ptxStoreInstr =
ptxBuilder.create<>("st")->global().v(nWords).b(width);
ptxBuilder.create<>("st")
->global()
.o("L1::evict_first",
op.getEvict() == triton::EvictionPolicy::EVICT_FIRST)
.o("L1::evict_last",
op.getEvict() == triton::EvictionPolicy::EVICT_LAST)
.v(nWords)
.b(width);
ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b");
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
@@ -528,11 +535,6 @@ struct AtomicCASOpConversion
: valueTy;
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
Value mask = getMask(valueTy, rewriter, loc);
PTXBuilder ptxBuilderMemfence;
auto memfence = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
memfence();
auto ASMReturnTy = void_ty(ctx);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
@@ -546,7 +548,10 @@ struct AtomicCASOpConversion
auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, "r");
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r");
auto &atom = *ptxBuilderAtomicCAS.create<PTXInstr>("atom");
atom.global().o("cas").o("b32");
std::string semStr;
llvm::raw_string_ostream os(semStr);
os << op.getSem();
atom.global().o(semStr).o("cas").o("b32");
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask);
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
barrier();
@@ -557,8 +562,8 @@ struct AtomicCASOpConversion
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
st.shared().o("b32");
st(dstOprStore, valOprStore).predicate(mask);
auto ASMReturnTy = void_ty(ctx);
ptxBuilderStore.launch(rewriter, loc, ASMReturnTy);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
barrier();
Value ret = load(atomPtr);
barrier();
@@ -731,7 +736,7 @@ struct AtomicRMWOpConversion
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
MLIRContext *ctx = rewriter.getContext();
//
auto atomicRmwAttr = op.getAtomicRmwOp();
Value val = op.getVal();
@@ -832,7 +837,10 @@ struct AtomicRMWOpConversion
default:
return failure();
}
atom.o(rmwOp).o(sTy);
std::string semStr;
llvm::raw_string_ostream os(semStr);
os << op.getSem();
atom.o(semStr).o(rmwOp).o(sTy);
if (tensorTy) {
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
auto retType = vec == 1 ? valueElemTy : vecTy;
@@ -842,13 +850,13 @@ struct AtomicRMWOpConversion
vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii));
}
} else {
PTXBuilder ptxBuilderMemfence;
auto memfenc = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
memfenc();
auto ASMReturnTy = void_ty(ctx);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy);
if (op->user_begin() == op->user_end()) {
rewriter.replaceOp(op, {old});
return success();
}
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
// Only threads with rmwMask = True store the result

View File

@@ -253,8 +253,8 @@ struct SharedMemoryObject {
return offsets[order];
}
Value getBaseBeforeSwizzle(int order, Location loc,
ConversionPatternRewriter &rewriter) const {
Value getBaseBeforeSlice(int order, Location loc,
ConversionPatternRewriter &rewriter) const {
Value cSwizzleOffset = getCSwizzleOffset(order);
Value offset = sub(i32_val(0), cSwizzleOffset);
Type type = base.getType();

View File

@@ -6,6 +6,7 @@
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
@@ -309,11 +310,43 @@ struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
LogicalResult
matchAndRewrite(triton::CatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// For now, this behaves like generic, but this will evolve when
// we add support for `can_reorder=False`
Type retType = this->getTypeConverter()->convertType(op.getType());
// The cat op satisfy two conditions:
// 1. output.numel = lhs.numel + rhs.numel
// 2. output.total_elems_per_thread =
// next_power_of_2(lhs.total_elems_per_thread + rhs.total_elems_per_thread)
// For now, this behaves like generic, but this
// will evolve when we add support for `can_reorder=False`.
auto retType = this->getTypeConverter()
->convertType(op.getType())
.cast<RankedTensorType>();
auto retEncoding =
retType.getEncoding().cast<triton::gpu::BlockedEncodingAttr>();
auto lhsType = adaptor.getLhs().getType().cast<RankedTensorType>();
auto rhsType = adaptor.getRhs().getType().cast<RankedTensorType>();
auto lhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(lhsType);
auto rhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(rhsType);
auto retTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(retType);
auto retShape = retType.getShape();
auto retOrder = retEncoding.getOrder();
auto retSizePerThread = retEncoding.getSizePerThread();
auto retThreadsPerWarp = retEncoding.getThreadsPerWarp();
auto retWarpsPerCTA = retEncoding.getWarpsPerCTA();
// Get new retSizePerThread if ret elems per thread is not enough.
// We have to round it up to the next power of 2 due to triton's tensor size
// constraint.
auto newRetTotalElemsPerThread =
nextPowOf2(lhsTotalElemsPerThread + rhsTotalElemsPerThread);
auto newRetSizePerThread = retSizePerThread.vec();
newRetSizePerThread[retOrder[0]] *=
newRetTotalElemsPerThread / retTotalElemsPerThread;
triton::gpu::BlockedEncodingAttr newRetEncoding =
triton::gpu::BlockedEncodingAttr::get(getContext(), newRetSizePerThread,
retThreadsPerWarp, retWarpsPerCTA,
retOrder);
auto newRetType = RankedTensorType::get(retShape, retType.getElementType(),
newRetEncoding);
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::CatOp>(
op, retType, adaptor.getOperands()),
op, newRetType, adaptor.getOperands()),
adaptor.getAttributes());
return success();
}
@@ -393,7 +426,8 @@ struct TritonAtomicCASPattern
ConversionPatternRewriter &rewriter) const override {
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::AtomicCASOp>(
op, typeConverter->convertType(op.getType()),
adaptor.getPtr(), adaptor.getCmp(), adaptor.getVal()),
adaptor.getPtr(), adaptor.getCmp(), adaptor.getVal(),
op.getSem()),
adaptor.getAttributes());
return success();
}
@@ -409,7 +443,7 @@ struct TritonAtomicRMWPattern
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::AtomicRMWOp>(
op, typeConverter->convertType(op.getType()),
adaptor.getAtomicRmwOp(), adaptor.getPtr(),
adaptor.getVal(), adaptor.getMask()),
adaptor.getVal(), adaptor.getMask(), op.getSem()),
adaptor.getAttributes());
return success();
}

View File

@@ -7,6 +7,7 @@
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
@@ -410,9 +411,6 @@ bool isaDistributedLayout(Attribute layout) {
layout.isa<MfmaEncodingAttr>() || layout.isa<SliceEncodingAttr>();
}
} // namespace gpu
} // namespace triton
bool isSharedEncoding(Value value) {
auto type = value.getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
@@ -422,6 +420,9 @@ bool isSharedEncoding(Value value) {
return false;
}
} // namespace gpu
} // namespace triton
} // namespace mlir
static LogicalResult parseIntAttrValue(AsmParser &parser, Attribute attr,
@@ -1282,6 +1283,10 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
}
// cvt(cat) -> cat
if (auto cat = dyn_cast<triton::CatOp>(arg)) {
auto encoding =
op->getResult(0).getType().cast<RankedTensorType>().getEncoding();
if (isExpensiveCat(cat, encoding))
return mlir::failure();
rewriter.replaceOpWithNewOp<triton::CatOp>(op, op->getResult(0).getType(),
cat.getOperands());
return mlir::success();
@@ -1289,7 +1294,7 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
// cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2)
auto alloc_tensor = dyn_cast<triton::gpu::AllocTensorOp>(arg);
if (alloc_tensor) {
if (!isSharedEncoding(op->getResult(0))) {
if (!triton::gpu::isSharedEncoding(op->getResult(0))) {
return mlir::failure();
}
rewriter.replaceOpWithNewOp<triton::gpu::AllocTensorOp>(
@@ -1299,7 +1304,7 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
// cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2))
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
if (insert_slice) {
if (!isSharedEncoding(op->getResult(0))) {
if (!triton::gpu::isSharedEncoding(op->getResult(0))) {
return mlir::failure();
}
auto newType = op->getResult(0).getType().cast<RankedTensorType>();
@@ -1321,7 +1326,7 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
// cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2))
auto extract_slice = dyn_cast<triton::gpu::ExtractSliceOp>(arg);
if (extract_slice) {
if (!isSharedEncoding(op->getResult(0))) {
if (!triton::gpu::isSharedEncoding(op->getResult(0))) {
return mlir::failure();
}
auto origType =
@@ -1351,12 +1356,13 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
// cvt(cvt(x, type1), type2) -> cvt(x, type2)
if (llvm::isa<triton::gpu::ConvertLayoutOp>(arg)) {
if (arg->getOperand(0).getDefiningOp() &&
!isSharedEncoding(arg->getOperand(0)) &&
isSharedEncoding(op.getOperand()) &&
!isSharedEncoding(op.getResult())) {
!triton::gpu::isSharedEncoding(arg->getOperand(0)) &&
triton::gpu::isSharedEncoding(op.getOperand()) &&
!triton::gpu::isSharedEncoding(op.getResult())) {
return mlir::failure();
}
if (isSharedEncoding(op.getOperand()) && isSharedEncoding(op.getResult())) {
if (triton::gpu::isSharedEncoding(op.getOperand()) &&
triton::gpu::isSharedEncoding(op.getResult())) {
return mlir::failure();
}
auto srcType = op.getOperand().getType().cast<RankedTensorType>();

View File

@@ -7,7 +7,7 @@ mlir::OpTrait::impl::verifyResultsAreSharedEncoding(Operation *op) {
return failure();
for (auto result : op->getResults())
if (!isSharedEncoding(result))
if (!triton::gpu::isSharedEncoding(result))
return op->emitOpError() << "requires all results to be shared encoding";
return success();

View File

@@ -47,7 +47,7 @@ SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
};
auto slices = mlir::getSlice(dotOp, filter);
auto slices = mlir::getSlice(dotOp, {filter});
for (Operation *op : slices)
if (isa<triton::DotOp>(op) && (op != dotOp))
return {(unsigned)numWarps, 1};
@@ -234,8 +234,8 @@ public:
if (versionMajor == 1) {
SetVector<Operation *> aBwdSlices, bBwdSlices;
auto isCvt = [](Operation *op) { return isa<ConvertLayoutOp>(op); };
getBackwardSlice(a, &aBwdSlices, isCvt);
getBackwardSlice(b, &bBwdSlices, isCvt);
getBackwardSlice(a, &aBwdSlices, {isCvt});
getBackwardSlice(b, &bBwdSlices, {isCvt});
// get the source of the first conversion found in slices
auto getCvtArgOrder = [](Operation *op) {
return cast<ConvertLayoutOp>(op)

View File

@@ -14,8 +14,8 @@ using namespace mlir::triton;
template <class T> SmallVector<unsigned, 4> argSort(const T &arr) {
SmallVector<unsigned, 4> ret(arr.size());
std::iota(ret.begin(), ret.end(), 0);
std::sort(ret.begin(), ret.end(),
[&](unsigned x, unsigned y) { return arr[x] > arr[y]; });
std::stable_sort(ret.begin(), ret.end(),
[&](unsigned x, unsigned y) { return arr[x] > arr[y]; });
return ret;
}

View File

@@ -1,4 +1,3 @@
#include "Utility.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -6,6 +5,7 @@
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include <memory>
using namespace mlir;

File diff suppressed because it is too large Load Diff

View File

@@ -172,7 +172,7 @@ LogicalResult Prefetcher::initialize() {
break;
rets.push_back(op->getOperand(0));
if (auto cvt = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(op))
if (isSharedEncoding(cvt.getOperand())) {
if (triton::gpu::isSharedEncoding(cvt.getOperand())) {
foundConvertFromShared = true;
break;
}

View File

@@ -1,4 +1,3 @@
#include "Utility.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -16,6 +15,7 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include <memory>
@@ -84,6 +84,7 @@ public:
!dstParent.isa<triton::gpu::MfmaEncodingAttr>()))
return mlir::failure();
<<<<<<< HEAD
if (dstParent.isa<triton::gpu::MmaEncodingAttr>()) {
auto dstParentMma = dstParent.cast<triton::gpu::MmaEncodingAttr>();
if (dstParentMma.isVolta() || dstParentMma.getWarpsPerCTA()[1] > 1)
@@ -97,6 +98,15 @@ public:
return mlir::failure();
return processEncoding(dstParentMfma, convert, dstType, rewriter);
}
=======
auto tmpType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(), dstParentMma);
auto tmp = rewriter.create<triton::gpu::ConvertLayoutOp>(
convert.getLoc(), tmpType, convert.getOperand());
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(op, dstType,
tmp);
return mlir::success();
>>>>>>> oai/main
}
return mlir::failure();
}
@@ -373,13 +383,13 @@ public:
!(isa<triton::ReduceOp>(op) &&
!op->getResult(0).getType().isa<RankedTensorType>());
};
mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter);
mlir::getForwardSlice(cvt.getResult(), &cvtSlices, {filter});
if (cvtSlices.empty())
return failure();
for (Operation *op : cvtSlices) {
// don't rematerialize anything expensive
if (expensiveToRemat(op, dstEncoding))
if (isExpensiveToRemat(op, srcEncoding))
return failure();
// don't rematerialize non-element-wise
if (!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
@@ -428,8 +438,8 @@ public:
if (!op)
return mlir::failure();
// we don't want to rematerialize any conversion to/from shared
if (isSharedEncoding(cvt->getResults()[0]) ||
isSharedEncoding(cvt->getOperand(0)))
if (triton::gpu::isSharedEncoding(cvt->getResults()[0]) ||
triton::gpu::isSharedEncoding(cvt->getOperand(0)))
return mlir::failure();
// we don't handle conversions to DotOperandEncodingAttr
// this is a heuristics to accommodate fused attention

View File

@@ -1,10 +1,10 @@
#include "Utility.h"
#include "triton/Analysis/Utility.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
namespace mlir {
@@ -88,7 +88,7 @@ LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
return success();
}
bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
bool isExpensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
// Case 1: A size 1 tensor is not expensive since all threads will load the
// same
if (isSingleValue(op->getOperand(0)))
@@ -96,22 +96,34 @@ bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
// Case 2: Tensor of pointers has more threads than elements
// we can presume a high hit-rate that makes it cheap to load
auto ptrType = op->getOperand(0).getType().cast<RankedTensorType>();
IntegerAttr numWarps =
op->getParentOfType<ModuleOp>()->getAttrOfType<IntegerAttr>(
"triton_gpu.num-warps");
if (numWarps) {
int sizePerThread = triton::gpu::getTotalElemsPerThread(ptrType);
if (ptrType.getNumElements() < numWarps.getInt() * 32)
return false;
}
auto mod = op->getParentOfType<ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
if (ptrType.getNumElements() < numWarps * threadsPerWarp)
return false;
return true;
}
bool expensiveToRemat(Operation *op, Attribute &targetEncoding) {
bool isExpensiveCat(triton::CatOp cat, Attribute &targetEncoding) {
// If the new elements per thread is less than the old one, we will need to do
// convert encoding that goes through shared memory anyway. So we consider it
// as expensive.
auto tensorTy = cat.getResult().getType().cast<RankedTensorType>();
auto totalElemsPerThread = triton::gpu::getTotalElemsPerThread(tensorTy);
auto shape = tensorTy.getShape();
auto elemTy = tensorTy.getElementType();
auto newTotalElemsPerThread =
triton::gpu::getTotalElemsPerThread(targetEncoding, shape, elemTy);
return newTotalElemsPerThread < totalElemsPerThread;
}
bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) {
if (!op)
return true;
if (isa<triton::LoadOp, triton::StoreOp>(op))
return expensiveLoadOrStore(op, targetEncoding);
return isExpensiveLoadOrStore(op, targetEncoding);
if (isa<triton::CatOp>(op))
return isExpensiveCat(cast<triton::CatOp>(op), targetEncoding);
if (isa<tensor::ExtractSliceOp, triton::gpu::AllocTensorOp,
triton::gpu::InsertSliceAsyncOp, triton::AtomicRMWOp,
triton::AtomicCASOp, triton::DotOp>(op))
@@ -122,10 +134,11 @@ bool expensiveToRemat(Operation *op, Attribute &targetEncoding) {
return false;
}
bool canFoldConversion(Operation *op) {
bool canFoldConversion(Operation *op, Attribute &targetEncoding) {
if (isa<triton::CatOp>(op))
return !isExpensiveCat(cast<triton::CatOp>(op), targetEncoding);
return isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
triton::MakeRangeOp, triton::SplatOp, triton::ViewOp,
triton::CatOp>(*op);
triton::MakeRangeOp, triton::SplatOp, triton::ViewOp>(op);
}
int simulateBackwardRematerialization(
@@ -145,7 +158,7 @@ int simulateBackwardRematerialization(
queue.pop_back();
// If the current operation is expensive to rematerialize,
// we stop everything
if (expensiveToRemat(currOp, currLayout))
if (isExpensiveToRemat(currOp, currLayout))
break;
// A conversion will be removed here (i.e. transferred to operands)
numCvts -= 1;
@@ -173,7 +186,7 @@ int simulateBackwardRematerialization(
continue;
// If the conversion can be folded into opArgI then
// we don't count this conversion as expensive
if (canFoldConversion(opArgI))
if (canFoldConversion(opArgI, newEncoding))
continue;
// We add one expensive conversion for the current operand

View File

@@ -31,6 +31,17 @@ def get_build_type():
# TODO: change to release when stable enough
return "TritonRelBuildWithAsserts"
def get_codegen_backends():
backends = []
env_prefix = "TRITON_CODEGEN_"
for name, _ in os.environ.items():
if name.startswith(env_prefix) and check_env_flag(name):
assert name.count(env_prefix) <= 1
backends.append(name.replace(env_prefix, '').lower())
return backends
# --- third party packages -----
@@ -210,6 +221,11 @@ class CMakeBuild(build_ext):
cfg = get_build_type()
build_args = ["--config", cfg]
codegen_backends = get_codegen_backends()
if len(codegen_backends) > 0:
all_codegen_backends = ';'.join(codegen_backends)
cmake_args += ["-DTRITON_CODEGEN_BACKENDS=" + all_codegen_backends]
if platform.system() == "Windows":
cmake_args += [f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"]
if sys.maxsize > 2**32:
@@ -249,16 +265,14 @@ setup(
"triton/_C",
"triton/common",
"triton/compiler",
"triton/debugger",
"triton/interpreter",
"triton/language",
"triton/language/extra",
"triton/ops",
"triton/ops/blocksparse",
"triton/runtime",
"triton/runtime/backends",
"triton/third_party/cuda/bin",
"triton/third_party/cuda/include",
"triton/third_party/cuda/lib",
"triton/third_party",
"triton/tools",
],
install_requires=[

View File

@@ -92,6 +92,13 @@ void init_triton_ir(py::module &&m) {
.value("CG", mlir::triton::CacheModifier::CG)
.export_values();
py::enum_<mlir::triton::MemSemantic>(m, "MEM_SEMANTIC")
.value("ACQUIRE_RELEASE", mlir::triton::MemSemantic::ACQUIRE_RELEASE)
.value("ACQUIRE", mlir::triton::MemSemantic::ACQUIRE)
.value("RELEASE", mlir::triton::MemSemantic::RELEASE)
.value("RELAXED", mlir::triton::MemSemantic::RELAXED)
.export_values();
py::enum_<mlir::triton::EvictionPolicy>(m, "EVICTION_POLICY")
.value("NORMAL", mlir::triton::EvictionPolicy::NORMAL)
.value("EVICT_FIRST", mlir::triton::EvictionPolicy::EVICT_FIRST)
@@ -1370,7 +1377,7 @@ void init_triton_ir(py::module &&m) {
// // atomic
.def("create_atomic_cas",
[](mlir::OpBuilder &self, mlir::Value &ptr, mlir::Value &cmp,
mlir::Value &val) -> mlir::Value {
mlir::Value &val, mlir::triton::MemSemantic sem) -> mlir::Value {
auto loc = self.getUnknownLoc();
mlir::Type dstType;
if (auto srcTensorType =
@@ -1386,12 +1393,12 @@ void init_triton_ir(py::module &&m) {
dstType = ptrType.getPointeeType();
}
return self.create<mlir::triton::AtomicCASOp>(loc, dstType, ptr,
cmp, val);
cmp, val, sem);
})
.def("create_atomic_rmw",
[](mlir::OpBuilder &self, mlir::triton::RMWOp rmwOp,
mlir::Value &ptr, mlir::Value &val,
mlir::Value &mask) -> mlir::Value {
mlir::Value &ptr, mlir::Value &val, mlir::Value &mask,
mlir::triton::MemSemantic sem) -> mlir::Value {
auto loc = self.getUnknownLoc();
mlir::Type dstType;
if (auto srcTensorType =
@@ -1407,7 +1414,7 @@ void init_triton_ir(py::module &&m) {
dstType = ptrType.getPointeeType();
}
return self.create<mlir::triton::AtomicRMWOp>(loc, dstType, rmwOp,
ptr, val, mask);
ptr, val, mask, sem);
})
// External
.def("create_extern_elementwise",

View File

@@ -0,0 +1,42 @@
#include <Python.h>
#include <stdio.h>
#include <stdlib.h>
static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
// create a struct to hold device properties
return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", 1024,
"multiprocessor_count", 16, "sm_clock_rate", 2100,
"mem_clock_rate", 2300, "mem_bus_width", 2400);
}
static PyObject *loadBinary(PyObject *self, PyObject *args) {
// get allocated registers and spilled registers from the function
int n_regs = 0;
int n_spills = 0;
int mod = 0;
int fun = 0;
return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs,
n_spills);
}
static PyMethodDef ModuleMethods[] = {
{"load_binary", loadBinary, METH_VARARGS,
"Load dummy binary for the extension device"},
{"get_device_properties", getDeviceProperties, METH_VARARGS,
"Get the properties for the extension device"},
{NULL, NULL, 0, NULL} // sentinel
};
static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "ext_utils",
NULL, // documentation
-1, // size
ModuleMethods};
PyMODINIT_FUNC PyInit_ext_utils(void) {
PyObject *m = PyModule_Create(&ModuleDef);
if (m == NULL) {
return NULL;
}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}

View File

@@ -0,0 +1,262 @@
import functools
import hashlib
import importlib
import os
import shutil
import subprocess
import sysconfig
import tempfile
from pathlib import Path
import setuptools
import torch
import triton
import triton.language as tl
from triton.common.backend import BaseBackend, register_backend
from triton.common.build import quiet
from triton.compiler.make_launcher import make_so_cache_key
from triton.runtime.cache import get_cache_manager
from triton.runtime.driver import DriverBase
from triton.runtime.jit import version_key
def build_for_backend(name, src, srcdir):
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
# try to avoid setuptools if possible
cc = os.environ.get("CC")
if cc is None:
# TODO: support more things here.
clang = shutil.which("clang")
gcc = shutil.which("gcc")
cc = gcc if gcc is not None else clang
if cc is None:
raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.")
# This function was renamed and made public in Python 3.10
if hasattr(sysconfig, 'get_default_scheme'):
scheme = sysconfig.get_default_scheme()
else:
scheme = sysconfig._get_default_scheme()
# 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install
# path changes to include 'local'. This change is required to use triton with system-wide python.
if scheme == 'posix_local':
scheme = 'posix_prefix'
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
ret = subprocess.check_call([cc, src, f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-o", so])
if ret == 0:
return so
# fallback on setuptools
extra_compile_args = []
library_dirs = []
include_dirs = [srcdir]
libraries = []
# extra arguments
extra_link_args = []
# create extension module
ext = setuptools.Extension(
name=name,
language='c',
sources=[src],
include_dirs=include_dirs,
extra_compile_args=extra_compile_args + ['-O3'],
extra_link_args=extra_link_args,
library_dirs=library_dirs,
libraries=libraries,
)
# build extension module
args = ['build_ext']
args.append('--build-temp=' + srcdir)
args.append('--build-lib=' + srcdir)
args.append('-q')
args = dict(
name=name,
ext_modules=[ext],
script_args=args,
)
with quiet():
setuptools.setup(**args)
return so
class ExtensionUtils:
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(ExtensionUtils, cls).__new__(cls)
return cls.instance
def __init__(self):
dirname = os.path.dirname(os.path.realpath(__file__))
src = Path(os.path.join(dirname, "extension_backend.c")).read_text()
key = hashlib.md5(src.encode("utf-8")).hexdigest()
cache = get_cache_manager(key)
fname = "ext_utils.so"
cache_path = cache.get_file(fname)
if cache_path is None:
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "main.c")
with open(src_path, "w") as f:
f.write(src)
so = build_for_backend("ext_utils", src_path, tmpdir)
with open(so, "rb") as f:
cache_path = cache.put(f.read(), fname, binary=True)
import importlib.util
spec = importlib.util.spec_from_file_location("ext_utils", cache_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
self.load_binary = mod.load_binary
self.get_device_properties = mod.get_device_properties
class ExtensionDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(ExtensionDriver, cls).__new__(cls)
return cls.instance
def __init__(self):
self.utils = ExtensionUtils()
class ExtensionBackend(BaseBackend):
stub_so_path = ""
def __init__(self, device_type: str) -> None:
super(ExtensionBackend, self).__init__(device_type)
self.driver = ExtensionDriver()
def add_stages(self, arch, extern_libs, stages):
filter_in_stages = ["ast", "ttir", "ttgir"]
filter_out_stages = []
for key, _ in stages.items():
if key not in filter_in_stages:
filter_out_stages.append(key)
for filter_out_key in filter_out_stages:
stages.pop(filter_out_key)
def add_meta_info(self, ir, cur_module, next_module, metadata, asm):
metadata["name"] = "extension_backend_name"
def get_driver(self):
return self.driver
def get_stream(self):
return ""
@functools.lru_cache(None)
def get_device_properties(self, device):
return self.driver.utils.get_device_properties()
def get_current_device(self):
return torch.device("cpu")
def set_current_device(self, device):
pass
def get_load_binary_fn(self):
return self.driver.utils.load_binary
def get_kernel_bin(self):
return "ttgir"
def get_architecture_descriptor(self, **kwargs):
return ""
def make_launcher_stub(self, name, signature, constants):
# name of files that are cached
so_cache_key = make_so_cache_key(version_key(), signature, constants)
so_cache_manager = get_cache_manager(so_cache_key)
so_name = f"{name}.so"
# retrieve stub from cache if it exists
cache_path = so_cache_manager.get_file(so_name)
if cache_path is None:
with tempfile.TemporaryDirectory() as tmpdir:
src = self._generate_launcher(constants, signature)
src_path = os.path.join(tmpdir, "main.c")
with open(src_path, "w") as f:
f.write(src)
so = build_for_backend(name, src_path, tmpdir)
with open(so, "rb") as f:
so_path = so_cache_manager.put(f.read(), so_name, binary=True)
type(self).stub_so_path = so_path
return so_path
else:
type(self).stub_so_path = cache_path
return cache_path
def _generate_launcher(self, constants, signature):
# generate glue code
src = """
#define __EXTENSION_BACKEND__
#include <Python.h>
#include <stdio.h>
static PyObject* launch_counter(PyObject* self, PyObject* args) {
static int64_t launch_counter = 0;
launch_counter += 1;
return PyLong_FromLong(launch_counter);
}
static PyObject* launch(PyObject* self, PyObject* args) {
if (PyErr_Occurred()) {
return NULL;
}
launch_counter(self, args);
// return None
Py_INCREF(Py_None);
return Py_None;
}
static PyMethodDef ModuleMethods[] = {
{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"},
{"launch_counter", launch_counter, METH_VARARGS, "Entry point to get launch counter"},
{NULL, NULL, 0, NULL} // sentinel
};
static struct PyModuleDef ModuleDef = {
PyModuleDef_HEAD_INIT,
\"__triton_launcher\",
NULL, //documentation
-1, //size
ModuleMethods
};
PyMODINIT_FUNC PyInit___triton_launcher(void) {
PyObject *m = PyModule_Create(&ModuleDef);
if(m == NULL) {
return NULL;
}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}
"""
return src
def test_dummy_backend():
register_backend("cpu", ExtensionBackend)
@triton.jit
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
xnumel = 10
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask)
inp = torch.randn(10)
out = torch.randn(10)
kernel[(10,)](inp, out, 10, XBLOCK=16)
spec = importlib.util.spec_from_file_location("__triton_launcher", ExtensionBackend.stub_so_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
launch_counter = getattr(mod, "launch_counter")
for _ in range(100):
kernel[(10,)](inp, out, 10, XBLOCK=16)
assert launch_counter() > 0

View File

@@ -0,0 +1,14 @@
# content of conftest.py
import pytest
def pytest_addoption(parser):
parser.addoption(
"--backend", action="store", default="", help="Codegen backend"
)
@pytest.fixture
def cmdopt(request):
return request.config.getoption("--backend")

View File

@@ -0,0 +1,33 @@
import torch
import triton
import triton.language as tl
def test_xpu_backend(cmdopt):
if cmdopt == "xpu":
has_ipex = False
try:
# Import IPEX to provide Intel GPU runtime
import intel_extension_for_pytorch # type: ignore # noqa: F401
has_ipex = True if hasattr(torch, "xpu") else False
except Exception:
has_ipex = False
@triton.jit()
def kernel(x_ptr, y_ptr, out_ptr):
pid = tl.program_id(axis=0)
x = tl.load(x_ptr + pid)
y = tl.load(y_ptr + pid)
out = x + y
tl.store(out_ptr + pid, out)
if has_ipex:
for _ in range(1000):
x = torch.randn((65536,), device="xpu", dtype=torch.float32)
y = torch.randn((65536,), device="xpu", dtype=torch.float32)
z = torch.zeros((65536,), device="xpu", dtype=torch.float32)
kernel[(65536,)](x, y, z, num_warps=32)
assert torch.all(x + y == z)
else:
return

View File

@@ -69,7 +69,7 @@ matmul_data = {
(16, 1024, 1024): {'float16': 0.0077, 'float32': 0.0127, 'int8': 0.005},
(16, 4096, 4096): {'float16': 0.044, 'float32': 0.0457, 'int8': 0.0259},
(16, 8192, 8192): {'float16': 0.07, 'float32': 0.0648, 'int8': 0.0431},
(64, 1024, 1024): {'float16': 0.030, 'float32': 0.0509, 'int8': 0.0169},
(64, 1024, 1024): {'float16': 0.028, 'float32': 0.0509, 'int8': 0.0169},
(64, 4096, 4096): {'float16': 0.163, 'float32': 0.162, 'int8': 0.097},
(64, 8192, 8192): {'float16': 0.285, 'float32': 0.257, 'int8': 0.174},
(1024, 64, 1024): {'float16': 0.033, 'float32': 0.0458, 'int8': 0.017},

View File

@@ -4,7 +4,7 @@ import torch
import triton
import triton.language as tl
from triton.debugger.debugger import program_ids_from_grid
from triton.interpreter.interpreter import program_ids_from_grid
def test_addition():

View File

@@ -51,8 +51,13 @@ def test_assert(func: str):
kernel_device_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
kernel_device_assert_scalar[(1,)](x, y, num_warps=2, BLOCK=shape[0])
elif func == "no_debug":
<<<<<<< HEAD
# TRITON_DEBUG=True can override the debug flag
kernel_device_assert_no_debug[(1,)](x, y, num_warps=2, BLOCK=shape[0])
=======
# TRITON_DEBUG=1 can override the debug flag
kernel_device_assert_no_debug[(1,)](x, y, BLOCK=shape[0])
>>>>>>> oai/main
elif func == "assert":
kernel_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
elif func == "static_assert":

View File

@@ -937,15 +937,16 @@ def test_noinline(mode):
# ---------------
# test atomics
# ---------------
@pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
@pytest.mark.parametrize("op, dtype_x_str, mode, sem", itertools.chain.from_iterable([
[
('add', 'float16', mode),
('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode),
('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode),
('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode),
('add', 'float16', mode, sem),
('add', 'uint32', mode, sem), ('add', 'int32', mode, sem), ('add', 'float32', mode, sem),
('max', 'uint32', mode, sem), ('max', 'int32', mode, sem), ('max', 'float32', mode, sem),
('min', 'uint32', mode, sem), ('min', 'int32', mode, sem), ('min', 'float32', mode, sem),
]
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']
for sem in [None, 'acquire', 'release', 'acq_rel', 'relaxed']]))
def test_atomic_rmw(op, dtype_x_str, mode, sem, device='cuda'):
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
if dtype_x_str == 'float16':
@@ -959,7 +960,8 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
x = tl.load(X + pid)
old = GENERATE_TEST_HERE
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
sem_arg = sem if sem is None else f'"{sem}"'
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x, sem={sem_arg})'})
numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op]
max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min
min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max
@@ -981,7 +983,7 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
x_tri = to_triton(x, device=device)
z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device)
kernel[(n_programs, )](x_tri, z_tri)
h = kernel[(n_programs, )](x_tri, z_tri)
# torch result
z_ref = numpy_op(x).astype(getattr(np, dtype_x_str))
# compare
@@ -990,6 +992,8 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
assert z_ref.item() == to_numpy(z_tri).item()
else:
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
sem_str = "acq_rel" if sem is None else sem
assert f"atom.global.gpu.{sem_str}" in h.asm["ptx"]
def test_atomic_rmw_predicate(device="cuda"):
@@ -1047,7 +1051,8 @@ def test_tensor_atomic_rmw_block(device="cuda"):
assert torch.min(x).item() == 0.0
def test_atomic_cas():
@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed'])
def test_atomic_cas(sem):
# 1. make sure that atomic_cas changes the original value (Lock)
@triton.jit
def change_value(Lock):
@@ -1060,9 +1065,9 @@ def test_atomic_cas():
# 2. only one block enters the critical section
@triton.jit
def serialized_add(data, Lock):
def serialized_add(data, Lock, SEM: tl.constexpr):
ptrs = data + tl.arange(0, 128)
while tl.atomic_cas(Lock, 0, 1) == 1:
while tl.atomic_cas(Lock, 0, 1, SEM) == 1:
pass
tl.store(ptrs, tl.load(ptrs) + 1.0)
@@ -1073,8 +1078,10 @@ def test_atomic_cas():
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
data = torch.zeros((128,), device='cuda', dtype=torch.float32)
ref = torch.full((128,), 64.0)
serialized_add[(64,)](data, Lock)
h = serialized_add[(64,)](data, Lock, SEM=sem)
sem_str = "acq_rel" if sem is None else sem
np.testing.assert_allclose(to_numpy(data), to_numpy(ref))
assert f"atom.global.{sem_str}" in h.asm["ptx"]
# ---------------
@@ -1140,6 +1147,28 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
assert to_numpy(z_tri) == z_ref
@pytest.mark.parametrize("dtype_str, num_warps", [(dtype_str, num_warps) for dtype_str in int_dtypes + float_dtypes for num_warps in [4, 8]])
def test_cat(dtype_str, num_warps):
check_type_supported(dtype_str)
@triton.jit
def kernel(X, Y, Z, N: tl.constexpr):
offs = tl.arange(0, N)
x = tl.load(X + offs)
y = tl.load(Y + offs)
z = tl.cat(x, y, can_reorder=True)
tl.store(Z + tl.arange(0, 2 * N), z)
x = torch.arange(0, 128, device='cuda').to(getattr(torch, dtype_str))
y = torch.arange(-128, 0, device='cuda').to(getattr(torch, dtype_str))
z_ref = torch.cat([x, y], dim=0).sum()
z = torch.zeros((256,), dtype=getattr(torch, dtype_str), device='cuda')
kernel[(1, )](x, y, z, N=128, num_warps=num_warps)
assert z.sum() == z_ref
# check if there's no duplicate value in z
assert z.unique().size(0) == z.size(0)
@pytest.mark.parametrize("dtype_str", list(torch_dtypes))
def test_store_constant(dtype_str):
check_type_supported(dtype_str)
@@ -1197,8 +1226,8 @@ def convert_float_to_float32(fp: torch.tensor, dtype=None):
# special cases, exp is 0b11..1
if dtype == tl.float8e4:
# float8e4m3 does not have infinities
output[fp == torch.tensor(0b01111111, dtype=torch.int8)] = torch.nan
output[fp == torch.tensor(0b11111111, dtype=torch.int8)] = torch.nan
output[fp == 0b01111111] = torch.nan
output[fp == 0b11111111] = torch.nan
else:
output = torch.where(exp == (1 << exp_width) - 1,
((sign << (tl.float32.primitive_bitwidth - 1)) | extended_exp | (frac << (tl.float32.fp_mantissa_width - dtype.fp_mantissa_width))).view(torch.float32),
@@ -1338,7 +1367,11 @@ def get_reduced_dtype(dtype_str, op):
@pytest.mark.parametrize("op, dtype_str, shape",
[(op, dtype, shape)
for op in ['min', 'max', 'sum', 'argmin', 'argmax']
for op in ['min', 'max',
'min-with-indices',
'max-with-indices',
'argmin', 'argmax',
'sum']
for dtype in dtypes_with_bfloat16
for shape in [32, 64, 128, 512]])
def test_reduce1d(op, dtype_str, shape, device='cuda'):
@@ -1348,15 +1381,22 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'):
@triton.jit
def kernel(X, Z, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.store(Z, GENERATE_TEST_HERE)
GENERATE_TEST_HERE
tl.store(Z, z)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=0)'})
if 'with-indices' in op:
patch = f'z, _ = tl.{op.split("-")[0]}(x, axis=0, return_indices=True)'
else:
patch = f'z = tl.{op}(x, axis=0)'
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': patch})
# input
rs = RandomState(17)
# limit the range of integers so that the sum does not overflow
x = numpy_random((shape,), dtype_str=dtype_str, rs=rs)
x_tri = to_triton(x, device=device)
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
'max-with-indices': np.max,
'min-with-indices': np.min,
'argmin': np.argmin, 'argmax': np.argmax}[op]
# numpy result
z_dtype_str = 'int32' if op in ('argmin', 'argmax') else dtype_str
@@ -2444,11 +2484,11 @@ def test_call(type):
# -------------
@pytest.mark.parametrize("if_type", ["if", "if_exp", "if_and"])
@pytest.mark.parametrize("if_type", ["if", "if_exp", "if_and_dynamic", "if_and_static"])
def test_if(if_type):
@triton.jit
def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr):
def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr, StaticVaue: tl.constexpr):
pid = tl.program_id(0)
cond = tl.load(Cond)
if IfType == "if":
@@ -2458,17 +2498,22 @@ def test_if(if_type):
tl.store(Ret, tl.load(XFalse))
elif IfType == "if_exp":
tl.store(Ret, tl.load(XTrue)) if pid % 2 else tl.store(Ret, tl.load(XFalse))
elif IfType == "if_and":
elif IfType == "if_and_dynamic":
if BoolVar and pid % 2 == 0:
tl.store(Ret, tl.load(XTrue))
else:
tl.store(Ret, tl.load(XFalse))
elif IfType == "if_and_static":
if StaticVaue != 0 and StaticVaue != 0:
tl.store(Ret, tl.load(XTrue))
else:
tl.store(Ret, tl.load(XFalse))
cond = torch.ones(1, dtype=torch.int32, device='cuda')
x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda')
x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda')
ret = torch.empty(1, dtype=torch.float32, device='cuda')
kernel[(1,)](cond, x_true, x_false, ret, if_type, True)
kernel[(1,)](cond, x_true, x_false, ret, if_type, True, 1)
assert torch.equal(ret, x_true)

View File

@@ -4,75 +4,103 @@ import pytest
import torch
import triton
import triton.language as tl
import triton.ops
def f8_to_f16(x):
@triton.jit
def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offs < N
x = tl.load(X + offs, mask=mask)
y = x.to(tl.float8e5)
tl.store(Y + offs, y, mask=mask)
ret = torch.empty(x.shape, dtype=torch.float16, device=x.device)
grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']),)
kernel[grid](ret, triton.reinterpret(x, tl.float8e5), ret.numel(), BLOCK_SIZE=1024)
return ret
@pytest.mark.parametrize(
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE",
itertools.chain(
*[
[
# 1 warp
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
# 2 warp
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE),
# 4 warp
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
# 8 warp
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE),
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE),
# split-k
(64, 64, 16, 2, 4, 2, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 2, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
# variable input
(128, 128, 32, 1, 4, 2, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 2, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE),
(128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE, DTYPE),
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]
],
# n-stage
*[
[
(16, 16, 16, 1, 1, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(64, 32, 64, 1, 2, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 64, 16, 1, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(256, 128, 32, 1, 8, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 128, 32, 1, 4, STAGES, 384, 128, 640, AT, BT, DTYPE),
(16, 16, 16, 1, 1, STAGES, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE),
(64, 32, 64, 1, 2, STAGES, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE),
(128, 64, 16, 1, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE),
(256, 128, 32, 1, 8, STAGES, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE),
(128, 128, 32, 1, 4, STAGES, 384, 128, 640, AT, BT, DTYPE, DTYPE),
# split-k
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE),
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE),
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE, DTYPE),
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4]
],
# mixed-precision
*[
[
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE),
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, ADTYPE, BDTYPE),
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE),
(32, 64, 16, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE),
(128, 128, 32, 8, 4, 2, 1024, 1024, 1024, AT, BT, ADTYPE, BDTYPE),
] for ADTYPE, BDTYPE in [("float8", "float16")] for AT in [False, True] for BT in [False, True]
]
),
)
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE):
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
if capability[0] < 8 and DTYPE == "bfloat16":
if capability[0] < 8 and (ADTYPE == "bfloat16" or BDTYPE == "bfloat16"):
pytest.skip("Only test bfloat16 on devices with sm >= 80")
if DTYPE == "bfloat16" and SPLIT_K != 1:
if (ADTYPE == "bfloat16" or BDTYPE == "bfloat16") and SPLIT_K != 1:
pytest.skip("bfloat16 matmuls don't allow split_k for now")
torch.manual_seed(0)
# nuke kernel decorators -- will set meta-parameters manually
@@ -87,12 +115,18 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K * SPLIT_K if K is None else K
def get_input(n, m, t, dtype):
if t:
return get_input(m, n, False, dtype).t()
if dtype == "float8":
x = torch.randint(10, 50, (n, m), device="cuda", dtype=torch.int8)
return f8_to_f16(x)
dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype]
return .1 * torch.randn((n, m), device="cuda", dtype=dtype)
# allocate/transpose inputs
DTYPE = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[DTYPE]
a = .1 * torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
b = .1 * torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
a = a.t() if AT else a
b = b.t() if BT else b
a = get_input(M, K, AT, ADTYPE)
b = get_input(K, N, BT, BDTYPE)
# run test
th_c = torch.matmul(a, b)
try:

View File

@@ -38,7 +38,7 @@ def test_memory_leak() -> None:
kernel[(10,)](inp, out, 10, XBLOCK=16)
gc.collect()
end, _ = tracemalloc.get_traced_memory()
assert end - begin < 1000
assert end - begin < 5000
finally:
tracemalloc.stop()

View File

@@ -18,7 +18,6 @@ from .runtime import (
)
from .runtime.jit import jit
from .compiler import compile, CompilationError
from .debugger.debugger import program_ids_from_grid
from . import language
from . import testing
@@ -43,7 +42,6 @@ __all__ = [
"runtime",
"TensorWrapper",
"testing",
"program_ids_from_grid",
]
@@ -52,11 +50,11 @@ __all__ = [
# into any specific module
# -------------------------------------
def cdiv(x, y):
def cdiv(x: int, y: int):
return (x + y - 1) // y
def next_power_of_2(n):
def next_power_of_2(n: int):
"""Return the smallest power of 2 greater than or equal to n"""
n -= 1
n |= n >> 1
@@ -64,5 +62,6 @@ def next_power_of_2(n):
n |= n >> 4
n |= n >> 8
n |= n >> 16
n |= n >> 32
n += 1
return n

View File

@@ -0,0 +1,96 @@
import importlib
import importlib.util
from typing import Dict
from ..runtime.driver import DriverBase
class BaseBackend:
def __init__(self, device_type: str) -> None:
self.device_type = device_type
def add_stages(self, arch, extern_libs, stages):
"""
Custom the arch, extern_libs and stages per backend specific requirement
"""
raise NotImplementedError
def add_meta_info(self, ir, cur_module, next_module, metadata, asm):
"""
Custom the ir, module, metadata and asm per backend specific requirement
"""
raise NotImplementedError
def get_load_binary_fn(self):
"""
Return a callable to load binary
"""
raise NotImplementedError
def get_driver(self) -> DriverBase:
"""
Get the backend driver. Please refer to "DriverBase" for more details
"""
raise NotImplementedError
def get_stream(self):
"""
Get stream for current device
"""
raise NotImplementedError
def get_device_properties(self, device):
raise NotImplementedError
def get_current_device(self):
"""
Get current device
"""
raise NotImplementedError
def set_current_device(self, device):
"""
Set current device as the given device
"""
raise NotImplementedError
def get_kernel_bin(self):
raise NotImplementedError
def make_launcher_stub(self, name, signature, constants):
"""
Generate the launcher stub to launch the kernel
"""
raise NotImplementedError
def get_architecture_descriptor(self, **kwargs):
"""
Get the architecture descriptor the backend
"""
raise NotImplementedError
@classmethod
def create_backend(cls, device_type: str):
return cls(device_type)
_backends: Dict[str, BaseBackend] = {}
def register_backend(device_type: str, backend_cls: type):
if device_type not in _backends:
_backends[device_type] = backend_cls.create_backend(device_type)
def get_backend(device_type: str):
if device_type not in _backends:
device_backend_package_name = f"triton.third_party.{device_type}"
if importlib.util.find_spec(device_backend_package_name):
try:
importlib.import_module(device_backend_package_name)
except Exception:
return None
else:
return None
return _backends[device_type] if device_type in _backends else None

View File

@@ -6,12 +6,12 @@ import warnings
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
from .. import language
from .._C.libtriton.triton import ir
from ..language import constexpr, tensor
# ideally we wouldn't need any runtime component
from ..runtime import JITFunction
from .errors import (CompilationError, CompileTimeAssertionFailure,
UnsupportedLanguageConstruct)
from triton._C.libtriton.triton import ir
def mangle_ty(ty):
@@ -595,12 +595,14 @@ class CodeGenerator(ast.NodeVisitor):
def visit_Compare(self, node):
if not (len(node.comparators) == 1 and len(node.ops) == 1):
raise UnsupportedLanguageConstruct(None, node, "simultaneous multiple comparison is not supported")
lhs = _unwrap_if_constexpr(self.visit(node.left))
rhs = _unwrap_if_constexpr(self.visit(node.comparators[0]))
lhs = self.visit(node.left)
rhs = self.visit(node.comparators[0])
lhs_value = _unwrap_if_constexpr(lhs)
rhs_value = _unwrap_if_constexpr(rhs)
if type(node.ops[0]) == ast.Is:
return constexpr(lhs is rhs)
return constexpr(lhs_value is rhs_value)
if type(node.ops[0]) == ast.IsNot:
return constexpr(lhs is not rhs)
return constexpr(lhs_value is not rhs_value)
method_name = self._method_name_for_comp_op.get(type(node.ops[0]))
if method_name is None:
raise UnsupportedLanguageConstruct(None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
@@ -988,7 +990,7 @@ class CodeGenerator(ast.NodeVisitor):
if not (0 < arg_count <= 2) or len(node.keywords):
raise TypeError("`static_assert` requires one or two positional arguments only")
passed = self.visit(node.args[0])
passed = _unwrap_if_constexpr(self.visit(node.args[0]))
if not isinstance(passed, bool):
raise NotImplementedError("Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values")
if not passed:

View File

@@ -11,12 +11,19 @@ from collections import namedtuple
from pathlib import Path
from typing import Any, Tuple
import triton
import triton._C.libtriton.triton as _triton
from ..runtime import driver
# import triton
from .._C.libtriton.triton import (add_external_libs, compile_ptx_to_cubin,
get_shared_memory_size, ir,
translate_llvmir_to_hsaco, translate_llvmir_to_ptx,
translate_triton_gpu_to_llvmir)
from ..common.backend import get_backend
# from ..runtime import driver, jit, JITFunction
# TODO: runtime.errors
from ..runtime.autotuner import OutOfResources
from ..runtime.cache import get_cache_manager
from ..runtime.driver import driver
from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device,
get_device_capability, version_key)
from ..tools.disasm import extract
from .code_generator import ast_to_ttir
from .make_launcher import make_stub
@@ -24,7 +31,7 @@ from .make_launcher import make_stub
CUDA_DEFAULT_WARP_SIZE = 32
def inline_triton_ir(mod):
pm = _triton.ir.pass_manager(mod.context)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_inliner_pass()
pm.run(mod)
@@ -34,7 +41,7 @@ def inline_triton_ir(mod):
def ttir_compute_capability_rewrite(mod, arch):
# For hardware without support, we must rewrite all load/store
# with block (tensor) pointers into tensors of pointers
pm = _triton.ir.pass_manager(mod.context)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
if _is_cuda(arch):
pm.add_rewrite_tensor_pointer_pass(arch)
@@ -45,7 +52,7 @@ def ttir_compute_capability_rewrite(mod, arch):
def optimize_ttir(mod, arch):
mod = inline_triton_ir(mod)
mod = ttir_compute_capability_rewrite(mod, arch)
pm = _triton.ir.pass_manager(mod.context)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_inliner_pass()
pm.add_triton_combine_pass()
@@ -57,15 +64,22 @@ def optimize_ttir(mod, arch):
return mod
<<<<<<< HEAD
def ttir_to_ttgir(mod, num_warps, warpsize):
pm = _triton.ir.pass_manager(mod.context)
pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize)
=======
def ttir_to_ttgir(mod, num_warps):
pm = ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_convert_triton_to_tritongpu_pass(num_warps)
>>>>>>> oai/main
pm.run(mod)
return mod
def optimize_ttgir(mod, num_stages, arch):
pm = _triton.ir.pass_manager(mod.context)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_tritongpu_coalesce_pass()
pm.add_tritongpu_remove_layout_conversions_pass()
@@ -92,7 +106,7 @@ def _add_external_libs(mod, libs):
for name, path in libs.items():
if len(name) == 0 or len(path) == 0:
return
_triton.add_external_libs(mod, list(libs.keys()), list(libs.values()))
add_external_libs(mod, list(libs.keys()), list(libs.values()))
def ttgir_to_llir(mod, extern_libs, arch):
@@ -100,9 +114,9 @@ def ttgir_to_llir(mod, extern_libs, arch):
_add_external_libs(mod, extern_libs)
# TODO: separate tritongpu_to_llvmir for different backends
if _is_cuda(arch):
return _triton.translate_triton_gpu_to_llvmir(mod, arch, False)
return translate_triton_gpu_to_llvmir(mod, arch, False)
else:
return _triton.translate_triton_gpu_to_llvmir(mod, 0, True)
return translate_triton_gpu_to_llvmir(mod, 0, True)
# PTX translation
@@ -132,8 +146,9 @@ def path_to_ptxas():
]
for ptxas in paths:
if os.path.exists(ptxas) and os.path.isfile(ptxas):
result = subprocess.check_output([ptxas, "--version"], stderr=subprocess.STDOUT)
ptxas_bin = ptxas.split(" ")[0]
if os.path.exists(ptxas_bin) and os.path.isfile(ptxas_bin):
result = subprocess.check_output([ptxas_bin, "--version"], stderr=subprocess.STDOUT)
if result is not None:
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
if version is not None:
@@ -150,7 +165,7 @@ def llir_to_ptx(mod: Any, arch: int, ptx_version: int = None) -> str:
if ptx_version is None:
_, cuda_version = path_to_ptxas()
ptx_version = ptx_get_version(cuda_version)
return _triton.translate_llvmir_to_ptx(mod, arch, ptx_version)
return translate_llvmir_to_ptx(mod, arch, ptx_version)
def ptx_to_cubin(ptx: str, arch: int):
@@ -161,7 +176,7 @@ def ptx_to_cubin(ptx: str, arch: int):
:return: str
'''
ptxas, _ = path_to_ptxas()
return _triton.compile_ptx_to_cubin(ptx, ptxas, arch)
return compile_ptx_to_cubin(ptx, ptxas, arch)
# AMDGCN translation
@@ -228,7 +243,7 @@ def llir_to_amdgcn_and_hsaco(mod: Any, gfx_arch: str, gfx_triple: str, gfx_featu
- AMDGCN code
- Path to HSACO object
'''
return _triton.translate_llvmir_to_hsaco(mod, gfx_arch, gfx_triple, gfx_features)
return translate_llvmir_to_hsaco(mod, gfx_arch, gfx_triple, gfx_features)
# ------------------------------------------------------------------------------
@@ -255,7 +270,7 @@ def convert_type_repr(x):
def make_hash(fn, arch, **kwargs):
if isinstance(fn, triton.runtime.JITFunction):
if isinstance(fn, JITFunction):
configs = kwargs["configs"]
signature = kwargs["signature"]
constants = kwargs.get("constants", dict())
@@ -268,7 +283,7 @@ def make_hash(fn, arch, **kwargs):
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{debug}-{arch}"
return hashlib.md5(key.encode("utf-8")).hexdigest()
assert isinstance(fn, str)
return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest()
return hashlib.md5((Path(fn).read_text() + version_key()).encode("utf-8")).hexdigest()
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
@@ -312,7 +327,7 @@ def _get_jsonable_constants(constants):
def parse_mlir_module(path, context):
module = _triton.ir.parse_mlir_module(path, context)
module = ir.parse_mlir_module(path, context)
# module takes ownership of the context
module.context = context
return module
@@ -341,8 +356,8 @@ def get_architecture_descriptor(capability):
raise ImportError("Triton requires PyTorch to be installed")
if capability is None:
if torch.version.hip is None:
device = triton.runtime.jit.get_current_device()
capability = triton.runtime.jit.get_device_capability(device)
device = get_current_device()
capability = get_device_capability(device)
capability = capability[0] * 10 + capability[1]
else:
capability = get_amdgpu_arch_fulldetails()
@@ -375,6 +390,7 @@ def add_cuda_stages(arch, extern_libs, stages):
def compile(fn, **kwargs):
<<<<<<< HEAD
if is_hip():
capability = None
else:
@@ -384,6 +400,22 @@ def compile(fn, **kwargs):
warp_size = CUDA_DEFAULT_WARP_SIZE if _is_cuda(arch) else arch[3]
context = _triton.ir.context()
asm = dict()
=======
# Get device type to decide which backend should be used
device_type = kwargs.get("device_type", "cuda")
_device_backend = get_backend(device_type)
if device_type in ["cuda", "hip"]:
arch = get_architecture_descriptor(kwargs.get("cc", None))
else:
_device_backend = get_backend(device_type)
assert _device_backend
arch = _device_backend.get_architecture_descriptor(**kwargs)
is_cuda = device_type == "cuda" and _is_cuda(arch)
is_hip = device_type in ["cuda", "hip"] and not is_cuda
context = ir.context()
>>>>>>> oai/main
constants = kwargs.get("constants", dict())
num_warps = kwargs.get("num_warps", 4)
num_stages = kwargs.get("num_stages", 3 if is_cuda and arch >= 75 else 2)
@@ -391,6 +423,7 @@ def compile(fn, **kwargs):
if extern_libs is None:
extern_libs = dict()
debug = kwargs.get("debug", False)
# build compilation stages
stages = dict()
stages["ast"] = (lambda path: fn, None)
@@ -402,11 +435,13 @@ def compile(fn, **kwargs):
lambda src: ttgir_to_llir(src, extern_libs, arch))
if is_cuda:
add_cuda_stages(arch, extern_libs, stages)
else:
elif is_hip:
add_rocm_stages(arch, extern_libs, stages)
else:
_device_backend.add_stages(arch, extern_libs, stages)
# find out the signature of the function
if isinstance(fn, triton.runtime.JITFunction):
if isinstance(fn, JITFunction):
configs = kwargs.get("configs", None)
signature = kwargs["signature"]
if configs is None:
@@ -420,27 +455,31 @@ def compile(fn, **kwargs):
kwargs["signature"] = signature
else:
assert isinstance(fn, str)
_, ir = os.path.basename(fn).split(".")
_, ir_name = os.path.basename(fn).split(".")
src = Path(fn).read_text()
import re
match = re.search(prototype_pattern[ir], src, re.MULTILINE)
match = re.search(prototype_pattern[ir_name], src, re.MULTILINE)
name, signature = match.group(1), match.group(2)
types = re.findall(arg_type_pattern[ir], signature)
if ir == 'ttgir':
types = re.findall(arg_type_pattern[ir_name], signature)
if ir_name == 'ttgir':
num_warps_matches = re.findall(ttgir_num_warps_pattern, src)
assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps"
assert "num_warps" not in kwargs or int(num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile"
num_warps = int(num_warps_matches[0])
param_tys = [convert_type_repr(ty) for ty in types]
signature = {k: v for k, v in enumerate(param_tys)}
first_stage = list(stages.keys()).index(ir)
first_stage = list(stages.keys()).index(ir_name)
# cache manager
so_path = make_stub(name, signature, constants)
if is_cuda or is_hip:
so_path = make_stub(name, signature, constants)
else:
so_path = _device_backend.make_launcher_stub(name, signature, constants)
# create cache manager
fn_cache_manager = get_cache_manager(make_hash(fn, arch, **kwargs))
# determine name and extension type of provided function
if isinstance(fn, triton.runtime.JITFunction):
if isinstance(fn, JITFunction):
name, ext = fn.__name__, "ast"
else:
name, ext = os.path.basename(fn).split(".")
@@ -464,19 +503,23 @@ def compile(fn, **kwargs):
"warp_size": warp_size,
"num_stages": num_stages,
"constants": _get_jsonable_constants(constants),
"debug": debug}
"debug": debug,
"arch": arch, }
if ext == "ptx":
assert "shared" in kwargs, "ptx compilation must provide shared memory size"
metadata["shared"] = kwargs["shared"]
# Add device type to meta information
metadata["device_type"] = device_type
first_stage = list(stages.keys()).index(ext)
asm = dict()
module = fn
# run compilation pipeline and populate metadata
for ir, (parse, compile_kernel) in list(stages.items())[first_stage:]:
ir_filename = f"{name}.{ir}"
for ir_name, (parse, compile_kernel) in list(stages.items())[first_stage:]:
ir_filename = f"{name}.{ir_name}"
if ir == ext:
if ir_name == ext:
next_module = parse(fn)
else:
path = metadata_group.get(ir_filename)
@@ -490,7 +533,7 @@ def compile(fn, **kwargs):
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
fn_cache_manager.put(next_module, ir_filename)
else:
if ir == "amdgcn":
if ir_name == "amdgcn":
extra_file_name = f"{name}.hsaco_path"
hasco_path = metadata_group.get(extra_file_name)
assert hasco_path is not None, "Expected to have hsaco_path in metadata when we have the amdgcn"
@@ -498,19 +541,21 @@ def compile(fn, **kwargs):
else:
next_module = parse(path)
if ir == "cubin":
asm[ir] = next_module
elif ir == "amdgcn":
asm[ir] = str(next_module[0])
if ir_name == "cubin":
asm[ir_name] = next_module
elif ir_name == "amdgcn":
asm[ir_name] = str(next_module[0])
else:
asm[ir] = str(next_module)
if ir == "llir" and "shared" not in metadata:
metadata["shared"] = _triton.get_shared_memory_size(module)
if ir == "ptx":
asm[ir_name] = str(next_module)
if ir_name == "llir" and "shared" not in metadata:
metadata["shared"] = get_shared_memory_size(module)
if ir_name == "ptx":
metadata["name"] = get_kernel_name(next_module, pattern='// .globl')
if ir == "amdgcn":
if ir_name == "amdgcn":
metadata["name"] = get_kernel_name(next_module[0], pattern='.globl')
asm["hsaco_path"] = next_module[1]
if not is_cuda and not is_hip:
_device_backend.add_meta_info(ir_name, module, next_module, metadata, asm)
module = next_module
# write-back metadata, if it didn't come from the cache
if metadata_path is None:
@@ -536,11 +581,13 @@ class CompiledKernel:
spec.loader.exec_module(mod)
self.c_wrapper = getattr(mod, "launch")
# initialize metadata
self.shared = metadata["shared"]
self.shared = metadata["shared"] if "shared" in metadata else 0
self.num_warps = metadata["num_warps"]
self.warp_size = metadata["warp_size"]
self.num_stages = metadata["num_stages"]
self.constants = metadata["constants"]
self.device_type = metadata["device_type"]
self.device_backend = get_backend(self.device_type) if self.device_type not in ["cuda", "hip"] else None
# initialize asm dict
self.asm = asm
# binaries are lazily initialized
@@ -553,15 +600,26 @@ class CompiledKernel:
def _init_handles(self):
if self.cu_module is not None:
return
device = triton.runtime.jit.get_current_device()
bin_path = {
driver.HIP: "hsaco_path",
driver.CUDA: "cubin"
}[driver.backend]
max_shared = driver.utils.get_device_properties(device)["max_shared_mem"]
if self.device_type in ["cuda", "hip"]:
device = get_current_device()
bin_path = {
driver.HIP: "hsaco_path",
driver.CUDA: "cubin"
}[driver.backend]
max_shared = driver.utils.get_device_properties(device)["max_shared_mem"]
fn_load_binary = driver.utils.load_binary
else:
assert self.device_backend
device = self.device_backend.get_current_device()
bin_path = self.device_backend.get_kernel_bin()
max_shared = self.device_backend.get_device_properties(device)["max_shared_mem"]
fn_load_binary = self.device_backend.get_load_binary_fn()
if self.shared > max_shared:
raise OutOfResources(self.shared, max_shared, "shared memory")
mod, func, n_regs, n_spills = driver.utils.load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device)
mod, func, n_regs, n_spills = fn_load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device)
self.n_spills = n_spills
self.n_regs = n_regs
@@ -578,8 +636,16 @@ class CompiledKernel:
def runner(*args, stream=None):
if stream is None:
<<<<<<< HEAD
stream = triton.runtime.jit.get_cuda_stream()
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.warp_size, self.shared, stream, self.cu_function,
=======
if self.device_type in ["cuda", "rocm"]:
stream = get_cuda_stream()
else:
stream = get_backend(self.device_type).get_stream(None)
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function,
>>>>>>> oai/main
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args)
return runner

View File

@@ -2,13 +2,14 @@ import itertools
import random
from typing import Tuple
import triton
import triton.language as tl
from .. import language as tl
# import .language.core as lcore
from ..language import core as lcore
from . import torch_wrapper
from .core import ExecutionContext
from .memory_map import MemoryMap
from .tl_lang import (TritonLangProxy, WrappedTensor, _primitive_to_tensor,
debugger_constexpr)
from triton.debugger import torch_wrapper
torch = torch_wrapper.torch
tl_method_backup = {}
@@ -59,12 +60,12 @@ class DebuggerFunction:
self.grid = grid
def _is_constexpr(self, name):
return name in self.func.__annotations__ and self.func.__annotations__[name] is triton.language.core.constexpr
return name in self.func.__annotations__ and self.func.__annotations__[name] is lcore.constexpr
def _get_constexpr(self):
result = []
for name, annotation in self.func.__annotations__.items():
if annotation is triton.language.core.constexpr:
if annotation is lcore.constexpr:
result.append(name)
return result

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
import dataclasses
from triton.debugger import torch_wrapper
from triton.interpreter import torch_wrapper
torch = torch_wrapper.torch

View File

@@ -1,7 +1,10 @@
import triton
from __future__ import annotations
# import triton
from ..language import core as lcore
from . import torch_wrapper
from .core import ExecutionContext
from .memory_map import MemoryMap
from triton.debugger import torch_wrapper
torch = torch_wrapper.torch
@@ -387,7 +390,7 @@ class TritonLangProxy:
if not isinstance(d.value, int):
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
shape = [x.value for x in shape]
if isinstance(dtype, triton.language.core.dtype):
if isinstance(dtype, lcore.dtype):
if dtype.is_fp32():
dtype = torch.float32
elif dtype.is_fp16():
@@ -405,7 +408,9 @@ class TritonLangProxy:
return torch.zeros(size=shape, dtype=dtype, device="cuda")
@_tensor_operation
def dequantize(self, input, scale, shift, nbit, dst_ty=torch.float16):
def dequantize(self, input, scale, shift, nbit, dst_ty=None):
if dst_ty is None:
dst_ty = torch.float16
raise NotImplementedError()
@_tensor_operation

View File

@@ -77,7 +77,7 @@ from .core import (
static_range,
tensor,
trans,
triton,
# triton,
uint16,
uint32,
uint64,

View File

@@ -5,9 +5,10 @@ from enum import Enum
from functools import wraps
from typing import Callable, List, Sequence, TypeVar
import triton
from .._C.libtriton.triton import ir
# import triton
from ..runtime.jit import jit
from . import semantic
from triton._C.libtriton.triton import ir
T = TypeVar('T')
@@ -1112,59 +1113,67 @@ def _add_atomic_docstr(name: str) -> Callable[[T], T]:
@builtin
@_add_atomic_docstr("compare-and-swap")
def atomic_cas(pointer, cmp, val, _builder=None):
def atomic_cas(pointer, cmp, val, sem=None, _builder=None):
cmp = _to_tensor(cmp, _builder)
val = _to_tensor(val, _builder)
return semantic.atomic_cas(pointer, cmp, val, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_cas(pointer, cmp, val, sem, _builder)
@builtin
@_add_atomic_docstr("exchange")
def atomic_xchg(pointer, val, mask=None, _builder=None):
def atomic_xchg(pointer, val, mask=None, sem=None, _builder=None):
val = _to_tensor(val, _builder)
return semantic.atomic_xchg(pointer, val, mask, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_xchg(pointer, val, mask, sem, _builder)
@builtin
@_add_atomic_docstr("add")
def atomic_add(pointer, val, mask=None, _builder=None):
def atomic_add(pointer, val, mask=None, sem=None, _builder=None):
val = _to_tensor(val, _builder)
return semantic.atomic_add(pointer, val, mask, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_add(pointer, val, mask, sem, _builder)
@builtin
@_add_atomic_docstr("max")
def atomic_max(pointer, val, mask=None, _builder=None):
def atomic_max(pointer, val, mask=None, sem=None, _builder=None):
val = _to_tensor(val, _builder)
return semantic.atomic_max(pointer, val, mask, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_max(pointer, val, mask, sem, _builder)
@builtin
@_add_atomic_docstr("min")
def atomic_min(pointer, val, mask=None, _builder=None):
def atomic_min(pointer, val, mask=None, sem=None, _builder=None):
val = _to_tensor(val, _builder)
return semantic.atomic_min(pointer, val, mask, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_min(pointer, val, mask, sem, _builder)
@builtin
@_add_atomic_docstr("logical and")
def atomic_and(pointer, val, mask=None, _builder=None):
def atomic_and(pointer, val, mask=None, sem=None, _builder=None):
val = _to_tensor(val, _builder)
return semantic.atomic_and(pointer, val, mask, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_and(pointer, val, mask, sem, _builder)
@builtin
@_add_atomic_docstr("logical or")
def atomic_or(pointer, val, mask=None, _builder=None):
def atomic_or(pointer, val, mask=None, sem=None, _builder=None):
val = _to_tensor(val, _builder)
return semantic.atomic_or(pointer, val, mask, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_or(pointer, val, mask, sem, _builder)
@builtin
@_add_atomic_docstr("logical xor")
def atomic_xor(pointer, val, mask=None, _builder=None):
def atomic_xor(pointer, val, mask=None, sem=None, _builder=None):
val = _to_tensor(val, _builder)
return semantic.atomic_xor(pointer, val, mask, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_xor(pointer, val, mask, sem, _builder)
# -----------------------
@@ -1339,7 +1348,7 @@ def _promote_reduction_input(t, _builder=None):
@builtin
def _argreduce(input, axis, combine_fn, _builder=None, _generator=None):
def _reduce_with_indices(input, axis, combine_fn, _builder=None, _generator=None):
axis = _constexpr_to_value(axis)
n = input.shape[axis]
index = arange(0, n, _builder=_builder)
@@ -1353,10 +1362,10 @@ def _argreduce(input, axis, combine_fn, _builder=None, _generator=None):
rvalue, rindices = reduce((input, index), axis, combine_fn,
_builder=_builder, _generator=_generator)
return rindices
return rvalue, rindices
@triton.jit
@jit
def minimum(x, y):
"""
Computes the element-wise minimum of :code:`x` and :code:`y`.
@@ -1369,7 +1378,7 @@ def minimum(x, y):
return where(x < y, x, y)
@triton.jit
@jit
def maximum(x, y):
"""
Computes the element-wise maximum of :code:`x` and :code:`y`.
@@ -1381,83 +1390,93 @@ def maximum(x, y):
"""
return where(x > y, x, y)
# max and argmax
@triton.jit
@jit
def _max_combine(a, b):
return maximum(a, b)
@triton.jit
@_add_reduction_docstr("maximum")
def max(input, axis=None):
input = _promote_reduction_input(input)
return reduce(input, axis, _max_combine)
@triton.jit
@jit
def _argmax_combine(value1, index1, value2, index2):
gt = value1 > value2
lt = value1 < value2
index_min = minimum(index1, index2)
index_ret = where(gt, index1, where(lt, index2, index_min))
value_ret = maximum(value1, value2)
value_ret = where(gt, value1, value2)
index_ret = where(gt, index1, index2)
return value_ret, index_ret
@triton.jit
@jit
@_add_reduction_docstr("maximum")
def max(input, axis=None, return_indices=False):
input = _promote_reduction_input(input)
if return_indices:
return _reduce_with_indices(input, axis, _argmax_combine)
else:
return reduce(input, axis, _max_combine)
@jit
@_add_reduction_docstr("maximum index")
def argmax(input, axis):
input = _promote_reduction_input(input)
return _argreduce(input, axis, _argmax_combine)
(_, ret) = max(input, axis, return_indices=True)
return ret
# min and argmin
@triton.jit
@jit
def _min_combine(a, b):
# TODO: minimum/maximum doesn't get lowered to fmin/fmax...
return minimum(a, b)
@triton.jit
@_add_reduction_docstr("minimum")
def min(input, axis=None):
input = _promote_reduction_input(input)
return reduce(input, axis, _min_combine)
@triton.jit
@jit
def _argmin_combine(value1, index1, value2, index2):
lt = value1 < value2
gt = value1 > value2
index_min = minimum(index1, index2)
index_ret = where(lt, index1, where(gt, index2, index_min))
value_ret = minimum(value1, value2)
value_ret = where(lt, value1, value2)
index_ret = where(lt, index1, index2)
return value_ret, index_ret
@triton.jit
@jit
@_add_reduction_docstr("minimum")
def min(input, axis=None, return_indices=False):
input = _promote_reduction_input(input)
if return_indices:
return _reduce_with_indices(input, axis, _argmin_combine)
else:
return reduce(input, axis, _min_combine)
@jit
@_add_reduction_docstr("minimum index")
def argmin(input, axis):
input = _promote_reduction_input(input)
return _argreduce(input, axis, _argmin_combine)
_, ret = min(input, axis, return_indices=True)
return ret
@triton.jit
@jit
def _sum_combine(a, b):
return a + b
# sum
@triton.jit
@jit
@_add_reduction_docstr("sum")
def sum(input, axis=None):
input = _promote_reduction_input(input)
return reduce(input, axis, _sum_combine)
@triton.jit
@jit
def _xor_combine(a, b):
return a ^ b
# xor sum
@builtin
@_add_reduction_docstr("xor sum")
def xor_sum(input, axis=None, _builder=None, _generator=None):

View File

@@ -1,4 +1,4 @@
import triton
from ..runtime.jit import jit
from . import core as tl
PHILOX_KEY_A: tl.constexpr = 0x9E3779B9
@@ -12,7 +12,7 @@ N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
# -------------------
@triton.jit
@jit
def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1).
@@ -33,7 +33,7 @@ def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAUL
return c0, c1, c2, c3
@triton.jit
@jit
def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
seed = seed.to(tl.uint64)
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
@@ -45,7 +45,7 @@ def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds)
@triton.jit
@jit
def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block, returns a single
@@ -61,7 +61,7 @@ def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
return ret
@triton.jit
@jit
def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block, returns four
@@ -82,7 +82,7 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
# rand
# -------------------
# @triton.jit
# @jit
# def uint32_to_uniform_float(x):
# """
# Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
@@ -90,7 +90,7 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
# two_to_the_minus_32: tl.constexpr = 2.328306e-10
# return x * two_to_the_minus_32
@triton.jit
@jit
def uint32_to_uniform_float(x):
"""
Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
@@ -102,7 +102,7 @@ def uint32_to_uniform_float(x):
return x * scale
@triton.jit
@jit
def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block,
@@ -116,7 +116,7 @@ def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
return uint32_to_uniform_float(source)
@triton.jit
@jit
def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offsets` block,
@@ -138,7 +138,7 @@ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
# -------------------
@triton.jit
@jit
def pair_uniform_to_normal(u1, u2):
"""Box-Muller transform"""
u1 = tl.maximum(1.0e-7, u1)
@@ -147,7 +147,7 @@ def pair_uniform_to_normal(u1, u2):
return r * tl.cos(th), r * tl.sin(th)
@triton.jit
@jit
def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block,
@@ -163,7 +163,7 @@ def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
return n1
@triton.jit
@jit
def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
"""
Given a :code:`seed` scalar and an :code:`offset` block,

View File

@@ -820,6 +820,22 @@ def _str_to_padding_option(padding_option):
return padding
def _str_to_sem(sem_option):
sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE
if sem_option:
if sem_option == "acquire":
sem = ir.MEM_SEMANTIC.ACQUIRE
elif sem_option == "release":
sem = ir.MEM_SEMANTIC.RELEASE
elif sem_option == "acq_rel":
sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE
elif sem_option == "relaxed":
sem = ir.MEM_SEMANTIC.RELAXED
else:
raise ValueError(f"Memory semantic {sem_option} not supported")
return sem
def _canonicalize_boundary_check(boundary_check, block_shape):
if boundary_check:
if not hasattr(boundary_check, "__iter__"):
@@ -1030,11 +1046,13 @@ def store(ptr: tl.tensor,
def atomic_cas(ptr: tl.tensor,
cmp: tl.tensor,
val: tl.tensor,
sem: str,
builder: ir.builder) -> tl.tensor:
sem = _str_to_sem(sem)
element_ty = ptr.type.scalar.element_ty
if element_ty.primitive_bitwidth not in [16, 32, 64]:
raise ValueError("atomic_cas only supports elements with width {16, 32, 64}")
return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle), val.type)
return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem), val.type)
def atom_red_typechecking_impl(ptr: tl.tensor,
@@ -1044,7 +1062,6 @@ def atom_red_typechecking_impl(ptr: tl.tensor,
builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]:
if not ptr.type.scalar.is_ptr():
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
element_ty = ptr.type.scalar.element_ty
if element_ty is tl.float16 and op != 'add':
raise ValueError("atomic_" + op + " does not support fp16")
@@ -1069,8 +1086,10 @@ def atom_red_typechecking_impl(ptr: tl.tensor,
def atomic_max(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder)
sem = _str_to_sem(sem)
sca_ty = val.type.scalar
# direct call to atomic_max for integers
if sca_ty.is_int():
@@ -1078,13 +1097,15 @@ def atomic_max(ptr: tl.tensor,
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX,
ptr.handle,
val.handle,
mask.handle),
mask.handle,
sem),
val.type)
else:
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX,
ptr.handle,
val.handle,
mask.handle),
mask.handle,
sem),
val.type)
# ROCM TODO: implement atomic_max/min for f32 as they are supported by MI cards.
# for float
@@ -1094,16 +1115,18 @@ def atomic_max(ptr: tl.tensor,
i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder)
pos = greater_equal(val, tl.tensor(builder.get_fp32(0), sca_ty), builder)
neg = less_than(val, tl.tensor(builder.get_fp32(0), sca_ty), builder)
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle), i_val.type)
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle), i_val.type)
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle, sem), i_val.type)
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle, sem), i_val.type)
return where(pos, pos_ret, neg_ret, builder)
def atomic_min(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder)
sem = _str_to_sem(sem)
sca_ty = val.type.scalar
# direct call to atomic_min for integers
if sca_ty.is_int():
@@ -1111,13 +1134,15 @@ def atomic_min(ptr: tl.tensor,
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
ptr.handle,
val.handle,
mask.handle),
mask.handle,
sem),
val.type)
else:
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN,
ptr.handle,
val.handle,
mask.handle),
mask.handle,
sem),
val.type)
# for float
# return atomic_smin(i_ptr, i_val) if val >= 0
@@ -1129,12 +1154,14 @@ def atomic_min(ptr: tl.tensor,
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
i_ptr.handle,
i_val.handle,
and_(mask, pos, builder).handle),
and_(mask, pos, builder).handle,
sem),
i_val.type)
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX,
i_ptr.handle,
i_val.handle,
and_(mask, neg, builder).handle),
and_(mask, neg, builder).handle,
sem),
i_val.type)
return where(pos, pos_ret, neg_ret, builder)
@@ -1142,43 +1169,53 @@ def atomic_min(ptr: tl.tensor,
def atomic_add(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder)
sem = _str_to_sem(sem)
sca_ty = val.type.scalar
op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD
return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle), val.type)
return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem), val.type)
def atomic_and(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle), val.type)
sem = _str_to_sem(sem)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem), val.type)
def atomic_or(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle), val.type)
sem = _str_to_sem(sem)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem), val.type)
def atomic_xor(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle), val.type)
sem = _str_to_sem(sem)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem), val.type)
def atomic_xchg(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle), val.type)
sem = _str_to_sem(sem)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem), val.type)
# ===----------------------------------------------------------------------===//
# Linear Algebra
@@ -1213,12 +1250,13 @@ def dot(lhs: tl.tensor,
out_dtype: tl.dtype,
builder: ir.builder) -> tl.tensor:
assert lhs.type.is_block() and rhs.type.is_block()
assert lhs.dtype == rhs.dtype, "lhs and rhs must have the same dtype!"
assert len(lhs.shape) == 2 and len(rhs.shape) == 2
assert lhs.shape[1].value == rhs.shape[0].value
assert lhs.dtype == rhs.dtype, f"First input ({lhs.dtype}) and second input ({rhs.dtype}) must have the same dtype!"
assert len(lhs.shape) == 2, f"First input shape ({lhs.shape}) is not two dimensional!"
assert len(rhs.shape) == 2, f"Second input shape ({rhs.shape}) is not two dimensional!"
assert lhs.shape[1].value == rhs.shape[0].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[1].value}) must be equal to first index of second shape ({rhs.shape[0].value})"
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
and rhs.shape[1].value >= 16,\
"small blocks not supported!"
f"All values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!"
if lhs.type.scalar.is_int():
assert lhs.type.scalar == tl.int8, "only int8 supported!"
# TODO: This is CUDA specific, check if ROCm has the same limitation

View File

@@ -1,7 +1,10 @@
import torch
import triton
import triton.language as tl
from ... import cdiv, heuristics, jit
from ... import language as tl
# import triton
# import language as tl
# ********************************************************
# --------------------------------------------------------
@@ -13,10 +16,10 @@ import triton.language as tl
# ********************************************************
@triton.heuristics({
@heuristics({
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
})
@triton.jit
@jit
def _sdd_kernel(
A, B, C,
stride_za, stride_ha, stride_ma, stride_ak,
@@ -127,7 +130,7 @@ def sdd_lut(layout, block, device):
# -----------------------------
@triton.jit
@jit
def _dsd_kernel(
A, B, C,
stride_az, stride_ha, stride_am, stride_ak,
@@ -227,7 +230,7 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=N
# meta-parameter heuristics
TILE_N = 128
# compute output
grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0]
grid = lambda meta: [cdiv(BS3, meta['TILE_N']), width, BS0]
_dsd_kernel[grid](
a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),

View File

@@ -1,7 +1,10 @@
import torch
import triton
import triton.language as tl
# import triton
# import language as tl
from ... import jit
from ... import language as tl
from ... import next_power_of_2
def num_warps(n):
@@ -16,7 +19,7 @@ def num_warps(n):
return 16
@triton.jit
@jit
def _blocksparse_softmax_fwd(
Out, A, stride_xz, LUT,
R, extent, stride_zr, stride_hr, # relative attention
@@ -71,7 +74,7 @@ def _blocksparse_softmax_fwd(
tl.store(Out + off_a + lane_n, out, mask=mask)
@triton.jit
@jit
def _blocksparse_softmax_bwd(
DA, stride_zdx,
DOut, stride_zdout,
@@ -169,7 +172,7 @@ class _softmax(torch.autograd.Function):
scale,
is_causal,
BLOCK_SIZE=block,
ROW_SIZE=triton.next_power_of_2(maxlut),
ROW_SIZE=next_power_of_2(maxlut),
IS_DENSE=is_dense,
num_warps=num_warps(maxlut)
)
@@ -208,7 +211,7 @@ class _softmax(torch.autograd.Function):
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2],
ctx.is_causal,
BLOCK_SIZE=ctx.block,
ROW_SIZE=triton.next_power_of_2(ctx.maxlut),
ROW_SIZE=next_power_of_2(ctx.maxlut),
IS_DENSE=ctx.is_dense,
num_warps=num_warps(ctx.maxlut)
)

View File

@@ -1,7 +1,10 @@
import torch
import triton
import triton.language as tl
# import triton
# import language as tl
from .. import heuristics, jit
from .. import language as tl
from .. import next_power_of_2
def num_warps(N):
@@ -12,9 +15,9 @@ def num_warps(N):
return 16
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
@triton.heuristics({'BLOCK': lambda nargs: triton.next_power_of_2(nargs['N'])})
@triton.jit
@heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
@heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])})
@jit
def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
row = tl.program_id(0)
cols = tl.arange(0, BLOCK)
@@ -37,9 +40,9 @@ def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
tl.store(LOSS + row, probs)
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
@triton.heuristics({'BLOCK': lambda nargs: triton.next_power_of_2(nargs['N'])})
@triton.jit
@heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
@heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])})
@jit
def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
row = tl.program_id(0)
cols = tl.arange(0, BLOCK)

View File

@@ -7,11 +7,14 @@ This is a Triton implementation of the Flash Attention algorithm
import torch
import triton
import triton.language as tl
from .. import cdiv, jit
from .. import language as tl
# import triton
# import language as tl
@triton.jit
@jit
def _fwd_kernel(
Q, K, V, sm_scale,
L, M,
@@ -87,7 +90,7 @@ def _fwd_kernel(
tl.store(out_ptrs, acc)
@triton.jit
@jit
def _bwd_preprocess(
Out, DO, L,
NewDO, Delta,
@@ -107,7 +110,7 @@ def _bwd_preprocess(
tl.store(Delta + off_m, delta)
@triton.jit
@jit
def _bwd_kernel(
Q, K, V, sm_scale, Out, DO,
DQ, DK, DV,
@@ -205,7 +208,7 @@ class _attention(torch.autograd.Function):
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
grid = (cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8

View File

@@ -1,9 +1,12 @@
import torch
import triton
import triton.language as tl
from .. import Config, autotune, cdiv, heuristics, jit
from .. import language as tl
from .matmul_perf_model import early_config_prune, estimate_matmul_time
# import triton
# import language as tl
def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
@@ -17,37 +20,37 @@ def get_configs_io_bound():
for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
# split_k
for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
configs.append(Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs
@triton.autotune(
@autotune(
configs=[
# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
] + get_configs_io_bound(),
key=['M', 'N', 'K'],
prune_configs_by={
@@ -56,10 +59,10 @@ def get_configs_io_bound():
'top_k': 10
},
)
@triton.heuristics({
@heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit
@jit
def _kernel(A, B, C, M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
@@ -146,7 +149,7 @@ class _matmul(torch.autograd.Function):
else:
dot_out_dtype = tl.int32
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),

View File

@@ -2,10 +2,11 @@ import heapq
import torch
import triton
import triton._C.libtriton.triton as _triton
from triton.runtime import driver
from triton.testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops
# import triton
from .. import cdiv
from .._C.libtriton.triton import runtime
from ..runtime import driver
from ..testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
@@ -41,13 +42,13 @@ def estimate_matmul_time(
):
''' return estimated running time in ms
= max(compute, loading) + store '''
backend = _triton.runtime.backend.CUDA
backend = runtime.backend.CUDA
device = torch.cuda.current_device()
dtype = A.dtype
dtsize = A.element_size()
num_cta_m = triton.cdiv(M, BLOCK_M)
num_cta_n = triton.cdiv(N, BLOCK_N)
num_cta_m = cdiv(M, BLOCK_M)
num_cta_n = cdiv(N, BLOCK_N)
num_cta_k = SPLIT_K
num_ctas = num_cta_m * num_cta_n * num_cta_k

View File

@@ -8,9 +8,15 @@ import os
import subprocess
import textwrap
from collections import defaultdict, namedtuple
from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, cast, overload
from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union, cast,
overload)
import triton
# import triton
# from .. import compile, CompiledKernel
from ..common.backend import get_backend
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TRITON_VERSION = "2.1.0"
def get_cuda_stream(idx=None):
@@ -65,7 +71,7 @@ class DependenciesFinder(ast.NodeVisitor):
lhs = self.visit(node.value)
while isinstance(lhs, ast.Attribute):
lhs = self.visit(lhs.value)
if lhs is None or lhs is triton:
if lhs is None or lhs.__name__ == "triton":
return None
return getattr(lhs, node.attr)
@@ -100,15 +106,15 @@ def version_key():
with open(__file__, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# compiler
compiler_path = os.path.join(*triton.__path__, 'compiler')
compiler_path = os.path.join(TRITON_PATH, 'compiler')
for lib in pkgutil.iter_modules([compiler_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# backend
with open(triton._C.libtriton.__file__, "rb") as f:
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
# language
language_path = os.path.join(*triton.__path__, 'language')
language_path = os.path.join(TRITON_PATH, 'language')
for lib in pkgutil.iter_modules([language_path]):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.md5(f.read()).hexdigest()]
@@ -117,7 +123,7 @@ def version_key():
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
except Exception:
ptxas_version = ''
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
return '-'.join(TRITON_VERSION) + '-' + ptxas_version + '-' + '-'.join(contents)
class KernelInterface(Generic[T]):
@@ -158,6 +164,22 @@ class JITFunction(KernelInterface[T]):
else:
raise TypeError(f'Unsupported type {type(arg)} for {arg}')
@staticmethod
def _device_of(arg):
if hasattr(arg, "device"):
if hasattr(arg.device, 'type'):
return arg.device.type
return ''
@staticmethod
def _pinned_memory_of(arg):
if hasattr(arg, "is_pinned"):
if isinstance(arg.is_pinned, Callable):
return arg.is_pinned()
return False
@staticmethod
def _spec_of(arg):
if hasattr(arg, "data_ptr"):
@@ -261,12 +283,29 @@ class JITFunction(KernelInterface[T]):
else:
return f'_key_of({arg})'
def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: List[bool]) -> str:
device_types = [device_type for device_type in device_types if device_type != '']
# Return cuda if one of the input tensors is cuda
if 'cuda' in device_types:
import torch
return 'hip' if torch.version.hip else 'cuda'
is_cpu = all(device_type == 'cpu' for device_type in device_types)
is_pinned_memory = any(pinned_memory_flag for pinned_memory_flag in pinned_memory_flags)
# Return cuda if all the input tensors are cpu while the memory is pinned
if is_cpu and is_pinned_memory:
return 'cuda'
return device_types[0] if len(device_types) > 0 else 'cuda'
def _make_launcher(self):
regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs]
constexpr_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i in self.constexprs]
args = ', '.join(regular_args)
# cache key for regular argument type
sig_keys = ', '.join([self._get_arg_sig_key(arg) for arg in regular_args])
device_types = '[' + ', '.join([f'_device_of({arg})' for arg in regular_args]) + ']'
pinned_memory_flags = '[' + ', '.join([f'_pinned_memory_of({arg})' for arg in regular_args]) + ']'
# cache key for constexpr argument values
constexpr_keys = ', '.join(constexpr_args)
# cache key for argument specialization
@@ -280,7 +319,12 @@ class JITFunction(KernelInterface[T]):
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
src = f"""
<<<<<<< HEAD
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, warp_size=32, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None):
=======
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
from ..compiler import compile, CompiledKernel
>>>>>>> oai/main
sig_key = {sig_keys},
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()}
@@ -294,15 +338,38 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, warp_size
grid_0 = grid[0]
grid_1 = grid[1] if grid_size > 1 else 1
grid_2 = grid[2] if grid_size > 2 else 1
if device_type is None:
device_types = [_device_type for _device_type in {device_types} if _device_type != '']
device_type = self._conclude_device_type(device_types, {pinned_memory_flags})
device_backend = None
if device_type not in ['cuda', 'hip']:
device_backend = get_backend(device_type)
if device_backend is None:
raise ValueError('Cannot find backend for ' + device_type)
if device is None:
device = get_current_device()
set_current_device(device)
if device_type in ['cuda', 'hip']:
device = get_current_device()
set_current_device(device)
else:
device = device_backend.get_current_device()
device_backend.set_current_device(device)
if stream is None and not warmup:
stream = get_cuda_stream(device)
if device_type in ['cuda', 'hip']:
stream = get_cuda_stream(device)
else:
stream = device_backend.get_stream()
bin = cache[device].get(key, None)
if bin is not None:
if not warmup:
<<<<<<< HEAD
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.warp_size, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args})
=======
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, {args})
>>>>>>> oai/main
return bin
# kernel not cached -- compile
else:
@@ -319,17 +386,31 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, warp_size
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
<<<<<<< HEAD
if not self._call_hook(key, signature, device, constants, num_warps, warp_size, num_stages, extern_libs, configs):
bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, warp_size=warp_size, num_stages=num_stages, extern_libs=extern_libs, configs=configs, debug=self.debug)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.warp_size, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)
=======
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
>>>>>>> oai/main
self.cache[device][key] = bin
return bin
return None
"""
scope = {"version_key": version_key(), "get_cuda_stream": get_cuda_stream,
"self": self, "_spec_of": self._spec_of, "_key_of": self._key_of,
"cache": self.cache, "triton": triton,
scope = {"version_key": version_key(),
"get_cuda_stream": get_cuda_stream,
"self": self,
"_spec_of": self._spec_of,
"_key_of": self._key_of,
"_device_of": self._device_of,
"_pinned_memory_of": self._pinned_memory_of,
"cache": self.cache,
"__spec__": __spec__,
"get_backend": get_backend,
"get_current_device": get_current_device,
"set_current_device": set_current_device}
exec(src, scope)
@@ -462,7 +543,7 @@ def jit(
def decorator(fn: T) -> JITFunction[T]:
assert callable(fn)
if interpret:
from ..debugger.debugger import GridSelector
from ..interpreter.interpreter import GridSelector
return GridSelector(fn)
else:
return JITFunction(

View File

@@ -4,7 +4,7 @@ import subprocess
import sys
from contextlib import contextmanager
import triton._C.libtriton.triton as _triton
from ._C.libtriton.triton import runtime
def nvsmi(attrs):
@@ -86,9 +86,9 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
end_event[i].record()
# Record clocks
torch.cuda.synchronize()
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)])
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float)
if quantiles is not None:
ret = torch.quantile(times, torch.tensor(quantiles)).tolist()
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
if len(ret) == 1:
ret = ret[0]
return ret
@@ -281,7 +281,7 @@ def get_dram_gbps(backend=None, device=None):
from .runtime import driver
if not backend:
backend = _triton.runtime.backend.CUDA
backend = runtime.backend.CUDA
if not device:
device = torch.cuda.current_device()
mem_clock_khz = driver.utils.get_device_properties(device)["mem_clock_rate"] # in kHz
@@ -295,7 +295,7 @@ def get_max_tensorcore_tflops(dtype, backend=None, device=None, clock_rate=None)
from .runtime import driver
if not backend:
backend = _triton.runtime.backend.CUDA
backend = runtime.backend.CUDA
if not device:
device = torch.cuda.current_device()
@@ -398,7 +398,7 @@ def get_max_simd_tflops(dtype, backend=None, device=None):
from .runtime import driver
if not backend:
backend = _triton.runtime.backend.CUDA
backend = runtime.backend.CUDA
if not device:
device = torch.cuda.current_device()

View File

@@ -1,8 +1,11 @@
import argparse
import sys
import triton._C.libtriton.triton as libtriton
import triton.compiler.compiler as tc
from .._C.libtriton.triton import ir
# import triton.compiler.compiler as tc
from ..compiler.compiler import (get_amdgpu_arch_fulldetails, llir_to_amdgcn_and_hsaco,
llir_to_ptx, optimize_ttgir, optimize_ttir,
ttgir_to_llir, ttir_to_ttgir)
if __name__ == '__main__':
@@ -32,12 +35,12 @@ if __name__ == '__main__':
sys.exit(0)
# parse source file to MLIR module
context = libtriton.ir.context()
module = libtriton.ir.parse_mlir_module(args.src, context)
context = ir.context()
module = ir.parse_mlir_module(args.src, context)
module.context = context
# optimizer triton-ir
module = tc.optimize_ttir(module, arch=args.sm)
module = optimize_ttir(module, arch=args.sm)
if args.target == 'triton-ir':
print(module.str())
sys.exit(0)
@@ -49,7 +52,7 @@ if __name__ == '__main__':
if args.target == 'amdgcn':
# auto detect available architecture and features
# if nothing detected, set with default values
arch_details = tc.get_amdgpu_arch_fulldetails()
arch_details = get_amdgpu_arch_fulldetails()
if not arch_details:
arch_name = ""
arch_triple = "amdgcn-amd-amdhsa"
@@ -72,13 +75,21 @@ if __name__ == '__main__':
# triton-ir -> triton-gpu-ir
# use compute_capability == 80
<<<<<<< HEAD
module = tc.ttir_to_ttgir(module, num_warps=args.num_warps, warpsize=arch_warpsize) # num_stages=3, compute_capability=80)
module = tc.optimize_ttgir(module, num_stages=3, arch=args.gfx)
# triton-gpu-ir -> llvm-ir
# use compute_capability == 80
module = tc.ttgir_to_llir(module, extern_libs=None, arch=args.gfx)
=======
module = ttir_to_ttgir(module, num_warps=args.num_warps) # num_stages=3, compute_capability=80)
module = optimize_ttgir(module, num_stages=3, arch=80)
# triton-gpu-ir -> llvm-ir
# use compute_capability == 80
module = ttgir_to_llir(module, extern_libs=None, arch=80)
>>>>>>> oai/main
# llvm-ir -> amdgcn asm, hsaco binary
module, hsaco_path = tc.llir_to_amdgcn_and_hsaco(module, arch_name, arch_triple, arch_features)
module, hsaco_path = llir_to_amdgcn_and_hsaco(module, arch_name, arch_triple, arch_features)
print(hsaco_path)
print(module)
@@ -93,14 +104,23 @@ if __name__ == '__main__':
raise argparse.ArgumentError(None, "Must specify --sm or --gfx for ttgir compilation")
# triton-ir -> triton-gpu-ir
<<<<<<< HEAD
module = tc.ttir_to_ttgir(module, num_warps=args.num_warps, warpsize=tc.CUDA_DEFAULT_WARP_SIZE)
module = tc.optimize_ttgir(module, num_stages=3, arch=arch)
=======
module = ttir_to_ttgir(module, num_warps=args.num_warps)
module = optimize_ttgir(module, num_stages=3, arch=args.sm)
>>>>>>> oai/main
if args.target == 'triton-gpu-ir':
print(module.str())
sys.exit(0)
# triton-gpu-ir -> llvm-ir
<<<<<<< HEAD
module = tc.ttgir_to_llir(module, extern_libs=None, arch=arch)
=======
module = ttgir_to_llir(module, extern_libs=None, arch=args.sm)
>>>>>>> oai/main
if args.target == 'llvm-ir':
print(module)
sys.exit(0)
@@ -111,12 +131,12 @@ if __name__ == '__main__':
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
if not args.ptx_version:
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")
module = tc.llir_to_ptx(module, arch=args.sm, ptx_version=args.ptx_version)
module = llir_to_ptx(module, arch=args.sm, ptx_version=args.ptx_version)
# llvm-ir -> amdgcn
if args.target == 'amdgcn':
if not args.gfx:
raise argparse.ArgumentError(None, "Must specify --gfx for AMDGCN compilation")
module, hsaco_path = tc.llir_to_amdgcn_and_hsaco(module, args.gfx)
module, hsaco_path = llir_to_amdgcn_and_hsaco(module, args.gfx)
print(module)

View File

@@ -1,7 +1,11 @@
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s
tt.func @ops() {
<<<<<<< HEAD
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {{.*}}
=======
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {{.*}}
>>>>>>> oai/main
%a = arith.constant dense<1.00e+00> : tensor<128x32xf16>
%b = arith.constant dense<2.00e+00> : tensor<32x128xf16>
%c = arith.constant dense<3.00e+00> : tensor<128x128xf32>
@@ -33,10 +37,17 @@ tt.func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
tt.func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
// Test if the total number of threadsPerWarp is 64
// Test if the total number of warps is 2
<<<<<<< HEAD
// CHECK: #[[blocked0:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 16], warpsPerCTA = [1, 2], order = [0, 1]}>
// CHECK: #[[blocked1:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 2], order = [0, 1]}>
// CHECK: #[[blocked2:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 2], order = [0, 1]}>
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {{.*}}
=======
// CHECK: #[[blocked0:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 2], order = [0, 1]}>
// CHECK: #[[blocked1:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 2], order = [0, 1]}>
// CHECK: #[[blocked2:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 2], order = [0, 1]}>
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {{.*}}
>>>>>>> oai/main
%c0 = arith.constant dense<1.00e+00> : tensor<4x4xf32>
%c1 = arith.constant dense<2.00e+00> : tensor<8x2xf32>
%c2 = arith.constant dense<3.00e+00> : tensor<16x16xf32>

View File

@@ -132,6 +132,21 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: store_with_cache_attr
tt.func @store_with_cache_attr(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
// CHECK: llvm.inline_asm
// CHECK-SAME: st.global.L1::evict_last.b32
// CHECK: llvm.inline_asm
// CHECK-SAME: st.global.L1::evict_last.b32
tt.store %a_ptr_init, %cst_0, %cst {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
tt.return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
module attributes {"triton_gpu.num-warps" = 2 : i32} {
// CHECK-LABEL: global_load_store_no_vec
@@ -1462,6 +1477,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: atomic_add_f32
tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
<<<<<<< HEAD
// GCN-NOT: llvm.inline_asm
// GCN: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr<f32, 1>, f32
// PTC: llvm.inline_asm
@@ -1469,6 +1485,13 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// PTX: llvm.inline_asm
// PTX-SAME: @$3 atom.global.gpu.add.f32
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
=======
// CHECK: llvm.inline_asm
// CHECK-SAME: @$3 atom.global.gpu.relaxed.add.f32
// CHECK: llvm.inline_asm
// CHECK-SAME: @$3 atom.global.gpu.relaxed.add.f32
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32, sem = 1 : i32} : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
>>>>>>> oai/main
tt.return
}
}
@@ -1478,6 +1501,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: atomic_add_f32_scalar
tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : i1, %arg2 : f32) {
<<<<<<< HEAD
// GCN-NOT: llvm.inline_asm
// GCN: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr<f32, 1>, f32
// PTX: llvm.icmp "eq"
@@ -1485,6 +1509,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// PTX: llvm.inline_asm
// PTX-SAME: @$3 atom.global.gpu.add.f32
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (!tt.ptr<f32>, f32, i1) -> f32
=======
// CHECK: llvm.icmp "eq"
// CHECK: llvm.inline_asm
// CHECK-SAME: @$3 atom.global.gpu.relaxed.add.f32
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32, sem = 1: i32} : (!tt.ptr<f32>, f32, i1) -> f32
>>>>>>> oai/main
tt.return
}
}

View File

@@ -9,6 +9,8 @@
// CHECK: [[$col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
// CHECK: [[$col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
// CHECK-LABEL: cst
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @cst() -> tensor<1024xi32, #layout1> {
%cst = arith.constant dense<0> : tensor<1024xi32, #layout0>
%1 = triton_gpu.convert_layout %cst : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
@@ -67,8 +69,6 @@ tt.func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
tt.return
}
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @remat_fast_load(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%0 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<16x!tt.ptr<i32>, #layout1>
%1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #layout1>
@@ -80,7 +80,6 @@ tt.func @remat_fast_load(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
tt.store %5, %4 : tensor<16xi32, #layout0>
tt.return
}
}
// CHECK-LABEL: if
tt.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
@@ -164,6 +163,8 @@ tt.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility =
tt.return
}
}
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#slice1dim1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}>
@@ -173,6 +174,7 @@ tt.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility =
#blocked4 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
// CHECK-LABEL: transpose
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, {{%cst.*}}, {{%cst.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[$row_layout]]>
@@ -212,8 +214,10 @@ tt.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32
tt.store %24, %25, %26 : tensor<64x64xf32, #blocked4>
tt.return
}
}
// CHECK-LABEL: loop
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32) {
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: [[loop_ret:%.*]]:2 = scf.for {{.*}} -> (tensor<64x64xf32, [[$row_layout]]>, tensor<64x64x!tt.ptr<f32>, [[$row_layout]]>)
@@ -266,8 +270,10 @@ tt.func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32,
tt.store %20, %21, %22 : tensor<64x64xf32, #blocked1>
tt.return
}
}
// CHECK-LABEL: loop_if
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @loop_if(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32) {
%cst = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
%cst_0 = arith.constant dense<64> : tensor<64x64xi32, #blocked1>
@@ -318,8 +324,10 @@ tt.func @loop_if(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i3
tt.store %20, %21, %22 : tensor<64x64xf32, #blocked1>
tt.return
}
}
// CHECK-LABEL: vecadd
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
// CHECK-NOT: triton_gpu.convert_layout
%c256_i32 = arith.constant 256 : i32
@@ -349,9 +357,11 @@ tt.func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr
tt.store %21, %22 : tensor<256xf32, #layout1>
tt.return
}
}
// Select has args with different element types
// CHECK-LABEL: select
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout
%cst = arith.constant dense<30000> : tensor<1x1xi32, #blocked2>
@@ -400,9 +410,11 @@ tt.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr
}
tt.return
}
}
// Make sure the following IR doesn't hang the compiler.
// CHECK-LABEL: long_func
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg13: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg14: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg15: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<1.000000e+00> : tensor<1024xf32, #blocked0>
%cst_0 = arith.constant dense<5.000000e-04> : tensor<1024xf32, #blocked0>
@@ -796,10 +808,12 @@ tt.func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg
tt.store %365, %366 : tensor<1024xf64, #blocked0>
tt.return
}
}
// A mnist model from torch inductor.
// Check if topological sort is working correct and there's no unnecessary convert
// CHECK-LABEL: mnist
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32) {
// CHECK-NOT: triton_gpu.convert_layout
%cst = arith.constant dense<10> : tensor<16x1xi32, #blocked2>
@@ -884,17 +898,19 @@ tt.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !
tt.store %61, %62, %63 : tensor<16x16xf32, #blocked4>
tt.return
}
}
// -----
// cmpf and cmpi have different operands and result types
// CHECK-LABEL: cmp
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 4], order = [0, 1]}>
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [0, 1]}>
#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
// cmpf and cmpi have different operands and result types
// CHECK-LABEL: cmp
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
%c64 = arith.constant 64 : index
%c2048 = arith.constant 2048 : index
@@ -1034,11 +1050,13 @@ tt.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt
}
tt.return
}
}
// -----
// Just make sure it doesn't crash on non-tensor types.
// CHECK-LABEL: if_no_tensor
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func public @if_no_tensor(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout
%c-1_i64 = arith.constant -1 : i64
@@ -1062,6 +1080,7 @@ tt.func public @if_no_tensor(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %
tt.store %9, %8 {cache = 1 : i32, evict = 1 : i32} : f32
tt.return
}
}
// -----

View File

@@ -11,6 +11,8 @@
#BLR = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#BLC = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: tt.func @push_elementwise1
// CHECK: %[[ALOAD:.*]] = tt.load %arg0
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]]
@@ -122,3 +124,5 @@ tt.func @push_elementwise5(
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av1> * tensor<16x16xf16, #Bv1> -> tensor<16x16xf32, #Cv1>
tt.return %newc : tensor<16x16xf32, #Cv1>
}
}

View File

@@ -313,3 +313,122 @@ tt.func @lut_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt
}
tt.return %79#0 : tensor<16x16xf32, #C>
}
// CHECK: tt.func @post_load_inv
// CHECK: scf.for
// CHECK: arith.index_cast
// CHECK-DAG: %[[IV:.*]] = arith.index_cast
// CHECK: %[[NEXT_IV:.*]] = arith.addi %[[IV]], %c1_i32 : i32
// CHECK-NOT: arith.addi %[[NEXT_IV]]
tt.func @post_load_inv(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg3: i32 {tt.divisibility = 16 : i32},
%arg4: i32 {tt.divisibility = 16 : i32},
%arg5: i32 {tt.divisibility = 16 : i32},
%arg6: i32 {tt.divisibility = 16 : i32},
%arg7: i32 {tt.divisibility = 16 : i32},
%arg8: i32 {tt.divisibility = 16 : i32}) -> tensor<32x32xf32, #C> {
%c0_index = arith.constant 0 : index
%c1_index = arith.constant 1 : index
%c1_i32 = arith.constant 1 : i32
%c32_i32 = arith.constant 32 : i32
%84 = arith.constant 900 : index
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #C>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #AL>
%50 = tt.splat %arg3 : (i32) -> tensor<1x32xi32, #AL>
%59 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>, #AL>
%81 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>, #AL>
%66 = tt.splat %arg4 : (i32) -> tensor<32x1xi32, #AL>
%60 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>, #AL>
%82 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>, #AL>
%85:3 = scf.for %arg9 = %c0_index to %84 step %c1_index iter_args(%arg10 = %cst, %arg11 = %59, %arg12 = %81) -> (tensor<32x32xf32, #C>, tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32x!tt.ptr<f32>, #AL>) {
%130 = arith.index_cast %arg9 : index to i32
%107 = arith.muli %130, %c32_i32 : i32
%108 = arith.subi %arg5, %107 : i32
%109 = tt.splat %108 : (i32) -> tensor<1x32xi32, #AL>
%110 = "triton_gpu.cmpi"(%50, %109) <{predicate = 2 : i64}> : (tensor<1x32xi32, #AL>, tensor<1x32xi32, #AL>) -> tensor<1x32xi1, #AL>
%111 = tt.broadcast %110 : (tensor<1x32xi1, #AL>) -> tensor<32x32xi1, #AL>
%112 = tt.load %arg11, %111, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL>
%113 = tt.splat %108 : (i32) -> tensor<32x1xi32, #AL>
%114 = "triton_gpu.cmpi"(%66, %113) <{predicate = 2 : i64}> : (tensor<32x1xi32, #AL>, tensor<32x1xi32, #AL>) -> tensor<32x1xi1, #AL>
%115 = tt.broadcast %114 : (tensor<32x1xi1, #AL>) -> tensor<32x32xi1, #AL>
%116 = tt.load %arg12, %115, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL>
%117 = triton_gpu.convert_layout %112 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>>
%118 = triton_gpu.convert_layout %116 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>>
%119 = tt.dot %117, %118, %arg10 {allowTF32 = true} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C>
%131 = arith.index_cast %arg9 : index to i32
%120 = arith.addi %131, %c1_i32 : i32
%121 = arith.muli %120, %c32_i32 : i32
%122 = tt.splat %121 : (i32) -> tensor<32x32xi32, #AL>
%123 = tt.addptr %60, %122 : tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32xi32, #AL>
%124 = arith.muli %121, %arg7 : i32
%125 = tt.splat %124 : (i32) -> tensor<32x32xi32, #AL>
%126 = tt.addptr %82, %125 : tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32xi32, #AL>
scf.yield %119, %123, %126 : tensor<32x32xf32, #C>, tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32x!tt.ptr<f32>, #AL>
}
tt.return %85#0 : tensor<32x32xf32, #C>
}
// CHECK: tt.func @cross_iter_dep
// CHECK: triton_gpu.async_commit_group
// CHECK: triton_gpu.async_commit_group
// CHECK: triton_gpu.async_commit_group
// CHECK: triton_gpu.async_commit_group
// CHECK: %[[PTR0:.*]] = tt.addptr
// CHECK: %[[PTR1:.*]] = tt.addptr
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[BUF0:.*]] = %[[PTR0]], {{.*}}, %[[BUF1:.*]] = %[[PTR1]]
// CHECK: scf.yield
// CHECK-SAME: %[[BUF0]]
// CHECK-SAME: %[[BUF1]]
tt.func @cross_iter_dep(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg3: i32 {tt.divisibility = 16 : i32},
%arg4: i32 {tt.divisibility = 16 : i32},
%arg5: i32 {tt.divisibility = 16 : i32},
%arg6: i32 {tt.divisibility = 16 : i32},
%arg7: i32 {tt.divisibility = 16 : i32},
%arg8: i32 {tt.divisibility = 16 : i32}) -> tensor<32x32xf32, #C> {
%c0_i32 = arith.constant 0 : index
%118 = arith.constant 32 : index
%c1_i32 = arith.constant 1 : index
%c2_i32 = arith.constant 2 : i32
%c32_i32 = arith.constant 32 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #C>
%cst_1 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #AL>
%78 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>, #AL>
%110 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>, #AL>
%112 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>, #AL>
%113 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>, #AL>
%116 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>, #AL>
%65 = tt.splat %arg3 : (i32) -> tensor<1x32xi32, #AL>
%88 = tt.splat %arg4 : (i32) -> tensor<32x1xi32, #AL>
%80 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>, #AL>
%119:5 = scf.for %arg9 = %c0_i32 to %118 step %c1_i32 iter_args(%arg10 = %cst, %arg11 = %78, %arg12 = %110, %arg13 = %113, %arg14 = %116) -> (tensor<32x32xf32, #C>, tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32x!tt.ptr<f32>, #AL>) {
%161 = arith.index_cast %arg9 : index to i32
%141 = arith.muli %161, %c32_i32 : i32
%142 = arith.subi %arg5, %141 : i32
%143 = tt.splat %142 : (i32) -> tensor<1x32xi32, #AL>
%144 = "triton_gpu.cmpi"(%65, %143) <{predicate = 2 : i64}> : (tensor<1x32xi32, #AL>, tensor<1x32xi32, #AL>) -> tensor<1x32xi1, #AL>
%145 = tt.broadcast %144 : (tensor<1x32xi1, #AL>) -> tensor<32x32xi1, #AL>
%146 = tt.load %arg11, %145, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL>
%147 = tt.splat %142 : (i32) -> tensor<32x1xi32, #AL>
%148 = "triton_gpu.cmpi"(%88, %147) <{predicate = 2 : i64}> : (tensor<32x1xi32, #AL>, tensor<32x1xi32, #AL>) -> tensor<32x1xi1, #AL>
%149 = tt.broadcast %148 : (tensor<32x1xi1, #AL>) -> tensor<32x32xi1, #AL>
%150 = tt.load %arg12, %149, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL>
%151 = triton_gpu.convert_layout %146 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>>
%152 = triton_gpu.convert_layout %150 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>>
%153 = tt.dot %151, %152, %arg10 {allowTF32 = true} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C>
%162 = arith.index_cast %arg9 : index to i32
%154 = arith.addi %162, %c2_i32 : i32
%155 = arith.muli %154, %c32_i32 : i32
%156 = tt.splat %155 : (i32) -> tensor<32x32xi32, #AL>
%157 = tt.addptr %80, %156 : tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32xi32, #AL>
%158 = arith.muli %155, %arg7 : i32
%159 = tt.splat %158 : (i32) -> tensor<32x32xi32, #AL>
%160 = tt.addptr %112, %159 : tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32xi32, #AL>
scf.yield %153, %arg13, %arg14, %157, %160 : tensor<32x32xf32, #C>, tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32x!tt.ptr<f32>, #AL>, tensor<32x32x!tt.ptr<f32>, #AL>
}
tt.return %119#0 : tensor<32x32xf32, #C>
}

View File

@@ -13,6 +13,13 @@ struct TestAliasPass
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasPass);
static std::string getValueOperandName(Value value, AsmState &state) {
std::string opName;
llvm::raw_string_ostream ss(opName);
value.printAsOperand(ss, state);
return opName;
}
static void print(StringRef name, SmallVector<std::string, 4> &vals,
raw_ostream &os) {
if (vals.empty())

6
utils/nightly.pypirc Normal file
View File

@@ -0,0 +1,6 @@
[distutils]
Index-servers =
Triton-Nightly
[Triton-Nightly]
Repository = https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/upload/