Merge commit '36fc54b6f28168d3644808bfe299f1ba06a36272' into ifu230908-2

Conflicts:
	.gitignore
	bin/triton-translate.cpp
	include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h
	include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
	include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td
	lib/Analysis/Utility.cpp
	lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp
	lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
	lib/Conversion/TritonGPUToLLVM/Utility.h
	lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp
	lib/Dialect/TritonGPU/IR/Dialect.cpp
	lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
	lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
	lib/Target/LLVMIR/LLVMIRTranslation.cpp
	python/src/triton.cc
	python/test/unit/runtime/test_subproc.py
	python/triton/compiler/compiler.py
	python/triton/compiler/make_launcher.py
	python/triton/language/semantic.py
	python/triton/runtime/jit.py
	python/tutorials/06-fused-attention.py
	test/Conversion/triton_to_tritongpu.mlir
	test/Conversion/tritongpu_to_llvm.mlir
	test/TritonGPU/coalesce.mlir
	unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt
This commit is contained in:
Jason Furmanek
2023-10-02 18:01:04 +00:00
259 changed files with 32652 additions and 3712 deletions

85
.github/workflows/compare-artifacts.yml vendored Normal file
View File

@@ -0,0 +1,85 @@
name: Compare Artifacts
on:
workflow_run:
workflows:
- Integration Tests
types:
- completed
jobs:
Compare-artifacts:
runs-on: ubuntu-latest
if: ${{ github.event.workflow_run.conclusion == 'success' }}
steps:
- name: Download PR number artifact
uses: actions/github-script@v6
with:
script: |
let allArtifacts = await github.rest.actions.listWorkflowRunArtifacts({
owner: context.repo.owner,
repo: context.repo.repo,
run_id: context.payload.workflow_run.id,
});
let matchArtifact = allArtifacts.data.artifacts.filter((artifact) => {
return artifact.name == "pr_number"
})[0];
let download = await github.rest.actions.downloadArtifact({
owner: context.repo.owner,
repo: context.repo.repo,
artifact_id: matchArtifact.id,
archive_format: 'zip',
});
let fs = require('fs');
fs.writeFileSync(`${process.env.GITHUB_WORKSPACE}/pr_number.zip`, Buffer.from(download.data));
- name: Download comparison result artifact
uses: actions/github-script@v6
with:
script: |
let allArtifacts = await github.rest.actions.listWorkflowRunArtifacts({
owner: context.repo.owner,
repo: context.repo.repo,
run_id: context.payload.workflow_run.id,
});
let matchArtifact = allArtifacts.data.artifacts.filter((artifact) => {
return artifact.name == "comparison_result"
})[0];
let download = await github.rest.actions.downloadArtifact({
owner: context.repo.owner,
repo: context.repo.repo,
artifact_id: matchArtifact.id,
archive_format: 'zip',
});
let fs = require('fs');
fs.writeFileSync(`${process.env.GITHUB_WORKSPACE}/comparison_result.zip`, Buffer.from(download.data));
- name: Unzip artifacts
run: |
unzip pr_number.zip
unzip comparison_result.zip
- name: Print artifacts
uses: actions/github-script@v6
with:
script: |
let fs = require('fs');
let pr_number = Number(fs.readFileSync('./pr_number'));
let comparison_result = fs.readFileSync('./comparison_result', 'utf8');
console.log("PR number = ", pr_number);
console.log("Comparison result = ", comparison_result);
- name: Comment on PR
uses: actions/github-script@v6
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
let fs = require('fs');
let run_id = context.payload.workflow_run.id;
let issue_number = Number(fs.readFileSync('./pr_number'));
let comparison_result = fs.readFileSync('./comparison_result', 'utf8');
const message = `:warning: **This PR does not produce bitwise identical kernels as the branch it's merged against.** Please check artifacts for details. [Download the output file here](https://github.com/${{ github.repository }}/actions/runs/${run_id}).`;
if (comparison_result.trim() !== 'SUCCESS') {
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: issue_number,
body: message
});
}

View File

@@ -24,6 +24,7 @@ jobs:
run: |
pip3 install tabulate
pip3 install cmake
pip3 install sphinx
#- name: Fetch dependent branches
# run: |
@@ -33,7 +34,7 @@ jobs:
run: |
cd docs
export PATH=$(python3 -c "import cmake; print(cmake.CMAKE_BIN_DIR)"):$PATH
python3 -m sphinx_multiversion . _build/html/
python3 -m sphinx . _build/html/main
- name: Update docs
run: |

View File

@@ -27,7 +27,7 @@ jobs:
run: |
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
echo '::set-output name=matrix-required::[["self-hosted", "A100"], ["self-hosted", "H100"]]'
echo '::set-output name=matrix-optional::[["self-hosted", "gfx908"], ["self-hosted", "arc770"]]'
echo '::set-output name=matrix-optional::[]'
else
echo '::set-output name=matrix-required::["ubuntu-latest"]'
echo '::set-output name=matrix-optional::["ubuntu-latest"]'
@@ -50,6 +50,9 @@ jobs:
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}}
run: |
echo "BACKEND=CUDA" >> "${GITHUB_ENV}"
echo "ENABLE_TMA=0" >> "${GITHUB_ENV}"
echo "ENABLE_MMA_V3=0" >> "${GITHUB_ENV}"
echo "TRITON_DISABLE_LINE_INFO=1" >> "${GITHUB_ENV}"
- name: Clear cache
run: |
@@ -59,6 +62,11 @@ jobs:
run: |
echo "PATH=${HOME}/.local/bin:${PATH}" >> "${GITHUB_ENV}"
- name: Check pre-commit
run: |
python3 -m pip install --upgrade pre-commit
python3 -m pre_commit run --all-files --verbose
- name: Install Triton
if: ${{ env.BACKEND == 'CUDA'}}
run: |
@@ -80,19 +88,53 @@ jobs:
fi
lit -v "${LIT_TEST_DIR}"
- name: Run python tests on CUDA
if: ${{ env.BACKEND == 'CUDA'}}
- name: Enable MMAV3 and TMA
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'H100')}}
run: |
echo "ENABLE_TMA=1" >> "${GITHUB_ENV}"
echo "ENABLE_MMA_V3=1" >> "${GITHUB_ENV}"
- name: Run python tests on CUDA with ENABLE_TMA=1 and ENABLE_MMA_V3=1
if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1' && env.ENABLE_MMA_V3 == '1'}}
run: |
cd python/test/unit
python3 -m pytest -n 8 --ignore=runtime
python3 -m pytest -n 8 --ignore=runtime --ignore=operators --ignore=language/test_line_info.py
# run runtime tests serially to avoid race condition with cache handling.
python3 -m pytest runtime/
# run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest language/test_line_info.py
- name: Run python tests on CUDA with ENABLE_TMA=0 and ENABLE_MMA_V3=0
if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0' && env.ENABLE_MMA_V3 == '0'}}
run: |
cd python/test/unit
python3 -m pytest -n 8 --ignore=runtime --ignore=hopper --ignore=operators --ignore=language/test_line_info.py
# run runtime tests serially to avoid race condition with cache handling.
python3 -m pytest runtime/
# run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest language/test_line_info.py
- name: Clear cache
run: |
rm -rf ~/.triton
- name: Run partial tests on CUDA with ENABLE_TMA=1 and ENABLE_MMA_V3=1
if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1' && env.ENABLE_MMA_V3 == '1'}}
run: |
cd python/test/unit
python3 -m pytest -n 8 operators
- name: Run partial tests on CUDA with ENABLE_TMA=0 and ENABLE_MMA_V3=0
if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0' && env.ENABLE_MMA_V3 == '0'}}
run: |
cd python/test/unit
python3 -m pytest -n 8 operators
- name: Create artifacts archive
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}}
run: |
cd ~/.triton
tar -czvf artifacts.tar.gz cache
tar -czf artifacts.tar.gz cache
- name: Upload artifacts archive
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}}
@@ -120,6 +162,7 @@ jobs:
Integration-Tests-Third-Party:
needs: Runner-Preparation
if: false
runs-on: ${{ matrix.runner }}
@@ -219,10 +262,22 @@ jobs:
sudo apt update
sudo apt install gh
- name: Save PR number to a file
env:
PR_NUMBER: ${{ github.event.number }}
run: |
echo $PR_NUMBER > pr_number
- name: Upload PR number to artifacts
uses: actions/upload-artifact@v3
with:
name: pr_number
path: pr_number
- name: Download latest main artifacts
env:
ARTIFACT_NAME: artifacts A100
ARTIFACT_JOB_NAME: Integration-Tests-Nvidia
MAX_NUM_ACTIONS_PAGES: 30
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
OWNER_REPO="${{ github.repository }}"
@@ -239,18 +294,27 @@ jobs:
USER_ID=$(gh api repos/$OWNER_REPO/pulls/$PR_NUMBER --jq '.user.id')
echo "USER_ID: $USER_ID"
run_id_found=false
page=1
while true; do
if [ "$page" -gt $MAX_NUM_ACTIONS_PAGES ]; then
break
fi
run_id=$(gh api --method GET "repos/$OWNER_REPO/actions/runs?page=$page&per_page=100" | jq --arg branch_name "$BRANCH_NAME" --arg run_name "Integration Tests" --arg user_id "$USER_ID" '.workflow_runs[] | select(.head_branch == $branch_name and .name == $run_name and .actor.id == ($user_id | tonumber))' | jq '.id' | head -1)
if [ "$run_id" != "" ]; then
echo "First run ID on branch $BRANCH_NAME is: $run_id"
WORKFLOW_RUN_ID=$run_id
run_id_found=true
break
fi
((page++))
done
if ! $run_id_found; then
echo "No run_id found for PR ${PR_NUMBER}, moving to the next PR."
continue
fi
echo "WORKFLOW_RUN_ID: $WORKFLOW_RUN_ID"
ARTIFACT_URL=$(gh api repos/$OWNER_REPO/actions/runs/$WORKFLOW_RUN_ID/artifacts | jq --arg artifact_name "$ARTIFACT_NAME" '.artifacts[] | select(.name == $artifact_name).archive_download_url' --raw-output)
echo "ARTIFACT_URL: $ARTIFACT_URL"
@@ -290,7 +354,7 @@ jobs:
- name: Compare artifacts
run: |
set +e
python3 python/test/tools/compare_files.py --path1 reference --path2 current --kernels python/test/kernel_comparison/kernels.yml
python3 python/test/tools/compare_files.py --path1 reference --path2 current
exit_code=$?
set -e
echo $exit_code
@@ -304,34 +368,20 @@ jobs:
echo "Error while comparing artifacts"
echo "COMPARISON_RESULT=error" >> $GITHUB_ENV
fi
echo "COMPARISON_RESULT=env.COMPARISON_RESULT"
- name: Check exit code and handle failure
if: ${{ env.COMPARISON_RESULT == 'error' }}
- name: Check comparison result and write to file
run: |
echo "Error while comparing artifacts"
exit 1
- name: Fetch Run ID
id: get_run_id
run: echo "RUN_ID=${{ github.run_id }}" >> $GITHUB_ENV
if [ "${{ env.COMPARISON_RESULT }}" = "true" ]; then
echo "SUCCESS" > comparison_result
else
echo "FAILED" > comparison_result
fi
- name: Upload comparison result to artifacts
uses: actions/upload-artifact@v3
with:
name: comparison_result
path: comparison_result
- name: Upload results as artifact
uses: actions/upload-artifact@v2
with:
name: kernels-reference-check
path: kernels_reference_check.txt
- name: Check output and comment on PR
if: ${{ env.COMPARISON_RESULT == 'false' }}
uses: actions/github-script@v5
with:
github-token: ${{ secrets.CI_ACCESS_TOKEN }}
script: |
const run_id = ${{ env.RUN_ID }};
const issue_number = context.payload.pull_request.number;
const message = `:warning: **This PR does not produce bitwise identical kernels as the branch it's merged against.** Please check artifacts for details. [Download the output file here](https://github.com/${{ github.repository }}/actions/runs/${run_id}).`;
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: issue_number,
body: message
});

5
.gitignore vendored
View File

@@ -25,9 +25,14 @@ venv.bak/
.idea
cmake-build-*
<<<<<<< HEAD
# cache dumps
triton_cache*
log_*
#
python/triton/third_party/cuda/bin/ptxas
=======
# Third-party binaries
ptxas
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272

View File

@@ -218,6 +218,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
TritonAnalysis
TritonTransforms
TritonGPUTransforms
TritonNvidiaGPUTransforms
TritonLLVMIR
TritonPTX
TritonHSACO

View File

@@ -4,11 +4,38 @@
[![Wheels](https://github.com/openai/triton/actions/workflows/wheels.yml/badge.svg?branch=release/2.0.x)](https://github.com/openai/triton/actions/workflows/wheels.yml)
We're hiring! If you are interested in working on Triton at OpenAI, we have roles open for [Compiler Engineers](https://openai.com/careers/software-engineer-triton-compiler) and [Kernel Engineers](https://openai.com/careers/kernel-engineer).
**`Documentation`** |
------------------- |
[![Documentation](https://github.com/openai/triton/actions/workflows/documentation.yml/badge.svg)](https://triton-lang.org/)
# Triton Developer Conference Registration Open
The Triton Developer Conference will be held in a hybrid mode at the Microsoft Silicon Valley Campus in Mountain View, California. The conference will be held on September 20th from 10am to 4pm, followed by a reception till 5:30 pm. Please use the link below to register to attend either in-person or virtually online.
Registration Link for Triton Developer Conference is [here](https://forms.office.com/r/m4jQXShDts)
Tentative Agenda for the conference (subject to change):
|Time |Title |Speaker
|--------|-------|-------|
|10:00 AM|Welcome|Kevin Scott (Microsoft)|
|10:20 AM|The Triton Compiler: Past, Present and Future|Phil Tillet (OpenAI)|
|11:00 AM|**Break**||
|11:20 AM|Hopper support in Triton|Gustav Zhu (Nvidia)|
|11:40 AM|Bringing Triton to AMD GPUs|Jason Furmanek, Lixun Zhang (AMD)|
|12:00 PM|Intel XPU Backend for Triton|Eikan Wang (Intel)|
|12:20 PM|Vectorization of Triton Kernels for Qualcomm Hexagon Backend|Javed Absar (Qualcomm)|
|12:30 PM|**Lunch**||
|1:40 PM |Triton for MTIA|Roman Levenstein et al, (Meta)|
|2:00 PM |Using Triton IR for high-performance fusions in XLA|George Karpenkov (Google)|
|2:20 PM |Triton for All: Triton as a device-independent language|Ian Bearman (Microsoft)|
|2:40 PM|**Break**||
|3:00 PM|PyTorch 2.0 and TorchInductor|Jason Ansel, Horace He (Meta)|
|3:20 PM|Pallas: A JAX Kernel Language|Sharad Vikram (Google)|
|3:40 PM|Writing Grouped GEMMs in Triton|Vinod Grover (Nvidia)|
|4:00 PM|**Reception**||
# Triton
@@ -116,10 +143,6 @@ Version 2.0 is out! New features include:
Community contributions are more than welcome, whether it be to fix bugs or to add new features at [github](https://github.com/openai/triton/). For more detailed instructions, please visit our [contributor's guide](CONTRIBUTING.md).
If youre interested in joining our team and working on Triton & GPU kernels, [were hiring](https://openai.com/jobs/#acceleration)!
# Compatibility

View File

@@ -9,6 +9,7 @@ target_link_libraries(triton-opt PRIVATE
TritonAnalysis
TritonTransforms
TritonGPUTransforms
TritonNvidiaGPUTransforms
${dialect_libs}
${conversion_libs}
# tests
@@ -29,6 +30,7 @@ target_link_libraries(triton-reduce PRIVATE
TritonAnalysis
TritonTransforms
TritonGPUTransforms
TritonNvidiaGPUTransforms
${dialect_libs}
${conversion_libs}
# tests
@@ -48,6 +50,7 @@ llvm_update_compile_flags(triton-translate)
TritonAnalysis
TritonTransforms
TritonGPUTransforms
TritonNvidiaGPUTransforms
TritonLLVMIR
TritonPTX
TritonHSACO

View File

@@ -1,10 +1,13 @@
#pragma once
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
#include "triton/Conversion/NVGPUToLLVM/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
#include "triton/Conversion/TritonToTritonGPU/Passes.h"
@@ -23,15 +26,18 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::registerAllPasses();
mlir::registerTritonPasses();
mlir::registerTritonGPUPasses();
mlir::registerTritonNvidiaGPUPasses();
mlir::test::registerTestAliasPass();
mlir::test::registerTestAlignmentPass();
mlir::test::registerTestAllocationPass();
mlir::test::registerTestMembarPass();
mlir::triton::registerConvertTritonToTritonGPUPass();
mlir::triton::registerConvertTritonGPUToLLVMPass();
mlir::triton::registerConvertNVGPUToLLVMPass();
// TODO: register Triton & TritonGPU passes
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
mlir::arith::ArithDialect, mlir::scf::SCFDialect,
mlir::gpu::GPUDialect>();

View File

@@ -14,6 +14,7 @@
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Target/HSACO/HSACOTranslation.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "triton/Target/PTX/PTXTranslation.h"
@@ -39,6 +40,7 @@ OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename,
mlir::DialectRegistry registry;
registry
.insert<TritonDialect, triton::gpu::TritonGPUDialect,
triton::nvidia_gpu::TritonNvidiaGPUDialect,
mlir::math::MathDialect, arith::ArithDialect, scf::SCFDialect>();
context.appendDialectRegistry(registry);
@@ -122,6 +124,7 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
}
llvm::LLVMContext llvmContext;
<<<<<<< HEAD
#ifdef USE_ROCM
auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module,
SMArch.getValue(), true /*isRocm*/);
@@ -129,6 +132,12 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module,
SMArch.getValue(), false /*isRocm*/);
#endif
=======
mlir::triton::gpu::TMAMetadataTy tmaInfos;
auto llvmir = translateTritonGPUToLLVMIR(
&llvmContext, *module, SMArch.getValue(), tmaInfos, Target::Default);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
if (!llvmir) {
llvm::errs() << "Translate to LLVM IR failed";
}

View File

@@ -101,11 +101,12 @@ sphinx_gallery_conf = {
'gallery_dirs': 'getting-started/tutorials',
'filename_pattern': '',
# XXX: Temporarily disable fused attention tutorial on V100
'ignore_pattern': r'__init__\.py',
'ignore_pattern': r'(__init__\.py|09.*\.py|10.*\.py)',
'within_subsection_order': FileNameSortKey,
'reference_url': {
'sphinx_gallery': None,
}
},
'abort_on_example_error': True,
}
# Add any paths that contain templates here, relative to this directory.
@@ -144,7 +145,7 @@ release = ''
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
language = 'en'
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.

View File

@@ -0,0 +1,41 @@
#### Agenda:
##### Announcements:
1. Triton conference registration opening soon. Conference on 20th September at the Microsoft Silicon Valley Campus.
##### Items:
1. H100 updates
2. Triton release plan update
3. Linalg updates
4. Intel GPU Backend status update.
5. Intel working on the CPU backend for Triton.
6. AMD updates
7. Open discussion
##### Minutes:
Recording link [here](https://drive.google.com/file/d/19Nnc0i7zUyn-ni2RSFHbPHHiPkYU96Mz/view)
1. H100 updates:
- Preliminary support is merged, disabled by default, can be enabled with env variables
- Supports latest tensor cores, FP8s. Support for Flash Attention on the main branch coming soon.
- Performance is very good on Matmuls, 80-90% of cublas on large Matmuls right now, will eventually reach parity with cublas. Above 600 teraflops on fp16 on xxm card, cublas is 670 on random input data. FP8 is twice that, around 1.2 petaflops.
- Hopper support includes the full FP8 support for compute.
2. Triton release plan update
- No specific dates for now, plan is to release before end of 2023.
- Will move to 3.0 release due to minor backward compatibility breaking changes. For eg. Will move compiler options in the indexing operators as hardcoded operators in the kernel, will bump the major version.
- Functionally the main goal will be to have 3rd party plugins for Intel and AMD gpus.
- May synchronise with a PyTorch release so that PyTorch can benefit from the latest features, however continuous integration workflow is the default release cadence expected.
- Will switch the default behavior to optimized mode for the release, needs more discussion with Nvidia.
- Will expose flags for a user to enable kernel selection themselves.
- Open question: Pytorch hasnt rebased to latest triton, it is close to PyTorch code freeze will PyTorch still sync with Triton 2.0? Will we have another release to support triton 2.0?
- Community can start with the latest stable branch and rebase 3rd party plugin on top of that. OAI has no resources to commit to, but community can contribute.
3. Linalg updates
- Discussion on Github for Linalg as a middle layer between the language and target hardware. Includes support for block pointers and modulo operators.
- Please join the conversation [here](https://github.com/openai/triton/discussions/1842)
- Branch pushed is behind the tip, will work on getting it caught up on the tip.
4. Intel GPU Backend status update.
- Please refer to slides [here](https://github.com/openai/triton/blob/main/docs/meetups/Intel%20XPU%20Backend%20for%20Triton%20-%20Update%20-%200823.pptx)
5. Intel working on the CPU backend for Triton.
- Please refer to slides [here](https://github.com/openai/triton/blob/main/docs/meetups/Intel%20XPU%20Backend%20for%20Triton%20-%20Update%20-%200823.pptx)
6. AMD updates
- Please refer to slides [here](https://github.com/openai/triton/blob/main/docs/meetups/Triton_AMD_update_0823.pdf).

Binary file not shown.

View File

@@ -192,3 +192,4 @@ Iterators
:nosignatures:
static_range
multiple_of

View File

@@ -9,6 +9,7 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include <atomic>
#include <limits>
@@ -158,17 +159,17 @@ private:
BufferKind kind;
BufferId id;
size_t size;
size_t alignment;
size_t offset;
bool operator==(const BufferT &other) const { return id == other.id; }
bool operator<(const BufferT &other) const { return id < other.id; }
BufferT() : BufferT(BufferKind::Explicit) {}
BufferT(BufferKind kind)
: kind(kind), id(InvalidBufferId), size(0), offset(0) {}
BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {}
BufferT(BufferKind kind, size_t size, size_t offset)
: kind(kind), id(nextId++), size(size), offset(offset) {}
BufferT() : BufferT(BufferKind::Explicit, 0) {}
BufferT(BufferKind kind, size_t size, size_t alignment = 4,
size_t offset = 0)
: kind(kind), id(nextId++), size(size), alignment(alignment),
offset(offset) {}
};
/// Op -> Scratch Buffer

View File

@@ -21,7 +21,7 @@ namespace mlir {
/// This lattice value represents known information on the axes of a lattice.
class AxisInfo {
public:
typedef SmallVector<int64_t, 4> DimVectorT;
typedef SmallVector<int64_t> DimVectorT;
public:
/// Default constructor

View File

@@ -3,6 +3,7 @@
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <algorithm>
#include <numeric>
@@ -125,7 +126,11 @@ bool isSingleValue(Value value);
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
Type getElementType(Value value);
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
// TODO: Move utility functions that belong to ConvertLayoutOp to class
// ConvertLayoutOpHelper in the future
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout);
template <typename T_OUT, typename T_IN>
inline SmallVector<T_OUT> convertType(ArrayRef<T_IN> in) {
@@ -332,6 +337,10 @@ protected:
FuncDataMapT funcMap;
SmallVector<FunctionOpInterface> roots;
};
// Create a basic DataFlowSolver with constant and dead code analysis included.
std::unique_ptr<DataFlowSolver> createDataFlowSolver();
triton::MakeTensorPtrOp getMakeTensorPtrOp(Value v);
} // namespace mlir

View File

@@ -1,2 +1,3 @@
add_subdirectory(TritonToTritonGPU)
add_subdirectory(TritonGPUToLLVM)
add_subdirectory(NVGPUToLLVM)

View File

@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls --name NVGPUToLLVM)
add_public_tablegen_target(NVGPUConversionPassIncGen)

View File

@@ -0,0 +1,19 @@
#ifndef TRITON_CONVERSION_NVGPU_TO_LLVM_PASS_H
#define TRITON_CONVERSION_NVGPU_TO_LLVM_PASS_H
#include <memory>
namespace mlir {
class ModuleOp;
template <typename T> class OperationPass;
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>> createConvertNVGPUToLLVMPass();
} // namespace triton
} // namespace mlir
#endif

View File

@@ -0,0 +1,16 @@
#ifndef NVGPU_CONVERSION_PASSES_H
#define NVGPU_CONVERSION_PASSES_H
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.h"
namespace mlir {
namespace triton {
#define GEN_PASS_REGISTRATION
#include "triton/Conversion/NVGPUToLLVM/Passes.h.inc"
} // namespace triton
} // namespace mlir
#endif

View File

@@ -0,0 +1,20 @@
#ifndef NVGPU_CONVERSION_PASSES
#define NVGPU_CONVERSION_PASSES
include "mlir/Pass/PassBase.td"
def ConvertNVGPUToLLVM : Pass<"convert-nv-gpu-to-llvm", "mlir::ModuleOp"> {
let summary = "Convert NVGPU to LLVM";
let description = [{
}];
let constructor = "mlir::triton::createConvertNVGPUToLLVMPass()";
let dependentDialects = ["mlir::arith::ArithDialect",
"mlir::LLVM::LLVMDialect",
"mlir::NVVM::NVVMDialect",
"mlir::triton::nvgpu::NVGPUDialect"];
}
#endif

View File

@@ -151,6 +151,12 @@ struct PTXBuilder {
// aggressive optimizations that may lead to incorrect results.
Operand *newOperand(StringRef constraint, bool init = false);
// Create a new operand that is tied to a previous operand. In this case the
// asm would be permitted to write to an input register. Instead of providing
// constraint code for this operand, the constraint code of the tied operand
// is used.
Operand *newOperand(unsigned operandIndex);
// Create a constant integer operand.
Operand *newConstantOperand(int64_t v);
// Create a constant operand with explicit code specified.

View File

@@ -19,6 +19,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
"mlir::tensor::TensorDialect",
"mlir::triton::TritonDialect",
"mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
"mlir::ROCDL::ROCDLDialect",
"mlir::NVVM::NVVMDialect"];
@@ -26,9 +27,16 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">,
Option<"isROCM", "is-rocm",
"bool", /*default*/"false",
"compile for ROCM-compatible LLVM">,
Option<"tmaMetadata", "tma-metadata",
"mlir::triton::gpu::TMAMetadataTy*", /*default*/"nullptr",
"tma metadata to the runtime">,
Option<"target", "target", "enum Target", "mlir::triton::Target::Default",
"compile for target compatible LLVM",
"llvm::cl::values("
"clEnumValN(mlir::triton::Target::NVVM, \"nvvm\", \"compile for "
"NVVM-compatible LLVM\"), "
"clEnumValN(mlir::triton::Target::ROCDL, \"rocdl\", \"compile for "
"ROCDL-compatible LLVM\"))">,
];
}

View File

@@ -3,6 +3,8 @@
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Target/PTX/TmaMetadata.h"
#include <memory>
namespace mlir {
@@ -12,7 +14,14 @@ template <typename T> class OperationPass;
namespace triton {
enum Target { NVVM, ROCDL, Default = NVVM };
#define GEN_PASS_DECL
#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc"
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass();
std::unique_ptr<OperationPass<ModuleOp>>
<<<<<<< HEAD
#ifdef USE_ROCM
createConvertTritonGPUToLLVMPass(int computeCapability = 80,
bool isROCM = true);
@@ -20,6 +29,10 @@ createConvertTritonGPUToLLVMPass(int computeCapability = 80,
createConvertTritonGPUToLLVMPass(int computeCapability = 80,
bool isROCM = false);
#endif
=======
createConvertTritonGPUToLLVMPass(const ConvertTritonGPUToLLVMOptions &options);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
} // namespace triton
} // namespace mlir

View File

@@ -2,6 +2,7 @@
#define TRITON_CONVERSION_PASSES_H
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
#include "triton/Target/PTX/TmaMetadata.h"
namespace mlir {
namespace triton {

View File

@@ -25,6 +25,12 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
Option<"threadsPerWarp", "threads-per-warp",
"int32_t", /*default*/"TRITONGPU_DEFAULT_WARPSIZE",
"number of threads per warp">,
Option<"numCTAs", "num-ctas",
"int32_t", /*default*/"1",
"number of ctas in a cga">,
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"compute capability">
];
}

View File

@@ -11,6 +11,9 @@ template <typename T> class OperationPass;
namespace triton {
constexpr static char AttrNumWarpsName[] = "triton_gpu.num-warps";
constexpr static char AttrNumCTAsName[] = "triton_gpu.num-ctas";
constexpr static char AttrComputeCapabilityName[] =
"triton_gpu.compute-capability";
constexpr static char AttrNumThreadsPerWarp[] = "triton_gpu.threads-per-warp";
@@ -19,7 +22,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonToTritonGPUPass();
// Create the pass with numWarps set explicitly.
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonToTritonGPUPass(int numWarps, int threadsPerWarp = 32);
createConvertTritonToTritonGPUPass(int numWarps, int threadsPerWarp = 32,
int numCTAs = 1, int computeCapability = 80);
} // namespace triton
} // namespace mlir

View File

@@ -1,2 +1,4 @@
add_subdirectory(Triton)
add_subdirectory(TritonGPU)
add_subdirectory(TritonNvidiaGPU)
add_subdirectory(NVGPU)

View File

@@ -0,0 +1,2 @@
add_subdirectory(IR)
#add_subdirectory(Transforms)

View File

@@ -0,0 +1,14 @@
set(LLVM_TARGET_DEFINITIONS NVGPUOps.td)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=nvgpu)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=nvgpu)
mlir_tablegen(OpsConversions.inc -gen-llvmir-conversions)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(NVGPUTableGen)
set(LLVM_TARGET_DEFINITIONS NVGPUAttrDefs.td)
mlir_tablegen(NVGPUAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(NVGPUAttrDefs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(NVGPUAttrDefsIncGen)

View File

@@ -0,0 +1,47 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef TRITON_DIALECT_NVGPU_IR_DIALECT_H_
#define TRITON_DIALECT_NVGPU_IR_DIALECT_H_
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "triton/Dialect/NVGPU/IR/Dialect.h.inc"
#include "triton/Dialect/NVGPU/IR/OpsEnums.h.inc"
#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"
#define GET_OP_CLASSES
#include "triton/Dialect/NVGPU/IR/Ops.h.inc"
namespace mlir {
namespace triton {
namespace nvgpu {} // namespace nvgpu
} // namespace triton
} // namespace mlir
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_

View File

@@ -0,0 +1,33 @@
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files
// (the "Software"), to deal in the Software without restriction,
// including without limitation the rights to use, copy, modify, merge,
// publish, distribute, sublicense, and/or sell copies of the Software,
// and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#ifndef NVGPU_ATTRDEFS
#define NVGPU_ATTRDEFS
include "triton/Dialect/NVGPU/IR/NVGPUDialect.td"
include "mlir/IR/AttrTypeBase.td"
class NVGPU_Attr<string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Attribute">
: AttrDef<NVGPU_Dialect, name, traits, baseCppClass> {
}
#endif

View File

@@ -0,0 +1,40 @@
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files
// (the "Software"), to deal in the Software without restriction,
// including without limitation the rights to use, copy, modify, merge,
// publish, distribute, sublicense, and/or sell copies of the Software,
// and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#ifndef NVGPU_DIALECT
#define NVGPU_DIALECT
include "mlir/IR/OpBase.td"
def NVGPU_Dialect : Dialect {
let name = "nvgpu";
let cppNamespace = "::mlir::triton::nvgpu";
let description = [{
NVGPU Dialect.
}];
let dependentDialects = [
"mlir::LLVM::LLVMDialect"
];
}
#endif

View File

@@ -0,0 +1,248 @@
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files
// (the "Software"), to deal in the Software without restriction,
// including without limitation the rights to use, copy, modify, merge,
// publish, distribute, sublicense, and/or sell copies of the Software,
// and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#ifndef NVGPU_OPS
#define NVGPU_OPS
include "triton/Dialect/NVGPU/IR/NVGPUDialect.td"
include "triton/Dialect/NVGPU/IR/NVGPUAttrDefs.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/EnumAttr.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
def I8Ptr_global : LLVM_IntPtrBase<8, 1>;
def I8Ptr_shared : LLVM_IntPtrBase<8, 3>;
def I64Ptr_shared : LLVM_IntPtrBase<64, 3>;
class NVGPU_Op<string mnemonic, list<Trait> traits = []> :
LLVM_OpBase<NVGPU_Dialect, mnemonic, traits>;
def NVGPU_WGMMAFenceOp : NVGPU_Op<"wgmma_fence", []> {
let assemblyFormat = "attr-dict";
}
def NVGPU_WGMMACommitGroupOp : NVGPU_Op<"wgmma_commit_group", []> {
let assemblyFormat = "attr-dict";
}
def NVGPU_WGMMAWaitGroupOp : NVGPU_Op<"wgmma_wait_group", []> {
let arguments = (ins I32Attr:$pendings);
let assemblyFormat = "attr-dict";
}
def NVGPU_MBarrierInitOp : NVGPU_Op<"mbarrier_init", [MemoryEffects<[MemWrite]>]> {
let arguments = (ins I64Ptr_shared:$mbarrier, I1:$pred, I32Attr:$count);
let assemblyFormat = "$mbarrier `,` $pred attr-dict `:` type($mbarrier)";
}
def MBarrier_ArriveTypeAttr : I32EnumAttr<"MBarriveType",
"mbarrier arrive type, either 'normal', 'expect_tx', 'cp_async'",
[
I32EnumAttrCase<"normal", 0>,
I32EnumAttrCase<"cp_async", 1>,
I32EnumAttrCase<"expect_tx", 2>,
I32EnumAttrCase<"remote", 3>,
]>{
let cppNamespace = "::mlir::triton::nvgpu";
}
def NVGPU_MBarrierArriveOp : NVGPU_Op<"mbarrier_arrive", []> {
let arguments = (ins I64Ptr_shared:$mbarrier, I1:$pred, Optional<I32>:$ctaId, MBarrier_ArriveTypeAttr:$arriveType, DefaultValuedAttr<I32Attr, "0">:$txCount);
let assemblyFormat = "$mbarrier `,` $pred (`,` $ctaId^)? attr-dict `:` type($mbarrier)";
}
def NVGPU_MBarrierWaitOp : NVGPU_Op<"mbarrier_wait", []> {
let arguments = (ins I64Ptr_shared:$mbarrier, I1:$phase);
let assemblyFormat = "$mbarrier `,` $phase attr-dict `:` type(operands)";
}
def NVGPU_NamedBarrierArriveOp : NVGPU_Op<"bar_arrive", []> {
let arguments = (ins I32:$bar, I32:$numThreads);
let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)";
}
def NVGPU_NamedBarrierWaitOp : NVGPU_Op<"bar_wait", []> {
let arguments = (ins I32:$bar, I32:$numThreads);
let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)";
}
def WGMMADesc_ModeAttr : I32EnumAttr<"WGMMADescMode",
"wgmma desc mode, either 'none', 'swizzle128', 'swizzle64', or 'swizzle32'",
[
I32EnumAttrCase<"none", 0>,
I32EnumAttrCase<"swizzle128", 1>,
I32EnumAttrCase<"swizzle64", 2>,
I32EnumAttrCase<"swizzle32", 3>
]>{
let cppNamespace = "::mlir::triton::nvgpu";
}
def NVGPU_WGMMADescCreateOp : NVGPU_Op<"wgmma_desc_create", []> {
let arguments = (ins LLVM_AnyPointer:$buffer, I32:$height, WGMMADesc_ModeAttr:$mode);
let results = (outs I64:$res);
let assemblyFormat = "$buffer `,` $height attr-dict `:` functional-type(operands, results)";
}
def NVGPU_TMALoadTiledOp : NVGPU_Op<"tma_load_tiled", [AttrSizedOperandSegments]> {
let arguments = (ins I8Ptr_shared:$dst, I64Ptr_shared:$mbarrier, I8Ptr_global:$tmaDesc, I64:$l2Desc,
I1:$pred, Variadic<I32>:$coords, Optional<I16>:$mcastMask);
let assemblyFormat = "operands attr-dict `:` type(operands)";
}
def NVGPU_TMALoadIm2colOp : NVGPU_Op<"tma_load_im2col", []> {
let arguments = (ins I8Ptr_shared:$dst, I64Ptr_shared:$mbarrier, I8Ptr_global:$tmaDesc, I64:$l2Desc, LLVM_AnyStruct:$im2colOffsets, I1:$pred, Variadic<I32>:$coords, I16Attr:$mcastMask);
let assemblyFormat = "operands attr-dict `:` type(operands)";
}
def WGMMA_LayoutAttr : I32EnumAttr<"WGMMALayout",
"wgmma layout, either 'row' or 'col'",
[
I32EnumAttrCase<"row", 0>,
I32EnumAttrCase<"col", 1>
]>{
let cppNamespace = "::mlir::triton::nvgpu";
}
def WGMMA_EltTypeAttr : I32EnumAttr<"WGMMAEltType",
"wgmma operand type, either 's8', 's32', 'e4m3', 'e5m2', 'f16', 'bf16', 'tf32', or 'f32'",
[
I32EnumAttrCase<"s8", 0>,
I32EnumAttrCase<"s32", 1>,
I32EnumAttrCase<"e4m3", 2>,
I32EnumAttrCase<"e5m2", 3>,
I32EnumAttrCase<"f16", 4>,
I32EnumAttrCase<"bf16", 5>,
I32EnumAttrCase<"tf32", 6>,
I32EnumAttrCase<"f32", 7>
]>{
let cppNamespace = "::mlir::triton::nvgpu";
}
def WGMMA_OperandType : AnyTypeOf<[LLVM_AnyStruct, I64], "wgmma operand A/B type">;
def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> {
let arguments = (ins WGMMA_OperandType:$opA, WGMMA_OperandType:$opB, LLVM_AnyStruct:$opC,
I32Attr:$m, I32Attr:$n, I32Attr:$k,
WGMMA_EltTypeAttr:$eltTypeC, WGMMA_EltTypeAttr:$eltTypeA, WGMMA_EltTypeAttr:$eltTypeB,
WGMMA_LayoutAttr:$layoutA, WGMMA_LayoutAttr:$layoutB);
let results = (outs LLVM_AnyStruct:$res);
let assemblyFormat = "$opA `,` $opB `,` $opC attr-dict `:` functional-type(operands, $res)";
}
def NVGPU_CGABarrierSyncOp : NVGPU_Op<"cga_barrier_sync", []> {
let assemblyFormat = "attr-dict";
}
def NVGPU_CGABarrierArriveOp : NVGPU_Op<"cga_barrier_arrive", []> {
let assemblyFormat = "attr-dict";
}
def NVGPU_CGABarrierWaitOp : NVGPU_Op<"cga_barrier_wait", []> {
let assemblyFormat = "attr-dict";
}
def NVGPU_LoadDSmemOp : NVGPU_Op<"load_dsmem", [MemoryEffects<[MemRead]>]> {
let arguments = (ins LLVM_AnyPointer:$addr, I32:$ctaId, I32Attr:$bitwidth, I32Attr:$vec);
let builders = [
OpBuilder<(ins "Type":$resultTy, "Value":$addr, "Value":$ctaId)>,
OpBuilder<(ins "Value":$addr, "Value":$ctaId, "unsigned":$bitwidth, "unsigned":$vec)>,
OpBuilder<(ins "Value":$addr, "Value":$ctaId, "unsigned":$bitwidth)>
];
let results = (outs LLVM_LoadableType:$result);
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
def NVGPU_StoreDSmemOp : NVGPU_Op<"store_dsmem", [MemoryEffects<[MemWrite]>]> {
let arguments = (ins LLVM_AnyPointer:$addr, I32:$ctaId,
Variadic<LLVM_LoadableType>:$values, I1:$pred);
let builders = [
OpBuilder<(ins "Value":$addr, "Value":$ctaId, "Value":$value, "Value":$pred)>,
];
let assemblyFormat = "operands attr-dict `:` type(operands)";
let extraClassDeclaration = [{
unsigned getBitwidth();
unsigned getVec();
}];
}
def NVGPU_FenceAsyncSharedOp : NVGPU_Op<"fence_async_shared", []> {
let arguments = (ins BoolAttr:$bCluster);
let assemblyFormat = "attr-dict";
}
def NVGPU_FenceMBarrierInitOp : NVGPU_Op<"fence_mbarrier_init", []> {
let assemblyFormat = "attr-dict";
}
def NVGPU_ClusterArriveOp : NVGPU_Op<"cluster_arrive", []> {
let arguments = (ins I1Attr:$relaxed);
let assemblyFormat = "attr-dict";
}
def NVGPU_ClusterWaitOp : NVGPU_Op<"cluster_wait", []> {
let assemblyFormat = "attr-dict";
}
def NVGPU_TMAStoreTiledOp : NVGPU_Op<"tma_store_tiled", [MemoryEffects<[MemWrite]>]> {
let arguments = (ins I8Ptr_global:$tmaDesc, I8Ptr_shared:$src, I1:$pred, Variadic<I32>:$coords);
let assemblyFormat = "operands attr-dict `:` type(operands)";
}
def NVGPU_StoreMatrixOp : NVGPU_Op<"stmatrix", [MemoryEffects<[MemWrite]>]> {
let arguments = (ins I8Ptr_shared:$addr, Variadic<I32>:$datas);
let assemblyFormat = "operands attr-dict `:` type(operands)";
}
def NVGPU_OffsetOfStmatrixV4Op : NVGPU_Op<"offset_of_stmatrix_v4", []> {
let arguments = (ins I32:$threadId, I32:$rowOfWarp, I32:$elemIdx, I32Attr:$leadingDimOffset, I32Attr:$rowStride, I1Attr:$swizzleEnabled);
let results = (outs I32:$offset);
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($offset)";
}
def NVGPU_OffsetOfSts64Op : NVGPU_Op<"offset_of_sts64", []> {
let arguments = (ins I32:$threadId, I32:$rowOfWarp, I32:$elemIdx, I32Attr:$leadingDimOffset, I32Attr:$rowStride, I1Attr:$swizzleEnabled);
let results = (outs I32:$offset);
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($offset)";
}
def NVGPU_Sts64Op : NVGPU_Op<"sts64", [MemoryEffects<[MemWrite]>]> {
let arguments = (ins I32:$offset, AnyTypeOf<[F32, I32]>:$d0, AnyTypeOf<[F32, I32]>:$d1);
let assemblyFormat = "operands attr-dict `:` type(operands)";
}
def NVGPU_ClusterCTAIdOp : NVGPU_Op<"cluster_id", [Pure]> {
let results = (outs I32:$result);
let assemblyFormat = "attr-dict";
}
def NVGPU_RegAllocOp : NVGPU_Op<"reg_alloc", []> {
let arguments = (ins I32Attr: $regCount);
let assemblyFormat = "operands attr-dict `:` type(operands)";
}
def NVGPU_RegDeallocOp : NVGPU_Op<"reg_dealloc", []> {
let arguments = (ins I32Attr: $regCount);
let assemblyFormat = "operands attr-dict `:` type(operands)";
}
#endif

View File

@@ -9,6 +9,8 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
#include "triton/Dialect/Triton/IR/OpsEnums.h.inc"

View File

@@ -9,6 +9,8 @@ include "mlir/IR/OpBase.td"
include "mlir/IR/FunctionInterfaces.td" // FunctionOpInterface
include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface
include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface
include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
@@ -135,7 +137,7 @@ def TT_LoadOp : TT_Op<"load",
[SameLoadStoreOperandsAndResultShape,
SameLoadStoreOperandsAndResultEncoding,
AttrSizedOperandSegments,
MemoryEffects<[MemRead]>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TypesMatchWith<"infer ptr type from result type",
"result", "ptr", "$_self",
"mlir::OpTrait::impl::verifyLoadStorePointerAndValueType">,
@@ -461,33 +463,23 @@ def TT_ScanReturnOp: TT_Op<"scan.return",
//
// External Elementwise op
//
class TT_ExternElementwiseOpBase<string mnemonic, list<Trait> traits = []> :
TT_Op<mnemonic,
traits # [Elementwise,
SameOperandsAndResultEncoding,
SameVariadicOperandSize]> {
def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise,
SameOperandsAndResultEncoding,
SameVariadicOperandSize,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let description = [{
call an external function $symbol implemented in $libpath/$libname with $args
return $libpath/$libname:$symbol($args...)
}];
let arguments = (ins Variadic<TT_Type>:$args, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol);
let arguments = (ins Variadic<TT_Type>:$args, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol, BoolAttr:$pure);
let results = (outs TT_Type:$result);
let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)";
}
def TT_PureExternElementwiseOp : TT_ExternElementwiseOpBase<"pure_extern_elementwise", [Pure, Elementwise]> {
let summary = "FFI for pure element-wise extern LLVM bitcode functions";
}
def TT_ImpureExternElementwiseOp : TT_ExternElementwiseOpBase<"impure_extern_elementwise", [MemoryEffects<[MemRead]>,
MemoryEffects<[MemWrite]>]> {
let summary = "FFI for impure element-wise extern LLVM bitcode functions";
}
//
// Make Range Op
//
@@ -506,6 +498,30 @@ def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> {
let results = (outs TT_IntTensor:$result);
let assemblyFormat = "attr-dict `:` type($result)";
let hasFolder = 1;
}
//
// ElementwiseInlineAsm Op
//
def TT_ElementwiseInlineAsmOp : TT_Op<"elementwise_inline_asm", [Elementwise,
SameOperandsAndResultEncoding,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "inline assembly applying elementwise operation to a group of packed element.";
let description = [{
This will apply the given in inline assembly to `packed_element` number of
elements of the inputs. The elements packed together is unknown and will
depend on the backend implementation.
}];
let arguments = (ins StrAttr:$asm_string, StrAttr:$constraints, BoolAttr:$pure, I32Attr:$packed_element, Variadic<AnyTypeOf<[TT_Type]>>:$args);
let results = (outs TT_Type:$result);
let assemblyFormat = [{
$asm_string attr-dict ($args^ `:` type($args))? `->` type($result)
}];
}
//
@@ -563,6 +579,7 @@ def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr",
let results = (outs TT_TensorPtr:$result);
// TODO(Keren): define a custom assembly format for this op because the result type cannot be printed correctly
// Add additional `[]` to increase readability and split variadic lists
let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` `,` `[` $offsets `]` attr-dict `:` type($result)";

View File

@@ -14,7 +14,7 @@ class TritonTypeDef<string name, string _mnemonic>
}
// Floating-point Type
def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E4M3B11FNUZ, F8E5M2, F16, BF16, F32, F64], "floating-point">;
def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2, F16, BF16, F32, F64], "floating-point">;
def TT_FloatTensor : TensorOf<[TT_Float]>;
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
@@ -74,7 +74,7 @@ def TT_PtrType : TritonTypeDef<"Pointer", "ptr"> {
// Scalar Pointer Type: `ptr<>`
def TT_Ptr : TT_PtrOf<[AnyType]>;
// Tensor of Pointer Type
// Tensor of Pointer Type: `tensor<ptr<>>`
def TT_PtrTensor : TensorOf<[TT_Ptr]>;
// Tensor of Pointer Type or Pointer type: `tensor<ptr<>>` or `ptr<>`

View File

@@ -14,6 +14,8 @@ namespace triton {
bool isTensorPointerType(Type type);
bool isTensorOrTensorPointerType(Type type);
unsigned getPointeeBitWidth(Type type);
Type getPointeeType(Type type);

View File

@@ -9,7 +9,6 @@ namespace triton {
std::unique_ptr<Pass> createCombineOpsPass();
std::unique_ptr<Pass> createReorderBroadcastPass();
std::unique_ptr<Pass>
createRewriteTensorPointerPass(int computeCapability = 80,
bool isROCM = false);

View File

@@ -3,9 +3,13 @@ mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_gpu)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_gpu)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_gpu)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_gpu)
add_public_tablegen_target(TritonGPUTableGen)
set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td)
mlir_tablegen(TritonGPUAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(TritonGPUAttrDefs.cpp.inc -gen-attrdef-defs)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(TritonGPUAttrDefsIncGen)

View File

@@ -7,6 +7,7 @@
#include "mlir/IR/Dialect.h"
// TritonGPU depends on Triton
#include "triton/Dialect/NVGPU/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
@@ -73,17 +74,41 @@ getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);
SmallVector<unsigned> getThreadsPerCTA(Attribute layout);
SmallVector<unsigned>
getShapePerCTA(Attribute layout,
ArrayRef<int64_t> tensorShape = ArrayRef<int64_t>());
SmallVector<unsigned> getOrder(Attribute layout);
CTALayoutAttr getCTALayout(Attribute layout);
SmallVector<unsigned> getCTAsPerCGA(Attribute layout);
SmallVector<unsigned> getCTASplitNum(Attribute layout);
SmallVector<unsigned> getCTAOrder(Attribute layout);
/* The difference between ShapePerCTATile and ShapePerCTA:
* (1) ShapePerCTATile is defined by SizePerThread * ThreadsPerWarp *
* WarpsPerCTA in each dimension and is independent from the tensor shape.
* (2) ShapePerCTA is defined by shape / CTASplitNum in each dimension.
* (3) In the implementation of emitIndices, ShapePerCTATile will
* be replicated or wraped to fit ShapePerCTA.
*/
SmallVector<unsigned>
getShapePerCTATile(Attribute layout,
ArrayRef<int64_t> tensorShape = ArrayRef<int64_t>());
SmallVector<int64_t> getShapePerCTA(ArrayRef<unsigned> CTASplitNum,
ArrayRef<int64_t> shape);
SmallVector<int64_t> getShapePerCTA(Attribute layout, ArrayRef<int64_t> shape);
SmallVector<int64_t> getShapePerCTA(Type type);
unsigned getNumWarpsPerCTA(Attribute layout);
unsigned getNumCTAs(Attribute layout);
bool isaDistributedLayout(Attribute layout);
bool isSharedEncoding(Value value);
bool isExpensiveCat(CatOp cat, Attribute &targetEncoding);
bool isExpensiveCat(CatOp cat, Attribute targetEncoding);
} // namespace gpu
} // namespace triton

View File

@@ -41,6 +41,19 @@ Right now, Triton implements two classes of layouts: shared, and distributed.
}];
}
//===----------------------------------------------------------------------===//
// CTA Layout
//===----------------------------------------------------------------------===//
def CTALayoutAttr : TritonGPU_Attr<"CTALayout"> {
let parameters = (
ins
ArrayRefParameter<"unsigned">:$CTAsPerCGA,
ArrayRefParameter<"unsigned">:$CTASplitNum,
ArrayRefParameter<"unsigned">:$CTAOrder
);
}
//===----------------------------------------------------------------------===//
// Shared Layout Encoding
//===----------------------------------------------------------------------===//
@@ -64,19 +77,41 @@ are stored contiguously
_ _ _ _ /\_ _ _ _
A_{2, 2} A_{2, 3} A_{2, 0} A_{2, 1} ... [phase 1] \ per phase = 2
A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
For MMAv3 eg Hopper GMMA, hasLeadingOffset should be true. In this case,
when the matrix is stored in shared memory, there will be an offset not
only in the stride dimension, but also in the leading dimension. For example,
a matrix of size 16x128 and data type I8 is stored in the shared memory with
64B-swizzle mode. The offset of the element with index (0, 64) will be 16*64,
compared to 1*64 when the hasLeadingOffset is false.
}];
// swizzle info: vec, perPhase, maxPhase
// order: the fastest-changing axis first
let parameters = (
ins
// swizzle info
"unsigned":$vec, "unsigned":$perPhase, "unsigned":$maxPhase,
ArrayRefParameter<"unsigned", "order of axes by the rate of changing">:$order
"unsigned":$vec,
"unsigned":$perPhase,
"unsigned":$maxPhase,
ArrayRefParameter<"unsigned">:$order,
"CTALayoutAttr":$CTALayout,
"bool":$hasLeadingOffset
);
let builders = [
AttrBuilder<(ins "unsigned":$vec,
"unsigned":$perPhase,
"unsigned":$maxPhase,
"ArrayRef<unsigned>":$order,
"CTALayoutAttr":$CTALayout), [{
bool hasLeadingOffset = false; // default value
return $_get(context, vec, perPhase, maxPhase, order, CTALayout, hasLeadingOffset);
}]>,
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
"ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$order,
"CTALayoutAttr":$CTALayout,
"unsigned":$typeWidthInBit), [{
#ifdef USE_ROCM
@@ -123,9 +158,10 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
auto mmaEnc = dotOpEnc.getParent().dyn_cast<MmaEncodingAttr>();
if(!mmaEnc)
return $_get(context, 1, 1, 1, order);
return get(context, 1, 1, 1, order, CTALayout);
int opIdx = dotOpEnc.getOpIdx();
auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape);
// number of rows per phase
@@ -134,34 +170,42 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
// ---- begin Volta ----
if (mmaEnc.isVolta()) {
int perPhase = 128 / (shape[order[0]] * (typeWidthInBit / 8));
int perPhase = 128 / (shapePerCTA[order[0]] * (typeWidthInBit / 8));
perPhase = std::max<int>(perPhase, 1);
bool is_row = order[0] != 0;
bool is_vec4 = opIdx == 0 ? !is_row && (shape[order[0]] <= 16) :
is_row && (shape[order[0]] <= 16);
bool is_vec4 = opIdx == 0 ? !is_row && (shapePerCTA[order[0]] <= 16) :
is_row && (shapePerCTA[order[0]] <= 16);
int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) :
((is_row && !is_vec4) ? 2 : 1);
int rep = 2 * pack_size;
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
int vec = 2 * rep;
return $_get(context, vec, perPhase, maxPhase, order);
return get(context, vec, perPhase, maxPhase, order, CTALayout);
}
// ---- begin Ampere ----
if (mmaEnc.isAmpere()) {
<<<<<<< HEAD
int perPhase = 128 / (shape[order[0]] * 4 / dotOpEnc.getKWidth());
=======
int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth());
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
perPhase = std::max<int>(perPhase, 1);
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()};
// for now, disable swizzle when using transposed int8 tensor cores
if ((32 / typeWidthInBit != dotOpEnc.getKWidth()) && order[0] == inner)
<<<<<<< HEAD
return $_get(context, 1, 1, 1, order);
=======
return get(context, 1, 1, 1, order, CTALayout);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
// --- handle A operand ---
if (opIdx == 0) { // compute swizzling for A operand
int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m
int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2];
int maxPhase = mmaStride / perPhase;
return $_get(context, vec, perPhase, maxPhase, order);
return get(context, vec, perPhase, maxPhase, order, CTALayout);
}
// --- handle B operand ---
@@ -169,12 +213,19 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
int vec = (order[0] == 1) ? matShape[1] : matShape[2]; // n : k
int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1];
int maxPhase = mmaStride / perPhase;
return $_get(context, vec, perPhase, maxPhase, order);
return get(context, vec, perPhase, maxPhase, order, CTALayout);
}
llvm_unreachable("invalid operand index");
}
// ---- begin version 3 ----
if (mmaEnc.isHopper()) {
llvm_unreachable("SharedEncodingAttr builder when the MMAEncodingAttr"
" is Hopper has not been implemented yet");
return $_get(context, 1, 1, 1, order, CTALayout, true);
}
// ---- not implemented ----
llvm_unreachable("unsupported swizzling for provided MMA version");
}]>,
@@ -182,9 +233,38 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
"ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$order,
"CTALayoutAttr":$CTALayout,
"Type":$eltTy), [{
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
return get(context, dotOpEnc, shape, order, bitwidth);
return get(context, dotOpEnc, shape, order, CTALayout, bitwidth);
}]>,
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$order,
"CTALayoutAttr":$CTALayout,
"Type":$eltTy), [{
auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape);
int32_t eleBitWidth = eltTy.getIntOrFloatBitWidth();
int32_t vec = 128 / eleBitWidth, perPhase = 1, maxPhase = 1;
// get proper shared memory swizzling mode from the contiguous dimension
// size of the origin blocked layout.
auto contigDimSizeInByte = shapePerCTA[order[0]] * eleBitWidth / 8;
if (contigDimSizeInByte >= 128 && contigDimSizeInByte % 128 == 0) {
perPhase = 1;
maxPhase = 8;
} else if (contigDimSizeInByte >= 64 && contigDimSizeInByte % 64 == 0) {
perPhase = 2;
maxPhase = 4;
} else if (contigDimSizeInByte >= 32 && contigDimSizeInByte % 32 == 0) {
perPhase = 4;
maxPhase = 2;
} else {
llvm_unreachable("unsupported shared memory layout for MMAv3");
}
return $_get(context, vec, perPhase, maxPhase, order, CTALayout, true);
}]>
];
@@ -236,7 +316,7 @@ used to promote memory coalescing in LoadInst and StoreInst.
It is characterized by three tuples -- thread tile size, warp tile size, and block tile size -- which
specify the amount of elements owned by each CUDA thread, warp and CTA respectively.
For example, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows.
Example 1, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows:
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
@@ -252,82 +332,143 @@ for
sizePerThread = {2, 2}
threadsPerWarp = {8, 4}
warpsPerCTA = {1, 2}
CTAsPerCGA = {1, 1}
}>
Example 2, a row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) as follows:
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
... ...
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
... ...
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
for
#triton_gpu.blocked_layout<{
sizePerThread = {2, 2}
threadsPerWarp = {8, 4}
warpsPerCTA = {1, 2}
CTAsPerCGA = {1, 1}
}>
Example 3, A row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) and
4 CTAs (taking 2x2 for example) as follows:
CTA [0,0] CTA [0,1]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
... ...
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
CTA [1,0] CTA [1,1]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
... ...
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
for
#triton_gpu.blocked_layout<{
sizePerThread = {2, 2}
threadsPerWarp = {8, 4}
warpsPerCTA = {1, 2}
CTAsPerCGA = {2, 2}
}>
}];
let builders = [
// Custom builder initializes sizePerWarp and sizePerCTA automatically
// TODO: compiles on MacOS but not linux?
// AttrBuilder<(ins "ArrayRef<unsigned>":$sizePerThread,
// "ArrayRef<unsigned>":$threadsPerWarp,
// "ArrayRef<unsigned>":$warpsPerCTA,
// "ArrayRef<unsigned>":$order), [{
// int rank = threadsPerWarp.size();
// SmallVector<unsigned, 4> sizePerWarp(rank);
// SmallVector<unsigned, 4> sizePerCTA(rank);
// for (unsigned i = 0; i < rank; i++) {
// sizePerWarp.push_back(sizePerThread[i] * threadsPerWarp[i]);
// sizePerCTA.push_back(sizePerWarp[i] * warpsPerCTA[i]);
// }
// return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, sizePerWarp, sizePerCTA);
// }]>,
// Custom builder initializes sizePerWarp and sizePerCTA automatically
// Default builder takes sizePerThread, order and numWarps, and tries to
// pack numWarps*32 threads in the provided order for use in a type
// of the given shape.
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$sizePerThread,
"ArrayRef<unsigned>":$order,
"unsigned":$numWarps,
"unsigned":$threadsPerWarp), [{
int rank = sizePerThread.size();
unsigned remainingLanes = threadsPerWarp;
unsigned remainingThreads = numWarps*threadsPerWarp;
unsigned remainingWarps = numWarps;
unsigned prevLanes = 1;
unsigned prevWarps = 1;
SmallVector<unsigned, 4> rankedThreadsPerWarp(rank);
SmallVector<unsigned, 4> warpsPerCTA(rank);
for (int _dim = 0; _dim < rank - 1; ++_dim) {
int i = order[_dim];
unsigned threadsPerCTA = std::clamp<unsigned>(remainingThreads, 1, shape[i] / sizePerThread[i]);
rankedThreadsPerWarp[i] = std::clamp<unsigned>(threadsPerCTA, 1, remainingLanes);
warpsPerCTA[i] = std::clamp<unsigned>(threadsPerCTA / rankedThreadsPerWarp[i], 1, remainingWarps);
remainingWarps /= warpsPerCTA[i];
remainingLanes /= rankedThreadsPerWarp[i];
remainingThreads /= threadsPerCTA;
prevLanes *= rankedThreadsPerWarp[i];
prevWarps *= warpsPerCTA[i];
}
// Expand the last dimension to fill the remaining lanes and warps
rankedThreadsPerWarp[order[rank-1]] = threadsPerWarp / prevLanes;
warpsPerCTA[order[rank-1]] = numWarps / prevWarps;
return $_get(context, sizePerThread, rankedThreadsPerWarp, warpsPerCTA, order);
}]>
];
let extraClassDeclaration = extraBaseClassDeclaration # [{
SliceEncodingAttr squeeze(int axis);
}];
let parameters = (
ins
ArrayRefParameter<"unsigned">:$sizePerThread,
ArrayRefParameter<"unsigned">:$threadsPerWarp,
ArrayRefParameter<"unsigned">:$warpsPerCTA,
// fastest-changing axis first
ArrayRefParameter<
"unsigned",
"order of axes by the rate of changing"
>:$order
// These attributes can be inferred from the rest
// ArrayRefParameter<"unsigned">:$sizePerWarp,
// ArrayRefParameter<"unsigned">:$sizePerCTA
ArrayRefParameter<"unsigned">:$order, // the fastest-changing axis first
"CTALayoutAttr":$CTALayout
);
let builders = [
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$sizePerThread,
"ArrayRef<unsigned>":$order,
"unsigned":$numWarps,
"unsigned":$numThreadsPerWarp,
"CTALayoutAttr":$CTALayout), [{
unsigned rank = sizePerThread.size();
SmallVector<unsigned, 4> threadsPerWarp(rank);
SmallVector<unsigned, 4> warpsPerCTA(rank);
SmallVector<int64_t> shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape);
unsigned remainingLanes = numThreadsPerWarp;
unsigned remainingThreads = numWarps * numThreadsPerWarp;
unsigned remainingWarps = numWarps;
unsigned prevLanes = 1;
unsigned prevWarps = 1;
// starting from the contiguous dimension
for (unsigned d = 0; d < rank - 1; ++d) {
unsigned i = order[d];
unsigned threadsPerCTA = std::clamp<unsigned>(remainingThreads, 1, shapePerCTA[i] / sizePerThread[i]);
threadsPerWarp[i] = std::clamp<unsigned>(threadsPerCTA, 1, remainingLanes);
warpsPerCTA[i] = std::clamp<unsigned>(threadsPerCTA / threadsPerWarp[i], 1, remainingWarps);
remainingWarps /= warpsPerCTA[i];
remainingLanes /= threadsPerWarp[i];
remainingThreads /= threadsPerCTA;
prevLanes *= threadsPerWarp[i];
prevWarps *= warpsPerCTA[i];
}
// Expand the last dimension to fill the remaining lanes and warps
threadsPerWarp[order[rank - 1]] = numThreadsPerWarp / prevLanes;
warpsPerCTA[order[rank - 1]] = numWarps / prevWarps;
return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, CTALayout);
}]>,
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$sizePerThread,
"ArrayRef<unsigned>":$order,
"unsigned":$numWarps,
"unsigned":$numThreadsPerWarp,
"unsigned":$numCTAs), [{
unsigned rank = sizePerThread.size();
SmallVector<unsigned, 4> CTAsPerCGA(rank);
SmallVector<unsigned, 4> CTASplitNum(rank);
ArrayRef<unsigned> CTAOrder = order;
unsigned remainingCTAs = numCTAs;
// starting from the most strided dimension
for (int d = rank - 1; d >= 0; --d) {
unsigned i = order[d];
CTAsPerCGA[i] = std::clamp<unsigned>(remainingCTAs, 1, shape[i] / sizePerThread[i]);
CTASplitNum[i] = CTAsPerCGA[i];
remainingCTAs /= CTAsPerCGA[i];
}
CTAsPerCGA[rank - 1] *= remainingCTAs; // wrap at CTA level
CTALayoutAttr CTALayout = CTALayoutAttr::get(context, CTAsPerCGA, CTASplitNum, CTAOrder);
return get(context, shape, sizePerThread, order, numWarps, numThreadsPerWarp, CTALayout);
}]>
];
let extraClassDeclaration = extraBaseClassDeclaration # [{
SliceEncodingAttr squeeze(int axis);
}];
let hasCustomAssemblyFormat = 1;
}
@@ -423,13 +564,17 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
ins
"unsigned":$versionMajor,
"unsigned":$versionMinor,
ArrayRefParameter<"unsigned">:$warpsPerCTA
ArrayRefParameter<"unsigned">:$warpsPerCTA,
"CTALayoutAttr":$CTALayout,
ArrayRefParameter<"unsigned">:$instrShape
);
let builders = [
// Specially for MMAV1(Volta)
AttrBuilder<(ins "int":$versionMajor,
"int":$numWarps,
"CTALayoutAttr":$CTALayout,
"ArrayRef<unsigned>":$instrShape,
"ArrayRef<int64_t>":$shapeC,
"bool":$isARow,
"bool":$isBRow,
@@ -443,7 +588,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
(isAVec4 * (1<<2)) |\
(isBVec4 * (1<<3));
// TODO: Share code with
// DotOpMmaV1ConversionHelper::AParam/BParam, since same code to compute the
// rep,spw and fpw.
@@ -468,11 +612,13 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
wpt[1] = std::clamp<int>(wpt[1] * 2, 1, shapeC[1] / spw[1]);
} while (wpt_nm1 != wpt);
return $_get(context, versionMajor, versionMinor, wpt);
return $_get(context, versionMajor, versionMinor, wpt, CTALayout, instrShape);
}]>,
AttrBuilder<(ins "int":$versionMajor,
"int":$numWarps,
"CTALayoutAttr":$CTALayout,
"ArrayRef<unsigned>":$instrShape,
"ArrayRef<int64_t>":$shapeA,
"ArrayRef<int64_t>":$shapeB,
"ArrayRef<int64_t>":$shapeC,
@@ -482,15 +628,21 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
assert(versionMajor == 1 && "This builder is specially for versionMajor==1");
bool isAVec4 = !isARow && (shapeA[isARow] <= 16);
bool isBVec4 = isBRow && (shapeB[isBRow] <= 16);
return get(context, versionMajor, numWarps, shapeC, isARow, isBRow, isAVec4, isBVec4, id);
return get(context, versionMajor, numWarps, CTALayout, instrShape, shapeC, isARow, isBRow, isAVec4, isBVec4, id);
}]>
];
let extraClassDeclaration = extraBaseClassDeclaration # [{
bool isVolta() const;
bool isTuring() const;
bool isAmpere() const;
bool isHopper() const;
unsigned getElemsPerThreadOfOperand(int opIdx, ArrayRef<int64_t> shape) const;
// Get [isARow, isBRow, isAVec4, isBVec4, id] from versionMinor
std::tuple<bool, bool, bool, bool, int> decodeVoltaLayoutStates() const;
// Number of bits in versionMinor to hold the ID of the MMA encoding instance.
// Here 5 bits can hold 32 IDs in a single module.
static constexpr int numBitsToHoldMmaV1ID{5};
@@ -670,6 +822,4 @@ section 9.7.13.4.1 for more details.
}];
}
#endif

View File

@@ -16,6 +16,7 @@ def TritonGPU_Dialect : Dialect {
let dependentDialects = [
"triton::TritonDialect",
"mlir::triton::nvgpu::NVGPUDialect",
"mlir::gpu::GPUDialect",
"tensor::TensorDialect",
];
@@ -23,14 +24,27 @@ def TritonGPU_Dialect : Dialect {
let extraClassDeclaration = [{
static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; }
static int getNumWarps(ModuleOp mod) {
Attribute numWarps = mod->getDiscardableAttr("triton_gpu.num-warps");
if(!numWarps)
if(!mod->hasAttr("triton_gpu.num-warps"))
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.num-warps attribute");
return numWarps.cast<IntegerAttr>().getInt();
return mod->getAttr("triton_gpu.num-warps").cast<IntegerAttr>().getInt();
}
static int getNumCTAs(ModuleOp mod) {
if(!mod->hasAttr("triton_gpu.num-ctas"))
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.num-ctas attribute");
return mod->getAttr("triton_gpu.num-ctas").cast<IntegerAttr>().getInt();
}
static int getComputeCapability(ModuleOp mod) {
if(!mod->hasAttr("triton_gpu.compute-capability"))
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.compute-capability attribute");
return mod->getAttrOfType<IntegerAttr>("triton_gpu.compute-capability").getInt();
}
void registerTypes();
static std::string getThreadsPerWarpAttrName() { return "triton_gpu.threads-per-warp"; }
static int getThreadsPerWarp(ModuleOp mod) {
Attribute threadsPerWarp = mod->getDiscardableAttr("triton_gpu.threads-per-warp");
if(!threadsPerWarp) {
@@ -38,6 +52,7 @@ def TritonGPU_Dialect : Dialect {
}
return threadsPerWarp.cast<IntegerAttr>().getInt();
}
<<<<<<< HEAD
static int getSharedSize(ModuleOp mod) {
Attribute sharedAttr = mod->getDiscardableAttr("triton_gpu.shared");
if(!sharedAttr) {
@@ -46,6 +61,8 @@ def TritonGPU_Dialect : Dialect {
return sharedAttr.cast<IntegerAttr>().getInt();
}
=======
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
}];
let useDefaultAttributePrinterParser = 1;

View File

@@ -2,6 +2,7 @@
#define TRITONGPU_OPS
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
@@ -46,6 +47,20 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
}];
}
def TTG_AsyncBulkWaitOp : TTG_Op<"async_bulk_wait"> {
let summary = "async bulk wait";
let arguments = (ins I32Attr:$num);
let assemblyFormat = "attr-dict";
let extraClassDeclaration = [{
static bool isSupported(int computeCapability) {
return computeCapability >= 90;
}
}];
}
def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> {
let summary = "async commit group";
@@ -58,6 +73,18 @@ def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> {
}];
}
def TTG_AsyncBulkCommitGroupOp : TTG_Op<"async_bulk_commit_group"> {
let summary = "async bulk commit group";
let assemblyFormat = "attr-dict";
let extraClassDeclaration = [{
static bool isSupported(int computeCapability) {
return computeCapability >= 90;
}
}];
}
// Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU.
// This is needed because these ops don't
@@ -106,6 +133,98 @@ def TTG_SelectOp : TTG_Op<"select", [Pure, Elementwise,
let results = (outs TT_Type:$result);
}
// TODO[goostavz]: extract a base class for InsertSlice & InsertSliceAsync once the op definition is verified
def TTG_InsertSliceOp : TTG_Op<"insert_slice",
[AttrSizedOperandSegments,
ResultsAreSharedEncoding,
MemoryEffects<[MemRead, MemWrite]>,
TypesMatchWith<"infer mask type from src type",
"src", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
TypesMatchWith<"infer other type from src type",
"src", "other", "getPointeeType($_self)",
"($_op.getOperands().size() <= 4) || std::equal_to<>()">]> {
let summary = "insert slice";
let description = [{
This operation inserts a tensor `$src` into another tensor `$dst` as specified by the operations
`$index` argument and `$axis` attribute.
It returns a copy of `$dst` with the proper slice updated with the value of `$src`.
When converting from `tt.load` to `triton_gpu.insert_slice`, the `$evict`, `$cache`, and `$isVolatile` fields
might be ignored on certain hardware. For example, on NVIDIA GPUs, the cache policy is determined by the backend,
and `$evict` and `$isVolatile` are ignored because they apply to L1 cache only.
The insert_slice operation supports the following arguments:
* src: the tensor that is inserted.
* dst: the tensor into which the `$src` tensor is inserted.
* index: the index of the `$src` tensor at the given `$axis` from which the `$dst` tensor is inserted into
* mask: optional tensor-rank number of boolean masks which specify which
elements of the `$src` tensor are inserted into the `$dst` tensor.
* other: optional tensor-rank number of other tensors which specify what
values are inserted into the `$dst` tensor if the corresponding
element of the `$mask` tensor is false.
ttgpu.load_tile_async depracate
triton_gpu.insert_slice might be further lowered into triton_gpu_async for different hardware implementations
like tt.load, ttgpu.insert_slice/insert_slice_async has two modes up to the type of src
mode 1: ptr/src is a tensor of pointers
mode 2: ptr/src is a tensor pointer
Some typical lowering paths are:
in case the load is pipelined by the pipeline pass( load is inside kBlock loop, which means "pipeline pass):
Load from global + store to shared : tt.load(mode 1) -(tt->ttgpu+Coalesce)-> tt.load(mode 1) -(Pipeline)-> ttgpu.insert_slice(mode 1)
Non-bulk cp.async : tt.load(mode 1) -(tt->ttgpu+Coalesce)-> tt.load(mode 1) -(Pipeline)-> ttgpu.insert_slice(mode 1) -(MaterializeLoad)> ttgpu.insert_slice_async(mode 1) + ttgpu.await-> llvm
TMA load : tt.load(mode 2) -(tt->ttgpu+Coalesce)-> tt.load(mode 2) -(Pipeline)-> ttgpu.insert_slice(mode 2) -(MaterializeLoad)> ttgpu.insert_slice_async_v2(mode 2) + ttgpu.await-> llvm
otherwise:
Load from global + store to shared : tt.load(mode 1) -(tt->ttgpu+Coalesce)-> tt.load(mode 1)
Non-bulk cp.async : tt.load(mode 1) -(tt->ttgpu+Coalesce)-> tt.load(mode 1) -> ... -(MaterializeLoad)-> ttgpu.insert_slice_async(mode 1) + ttgpu.await -> llvm
TMA load : tt.load(mode 2) -(tt->ttgpu+Coalesce)-> tt.load(mode 2) -> ... -(MaterializeLoad)-> ttgpu.insert_slice_async(mode 2) + ttgpu.await -> llvm
Example:
```
%1 = triton_gpu.alloc_tensor : tensor<2x32xf32>
%2 = triton_gpu.insert_slice %0, %1, %index { axis = 0 } : tensor<32x!tt.ptr<f32>, #AL> -> tensor<2x32xf32, #A>
```
}];
let arguments = (ins TT_PtrLike:$src, TT_Tensor:$dst, I32:$index,
Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
BoolAttr:$isVolatile, I32Attr:$axis);
let builders = [
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index,
"triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index, "Value":$mask,
"triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index,
"Value":$mask, "Value":$other,
"triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
];
let results = (outs TT_Tensor:$result);
let extraClassDeclaration = [{
static DenseSet<unsigned> getEligibleLoadByteWidth(int computeCapability) {
DenseSet<unsigned> validLoadBytes;
if (computeCapability >= 80) {
validLoadBytes = {4, 8, 16};
}
return validLoadBytes;
}
}];
let hasCustomAssemblyFormat = 1;
}
def TTG_ExtractSliceOp : TTG_Op<"extract_slice",
@@ -173,7 +292,8 @@ def TTG_ExtractSliceOp : TTG_Op<"extract_slice",
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
[AttrSizedOperandSegments,
ResultsAreSharedEncoding,
MemoryEffects<[MemRead]>,
// TODO: Check if MemWrite will degrade performance of non-warp-specialized kernel
MemoryEffects<[MemRead, MemWrite]>,
TypesMatchWith<"infer mask type from src type",
"src", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
@@ -219,7 +339,7 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
```
}];
let arguments = (ins TT_PtrTensor:$src, TT_Tensor:$dst, I32:$index,
let arguments = (ins TT_PtrLike:$src, TT_Tensor:$dst, I32:$index,
Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
BoolAttr:$isVolatile, I32Attr:$axis);

View File

@@ -0,0 +1,26 @@
#ifndef TRITONGPU_TYPES
#define TRITONGPU_TYPES
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
include "mlir/IR/AttrTypeBase.td"
class TTG_TypeDef<string name, string _mnemonic, list<Trait> traits = []>
: TypeDef<TritonGPU_Dialect, name, traits> {
let mnemonic = _mnemonic;
}
def TTG_TokenType : TTG_TypeDef<"Token", "token"> {
let parameters = (ins "int32_t":$type);
let builders = [
TypeBuilder<(ins "unsigned":$type), [{
return $_get($_ctxt, type);
}]>
];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}
#endif

View File

@@ -0,0 +1,10 @@
#ifndef TRITONGPU_IR_TYPES_H_
#define TRITONGPU_IR_TYPES_H_
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#define GET_TYPEDEF_CLASSES
#include "triton/Dialect/TritonGPU/IR/Types.h.inc"
#endif // TRITON_IR_TYPES_H_

View File

@@ -2,9 +2,14 @@
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_
#include "mlir/Pass/Pass.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
namespace mlir {
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 2);
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 3,
int numWarps = 4,
int numCTAs = 1,
int computeCapability = 80);
std::unique_ptr<Pass> createTritonGPUStreamPipelinePass();
@@ -27,6 +32,8 @@ std::unique_ptr<Pass> createTritonGPUVerifier();
std::unique_ptr<Pass> createTritonGPUOptimizeDotOperandsPass();
std::unique_ptr<Pass> createTritonGPUOptimizeEpiloguePass();
/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"

View File

@@ -14,13 +14,23 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
let constructor = "mlir::createTritonGPUPipelinePass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
"mlir::scf::SCFDialect",
"mlir::arith::ArithDialect"];
let options = [
Option<"numStages", "num-stages",
"int32_t", /*default*/"2",
"number of pipeline stages">
"int32_t", /*default*/"3",
"number of pipeline stages">,
Option<"numWarps", "num-warps",
"int32_t", /*default*/"4",
"number of warps per block">,
Option<"numCTAs", "num-ctas",
"int32_t", /*default*/"1",
"number of CTAs per CGA">,
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
];
}
@@ -65,6 +75,7 @@ def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::Modul
let constructor = "mlir::createTritonGPUAccelerateMatmulPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
"mlir::triton::TritonDialect"];
let options = [
@@ -85,6 +96,7 @@ def TritonGPUOptimizeDotOperands : Pass<"tritongpu-optimize-dot-operands", "mlir
let constructor = "mlir::createTritonGPUOptimizeDotOperandsPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
"mlir::triton::TritonDialect"];
}
@@ -111,6 +123,20 @@ def TritonGPURemoveLayoutConversions : Pass<"tritongpu-remove-layout-conversions
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
}
def TritonGPUOptimizeEpilogue : Pass<"tritongpu-optimize-epilogue", "mlir::ModuleOp"> {
let summary = "Optimize epilogue: (1) Store accumulators directly without going thorough SMEM in epilogue.";
let description = [{
}];
let constructor = "mlir::createTritonGPUOptimizeEpiloguePass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
}
def TritonGPUReorderInstructions: Pass<"tritongpu-reorder-instructions", "mlir::ModuleOp"> {

View File

@@ -13,15 +13,17 @@ namespace mlir {
class TritonGPUTypeConverter : public TypeConverter {
public:
TritonGPUTypeConverter(MLIRContext *context, int numWarps,
int threadsPerWarp);
TritonGPUTypeConverter(MLIRContext *context, int numWarps, int threadsPerWarp,
int numCTAs);
int getNumWarps() const { return numWarps; }
int getThreadsPerWarp() const { return threadsPerWarp; }
int getNumCTAs() const { return numCTAs; }
private:
MLIRContext *context;
int numWarps;
int threadsPerWarp;
int numCTAs;
};
class TritonGPUConversionTarget : public ConversionTarget {

View File

@@ -10,33 +10,143 @@
namespace mlir {
LogicalResult fixupLoops(ModuleOp mod);
namespace triton {
class LoadOp;
class StoreOp;
class FuncOp;
namespace gpu {
class SharedEncodingAttr;
}
} // namespace triton
// TODO: Interface
LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
Attribute &ret);
SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
const ArrayRef<int64_t> &shape,
RankedTensorType type);
bool isExpensiveLoadOrStore(Operation *op, Attribute &targetEncoding);
/// Returns true if the Load is for TMA
bool isLoadFromTensorPtr(triton::LoadOp op);
bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding);
/// Returns true if the store is for TMA
bool isStoreToTensorPtr(triton::StoreOp op);
// skipInit is True when we only consider the operands of the initOp but
// not the initOp itself.
int simulateBackwardRematerialization(
Operation *initOp, SetVector<Operation *> &processed,
SetVector<Attribute> &layout, llvm::MapVector<Value, Attribute> &toConvert,
Attribute targetEncoding);
/// Return the first consumer of v
Operation *getFirstUser(Value v);
/// Return the proper SharedEncodingAttr according to shape/order
triton::gpu::SharedEncodingAttr getSharedEncoding(RankedTensorType tensorTy);
/* Dump Triton IR in graphviz dot format.
*
* You can override `onValue` and `onOperation` in a subclass to mark
* specific Values and Operations. The below subclass
* GraphLayoutMarker is an example.
*
* Default NodeInfo for Value nodes:
* {{"shape": "box"},
* {"style", "filled"},
* {"fillcolor", "white"},
* {"label", shapeStr}}
*
* Default NodeInfo for Operation nodes:
* {{"shape": "ellipse"},
* {"style", "filled"},
* {"fillcolor", "white"},
* {"label", operationName}}
*
* If the key "label" is not set by `onValue` or `onOperation`, default labels
* will be generated. For Value node, the default label is the shape string and
* for Operation node, it is the operation name.
*
* Reference:
* https://graphviz.org/doc/info/shapes.html
* https://graphviz.org/doc/info/colors.html
*
* Usage:
* C++: GraphDumper().dumpToFile(func, "func.dot");
* Shell: dot -Tjpg func.dot -o func.jpg
*/
class GraphDumper {
public:
using NodeInfo = std::map<std::string, std::string>;
// Override this function to mark specific Values
virtual NodeInfo onValue(Value value) const;
// Override this function to mark specific Operations
virtual NodeInfo onOperation(Operation *op) const;
std::string dump(triton::FuncOp func) const;
void dumpToFile(triton::FuncOp func, const std::string &filename) const;
protected:
std::string getShapeStr(const Type &type) const;
std::string getUniqueId(Value value) const;
std::string getUniqueId(Operation *op) const;
std::string emitNode(const std::string &id, const NodeInfo style) const;
std::string emitEdge(const std::string &srcId,
const std::string &destId) const;
std::string emitValueNode(Value value) const;
std::string emitOperationNode(Operation *op) const;
};
/* A subclass of GraphDumper that marks different layout kinds in different
* colors.*/
class GraphLayoutMarker : public GraphDumper {
public:
NodeInfo onValue(Value value) const override;
protected:
std::string getColor(const Type &type) const;
};
// Infers the encoding of the result of op given the source encoding.
std::optional<Attribute> inferDstEncoding(Operation *op, Attribute encoding);
// Infers the encoding of the source of op given the result encoding.
std::optional<Attribute> inferSrcEncoding(Operation *op, Attribute encoding);
bool isExpensiveLoadOrStore(Operation *op);
bool canFoldIntoConversion(Operation *op, Attribute targetEncoding);
Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
IRMapping &mapping);
void rematerializeConversionChain(
const llvm::MapVector<Value, Attribute> &toConvert,
mlir::PatternRewriter &rewriter, SetVector<Operation *> &processed,
IRMapping &mapping);
// Get backward slice of tensor values starting from the root node along with
// encoding propagation.
LogicalResult getConvertBackwardSlice(
Value root, SetVector<Value> &slice, Attribute rootEncoding,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation = nullptr);
LogicalResult canMoveOutOfLoop(BlockArgument arg,
SmallVector<Operation *> &cvts);
// Populate pattern to remove dead cycles in ForOp.
void populateForOpDeadArgumentElimination(RewritePatternSet &patterns);
// Convert an \param index to a multi-dim coordinate given \param shape and
// \param order.
SmallVector<Value> delinearize(OpBuilder &b, Location loc, Value linear,
ArrayRef<unsigned> shape,
ArrayRef<unsigned> order);
SmallVector<Value> delinearize(OpBuilder &b, Location loc, unsigned linear,
ArrayRef<unsigned> shape);
SmallVector<Value> delinearize(OpBuilder &b, Location loc, Value linear,
ArrayRef<unsigned> shape);
Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
ArrayRef<unsigned> shape, ArrayRef<unsigned> order);
Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
ArrayRef<unsigned> shape);
// Returns null if the op is not inside a agent region (warp specialization
// mode). Note that there should be at most one agent id attached to the
// operation.
std::optional<int> getWSAgentId(Operation *op);
std::optional<int> getWSRoleId(Operation *op);
void setRoleId(Operation *op, int roleId);
} // namespace mlir

View File

@@ -0,0 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -0,0 +1,15 @@
set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUOps.td)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_nvidia_gpu)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_nvidia_gpu)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_nvidia_gpu)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_nvidia_gpu)
add_public_tablegen_target(TritonNvidiaGPUTableGen)
set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUAttrDefs.td)
mlir_tablegen(TritonNvidiaGPUAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(TritonNvidiaGPUAttrDefs.cpp.inc -gen-attrdef-defs)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(TritonNvidiaGPUAttrDefsIncGen)

View File

@@ -0,0 +1,46 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_
#define TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
// TritonNvidiaGPU depends on Triton
#include "triton/Dialect/NVGPU/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Traits.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc"
#include "triton/Dialect/TritonNvidiaGPU/IR/Traits.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h"
#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.h.inc"
#define GET_OP_CLASSES
#include "triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc"
#endif // TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_

View File

@@ -0,0 +1,53 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef TRITON_NVIDIA_GPU_IR_TRAITS_H_
#define TRITON_NVIDIA_GPU_IR_TRAITS_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
namespace mlir {
namespace OpTrait {
// These functions are out-of-line implementations of the methods in the
// corresponding trait classes. This avoids them being template
// instantiated/duplicated.
namespace impl {
LogicalResult verifySource1IsSharedEncoding(Operation *op);
} // namespace impl
template <typename ConcreteType>
class Source1IsSharedEncoding
: public TraitBase<ConcreteType, Source1IsSharedEncoding> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySource1IsSharedEncoding(op);
}
};
} // namespace OpTrait
} // namespace mlir
#endif

View File

@@ -0,0 +1,29 @@
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files
// (the "Software"), to deal in the Software without restriction,
// including without limitation the rights to use, copy, modify, merge,
// publish, distribute, sublicense, and/or sell copies of the Software,
// and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#ifndef TRITONNVIDIAGPU_ATTRDEFS
#define TRITONNVIDIAGPU_ATTRDEFS
include "mlir/IR/AttrTypeBase.td"
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
#endif

View File

@@ -0,0 +1,82 @@
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files
// (the "Software"), to deal in the Software without restriction,
// including without limitation the rights to use, copy, modify, merge,
// publish, distribute, sublicense, and/or sell copies of the Software,
// and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#ifndef TRITONNVIDIAGPU_DIALECT
#define TRITONNVIDIAGPU_DIALECT
include "mlir/IR/OpBase.td"
def TritonNvidiaGPU_Dialect : Dialect {
let name = "triton_nvidia_gpu";
let cppNamespace = "::mlir::triton::nvidia_gpu";
let hasOperationAttrVerify = 1;
let description = [{
Triton Nvidia GPU Dialect.
}];
let dependentDialects = [
"triton::TritonDialect",
"triton::gpu::TritonGPUDialect",
"mlir::triton::nvgpu::NVGPUDialect",
"mlir::gpu::GPUDialect",
"tensor::TensorDialect",
];
let extraClassDeclaration = [{
static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; }
static int getNumWarps(ModuleOp mod) {
if(!mod->hasAttr("triton_gpu.num-warps"))
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.num-warps attribute");
return mod->getAttr("triton_gpu.num-warps").cast<IntegerAttr>().getInt();
}
static int getNumCTAs(ModuleOp mod) {
if(!mod->hasAttr("triton_gpu.num-ctas"))
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.num-ctas attribute");
return mod->getAttr("triton_gpu.num-ctas").cast<IntegerAttr>().getInt();
}
static int getComputeCapability(ModuleOp mod) {
if(!mod->hasAttr("triton_gpu.compute-capability"))
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.compute-capability attribute");
return mod->getAttrOfType<IntegerAttr>("triton_gpu.compute-capability").getInt();
}
void registerTypes();
// Warp specialization related:
static std::string getWSSupportedAttrName() { return "triton_gpu.enable-warp-specialization"; }
static int getWSSupportedAttr(ModuleOp mod) {
auto name = getWSSupportedAttrName();
if (!mod->hasAttr(name)) return 0;
return mod->getAttrOfType<IntegerAttr>(name).getInt();
}
}];
let useDefaultTypePrinterParser = 1;
}
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td"
#endif

View File

@@ -0,0 +1,385 @@
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files
// (the "Software"), to deal in the Software without restriction,
// including without limitation the rights to use, copy, modify, merge,
// publish, distribute, sublicense, and/or sell copies of the Software,
// and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#ifndef TRITONNVIDIAGPU_OPS
#define TRITONNVIDIAGPU_OPS
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td"
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td"
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td"
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/ViewLikeInterface.td"
def Source1IsSharedEncoding: NativeOpTrait<"Source1IsSharedEncoding">;
def ResultsAreSharedEncoding: NativeOpTrait<"ResultsAreSharedEncoding">;
class TTNG_Op<string mnemonic, list<Trait> traits = []> :
Op<TritonNvidiaGPU_Dialect, mnemonic, traits>;
// --------------------------------------------------------------------------------------------------
// MBarrier related Ops:
// 1, These mbarrier commands are currently not needed, and not taken into consideration:
// (1), mbarrier.expect_tx
// (2), mbarrier.arrive_drop
// (3), mbarrier.complete_tx
// (4), mbarrier.inval
//
// 2, The mbarriers is supported to be created in vector, and accessed in seperate via tensor.extract.
// The mbarriers created in vector will have counters initialized in the same configuration. A
// typical example to demonstrate this:
//
// %1 = triton_nvidia_gpu.alloc_mbarrier { count = 1 } : tensor<4x!tt.ptr<i64>>
// scf.for %iv = %lb to %ub step %step iter_args() -> () {
// %buffer_id = arith.remi %iv, %c4 : i32
// %2 = triton_nvidia_gpu.extract_mbarrier %1[%buffer_id] : tensor<4xi64>, i32 -> !tt.ptr<i64>
// triton_nvidia_gpu.mbarrier_arrive %2 {expectTx = 2048} : !tt.ptr<i64> -> ()
// }
// ...
// scf.for %iv = %lb to %ub step %step iter_args() -> () {
// %buffer_id = arith.remi %iv, %c4 : i32
// %2 = triton_nvidia_gpu.extract_mbarrier %1[%buffer_id] : tensor<4xi64>, i32 -> !tt.ptr<i64>
// triton_nvidia_gpu.mbarrier_wait %2, %c0 : !tt.ptr<i64>, i1 -> ()
// }
def TTNG_AllocMBarrierOp : TTNG_Op<"alloc_mbarrier", [MemoryEffects<[MemAlloc]>]> {
let summary = "allocate a vector of mbarriers";
let description = [{
Allocate and initialize a vector of mbarriers. The size of the vector is implied in the returned type.
Each mbarrier is initialized as:
1, the current phase initialized to 0.
2, the expected arrival count initialized to 'count'.
3, the pending arrival count initialized to 'count'.
4, the tx-count initialized to 0.
Example:
case a. when created in vector:
%1 = triton_nvidia_gpu.alloc_mbarrier { count = 1 } : tensor<4xi64>
case b. when created in scalar:
%1 = triton_nvidia_gpu.alloc_mbarrier { count = 1 } : !tt.ptr<i64>
}];
let assemblyFormat = [{attr-dict `:` type($result)}];
let arguments = (ins I32Attr:$count);
let results = (outs AnyTypeOf<[TT_Ptr, I64Tensor]>:$result);
}
def TTNG_ExtractMBarrierOp : TTNG_Op<"extract_mbarrier", [Pure]> {
let summary = "extract a mbarrier from a vector of mbarriers";
let description = [{
Extract a mbarrier from a vector of mbarriers
Example:
%1 = triton_nvidia_gpu.extract_mbarrier %mbarriers[%idx] : tensor<4xi64>, index -> !tt.ptr<i64>
}];
let assemblyFormat = "$tensor `[` $index `]` attr-dict `:` type($tensor) `,` type($index) `->` type($result)";
let arguments = (ins I64Tensor:$tensor, I32:$index);
let results = (outs TT_Ptr:$result);
}
def TTNG_MBarrierWaitOp : TTNG_Op<"mbarrier_wait", [MemoryEffects<[MemRead, MemWrite]>]> {
let summary = "mbarrier wait";
let description = [{
This operation defining the waiting action for a mbarrier.
The subsequent operations should not execute until this operation completes waiting.
Example:
triton_nvidia_gpu.mbarrier_wait %0, %1 : !tt.ptr<i64>
}];
let arguments = (ins TT_Ptr:$mbarrier, I1: $phase);
let assemblyFormat = "$mbarrier `,` $phase attr-dict `:` type($mbarrier)";
}
def TTNG_MBarrierArriveOp : TTNG_Op<"mbarrier_arrive", [AttrSizedOperandSegments,
MemoryEffects<[MemWrite]>]> {
let summary = "mbarrier arrive";
let description = [{
This operation defining the arriving action for a mbarrier.
txCount:
An optional attribute that set tx-count. This Op will be lowered into
mbarrier.arrive.expect_tx if the optional attribute exist.
trackAsyncOp:
If true, this op will be lowered into cp.async.mbarrier.arrive.noinc.
pred:
Only perform arrive action when pred is true.
remoteCtaId:
if set, perform an remote arrive action.
Example:
triton_nvidia_gpu.mbarrier_arrive %0 {trackAsyncOp = false} : !tt.ptr<i64>
}];
let arguments = (ins TT_Ptr:$mbarrier,
Optional<I1>:$pred,
Optional<I32>:$remoteCtaId,
I1Attr: $trackAsyncOp,
DefaultValuedAttr<I32Attr, "0">: $txCount
);
let assemblyFormat = "operands attr-dict `:` type(operands)";
}
def TTNG_FenceAsyncSharedOp : TTNG_Op<"fence_async_shared"> {
let arguments = (ins BoolAttr:$bCluster);
let summary = "fence proxy async";
let assemblyFormat = "attr-dict";
let extraClassDeclaration = [{
static bool isSupported(int computeCapability) {
return computeCapability >= 90;
}
}];
}
// TODO[goostavz]: ThreadId & ClusterCTAId should not be exposed to
// ttgpu level. Remove them when async dialect is ready.
def TTNG_GetThreadIdOp : TTNG_Op<"get_thread_id", [Pure]> {
let description = [{
Returns the one dimensional threadId.
}];
let results = (outs I32:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}
def TTNG_GetClusterCTAIdOp : TTNG_Op<"get_cluster_cta_id", [Pure]> {
let description = [{
Returns the one dimensional cluster_cta_id.
}];
let results = (outs I32:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}
def TTNG_NamedBarrierArriveOp : TTNG_Op<"bar_arrive", []> {
let summary = "named barrier arrive";
let arguments = (ins I32:$bar, I32: $numThreads);
let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)";
}
def TTNG_NamedBarrierWaitOp : TTNG_Op<"bar_wait", []> {
let summary = "named barrier wait";
let arguments = (ins I32:$bar, I32: $numThreads);
let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)";
}
def TTNG_InsertSliceAsyncV2Op : TTNG_Op<"insert_slice_async_v2",
[AttrSizedOperandSegments,
ResultsAreSharedEncoding,
// TODO: Check if MemWrite will degrade performance of non-warp-specialized kernel
MemoryEffects<[MemRead, MemWrite]>]> {
let arguments = (ins AnyTypeOf<[TT_Ptr, TT_PtrTensor]>:$src, TT_Tensor:$dst,
I32:$index, TT_Ptr:$mbar,
Optional<AnyTypeOf<[I1Tensor, I1]>>:$mask, Optional<TT_Type>:$other,
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
BoolAttr:$isVolatile, I32Attr:$axis);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)";
}
// TODO: the abstraction of barriers in ttgpu level is pending, will revisit later
// def TTNG_AwaitOp : TTNG_Op<"await", []> {
// let arguments = (ins TTNG_TokenType:$token);
// let assemblyFormat = "$token attr-dict `:` type($token)";
// }
def TTNG_ClusterArriveOp : TTNG_Op<"cluster_arrive", []> {
let arguments = (ins I1Attr:$relaxed);
let assemblyFormat = "attr-dict";
}
def TTNG_ClusterWaitOp : TTNG_Op<"cluster_wait", []> {
let assemblyFormat = "attr-dict";
}
//
// DotAsync Op
//
def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
TypesMatchWith<"result's type matches accumulator's type",
"d", "c", "$_self">]> {
let summary = "dot async";
let description = [{
$d = matrix_multiply($a, $b) + $c
}];
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
let results = (outs TT_FpIntTensor:$d);
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
}
def TTNG_DotWaitOp : TTNG_Op<"dot_wait", []> {
let summary = "dot wait";
let description = [{
This operation defining the waiting action for a async dot, MMAv3 .e.g.
The subsequent operations should not execute until this operation completes waiting.
}];
let arguments = (ins I32Attr:$pendings);
let assemblyFormat = "attr-dict";
}
def TTNG_StoreAsyncOp : TTNG_Op<"store_async",
[MemoryEffects<[MemWrite]>]> {
let summary = "store asynchronous by a tensor pointer";
let arguments = (ins TT_TensorPtr:$dst, TT_Tensor:$src,
DefaultValuedAttr<TT_CacheModifierAttr, "triton::CacheModifier::NONE">:$cache);
let assemblyFormat = "operands attr-dict `:` type(operands)";
}
def TTNG_GetAgentIdOp : TTNG_Op<"get_agent_id", [Pure]> {
let results = (outs I32:$result);
let builders = [OpBuilder<(ins)>];
let assemblyFormat = "attr-dict `:` type($result)";
}
//
// Token
//
def TTNG_CreateTokenOp : TTNG_Op<"create_token"> {
let results = (outs TensorOf<[TTNG_TokenType]>:$result);
let arguments = (ins I32Attr:$num);
let builders = [OpBuilder<(ins "uint32_t":$num)>];
let assemblyFormat = "attr-dict `:` type($result)";
}
def TTNG_ProducerAcquireOp : TTNG_Op<"producer_acquire"> {
let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx);
let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)";
}
def TTNG_ProducerCommitOp : TTNG_Op<"producer_commit"> {
let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx);
let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)";
}
def TTNG_ConsumerWaitOp : TTNG_Op<"consumer_wait"> {
let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx);
let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)";
}
def TTNG_ConsumerReleaseOp : TTNG_Op<"consumer_release"> {
let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx);
let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)";
}
//
// Mutex
//
def TTNG_GetMutexRoleIdOp : TTNG_Op<"get_mutex_role_id"> {
let results = (outs I32:$result);
let arguments = (ins I32Attr:$num);
let builders = [OpBuilder<(ins "uint32_t":$num)>];
let assemblyFormat = "attr-dict `:` type($result)";
}
def TTNG_CreateMutexOp : TTNG_Op<"create_mutex"> {
let results = (outs TTNG_MutexType:$result);
let builders = [OpBuilder<(ins)>];
let assemblyFormat = "attr-dict `:` type($result)";
}
def TTNG_LockOp : TTNG_Op<"lock"> {
let arguments = (ins TTNG_MutexType:$mutex);
let assemblyFormat = "$mutex attr-dict `:` type(operands)";
}
def TTNG_UnlockOp : TTNG_Op<"unlock"> {
let arguments = (ins TTNG_MutexType:$mutex);
let assemblyFormat = "$mutex attr-dict `:` type(operands)";
}
def TTNG_RegAllocOp : TTNG_Op<"reg_alloc", []> {
let summary = "register allocation";
let arguments = (ins I32Attr: $regCount);
let assemblyFormat = "$regCount attr-dict";
}
def TTNG_RegDeallocOp : TTNG_Op<"reg_dealloc", []> {
let summary = "register deallocation";
let arguments = (ins I32Attr: $regCount);
let assemblyFormat = "$regCount attr-dict";
}
#endif

View File

@@ -0,0 +1,37 @@
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files
// (the "Software"), to deal in the Software without restriction,
// including without limitation the rights to use, copy, modify, merge,
// publish, distribute, sublicense, and/or sell copies of the Software,
// and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#ifndef TRITONNVIDIAGPU_TYPES
#define TRITONNVIDIAGPU_TYPES
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td"
include "mlir/IR/AttrTypeBase.td"
class TTNG_TypeDef<string name, string _mnemonic>
: TypeDef<TritonNvidiaGPU_Dialect, name> {
let mnemonic = _mnemonic;
}
def TTNG_TokenType : TTNG_TypeDef<"Token", "token">;
def TTNG_MutexType : TTNG_TypeDef<"Mutex", "mutex">;
#endif

View File

@@ -0,0 +1,33 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef TRITONNVIDIAGPU_IR_TYPES_H_
#define TRITONNVIDIAGPU_IR_TYPES_H_
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#define GET_TYPEDEF_CLASSES
#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h.inc"
#endif // TRITON_IR_TYPES_H_

View File

@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonNvidiaGPU)
add_public_tablegen_target(TritonNvidiaGPUTransformsIncGen)

View File

@@ -0,0 +1,83 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_
#define TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace triton {
namespace nvidia_gpu {
// Used by Triton runtime
struct ClusterInfo {
ClusterInfo() : clusterDimX(1), clusterDimY(1), clusterDimZ(1) {}
int clusterDimX;
int clusterDimY;
int clusterDimZ;
};
} // namespace nvidia_gpu
} // namespace triton
} // namespace mlir
namespace mlir {
std::unique_ptr<Pass>
createTritonNvidiaGPUMaterializeLoadStorePass(int numWarps = 4,
int computeCapability = 80);
std::unique_ptr<Pass> createTritonNvidiaGPUPlanCTAPass(
mlir::triton::nvidia_gpu::ClusterInfo *clusterInfo = nullptr);
std::unique_ptr<Pass>
createTritonNvidiaGPUWSFeasibilityCheckingPass(int computeCapability = 90);
std::unique_ptr<Pass>
createTritonNvidiaGPUWSDecomposingPass(int computeCapability = 90);
std::unique_ptr<Pass>
createTritonNvidiaGPUWSPipelinePass(int numStages = 3, int numWarps = 4,
int computeCapability = 90);
std::unique_ptr<Pass>
createTritonNvidiaGPUWSMutexPass(int computeCapability = 90);
std::unique_ptr<Pass>
createTritonNvidiaGPUWSMaterializationPass(int computeCapability = 90);
std::unique_ptr<Pass>
createTritonNvidiaGPUFenceInsertionPass(int computeCapability = 90);
std::unique_ptr<Pass>
createTritonGPURewriteTensorPointerPass(int computeCapability = 80);
std::unique_ptr<Pass> createTritonNvidiaGPUWSFixupMissingAttrs();
/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
} // namespace mlir
#endif // TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_

View File

@@ -0,0 +1,246 @@
// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files
// (the "Software"), to deal in the Software without restriction,
// including without limitation the rights to use, copy, modify, merge,
// publish, distribute, sublicense, and/or sell copies of the Software,
// and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#ifndef TRITONNVIDIAGPU_PASSES
#define TRITONNVIDIAGPU_PASSES
include "mlir/Pass/PassBase.td"
def MaterializeLoadStore : Pass<"triton-nvidia-gpu-materialize-load-store", "mlir::ModuleOp"> {
let summary = "materialize load & store";
let description = [{
This pass works after pipeline pass, converting the remaining tt.LoadOp taking
ptr<tensor> as input into ttg.InsertSliceAsyncOp and emit proper barriers
}];
let constructor = "mlir::createTritonNvidiaGPUMaterializeLoadStorePass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
"mlir::scf::SCFDialect",
"mlir::arith::ArithDialect"];
let options = [
Option<"numWarps", "num-warps",
"int32_t", /*default*/"4",
"number of warps per block">,
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
];
}
def TritonGPUPlanCTAPass : Pass<"triton-nvidia-gpu-plan-cta", "mlir::ModuleOp"> {
let summary = "plan CTA";
let description = [{
Plan CTAs in CGA
}];
let constructor = "mlir::createTritonNvidiaGPUPlanCTAPass()";
let dependentDialects = [
"mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
];
}
def TritonGPUWSFeasibilityChecking : Pass<"triton-nvidia-gpu-ws-feasibility-checking", "mlir::ModuleOp"> {
let summary = "Attach attr named TritonNvidiaGPUDialect::getWSSupportedAttrName() if auto WS supported";
let description = [{
Since not every legal triton kernels can be auto WS, this pass does some (conservative) check
and attaches an attribute named TritonNvidiaGPUDialect::getWSSupportedAttrName() on
the input module op if the kernel is supported.
}];
let constructor = "mlir::createTritonNvidiaGPUWSFeasibilityCheckingPass()";
let dependentDialects = [
"mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"90",
"device compute capability">
];
}
def TritonGPUWSDecomposing : Pass<"triton-nvidia-gpu-ws-decomposing", "mlir::ModuleOp"> {
let summary = "Clustering on the ops according to their performance hotspots";
let description = [{
Based on compute capability and heuristics,
this pass will identify some operations to be executed in different agents,
by marking them with async 'label'. E.g.,
input:
%1 = tt,load %0 ...
%4 = tt.dot %1, %2, %3 ...
output:
%1 = tt,load %0 {async_agent = 0} ...
%4 = tt.dot %1, %2, %3 {async_agent = 1} : ...
}];
let constructor = "mlir::createTritonNvidiaGPUWSDecomposingPass()";
let dependentDialects = [
"mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
];
}
def TritonGPUWSPipeline : Pass<"triton-nvidia-gpu-ws-pipeline", "mlir::ModuleOp"> {
let summary = "Warp specialization pipeline";
let description = [{
}];
let constructor = "mlir::createTritonNvidiaGPUWSPipelinePass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
"mlir::scf::SCFDialect",
"mlir::arith::ArithDialect"];
let options = [
Option<"numStages", "num-stages",
"int32_t", /*default*/"3",
"number of pipeline stages">,
Option<"numWarps", "num-warps",
"int32_t", /*default*/"12",
"number of warps per block">,
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"90",
"device compute capability">
];
}
def TritonGPUWSMutex : Pass<"triton-nvidia-gpu-ws-mutex", "mlir::ModuleOp"> {
let summary = "Warp specialization mutex syncronization";
let description = [{
create mutex syncronization for persistent kernel. (as "2 Math WG" persistent kernel in cutlass)
For example, the agent containing dot and store will be divided into two sub-agent,
which execute dot and store alternately. i.e.:
sub-agent-0: dot | store | dot | ... | store
sub-agent-1: | dot | store | ... | dot | store
}];
let constructor = "mlir::createTritonNvidiaGPUWSMutexPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
"mlir::scf::SCFDialect",
"mlir::arith::ArithDialect"];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
];
}
def TritonGPUWSMaterialization : Pass<"triton-nvidia-gpu-ws-materialization", "mlir::ModuleOp"> {
let summary = "Warp specialization materialization";
let description = [{
}];
let constructor = "mlir::createTritonNvidiaGPUWSMaterializationPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"90",
"device compute capability">
];
}
def TritonGPUFenceInsertion : Pass<"triton-nvidia-gpu-fence-insertion", "mlir::ModuleOp"> {
let summary = "Insert fences across generic and async proxy";
let description = [{
}];
let constructor = "mlir::createTritonNvidiaGPUFenceInsertionPass()";
let dependentDialects = [
"mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"90",
"device compute capability">
];
}
def TritonGPURewriteTensorPointer : Pass</*cli-arg*/"tritongpu-rewrite-tensor-pointer", /*Op*/"mlir::ModuleOp"> {
let summary = "Rewrite load/stores with tensor pointers into legacy load/stores";
let description = [{
This pass rewrites all load/store semantics initiated by a `tt.make_tensor_ptr` and `tt.advance` into legacy
semantics. After this pass, `tt.make_tensor_ptr` and `tt.advance` will disappear, and it generates logics to compute
the pointer/mask/other for each load/store.
}];
let constructor = "mlir::createTritonGPURewriteTensorPointerPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
];
}
def TritonGPUWSFixupMissingAttrs : Pass<"triton-nvidia-gpu-ws-fixup-missing-attrs", "mlir::ModuleOp"> {
let summary = "Fixup missing WS related attributes";
let description = [{
WS related attributes are attached to some key operations and are used when lowering to llvm.
However these attributes maybe be dropped in the following IR transform. This pass tries to
fixup the missing attributes.
}];
let constructor = "mlir::createTritonNvidiaGPUWSFixupMissingAttrs()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
"mlir::scf::SCFDialect",
"mlir::arith::ArithDialect"];
}
#endif

View File

@@ -0,0 +1,95 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_
#define TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/MapVector.h"
namespace mlir {
// 0 is reserved for default sync.
// TODO: comprehensive mechanism to globally manage namedbarrier.
static int const nameBarrierIdBegin = 1;
static int nameBarrierIdEnd = 16;
/// Helper functions for async agent
typedef int AgentId;
SmallVector<AgentId> getAgentIds(Operation *op);
bool hasAgentId(Operation *op, AgentId agentId);
void setAgentIds(Operation *op, ArrayRef<AgentId> agentIds);
SmallVector<AgentId> collectAgentIds(Operation *op);
void addAgentIds(Operation *op, ArrayRef<int> agents);
SmallVector<int> getMutexBarIds(Operation *op);
SmallVector<int> getMutexNumThreads(Operation *op);
class OpBuilderWithAgentIds : public OpBuilder {
public:
OpBuilderWithAgentIds(MLIRContext *context) : OpBuilder(context) {}
void setAgentIdsFromArray(ArrayRef<AgentId> newAgentIds) {
agentIds = SmallVector<AgentId>(newAgentIds.begin(), newAgentIds.end());
}
void setAgentIdsFromOp(Operation *op) {
setAgentIdsFromArray(getAgentIds(op));
}
void setAgentIdsFromValueUsers(Value value) {
SetVector<AgentId> agentIdSet;
for (Operation *user : value.getUsers())
for (AgentId agentId : getAgentIds(user))
agentIdSet.insert(agentId);
setAgentIdsFromArray(agentIdSet.getArrayRef());
}
template <typename OpTy, typename... Args>
OpTy createWithAgentIds(Args &&...args) {
OpTy op = create<OpTy>(std::forward<Args>(args)...);
if (!agentIds.empty())
setAgentIds(op, agentIds);
return op;
}
private:
SmallVector<AgentId> agentIds;
};
/// Constant agent ids
constexpr AgentId kLoadAgentId = 0;
constexpr AgentId kDotAgentId = 1;
bool isWSCandidateLoad(Operation *op);
bool isWSSupported(ModuleOp m, int computeCapability);
LogicalResult getDependentValues(Value val, DenseSet<Value> &depSet,
const DenseSet<Value> &stopSet = {});
LogicalResult getDependentValues(Operation *op, DenseSet<Value> &depSet,
const DenseSet<Value> &stopSet = {});
DenseSet<Operation *> getDependentOps(DenseSet<Value> &depSet);
} // namespace mlir
#endif // TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_

View File

@@ -0,0 +1,19 @@
#ifndef TRITON_TARGET_AMDGCNTRANSLATION_H
#define TRITON_TARGET_AMDGCNTRANSLATION_H
#include <string>
#include <tuple>
namespace llvm {
class Module;
} // namespace llvm
namespace triton {
// Translate LLVM IR to AMDGCN code.
std::tuple<std::string, std::string>
translateLLVMIRToAMDGCN(llvm::Module &module, std::string cc);
} // namespace triton
#endif

View File

@@ -1,5 +1,7 @@
#ifndef TRITON_TARGET_LLVM_IR_LLVM_IR_TRANSLATION_H
#define TRITON_TARGET_LLVM_IR_LLVM_IR_TRANSLATION_H
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
#include "triton/Target/PTX/TmaMetadata.h"
#include "llvm/ADT/StringRef.h"
#include <memory>
#include <string>
@@ -26,12 +28,16 @@ void addExternalLibs(mlir::ModuleOp &module,
std::unique_ptr<llvm::Module>
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
mlir::ModuleOp module, int computeCapability,
bool isROCM);
mlir::triton::gpu::TMAMetadataTy &tmaInfos,
Target target);
// Translate mlir LLVM dialect to LLVMIR, return null if failed.
std::unique_ptr<llvm::Module>
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
bool isROCM);
Target target);
bool linkExternLib(llvm::Module &module, llvm::StringRef name,
llvm::StringRef path, Target target);
} // namespace triton
} // namespace mlir

View File

@@ -0,0 +1,107 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef TRITON_TARGET_PTX_TMAMETADATA_H
#define TRITON_TARGET_PTX_TMAMETADATA_H
#include "python/triton/third_party/cuda/include/cuda.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/FormatVariadic.h"
#include <map>
#include <utility>
#include <vector>
namespace mlir {
namespace triton {
namespace gpu {
struct TMAInfo {
// --------------------------------------------
// informations to be filled into CUtensorMaps
int tensorDataType;
uint32_t tensorRank;
// the argument indices for the runtime to get globalAddresses
size_t globalAddressArgIdx;
// the argument indices for the runtime to get globalDims, -1 stands for this
// dim is padded
std::vector<int32_t> globalDimsArgIdx;
// the argument indices for the runtime to get globalStrides, -1 stands for
// this dim is padded the runtime need to map the value to internal format
std::vector<int32_t> globalStridesArgIdx;
std::vector<uint32_t> boxDims;
std::vector<uint32_t> elementStrides;
int interleave;
int swizzle;
int l2Promotion;
int oobFill;
// --------------------------------------------
// the argument indices for the runtime to send the address of tma_desc to the
// binary
int TMADescArgIdx;
template <typename T>
void dump_vec(const std::vector<T> &vec, llvm::StringRef info) const {
llvm::errs() << info << ": ";
for (const T &e : vec)
llvm::errs() << e << ",";
llvm::errs() << "\n";
}
void dump() const {
llvm::errs() << "TMA Info: ----------"
<< "\n";
llvm::errs() << "-- tensorDataType: " << tensorDataType
<< ", tensorRank: " << tensorRank << "\n";
llvm::errs() << "-- globalAddressArgIdx: " << globalAddressArgIdx << "\n";
llvm::errs() << "-- TMADescArgIdx: " << TMADescArgIdx << "\n";
dump_vec<int32_t>(globalDimsArgIdx, "-- globalDimsArgIdx");
dump_vec<int32_t>(globalStridesArgIdx, "-- globalStridesArgIdx");
dump_vec<uint32_t>(boxDims, "-- boxDims");
dump_vec<uint32_t>(elementStrides, "-- elementStrides");
llvm::errs() << "-- interleave: " << interleave << "\n";
llvm::errs() << "-- swizzle: " << swizzle << "\n";
llvm::errs() << "-- l2Promotion: " << l2Promotion << "\n";
llvm::errs() << "-- oobFill: " << oobFill << "\n";
};
};
using TMAMetadataTy = std::vector<TMAInfo>;
} // namespace gpu
} // namespace triton
} // namespace mlir
#endif // TRITON_TARGET_PTX_TMAMETADATA_H

View File

@@ -24,10 +24,16 @@
#include <algorithm>
#include <cstdlib>
#include <set>
#include <string>
namespace triton {
const std::set<std::string> ENV_VARS = {
"ENABLE_MMA_V3", "TRITON_DISABLE_LINE_INFO", "DISABLE_FAST_REDUCTION",
"ENABLE_TMA", "MLIR_ENABLE_DUMP", "LLVM_IR_ENABLE_DUMP",
"AMDGCN_ENABLE_DUMP"};
namespace tools {
inline std::string getenv(const char *name) {
@@ -39,6 +45,9 @@ inline std::string getenv(const char *name) {
}
inline bool getBoolEnv(const std::string &env) {
std::string msg = "Environment variable " + env + " is not recognized";
assert(triton::ENV_VARS.find(env.c_str()) != triton::ENV_VARS.end() &&
msg.c_str());
const char *s = std::getenv(env.c_str());
std::string str(s ? s : "");
std::transform(str.begin(), str.end(), str.begin(),

View File

@@ -2,6 +2,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
namespace mlir {
@@ -27,17 +28,21 @@ void SharedMemoryAliasAnalysis::visitOperation(
// These ops may allocate a new shared memory buffer.
auto result = op->getResult(0);
// XXX(Keren): the following ops are always aliasing for now
if (isa<triton::gpu::ExtractSliceOp, triton::TransOp>(op)) {
if (isa<triton::gpu::ExtractSliceOp, triton::TransOp,
triton::nvidia_gpu::ExtractMBarrierOp>(op)) {
// extract_slice %src
// trans %src
aliasInfo = AliasInfo(operands[0]->getValue());
pessimistic = false;
} else if (isa<tensor::InsertSliceOp, triton::gpu::InsertSliceAsyncOp>(
op)) {
} else if (isa<tensor::InsertSliceOp, triton::gpu::InsertSliceAsyncOp,
triton::nvidia_gpu::InsertSliceAsyncV2Op>(op)) {
// insert_slice_async %src, %dst, %index
// insert_slice %src into %dst[%offsets]
aliasInfo = AliasInfo(operands[1]->getValue());
pessimistic = false;
} else if (isa<triton::nvidia_gpu::StoreAsyncOp>(op)) {
aliasInfo = AliasInfo(operands[0]->getValue());
pessimistic = false;
} else if (triton::gpu::isSharedEncoding(result)) {
aliasInfo.insert(result);
pessimistic = false;

View File

@@ -16,6 +16,7 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getContigPerThread;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getShapePerCTATile;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::MfmaEncodingAttr;
using ::mlir::triton::gpu::MmaEncodingAttr;
@@ -58,11 +59,23 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
// MmaToDotShortcut doesn't use shared mem
if (srcLayout.isa<MmaEncodingAttr>() &&
dstLayout.isa<DotOperandEncodingAttr>())
if (isMmaToDotShortcut(srcTy, dstTy))
return {};
if (shouldUseDistSmem(srcLayout, dstLayout)) {
// TODO: padding to avoid bank conflicts
return convertType<unsigned, int64_t>(getShapePerCTA(srcTy));
}
// MmaToDotShortcut and MmaToMmaShortcut doesn't use shared mem
if (auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>()) {
if (dstLayout.isa<DotOperandEncodingAttr>()) {
if (isMmaToDotShortcut(srcTy, dstTy)) {
return {};
}
} else if (auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>()) {
if (isMmaToMmaShortcut(srcTy, dstTy)) {
return {};
}
}
}
#ifdef USE_ROCM
if (srcLayout.isa<MfmaEncodingAttr>() &&
@@ -82,18 +95,18 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;
auto srcShape = srcTy.getShape();
auto dstShape = dstTy.getShape();
auto srcShapePerCTA = getShapePerCTA(srcLayout, srcShape);
auto dstShapePerCTA = getShapePerCTA(dstLayout, dstShape);
auto srcShapePerCTA = getShapePerCTA(srcTy);
auto dstShapePerCTA = getShapePerCTA(dstTy);
auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape());
auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape());
unsigned rank = dstTy.getRank();
SmallVector<unsigned> paddedRepShape(rank);
unsigned pad = std::max(inVec, outVec);
for (unsigned d = 0; d < rank; ++d) {
paddedRepShape[d] =
std::max(std::min<unsigned>(srcTy.getShape()[d], srcShapePerCTA[d]),
std::min<unsigned>(dstTy.getShape()[d], dstShapePerCTA[d]));
std::max(std::min<unsigned>(srcShapePerCTA[d], srcShapePerCTATile[d]),
std::min<unsigned>(dstShapePerCTA[d], dstShapePerCTATile[d]));
}
if (rank == 1)
return paddedRepShape;
@@ -105,6 +118,12 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
return paddedRepShape;
}
SmallVector<unsigned>
getScratchConfigForStoreAsync(triton::nvidia_gpu::StoreAsyncOp op) {
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
return convertType<unsigned, int64_t>(getShapePerCTA(srcTy));
}
// TODO: extend beyond scalars
SmallVector<unsigned> getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) {
SmallVector<unsigned> smemShape;
@@ -212,20 +231,45 @@ private:
// For example: %a = scf.if -> yield
// %a must be allocated elsewhere by other operations.
// FIXME(Keren): extract and insert are always alias for now
if (!maybeSharedAllocationOp(op) || maybeAliasOp(op)) {
if (!maybeSharedAllocationOp(op) || maybeAliasOp(op))
return;
}
// XXX(Keren): Why this hard-coded alignment?
size_t kAlignment = 8;
for (Value result : op->getResults()) {
if (triton::gpu::isSharedEncoding(result)) {
// Bytes could be a different value once we support padding or other
// allocation policies.
auto tensorType = result.getType().dyn_cast<RankedTensorType>();
auto bytes = tensorType.getNumElements() *
auto shapePerCTA = triton::gpu::getShapePerCTA(tensorType);
auto bytes = product<int64_t>(shapePerCTA) *
tensorType.getElementTypeBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Explicit>(result, bytes);
// XXX(Keren): magic numbers 256 and 1024
// benzh@maybe alignment should be passed in.
// Software swizzling calculates phase based on offset, while hardware
// swizzling do that based on physical address. Thus only by setting the
// alignment to 1024 can ensure the correctness. 
if (bytes > 256)
kAlignment = 1024;
allocation->addBuffer<BufferT::BufferKind::Explicit>(result, bytes,
kAlignment);
}
}
if (isa<triton::nvidia_gpu::AllocMBarrierOp>(op)) {
Value result = op->getResult(0);
if (!result.getType().isa<RankedTensorType>())
// In case AllocMBarrierOp is allocating scalar mbarriers
allocation->addBuffer<BufferT::BufferKind::Explicit>(result, 8,
kAlignment);
}
}
template <BufferT::BufferKind T>
void maybeAddScratchBuffer(Operation *op, unsigned bytes,
unsigned alignment) {
if (bytes > 0)
allocation->addBuffer<T>(op, bytes, alignment);
}
template <BufferT::BufferKind T>
@@ -236,14 +280,17 @@ private:
/// Initializes temporary shared memory for a given operation.
void getScratchValueSize(Operation *op) {
const size_t scratchAlignment = 128;
if (auto reduceOp = dyn_cast<triton::ReduceOp>(op)) {
ReduceOpHelper helper(reduceOp);
unsigned bytes = helper.getScratchSizeInBytes();
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (auto scanOp = dyn_cast<triton::ScanOp>(op)) {
ScanLoweringHelper helper(scanOp);
unsigned bytes = helper.getScratchSizeInBytes();
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
auto srcTy = cvtLayout.getSrc().getType().cast<RankedTensorType>();
auto dstTy = cvtLayout.getResult().getType().cast<RankedTensorType>();
@@ -267,7 +314,20 @@ private:
srcTy.getElementType().isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (auto storeAsyncOp =
dyn_cast<triton::nvidia_gpu::StoreAsyncOp>(op)) {
auto srcTy = storeAsyncOp.getSrc().getType().cast<RankedTensorType>();
auto srcEncoding = srcTy.getEncoding();
if (!srcEncoding.isa<MmaEncodingAttr>()) {
return;
}
auto smemShape = getScratchConfigForStoreAsync(storeAsyncOp);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto bytes = elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes, 1024);
} else if (auto atomicRMWOp = dyn_cast<triton::AtomicRMWOp>(op)) {
auto value = op->getOperand(0);
// only scalar requires scratch memory
@@ -284,7 +344,8 @@ private:
elemTy.isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
}
} else if (auto atomicCASOp = dyn_cast<triton::AtomicCASOp>(op)) {
auto value = op->getOperand(0);
@@ -296,13 +357,15 @@ private:
auto bytes = elemTy.isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * elemTy.getIntOrFloatBitWidth() / 8;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (auto callOp = dyn_cast<CallOpInterface>(op)) {
auto callable = callOp.resolveCallable();
auto funcOp = dyn_cast<FunctionOpInterface>(callable);
auto *funcAlloc = &(*funcAllocMap)[funcOp];
auto bytes = funcAlloc->getSharedMemorySize();
maybeAddScratchBuffer<BufferT::BufferKind::Virtual>(op, bytes);
maybeAddScratchBuffer<BufferT::BufferKind::Virtual>(op, bytes,
scratchAlignment);
}
}
@@ -444,6 +507,12 @@ private:
// Analyze liveness of explicit buffers
Liveness liveness(operation);
auto getValueLivenessRange = [&](Value value) {
// Shared memory allocated by mbarrier cannot be reused
if (value.getDefiningOp() &&
isa<triton::nvidia_gpu::AllocMBarrierOp>(value.getDefiningOp()))
return Interval(std::numeric_limits<size_t>::min(),
std::numeric_limits<size_t>::max());
auto liveOperations = liveness.resolveLiveness(value);
LivenessR ranges;
std::for_each(liveOperations.begin(), liveOperations.end(),
@@ -520,17 +589,22 @@ private:
auto xRange = bufferRange[buffer];
bool res = xRange.intersects(range);
for (auto val : tripleMap)
res = res && !val.second.intersects(xRange);
res = res &&
!val.second.intersects(xRange); // only one buffer intersect
return res;
});
if (bufferIt != xBuffers.end()) {
auto buffer = *bufferIt;
auto xSize = buffer->size;
auto xRange = bufferRange.lookup(buffer);
bufferStart[buffer] = size;
tripleMap.insert(
{size + xSize, Interval{std::max(range.start(), xRange.start()),
std::min(range.end(), xRange.end())}});
// TODO(Keren): A buffer's size shouldn't be determined here, have to
// clean it up
size_t alignment = buffer->alignment;
size_t alignSize = ((size + alignment - 1) / alignment) * alignment;
bufferStart[buffer] = alignSize;
tripleMap.insert({alignSize + xSize,
Interval{std::max(range.start(), xRange.start()),
std::min(range.end(), xRange.end())}});
// We could either insert (range.start, xRange.start) or (range.start,
// xRange.end), both are correct and determine the potential buffer
// offset, and the graph coloring algorithm will solve the interference,

View File

@@ -667,14 +667,10 @@ public:
AxisInfo
getAxisInfo(OpTy op,
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
auto resTy = op.getResult().getType().template dyn_cast<RankedTensorType>();
if (!resTy)
return AxisInfo();
auto shape = resTy.getShape();
auto rank = shape.size();
auto condConstancy = operands[0]->getValue().getConstancy();
auto lhsInfo = operands[1]->getValue();
auto rhsInfo = operands[2]->getValue();
auto rank = lhsInfo.getRank();
AxisInfo::DimVectorT contiguity, divisibility, constancy;
std::optional<int64_t> constantValue;

View File

@@ -14,4 +14,5 @@ add_mlir_library(TritonAnalysis
MLIRLLVMDialect
TritonIR
TritonGPUIR
TritonNvidiaGPUIR
)

View File

@@ -2,7 +2,12 @@
#include "triton/Analysis/Alias.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "../lib/Conversion/TritonGPUToLLVM/Utility.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include <deque>
@@ -103,20 +108,37 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
return;
}
if (isa<gpu::BarrierOp>(op)) {
// TODO(Keren): Don't expose LLVM Dialect ops here
if (isa<gpu::BarrierOp>(op) ||
(isa<LLVM::InlineAsmOp>(op) &&
(dyn_cast<LLVM::InlineAsmOp>(op).getAsmString().find("bar.sync") !=
std::string::npos))) {
// If the current op is a barrier, we sync previous reads and writes
blockInfo->sync();
return;
}
if (isa<triton::gpu::AsyncWaitOp>(op) &&
!isa<gpu::BarrierOp>(op->getNextNode())) {
if (isa<triton::gpu::AsyncWaitOp, triton::gpu::AsyncBulkWaitOp>(op) &&
!isa<gpu::BarrierOp>(op->getNextNode()) &&
!(isa<LLVM::InlineAsmOp>(op->getNextNode()) &&
(dyn_cast<LLVM::InlineAsmOp>(op->getNextNode())
.getAsmString()
.find("bar.sync") != std::string::npos))) {
// If the current op is an async wait and the next op is not a barrier we
// insert a barrier op and sync
blockInfo->sync();
OpBuilder::InsertionGuard g(*builder);
builder->setInsertionPointAfter(op);
builder->create<gpu::BarrierOp>(op->getLoc());
if (auto optionalAgentId = getWSAgentId(op)) {
int agentId = *optionalAgentId, roleId = 0;
if (auto optionalRoleId = getWSRoleId(op))
roleId = *optionalRoleId;
int barId = agentId + roleId + nameBarrierIdBegin;
assert(barId < nameBarrierIdEnd);
barSync(*builder, op, barId, 128);
} else {
builder->create<gpu::BarrierOp>(op->getLoc());
}
blockInfo->sync();
return;
}
@@ -169,12 +191,23 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
if (blockInfo->isIntersected(curBlockInfo)) {
OpBuilder::InsertionGuard g(*builder);
builder->setInsertionPoint(op);
builder->create<gpu::BarrierOp>(op->getLoc());
// TODO(Keren): Don't expose LLVM Dialect ops here
// TODO[shuhaoj]: Change hard code style of numThreads. Hide async_agent
// attr. Better way to determine barId (number of agents are limited).
if (auto optionalAgentId = getWSAgentId(op)) {
int agentId = *optionalAgentId, roleId = 0;
if (auto optionalRoleId = getWSRoleId(op))
roleId = *optionalRoleId;
int barId = agentId + roleId + nameBarrierIdBegin;
assert(barId < nameBarrierIdEnd);
barSync(*builder, op, barId, 128);
} else {
builder->create<gpu::BarrierOp>(op->getLoc());
}
blockInfo->sync();
}
// Update the region info, even if barrier is inserted, we have to maintain
// the current op's read/write buffers.
blockInfo->join(curBlockInfo);
}
} // namespace mlir

View File

@@ -1,10 +1,14 @@
#include "triton/Analysis/Utility.h"
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Matchers.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include <deque>
@@ -37,6 +41,51 @@ bool ReduceOpHelper::isFastReduction() {
getParentOrder(getSrcLayout())[0];
}
// Cases where distributed shared memory is not required in ConvertLayout:
// (1) numCTAs == 1
// (2) numCTAs > 1 but srcCTALayout == dstCTALayout
// TODO: Case with SliceLayout as srcLayout and numCTAs > 1 is to be implemented
// in the future
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout) {
unsigned numCTAs = triton::gpu::getNumCTAs(srcLayout);
assert(numCTAs == triton::gpu::getNumCTAs(dstLayout) &&
"Invalid layout conversion: the numbers of CTAs of src and dst "
"layouts are different");
// Case (1): Never use dsmem when numCTAs == 1
if (numCTAs == 1)
return false;
// Case where CTAsPerCGA of srcLayout in the sliced dim is not 1 is not
// implemented yet
if (auto sliceLayout = srcLayout.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
auto dim = sliceLayout.getDim();
auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(sliceLayout.getParent());
if (CTAsPerCGA[dim] != 1)
assert(0 && "Layout conversion to be implemented");
}
// Case where CTAsPerCGA of dstLayout in the sliced dim is not 1 is supported
if (auto sliceLayout = dstLayout.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
auto dim = sliceLayout.getDim();
auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(sliceLayout.getParent());
if (CTAsPerCGA[dim] != 1)
return true;
}
// The above two branches make sure that it is legal to call getCTALayout of
// srcLayout and dstLayout
// Case (2): Do not use dsmem when srcCTALayout == dstCTALayout
auto srcCTALayout = triton::gpu::getCTALayout(srcLayout);
auto dstCTALayout = triton::gpu::getCTALayout(dstLayout);
if (srcCTALayout == dstCTALayout)
return false;
// Dsmem access is required when srcCTALayout != dstCTALayout
return true;
}
unsigned ReduceOpHelper::getInterWarpSize() {
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
unsigned sizeIntraWarps = getIntraWarpSize();
@@ -125,7 +174,7 @@ unsigned ReduceOpHelper::getScratchSizeInBytes() {
unsigned bytesPerElem = 0;
for (const auto &ty : srcElementTypes) {
bytesPerElem += ty.getIntOrFloatBitWidth() / 8;
bytesPerElem += ceil<unsigned>(ty.getIntOrFloatBitWidth(), 8);
}
return bytesPerElem * elems;
}
@@ -136,7 +185,7 @@ bool ReduceOpHelper::isSupportedLayout() {
return true;
}
if (auto mmaLayout = srcLayout.dyn_cast<triton::gpu::MmaEncodingAttr>()) {
if (mmaLayout.isAmpere()) {
if (mmaLayout.isAmpere() || mmaLayout.isHopper()) {
return true;
}
}
@@ -286,6 +335,8 @@ bool maybeSharedAllocationOp(Operation *op) {
return dialect &&
(dialect->getTypeID() ==
mlir::TypeID::get<triton::gpu::TritonGPUDialect>() ||
dialect->getTypeID() ==
mlir::TypeID::get<triton::nvidia_gpu::TritonNvidiaGPUDialect>() ||
dialect->getTypeID() == mlir::TypeID::get<triton::TritonDialect>() ||
dialect->getTypeID() == mlir::TypeID::get<arith::ArithDialect>() ||
dialect->getTypeID() == mlir::TypeID::get<tensor::TensorDialect>());
@@ -294,6 +345,8 @@ bool maybeSharedAllocationOp(Operation *op) {
bool maybeAliasOp(Operation *op) {
return isa<triton::gpu::ExtractSliceOp>(op) || isa<triton::TransOp>(op) ||
isa<triton::gpu::InsertSliceAsyncOp>(op) ||
isa<triton::nvidia_gpu::InsertSliceAsyncV2Op>(op) ||
isa<triton::nvidia_gpu::StoreAsyncOp>(op) ||
isa<tensor::InsertSliceOp>(op);
}
@@ -303,7 +356,25 @@ bool supportMMA(triton::DotOp op, int version) {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
auto aElemTy = op.getA().getType().cast<RankedTensorType>().getElementType();
auto bElemTy = op.getB().getType().cast<RankedTensorType>().getElementType();
<<<<<<< HEAD
=======
if (version == 3) {
if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3"))
return false;
auto retType = op.getResult().getType().cast<RankedTensorType>();
auto retShapePerCTA = triton::gpu::getShapePerCTA(retType);
auto mod = op->getParentOfType<mlir::ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
if (!(numWarps % 4 == 0 && retShapePerCTA[0] % 64 == 0 &&
retShapePerCTA[1] % 8 == 0 &&
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ() ||
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
aElemTy.isF32()))) {
return false;
}
}
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
if (aElemTy.isF32() && bElemTy.isF32()) {
return (op.getAllowTF32() && version == 2) || version == 3;
}
@@ -345,25 +416,22 @@ bool supportMFMA(triton::DotOp op) {
#endif
bool supportMMA(Value value, int version) {
// Tell whether a DotOp support HMMA by the operand type(either $a or $b).
// Tell whether a DotOp support MMA by the operand type(either $a or $b).
// We cannot get both the operand types(in TypeConverter), here we assume the
// types of both the operands are identical here.
assert((version == 1 || version == 2) &&
assert((version == 1 || version == 2 || version == 3) &&
"Unexpected MMA layout version found");
auto elemTy = value.getType().cast<RankedTensorType>().getElementType();
return elemTy.isF16() || elemTy.isBF16() ||
// FP8 is not natively supported on all mma versions but it can always be
// promoted to fp16 therefore we can always support it.
bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() ||
elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ();
return isFP8 || elemTy.isF16() || elemTy.isBF16() ||
(elemTy.isF32() && version >= 2) ||
(elemTy.isInteger(8) && version >= 2);
}
Type getElementType(Value value) {
auto type = value.getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return tensorType.getElementType();
return type;
}
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
// dot_op<opIdx=0, parent=#mma> = #mma
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
@@ -378,6 +446,7 @@ bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
!srcTy.getElementType().isF32();
}
<<<<<<< HEAD
#ifdef USE_ROCM
bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
auto srcLayout = srcTy.getEncoding();
@@ -395,6 +464,18 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
}
#endif
=======
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
auto src = srcTy.getEncoding().cast<triton::gpu::MmaEncodingAttr>();
auto dst = dstTy.getEncoding().cast<triton::gpu::MmaEncodingAttr>();
auto srcElemsPerThread = triton::gpu::getTotalElemsPerThread(srcTy);
auto dstElemsPerThread = triton::gpu::getTotalElemsPerThread(dstTy);
// when #mma = MmaEncoding<version=3, warpsPerCTA=[..., 1]>
return src.getVersionMajor() == 3 && src.getWarpsPerCTA()[1] == 1 &&
dst.getVersionMajor() == 3 && dst.getWarpsPerCTA()[1] == 1 &&
srcElemsPerThread == dstElemsPerThread;
}
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
bool isSingleValue(Value value) {
// Don't consider load as expensive if it is loading a scalar.
@@ -455,9 +536,11 @@ struct DFSState {
SmallVector<Operation *, 16> topologicalCounts;
DenseSet<Operation *> seen;
/// We mark each op as ready if all its operands are seen. If an op is ready,
/// we add it to the queue. Otherwise, we keep adding its operands to the
/// ancestors set.
/// We mark each op as ready if all its operands and parents ops are seen. If
/// an op is ready, we add it to the queue. Otherwise, we keep adding its
/// operands to the ancestors set.
/// We always want an op to be scheduled after all its parents to handle
/// correctly cases with scf operations.
void addToReadyQueue(Operation *op, DFSSubgraphState &subGraph,
SmallVector<Operation *, 4> &readyQueue) {
bool ready = true;
@@ -468,6 +551,14 @@ struct DFSState {
ready = false;
}
}
Operation *parent = op->getParentOp();
while (parent) {
if (!seen.count(parent)) {
subGraph.push_back(parent);
ready = false;
}
parent = parent->getParentOp();
}
if (ready)
readyQueue.push_back(op);
}
@@ -615,4 +706,81 @@ std::unique_ptr<DataFlowSolver> createDataFlowSolver() {
return solver;
}
static triton::MakeTensorPtrOp getMakeTensorPtrOpImpl(Operation *op, Value v) {
if (auto makeTensorPtrOp = dyn_cast<triton::MakeTensorPtrOp>(op)) {
return makeTensorPtrOp;
}
if (auto advanceOp = dyn_cast<triton::AdvanceOp>(op)) {
return getMakeTensorPtrOp(advanceOp.getPtr());
}
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
auto idx = v.cast<OpResult>().getResultNumber();
llvm::SmallVector<scf::YieldOp> yieldOps;
op->walk([&](Operation *op) {
if (auto yieldOp = dyn_cast<scf::YieldOp>(op))
yieldOps.push_back(yieldOp);
});
// benzh@ if multi yields, all yields operand should come from same arg.
Value newValue = yieldOps[0].getOperands()[idx];
return getMakeTensorPtrOp(newValue);
}
llvm_unreachable("Unable to getMakeTensorPtr()");
}
triton::MakeTensorPtrOp getMakeTensorPtrOp(Value v) {
using BranchOps = llvm::SetVector<std::pair<Operation *, int>>;
llvm::DenseMap<Block *, BranchOps> blockToCFOps;
auto moduleOp =
v.getParentBlock()->getParentOp()->getParentOfType<ModuleOp>();
moduleOp.walk([&](Operation *op) {
if (auto br = dyn_cast<cf::BranchOp>(op)) {
Block *block = br.getDest();
blockToCFOps[block].insert({op, -1});
}
if (auto condBr = dyn_cast<cf::CondBranchOp>(op)) {
Block *blockT = condBr.getTrueDest();
Block *blockF = condBr.getFalseDest();
blockToCFOps[blockT].insert({condBr, 1});
blockToCFOps[blockF].insert({condBr, 0});
}
});
if (Operation *definingOp = v.getDefiningOp()) {
return getMakeTensorPtrOpImpl(definingOp, v);
} else if (BlockArgument arg = v.cast<BlockArgument>()) {
unsigned argNum = arg.getArgNumber();
Operation *argOwner = arg.getOwner()->getParentOp();
if (auto forOp = dyn_cast<scf::ForOp>(argOwner)) {
return getMakeTensorPtrOp(
forOp.getOperand(argNum + forOp.getNumControlOperands() - 1));
} else if (auto funcOp = dyn_cast<mlir::triton::FuncOp>(argOwner)) {
Block *block = arg.getOwner();
Operation *op;
int tOrF;
std::tie(op, tOrF) = blockToCFOps[block][0];
if (auto br = dyn_cast<cf::BranchOp>(op)) {
return getMakeTensorPtrOp(br.getDestOperands()[argNum]);
}
if (auto condBr = dyn_cast<cf::CondBranchOp>(op)) {
if (tOrF) {
return getMakeTensorPtrOp(condBr.getTrueDestOperands()[argNum]);
} else {
return getMakeTensorPtrOp(condBr.getFalseDestOperands()[argNum]);
}
}
} else {
return getMakeTensorPtrOp(argOwner->getOperand(argNum));
}
}
llvm_unreachable("Unable to getMakeTensorPtr()");
}
} // namespace mlir

View File

@@ -1,2 +1,3 @@
add_subdirectory(TritonToTritonGPU)
add_subdirectory(TritonGPUToLLVM)
add_subdirectory(NVGPUToLLVM)

View File

@@ -0,0 +1,27 @@
add_mlir_conversion_library(NVGPUToLLVM
NVGPUToLLVMPass.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/triton/Conversion/NVGPUToLLVM
${PROJECT_BINARY_DIR}/include/triton/Conversion/NVGPUToLLVM
DEPENDS
NVGPUConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRGPUOps
MLIRGPUToNVVMTransforms
MLIRGPUToROCDLTransforms
MLIRGPUTransforms
TritonAnalysis
TritonIR
TritonGPUIR
TritonGPUTransforms
TritonNvidiaGPUTransforms
NVGPUIR
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,217 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include "BarrierOpToLLVM.h"
using namespace mlir;
using namespace mlir::triton;
// --------------------------------------------------------------------------
// -- MBarrier related Ops lowering, to be moved to a seperate file ---------
// --------------------------------------------------------------------------
struct AllocMBarrierOpConversion : public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::AllocMBarrierOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::AllocMBarrierOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::AllocMBarrierOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getResult());
auto resultTy = op.getType();
auto resultTensorTy = resultTy.dyn_cast<RankedTensorType>();
Type elemPtrTy;
if (resultTensorTy) {
auto llvmElemTy =
getTypeConverter()->convertType(resultTensorTy.getElementType());
elemPtrTy = ptr_ty(llvmElemTy, 3);
} else {
elemPtrTy = getTypeConverter()->convertType(resultTy);
}
smemBase = bitcast(smemBase, elemPtrTy);
auto threadId = getThreadId(rewriter, loc);
auto pred = icmp_eq(threadId, i32_val(0));
int numMBarriers = 1;
if (resultTensorTy) {
assert(resultTensorTy.getRank() == 1 &&
"unexpected rank for AllocMBarrierOp");
numMBarriers = resultTensorTy.getShape()[0];
}
for (int i = 0; i < numMBarriers; ++i) {
Value smem = smemBase;
if (i > 0) {
smem = gep(elemPtrTy, smem, i32_val(i));
}
rewriter.create<triton::nvgpu::MBarrierInitOp>(loc, smem, pred,
op.getCount());
}
if (resultTensorTy) {
auto smemObj = SharedMemoryObject(smemBase, resultTensorTy.getShape(),
{0}, loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal);
} else {
rewriter.replaceOp(op, smemBase);
}
return success();
}
};
struct MBarrierArriveOpConversion : public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::MBarrierArriveOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::MBarrierArriveOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::MBarrierArriveOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto mbarrier = adaptor.getMbarrier();
bool trackAsyncOp = op.getTrackAsyncOp();
triton::nvgpu::MBarriveType type = triton::nvgpu::MBarriveType::normal;
uint32_t txCount = op.getTxCount();
auto remoteCtaId = adaptor.getRemoteCtaId();
if (trackAsyncOp) {
type = triton::nvgpu::MBarriveType::cp_async;
} else if (remoteCtaId) {
assert(txCount == 0 &&
"remote arrive of transaction mbarrier is not implemented yet");
type = triton::nvgpu::MBarriveType::remote;
} else if (txCount > 0) {
type = triton::nvgpu::MBarriveType::expect_tx;
}
Value pred = adaptor.getPred();
if (pred == nullptr) {
pred = int_val(/*width*/ 1, 1);
}
rewriter.replaceOpWithNewOp<triton::nvgpu::MBarrierArriveOp>(
op, mbarrier, pred, remoteCtaId, type, txCount);
return success();
}
};
struct MBarrierWaitOpConversion : public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::MBarrierWaitOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::MBarrierWaitOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::MBarrierWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
rewriter.replaceOpWithNewOp<triton::nvgpu::MBarrierWaitOp>(
op, adaptor.getMbarrier(), adaptor.getPhase());
return success();
}
};
struct ExtractMBarrierOpConversion
: public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::ExtractMBarrierOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::ExtractMBarrierOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::ExtractMBarrierOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto elemTy =
op.getTensor().getType().cast<RankedTensorType>().getElementType();
auto tensorStruct = adaptor.getTensor();
auto index = adaptor.getIndex();
auto ptrTy =
LLVM::LLVMPointerType::get(getTypeConverter()->convertType(elemTy), 3);
auto basePtr =
extract_val(ptrTy, tensorStruct, rewriter.getDenseI64ArrayAttr(0));
Value result = gep(ptrTy, basePtr, index);
rewriter.replaceOp(op, result);
return success();
}
};
struct NamedBarrierArriveOpConversion
: public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::NamedBarrierArriveOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::NamedBarrierArriveOp>::
ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::NamedBarrierArriveOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
rewriter.replaceOpWithNewOp<triton::nvgpu::NamedBarrierArriveOp>(
op, adaptor.getBar(), adaptor.getNumThreads());
return success();
}
};
struct NamedBarrierWaitOpConversion
: public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::NamedBarrierWaitOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::NamedBarrierWaitOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::NamedBarrierWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
rewriter.replaceOpWithNewOp<triton::nvgpu::NamedBarrierWaitOp>(
op, adaptor.getBar(), adaptor.getNumThreads());
return success();
}
};
struct FenceAsyncSharedOpConversion
: public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::FenceAsyncSharedOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::FenceAsyncSharedOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::FenceAsyncSharedOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
rewriter.replaceOpWithNewOp<triton::nvgpu::FenceAsyncSharedOp>(
op, adaptor.getBCluster());
return success();
}
};
void populateBarrierOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation, PatternBenefit benefit) {
patterns.add<AllocMBarrierOpConversion>(typeConverter, allocation, benefit);
patterns.add<MBarrierArriveOpConversion>(typeConverter, allocation, benefit);
patterns.add<MBarrierWaitOpConversion>(typeConverter, allocation, benefit);
patterns.add<ExtractMBarrierOpConversion>(typeConverter, allocation, benefit);
patterns.add<NamedBarrierArriveOpConversion>(typeConverter, allocation,
benefit);
patterns.add<NamedBarrierWaitOpConversion>(typeConverter, allocation,
benefit);
patterns.add<FenceAsyncSharedOpConversion>(typeConverter, allocation,
benefit);
}

View File

@@ -0,0 +1,37 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_BARRIER_OP_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_BARRIER_OP_H
#include "TritonGPUToLLVMBase.h"
using namespace mlir;
using namespace mlir::triton;
void populateBarrierOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation, PatternBenefit benefit);
#endif

View File

@@ -1,8 +1,19 @@
add_mlir_conversion_library(TritonGPUToLLVM
ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp
ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp
ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp
ConvertLayoutOpToLLVM.cpp
DotOpToLLVM/FMA.cpp
DotOpToLLVM/MMAv1.cpp
DotOpToLLVM/MMAv2.cpp
DotOpToLLVM/WGMMA.cpp
DotOpToLLVM.cpp
ElementwiseOpToLLVM.cpp
LoadStoreOpToLLVM.cpp
BarrierOpToLLVM.cpp
TritonGPUToLLVM.cpp
GCNAsmFormat.cpp
PTXAsmFormat.cpp
TritonGPUToLLVMPass.cpp
ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp
ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp
ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp
@@ -21,9 +32,12 @@ add_mlir_conversion_library(TritonGPUToLLVM
PTXAsmFormat.cpp
ReduceOpToLLVM.cpp
ScanOpToLLVM.cpp
Utility.cpp
TypeConverter.cpp
Utility.cpp
ViewOpToLLVM.cpp
TensorPtrOpsToLLVM.cpp
ClusterOpsToLLVM.cpp
RegReallocOpToLLVM.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonGPUToLLVM
@@ -46,4 +60,6 @@ add_mlir_conversion_library(TritonGPUToLLVM
TritonIR
TritonGPUIR
TritonGPUTransforms
TritonNvidiaGPUTransforms
NVGPUIR
)

View File

@@ -0,0 +1,62 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include "ClusterOpsToLLVM.h"
using namespace mlir;
using namespace mlir::triton;
struct ClusterArriveOpConversion : public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::ClusterArriveOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::ClusterArriveOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::ClusterArriveOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::nvgpu::ClusterArriveOp>(
op, op.getRelaxed());
return success();
}
};
struct ClusterWaitOpConversion : public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::ClusterWaitOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::ClusterWaitOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::ClusterWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::nvgpu::ClusterWaitOp>(op);
return success();
}
};
void populateClusterOpsToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation, PatternBenefit benefit) {
patterns.add<ClusterArriveOpConversion>(typeConverter, benefit);
patterns.add<ClusterWaitOpConversion>(typeConverter, benefit);
return;
}

View File

@@ -0,0 +1,37 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_CLUSTER_OPS_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_CLUSTER_OPS_H
#include "TritonGPUToLLVMBase.h"
using namespace mlir;
using namespace mlir::triton;
void populateClusterOpsToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation, PatternBenefit benefit);
#endif

View File

@@ -1,14 +1,20 @@
#include "ConvertLayoutOpToLLVM.h"
#include "Utility.h"
using ::mlir::LLVM::delinearize;
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStridesFromShapeAndOrder;
using ::mlir::LLVM::linearize;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStridesFromShapeAndOrder;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getContigPerThread;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getShapePerCTATile;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::getTotalElemsPerThread;
using ::mlir::triton::gpu::isaDistributedLayout;
@@ -82,6 +88,13 @@ public:
dstLayout.isa<DotOperandEncodingAttr>()) {
return lowerSharedToDotOperand(op, adaptor, rewriter);
}
// forwarding on mma->mma shortcut, lower distributed->distributed otherwise
if (srcLayout.isa<MmaEncodingAttr>() && dstLayout.isa<MmaEncodingAttr>()) {
if (isMmaToMmaShortcut(srcTy, dstTy)) {
rewriter.replaceOp(op, op.getSrc());
return success();
}
}
if (isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout)) {
return lowerDistributedToDistributed(op, adaptor, rewriter);
}
@@ -105,23 +118,25 @@ public:
}
private:
SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
ConversionPatternRewriter &rewriter,
unsigned elemId, RankedTensorType type,
ArrayRef<unsigned> multiDimCTAInRepId,
ArrayRef<unsigned> shapePerCTA) const {
SmallVector<Value>
getMultiDimOffset(Attribute layout, Location loc,
ConversionPatternRewriter &rewriter, unsigned elemId,
RankedTensorType type,
ArrayRef<unsigned> multiDimCTAInRepId,
ArrayRef<unsigned> shapePerCTATile) const {
auto shape = type.getShape();
unsigned rank = shape.size();
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
auto multiDimOffsetFirstElem =
emitBaseIndexForLayout(loc, rewriter, blockedLayout, type);
emitBaseIndexForLayout(loc, rewriter, blockedLayout, type, false);
SmallVector<Value> multiDimOffset(rank);
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
elemId, getSizePerThread(layout), getOrder(layout));
for (unsigned d = 0; d < rank; ++d) {
multiDimOffset[d] = add(multiDimOffsetFirstElem[d],
i32_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
multiDimElemId[d]));
multiDimOffset[d] =
add(multiDimOffsetFirstElem[d],
i32_val(multiDimCTAInRepId[d] * shapePerCTATile[d] +
multiDimElemId[d]));
}
return multiDimOffset;
}
@@ -143,7 +158,7 @@ private:
auto multiDimOffsetParent = getMultiDimOffset(
parentEncoding, loc, rewriter, idxs[elemId], parentTy,
sliceLayout.paddedShape(multiDimCTAInRepId),
sliceLayout.paddedShape(shapePerCTA));
sliceLayout.paddedShape(shapePerCTATile));
SmallVector<Value> multiDimOffset(rank);
for (unsigned d = 0; d < rank + 1; ++d) {
if (d == dim)
@@ -154,6 +169,8 @@ private:
return multiDimOffset;
}
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
auto shapePerCTA = getShapePerCTA(mmaLayout, shape);
auto instrShape = mmaLayout.getInstrShape();
SmallVector<Value> mmaColIdx(4);
SmallVector<Value> mmaRowIdx(2);
Value threadId = getThreadId(rewriter, loc);
@@ -162,27 +179,35 @@ private:
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
// TODO: fix the bug in MMAEncodingAttr document
SmallVector<Value> multiDimWarpId(2);
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
auto order = triton::gpu::getOrder(mmaLayout);
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
if (mmaLayout.isHopper()) {
multiDimWarpId[0] = urem(warpId, i32_val(warpsPerCTA[0]));
multiDimWarpId[1] = udiv(warpId, i32_val(warpsPerCTA[0]));
} else {
auto order = triton::gpu::getOrder(mmaLayout);
multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, order);
}
Value _1 = i32_val(1);
Value _2 = i32_val(2);
Value _4 = i32_val(4);
Value _8 = i32_val(8);
Value _16 = i32_val(16);
if (mmaLayout.isAmpere()) {
multiDimWarpId[0] = urem(multiDimWarpId[0], i32_val(shape[0] / 16));
multiDimWarpId[1] = urem(multiDimWarpId[1], i32_val(shape[1] / 8));
if (mmaLayout.isAmpere() || mmaLayout.isHopper()) {
multiDimWarpId[0] =
urem(multiDimWarpId[0], i32_val(shapePerCTA[0] / instrShape[0]));
multiDimWarpId[1] =
urem(multiDimWarpId[1], i32_val(shapePerCTA[1] / instrShape[1]));
Value mmaGrpId = udiv(laneId, _4);
Value mmaGrpIdP8 = add(mmaGrpId, _8);
Value mmaThreadIdInGrp = urem(laneId, _4);
Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, _2);
Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, _1);
Value rowWarpOffset = mul(multiDimWarpId[0], _16);
Value rowWarpOffset = mul(multiDimWarpId[0], i32_val(instrShape[0]));
mmaRowIdx[0] = add(mmaGrpId, rowWarpOffset);
mmaRowIdx[1] = add(mmaGrpIdP8, rowWarpOffset);
Value colWarpOffset = mul(multiDimWarpId[1], _8);
Value colWarpOffset = mul(multiDimWarpId[1], i32_val(instrShape[1]));
mmaColIdx[0] = add(mmaThreadIdInGrpM2, colWarpOffset);
mmaColIdx[1] = add(mmaThreadIdInGrpM2P1, colWarpOffset);
} else if (mmaLayout.isVolta()) {
@@ -193,13 +218,27 @@ private:
assert(rank == 2);
SmallVector<Value> multiDimOffset(rank);
if (mmaLayout.isAmpere()) {
if (mmaLayout.isHopper()) {
unsigned elemIdRem4 = elemId % 4;
unsigned nGrpId = elemId / 4;
multiDimOffset[0] = elemIdRem4 < 2 ? mmaRowIdx[0] : mmaRowIdx[1];
multiDimOffset[1] = elemIdRem4 % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1];
multiDimOffset[1] = add(multiDimOffset[1], i32_val(8 * nGrpId));
multiDimOffset[0] =
add(multiDimOffset[0],
i32_val(multiDimCTAInRepId[0] * shapePerCTATile[0]));
multiDimOffset[1] =
add(multiDimOffset[1],
i32_val(multiDimCTAInRepId[1] * shapePerCTATile[1]));
} else if (mmaLayout.isAmpere()) {
multiDimOffset[0] = elemId < 2 ? mmaRowIdx[0] : mmaRowIdx[1];
multiDimOffset[1] = elemId % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1];
multiDimOffset[0] = add(
multiDimOffset[0], i32_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
multiDimOffset[1] = add(
multiDimOffset[1], i32_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
multiDimOffset[0] =
add(multiDimOffset[0],
i32_val(multiDimCTAInRepId[0] * shapePerCTATile[0]));
multiDimOffset[1] =
add(multiDimOffset[1],
i32_val(multiDimCTAInRepId[1] * shapePerCTATile[1]));
} else if (mmaLayout.isVolta()) {
auto [isARow, isBRow, isAVec4, isBVec4, _] =
mmaLayout.decodeVoltaLayoutStates();
@@ -240,11 +279,12 @@ private:
auto rank = type.getRank();
auto sizePerThread = getSizePerThread(layout);
auto accumSizePerThread = product<unsigned>(sizePerThread);
SmallVector<unsigned> numCTAs(rank);
SmallVector<unsigned> numCTATiles(rank);
auto shapePerCTATile = getShapePerCTATile(layout);
auto shapePerCTA = getShapePerCTA(layout, type.getShape());
auto order = getOrder(layout);
for (unsigned d = 0; d < rank; ++d) {
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
numCTATiles[d] = ceil<unsigned>(shapePerCTA[d], shapePerCTATile[d]);
}
auto elemTy = type.getElementType();
bool isInt1 = elemTy.isInteger(1);
@@ -267,17 +307,16 @@ private:
}
auto linearCTAId =
getLinearIndex<unsigned>(multiDimCTAId, numCTAs, order);
getLinearIndex<unsigned>(multiDimCTAId, numCTATiles, order);
// TODO: This is actually redundant index calculation, we should
// consider of caching the index calculation result in case
// of performance issue observed.
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
SmallVector<Value> multiDimOffset =
getMultiDimOffset(layout, loc, rewriter, elemId, type,
multiDimCTAInRepId, shapePerCTA);
multiDimCTAInRepId, shapePerCTATile);
Value offset =
linearize(rewriter, loc, multiDimOffset, paddedRepShape, outOrd);
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
Value ptr = gep(elemPtrTy, smemBase, offset);
auto vecTy = vec_ty(llvmElemTy, vec);
@@ -334,7 +373,8 @@ private:
SmallVector<unsigned> numCTAs(rank, 1);
SmallVector<unsigned> numCTAsEachRep(rank, 1);
SmallVector<unsigned> shapePerCTA = getShapePerCTA(layout, shape);
SmallVector<unsigned> shapePerCTATile = getShapePerCTATile(layout, shape);
SmallVector<int64_t> shapePerCTA = getShapePerCTA(layout, shape);
auto elemTy = type.getElementType();
int ctaId = 0;
@@ -364,7 +404,7 @@ private:
// duplicate in Volta.
SmallVector<Value> multiDimOffset =
getMultiDimOffset(layout, loc, rewriter, elemId, type,
multiDimCTAInRepId, shapePerCTA);
multiDimCTAInRepId, shapePerCTATile);
coord2val[elemId] = std::make_pair(multiDimOffset, vals[elemId]);
}
@@ -372,7 +412,7 @@ private:
// do transpose
auto aEncoding =
DotOperandEncodingAttr::get(mma.getContext(), 0, mma, 0);
int numM = aEncoding.getMMAv1NumOuter(shape);
int numM = aEncoding.getMMAv1NumOuter(shapePerCTA);
int numN = accumSizePerThread / numM;
for (int r = 0; r < numM; r++) {
@@ -411,6 +451,91 @@ private:
}
}
LogicalResult
lowerDistToDistWithDistSmem(triton::gpu::ConvertLayoutOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
Value src = op.getSrc();
Value dst = op.getResult();
auto srcTy = src.getType().cast<RankedTensorType>();
auto dstTy = dst.getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding();
auto dstLayout = dstTy.getEncoding();
auto srcShapePerCTA = getShapePerCTA(srcTy);
auto srcCTAsPerCGA = triton::gpu::getCTAsPerCGA(srcLayout);
auto srcCTAOrder = triton::gpu::getCTAOrder(srcLayout);
unsigned rank = srcShapePerCTA.size();
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
smemBase = bitcast(smemBase, elemPtrTy);
auto smemShape = convertType<unsigned, int64_t>(srcShapePerCTA);
// Store to local shared memory
{
auto inVals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
rewriter, srcTy);
auto inIndices =
emitIndices(loc, rewriter, srcLayout, srcTy, /*withCTAOffset*/ false);
assert(inIndices.size() == inVals.size() &&
"Unexpected number of indices emitted");
for (unsigned i = 0; i < inIndices.size(); ++i) {
Value offset = linearize(rewriter, loc, inIndices[i], smemShape);
Value ptr = gep(elemPtrTy, smemBase, offset);
store(inVals[i], ptr);
}
}
// Cluster barrier
rewriter.create<triton::nvidia_gpu::ClusterArriveOp>(loc, false);
rewriter.create<triton::nvidia_gpu::ClusterWaitOp>(loc);
// Load from remote shared memory
{
SmallVector<Value> srcShapePerCTACache;
for (unsigned i = 0; i < rank; ++i)
srcShapePerCTACache.push_back(i32_val(srcShapePerCTA[i]));
SmallVector<Value> outVals;
auto outIndices =
emitIndices(loc, rewriter, dstLayout, dstTy, /*withCTAOffset*/ true);
for (unsigned i = 0; i < outIndices.size(); ++i) {
auto coord = outIndices[i];
assert(coord.size() == rank && "Unexpected rank of index emitted");
SmallVector<Value> multiDimCTAId, localCoord;
for (unsigned d = 0; d < rank; ++d) {
multiDimCTAId.push_back(udiv(coord[d], srcShapePerCTACache[d]));
localCoord.push_back(urem(coord[d], srcShapePerCTACache[d]));
}
Value remoteCTAId =
linearize(rewriter, loc, multiDimCTAId, srcCTAsPerCGA, srcCTAOrder);
Value localOffset = linearize(rewriter, loc, localCoord, smemShape);
Value ptr = gep(elemPtrTy, smemBase, localOffset);
outVals.push_back(load_dsmem(ptr, remoteCTAId));
}
Value result =
getTypeConverter()->packLLElements(loc, outVals, rewriter, dstTy);
rewriter.replaceOp(op, result);
}
// Cluster barrier
rewriter.create<triton::nvidia_gpu::ClusterArriveOp>(loc, false);
rewriter.create<triton::nvidia_gpu::ClusterWaitOp>(loc);
return success();
}
// blocked/mma -> blocked/mma.
// Data padding in shared memory to avoid bank conflict.
LogicalResult
@@ -424,6 +549,10 @@ private:
auto dstTy = dst.getType().cast<RankedTensorType>();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
if (shouldUseDistSmem(srcLayout, dstLayout))
return lowerDistToDistWithDistSmem(op, adaptor, rewriter);
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
@@ -435,8 +564,9 @@ private:
SmallVector<unsigned> outNumCTAsEachRep(rank);
SmallVector<unsigned> inNumCTAs(rank);
SmallVector<unsigned> outNumCTAs(rank);
auto srcShapePerCTA = getShapePerCTA(srcLayout, srcTy.getShape());
auto dstShapePerCTA = getShapePerCTA(dstLayout, shape);
auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape());
auto dstShapePerCTATile = getShapePerCTATile(dstLayout, shape);
auto shapePerCTA = getShapePerCTA(srcLayout, shape);
// For Volta, all the coords for a CTA are calculated.
bool isSrcMmaV1{}, isDstMmaV1{};
@@ -456,15 +586,17 @@ private:
}
for (unsigned d = 0; d < rank; ++d) {
unsigned inPerCTA = std::min<unsigned>(shape[d], srcShapePerCTA[d]);
unsigned outPerCTA = std::min<unsigned>(shape[d], dstShapePerCTA[d]);
unsigned inPerCTA =
std::min<unsigned>(shapePerCTA[d], srcShapePerCTATile[d]);
unsigned outPerCTA =
std::min<unsigned>(shapePerCTA[d], dstShapePerCTATile[d]);
unsigned maxPerCTA = std::max(inPerCTA, outPerCTA);
numReplicates[d] = ceil<unsigned>(shape[d], maxPerCTA);
numReplicates[d] = ceil<unsigned>(shapePerCTA[d], maxPerCTA);
inNumCTAsEachRep[d] = maxPerCTA / inPerCTA;
outNumCTAsEachRep[d] = maxPerCTA / outPerCTA;
assert(maxPerCTA % inPerCTA == 0 && maxPerCTA % outPerCTA == 0);
inNumCTAs[d] = ceil<unsigned>(shape[d], inPerCTA);
outNumCTAs[d] = ceil<unsigned>(shape[d], outPerCTA);
inNumCTAs[d] = ceil<unsigned>(shapePerCTA[d], inPerCTA);
outNumCTAs[d] = ceil<unsigned>(shapePerCTA[d], outPerCTA);
}
// Potentially we need to store for multiple CTAs in this replication
auto accumNumReplicates = product<unsigned>(numReplicates);
@@ -473,7 +605,8 @@ private:
unsigned inVec = 0;
unsigned outVec = 0;
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
if (getElementTypeOrSelf(op.getType()).isa<mlir::Float8E4M3B11FNUZType>()) {
if (getElementTypeOrSelf(op.getType())
.isa<mlir::Float8E4M3B11FNUZType, mlir::Float8E4M3FNType>()) {
assert(inVec % 4 == 0 && "conversion not supported for FP8E4M3B15");
assert(outVec % 4 == 0 && "conversion not supported for FP8E4M3B15");
}
@@ -485,8 +618,25 @@ private:
for (unsigned repId = 0; repId < accumNumReplicates; ++repId) {
auto multiDimRepId =
getMultiDimIndex<unsigned>(repId, numReplicates, outOrd);
if (repId != 0)
barrier();
if (repId != 0) {
// TODO[shuhaoj]: change hard code style of numThreads. Hide async
// attr. Better way to determine barId (number of agents are limited).
if (auto optionalAgentId = getWSAgentId(op)) {
int agentId = *optionalAgentId, roleId = 0;
if (auto optionalRoleId = getWSRoleId(op))
roleId = *optionalRoleId;
int barId = agentId + roleId + nameBarrierIdBegin;
assert(barId < nameBarrierIdEnd);
auto bar = rewriter.create<LLVM::ConstantOp>(
loc, i32_ty, rewriter.getI32IntegerAttr(barId));
auto kNumThreads = rewriter.create<LLVM::ConstantOp>(
loc, i32_ty, rewriter.getI32IntegerAttr(128));
rewriter.create<triton::nvgpu::NamedBarrierWaitOp>(loc, bar,
kNumThreads);
} else {
barrier();
}
}
if (srcLayout.isa<BlockedEncodingAttr>() ||
srcLayout.isa<SliceEncodingAttr>() ||
#ifdef USE_ROCM
@@ -506,7 +656,23 @@ private:
return failure();
}
barrier();
// TODO[shuhaoj]: change hard code style of numThreads. Hide async_agent
// attr. Better way to determine barId (number of agents are limited).
if (auto optionalAgentId = getWSAgentId(op)) {
int agentId = *optionalAgentId, roleId = 0;
if (auto optionalRoleId = getWSRoleId(op))
roleId = *optionalRoleId;
int barId = agentId + roleId + nameBarrierIdBegin;
assert(barId < nameBarrierIdEnd);
auto bar = rewriter.create<LLVM::ConstantOp>(
loc, i32_ty, rewriter.getI32IntegerAttr(barId));
auto kNumThreads = rewriter.create<LLVM::ConstantOp>(
loc, i32_ty, rewriter.getI32IntegerAttr(128));
rewriter.create<triton::nvgpu::NamedBarrierWaitOp>(loc, bar,
kNumThreads);
} else {
barrier();
}
if (dstLayout.isa<BlockedEncodingAttr>() ||
dstLayout.isa<SliceEncodingAttr>() ||
#ifdef USE_ROCM
@@ -580,7 +746,7 @@ private:
auto srcTy = src.getType().cast<RankedTensorType>();
auto srcShape = srcTy.getShape();
auto dstTy = dst.getType().cast<RankedTensorType>();
auto dstShape = dstTy.getShape();
auto dstShapePerCTA = triton::gpu::getShapePerCTA(dstTy);
assert(srcShape.size() == 2 &&
"Unexpected rank of ConvertLayout(blocked->shared)");
auto srcLayout = srcTy.getEncoding();
@@ -592,13 +758,102 @@ private:
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
smemBase = bitcast(smemBase, elemPtrTy);
auto dstStrides =
getStridesFromShapeAndOrder(dstShape, outOrd, loc, rewriter);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy);
storeDistributedToShared(src, adaptor.getSrc(), dstStrides, srcIndices, dst,
smemBase, elemTy, loc, rewriter);
int32_t elemSize = elemTy.getIntOrFloatBitWidth();
auto mmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy);
if (mmaLayout && mmaLayout.isHopper() && elemSize == 16 &&
inOrd == outOrd && numElems >= 16) {
auto inVals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
rewriter, srcTy);
auto srcShapePerCTA = getShapePerCTA(mmaLayout, srcShape);
auto instrShape = mmaLayout.getInstrShape();
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
uint32_t repM =
ceil<unsigned>(srcShapePerCTA[0], instrShape[0] * warpsPerCTA[0]);
uint32_t numElemsPerRep = numElems / repM;
// rowStride in bytes
uint32_t rowStrideInBytes = dstShapePerCTA[outOrd[0]] * 2;
uint32_t swizzlingByteWidth = rowStrideInBytes;
if (swizzlingByteWidth > 128)
swizzlingByteWidth = 128;
unsigned numElemsPerSwizzlingRow = swizzlingByteWidth * 8 / elemSize;
unsigned leadingDimOffset =
numElemsPerSwizzlingRow * srcShapePerCTA[outOrd[1]];
auto ptrI8SharedTy = LLVM::LLVMPointerType::get(
typeConverter->convertType(rewriter.getI8Type()), 3);
uint32_t rowsPerRep = getShapePerCTATile(mmaLayout)[0];
Value threadId = getThreadId(rewriter, loc);
Value warpId = udiv(threadId, i32_val(32));
Value warpId0 = urem(urem(warpId, i32_val(warpsPerCTA[0])),
i32_val(srcShape[0] / instrShape[0]));
unsigned inVec =
inOrd == outOrd ? triton::gpu::getContigPerThread(mmaLayout)[inOrd[0]]
: 1;
unsigned outVec = dstSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
assert(minVec == 2);
auto wordTy = vec_ty(elemTy, minVec);
for (int rep = 0; rep < repM; ++rep) {
Value rowOfWarp = add(mul(warpId0, i32_val(instrShape[0])),
i32_val(rep * rowsPerRep));
uint32_t elemIdxOffset = rep * numElemsPerRep;
for (unsigned idx = 0; idx < numElemsPerRep; idx += 8) {
uint32_t elemIdx = elemIdxOffset + idx;
Value offset = rewriter.create<triton::nvgpu::OffsetOfStmatrixV4Op>(
loc, i32_ty, threadId, rowOfWarp, i32_val(idx), leadingDimOffset,
numElemsPerSwizzlingRow, true);
Value addr = gep(elemPtrTy, smemBase, offset);
Value words[4];
for (unsigned i = 0; i < 8; ++i) {
if (i % minVec == 0)
words[i / 2] = undef(wordTy);
words[i / 2] = insert_element(
wordTy, words[i / 2], inVals[elemIdx + i], i32_val(i % minVec));
}
rewriter.create<triton::nvgpu::StoreMatrixOp>(
loc, bitcast(addr, ptrI8SharedTy),
ValueRange{bitcast(words[0], i32_ty), bitcast(words[1], i32_ty),
bitcast(words[2], i32_ty), bitcast(words[3], i32_ty)});
}
}
// TODO[shuhaoj]: change hard code style of numThreads. Hide async_agent
// attr. Better way to determine barId (number of agents are limited).
if (auto optionalAgentId = getWSAgentId(op)) {
int agentId = *optionalAgentId, roleId = 0;
if (auto optionalRoleId = getWSRoleId(op))
roleId = *optionalRoleId;
int barId = agentId + roleId + nameBarrierIdBegin;
assert(barId < nameBarrierIdEnd);
auto bar = rewriter.create<LLVM::ConstantOp>(
loc, i32_ty, rewriter.getI32IntegerAttr(barId));
auto kNumThreads = rewriter.create<LLVM::ConstantOp>(
loc, i32_ty, rewriter.getI32IntegerAttr(128));
rewriter.create<triton::nvgpu::NamedBarrierWaitOp>(loc, bar,
kNumThreads);
} else {
barrier();
}
} else {
auto dstStrides =
getStridesFromShapeAndOrder(dstShapePerCTA, outOrd, loc, rewriter);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy, false);
storeDistributedToShared(src, adaptor.getSrc(), dstStrides, srcIndices,
dst, smemBase, elemTy, loc, rewriter);
}
auto smemObj =
SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter);
SharedMemoryObject(smemBase, dstShapePerCTA, outOrd, loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
@@ -807,19 +1062,16 @@ private:
auto loc = op.getLoc();
Value src = op.getSrc();
Value dst = op.getResult();
bool isMMA = supportMMA(dst, mmaLayout.getVersionMajor());
auto smemObj =
getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter);
Value res;
if (!isOuter && mmaLayout.isAmpere()) { // tensor core v2
res = SharedToDotOperandMMAv2::convertLayout(
dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout,
smemObj, getTypeConverter(), tid_val());
} else if (!isOuter && mmaLayout.isVolta() &&
supportMMA(dst, mmaLayout.getVersionMajor())) { // tensor core v1
smemObj, getTypeConverter(), getThreadId(rewriter, loc));
} else if (!isOuter && mmaLayout.isVolta() && isMMA) { // tensor core v1
bool isMMAv1Row = dotOperandLayout.getMMAv1IsRow();
auto srcSharedLayout = src.getType()
.cast<RankedTensorType>()
@@ -841,10 +1093,11 @@ private:
}
return res;
}
}; // namespace triton::gpu::ConvertLayoutOp
}; // namespace triton::gpu::ConvertLayoutOp>
void populateConvertLayoutOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) {

View File

@@ -10,6 +10,7 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr;
void populateConvertLayoutOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit);

View File

@@ -2,8 +2,10 @@
#include "../Utility.h"
using ValueTable = std::map<std::pair<int, int>, Value>;
using ::mlir::LLVM::delinearize;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStridesFromShapeAndOrder;
using ::mlir::LLVM::linearize;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getContigPerThread;
using ::mlir::triton::gpu::getOrder;
@@ -14,31 +16,32 @@ using ::mlir::triton::gpu::isaDistributedLayout;
using ::mlir::triton::gpu::SharedEncodingAttr;
SmallVector<Value>
getThreadIds(Value threadId, ArrayRef<unsigned int> shapePerCTA,
getThreadIds(Value threadId, ArrayRef<unsigned int> shapePerCTATile,
ArrayRef<unsigned int> sizePerThread, ArrayRef<unsigned int> order,
ConversionPatternRewriter &rewriter, Location loc) {
int dim = order.size();
SmallVector<Value> threadIds(dim);
for (unsigned k = 0; k < dim - 1; k++) {
Value dimK = i32_val(shapePerCTA[order[k]] / sizePerThread[order[k]]);
Value dimK = i32_val(shapePerCTATile[order[k]] / sizePerThread[order[k]]);
Value rem = urem(threadId, dimK);
threadId = udiv(threadId, dimK);
threadIds[order[k]] = rem;
}
Value dimK = i32_val(shapePerCTA[order[dim - 1]]);
Value dimK = i32_val(shapePerCTATile[order[dim - 1]]);
threadIds[order[dim - 1]] = urem(threadId, dimK);
return threadIds;
}
int getShapePerCTAForMN(BlockedEncodingAttr layout, bool isM) {
// Get shapePerCTATile for M or N axis.
int getShapePerCTATileForMN(BlockedEncodingAttr layout, bool isM) {
auto order = layout.getOrder();
auto shapePerCTA = getShapePerCTA(layout);
auto shapePerCTATile = getShapePerCTATile(layout);
int mShapePerCTA =
order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
int nShapePerCTA =
order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
return isM ? mShapePerCTA : nShapePerCTA;
int mShapePerCTATile =
order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]];
int nShapePerCTATile =
order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]];
return isM ? mShapePerCTATile : nShapePerCTATile;
}
// Get sizePerThread for M or N axis.
@@ -91,7 +94,7 @@ Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread,
ConversionPatternRewriter &rewriter) {
auto aTensorTy = A.getType().cast<RankedTensorType>();
auto aLayout = aTensorTy.getEncoding().cast<SharedEncodingAttr>();
auto aShape = aTensorTy.getShape();
auto aShapePerCTA = getShapePerCTA(aTensorTy);
auto aOrder = aLayout.getOrder();
auto order = dLayout.getOrder();
@@ -104,10 +107,10 @@ Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread,
Value strideA0 = isARow ? strideAK : strideAM;
Value strideA1 = isARow ? strideAM : strideAK;
int aNumPtr = 8;
int K = aShape[1];
int M = aShape[0];
int K = aShapePerCTA[1];
int M = aShapePerCTA[0];
auto shapePerCTA = getShapePerCTA(dLayout);
auto shapePerCTATile = getShapePerCTATile(dLayout);
auto sizePerThread = getSizePerThread(dLayout);
Value _0 = i32_val(0);
@@ -115,8 +118,8 @@ Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread,
Value mContig = i32_val(sizePerThread[order[1]]);
// threadId in blocked layout
auto threadIds =
getThreadIds(thread, shapePerCTA, sizePerThread, order, rewriter, loc);
auto threadIds = getThreadIds(thread, shapePerCTATile, sizePerThread, order,
rewriter, loc);
Value threadIdM = threadIds[0];
Value offA0 = isARow ? _0 : mul(threadIdM, mContig);
@@ -135,11 +138,11 @@ Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread,
SmallVector<Value> vas;
int mShapePerCTA = getShapePerCTAForMN(dLayout, true /*isM*/);
int mShapePerCTATile = getShapePerCTATileForMN(dLayout, true /*isM*/);
int mSizePerThread = getSizePerThreadForMN(dLayout, true /*isM*/);
for (unsigned k = 0; k < K; ++k)
for (unsigned m = 0; m < M; m += mShapePerCTA)
for (unsigned m = 0; m < M; m += mShapePerCTATile)
for (unsigned mm = 0; mm < mSizePerThread; ++mm) {
Value offset =
add(mul(i32_val(m + mm), strideAM), mul(i32_val(k), strideAK));
@@ -156,7 +159,7 @@ Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread,
ConversionPatternRewriter &rewriter) {
auto bTensorTy = B.getType().cast<RankedTensorType>();
auto bLayout = bTensorTy.getEncoding().cast<SharedEncodingAttr>();
auto bShape = bTensorTy.getShape();
auto bShapePerCTA = getShapePerCTA(bTensorTy);
auto bOrder = bLayout.getOrder();
auto order = dLayout.getOrder();
@@ -169,10 +172,10 @@ Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread,
Value strideB0 = isBRow ? strideBN : strideBK;
Value strideB1 = isBRow ? strideBK : strideBN;
int bNumPtr = 8;
int K = bShape[0];
int N = bShape[1];
int K = bShapePerCTA[0];
int N = bShapePerCTA[1];
auto shapePerCTA = getShapePerCTA(dLayout);
auto shapePerCTATile = getShapePerCTATile(dLayout);
auto sizePerThread = getSizePerThread(dLayout);
Value _0 = i32_val(0);
@@ -180,8 +183,8 @@ Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread,
Value nContig = i32_val(sizePerThread[order[0]]);
// threadId in blocked layout
auto threadIds =
getThreadIds(thread, shapePerCTA, sizePerThread, order, rewriter, loc);
auto threadIds = getThreadIds(thread, shapePerCTATile, sizePerThread, order,
rewriter, loc);
Value threadIdN = threadIds[1];
Value offB0 = isBRow ? mul(threadIdN, nContig) : _0;
@@ -200,11 +203,11 @@ Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread,
SmallVector<Value> vbs;
int nShapePerCTA = getShapePerCTAForMN(dLayout, false /*isM*/);
int nShapePerCTATile = getShapePerCTATileForMN(dLayout, false /*isM*/);
int nSizePerThread = getSizePerThreadForMN(dLayout, false /*isM*/);
for (unsigned k = 0; k < K; ++k)
for (unsigned n = 0; n < N; n += nShapePerCTA)
for (unsigned n = 0; n < N; n += nShapePerCTATile)
for (unsigned nn = 0; nn < nSizePerThread; ++nn) {
Value offset =
add(mul(i32_val(n + nn), strideBN), mul(i32_val(k), strideBK));

View File

@@ -203,8 +203,8 @@ static Value loadA(Value tensor, const SharedMemoryObject &smemObj,
SmallVector<Value> elems;
elems.reserve(has.size() * 2);
for (auto item : has) { // has is a map, the key should be ordered.
elems.push_back(item.second.first);
elems.push_back(item.second.second);
elems.push_back(bitcast(item.second.first, i32_ty));
elems.push_back(bitcast(item.second.second, i32_ty));
}
Value res = typeConverter->packLLElements(loc, elems, rewriter, resultTy);
@@ -327,8 +327,8 @@ static Value loadB(Value tensor, const SharedMemoryObject &smemObj,
SmallVector<Value> elems;
for (auto &item : hbs) { // has is a map, the key should be ordered.
elems.push_back(item.second.first);
elems.push_back(item.second.second);
elems.push_back(bitcast(item.second.first, i32_ty));
elems.push_back(bitcast(item.second.second, i32_ty));
}
Value res = typeConverter->packLLElements(loc, elems, rewriter, resultTy);

View File

@@ -19,7 +19,7 @@ using ::mlir::triton::gpu::SharedEncodingAttr;
// Data loader for mma.16816 instruction.
class MMA16816SmemLoader {
public:
MMA16816SmemLoader(int warpsPerTile, ArrayRef<uint32_t> order,
MMA16816SmemLoader(int nPerWarp, int warpsPerTile, ArrayRef<uint32_t> order,
ArrayRef<uint32_t> warpsPerCTA, uint32_t kOrder,
int kWidth, ArrayRef<Value> smemStrides,
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
@@ -93,6 +93,8 @@ private:
int inWarpMatOffset;
// Offset in number of matrices to increment on non-k dim across warps
int warpMatOffset;
int nPerWarp;
};
SmallVector<Value>
@@ -131,10 +133,18 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value warpId, Value lane,
// address (s0,s1) annotates.
Value matOff[2];
matOff[kOrder ^ 1] = add(
mul(warpId, i32_val(warpMatOffset)), // warp offset (kOrder=1)
mul(nkMatArr,
i32_val(inWarpMatOffset))); // matrix offset inside a warp (kOrder=1)
// When B's shape(k, n) is (16, 8) and ldmatrix.x4 is used, the shared memory
// access will be out of bound. In the future we should change this case to
// ldmatrix.x2
if (kOrder == 0 && nPerWarp == 8) {
matOff[kOrder ^ 1] = mul(warpId, i32_val(warpMatOffset));
} else {
matOff[kOrder ^ 1] = add(
mul(warpId, i32_val(warpMatOffset)), // warp offset (kOrder=1)
mul(nkMatArr,
i32_val(
inWarpMatOffset))); // matrix offset inside a warp (kOrder=1)
}
matOff[kOrder] = kMatArr;
// Physical offset (before swizzling)
@@ -390,13 +400,13 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> ptrs, Type matTy,
}
MMA16816SmemLoader::MMA16816SmemLoader(
int warpsPerTile, ArrayRef<uint32_t> order, ArrayRef<uint32_t> warpsPerCTA,
uint32_t kOrder, int kWidth, ArrayRef<Value> smemStrides,
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
ArrayRef<int> matShape, int perPhase, int maxPhase, int elemBytes,
ConversionPatternRewriter &rewriter,
int nPerWarp, int warpsPerTile, ArrayRef<uint32_t> order,
ArrayRef<uint32_t> warpsPerCTA, uint32_t kOrder, int kWidth,
ArrayRef<Value> smemStrides, ArrayRef<int64_t> tileShape,
ArrayRef<int> instrShape, ArrayRef<int> matShape, int perPhase,
int maxPhase, int elemBytes, ConversionPatternRewriter &rewriter,
TritonGPUToLLVMTypeConverter *typeConverter, const Location &loc)
: order(order.begin(), order.end()),
: nPerWarp(nPerWarp), order(order.begin(), order.end()),
warpsPerCTA(warpsPerCTA.begin(), warpsPerCTA.end()), kOrder(kOrder),
kWidth(kWidth), tileShape(tileShape.begin(), tileShape.end()),
instrShape(instrShape.begin(), instrShape.end()),
@@ -490,6 +500,7 @@ std::function<void(int, int)> getLoadMatrixFn(
bool isA, TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc) {
auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto shapePerCTA = getShapePerCTA(tensorTy);
Type eltTy = tensorTy.getElementType();
// We assumes that the input operand of Dot should be from shared layout.
// TODO(Superjomn) Consider other layouts if needed later.
@@ -500,24 +511,19 @@ std::function<void(int, int)> getLoadMatrixFn(
const int elemBytes = tensorTy.getElementTypeBitWidth() / 8;
auto order = sharedLayout.getOrder();
if (tensor.getType()
.cast<RankedTensorType>()
.getElementType()
.isa<mlir::Float8E4M3B11FNUZType>()) {
bool noTrans = (isA ^ order[0] == 0);
assert(noTrans && "float8e4b15 must have row-col layout");
}
if (kWidth != (4 / elemBytes))
assert(vecPhase == 1 || vecPhase == 4 * kWidth);
int nPerWarp =
std::max<int>(shapePerCTA[1] / mmaLayout.getWarpsPerCTA()[1], 8);
// (a, b) is the coordinate.
auto load = [=, &rewriter, &vals](int a, int b) {
MMA16816SmemLoader loader(
warpsPerTile, sharedLayout.getOrder(), mmaLayout.getWarpsPerCTA(),
kOrder, kWidth, smemObj.strides, tensorTy.getShape() /*tileShape*/,
instrShape, matShape, perPhase, maxPhase, elemBytes, rewriter,
typeConverter, loc);
MMA16816SmemLoader loader(nPerWarp, warpsPerTile, sharedLayout.getOrder(),
mmaLayout.getWarpsPerCTA(), kOrder, kWidth,
smemObj.strides, shapePerCTA /*tileShape*/,
instrShape, matShape, perPhase, maxPhase,
elemBytes, rewriter, typeConverter, loc);
// Offset of a slice within the original tensor in shared memory
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
SmallVector<Value> offs =
@@ -559,17 +565,19 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, Value tensor,
TritonGPUToLLVMTypeConverter *typeConverter, Value thread,
bool isA) {
auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto shapePerCTA = getShapePerCTA(tensorTy);
int bitwidth = tensorTy.getElementTypeBitWidth();
auto mmaLayout = encoding.getParent().cast<MmaEncodingAttr>();
SmallVector<int64_t> shape(tensorTy.getShape().begin(),
tensorTy.getShape().end());
ValueTable vals;
int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / bitwidth;
int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth;
<<<<<<< HEAD
auto numRep = encoding.getMMAv2Rep(tensorTy.getShape(), bitwidth);
=======
auto numRep = encoding.getMMAv2Rep(shapePerCTA, bitwidth);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
int kWidth = encoding.getKWidth();
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
@@ -579,14 +587,14 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, Value tensor,
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warp, warpsPerCTA, order);
Value warpM = urem(multiDimWarpId[0], i32_val(shape[0] / 16));
Value warpN = urem(multiDimWarpId[1], i32_val(shape[1] / 8));
Value warpM = urem(multiDimWarpId[0], i32_val(shapePerCTA[0] / 16));
Value warpN = urem(multiDimWarpId[1], i32_val(shapePerCTA[1] / 8));
int warpsPerTile;
if (isA)
warpsPerTile = std::min<int>(warpsPerCTA[0], shape[0] / 16);
warpsPerTile = std::min<int>(warpsPerCTA[0], shapePerCTA[0] / 16);
else
warpsPerTile = std::min<int>(warpsPerCTA[1], shape[1] / 16);
warpsPerTile = std::min<int>(warpsPerCTA[1], shapePerCTA[1] / 16);
std::function<void(int, int)> loadFn;
if (isA)

View File

@@ -4,7 +4,9 @@
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::MmaEncodingAttr;
LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
@@ -15,15 +17,31 @@ LogicalResult convertMMA884(triton::DotOp op, triton::DotOp::Adaptor adaptor,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter);
LogicalResult convertMMA1688(triton::DotOp op, triton::DotOp::Adaptor adaptor,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter);
LogicalResult convertMMA16816(triton::DotOp op, triton::DotOp::Adaptor adaptor,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter);
<<<<<<< HEAD
#ifdef USE_ROCM
LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter);
#endif
=======
LogicalResult convertWGMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Value thread);
LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op,
triton::nvidia_gpu::DotAsyncOp::Adaptor adaptor,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
Value thread);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
using ConvertTritonGPUOpToLLVMPattern<
@@ -32,14 +50,15 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
LogicalResult
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
// D = A * B + C
Value A = op.getA();
Value D = op.getResult();
// Here we assume the DotOp's operands always comes from shared memory.
auto AShape = A.getType().cast<RankedTensorType>().getShape();
auto AShapePerCTA = getShapePerCTA(A.getType());
size_t reduceAxis = 1;
unsigned K = AShape[reduceAxis];
unsigned K = AShapePerCTA[reduceAxis];
bool isOuter = K == 1;
MmaEncodingAttr mmaLayout = D.getType()
@@ -49,8 +68,17 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
if (!isOuter && mmaLayout && supportMMA(op, mmaLayout.getVersionMajor())) {
if (mmaLayout.isVolta())
return convertMMA884(op, adaptor, getTypeConverter(), rewriter);
if (mmaLayout.isTuring())
return convertMMA1688(op, adaptor, getTypeConverter(), rewriter);
if (mmaLayout.isAmpere())
return convertMMA16816(op, adaptor, getTypeConverter(), rewriter);
<<<<<<< HEAD
=======
if (mmaLayout.isHopper())
return convertWGMMA(op, adaptor, getTypeConverter(), rewriter,
getThreadId(rewriter, loc));
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
llvm::report_fatal_error(
"Unsupported MMA kind found when converting DotOp to LLVM.");
}
@@ -76,9 +104,68 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
}
};
struct DotAsyncOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::nvidia_gpu::DotAsyncOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::DotAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::DotAsyncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
// D = A * B + C
Value A = op.getA();
Value D = op.getResult();
// Here we assume the DotOp's operands always comes from shared memory.
auto AShapePerCTA = getShapePerCTA(A.getType());
size_t reduceAxis = 1;
unsigned K = AShapePerCTA[reduceAxis];
bool isOuter = K == 1;
MmaEncodingAttr mmaLayout = D.getType()
.cast<RankedTensorType>()
.getEncoding()
.dyn_cast<MmaEncodingAttr>();
if (!isOuter && mmaLayout &&
supportMMA(op.getOperand(0), mmaLayout.getVersionMajor())) {
if (mmaLayout.isHopper()) {
return convertAsyncWGMMA(op, adaptor, getTypeConverter(), rewriter,
getThreadId(rewriter, loc));
}
llvm::report_fatal_error(
"Unsupported MMA kind found when converting DotAsyncOp to LLVM.");
}
llvm::report_fatal_error(
"Unsupported DotAsyncOp found when converting TritonGPU to LLVM.");
}
};
struct DotWaitOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::nvidia_gpu::DotWaitOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::DotWaitOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::DotWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto pendings = op.getPendings();
rewriter.create<triton::nvgpu::WGMMAWaitGroupOp>(op.getLoc(), pendings);
// Safe to remove the op since it doesn't have any return value.
rewriter.eraseOp(op);
return success();
}
};
void populateDotOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
RewritePatternSet &patterns, int numWarps,
ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
PatternBenefit benefit) {
patterns.add<DotOpConversion>(typeConverter, allocation, benefit);
patterns.add<DotAsyncOpConversion>(typeConverter, allocation, benefit);
patterns.add<DotWaitOpConversion>(typeConverter, allocation, benefit);
}

View File

@@ -7,7 +7,8 @@ using namespace mlir;
using namespace mlir::triton;
void populateDotOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
RewritePatternSet &patterns, int numWarps,
ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
PatternBenefit benefit);

View File

@@ -5,19 +5,20 @@ using namespace mlir;
using namespace mlir::triton;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ValueTableFMA = std::map<std::pair<int, int>, Value>;
static ValueTableFMA getValueTableFromStructFMA(
Value val, int K, int n0, int shapePerCTA, int sizePerThread,
Value val, int K, int n0, int shapePerCTATile, int sizePerThread,
ConversionPatternRewriter &rewriter, Location loc,
TritonGPUToLLVMTypeConverter *typeConverter, Type type) {
ValueTableFMA res;
auto elems = typeConverter->unpackLLElements(loc, val, rewriter, type);
int index = 0;
for (unsigned k = 0; k < K; ++k) {
for (unsigned m = 0; m < n0; m += shapePerCTA)
for (unsigned m = 0; m < n0; m += shapePerCTATile)
for (unsigned mm = 0; mm < sizePerThread; ++mm) {
res[{m + mm, k}] = elems[index++];
}
@@ -40,8 +41,8 @@ LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
auto bTensorTy = B.getType().cast<RankedTensorType>();
auto dTensorTy = D.getType().cast<RankedTensorType>();
auto aShape = aTensorTy.getShape();
auto bShape = bTensorTy.getShape();
auto aShapePerCTA = getShapePerCTA(aTensorTy);
auto bShapePerCTA = getShapePerCTA(bTensorTy);
BlockedEncodingAttr dLayout =
dTensorTy.getEncoding().cast<BlockedEncodingAttr>();
@@ -53,41 +54,42 @@ LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
Value llB = adaptor.getB();
auto sizePerThread = getSizePerThread(dLayout);
auto shapePerCTA = getShapePerCTA(dLayout);
auto shapePerCTATile = getShapePerCTATile(dLayout);
int K = aShape[1];
int M = aShape[0];
int N = bShape[1];
int K = aShapePerCTA[1];
int M = aShapePerCTA[0];
int N = bShapePerCTA[1];
int mShapePerCTA =
order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
int mShapePerCTATile =
order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]];
int mSizePerThread =
order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]];
int nShapePerCTA =
order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
int nShapePerCTATile =
order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]];
int nSizePerThread =
order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]];
auto has =
getValueTableFromStructFMA(llA, K, M, mShapePerCTA, mSizePerThread,
getValueTableFromStructFMA(llA, K, M, mShapePerCTATile, mSizePerThread,
rewriter, loc, typeConverter, aTensorTy);
auto hbs =
getValueTableFromStructFMA(llB, K, N, nShapePerCTA, nSizePerThread,
getValueTableFromStructFMA(llB, K, N, nShapePerCTATile, nSizePerThread,
rewriter, loc, typeConverter, bTensorTy);
SmallVector<Value> ret = cc;
bool isCRow = order[0] == 1;
for (unsigned k = 0; k < K; k++) {
for (unsigned m = 0; m < M; m += mShapePerCTA)
for (unsigned n = 0; n < N; n += nShapePerCTA)
for (unsigned m = 0; m < M; m += mShapePerCTATile)
for (unsigned n = 0; n < N; n += nShapePerCTATile)
for (unsigned mm = 0; mm < mSizePerThread; ++mm)
for (unsigned nn = 0; nn < nSizePerThread; ++nn) {
int mIdx = m / mShapePerCTA * mSizePerThread + mm;
int nIdx = n / nShapePerCTA * nSizePerThread + nn;
int mIdx = m / mShapePerCTATile * mSizePerThread + mm;
int nIdx = n / nShapePerCTATile * nSizePerThread + nn;
int z = isCRow ? mIdx * N / nShapePerCTA * mSizePerThread + nIdx
: nIdx * M / mShapePerCTA * nSizePerThread + mIdx;
int z = isCRow
? mIdx * N / nShapePerCTATile * mSizePerThread + nIdx
: nIdx * M / mShapePerCTATile * nSizePerThread + mIdx;
ret[z] = rewriter.create<LLVM::FMulAddOp>(loc, has[{m + mm, k}],
hbs[{n + nn, k}], ret[z]);
}

View File

@@ -141,7 +141,15 @@ TensorCoreType getMmaType(triton::DotOp op) {
return TensorCoreType::NOT_APPLICABLE;
}
inline static const std::map<TensorCoreType, std::string> mmaInstrPtx = {
inline static const std::map<TensorCoreType, std::string> mmaInstrPtxTuring = {
{TensorCoreType::FP32_FP16_FP16_FP32,
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"},
{TensorCoreType::FP16_FP16_FP16_FP16,
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16"},
};
inline static const std::map<TensorCoreType, std::string> mmaInstrPtxAmpere = {
{TensorCoreType::FP32_FP16_FP16_FP32,
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"},
{TensorCoreType::FP32_BF16_BF16_FP32,
@@ -164,22 +172,23 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc,
Value a, Value b, Value c, Value d, Value loadedA,
Value loadedB, Value loadedC, DotOp op,
DotOpAdaptor adaptor) {
DotOpAdaptor adaptor, bool isTuring) {
MLIRContext *ctx = c.getContext();
auto aTensorTy = a.getType().cast<RankedTensorType>();
auto bTensorTy = b.getType().cast<RankedTensorType>();
auto dTensorTy = d.getType().cast<RankedTensorType>();
SmallVector<int64_t> aShape(aTensorTy.getShape().begin(),
aTensorTy.getShape().end());
auto dShape = dTensorTy.getShape();
auto aShapePerCTA = triton::gpu::getShapePerCTA(aTensorTy);
auto bShapePerCTA = triton::gpu::getShapePerCTA(bTensorTy);
auto dShapePerCTA = triton::gpu::getShapePerCTA(dTensorTy);
int bitwidth = aTensorTy.getElementType().getIntOrFloatBitWidth();
auto repA =
aTensorTy.getEncoding().cast<DotOperandEncodingAttr>().getMMAv2Rep(
aTensorTy.getShape(), bitwidth);
aShapePerCTA, bitwidth);
auto repB =
bTensorTy.getEncoding().cast<DotOperandEncodingAttr>().getMMAv2Rep(
bTensorTy.getShape(), bitwidth);
bShapePerCTA, bitwidth);
assert(repA[1] == repB[0]);
int repM = repA[0], repN = repB[1], repK = repA[1];
@@ -196,23 +205,18 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
auto mmaType = getMmaType(op);
const auto &mmaInstructions =
isTuring ? mmaInstrPtxTuring : mmaInstrPtxAmpere;
auto callMma = [&](unsigned m, unsigned n, unsigned k) {
unsigned colsPerThread = repN * 2;
PTXBuilder builder;
auto &mma = *builder.create(mmaInstrPtx.at(mmaType));
auto &mma = *builder.create(mmaInstructions.at(mmaType));
// using =r for float32 works but leads to less readable ptx.
bool isIntMMA = dTensorTy.getElementType().isInteger(32);
bool isAccF16 = dTensorTy.getElementType().isF16();
auto retArgs =
builder.newListOperand(numMmaRets, isIntMMA || isAccF16 ? "=r" : "=f");
auto aArgs = builder.newListOperand({
{ha[{m, k}], "r"},
{ha[{m + 1, k}], "r"},
{ha[{m, k + 1}], "r"},
{ha[{m + 1, k + 1}], "r"},
});
auto bArgs =
builder.newListOperand({{hb[{n, k}], "r"}, {hb[{n, k + 1}], "r"}});
auto cArgs = builder.newListOperand();
for (int i = 0; i < numMmaRets; ++i) {
cArgs->listAppend(builder.newOperand(
@@ -221,7 +225,32 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
// reuse the output registers
}
mma(retArgs, aArgs, bArgs, cArgs);
if (isTuring) {
auto aArgs1 = builder.newListOperand({
{ha[{m, k}], "r"},
{ha[{m + 1, k}], "r"},
});
auto bArgs1 = builder.newListOperand({
{hb[{n, k}], "r"},
});
auto aArgs2 = builder.newListOperand({
{ha[{m, k + 1}], "r"},
{ha[{m + 1, k + 1}], "r"},
});
auto bArgs2 = builder.newListOperand({{hb[{n, k + 1}], "r"}});
mma(retArgs, aArgs1, bArgs1, cArgs);
mma(retArgs, aArgs2, bArgs2, cArgs);
} else {
auto aArgs = builder.newListOperand({
{ha[{m, k}], "r"},
{ha[{m + 1, k}], "r"},
{ha[{m, k + 1}], "r"},
{ha[{m + 1, k + 1}], "r"},
});
auto bArgs =
builder.newListOperand({{hb[{n, k}], "r"}, {hb[{n, k + 1}], "r"}});
mma(retArgs, aArgs, bArgs, cArgs);
}
Value mmaOut =
builder.launch(rewriter, loc, getMmaRetType(mmaType, op.getContext()));
@@ -258,10 +287,9 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
return success();
}
// Convert to mma.m16n8k16
LogicalResult convertMMA16816(triton::DotOp op, triton::DotOp::Adaptor adaptor,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter) {
LogicalResult convertMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, bool isTuring) {
auto loc = op.getLoc();
auto mmaLayout = op.getResult()
.getType()
@@ -287,5 +315,19 @@ LogicalResult convertMMA16816(triton::DotOp op, triton::DotOp::Adaptor adaptor,
loadC(op.getC(), adaptor.getC(), typeConverter, op.getLoc(), rewriter);
return convertDot(typeConverter, rewriter, op.getLoc(), A, B, C, op.getD(),
loadedA, loadedB, loadedC, op, adaptor);
loadedA, loadedB, loadedC, op, adaptor, isTuring);
}
// Convert to mma.m16n8k8
LogicalResult convertMMA1688(triton::DotOp op, triton::DotOp::Adaptor adaptor,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter) {
return convertMMA(op, adaptor, typeConverter, rewriter, true /*isTuring*/);
}
// Convert to mma.m16n8k16
LogicalResult convertMMA16816(triton::DotOp op, triton::DotOp::Adaptor adaptor,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter) {
return convertMMA(op, adaptor, typeConverter, rewriter, false /*isTuring*/);
}

View File

@@ -0,0 +1,431 @@
/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include "DotOpToLLVM.h"
#include "Utility.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getShapePerCTATile;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
triton::nvgpu::WGMMAEltType getMmaRetType(Value d) {
auto dTy = d.getType().cast<RankedTensorType>().getElementType();
if (dTy.isF32()) {
return triton::nvgpu::WGMMAEltType::f32;
} else if (dTy.isF16()) {
return triton::nvgpu::WGMMAEltType::f16;
} else if (dTy.isInteger(32)) {
return triton::nvgpu::WGMMAEltType::s32;
} else {
llvm::report_fatal_error("Unsupported mma result type found");
}
}
triton::nvgpu::WGMMAEltType getMmaOperandType(Value a, bool allowTF32) {
auto aTy = a.getType().cast<RankedTensorType>().getElementType();
if (aTy.isF16()) {
return triton::nvgpu::WGMMAEltType::f16;
} else if (aTy.isBF16()) {
return triton::nvgpu::WGMMAEltType::bf16;
} else if (aTy.isF32() && allowTF32) {
return triton::nvgpu::WGMMAEltType::tf32;
} else if (aTy.isInteger(8)) {
return triton::nvgpu::WGMMAEltType::s8;
} else if (aTy.isFloat8E5M2()) {
return triton::nvgpu::WGMMAEltType::e5m2;
} else if (aTy.isFloat8E4M3FNUZ()) {
return triton::nvgpu::WGMMAEltType::e4m3;
} else {
llvm::report_fatal_error("Unsupported mma operand type found");
}
}
mlir::triton::nvgpu::WGMMADescMode
getModeFromLayout(const SharedEncodingAttr &layout, uint32_t widthInByte) {
int perPhase = layout.getPerPhase();
int maxPhase = layout.getMaxPhase();
uint32_t swizzlingByteWidth = 0;
mlir::triton::nvgpu::WGMMADescMode mode;
if (perPhase == 4 && maxPhase == 2) {
mode = mlir::triton::nvgpu::WGMMADescMode::swizzle32;
swizzlingByteWidth = 32;
} else if (perPhase == 2 && maxPhase == 4) {
mode = mlir::triton::nvgpu::WGMMADescMode::swizzle64;
swizzlingByteWidth = 64;
} else if (perPhase == 1 && maxPhase == 8) {
mode = mlir::triton::nvgpu::WGMMADescMode::swizzle128;
swizzlingByteWidth = 128;
} else {
llvm::report_fatal_error("Unsupported shared layout.");
}
// TODO[biaow]: remove it once we support swizzling size larger than matrix
// width, which requires padding the matrix width to the swizzling size when
// allocating shared memory.
assert(swizzlingByteWidth <= widthInByte &&
"swizzling size larger than matrix width is not supported.");
return mode;
}
class DotOpMmaV3SmemLoader {
public:
DotOpMmaV3SmemLoader(Value tensor, const SharedMemoryObject &smemObj,
SmallVector<int64_t> shape, Value warpId,
unsigned int dimWpt, bool trans,
SmallVector<unsigned int> instrShape,
ConversionPatternRewriter &rewriter, Location loc)
: base(smemObj.base), shape(shape), warpId(warpId), dimWpt(dimWpt),
trans(trans), instrShape(instrShape), rewriter(rewriter), loc(loc) {
auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
ord = sharedLayout.getOrder();
const int perPhase = sharedLayout.getPerPhase();
const int maxPhase = sharedLayout.getMaxPhase();
elemBytes = tensorTy.getElementTypeBitWidth() / 8;
elemsPerSwizzlingRow = 128 / perPhase / elemBytes;
elemsPerSwizzlingRowVal = i32_val(elemsPerSwizzlingRow);
uint32_t widthInByte = shape[ord[0]] * elemBytes;
mode = getModeFromLayout(sharedLayout, widthInByte);
baseDesc = rewriter.create<triton::nvgpu::WGMMADescCreateOp>(
loc, i64_ty, base, i32_val(shape[ord[1]]), mode);
}
Value smemLoad(int a, int b) {
Value k = i32_val(b * instrShape[1]);
Value m = add(i32_val(a * dimWpt * instrShape[0]),
mul(warpId, i32_val(instrShape[0])));
if (trans) {
std::swap(k, m);
}
Value leading_offset = mul(udiv(k, elemsPerSwizzlingRowVal),
i32_val(shape[ord[1]] * elemsPerSwizzlingRow));
Value stride_offset = mul(m, elemsPerSwizzlingRowVal);
Value offset = add(add(leading_offset, stride_offset),
urem(k, elemsPerSwizzlingRowVal));
Value off1 = mul(i32_val(elemBytes), offset);
Value off_ = zext(i64_ty, udiv(off1, i32_val(16)));
return add(baseDesc, off_);
}
private:
Value base;
SmallVector<int64_t> shape;
Value warpId;
int dimWpt;
bool trans;
Value elemsPerSwizzlingRowVal;
mlir::triton::nvgpu::WGMMADescMode mode;
SmallVector<unsigned int> instrShape;
ArrayRef<unsigned> ord;
ConversionPatternRewriter &rewriter;
Location loc;
int elemsPerSwizzlingRow;
int elemBytes;
Value baseDesc;
};
DotOpMmaV3SmemLoader loadA(TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc,
const MmaEncodingAttr &mmaEncoding, Value tensor,
const SharedMemoryObject &smemObj, Value thread) {
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
auto aSharedLayout = aTensorTy.getEncoding().dyn_cast<SharedEncodingAttr>();
assert(aSharedLayout && "only support load dot operand from shared.");
auto instrShape = mmaEncoding.getInstrShape();
auto wpt = mmaEncoding.getWarpsPerCTA();
auto aOrd = aSharedLayout.getOrder();
bool transA = aOrd[0] == 0;
auto shapePerCTA = getShapePerCTA(aTensorTy);
int numRepM = ceil<unsigned>(shapePerCTA[0], instrShape[0] * wpt[0]);
int numRepK = ceil<unsigned>(shapePerCTA[1], instrShape[2]);
Value warp = udiv(thread, i32_val(32));
Value warpM = urem(warp, i32_val(wpt[0]));
Value warpId = urem(warpM, i32_val(shapePerCTA[0] / instrShape[0]));
return {tensor,
smemObj,
shapePerCTA,
warpId,
wpt[0],
transA,
{instrShape[0], instrShape[2]},
rewriter,
loc};
}
DotOpMmaV3SmemLoader loadB(TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc,
MmaEncodingAttr &mmaEncoding, Value tensor,
const SharedMemoryObject &smemObj, Value thread) {
auto bTensorTy = tensor.getType().cast<RankedTensorType>();
auto bSharedLayout = bTensorTy.getEncoding().cast<SharedEncodingAttr>();
assert(bSharedLayout && "only support load B from shared.");
auto instrShape = mmaEncoding.getInstrShape();
auto wpt = mmaEncoding.getWarpsPerCTA();
auto bOrd = bSharedLayout.getOrder();
bool transB = bOrd[0] == 1;
auto shapePerCTA = triton::gpu::getShapePerCTA(bTensorTy);
int numRepK = ceil<unsigned>(shapePerCTA[0], instrShape[2]);
int numRepN = ceil<unsigned>(shapePerCTA[1], instrShape[1] * wpt[1]);
Value warp = udiv(thread, i32_val(32));
Value warpMN = udiv(warp, i32_val(wpt[0]));
Value warpN = urem(warpMN, i32_val(wpt[1]));
Value warpId = urem(warpN, i32_val(shapePerCTA[1] / instrShape[1]));
return {tensor,
smemObj,
shapePerCTA,
warpId,
wpt[1],
transB,
{instrShape[1], instrShape[2]},
rewriter,
loc};
}
// Return a vector of Value of the accumulator start at startIndex and pack the
// values into 32bits in case the accumulator is fp16.
llvm::SmallVector<Value> loadC(ConversionPatternRewriter &rewriter,
Location loc, const SmallVector<Value> &elements,
int startIndex, int numElements) {
if (!elements[0].getType().isF16()) {
llvm::SmallVector<Value> mmaOut(numElements);
for (int i = 0; i < numElements; ++i)
mmaOut[i] = elements[startIndex + i];
return mmaOut;
}
// For FP16 we need to pack accumulator into 32-bit integers.
llvm::SmallVector<Value> mmaOut(numElements / 2);
for (int i = 0; i < numElements / 2; ++i) {
Value a0 = elements[startIndex + 2 * i];
Value a1 = elements[startIndex + 2 * i + 1];
Type cPackTy = vec_ty(rewriter.getF16Type(), 2);
Value pack = rewriter.create<LLVM::UndefOp>(loc, cPackTy);
pack = insert_element(cPackTy, pack, a0, i32_val(0));
pack = insert_element(cPackTy, pack, a1, i32_val(1));
pack = bitcast(pack, rewriter.getIntegerType(32));
mmaOut[i] = pack;
}
return mmaOut;
}
// If the accumulator is fp16 unpack it from 32-bit integers.
SmallVector<Value> unpackAccumulator(ConversionPatternRewriter &rewriter,
Location loc,
const SmallVector<Value> &packed,
RankedTensorType tensorTy) {
if (!tensorTy.getElementType().isF16())
return packed;
// For fp16 the accumualtor is pack into 32-bit integers so we need to unpack
// it.
SmallVector<Value> results;
for (Value elem : packed) {
elem = bitcast(elem, vec_ty(rewriter.getF16Type(), 2));
results.push_back(extract_element(rewriter.getF16Type(), elem, i32_val(0)));
results.push_back(extract_element(rewriter.getF16Type(), elem, i32_val(1)));
}
return results;
}
LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc,
Operation *op, Value a, Value b, Value c, Value d,
Value loadedA, Value loadedB, Value loadedC,
bool allowTF32, const SharedMemoryObject &smemObjA,
const SharedMemoryObject &smemObjB, bool sync,
Value thread) {
auto aTensorTy = a.getType().cast<RankedTensorType>();
auto bTensorTy = b.getType().cast<RankedTensorType>();
auto dTensorTy = d.getType().cast<RankedTensorType>();
auto aSharedLayout = aTensorTy.getEncoding().cast<SharedEncodingAttr>();
auto bSharedLayout = bTensorTy.getEncoding().cast<SharedEncodingAttr>();
auto mmaEncoding = dTensorTy.getEncoding().cast<MmaEncodingAttr>();
auto aOrd = aSharedLayout.getOrder();
auto bOrd = bSharedLayout.getOrder();
bool transA = aOrd[0] == 0;
bool transB = bOrd[0] == 1;
auto dShapePerCTA = getShapePerCTA(dTensorTy);
auto instrShape = mmaEncoding.getInstrShape();
auto accSize = 2 * (instrShape[1] / 4);
int M = 4 * instrShape[0];
int N = instrShape[1];
int K = instrShape[2];
auto shapePerCTATile = getShapePerCTATile(mmaEncoding);
int numRepM = ceil<unsigned>(dShapePerCTA[0], shapePerCTATile[0]);
int numRepN = ceil<unsigned>(dShapePerCTA[1], shapePerCTATile[1]);
int numRepK = ceil<unsigned>(aTensorTy.getShape()[1], instrShape[2]);
DotOpMmaV3SmemLoader aLoader =
loadA(typeConverter, rewriter, loc, mmaEncoding, a, smemObjA, thread);
DotOpMmaV3SmemLoader bLoader =
loadB(typeConverter, rewriter, loc, mmaEncoding, b, smemObjB, thread);
auto fc = typeConverter->unpackLLElements(loc, loadedC, rewriter, dTensorTy);
triton::nvgpu::WGMMAEltType eltTypeC = getMmaRetType(d);
triton::nvgpu::WGMMAEltType eltTypeA = getMmaOperandType(a, allowTF32);
triton::nvgpu::WGMMAEltType eltTypeB = getMmaOperandType(b, allowTF32);
triton::nvgpu::WGMMALayout layoutA = transA ? triton::nvgpu::WGMMALayout::col
: triton::nvgpu::WGMMALayout::row;
triton::nvgpu::WGMMALayout layoutB = transB ? triton::nvgpu::WGMMALayout::row
: triton::nvgpu::WGMMALayout::col;
auto func = op->getParentOfType<LLVM::LLVMFuncOp>();
int numTMADescs =
func->getAttr(kAttrNumTMALoadDescsName).cast<IntegerAttr>().getInt();
if (numTMADescs == 0)
rewriter.create<triton::nvgpu::FenceAsyncSharedOp>(loc, 0);
rewriter.create<triton::nvgpu::WGMMAFenceOp>(loc);
SmallVector<Value> mmaResults;
for (int m = 0; m < numRepM; ++m) {
for (int n = 0; n < numRepN; ++n) {
llvm::SmallVector<Value> mmaOut =
loadC(rewriter, loc, fc, (m * numRepN + n) * accSize, accSize);
llvm::SmallVector<Type> elemTypes;
for (Value accEl : mmaOut)
elemTypes.push_back(accEl.getType());
auto accTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
Value d = typeConverter->packLLElements(loc, mmaOut, rewriter, accTy);
for (int k = 0; k < numRepK; ++k) {
auto a = aLoader.smemLoad(m, k);
auto b = bLoader.smemLoad(n, k);
ValueRange operands{a, b, d};
d = rewriter.create<triton::nvgpu::WGMMAOp>(loc, accTy, a, b, d, M, N,
K, eltTypeC, eltTypeA,
eltTypeB, layoutA, layoutB);
}
auto acc = typeConverter->unpackLLElements(loc, d, rewriter, accTy);
for (int i = 0; i < acc.size(); ++i) {
mmaResults.push_back(acc[i]);
}
}
}
rewriter.create<triton::nvgpu::WGMMACommitGroupOp>(loc);
if (sync)
rewriter.create<triton::nvgpu::WGMMAWaitGroupOp>(loc, 0);
SmallVector<Value> results =
unpackAccumulator(rewriter, loc, mmaResults, dTensorTy);
// replace with new packed result
Type structTy = LLVM::LLVMStructType::getLiteral(
mmaEncoding.getContext(),
SmallVector<Type>(results.size(), dTensorTy.getElementType()));
auto res = typeConverter->packLLElements(loc, results, rewriter, structTy);
rewriter.replaceOp(op, res);
return success();
}
// Loading $c to registers, returns a Value.
Value loadC(Value tensor, Value llTensor) {
auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto mmaEncoding = tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>();
assert(mmaEncoding && "Currently, we only support $c with a mma layout.");
auto instrShape = mmaEncoding.getInstrShape();
auto wpt = mmaEncoding.getWarpsPerCTA();
auto shapePerCTA = getShapePerCTA(tensorTy);
auto shapePerCTATile = getShapePerCTATile(mmaEncoding);
int numRepM = ceil<unsigned>(shapePerCTA[0], shapePerCTATile[0]);
int numRepN = ceil<unsigned>(shapePerCTA[1], shapePerCTATile[1]);
size_t fcSize = 2 * (instrShape[1] / 4) * numRepM * numRepN;
auto structTy = llTensor.getType().cast<LLVM::LLVMStructType>();
assert(structTy.getBody().size() == fcSize &&
"DotOp's $c operand should pass the same number of values as $d in "
"mma layout.");
return llTensor;
}
LogicalResult convertWGMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Value thread) {
auto loc = op.getLoc();
Value A = op.getA();
Value B = op.getB();
Value C = op.getC();
auto ATensorTy = A.getType().cast<RankedTensorType>();
auto BTensorTy = B.getType().cast<RankedTensorType>();
assert(ATensorTy.getEncoding().isa<SharedEncodingAttr>() &&
BTensorTy.getEncoding().isa<SharedEncodingAttr>() &&
"Both $a and %b should be Shared layout.");
Value llA, llB, llC;
llA = adaptor.getA();
llB = adaptor.getB();
llC = loadC(C, adaptor.getC());
auto smemObjA = getSharedMemoryObjectFromStruct(loc, llA, rewriter);
auto smemObjB = getSharedMemoryObjectFromStruct(loc, llB, rewriter);
return convertDot(typeConverter, rewriter, loc, op.getOperation(), A, B, C,
op.getD(), llA, llB, llC, op.getAllowTF32(), smemObjA,
smemObjB, true, thread);
}
LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op,
triton::nvidia_gpu::DotAsyncOp::Adaptor adaptor,
TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
Value thread) {
auto loc = op.getLoc();
Value A = op.getA();
Value B = op.getB();
Value C = op.getC();
auto ATensorTy = A.getType().cast<RankedTensorType>();
auto BTensorTy = B.getType().cast<RankedTensorType>();
assert(ATensorTy.getEncoding().isa<SharedEncodingAttr>() &&
BTensorTy.getEncoding().isa<SharedEncodingAttr>() &&
"Both $a and %b should be Shared layout.");
Value llA, llB, llC;
llA = adaptor.getA();
llB = adaptor.getB();
llC = loadC(C, adaptor.getC());
auto smemObjA = getSharedMemoryObjectFromStruct(loc, llA, rewriter);
auto smemObjB = getSharedMemoryObjectFromStruct(loc, llB, rewriter);
return convertDot(typeConverter, rewriter, loc, op.getOperation(), A, B, C,
op.getD(), llA, llB, llC, op.getAllowTF32(), smemObjA,
smemObjB, false, thread);
}

View File

@@ -47,13 +47,11 @@ Fp16_to_Fp8E5M2(Location loc, ConversionPatternRewriter &rewriter,
const std::string Fp16_to_Fp8E5M2 =
"{ \n"
".reg .b32 a<2>; \n"
"and.b32 a0, $1, 0x7fff7fff; \n" // a0 &= 0x7fff7fff
"and.b32 a1, $2, 0x7fff7fff; \n" // (strip sign)
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
"lop3.b32 a0, $1, 0x80008000, a0, 0xea; \n" // a0 = a0|(0x80008000&in0)
"lop3.b32 a1, $2, 0x80008000, a1, 0xea; \n" // (restore sign)
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
"and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe
"and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit)
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
"}";
#endif
@@ -318,15 +316,16 @@ Fp8E4M3B15_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
const std::string Fp8E4M3B15_to_Fp16 =
"{ \n"
".reg .b32 a<2>, b<2>; \n"
"prmt.b32 a0, 0, $2, 0x5040; \n"
"prmt.b32 a1, 0, $2, 0x7060; \n"
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n"
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n"
"prmt.b32 a0, 0, $2, 0x5746; \n"
"and.b32 b0, a0, 0x7f007f00; \n"
"and.b32 b1, a0, 0x00ff00ff; \n"
"and.b32 a1, a0, 0x00800080; \n"
"shr.b32 b0, b0, 1; \n"
"shr.b32 b1, b1, 1; \n"
"add.u32 b1, b1, a1; \n"
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n"
"shl.b32 $1, b1, 7; \n"
"} \n";
<<<<<<< HEAD
#endif
#ifdef USE_ROCM
@@ -396,6 +395,43 @@ const std::string Fp16_to_Fp8E4M3B15 =
"prmt.b32 $0, b0, b1, 0x7531; \n"
"}";
#endif
=======
const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) {
std::string ret;
ret += "{ \n"
".reg .pred p<4>; \n"
".reg .b32 a<2>, b<2>; \n"
".reg .b16 c<4>; \n"
".reg .b16 max_val_f16; \n"
".reg .b32 max_val_f16x2; \n"
"mov.b16 max_val_f16, 0x3F80; \n"
"mov.b32 max_val_f16x2, 0x3F803F80; \n"
"and.b32 a0, $1, 0x7fff7fff; \n"
"and.b32 a1, $2, 0x7fff7fff; \n";
if (has_minx2)
ret += "min.f16x2 a0, a0, max_val_f16x2; \n"
"min.f16x2 a1, a1, max_val_f16x2; \n";
else
ret += "setp.lt.f16x2 p0|p1, a0, max_val_f16x2; \n"
"setp.lt.f16x2 p2|p3, a1, max_val_f16x2; \n"
"mov.b32 {c0, c1}, a0; \n"
"mov.b32 {c2, c3}, a1; \n"
"selp.b16 c0, c0, max_val_f16, p0; \n"
"selp.b16 c1, c1, max_val_f16, p1; \n"
"selp.b16 c2, c2, max_val_f16, p2; \n"
"selp.b16 c3, c3, max_val_f16, p3; \n"
"mov.b32 a0, {c0, c1}; \n"
"mov.b32 a1, {c2, c3}; \n";
ret += "mad.lo.u32 a0, a0, 2, 0x00800080; \n"
"mad.lo.u32 a1, a1, 2, 0x00800080; \n"
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n"
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n"
"prmt.b32 $0, b0, b1, 0x7531; \n"
"}";
return ret;
}
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
/* ----- FP8E4M3B15X4 ------ */
// NOTE: NOT USED RIGHT NOW
@@ -418,12 +454,11 @@ Fp8E4M3B15x4_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
const std::string Fp8E4M3B15x4_to_Fp16 =
"{ \n"
".reg .b32 a<2>; \n"
"shl.b32 a0, $2, 1; \n"
"add.u32 a0, $2, $2; \n"
"shl.b32 a1, $2, 7; \n"
"and.b32 $0, a0, 0x80008000; \n"
"lop3.b32 $0, $0, a1, 0x3f803f80, 0xf8; \n"
"and.b32 $1, $2, 0x80008000; \n"
"lop3.b32 $1, $1, $2, 0x3f803f80, 0xf8; \n"
"and.b32 $1, $2, 0xbf80bf80; \n"
"}";
#endif
@@ -449,11 +484,11 @@ const std::string Fp16_to_Fp8E4M3B15x4 =
"shr.b32 a1, $1, 7; \n"
"and.b32 $0, a0, 0x40004000; \n"
"lop3.b32 $0, $0, a1, 0x007f007f, 0xf8; \n"
"lop3.b32 $0, $0, $2, 0x80008000, 0xf8; \n"
"lop3.b32 $0, $0, $2, 0x3f803f80, 0xf8; \n"
"lop3.b32 $0, $0, $2, 0xbf80bf80, 0xf8; \n"
"}";
#endif
<<<<<<< HEAD
/* ----- FP8E4M3 ------ */
// Note: when handled by software, this format
// does not handle denormals and has
@@ -753,6 +788,16 @@ const std::string Bf16_to_Fp8E4M3 =
"or.b32 $0, nosign, sign; \n" // restore sign
"}";
#endif
=======
// Fp8E4M3 (x2) -> Fp16 (x2) (packed)
const std::string Fp8E4M3Nv_to_Fp16 = "{ \n"
"cvt.rn.f16x2.e4m3x2 $0, $1; \n"
"}";
// Fp16 (x2) -> Fp8E4M3 (x2) (packed)
const std::string Fp16_to_Fp8E4M3Nv = "{ \n"
"cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; \n"
"}";
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
/* ----- Packed integer to BF16 ------ */
#ifndef USE_ROCM
@@ -852,6 +897,13 @@ static SmallVector<Value> reorderValues(const SmallVector<Value> &values,
llvm_unreachable("unimplemented code path");
}
inline Type getElementType(Value value) {
auto type = value.getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return tensorType.getElementType();
return type;
}
inline SmallVector<Value> unpackI32(const SmallVector<Value> &inValues,
Type srcTy,
ConversionPatternRewriter &rewriter,
@@ -903,40 +955,49 @@ inline SmallVector<Value> packI32(const SmallVector<Value> &inValues,
}
typedef std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
const Value &, const Value &,
const Value &, const Value &)>
const SmallVector<Value> &)>
ConverterT;
static ConverterT makeConverterFromPtx(const std::string &ptxAsm, Type inType,
Type outType) {
Type outType,
const int inVecWidthBits = 32,
const int outVecWidthBits = 32) {
ConverterT converter =
[ptxAsm, inType, outType, inVecWidthBits,
outVecWidthBits](Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) -> SmallVector<Value> {
int numElements = v.size();
assert(numElements == 4 || numElements == 2 && "invalid vector size");
ConverterT converter = [ptxAsm, inType, outType](
Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) -> SmallVector<Value> {
SmallVector<Value> v = {v0, v1, v2, v3};
auto ctx = rewriter.getContext();
int inBitwidth = inType.getIntOrFloatBitWidth();
int outBitwidth = outType.getIntOrFloatBitWidth();
// first, we pack `v` into 32-bit ints
int inVecWidth = 32 / inBitwidth;
int inVecWidth = inVecWidthBits / inBitwidth;
auto inVecTy = vec_ty(inType, inVecWidth);
SmallVector<Value> inPacked(4 / inVecWidth, undef(inVecTy));
for (size_t i = 0; i < 4; i++)
SmallVector<Value> inPacked(numElements / inVecWidth, undef(inVecTy));
for (size_t i = 0; i < numElements; i++)
inPacked[i / inVecWidth] = insert_element(
inVecTy, inPacked[i / inVecWidth], v[i], i32_val(i % inVecWidth));
for (size_t i = 0; i < inPacked.size(); i++)
inPacked[i] = bitcast(inPacked[i], i32_ty);
inPacked[i] = bitcast(inPacked[i], int_ty(inVecWidthBits));
// then, we run the provided inline PTX
int outVecWidth = 32 / outBitwidth;
int outNums = 4 / outVecWidth;
int outVecWidth = outVecWidthBits / outBitwidth;
int outNums = numElements / outVecWidth;
PTXBuilder builder;
SmallVector<PTXBuilder::Operand *> operands;
for (int i = 0; i < outNums; i++)
operands.push_back(builder.newOperand("=r"));
for (Value inVal : inPacked)
operands.push_back(builder.newOperand(inVal, "r"));
auto outConstriant = outVecWidthBits == 16 ? "=h" : "=r";
auto inConstraint = inVecWidthBits == 16 ? "h" : "r";
for (int i = 0; i < outNums; i++) {
operands.push_back(builder.newOperand(outConstriant));
}
for (Value inVal : inPacked) {
operands.push_back(builder.newOperand(inVal, inConstraint));
}
auto &ptxOp = *builder.create(ptxAsm);
ptxOp(operands, /*onlyAttachMLIRArgs=*/true);
auto outVecTy = vec_ty(outType, outVecWidth);
@@ -951,7 +1012,7 @@ static ConverterT makeConverterFromPtx(const std::string &ptxAsm, Type inType,
}
// unpack the output
SmallVector<Value> ret;
for (size_t i = 0; i < 4; i++)
for (size_t i = 0; i < numElements; i++)
ret.push_back(extract_element(outType, outPacked[i / outVecWidth],
i32_val(i % outVecWidth)));
return ret;
@@ -1042,6 +1103,9 @@ public:
return success();
}
private:
int computeCapability;
};
template <typename SourceOp, typename DestOp>
@@ -1054,11 +1118,6 @@ struct ElementwiseOpConversion
using Base::Base;
using OpAdaptor = typename Base::OpAdaptor;
explicit ElementwiseOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ElementwiseOpConversionBase<SourceOp, ElementwiseOpConversion>(
typeConverter, benefit) {}
// An interface to support variant DestOp builder.
SmallVector<DestOp> createDestOps(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
@@ -1075,6 +1134,11 @@ struct FpToFpOpConversion
using ElementwiseOpConversionBase<
triton::FpToFpOp, FpToFpOpConversion>::ElementwiseOpConversionBase;
explicit FpToFpOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
int computeCapability, PatternBenefit benefit = 1)
: ElementwiseOpConversionBase(typeConverter, benefit),
computeCapability(computeCapability) {}
static Value convertBf16ToFp32(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
@@ -1168,6 +1232,7 @@ struct FpToFpOpConversion
auto F8E4M3B15TyID = TypeID::get<mlir::Float8E4M3B11FNUZType>();
auto F8E4M3TyID = TypeID::get<mlir::Float8E4M3FNUZType>();
auto F8E5M2TyID = TypeID::get<mlir::Float8E5M2Type>();
auto F8E4M3FNTyID = TypeID::get<mlir::Float8E4M3FNType>();
auto F16TyID = TypeID::get<mlir::Float16Type>();
auto BF16TyID = TypeID::get<mlir::BFloat16Type>();
auto F32TyID = TypeID::get<mlir::Float32Type>();
@@ -1179,19 +1244,29 @@ struct FpToFpOpConversion
#endif
// F8 -> F16
{{F8E4M3B15TyID, F16TyID}, Fp8E4M3B15_to_Fp16},
{{F8E4M3TyID, F16TyID}, Fp8E4M3_to_Fp16},
{{F8E4M3FNTyID, F16TyID}, Fp8E4M3B15x4_to_Fp16},
{{F8E4M3TyID, F16TyID}, Fp8E4M3Nv_to_Fp16},
{{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16},
// F16 -> F8
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15},
{{F16TyID, F8E4M3TyID}, Fp16_to_Fp8E4M3},
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15(computeCapability >= 80)},
{{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4},
{{F16TyID, F8E4M3TyID}, Fp16_to_Fp8E4M3Nv},
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2},
// F8 -> BF16
{{F8E4M3TyID, BF16TyID}, Fp8E4M3_to_Bf16},
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16},
// BF16 -> F8
{{BF16TyID, F8E4M3TyID}, Bf16_to_Fp8E4M3},
{{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2},
};
int inVecWidthBits = 32;
int outVecWidthBits = 32;
if (srcTy.isFloat8E4M3FNUZ()) {
inVecWidthBits = 16;
outVecWidthBits = 32;
}
if (dstTy.isFloat8E4M3FNUZ()) {
inVecWidthBits = 32;
outVecWidthBits = 16;
}
std::pair<TypeID, TypeID> key = {srcTy.getTypeID(), dstTy.getTypeID()};
if (srcMap.count(key) == 0) {
@@ -1199,6 +1274,7 @@ struct FpToFpOpConversion
<< "\n";
llvm_unreachable("");
}
<<<<<<< HEAD
#ifdef USE_ROCM
return srcMap.lookup(key);
#else
@@ -1206,27 +1282,46 @@ struct FpToFpOpConversion
getTypeConverter()->convertType(srcTy),
getTypeConverter()->convertType(dstTy));
#endif
=======
if (computeCapability < 90 &&
(srcTy.isFloat8E4M3FNUZ() || dstTy.isFloat8E4M3FNUZ())) {
llvm::errs() << "Conversion from/to f8e4m3nv is only supported on "
"compute capability >= 90"
<< "\n";
llvm_unreachable("");
}
return makeConverterFromPtx(srcMap.lookup(key),
getTypeConverter()->convertType(srcTy),
getTypeConverter()->convertType(dstTy),
inVecWidthBits, outVecWidthBits);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
}
SmallVector<Value> createDestOps(triton::FpToFpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
assert(operands.size() % 4 == 0 &&
"FP8 casting only support tensors with 4-aligned sizes");
auto srcElementType = getElementType(op.getFrom());
auto dstElementType = getElementType(op.getResult());
int numElements = 4;
if (srcElementType.isFloat8E4M3FNUZ() ||
dstElementType.isFloat8E4M3FNUZ()) {
numElements = 2;
}
assert(operands.size() % numElements == 0 &&
"FP8 casting only support tensors with aligned sizes");
bool isSrcFP32 = srcElementType.isF32();
bool isDstFP32 = dstElementType.isF32();
auto cvtFunc = getConversionFunc(isSrcFP32 ? f16_ty : srcElementType,
isDstFP32 ? f16_ty : dstElementType);
SmallVector<Value> inVals = {operands[0][0], operands[1][0], operands[2][0],
operands[3][0]};
SmallVector<Value> inVals;
for (unsigned i = 0; i < numElements; i++) {
inVals.push_back(operands[i][0]);
}
if (isSrcFP32)
for (Value &v : inVals)
v = convertFp32ToFp16(loc, rewriter, v);
SmallVector<Value> outVals =
cvtFunc(loc, rewriter, inVals[0], inVals[1], inVals[2], inVals[3]);
SmallVector<Value> outVals = cvtFunc(loc, rewriter, inVals);
assert(outVals.size() == inVals.size());
if (isDstFP32)
for (Value &v : outVals)
@@ -1234,6 +1329,9 @@ struct FpToFpOpConversion
// Pack values
return outVals;
}
private:
int computeCapability;
};
template <typename OP>
@@ -1336,15 +1434,16 @@ struct CmpFOpConversion
}
};
template <class T>
struct ExternElementwiseOpConversion
: public ElementwiseOpConversionBase<T, ExternElementwiseOpConversion<T>> {
using Base = ElementwiseOpConversionBase<T, ExternElementwiseOpConversion<T>>;
: public ElementwiseOpConversionBase<ExternElementwiseOp,
ExternElementwiseOpConversion> {
using Base = ElementwiseOpConversionBase<ExternElementwiseOp,
ExternElementwiseOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
typedef typename Base::OpAdaptor OpAdaptor;
SmallVector<Value> createDestOps(T op, OpAdaptor adaptor,
SmallVector<Value> createDestOps(ExternElementwiseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
@@ -1365,8 +1464,9 @@ private:
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
}
LLVM::LLVMFuncOp appendOrGetFuncOp(ConversionPatternRewriter &rewriter, T op,
StringRef funcName, Type funcType) const {
LLVM::LLVMFuncOp appendOrGetFuncOp(ConversionPatternRewriter &rewriter,
ExternElementwiseOp op, StringRef funcName,
Type funcType) const {
using LLVM::LLVMFuncOp;
auto funcAttr = StringAttr::get(op->getContext(), funcName);
@@ -1385,6 +1485,86 @@ private:
}
};
struct ElementwiseInlineAsmOpConversion
: public ElementwiseOpConversionBase<ElementwiseInlineAsmOp,
ElementwiseInlineAsmOpConversion> {
using Base = ElementwiseOpConversionBase<ElementwiseInlineAsmOp,
ElementwiseInlineAsmOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
typedef typename Base::OpAdaptor OpAdaptor;
// If operand size is smaller than 32bits pack by groups of 32bits.
// Otherwise have separate inputs.
SmallVector<Value> packOperands(ElementwiseInlineAsmOp op,
MultipleOperandsRange operands,
ConversionPatternRewriter &rewriter,
Location loc) const {
SmallVector<Value> packedOperands;
unsigned numPackedElements = op.getPackedElement();
for (int i = 0, e = op.getNumOperands(); i < e; i++) {
unsigned bitWidth =
getElementType(op.getOperand(i)).getIntOrFloatBitWidth();
unsigned numElementPerReg = bitWidth < 32 ? 32 / bitWidth : 1;
numElementPerReg = std::min(numElementPerReg, numPackedElements);
for (int j = 0; j < numPackedElements; j += numElementPerReg) {
if (numElementPerReg == 1) {
packedOperands.push_back(operands[j][i]);
continue;
}
Type t = vec_ty(
getTypeConverter()->convertType(getElementType(op.getOperand(i))),
numElementPerReg);
Value packed = undef(t);
for (int k = 0; k < numElementPerReg; k++) {
packed = insert_element(packed, operands[j + k][i], i32_val(k));
}
packedOperands.push_back(packed);
}
}
return packedOperands;
}
SmallVector<Value> createDestOps(ElementwiseInlineAsmOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
int numPackedElements = op.getPackedElement();
if (operands.size() % numPackedElements != 0)
llvm::report_fatal_error("Inline asm op has more packed elements than "
"number of elements per thread.");
SmallVector<Value> packedOperands =
packOperands(op, operands, rewriter, loc);
Type dstType =
getTypeConverter()->convertType(getElementType(op.getResult()));
Type retType = dstType;
if (numPackedElements > 1)
retType = vec_ty(retType, numPackedElements);
Value result = rewriter
.create<LLVM::InlineAsmOp>(
loc, retType,
packedOperands, // operands
op.getAsmString(), // asm_string
op.getConstraints(), // constraints
!op.getPure(), // has_side_effects
false, // is_align_stack
LLVM::AsmDialectAttr::get(
rewriter.getContext(),
LLVM::AsmDialect::AD_ATT), // asm_dialect
ArrayAttr() // operand_attrs
)
->getResult(0);
SmallVector<Value> results;
if (numPackedElements > 1) {
for (int i = 0; i < numPackedElements; i++)
results.push_back(extract_element(result, i32_val(i)));
} else {
results = {result};
}
return results;
}
};
struct FDivOpConversion
: ElementwiseOpConversionBase<mlir::arith::DivFOp, FDivOpConversion> {
using Base =
@@ -1577,9 +1757,15 @@ struct SIToFPOpConversion
auto cvtFunc = makeConverterFromPtx(
S8_to_Bf16, getTypeConverter()->convertType(inElemTy),
getTypeConverter()->convertType(outElemTy));
<<<<<<< HEAD
auto outVals = cvtFunc(loc, rewriter, operands[0][0], operands[1][0],
operands[2][0], operands[3][0]);
#endif
=======
SmallVector<Value> inVals = {operands[0][0], operands[1][0],
operands[2][0], operands[3][0]};
auto outVals = cvtFunc(loc, rewriter, inVals);
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
assert(outVals.size() == 4);
return outVals;
} else if (outElemTy.isBF16()) {
@@ -1768,7 +1954,10 @@ struct IndexCastOpLowering
void populateElementwiseOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
PatternBenefit benefit) {
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
int computeCapability, PatternBenefit benefit) {
#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp)
@@ -1792,7 +1981,11 @@ void populateElementwiseOpToLLVMPatterns(
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
POPULATE_BINARY_OP(arith::MinFOp, LLVM::MinNumOp) // fmin
POPULATE_BINARY_OP(arith::MaxFOp, LLVM::MaxNumOp) // fmax
POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin
POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax
POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin
POPULATE_BINARY_OP(arith::MaxUIOp, LLVM::UMaxOp) // umax
#undef POPULATE_BINARY_OP
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
@@ -1828,16 +2021,171 @@ void populateElementwiseOpToLLVMPatterns(
patterns.add<SIToFPOpConversion>(typeConverter, benefit);
patterns.add<IndexCastOpLowering>(typeConverter, benefit);
patterns.add<FpToFpOpConversion>(typeConverter, benefit);
patterns.add<FpToFpOpConversion>(typeConverter, computeCapability, benefit);
patterns.add<ExternElementwiseOpConversion<triton::PureExternElementwiseOp>>(
typeConverter, benefit);
patterns
.add<ExternElementwiseOpConversion<triton::ImpureExternElementwiseOp>>(
typeConverter, benefit);
patterns.add<ExternElementwiseOpConversion>(typeConverter, benefit);
patterns.add<ElementwiseInlineAsmOpConversion>(typeConverter, benefit);
// ExpOpConversionApprox will try using ex2.approx if the input type is
// FP32. For other input types, ExpOpConversionApprox will return failure and
// ElementwiseOpConversion<math::ExpOp, math::ExpOp> defined below will call
// __nv_expf for higher-precision calculation
patterns.add<ExpOpConversionApprox>(typeConverter, benefit);
}
struct FPExtOpConversion
: ElementwiseOpConversionBase<LLVM::FPExtOp, FPExtOpConversion> {
using Base = ElementwiseOpConversionBase<LLVM::FPExtOp, FPExtOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
static bool isLegalOp(LLVM::FPExtOp op) {
auto retTy = op.getResult().getType();
auto srcTy = op.getOperand().getType();
if (retTy.isF32() && srcTy.isF16()) {
return false;
}
return true;
}
SmallVector<Value> createDestOps(LLVM::FPExtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
return {
FpToFpOpConversion::convertFp16ToFp32(loc, rewriter, operands[0][0])};
}
};
struct FPTruncOpConversion
: ElementwiseOpConversionBase<LLVM::FPTruncOp, FPTruncOpConversion> {
using Base =
ElementwiseOpConversionBase<LLVM::FPTruncOp, FPTruncOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
static bool isLegalOp(LLVM::FPTruncOp op) {
auto retTy = op.getResult().getType();
auto srcTy = op.getOperand().getType();
if (retTy.isF16() && srcTy.isF32()) {
return false;
}
return true;
}
SmallVector<Value> createDestOps(LLVM::FPTruncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
return {
FpToFpOpConversion::convertFp32ToFp16(loc, rewriter, operands[0][0])};
}
};
struct TruncOpConversion
: ElementwiseOpConversionBase<LLVM::TruncOp, TruncOpConversion> {
using Base = ElementwiseOpConversionBase<LLVM::TruncOp, TruncOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
static bool isLegalOp(LLVM::TruncOp op) {
auto retTy = op.getResult().getType();
auto srcTy = op.getOperand().getType();
if (retTy.isInteger(16) && srcTy.isInteger(32)) {
return false;
}
return true;
}
SmallVector<Value> createDestOps(LLVM::TruncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
PTXBuilder builder;
auto &cvt = *builder.create("cvt.u16.u32");
auto res = builder.newOperand("=h");
auto operand = builder.newOperand(operands[0][0], "r");
cvt(res, operand);
return {builder.launch(rewriter, loc, i16_ty, false)};
}
};
struct SExtOpConversion
: ElementwiseOpConversionBase<LLVM::SExtOp, SExtOpConversion> {
using Base = ElementwiseOpConversionBase<LLVM::SExtOp, SExtOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
static bool isLegalOp(LLVM::SExtOp op) {
auto retTy = op.getResult().getType();
auto srcTy = op.getOperand().getType();
if (retTy.isInteger(32) && srcTy.isInteger(16)) {
return false;
}
return true;
}
SmallVector<Value> createDestOps(LLVM::SExtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
PTXBuilder builder;
auto &cvt = *builder.create("cvt.s32.s16");
auto res = builder.newOperand("=r");
auto operand = builder.newOperand(operands[0][0], "h");
cvt(res, operand);
return {builder.launch(rewriter, loc, i32_ty, false)};
}
};
struct ZExtOpConversion
: ElementwiseOpConversionBase<LLVM::ZExtOp, ZExtOpConversion> {
using Base = ElementwiseOpConversionBase<LLVM::ZExtOp, ZExtOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
static bool isLegalOp(LLVM::ZExtOp op) {
auto retTy = op.getResult().getType();
auto srcTy = op.getOperand().getType();
if (retTy.isInteger(32) && srcTy.isInteger(16)) {
return false;
}
return true;
}
SmallVector<Value> createDestOps(LLVM::ZExtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
PTXBuilder builder;
auto &cvt = *builder.create("cvt.u32.u16");
auto res = builder.newOperand("=r");
auto operand = builder.newOperand(operands[0][0], "h");
cvt(res, operand);
return {builder.launch(rewriter, loc, i32_ty, false)};
}
};
bool isLegalElementwiseOp(Operation *op) {
if (isa<LLVM::FPExtOp>(op)) {
return FPExtOpConversion::isLegalOp(cast<LLVM::FPExtOp>(op));
} else if (isa<LLVM::FPTruncOp>(op)) {
return FPTruncOpConversion::isLegalOp(cast<LLVM::FPTruncOp>(op));
} else if (isa<LLVM::TruncOp>(op)) {
return TruncOpConversion::isLegalOp(cast<LLVM::TruncOp>(op));
} else if (isa<LLVM::SExtOp>(op)) {
return SExtOpConversion::isLegalOp(cast<LLVM::SExtOp>(op));
} else if (isa<LLVM::ZExtOp>(op)) {
return ZExtOpConversion::isLegalOp(cast<LLVM::ZExtOp>(op));
}
return true;
}
void populateElementwiseOpToPTXPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<FPExtOpConversion>(typeConverter, benefit);
patterns.add<FPTruncOpConversion>(typeConverter, benefit);
patterns.add<TruncOpConversion>(typeConverter, benefit);
patterns.add<SExtOpConversion>(typeConverter, benefit);
patterns.add<ZExtOpConversion>(typeConverter, benefit);
}

View File

@@ -8,8 +8,15 @@ using namespace mlir::triton;
void populateElementwiseOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
PatternBenefit benefit);
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
int computeCapability, PatternBenefit benefit);
bool isLegalElementwiseOp(Operation *op);
void populateElementwiseOpToPTXPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
PatternBenefit benefit);
#endif

View File

@@ -3,11 +3,21 @@
#include "ConvertLayoutOpToLLVM.h"
#include "LoadStoreOpToLLVM.h"
#include "Utility.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
#include <numeric>
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::delinearize;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::linearize;
using ::mlir::triton::gpu::getCTALayout;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getTotalElemsPerThread;
using ::mlir::triton::gpu::SharedEncodingAttr;
@@ -64,6 +74,9 @@ struct LoadOpConversion
Value other = op.getOther();
// adaptor values
assert(!isTensorPointerType(ptr.getType()) &&
"Cannot convert load with a tensor pointer into LLVM; "
"this case should be transformed to normal load before lowering");
Value llPtr = adaptor.getPtr();
Value llMask = adaptor.getMask();
Value llOther = adaptor.getOther();
@@ -433,6 +446,528 @@ struct StoreOpConversion
return success();
}
};
// TODO: refactor to save common logic with insertsliceasyncv2
struct StoreAsyncOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::nvidia_gpu::StoreAsyncOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::StoreAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
StoreAsyncOpConversion(TritonGPUToLLVMTypeConverter &converter,
ModuleAllocation &allocation,
mlir::triton::gpu::TMAMetadataTy *tmaMetadata,
const TensorPtrMapT *tensorPtrMap,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::nvidia_gpu::StoreAsyncOp>(
converter, allocation, tmaMetadata, benefit),
tensorPtrMap(tensorPtrMap) {}
LogicalResult
matchAndRewrite(triton::nvidia_gpu::StoreAsyncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
auto srcEncoding = srcTy.getEncoding();
if (srcEncoding.isa<MmaEncodingAttr>()) {
return lowerStoreAsyncWithSlice(op, adaptor, rewriter);
} else {
return lowerStoreAsync(op, adaptor, rewriter);
}
}
LogicalResult lowerStoreAsync(triton::nvidia_gpu::StoreAsyncOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
MLIRContext *ctx = rewriter.getContext();
auto dst = op.getDst();
auto src = op.getSrc();
auto srcTy = src.getType().cast<RankedTensorType>();
auto elemTy = srcTy.getElementType();
auto rank = srcTy.getRank();
// The sotre async op only supports tensor with ranke <= 5.
// Reference:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensor-dimension-size-and-format
assert(rank > 0 && rank <= 5);
auto moduleOp = op->getParentOfType<ModuleOp>();
assert(moduleOp && "Parent ModuleOp not found for StoreAsyncOp");
auto llFuncOp = op->getParentOfType<LLVM::LLVMFuncOp>();
assert(llFuncOp && "LLVMFuncOp not found for StoreAsyncOp");
int numTMADescs = getNumTMADescs(llFuncOp);
assert(numTMADescs > 0);
auto sharedLayout = srcTy.getEncoding().dyn_cast<SharedEncodingAttr>();
assert(sharedLayout && "expected shared encoding");
mlir::triton::gpu::TMAInfo tmaInfo;
tmaInfo.tensorDataType = getCUtensorMapDataType(elemTy);
tmaInfo.tensorRank = rank;
assert(tmaMetadata);
auto inOrder = sharedLayout.getOrder();
unsigned TMADescIdx = tmaMetadata->size();
unsigned numFuncArgs = llFuncOp.getBody().front().getNumArguments();
auto makeTensorPtr = tensorPtrMap->lookup(op.getOperation());
auto dstOrder = makeTensorPtr.getOrder();
unsigned globalAddressArgIdx = getArgIdx(makeTensorPtr.getBase());
tmaInfo.globalAddressArgIdx = globalAddressArgIdx;
tmaInfo.TMADescArgIdx = numFuncArgs - numTMADescs + TMADescIdx;
auto getDimOfOrder = [](ArrayRef<int32_t> order, int32_t i) {
auto it = std::find(order.begin(), order.end(), i);
assert(it != order.end());
return std::distance(order.begin(), it);
};
std::vector<int32_t> globalDimsArgIdx;
std::vector<int32_t> globalStridesArgIdx;
// constant values are mapped to (-1 - value)
for (int i = 0; i < rank; ++i) {
int32_t argIdx = -1;
auto dim = getDimOfOrder(dstOrder, i);
argIdx = getArgIdx(makeTensorPtr.getShape()[dim]);
globalDimsArgIdx.emplace_back(argIdx);
// handle constant stride
argIdx = getArgIdx(makeTensorPtr.getStrides()[dim]);
globalStridesArgIdx.emplace_back(argIdx);
}
tmaInfo.globalDimsArgIdx = globalDimsArgIdx;
tmaInfo.globalStridesArgIdx = globalStridesArgIdx;
std::vector<uint32_t> boxDims;
auto CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA();
auto CTAOrder = sharedLayout.getCTALayout().getCTAOrder();
auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum();
auto tensorShape = makeTensorPtr.getResult()
.getType()
.cast<triton::PointerType>()
.getPointeeType()
.cast<RankedTensorType>()
.getShape();
auto shapePerCTA = getShapePerCTA(CTASplitNum, tensorShape);
const uint32_t bytesPerCacheline = 128;
uint32_t bytesPerElem = elemTy.getIntOrFloatBitWidth() / 8;
uint32_t numBox{1};
for (int i = 0; i < rank; ++i) {
auto dim = getDimOfOrder(dstOrder, i);
auto tNumElems = shapePerCTA[dim];
if (i == 0 && tNumElems * bytesPerElem > bytesPerCacheline) {
tNumElems = bytesPerCacheline / bytesPerElem;
numBox = (shapePerCTA[dim] + tNumElems - 1) / tNumElems;
}
boxDims.emplace_back(tNumElems);
}
std::vector<uint32_t> elementStrides(rank, 1);
tmaInfo.boxDims = boxDims;
tmaInfo.elementStrides = elementStrides;
CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;
assert(
((elemTy.getIntOrFloatBitWidth() == 16 && sharedLayout.getVec() == 8) or
(elemTy.getIntOrFloatBitWidth() == 32 &&
sharedLayout.getVec() == 4)) &&
"Unexpected shared layout for StoreAsyncOp");
if (sharedLayout.getPerPhase() == 4 && sharedLayout.getMaxPhase() == 2)
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_32B;
else if (sharedLayout.getPerPhase() == 2 && sharedLayout.getMaxPhase() == 4)
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B;
else if (sharedLayout.getPerPhase() == 1 && sharedLayout.getMaxPhase() == 8)
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B;
else
llvm::report_fatal_error("Unsupported shared layout for StoreAsyncOp");
tmaInfo.swizzle = swizzle;
tmaInfo.interleave = CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE;
tmaInfo.l2Promotion =
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B;
tmaInfo.oobFill =
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
tmaMetadata->emplace_back(tmaInfo);
Value llDst = adaptor.getDst();
Value llSrc = adaptor.getSrc();
auto srcShape = srcTy.getShape();
auto smemObj = getSharedMemoryObjectFromStruct(loc, llSrc, rewriter);
SmallVector<Value> offsetVals;
for (auto i = 0; i < srcShape.size(); ++i) {
offsetVals.emplace_back(i32_val(0));
}
Value tmaDesc =
llFuncOp.getBody().front().getArgument(tmaInfo.TMADescArgIdx);
auto ptrI8SharedTy = LLVM::LLVMPointerType::get(
typeConverter->convertType(rewriter.getI8Type()), 3);
auto threadId = getThreadId(rewriter, loc);
Value pred = icmp_eq(threadId, i32_val(0));
auto llCoord = getTypeConverter()->unpackLLElements(loc, llDst, rewriter,
dst.getType());
uint32_t boxStride = std::accumulate(boxDims.begin(), boxDims.end(), 1,
std::multiplies<uint32_t>());
Value clusterCTAId = getClusterCTAId(rewriter, loc);
SmallVector<Value> multiDimClusterCTAId =
delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);
rewriter.create<triton::nvgpu::FenceAsyncSharedOp>(loc, 0);
for (uint32_t b = 0; b < numBox; ++b) {
SmallVector<Value> coord;
// raw coord
for (int i = 0; i < rank; ++i) {
auto dim = getDimOfOrder(dstOrder, i);
coord.push_back(llCoord[dim]);
}
// coord with box and cta offset
for (int i = 0; i < rank; ++i) {
auto dim = getDimOfOrder(dstOrder, i);
if (i == 0) {
coord[i] = add(coord[i], i32_val(b * boxDims[i]));
auto CTAOffset =
mul(multiDimClusterCTAId[dim], i32_val(numBox * boxDims[i]));
coord[i] = add(coord[i], CTAOffset);
} else {
coord[i] = add(coord[i],
mul(multiDimClusterCTAId[dim], i32_val(boxDims[i])));
}
}
Value srcOffset = i32_val(b * boxStride);
auto srcPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
Value srcPtrBase = gep(srcPtrTy, smemObj.base, srcOffset);
auto addr = bitcast(srcPtrBase, ptrI8SharedTy);
rewriter.create<triton::nvgpu::TMAStoreTiledOp>(loc, tmaDesc, addr, pred,
coord);
}
rewriter.eraseOp(op);
return success();
}
LogicalResult
lowerStoreAsyncWithSlice(triton::nvidia_gpu::StoreAsyncOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
MLIRContext *ctx = rewriter.getContext();
auto dst = op.getDst();
auto src = op.getSrc();
auto srcTy = src.getType().cast<RankedTensorType>();
auto makeTensorPtr = tensorPtrMap->lookup(op.getOperation());
auto dstTensorTy = makeTensorPtr.getResult()
.getType()
.cast<triton::PointerType>()
.getPointeeType()
.cast<RankedTensorType>();
auto tensorShape = dstTensorTy.getShape();
auto dstOrder = makeTensorPtr.getOrder();
auto dstElemTy = dstTensorTy.getElementType();
auto rank = srcTy.getRank();
// The sotre async op only supports tensor with ranke <= 5.
// Reference:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensor-dimension-size-and-format
assert(rank > 0 && rank <= 5);
auto moduleOp = op->getParentOfType<ModuleOp>();
assert(moduleOp && "Parent ModuleOp not found for StoreAsyncOp");
auto llFuncOp = op->getParentOfType<LLVM::LLVMFuncOp>();
assert(llFuncOp && "LLVMFuncOp not found for StoreAsyncOp");
int numTMADescs = getNumTMADescs(llFuncOp);
assert(numTMADescs > 0);
auto ctaLayout = getCTALayout(dstTensorTy.getEncoding());
// The order of smem should be consistent with gmem.
SmallVector<unsigned> sharedOrder;
for (auto o : makeTensorPtr.getOrder()) {
sharedOrder.emplace_back(o);
}
auto sharedLayout = SharedEncodingAttr::get(ctx, tensorShape, sharedOrder,
ctaLayout, dstElemTy);
mlir::triton::gpu::TMAInfo tmaInfo;
tmaInfo.tensorDataType = getCUtensorMapDataType(dstElemTy);
tmaInfo.tensorRank = rank;
assert(tmaMetadata);
unsigned TMADescIdx = tmaMetadata->size();
unsigned numFuncArgs = llFuncOp.getBody().front().getNumArguments();
unsigned globalAddressArgIdx = getArgIdx(makeTensorPtr.getBase());
tmaInfo.globalAddressArgIdx = globalAddressArgIdx;
tmaInfo.TMADescArgIdx = numFuncArgs - numTMADescs + TMADescIdx;
auto getDimOfOrder = [](ArrayRef<int32_t> order, int32_t i) {
auto it = std::find(order.begin(), order.end(), i);
assert(it != order.end());
return std::distance(order.begin(), it);
};
std::vector<int32_t> globalDimsArgIdx;
std::vector<int32_t> globalStridesArgIdx;
// constant values are mapped to (-1 - value)
for (int i = 0; i < rank; ++i) {
int32_t argIdx = -1;
auto dim = getDimOfOrder(dstOrder, i);
argIdx = getArgIdx(makeTensorPtr.getShape()[dim]);
globalDimsArgIdx.emplace_back(argIdx);
// handle constant stride
argIdx = getArgIdx(makeTensorPtr.getStrides()[dim]);
globalStridesArgIdx.emplace_back(argIdx);
}
tmaInfo.globalDimsArgIdx = globalDimsArgIdx;
tmaInfo.globalStridesArgIdx = globalStridesArgIdx;
std::vector<uint32_t> boxDims;
auto CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA();
auto CTAOrder = sharedLayout.getCTALayout().getCTAOrder();
auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum();
auto shapePerCTA = getShapePerCTA(CTASplitNum, tensorShape);
auto srcLayout = srcTy.getEncoding();
auto mmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy);
auto instrShape = mmaLayout.getInstrShape();
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
uint32_t repM =
ceil<unsigned>(shapePerCTA[0], instrShape[0] * warpsPerCTA[0]);
uint32_t numElemsPerRep = numElems / repM;
const uint32_t bytesPerCacheline = 128;
uint32_t bytesPerElem = dstElemTy.getIntOrFloatBitWidth() / 8;
uint32_t numBox{1};
for (int i = 0; i < rank; ++i) {
auto dim = getDimOfOrder(dstOrder, i);
auto tNumElems = shapePerCTA[dim];
if (i == 0 && tNumElems * bytesPerElem > bytesPerCacheline) {
tNumElems = bytesPerCacheline / bytesPerElem;
numBox = (shapePerCTA[dim] + tNumElems - 1) / tNumElems;
}
if (i == 1) {
tNumElems = tNumElems / repM / warpsPerCTA[0];
}
boxDims.emplace_back(tNumElems);
}
std::vector<uint32_t> elementStrides(rank, 1);
tmaInfo.boxDims = boxDims;
tmaInfo.elementStrides = elementStrides;
CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;
assert(((dstElemTy.getIntOrFloatBitWidth() == 16 &&
sharedLayout.getVec() == 8) or
(dstElemTy.getIntOrFloatBitWidth() == 32 &&
sharedLayout.getVec() == 4)) &&
"Unexpected shared layout for StoreAsyncOp");
if (sharedLayout.getPerPhase() == 4 && sharedLayout.getMaxPhase() == 2)
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_32B;
else if (sharedLayout.getPerPhase() == 2 && sharedLayout.getMaxPhase() == 4)
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B;
else if (sharedLayout.getPerPhase() == 1 && sharedLayout.getMaxPhase() == 8)
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B;
else
llvm::report_fatal_error("Unsupported shared layout for StoreAsyncOp");
tmaInfo.swizzle = swizzle;
tmaInfo.interleave = CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE;
tmaInfo.l2Promotion =
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B;
tmaInfo.oobFill =
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
tmaMetadata->emplace_back(tmaInfo);
Value llDst = adaptor.getDst();
Value llSrc = adaptor.getSrc();
auto srcShape = srcTy.getShape();
auto dstElemPtrTy = ptr_ty(getTypeConverter()->convertType(dstElemTy), 3);
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
smemBase = bitcast(smemBase, dstElemPtrTy);
SmallVector<Value> offsetVals;
for (auto i = 0; i < srcShape.size(); ++i) {
offsetVals.emplace_back(i32_val(0));
}
Value tmaDesc =
llFuncOp.getBody().front().getArgument(tmaInfo.TMADescArgIdx);
auto ptrI8SharedTy = LLVM::LLVMPointerType::get(
typeConverter->convertType(rewriter.getI8Type()), 3);
auto threadId = getThreadId(rewriter, loc);
Value pred = icmp_eq(urem(threadId, i32_val(32)), i32_val(0));
auto llCoord = getTypeConverter()->unpackLLElements(loc, llDst, rewriter,
dst.getType());
uint32_t boxStride = std::accumulate(boxDims.begin(), boxDims.end(), 1,
std::multiplies<uint32_t>());
boxStride = boxStride * repM * warpsPerCTA[0];
Value clusterCTAId = getClusterCTAId(rewriter, loc);
SmallVector<Value> multiDimClusterCTAId =
delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);
// rowStride in bytes
uint32_t rowStrideInBytes = shapePerCTA[dstOrder[0]] * bytesPerElem;
uint32_t swizzlingByteWidth =
std::min<uint32_t>(rowStrideInBytes, bytesPerCacheline);
unsigned numElemsPerSwizzlingRow = swizzlingByteWidth / bytesPerElem;
unsigned leadingDimOffset =
numElemsPerSwizzlingRow * shapePerCTA[dstOrder[1]];
uint32_t rowsPerRep = getShapePerCTATile(mmaLayout)[0];
Value warpId = udiv(threadId, i32_val(32));
Value warpId0 = urem(urem(warpId, i32_val(warpsPerCTA[0])),
i32_val(srcShape[0] / instrShape[0]));
auto srcOrder = triton::gpu::getOrder(srcLayout);
unsigned inVec =
srcOrder == sharedLayout.getOrder()
? triton::gpu::getContigPerThread(srcLayout)[srcOrder[0]]
: 1;
unsigned outVec = sharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
assert(minVec == 2);
auto wordTy = vec_ty(dstElemTy, minVec);
auto inVals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
rewriter, srcTy);
for (uint32_t b = 0; b < numBox; ++b) {
for (int rep = 0; rep < repM; ++rep) {
Value rowOfWarp = add(mul(warpId0, i32_val(instrShape[0])),
i32_val(rep * rowsPerRep));
uint32_t elemIdxOffset = rep * numElemsPerRep;
for (unsigned idx = 0; idx < numElemsPerRep / numBox; idx += 8) {
uint32_t elemIdx = elemIdxOffset + b * numElemsPerRep / numBox + idx;
Value offset = rewriter.create<triton::nvgpu::OffsetOfStmatrixV4Op>(
loc, i32_ty, threadId, rowOfWarp,
i32_val(b * numElemsPerRep / numBox + idx), leadingDimOffset,
numElemsPerSwizzlingRow, true);
Value addr = gep(dstElemPtrTy, smemBase, offset);
Value words[4];
for (unsigned i = 0; i < 8; ++i) {
if (i % minVec == 0)
words[i / 2] = undef(wordTy);
words[i / 2] = insert_element(
wordTy, words[i / 2], inVals[elemIdx + i], i32_val(i % minVec));
}
rewriter.create<triton::nvgpu::StoreMatrixOp>(
loc, bitcast(addr, ptrI8SharedTy),
ValueRange{bitcast(words[0], i32_ty), bitcast(words[1], i32_ty),
bitcast(words[2], i32_ty), bitcast(words[3], i32_ty)});
}
rewriter.create<triton::nvgpu::FenceAsyncSharedOp>(loc, 0);
SmallVector<Value> coord;
// raw coord
for (int i = 0; i < rank; ++i) {
auto dim = getDimOfOrder(dstOrder, i);
coord.push_back(llCoord[dim]);
}
// coord with box and cta offset
for (int i = 0; i < rank; ++i) {
auto dim = getDimOfOrder(dstOrder, i);
if (i == 0) {
coord[i] = add(coord[i], i32_val(b * boxDims[i]));
auto CTAOffset =
mul(multiDimClusterCTAId[dim], i32_val(numBox * boxDims[i]));
coord[i] = add(coord[i], CTAOffset);
} else {
Value blockOffset = i32_val(rep * instrShape[0] * warpsPerCTA[0]);
Value warpOffset = mul(warpId0, i32_val(instrShape[0]));
coord[i] = add(add(coord[i], add(blockOffset, warpOffset)),
mul(multiDimClusterCTAId[dim],
i32_val(boxDims[i] * repM * warpsPerCTA[0])));
}
}
Value srcOffset =
add(i32_val(b * boxStride + rep * instrShape[0] * warpsPerCTA[0] *
instrShape[1] * warpsPerCTA[1] /
numBox),
mul(warpId0, i32_val(instrShape[0] * numElemsPerSwizzlingRow)));
auto srcPtrTy = ptr_ty(getTypeConverter()->convertType(dstElemTy), 3);
Value srcPtrBase = gep(srcPtrTy, smemBase, srcOffset);
auto addr = bitcast(srcPtrBase, ptrI8SharedTy);
rewriter.create<triton::nvgpu::TMAStoreTiledOp>(loc, tmaDesc, addr,
pred, coord);
}
}
rewriter.eraseOp(op);
return success();
}
private:
CUtensorMapDataType getCUtensorMapDataType(Type ty) const {
if (ty.isF16()) {
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
} else if (ty.isF32()) {
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
} else {
llvm::report_fatal_error("Unsupported elemTy for StoreAsyncOp");
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
}
}
unsigned getArgIdx(Value v) const {
if (auto op = v.getDefiningOp<mlir::arith::ConstantOp>()) {
return -1 -
op.getValue().dyn_cast<IntegerAttr>().getValue().getZExtValue();
}
if (v.getDefiningOp() &&
isa<mlir::UnrealizedConversionCastOp>(v.getDefiningOp())) {
return getArgIdx(v.getDefiningOp()->getOperand(0));
} else if (v.getParentBlock()->isEntryBlock() && v.isa<BlockArgument>()) {
// in entryblock and is BlockArgument; Because argument of func are
// arugments of entryblock bb0 in MLIR
return v.cast<BlockArgument>().getArgNumber();
} else if (v.getParentBlock()->isEntryBlock() &&
(!v.isa<BlockArgument>())) {
// in entryblock but not BlockArgument
return getArgIdx(v.getDefiningOp()->getOperand(0));
} else if (!v.getParentBlock()->isEntryBlock()) {
// in non-entryblock
return getArgIdx(v.getDefiningOp()->getOperand(0));
} else {
llvm::report_fatal_error(
"Operand of `MakeTensorPtrOp` is not the function's argument");
return 0;
}
}
int getNumTMADescs(LLVM::LLVMFuncOp func) const {
if (!func->hasAttr(kAttrNumTMALoadDescsName)) {
llvm::report_fatal_error("TritonGPU module should contain a "
"triton_gpu.num-tma-load attribute");
return -1;
}
if (!func->hasAttr(kAttrNumTMAStoreDescsName)) {
llvm::report_fatal_error("TritonGPU module should contain a "
"triton_gpu.num-tma-store attribute");
return -1;
}
return func->getAttr(kAttrNumTMAStoreDescsName)
.cast<IntegerAttr>()
.getInt() +
func->getAttr(kAttrNumTMALoadDescsName).cast<IntegerAttr>().getInt();
}
const TensorPtrMapT *tensorPtrMap;
};
struct AtomicCASOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::AtomicCASOp>,
@@ -1126,11 +1661,389 @@ struct InsertSliceAsyncOpConversion
}
};
struct InsertSliceAsyncV2OpConversion
: public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::InsertSliceAsyncV2Op> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::InsertSliceAsyncV2Op>::
ConvertTritonGPUOpToLLVMPattern;
InsertSliceAsyncV2OpConversion(TritonGPUToLLVMTypeConverter &converter,
ModuleAllocation &allocation,
mlir::triton::gpu::TMAMetadataTy *tmaMetadata,
const TensorPtrMapT *tensorPtrMap,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::InsertSliceAsyncV2Op>(converter, allocation,
tmaMetadata, benefit),
tensorPtrMap(tensorPtrMap) {}
LogicalResult
matchAndRewrite(triton::nvidia_gpu::InsertSliceAsyncV2Op op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = op.getResult().getType().cast<RankedTensorType>();
auto elemTy = resultTy.getElementType();
auto rank = resultTy.getRank() - 1;
// TODO: support any valid rank in (3, 4, 5)
// The sotre async op only supports tensor with ranke <= 5.
// Reference:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensor-dimension-size-and-format
assert(rank > 0 && rank <= 5);
SmallVector<unsigned> shape;
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
auto moduleOp = op->getParentOfType<ModuleOp>();
assert(moduleOp && "Parent ModuleOp not found for InsertSliceAsyncV2Op");
auto llFuncOp = op->getParentOfType<LLVM::LLVMFuncOp>();
assert(llFuncOp && "LLVMFuncOp not found for InsertSliceAsyncV2Op");
int numTMADescs = getNumTMADescs(llFuncOp);
assert(numTMADescs > 0);
auto sharedLayout = resultTy.getEncoding().dyn_cast<SharedEncodingAttr>();
assert(sharedLayout && "unexpected layout of InsertSliceAsyncV2Op");
auto CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA();
auto CTAOrder = sharedLayout.getCTALayout().getCTAOrder();
auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum();
mlir::triton::gpu::TMAInfo tmaInfo;
tmaInfo.tensorDataType = getCUtensorMapDataType(elemTy);
tmaInfo.tensorRank = rank;
assert(tmaMetadata);
unsigned TMADescIdx = tmaMetadata->size();
unsigned numFuncArgs = llFuncOp.getBody().front().getNumArguments();
auto makeTensorPtr = tensorPtrMap->lookup(op.getOperation());
auto inOrder = makeTensorPtr.getOrder();
unsigned globalAddressArgIdx = getArgIdx(makeTensorPtr.getBase());
tmaInfo.globalAddressArgIdx = globalAddressArgIdx;
tmaInfo.TMADescArgIdx = numFuncArgs - numTMADescs + TMADescIdx;
auto getDimOfOrder = [](ArrayRef<int32_t> order, int32_t i) {
auto it = std::find(order.begin(), order.end(), i);
assert(it != order.end());
return std::distance(order.begin(), it);
};
std::vector<int32_t> globalDimsArgIdx;
std::vector<int32_t> globalStridesArgIdx;
// constant values are mapped to (-1 - value)
for (int i = 0; i < rank; ++i) {
int32_t argIdx = -1;
auto dim = getDimOfOrder(inOrder, i);
argIdx = getArgIdx(makeTensorPtr.getShape()[dim]);
globalDimsArgIdx.emplace_back(argIdx);
// handle constant stride
argIdx = getArgIdx(makeTensorPtr.getStrides()[dim]);
globalStridesArgIdx.emplace_back(argIdx);
}
tmaInfo.globalDimsArgIdx = globalDimsArgIdx;
tmaInfo.globalStridesArgIdx = globalStridesArgIdx;
std::vector<uint32_t> boxDims;
auto tensorShape = makeTensorPtr.getResult()
.getType()
.cast<triton::PointerType>()
.getPointeeType()
.cast<RankedTensorType>()
.getShape();
SmallVector<unsigned> numMcast(rank);
unsigned accNumMcast = 1;
for (unsigned i = 0; i < rank; ++i) {
numMcast[i] = CTAsPerCGA[i] / CTASplitNum[i];
accNumMcast *= numMcast[i];
}
auto shapePerCTA = getShapePerCTA(CTASplitNum, tensorShape);
for (size_t i = 0; i < rank; ++i) {
auto dim = getDimOfOrder(inOrder, i);
// in case of TMA multicast, we should always slice along higher order
// dimensions
if (i == rank - 1) {
assert(shapePerCTA[dim] >= accNumMcast &&
"cases when the size of the highest order is smaller "
"than numMcasts is not implemented");
boxDims.emplace_back(shapePerCTA[dim] / accNumMcast);
} else {
boxDims.emplace_back(shapePerCTA[dim]);
}
}
std::vector<uint32_t> elementStrides(rank, 1);
tmaInfo.elementStrides = elementStrides;
CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;
if (sharedLayout.getPerPhase() == 4 && sharedLayout.getMaxPhase() == 2)
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_32B;
else if (sharedLayout.getPerPhase() == 2 && sharedLayout.getMaxPhase() == 4)
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B;
else if (sharedLayout.getPerPhase() == 1 && sharedLayout.getMaxPhase() == 8)
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B;
else
llvm::report_fatal_error(
"Unsupported shared layout for InsertSliceAsyncV2Op");
tmaInfo.swizzle = swizzle;
tmaInfo.interleave = CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE;
tmaInfo.l2Promotion =
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B;
tmaInfo.oobFill =
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
uint32_t numBoxes = 1;
uint32_t elemSizeOfBytes = elemTy.getIntOrFloatBitWidth() / 8;
if (swizzle == CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
while (elemSizeOfBytes * boxDims[0] > 128) {
boxDims[0] = boxDims[0] / 2;
numBoxes *= 2;
}
}
tmaInfo.boxDims = boxDims;
tmaMetadata->emplace_back(tmaInfo);
uint32_t elemsPerBox =
std::accumulate(boxDims.begin(), boxDims.end(), 1, std::multiplies{});
Value clusterCTAId = getClusterCTAId(rewriter, loc);
SmallVector<Value> multiDimClusterCTAId =
delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);
Value llDst = adaptor.getDst();
Value llIndex = adaptor.getIndex();
Value src = op.getSrc();
Value dst = op.getDst();
auto dstTy = dst.getType().cast<RankedTensorType>();
auto dstShape = dstTy.getShape();
auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter);
// the offset of coord considering multicast slicing
SmallVector<Value> mcastOffsetVals;
// The index of slice is this CTAId is responsible for
SmallVector<Value> multiDimSliceIdx(rank);
for (auto i = 0; i < rank; ++i)
multiDimSliceIdx[i] =
udiv(multiDimClusterCTAId[i], i32_val(CTASplitNum[i]));
Value sliceIdx =
linearize(rewriter, loc, multiDimSliceIdx, numMcast, CTAOrder);
Value sliceCoord;
for (auto i = 0; i < rank; ++i) {
if (inOrder[i] == rank - 1) {
// TODO[goostavz]: Cases when the size of the highest order is smaller
// than numMcasts is not implemented.
sliceCoord = mul(sliceIdx, i32_val(shapePerCTA[i] / accNumMcast));
mcastOffsetVals.emplace_back(
mul(sliceIdx, i32_val(shapePerCTA[i] / accNumMcast)));
} else {
mcastOffsetVals.emplace_back(i32_val(0));
}
}
uint32_t elemsPerSlice = std::accumulate(
shapePerCTA.begin(), shapePerCTA.end(), 1, std::multiplies{});
Value dstOffsetCommon = mul(llIndex, i32_val(elemsPerSlice));
// [benzh] sliceCoord should be higher dimension's multiplier accumulate.
// currently only support rank == 2.
dstOffsetCommon =
add(dstOffsetCommon, mul(sliceCoord, i32_val(boxDims[0])));
auto dstPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
Value tmaDesc =
llFuncOp.getBody().front().getArgument(tmaInfo.TMADescArgIdx);
// TODO: sink this logic into Triton::NVGPU dialect and support more
// cache-policy modes
Value l2Desc = int_val(64, 0x1000000000000000ll);
auto ptrI8SharedTy = LLVM::LLVMPointerType::get(
typeConverter->convertType(rewriter.getI8Type()), 3);
SmallVector<Value> coordCommon;
auto llCoord = getTypeConverter()->unpackLLElements(
loc, adaptor.getSrc(), rewriter, src.getType());
for (int i = 0; i < rank; ++i) {
auto dim = getDimOfOrder(inOrder, i);
Value coordDim = bitcast(llCoord[dim], i32_ty);
if (CTASplitNum[dim] != 1) {
// Add offset for each CTA
// boxDims[i] * (multiDimClusterCTAId[i] % CTASplitNum[i]);
auto CTAOffset =
mul(i32_val(shapePerCTA[dim]),
urem(multiDimClusterCTAId[dim], i32_val(CTASplitNum[dim])));
coordDim = add(coordDim, CTAOffset);
}
if (i == rank - 1)
// Add offset in case of multicast slicing
coordCommon.push_back(add(coordDim, mcastOffsetVals[dim]));
else
coordCommon.push_back(coordDim);
}
auto threadId = getThreadId(rewriter, loc);
Value pred = icmp_eq(threadId, i32_val(0));
auto mask = adaptor.getMask();
if (mask) {
// TODO(thomas): What is the right implementation for this case?
assert(mask.getType().isInteger(1) &&
"need to implement cases with tensor mask");
pred = rewriter.create<arith::AndIOp>(loc, pred, mask);
}
Value mcastMask = getMCastMask(sharedLayout, rewriter, loc, clusterCTAId);
for (size_t i = 0; i < numBoxes; ++i) {
Value dstOffset =
add(dstOffsetCommon, i32_val(i * elemsPerBox * accNumMcast));
Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset);
SmallVector<Value> coord = coordCommon;
coord[0] = add(coordCommon[0], i32_val(i * boxDims[0]));
rewriter.create<triton::nvgpu::TMALoadTiledOp>(
loc, bitcast(dstPtrBase, ptrI8SharedTy), adaptor.getMbar(), tmaDesc,
l2Desc, pred, coord, mcastMask);
}
rewriter.replaceOp(op, llDst);
return success();
}
private:
Value getMCastMask(const SharedEncodingAttr &sharedLayout,
ConversionPatternRewriter &rewriter, Location loc,
Value clusterCTAId) const {
auto CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA();
auto CTAOrder = sharedLayout.getCTALayout().getCTAOrder();
auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum();
// Short path when no multicast is needed
if (CTAsPerCGA == CTASplitNum)
return nullptr;
// Short path when bcastMask is a constant
bool isConstMcastMask = true;
for (unsigned s : CTASplitNum) {
if (s > 1) {
isConstMcastMask = false;
break;
}
}
if (isConstMcastMask) {
unsigned numCTAs = std::accumulate(CTAsPerCGA.begin(), CTAsPerCGA.end(),
1, std::multiplies{});
return int_val(/*width*/ 16, (1u << numCTAs) - 1);
}
SmallVector<Value> multiDimCTAId =
delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);
auto rank = CTAOrder.size();
SmallVector<SmallVector<Value>> multiDimMask(rank);
unsigned accNumMcast = 1;
SmallVector<unsigned> numMcast(rank);
for (unsigned i = 0; i < rank; ++i) {
// For the ith dimension, CTAsPerCGA[i]/CTASplitNum[i] vals is to be
// broadcasted, which for this CTAId is:
// multiDimCTAId[i] % CTASplitNum[i] + (0 ..
// (CTAsPerCGA[i]/CTASplitNum[i] - 1)) * CTASplitNum[i]
// TODO: will there be cases if CTAsPerCGA[i]/CTASplitNum[i] < 1?
Value rem = urem(multiDimCTAId[i], i32_val(CTASplitNum[i]));
numMcast[i] = CTAsPerCGA[i] / CTASplitNum[i];
accNumMcast *= numMcast[i];
for (unsigned j = 0; j < numMcast[i]; ++j) {
if (j == 0) {
multiDimMask[i].push_back(rem);
} else {
multiDimMask[i].push_back(add(rem, i32_val(j * CTASplitNum[i])));
}
}
}
Value bcastMask = int_val(/*width*/ 16, 0);
Value _1_i16 = int_val(/*width*/ 16, 1);
for (unsigned i = 0; i < accNumMcast; ++i) {
SmallVector<unsigned> multiDimIdx =
getMultiDimIndex<unsigned>(i, numMcast, CTAOrder);
SmallVector<Value> multiDimMaskedCTAId(rank);
for (unsigned dim = 0; dim < rank; ++dim) {
multiDimMaskedCTAId[dim] = multiDimMask[dim][multiDimIdx[dim]];
}
Value bcastCTAId =
linearize(rewriter, loc, multiDimMaskedCTAId, CTAsPerCGA, CTAOrder);
// bcastMask |= 1u << bcastCTAId;
bcastMask = or_(bcastMask, shl(_1_i16, trunc(i16_ty, bcastCTAId)));
}
return bcastMask;
}
CUtensorMapDataType getCUtensorMapDataType(Type ty) const {
if (ty.isF16()) {
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
} else if (ty.isF32()) {
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
} else {
llvm::report_fatal_error("Unsupported elemTy for InsertSliceAsyncV2Op");
return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
}
}
unsigned getArgIdx(Value v) const {
if (auto op = v.getDefiningOp<mlir::arith::ConstantOp>()) {
return -1 -
op.getValue().dyn_cast<IntegerAttr>().getValue().getZExtValue();
}
if (v.getDefiningOp() &&
isa<mlir::UnrealizedConversionCastOp>(v.getDefiningOp())) {
return getArgIdx(v.getDefiningOp()->getOperand(0));
} else if (v.getParentBlock()->isEntryBlock() && v.isa<BlockArgument>()) {
// in entryblock and is BlockArgument; Because argument of func are
// arugments of entryblock bb0 in MLIR
return v.cast<BlockArgument>().getArgNumber();
} else if (v.getParentBlock()->isEntryBlock() &&
(!v.isa<BlockArgument>())) {
// in entryblock but not BlockArgument
return getArgIdx(v.getDefiningOp()->getOperand(0));
} else if (!v.getParentBlock()->isEntryBlock()) {
// in non-entryblock
return getArgIdx(v.getDefiningOp()->getOperand(0));
} else {
llvm::report_fatal_error(
"Operand of `MakeTensorPtrOp` is not the function's argument");
return 0;
}
}
int getNumTMADescs(LLVM::LLVMFuncOp func) const {
if (!func->hasAttr(kAttrNumTMALoadDescsName)) {
llvm::report_fatal_error("TritonGPU module should contain a "
"triton_gpu.num-tma-load attribute");
return -1;
}
if (!func->hasAttr(kAttrNumTMAStoreDescsName)) {
llvm::report_fatal_error("TritonGPU module should contain a "
"triton_gpu.num-tma-store attribute");
return -1;
}
return func->getAttr(kAttrNumTMAStoreDescsName)
.cast<IntegerAttr>()
.getInt() +
func->getAttr(kAttrNumTMALoadDescsName).cast<IntegerAttr>().getInt();
}
const TensorPtrMapT *tensorPtrMap;
};
void populateLoadStoreOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) {
mlir::triton::gpu::TMAMetadataTy *tmaMetadata,
const TensorPtrMapT *tensorPtrMap, PatternBenefit benefit) {
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<StoreOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<AtomicCASOpConversion>(typeConverter, allocation,
@@ -1141,4 +2054,8 @@ void populateLoadStoreOpToLLVMPatterns(
indexCacheInfo, benefit);
patterns.add<InsertSliceAsyncOpConversion>(
typeConverter, allocation, indexCacheInfo, axisInfoAnalysis, benefit);
patterns.add<InsertSliceAsyncV2OpConversion>(
typeConverter, allocation, tmaMetadata, tensorPtrMap, benefit);
patterns.add<StoreAsyncOpConversion>(typeConverter, allocation, tmaMetadata,
tensorPtrMap, benefit);
}

View File

@@ -8,8 +8,10 @@ using namespace mlir::triton;
void populateLoadStoreOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit);
mlir::triton::gpu::TMAMetadataTy *tmaMetadata,
const TensorPtrMapT *tensorPtrMap, PatternBenefit benefit);
#endif

View File

@@ -51,6 +51,14 @@ PTXBuilder::Operand *PTXBuilder::newOperand(StringRef constraint, bool init) {
return opr;
}
PTXBuilder::Operand *PTXBuilder::newOperand(unsigned operandIndex) {
assert(operandIndex < oprCounter && "operand index out of range");
auto *opr = newOperand();
opr->idx = oprCounter++;
opr->constraint = std::to_string(operandIndex);
return opr;
}
PTXBuilder::Operand *PTXBuilder::newConstantOperand(const std::string &v) {
argArchive.emplace_back(std::make_unique<Operand>());
argArchive.back()->repr = [v](int idx) { return v; };

View File

@@ -1,5 +1,8 @@
#include "ReduceOpToLLVM.h"
#include "Utility.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
using namespace mlir;
using namespace mlir::triton;
@@ -14,8 +17,13 @@ using ::mlir::triton::gpu::getTotalElemsPerThread;
struct ReduceOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::ReduceOp> {
public:
using ConvertTritonGPUOpToLLVMPattern<
triton::ReduceOp>::ConvertTritonGPUOpToLLVMPattern;
ReduceOpConversion(
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
int computeCapability, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::ReduceOp>(
typeConverter, allocation, indexCacheInfo, benefit),
computeCapability(computeCapability) {}
LogicalResult
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
@@ -26,14 +34,12 @@ public:
}
private:
int computeCapability;
void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp,
llvm::SmallVectorImpl<Value> &acc, ValueRange cur,
bool isFirst) const {
SmallVector<Value> &acc, ValueRange cur, bool isFirst) const {
if (isFirst) {
acc.resize(cur.size());
for (unsigned i = 0; i < cur.size(); ++i) {
acc[i] = cur[i];
}
acc = SmallVector<Value>(cur.begin(), cur.end());
return;
}
@@ -120,7 +126,7 @@ private:
// writeIdx[originalAxis] = index[originalAxis] / axisSizePerThread
writeIdx[originalAxis] = udiv(index[originalAxis], axisSizePerThread);
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (!mmaLayout.isAmpere()) {
if (!mmaLayout.isAmpere() && !mmaLayout.isHopper()) {
llvm::report_fatal_error("Unsupported layout");
}
if (originalAxis == 0) {
@@ -175,7 +181,6 @@ private:
elemPtrTys[i] = LLVM::LLVMPointerType::get(llvmElemTy, 3);
}
auto llvmIndexTy = getTypeConverter()->getIndexType();
auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3);
auto smemShape = helper.getScratchConfigBasic();
unsigned elems = product<unsigned>(smemShape);
@@ -189,33 +194,10 @@ private:
elemPtrTys[i]);
}
unsigned srcElems = getTotalElemsPerThread(srcTys[0]);
// Emits indices of the original tensor that each thread
// would own
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]);
auto srcValues = unpackInputs(loc, op, adaptor, rewriter);
// Emits offsets (the offset from the base index)
// of the original tensor that each thread would own
// NOTE: Assumes offsets don't actually depend on type
SmallVector<SmallVector<unsigned>> offset =
emitOffsetForLayout(srcLayout, srcTys[0]);
// Keep track of accumulations and their indices
std::map<SmallVector<unsigned>, SmallVector<Value>> accs;
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
Region *combineOp = &op.getCombineOp();
// reduce within threads
for (unsigned i = 0; i < srcElems; ++i) {
SmallVector<unsigned> key = offset[i];
key[axis] = 0;
bool isFirst = accs.find(key) == accs.end();
accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst);
if (isFirst)
indices[key] = srcIndices[i];
}
reduceWithinThreads(helper, srcValues, accs, indices, rewriter);
// cached int32 constants
std::map<int, Value> ints;
@@ -271,15 +253,17 @@ private:
readPtrs[i] = gep(elemPtrTys[i], writePtrs[i], readOffset);
}
barrier();
sync(rewriter, loc, op);
// Combine accumulator value from another thread
SmallVector<Value> cur(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
cur[i] = load(readPtrs[i]);
}
accumulate(rewriter, *combineOp, acc, cur, false);
accumulate(rewriter, op.getCombineOp(), acc, cur, false);
sync(rewriter, loc, op);
barrier();
// Publish our new accumulator value to shared memory
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
store(acc[i], writePtrs[i]);
@@ -287,7 +271,7 @@ private:
}
}
barrier();
sync(rewriter, loc, op);
// set output values
SmallVector<Value> results(op.getNumOperands());
@@ -324,79 +308,197 @@ private:
return success();
}
// Use warp shuffle for reduction within warps and shared memory for data
// exchange across warps
LogicalResult matchAndRewriteFast(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
ReduceOpHelper helper(op);
Location loc = op->getLoc();
unsigned axis = adaptor.getAxis();
auto srcTys = op.getInputTypes();
auto srcLayout = helper.getSrcLayout();
if (!helper.isSupportedLayout()) {
assert(false && "Unexpected srcLayout in ReduceOpConversion");
void sync(ConversionPatternRewriter &rewriter, Location loc,
triton::ReduceOp op) const {
// TODO[shuhaoj]: change hard code style of numThreads. Hide async_agent
// attr.
if (getWSAgentId(op)) {
barSync(rewriter, op, getAgentIds(op).front(), 128);
} else {
barrier();
}
auto srcOrd = triton::gpu::getOrder(srcLayout);
auto srcShape = helper.getSrcShape();
}
SmallVector<Type> elemPtrTys(srcTys.size());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto ty = srcTys[i].getElementType();
auto llvmElemTy = getTypeConverter()->convertType(ty);
elemPtrTys[i] = LLVM::LLVMPointerType::get(llvmElemTy, 3);
}
auto llvmIndexTy = getTypeConverter()->getIndexType();
auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3);
auto smemShapes = helper.getScratchConfigsFast();
unsigned elems = product<unsigned>(smemShapes[0]);
unsigned maxElems = std::max(elems, product<unsigned>(smemShapes[1]));
unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData();
unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData();
SmallVector<Value> smemBases(op.getNumOperands());
bool isWarpSync = helper.isWarpSynchronous();
if (!isWarpSync) {
smemBases[0] = bitcast(
getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]);
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
smemBases[i] =
bitcast(gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(maxElems)),
elemPtrTys[i]);
}
}
unsigned srcElems = getTotalElemsPerThread(srcTys[0]);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]);
auto srcValues = unpackInputs(loc, op, adaptor, rewriter);
std::map<SmallVector<unsigned>, SmallVector<Value>> accs;
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
// Check if the reduction can use a redux op and return the kind.
std::optional<NVVM::ReduxKind> matchReduxKind(triton::ReduceOp op) const {
if (computeCapability < 80)
return std::nullopt;
if (op.getNumOperands() != 1 || op.getNumResults() != 1)
return std::nullopt;
Block *block = &(*op.getCombineOp().begin());
Operation *yield = block->getTerminator();
Operation *reduceOp = yield->getOperand(0).getDefiningOp();
if (!reduceOp || reduceOp->getNumOperands() != 2 ||
reduceOp->getNumResults() != 1)
return std::nullopt;
auto intType = reduceOp->getResultTypes()[0].dyn_cast<IntegerType>();
if (!intType || intType.getWidth() > 32)
return std::nullopt;
if (reduceOp->getOperand(0) != block->getArgument(0) ||
reduceOp->getOperand(1) != block->getArgument(1))
return std::nullopt;
if (isa<arith::AddIOp>(reduceOp))
return NVVM::ReduxKind::ADD;
if (isa<arith::AndIOp>(reduceOp))
return NVVM::ReduxKind::AND;
if (isa<arith::OrIOp>(reduceOp))
return NVVM::ReduxKind::OR;
if (isa<arith::XOrIOp>(reduceOp))
return NVVM::ReduxKind::XOR;
if (isa<arith::MinSIOp>(reduceOp))
return NVVM::ReduxKind::MIN;
if (isa<arith::MinUIOp>(reduceOp))
return NVVM::ReduxKind::UMIN;
if (isa<arith::MaxSIOp>(reduceOp))
return NVVM::ReduxKind::MAX;
if (isa<arith::MaxUIOp>(reduceOp))
return NVVM::ReduxKind::UMAX;
return std::nullopt;
}
// Reduce along op axis for elements that are in the same thread. The
// accumulated value is stored in accs.
void reduceWithinThreads(
ReduceOpHelper &helper, SmallVector<SmallVector<Value>> &srcValues,
std::map<SmallVector<unsigned>, SmallVector<Value>> &accs,
std::map<SmallVector<unsigned>, SmallVector<Value>> &indices,
ConversionPatternRewriter &rewriter) const {
triton::ReduceOp op = helper.getOperation();
RankedTensorType operandType = op.getInputTypes()[0];
// Assumes offsets don't actually depend on type
SmallVector<SmallVector<unsigned>> offset =
emitOffsetForLayout(srcLayout, srcTys[0]);
emitOffsetForLayout(helper.getSrcLayout(), operandType);
unsigned srcElems = getTotalElemsPerThread(operandType);
auto *combineOp = &op.getCombineOp();
auto srcIndices =
emitIndices(op.getLoc(), rewriter, helper.getSrcLayout(), operandType);
// reduce within threads
for (unsigned i = 0; i < srcElems; ++i) {
SmallVector<unsigned> key = offset[i];
key[axis] = 0;
key[op.getAxis()] = 0;
bool isFirst = accs.find(key) == accs.end();
accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst);
if (isFirst)
indices[key] = srcIndices[i];
}
}
// Apply warp reduction across the given number of contiguous lanes using op
// region and the accumulator values as source.
void warpReduce(ConversionPatternRewriter &rewriter, Location loc,
SmallVector<Value> &acc, triton::ReduceOp op,
unsigned numLaneToReduce) const {
if (auto kind = matchReduxKind(op)) {
// Based on benchmarking on A100 redux op gives a speed up only when doing
// a single reduction (not partioned) and when the mask is static.
// Therefore we currently only enable it to reduce across all the lanes.
if (numLaneToReduce == 32) {
assert(acc.size() == 1);
Value mask = i32_val(0xFFFFFFFF);
// Even though we currently don't use redux for partitioned reduction
// the code below supports it in case we want to tweak the heuristic.
if (numLaneToReduce < 32) {
// For partitioned reduction we need to caluclate the mask so that
// each group of numLaneToReduce threads has the correct mask.
unsigned bitmask = (1 << numLaneToReduce) - 1;
Value threadId = getThreadId(rewriter, loc);
Value laneId = urem(threadId, i32_val(32));
mask = shl(i32_val(bitmask),
and_(laneId, i32_val(~(numLaneToReduce - 1))));
}
for (unsigned i = 0; i < acc.size(); ++i) {
unsigned bitwidth = acc[i].getType().cast<IntegerType>().getWidth();
if (bitwidth < 32) {
if (*kind == NVVM::ReduxKind::MIN || *kind == NVVM::ReduxKind::MAX)
acc[i] = sext(i32_ty, acc[i]);
else
acc[i] = zext(i32_ty, acc[i]);
}
acc[i] = rewriter.create<NVVM::ReduxOp>(loc, acc[i].getType(), acc[0],
*kind, mask);
if (bitwidth < 32)
acc[i] = trunc(int_ty(bitwidth), acc[i]);
}
return;
}
}
for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) {
SmallVector<Value> shfl(acc.size());
for (unsigned i = 0; i < acc.size(); ++i) {
shfl[i] = shflSync(loc, rewriter, acc[i], N);
}
accumulate(rewriter, op.getCombineOp(), acc, shfl, false);
}
}
// Reduce across threads within each warp.
void
reduceWithinWarps(ReduceOpHelper &helper,
std::map<SmallVector<unsigned>, SmallVector<Value>> &accs,
ConversionPatternRewriter &rewriter) const {
triton::ReduceOp op = helper.getOperation();
unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData();
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
SmallVector<Value> &acc = accs[key];
warpReduce(rewriter, op.getLoc(), acc, op, sizeIntraWarps);
}
}
// Pack the accumualtor values and replace the reduce op with the result.
void packResults(ReduceOpHelper &helper,
std::map<SmallVector<unsigned>, SmallVector<Value>> &accs,
ConversionPatternRewriter &rewriter) const {
triton::ReduceOp op = helper.getOperation();
Location loc = op.getLoc();
unsigned axis = op.getAxis();
SmallVector<Value> results(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
if (auto resultTy =
op.getResult()[i].getType().dyn_cast<RankedTensorType>()) {
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
unsigned resultElems = getTotalElemsPerThread(resultTy);
SmallVector<SmallVector<unsigned>> resultOffset =
emitOffsetForLayout(resultLayout, resultTy);
SmallVector<Value> resultVals;
for (int j = 0; j < resultElems; j++) {
auto key = resultOffset[j];
key.insert(key.begin() + axis, 0);
resultVals.push_back(accs[key][i]);
}
results[i] = getTypeConverter()->packLLElements(loc, resultVals,
rewriter, resultTy);
} else
results[i] = accs.begin()->second[i];
}
rewriter.replaceOp(op, results);
}
// Return the type of the shared memory pointer for operand i.
Type getElementPtrType(triton::ReduceOp op, int i) const {
auto ty = op.getInputTypes()[i].getElementType();
auto llvmElemTy = getTypeConverter()->convertType(ty);
return LLVM::LLVMPointerType::get(llvmElemTy, 3);
}
void storeWarpReduceToSharedMemory(
ReduceOpHelper &helper,
std::map<SmallVector<unsigned>, SmallVector<Value>> &accs,
std::map<SmallVector<unsigned>, SmallVector<Value>> &indices,
SmallVector<Value> &smemBases,
ConversionPatternRewriter &rewriter) const {
triton::ReduceOp op = helper.getOperation();
Location loc = op.getLoc();
Value threadId = getThreadId(rewriter, loc);
unsigned wavefront_size = triton::gpu::getWarpSize(srcLayout);
Value warpSize = i32_val(wavefront_size);
Value warpId = udiv(threadId, warpSize);
Value laneId = urem(threadId, warpSize);
auto srcLayout = helper.getSrcLayout();
auto srcShape = helper.getSrcShape();
unsigned axis = op.getAxis();
auto smemShapes = helper.getScratchConfigsFast();
auto threadsPerWarp =
triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape);
@@ -428,9 +530,9 @@ private:
Value zero = i32_val(0);
Value laneZero = icmp_eq(laneIdAxis, zero);
std::map<SmallVector<unsigned>, SmallVector<Value>> finalAccs;
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
<<<<<<< HEAD
SmallVector<Value> acc = it.second;
// Reduce within warps
@@ -455,49 +557,38 @@ private:
finalAccs[key] = acc;
continue;
}
=======
SmallVector<Value> &acc = it.second;
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
SmallVector<Value> writeIdx = indices[key];
writeIdx[axis] = warpIdAxis;
Value writeOffset =
linearize(rewriter, loc, writeIdx, smemShapes[0], order);
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
Value writePtr = gep(elemPtrTys[i], smemBases[i], writeOffset);
auto elemPtrTy = getElementPtrType(op, i);
Value writePtr = gep(elemPtrTy, smemBases[i], writeOffset);
storeShared(rewriter, loc, writePtr, acc[i], laneZero);
}
}
}
if (isWarpSync) {
SmallVector<Value> results(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
if (auto resultTy =
op.getResult()[i].getType().dyn_cast<RankedTensorType>()) {
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
unsigned resultElems = getTotalElemsPerThread(resultTy);
SmallVector<SmallVector<unsigned>> resultOffset =
emitOffsetForLayout(resultLayout, resultTy);
SmallVector<Value> resultVals;
for (int j = 0; j < resultElems; j++) {
auto key = resultOffset[j];
key.insert(key.begin() + axis, 0);
resultVals.push_back(finalAccs[key][i]);
}
results[i] = getTypeConverter()->packLLElements(loc, resultVals,
rewriter, resultTy);
} else
results[i] = finalAccs.begin()->second[i];
}
rewriter.replaceOp(op, results);
return success();
}
// Load the reduction of each warp and accumulate them to a final value and
// store back to shared memory.
void accumulatePartialReductions(ReduceOpHelper &helper,
SmallVector<Value> &smemBases,
ConversionPatternRewriter &rewriter) const {
triton::ReduceOp op = helper.getOperation();
auto srcLayout = helper.getSrcLayout();
auto smemShapes = helper.getScratchConfigsFast();
unsigned elems = product<unsigned>(smemShapes[0]);
unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData();
Location loc = op.getLoc();
barrier();
// The second round of shuffle reduction
// now the problem size: sizeInterWarps, s1, s2, .. , sn
// where sizeInterWarps is 2^m
//
// Each thread needs to process:
// elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(32);
Value laneId = urem(threadId, warpSize);
Value zero = i32_val(0);
auto mod = op.getOperation()->getParentOfType<ModuleOp>();
unsigned numThreads =
@@ -510,23 +601,18 @@ private:
// i32_val(sizeInerWarps))
SmallVector<Value> acc(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset);
auto elemPtrTy = getElementPtrType(op, i);
Value readPtr = gep(elemPtrTy, smemBases[i], readOffset);
acc[i] = load(readPtr);
}
for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) {
SmallVector<Value> shfl(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
shfl[i] = shflSync(loc, rewriter, acc[i], N);
}
accumulate(rewriter, *combineOp, acc, shfl, false);
}
warpReduce(rewriter, loc, acc, op, sizeInterWarps);
// only the first thread in each sizeInterWarps is writing
Value writeOffset = readOffset;
SmallVector<Value> writePtrs(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
writePtrs[i] = gep(elemPtrTys[i], smemBases[i], writeOffset);
auto elemPtrTy = getElementPtrType(op, i);
writePtrs[i] = gep(elemPtrTy, smemBases[i], writeOffset);
}
Value threadIsNeeded = icmp_slt(threadId, i32_val(elems));
Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps));
@@ -550,10 +636,17 @@ private:
readOffset = add(readOffset, i32_val(numThreads));
}
}
}
barrier();
// set output values
// Load the final reduction from shared memory and replace the reduce result
// with it.
void loadReductionAndPackResult(ReduceOpHelper &helper,
SmallVector<Value> &smemBases,
ConversionPatternRewriter &rewriter) const {
triton::ReduceOp op = helper.getOperation();
Location loc = op.getLoc();
auto smemShapes = helper.getScratchConfigsFast();
auto order = getOrder(helper.getSrcLayout());
SmallVector<Value> results(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
if (auto resultTy =
@@ -567,10 +660,11 @@ private:
SmallVector<Value> resultVals(resultElems);
for (size_t j = 0; j < resultElems; ++j) {
SmallVector<Value> readIdx = resultIndices[j];
readIdx.insert(readIdx.begin() + axis, i32_val(0));
readIdx.insert(readIdx.begin() + op.getAxis(), i32_val(0));
Value readOffset =
linearize(rewriter, loc, readIdx, smemShapes[0], order);
Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset);
Value readPtr =
gep(getElementPtrType(op, i), smemBases[i], readOffset);
resultVals[j] = load(readPtr);
}
@@ -582,6 +676,65 @@ private:
}
}
rewriter.replaceOp(op, results);
}
// Use warp shuffle for reduction within warps and shared memory for data
// exchange across warps
LogicalResult matchAndRewriteFast(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
ReduceOpHelper helper(op);
assert(helper.isSupportedLayout() &&
"Unexpected srcLayout in ReduceOpConversion");
Location loc = op->getLoc();
auto srcValues = unpackInputs(loc, op, adaptor, rewriter);
std::map<SmallVector<unsigned>, SmallVector<Value>> accs;
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
// First reduce all the values along axis within each thread.
reduceWithinThreads(helper, srcValues, accs, indices, rewriter);
// Then reduce across threads within a warp.
reduceWithinWarps(helper, accs, rewriter);
if (helper.isWarpSynchronous()) {
// If all the values to be reduced are within the same warp there is
// nothing left to do.
packResults(helper, accs, rewriter);
return success();
}
// Compute a shared memory base per operand.
auto smemShapes = helper.getScratchConfigsFast();
unsigned elems = product<unsigned>(smemShapes[0]);
unsigned maxElems = std::max(elems, product<unsigned>(smemShapes[1]));
SmallVector<Value> smemBases(op.getNumOperands());
smemBases[0] =
bitcast(getSharedMemoryBase(loc, rewriter, op.getOperation()),
getElementPtrType(op, 0));
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
smemBases[i] = bitcast(gep(getElementPtrType(op, i - 1), smemBases[i - 1],
i32_val(maxElems)),
getElementPtrType(op, i));
}
storeWarpReduceToSharedMemory(helper, accs, indices, smemBases, rewriter);
sync(rewriter, loc, op);
// The second round of shuffle reduction
// now the problem size: sizeInterWarps, s1, s2, .. , sn
// where sizeInterWarps is 2^m
//
// Each thread needs to process:
// elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads
accumulatePartialReductions(helper, smemBases, rewriter);
// We could avoid this barrier in some of the layouts, however this is not
// the general case.
// TODO: optimize the barrier in case the layouts are accepted.
sync(rewriter, loc, op);
// set output values
loadReductionAndPackResult(helper, smemBases, rewriter);
return success();
}
@@ -589,9 +742,10 @@ private:
void populateReduceOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) {
int computeCapability, PatternBenefit benefit) {
patterns.add<ReduceOpConversion>(typeConverter, allocation, indexCacheInfo,
benefit);
computeCapability, benefit);
}

View File

@@ -8,8 +8,9 @@ using namespace mlir::triton;
void populateReduceOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit);
int computeCapability, PatternBenefit benefit);
#endif

View File

@@ -0,0 +1,43 @@
#include "RegReallocOpToLLVM.h"
using namespace mlir;
using namespace mlir::triton;
struct RegAllocOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::nvidia_gpu::RegAllocOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::RegAllocOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::RegAllocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
rewriter.replaceOpWithNewOp<triton::nvgpu::RegAllocOp>(
op, adaptor.getRegCount());
return success();
}
};
struct RegDeallocOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::nvidia_gpu::RegDeallocOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::RegDeallocOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::nvidia_gpu::RegDeallocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
rewriter.replaceOpWithNewOp<triton::nvgpu::RegDeallocOp>(
op, adaptor.getRegCount());
return success();
}
};
void populateRegReallocOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
const ModuleAllocation &allocation, PatternBenefit benefit) {
patterns.add<RegAllocOpConversion>(typeConverter, benefit);
patterns.add<RegDeallocOpConversion>(typeConverter, benefit);
return;
}

Some files were not shown because too many files have changed in this diff Show More