Merge remote-tracking branch 'upstream/main' into triton-mlir-IFU-02232023

This commit is contained in:
Rohit Santhanam
2023-02-23 21:41:54 +00:00
43 changed files with 1279 additions and 1018 deletions

View File

@@ -3,9 +3,10 @@ name: Integration Tests
on:
workflow_dispatch:
pull_request:
branches:
- main
- triton-mlir
branches: [main]
merge_group:
branches: [main]
types: [checks_requested]
concurrency:
group: ${{ github.ref }}
@@ -21,7 +22,7 @@ jobs:
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
echo '::set-output name=matrix::[["self-hosted", "A10"], ["self-hosted", "V100"], "macos-10.15"]'
echo '::set-output name=matrix::[["self-hosted", "A100"], ["self-hosted", "V100"], "macos-10.15"]'
else
echo '::set-output name=matrix::["ubuntu-latest", "macos-10.15"]'
fi
@@ -43,29 +44,33 @@ jobs:
run: |
rm -rf ~/.triton/cache/
- name: Update path
run: |
echo "$HOME/.local/bin/" >> $GITHUB_PATH
- name: Check imports
if: ${{ matrix.runner != 'macos-10.15' }}
run: |
pip install isort
pip3 install isort
isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )
- name: Check python style
if: ${{ matrix.runner != 'macos-10.15' }}
run: |
pip install autopep8
pip3 install autopep8
autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )
- name: Check cpp style
if: ${{ matrix.runner != 'macos-10.15' }}
run: |
pip install clang-format
pip3 install clang-format
find . -regex '.*\.\(cpp\|hpp\|h\|cc\)' -not -path "./python/triton/*" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file --dry-run -Werror -i ||
(echo '::error title=Style issues:: Please run `find . -regex ".*\.\(cpp\|hpp\|h\|cc\)" -not -path "./python/triton/*" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file -i`' ; exit 1)
- name: Flake8
if: ${{ matrix.runner != 'macos-10.15' }}
run: |
pip install flake8
pip3 install flake8
flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )
- name: Install Triton
@@ -94,3 +99,12 @@ jobs:
cd python/
cd "build/$(ls build)"
ctest
- name: Regression tests
if: ${{ contains(matrix.runner, 'A100') }}
run: |
cd python/test/regression
sudo nvidia-smi -i 0 -pm 1
sudo nvidia-smi -i 0 --lock-gpu-clocks=1350,1350
pytest -vs .
sudo nvidia-smi -i 0 -rgc

View File

@@ -51,8 +51,8 @@ include_directories(${PYBIND11_INCLUDE_DIR})
if(WIN32)
SET(BUILD_SHARED_LIBS OFF)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/deps/dlfcn-win32/src)
add_subdirectory(deps/dlfcn-win32/src ${CMAKE_BINARY_DIR}/dlfcn-win32)
find_package(dlfcn-win32 REQUIRED)
set(CMAKE_DL_LIBS dlfcn-win32::dl)
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17 -fvisibility=hidden -fvisibility-inlines-hidden")
@@ -215,8 +215,7 @@ get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
if(TRITON_BUILD_PYTHON_MODULE)
add_library(triton SHARED ${PYTHON_SRC})
target_link_libraries(triton
set(TRITON_LIBRARIES
TritonAnalysis
TritonTransforms
TritonGPUTransforms
@@ -237,16 +236,25 @@ if(TRITON_BUILD_PYTHON_MODULE)
MLIRROCDLToLLVMIRTranslation
MLIRIR
)
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
if(WIN32)
target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32
target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} ${CMAKE_DL_LIBS}
${TRITON_LIBRARIES}
)
elseif(APPLE)
target_link_libraries(triton ${LLVM_LIBRARIES} z)
target_link_libraries(triton ${LLVM_LIBRARIES} z
${TRITON_LIBRARIES}
)
else()
target_link_libraries(triton ${LLVM_LIBRARIES} z stdc++fs)
target_link_libraries(triton ${LLVM_LIBRARIES} z stdc++fs
${TRITON_LIBRARIES}
)
endif()
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
endif()
if (UNIX AND NOT APPLE)
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--exclude-libs,ALL")
endif()
if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32)

View File

@@ -32,7 +32,7 @@ int main(int argc, char **argv) {
// TODO: register Triton & TritonGPU passes
mlir::DialectRegistry registry;
registry.insert<mlir::triton::TritonDialect,
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
mlir::triton::gpu::TritonGPUDialect, mlir::func::FuncDialect,
mlir::math::MathDialect, mlir::arith::ArithDialect,
mlir::scf::SCFDialect, mlir::gpu::GPUDialect>();

View File

@@ -36,28 +36,29 @@ public:
void run();
private:
struct RegionInfo {
struct BlockInfo {
using BufferIdSetT = Allocation::BufferIdSetT;
BufferIdSetT syncReadBuffers;
BufferIdSetT syncWriteBuffers;
RegionInfo() = default;
RegionInfo(const BufferIdSetT &syncReadBuffers,
const BufferIdSetT &syncWriteBuffers)
BlockInfo() = default;
BlockInfo(const BufferIdSetT &syncReadBuffers,
const BufferIdSetT &syncWriteBuffers)
: syncReadBuffers(syncReadBuffers), syncWriteBuffers(syncWriteBuffers) {
}
/// Unions two RegionInfo objects.
void join(const RegionInfo &other) {
/// Unions two BlockInfo objects.
BlockInfo &join(const BlockInfo &other) {
syncReadBuffers.insert(other.syncReadBuffers.begin(),
other.syncReadBuffers.end());
syncWriteBuffers.insert(other.syncWriteBuffers.begin(),
other.syncWriteBuffers.end());
return *this;
}
/// Returns true if buffers in two RegionInfo objects are intersected.
bool isIntersected(const RegionInfo &other, Allocation *allocation) const {
/// Returns true if buffers in two BlockInfo objects are intersected.
bool isIntersected(const BlockInfo &other, Allocation *allocation) const {
return /*RAW*/ isIntersected(syncWriteBuffers, other.syncReadBuffers,
allocation) ||
/*WAR*/
@@ -74,6 +75,14 @@ private:
syncWriteBuffers.clear();
}
/// Compares two BlockInfo objects.
bool operator==(const BlockInfo &other) const {
return syncReadBuffers == other.syncReadBuffers &&
syncWriteBuffers == other.syncWriteBuffers;
}
bool operator!=(const BlockInfo &other) const { return !(*this == other); }
private:
/// Returns true if buffers in two sets are intersected.
bool isIntersected(const BufferIdSetT &lhs, const BufferIdSetT &rhs,
@@ -99,19 +108,19 @@ private:
/// op5
/// op6
/// op7
/// region2 and region3 started with the information of region1.
/// Each region is analyzed separately and keeps their own copy of the
/// information. At op7, we union the information of the region2 and region3
/// and update the information of region1.
void dfsOperation(Operation *operation, RegionInfo *blockInfo,
OpBuilder *builder);
/// TODO: Explain why we don't use ForwardAnalysis:
void resolve(Operation *operation, OpBuilder *builder);
/// Updates the RegionInfo operation based on the operation.
void transfer(Operation *operation, RegionInfo *blockInfo,
OpBuilder *builder);
/// Updates the BlockInfo operation based on the operation.
void update(Operation *operation, BlockInfo *blockInfo, OpBuilder *builder);
/// Collects the successors of the terminator
void visitTerminator(Operation *operation, SmallVector<Block *> &successors);
private:
Allocation *allocation;
DenseMap<Block *, BlockInfo> inputBlockInfoMap;
DenseMap<Block *, BlockInfo> outputBlockInfoMap;
};
} // namespace mlir

View File

@@ -6,10 +6,10 @@ include "triton/Dialect/Triton/IR/TritonTypes.td"
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // NoMemoryEffect
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
include "mlir/Interfaces/SideEffectInterfaces.td" // NoMemoryEffect
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
//
@@ -29,7 +29,7 @@ class TT_Op<string mnemonic, list<Trait> traits = []> :
// extui, extsi, tructi
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoMemoryEffect,
Pure,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
let summary = "Cast int64 to pointer";
@@ -42,7 +42,7 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoMemoryEffect,
Pure,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
let summary = "Cast pointer to int64";
@@ -56,7 +56,7 @@ def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
// arith.bitcast doesn't support pointers
def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoMemoryEffect,
Pure,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
let summary = "Cast between types of the same bitwidth";
@@ -71,7 +71,7 @@ def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoMemoryEffect,
Pure,
DeclareOpInterfaceMethods<CastOpInterface>]> {
let summary = "Floating point casting for custom types";
@@ -95,7 +95,7 @@ def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
//
def TT_AddPtrOp : TT_Op<"addptr",
[NoMemoryEffect,
[Pure,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
TypesMatchWith<"result type matches ptr type",
@@ -222,7 +222,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>,
//
// Shape Manipulation Ops
//
def TT_SplatOp : TT_Op<"splat", [NoMemoryEffect,
def TT_SplatOp : TT_Op<"splat", [Pure,
SameOperandsAndResultElementType,
SameOperandsAndResultEncoding]> {
let summary = "splat";
@@ -236,7 +236,7 @@ def TT_SplatOp : TT_Op<"splat", [NoMemoryEffect,
let hasFolder = 1;
}
def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoMemoryEffect,
def TT_ExpandDimsOp : TT_Op<"expand_dims", [Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
SameOperandsAndResultElementType]> {
let summary = "expand_dims";
@@ -248,6 +248,7 @@ def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoMemoryEffect,
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}
// view is not `pure` because it may reorder elements
def TT_ViewOp : TT_Op<"view", [NoMemoryEffect,
SameOperandsAndResultElementType]> {
let summary = "view";
@@ -260,7 +261,7 @@ def TT_ViewOp : TT_Op<"view", [NoMemoryEffect,
}
def TT_BroadcastOp : TT_Op<"broadcast", [NoMemoryEffect,
def TT_BroadcastOp : TT_Op<"broadcast", [Pure,
SameOperandsAndResultElementType,
SameOperandsAndResultEncoding]> {
let summary = "broadcast. No left-padding as of now.";
@@ -274,6 +275,7 @@ def TT_BroadcastOp : TT_Op<"broadcast", [NoMemoryEffect,
let hasFolder = 1;
}
// cat is not `pure` because it may reorder elements
def TT_CatOp : TT_Op<"cat", [NoMemoryEffect,
SameOperandsAndResultElementType]> {
let summary = "concatenate 2 tensors";
@@ -285,7 +287,7 @@ def TT_CatOp : TT_Op<"cat", [NoMemoryEffect,
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)";
}
def TT_TransOp : TT_Op<"trans", [NoMemoryEffect,
def TT_TransOp : TT_Op<"trans", [Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
SameOperandsAndResultElementType]> {
@@ -301,7 +303,7 @@ def TT_TransOp : TT_Op<"trans", [NoMemoryEffect,
//
// SPMD Ops
//
def TT_GetProgramIdOp : TT_Op<"get_program_id", [NoMemoryEffect]> {
def TT_GetProgramIdOp : TT_Op<"get_program_id", [Pure]> {
let arguments = (ins I32Attr:$axis);
let results = (outs I32:$result);
@@ -309,7 +311,7 @@ def TT_GetProgramIdOp : TT_Op<"get_program_id", [NoMemoryEffect]> {
let assemblyFormat = "attr-dict `:` type($result)";
}
def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [NoMemoryEffect]> {
def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [Pure]> {
let arguments = (ins I32Attr:$axis);
let results = (outs I32:$result);
@@ -320,7 +322,7 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [NoMemoryEffect]> {
//
// Dot Op
//
def TT_DotOp : TT_Op<"dot", [NoMemoryEffect,
def TT_DotOp : TT_Op<"dot", [Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
TypesMatchWith<"result's type matches accumulator's type",
"d", "c", "$_self">]> {
@@ -340,7 +342,7 @@ def TT_DotOp : TT_Op<"dot", [NoMemoryEffect,
//
// Reduce Op
//
def TT_ReduceOp : TT_Op<"reduce", [NoMemoryEffect,
def TT_ReduceOp : TT_Op<"reduce", [Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "reduce";
@@ -364,7 +366,7 @@ def TT_ReduceOp : TT_Op<"reduce", [NoMemoryEffect,
//
// External elementwise op
//
def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoMemoryEffect, Elementwise, SameOperandsAndResultShape,
def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [Pure, Elementwise, SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
SameVariadicOperandSize]> {
let summary = "ext_elemwise";
@@ -386,7 +388,7 @@ def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoMemoryEffect, Elementwise, SameO
// Make Range Op
//
// TODO: should have ConstantLike as Trait
def TT_MakeRangeOp : TT_Op<"make_range", [NoMemoryEffect]> {
def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> {
let summary = "make range";
let description = [{

View File

@@ -7,7 +7,7 @@ include "mlir/Dialect/Arith/IR/ArithBase.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // NoMemoryEffect
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
def ResultsAreSharedEncoding: NativeOpTrait<"ResultsAreSharedEncoding">;
@@ -18,13 +18,15 @@ class TTG_Op<string mnemonic, list<Trait> traits = []> :
def TTG_ConvertLayoutOp : TTG_Op<"convert_layout",
[SameOperandsAndResultShape,
SameOperandsAndResultElementType,
NoMemoryEffect]> {
Pure]> {
let summary = "convert layout";
let arguments = (ins TT_Tensor:$src);
let results = (outs TT_Tensor:$result);
let hasCanonicalizeMethod = 1;
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}
@@ -59,7 +61,7 @@ def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> {
// This is needed because these ops don't
// handle encodings
// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td#L111
def TTG_CmpIOp : TTG_Op<"cmpi", [NoMemoryEffect, Elementwise,
def TTG_CmpIOp : TTG_Op<"cmpi", [Pure, Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "integer comparison operation";
@@ -73,7 +75,7 @@ def TTG_CmpIOp : TTG_Op<"cmpi", [NoMemoryEffect, Elementwise,
let results = (outs TT_BoolLike:$result);
}
def TTG_CmpFOp : TTG_Op<"cmpf", [NoMemoryEffect, Elementwise,
def TTG_CmpFOp : TTG_Op<"cmpf", [Pure, Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "floating-point comparison operation";
@@ -88,7 +90,7 @@ def TTG_CmpFOp : TTG_Op<"cmpf", [NoMemoryEffect, Elementwise,
}
// TODO: migrate to arith::SelectOp on LLVM16
def TTG_SelectOp : TTG_Op<"select", [NoMemoryEffect, Elementwise,
def TTG_SelectOp : TTG_Op<"select", [Pure, Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "select operation";

View File

@@ -6,7 +6,9 @@
namespace mlir {
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 2);
// TODO(Keren): prefetch pass not working yet
std::unique_ptr<Pass>
createTritonGPUAccelerateMatmulPass(int computeCapability = 80);
std::unique_ptr<Pass> createTritonGPUPrefetchPass();
std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
@@ -17,10 +19,12 @@ std::unique_ptr<Pass> createTritonGPUReorderInstructionsPass();
std::unique_ptr<Pass> createTritonGPUDecomposeConversionsPass();
std::unique_ptr<Pass> createTritonGPUCombineOpsPass(int computeCapability = 80);
std::unique_ptr<Pass> createTritonGPURemoveLayoutConversionsPass();
std::unique_ptr<Pass> createTritonGPUVerifier();
std::unique_ptr<Pass> createTritonGPUFuseTranspositionsPass();
std::unique_ptr<Pass> createTritonGPUUpdateMmaForVoltaPass();
/// Generate the code for registering passes.

View File

@@ -7,7 +7,8 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
let summary = "pipeline";
let description = [{
Unroll loops to hide global memory -> shared memory latency.
Replace `LoadOp` in loops by `InsertSliceAsyncOp` instructions that asynchronously construct the data
needed at the next iteration
}];
let constructor = "mlir::createTritonGPUPipelinePass()";
@@ -27,7 +28,8 @@ def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
let summary = "prefetch";
let description = [{
Prefetch operands (a and b) of tt.dot into shared memory to hide shared memory -> register latency.
Decompose `DotOp` instructions in loops into several finer-grained `DotOp`
that may have their operands constructed at the end of the previous iteration
}];
let constructor = "mlir::createTritonGPUPrefetchPass()";
@@ -37,6 +39,41 @@ def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
"mlir::arith::ArithDialect"];
}
def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::ModuleOp"> {
let summary = "accelerate matmul";
let description = [{
Optimize the input/output layout of `dot` instruction to make them compatible hardware accelerators
(e.g., Nvidia tensor cores)
}];
let constructor = "mlir::createTritonGPUAccelerateMatmulPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
];
}
def TritonGPUFuseTranspositions : Pass<"tritongpu-fuse-transposition", "mlir::ModuleOp"> {
let summary = "fuse transpositions";
let description = [{
Re-arranged layouts of tensors used as matrix multiplication operands so as to promote the use of
hardware-accelerated transpositions.
}];
let constructor = "mlir::createTritonGPUFuseTranspositionsPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
}
def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> {
let summary = "coalesce";
@@ -49,26 +86,16 @@ def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> {
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
}
def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
let summary = "combine triton gpu ops";
def TritonGPURemoveLayoutConversions : Pass<"tritongpu-remove-layout-conversions", "mlir::ModuleOp"> {
let summary = "remove superfluous layout conversions";
let description = [{
convert_layout(convert_layout(%src, #LAYOUT_0), #LAYOUT_1) =>
convert_layout(%src, #LAYOUT_1)
convert_layout(%src, #LAYOUT) => %src if %src.layout() == #LAYOUT
}];
let constructor = "mlir::createTritonGPUCombineOpsPass()";
let constructor = "mlir::createTritonGPURemoveLayoutConversionsPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
];
}
def TritonGPUReorderInstructions: Pass<"tritongpu-reorder-instructions", "mlir::ModuleOp"> {
@@ -95,19 +122,6 @@ def TritonGPUDecomposeConversions: Pass<"tritongpu-decompose-conversions", "mlir
"mlir::triton::TritonDialect"];
}
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {
let summary = "canonicalize scf.ForOp ops";
let description = [{
This implements some optimizations that are missing in the standard scf.ForOp
canonicalizer.
}];
let constructor = "mlir::createTritonGPUCanonicalizeLoopsPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
}
def UpdateMmaForVolta : Pass<"tritongpu-update-mma-for-volta", "mlir::ModuleOp"> {
let summary = "Update mma encodings for Volta";

View File

@@ -2,74 +2,84 @@
#include "triton/Analysis/Alias.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include <deque>
namespace mlir {
void MembarAnalysis::run() {
auto *operation = allocation->getOperation();
RegionInfo regionInfo;
OpBuilder builder(operation);
dfsOperation(operation, &regionInfo, &builder);
resolve(operation, &builder);
}
void MembarAnalysis::dfsOperation(Operation *operation,
RegionInfo *parentRegionInfo,
OpBuilder *builder) {
transfer(operation, parentRegionInfo, builder);
if (operation->getNumRegions()) {
// If there's any nested regions, we need to visit them.
// scf.if and scf.else: two regions
// scf.if only: two regions
// scf.for: one region
RegionInfo curRegionInfo;
auto traverseRegions = [&]() -> auto{
for (auto &region : operation->getRegions()) {
// Copy the parent info as the current info.
RegionInfo regionInfo = *parentRegionInfo;
for (auto &block : region.getBlocks()) {
// assert(region.getBlocks().size() == 1 &&
// "Multiple blocks in a region is not supported");
for (auto &op : block.getOperations()) {
// Traverse the nested operation.
dfsOperation(&op, &regionInfo, builder);
}
}
curRegionInfo.join(regionInfo);
void MembarAnalysis::resolve(Operation *operation, OpBuilder *builder) {
// Initialize the blockList
std::deque<Block *> blockList;
operation->walk<WalkOrder::PreOrder>([&](Block *block) {
for (auto &op : block->getOperations()) {
// Check if the operation belongs to scf dialect, if so, we need to
// throw an error
if (op.getDialect()->getNamespace() == "scf") {
op.emitError("scf dialect is not supported in membar. Please lower it "
"to cf dialect first.");
return;
}
// Set the parent region info as the union of the nested region info.
*parentRegionInfo = curRegionInfo;
};
}
if (block->isEntryBlock())
blockList.emplace_back(block);
});
traverseRegions();
if (isa<scf::ForOp>(operation)) {
// scf.for can have two possible inputs: the init value and the
// previous iteration's result. Although we've applied alias analysis,
// there could be unsynced memory accesses on reused memories.
// For example, consider the following code:
// %1 = convert_layout %0: blocked -> shared
// ...
// gpu.barrier
// ...
// %5 = convert_layout %4 : shared -> dot
// %6 = tt.dot %2, %5
// scf.yield
//
// Though %5 could be released before scf.yield, it may shared the same
// memory with %1. So we actually have to insert a barrier before %1 to
// make sure the memory is synced.
traverseRegions();
// A fixed point algorithm
while (!blockList.empty()) {
auto *block = blockList.front();
blockList.pop_front();
// Make a copy of the inputblockInfo but not update
auto inputBlockInfo = inputBlockInfoMap.lookup(block);
SmallVector<Block *> successors;
for (auto &op : block->getOperations()) {
if (op.hasTrait<OpTrait::IsTerminator>()) {
visitTerminator(&op, successors);
} else {
update(&op, &inputBlockInfo, builder);
}
}
// Get the reference because we want to update if it changed
if (outputBlockInfoMap.count(block) &&
inputBlockInfo == outputBlockInfoMap[block]) {
// If we have seen the block before and the inputBlockInfo is the same as
// the outputBlockInfo, we skip the successors
continue;
}
// Update the current block
outputBlockInfoMap[block].join(inputBlockInfo);
// Update the successors
for (auto *successor : successors) {
inputBlockInfoMap[successor].join(outputBlockInfoMap[block]);
blockList.emplace_back(successor);
}
}
}
void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
OpBuilder *builder) {
if (isa<scf::ForOp>(op) || isa<scf::IfOp>(op) || isa<scf::YieldOp>(op) ||
isa<tensor::ExtractSliceOp>(op) || isa<triton::gpu::AllocTensorOp>(op)) {
// Do not insert barriers before control flow operations and
// alloc/extract/insert
void MembarAnalysis::visitTerminator(Operation *op,
SmallVector<Block *> &successors) {
if (auto branchInterface = dyn_cast<BranchOpInterface>(op)) {
Block *parentBlock = branchInterface->getBlock();
for (Block *successor : parentBlock->getSuccessors()) {
successors.push_back(successor);
}
return;
}
// Otherwise, it could be a return op
assert(isa<func::ReturnOp>(op) && "Unknown terminator");
}
void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
OpBuilder *builder) {
if (isa<tensor::ExtractSliceOp>(op) || isa<triton::gpu::AllocTensorOp>(op) ||
isa<triton::TransOp>(op)) {
// alloc is an allocation op without memory write.
// FIXME(Keren): extract_slice is always alias for now
return;
@@ -77,7 +87,7 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
if (isa<gpu::BarrierOp>(op)) {
// If the current op is a barrier, we sync previous reads and writes
regionInfo->sync();
blockInfo->sync();
return;
}
@@ -85,26 +95,26 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
!isa<gpu::BarrierOp>(op->getNextNode())) {
// If the current op is an async wait and the next op is not a barrier we
// insert a barrier op and sync
regionInfo->sync();
blockInfo->sync();
OpBuilder::InsertionGuard g(*builder);
builder->setInsertionPointAfter(op);
builder->create<gpu::BarrierOp>(op->getLoc());
regionInfo->sync();
blockInfo->sync();
return;
}
RegionInfo curRegionInfo;
BlockInfo curBlockInfo;
for (Value value : op->getOperands()) {
for (auto bufferId : allocation->getBufferIds(value)) {
if (bufferId != Allocation::InvalidBufferId) {
if (isa<triton::gpu::InsertSliceAsyncOp>(op) ||
isa<tensor::InsertSliceOp>(op)) {
// FIXME(Keren): insert_slice and insert_slice_async are always alias
// for now
curRegionInfo.syncWriteBuffers.insert(bufferId);
// FIXME(Keren): insert_slice and insert_slice_async are always
// alias for now
curBlockInfo.syncWriteBuffers.insert(bufferId);
} else {
// ConvertLayoutOp: shared memory -> registers
curRegionInfo.syncReadBuffers.insert(bufferId);
curBlockInfo.syncReadBuffers.insert(bufferId);
}
}
}
@@ -113,25 +123,25 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
// ConvertLayoutOp: registers -> shared memory
auto bufferId = allocation->getBufferId(value);
if (bufferId != Allocation::InvalidBufferId) {
curRegionInfo.syncWriteBuffers.insert(bufferId);
curBlockInfo.syncWriteBuffers.insert(bufferId);
}
}
// Scratch buffer is considered as both shared memory write & read
auto bufferId = allocation->getBufferId(op);
if (bufferId != Allocation::InvalidBufferId) {
curRegionInfo.syncWriteBuffers.insert(bufferId);
curRegionInfo.syncReadBuffers.insert(bufferId);
curBlockInfo.syncWriteBuffers.insert(bufferId);
curBlockInfo.syncReadBuffers.insert(bufferId);
}
if (regionInfo->isIntersected(curRegionInfo, allocation)) {
if (blockInfo->isIntersected(curBlockInfo, allocation)) {
OpBuilder::InsertionGuard g(*builder);
builder->setInsertionPoint(op);
builder->create<gpu::BarrierOp>(op->getLoc());
regionInfo->sync();
blockInfo->sync();
}
// Update the region info, even if barrier is inserted, we have to maintain
// the current op's read/write buffers.
regionInfo->join(curRegionInfo);
blockInfo->join(curBlockInfo);
}
} // namespace mlir

View File

@@ -337,7 +337,8 @@ namespace {
// Copied from TestDeadCodeAnalysis.cpp, because some dead code analysis
// interacts with constant propagation, but SparseConstantPropagation
// doesn't seem to be sufficient.
struct ConstantAnalysis : public DataFlowAnalysis {
class ConstantAnalysis : public DataFlowAnalysis {
public:
using DataFlowAnalysis::DataFlowAnalysis;
LogicalResult initialize(Operation *top) override {
@@ -359,12 +360,19 @@ struct ConstantAnalysis : public DataFlowAnalysis {
value, op->getDialect())));
return success();
}
// Dead code analysis requires every operands has initialized ConstantValue
// state before it is visited.
// https://github.com/llvm/llvm-project/blob/2ec1aba2b69faa1de5f71832a48e25aa3b5d5314/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp#L322
// That's why we need to set all operands to unknown constants.
setAllToUnknownConstants(op->getResults());
for (Region &region : op->getRegions())
setAllToUnknownConstants(region.getArguments());
for (Region &region : op->getRegions()) {
for (Block &block : region.getBlocks())
setAllToUnknownConstants(block.getArguments());
}
return success();
}
private:
/// Set all given values as not constants.
void setAllToUnknownConstants(ValueRange values) {
dataflow::ConstantValue unknownConstant(nullptr, nullptr);

View File

@@ -291,10 +291,10 @@ public:
for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) {
// extract multi dimensional index for current element
auto idx = srcIndices[elemIdx];
Value idxCol = idx[inOrder[0]]; // contiguous dimension
Value idxRow = idx[inOrder[1]]; // discontiguous dimension
Value strideCol = srcStrides[inOrder[0]];
Value strideRow = srcStrides[inOrder[1]];
Value idxCol = idx[outOrder[0]]; // contiguous dimension
Value idxRow = idx[outOrder[1]]; // discontiguous dimension
Value strideCol = srcStrides[outOrder[0]];
Value strideRow = srcStrides[outOrder[1]];
// extract dynamic/static offset for immediate offseting
unsigned immedateOffCol = 0;
if (auto add = dyn_cast_or_null<LLVM::AddOp>(idxCol.getDefiningOp()))
@@ -338,7 +338,7 @@ public:
Value currPtr = gep(dstPtrTy, dstPtrBase, offset);
// compute immediate offset
Value immedateOff =
add(mul(i32_val(immedateOffRow), srcStrides[inOrder[1]]),
add(mul(i32_val(immedateOffRow), srcStrides[outOrder[1]]),
i32_val(immedateOffCol));
ret[elemIdx] = gep(dstPtrTy, currPtr, immedateOff);
}

View File

@@ -128,16 +128,13 @@ public:
// Step 2: Decompose insert_slice_async to use load + insert_slice for
// pre-Ampere architectures or unsupported vectorized load sizes
// Step 3: Allocate shared memories and insert barriers
// Step 4: Convert SCF to CFG
// Step 5: Convert FuncOp to LLVMFuncOp via partial conversion
// Step 6: Get axis and shared memory info
// Step 7: Convert the rest of ops via partial conversion
// Step 4: Convert FuncOp to LLVMFuncOp via partial conversion
// Step 5: Get axis and shared memory info
// Step 6: Convert the rest of ops via partial conversion
//
// The reason for putting step 3 before step 4 is that the membar
// analysis currently only supports SCF but not CFG. The reason for a
// separation between 5/7 is that, step 6 is out of the scope of Dialect
// Conversion, thus we need to make sure the smem is not revised during the
// conversion of step 7.
// The reason for a separation between 4/6 is that, step 5 is out of the
// scope of Dialect Conversion, thus we need to make sure the smem is not
// revised during the conversion of step 6.
// Step 1
decomposeMmaToDotOperand(mod, numWarps);
@@ -153,24 +150,13 @@ public:
membarPass.run();
// Step 4
RewritePatternSet scf_patterns(context);
mlir::populateSCFToControlFlowConversionPatterns(scf_patterns);
mlir::ConversionTarget scf_target(*context);
scf_target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp,
scf::WhileOp, scf::ExecuteRegionOp>();
scf_target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
RewritePatternSet funcPatterns(context);
funcPatterns.add<FuncOpConversion>(typeConverter, numWarps, /*benefit=*/1);
if (failed(
applyPartialConversion(mod, scf_target, std::move(scf_patterns))))
applyPartialConversion(mod, funcTarget, std::move(funcPatterns))))
return signalPassFailure();
// Step 5
RewritePatternSet func_patterns(context);
func_patterns.add<FuncOpConversion>(typeConverter, numWarps, /*benefit=*/1);
if (failed(
applyPartialConversion(mod, funcTarget, std::move(func_patterns))))
return signalPassFailure();
// Step 6 - get axis and shared memory info
// Step 5 - get axis and shared memory info
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
AxisInfoAnalysis *axisInfoAnalysis = solver->load<AxisInfoAnalysis>();
if (failed(solver->initializeAndRun(mod)))
@@ -180,7 +166,7 @@ public:
mlir::IntegerAttr::get(mlir::IntegerType::get(context, 32),
allocation.getSharedMemorySize()));
// Step 7 - rewrite rest of ops
// Step 6 - rewrite rest of ops
// We set a higher benefit here to ensure triton's patterns runs before
// arith patterns for some encoding not supported by the community
// patterns.
@@ -222,30 +208,16 @@ public:
// Add arith/math's patterns to help convert scalar expression to LLVM.
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns);
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
patterns);
#ifdef USE_ROCM
mlir::populateGpuToROCDLConversionPatterns(typeConverter, patterns, mlir::gpu::amd::HIP);
#else
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
patterns);
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
#endif
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();
// Take care of scf pattern introduced by LoadStoreOp
#ifdef USE_ROCM
RewritePatternSet scf_patterns_extra(context);
mlir::populateSCFToControlFlowConversionPatterns(scf_patterns_extra);
if (failed(
applyPartialConversion(mod, scf_target, std::move(scf_patterns_extra))))
return signalPassFailure();
RewritePatternSet patterns_extra(context);
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns_extra);
if (failed(
applyPartialConversion(mod, target, std::move(patterns_extra))))
return signalPassFailure();
#endif
}
private:

View File

@@ -790,6 +790,141 @@ struct TritonGPUInferLayoutInterface
}
};
//===----------------------------------------------------------------------===//
// Canonicalizer
//===----------------------------------------------------------------------===//
LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
PatternRewriter &rewriter) {
// we don't handle conversions to DotOperandEncodingAttr
// this is a heuristics to accommodate fused attention
auto srcType = op.getOperand().getType().cast<RankedTensorType>();
auto dstType = op.getType().cast<RankedTensorType>();
if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>() &&
srcType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
return mlir::failure();
// convert to the same layout -- we can delete
if (op->getResultTypes() == op->getOperandTypes()) {
rewriter.replaceOp(op, op->getOperands());
return mlir::success();
}
Operation *arg = op->getOperand(0).getDefiningOp();
// block argument
if (!arg)
return mlir::failure();
// cvt(view) -> view
if (auto view = dyn_cast<triton::ViewOp>(arg)) {
rewriter.replaceOpWithNewOp<triton::ViewOp>(op, op->getResult(0).getType(),
view.getResult());
return mlir::success();
}
// cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2)
auto alloc_tensor = dyn_cast<triton::gpu::AllocTensorOp>(arg);
if (alloc_tensor) {
if (!isSharedEncoding(op->getResult(0))) {
return mlir::failure();
}
rewriter.replaceOpWithNewOp<triton::gpu::AllocTensorOp>(
op, op->getResult(0).getType());
return mlir::success();
}
// cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2))
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
if (insert_slice) {
if (!isSharedEncoding(op->getResult(0))) {
return mlir::failure();
}
auto newType = op->getResult(0).getType().cast<RankedTensorType>();
// Ensure that the new insert_slice op is placed in the same place as the
// old insert_slice op. Otherwise, the new insert_slice op may be placed
// after the async_wait op, which is not allowed.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(insert_slice);
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newType, insert_slice.getDst());
rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>(
op, newType, insert_slice.getSrc(), newArg.getResult(),
insert_slice.getIndex(), insert_slice.getMask(),
insert_slice.getOther(), insert_slice.getCache(),
insert_slice.getEvict(), insert_slice.getIsVolatile(),
insert_slice.getAxis());
return mlir::success();
}
// cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2))
auto extract_slice = dyn_cast<tensor::ExtractSliceOp>(arg);
if (extract_slice) {
if (!isSharedEncoding(op->getResult(0))) {
return mlir::failure();
}
auto origType =
extract_slice.getSource().getType().cast<RankedTensorType>();
auto newType = RankedTensorType::get(
origType.getShape(), origType.getElementType(),
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
auto origResType = op->getResult(0).getType().cast<RankedTensorType>();
auto resType = RankedTensorType::get(
origResType.getShape(), origResType.getElementType(),
extract_slice.getType().cast<RankedTensorType>().getEncoding());
// Ensure that the new extract_slice op is placed in the same place as the
// old extract_slice op. Otherwise, the new extract_slice op may be placed
// after the async_wait op, which is not allowed.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(extract_slice);
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newType, extract_slice.getSource());
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
op, resType, newArg.getResult(), extract_slice.offsets(),
extract_slice.sizes(), extract_slice.strides(),
extract_slice.static_offsets(), extract_slice.static_sizes(),
extract_slice.static_strides());
return mlir::success();
}
// cvt(cvt(x, type1), type2) -> cvt(x, type2)
if (llvm::isa<triton::gpu::ConvertLayoutOp>(arg)) {
if (arg->getOperand(0).getDefiningOp() &&
!isSharedEncoding(arg->getOperand(0)) &&
isSharedEncoding(op.getOperand()) &&
!isSharedEncoding(op.getResult())) {
return mlir::failure();
}
if (isSharedEncoding(op.getOperand()) && isSharedEncoding(op.getResult())) {
return mlir::failure();
}
auto srcType = op.getOperand().getType().cast<RankedTensorType>();
auto srcShared =
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
if (srcShared && srcShared.getVec() > 1)
return mlir::failure();
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
op, op->getResultTypes().front(), arg->getOperand(0));
return mlir::success();
}
// cvt(type1, splat(type2, x)) -> splat(type1, x)
if (auto splat = llvm::dyn_cast<triton::SplatOp>(arg)) {
rewriter.replaceOpWithNewOp<triton::SplatOp>(op, op->getResultTypes(),
splat.getSrc());
return mlir::success();
}
// cvt(type1, make_range(type2, x)) -> make_range(type1, x)
if (auto range = llvm::dyn_cast<triton::MakeRangeOp>(arg)) {
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
op, op->getResultTypes(), range.getStart(), range.getEnd());
return mlir::success();
}
// cvt(type, constant) -> constant
if (auto cst = llvm::dyn_cast<arith::ConstantOp>(arg))
if (auto ret = cst.getValue().dyn_cast<SplatElementsAttr>()) {
auto newRet = SplatElementsAttr::get(op->getResultTypes().front(),
ret.getSplatValue<Attribute>());
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newRet);
return mlir::success();
}
return mlir::failure();
}
//===----------------------------------------------------------------------===//
void TritonGPUDialect::initialize() {
addAttributes<
#define GET_ATTRDEF_LIST

View File

@@ -0,0 +1,215 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include <memory>
using namespace mlir;
namespace {
using triton::DotOp;
using triton::gpu::ConvertLayoutOp;
using triton::gpu::DotOperandEncodingAttr;
using triton::gpu::MmaEncodingAttr;
using triton::gpu::SliceEncodingAttr;
int computeCapabilityToMMAVersion(int computeCapability) {
#ifdef USE_ROCM
return 1;
#endif
if (computeCapability < 70) {
return 0;
} else if (computeCapability < 80) {
return 1;
} else if (computeCapability < 90) {
return 2;
} else if (computeCapability < 100) {
// FIXME: temporarily add this to pass unis tests
return 2;
} else {
assert(false && "computeCapability > 100 not supported");
return 3;
}
}
SmallVector<int64_t, 2> mmaVersionToShapePerWarp(int version) {
if (version == 1)
return {16, 16};
else if (version == 2)
return {16, 8};
else {
assert(false && "version not supported");
return {0, 0};
}
}
SmallVector<unsigned, 2> warpsPerTileV1(const ArrayRef<int64_t> shape,
int numWarps) {
// Set a default value that ensures product of wpt equals numWarps
return {static_cast<unsigned>(numWarps), 1};
}
SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
const ArrayRef<int64_t> shape,
int numWarps) {
SetVector<Operation *> slices;
mlir::getForwardSlice(dotOp.getResult(), &slices);
if (llvm::find_if(slices, [](Operation *op) {
return isa<triton::DotOp>(op);
}) != slices.end())
return {(unsigned)numWarps, 1};
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
bool changed = false;
// TODO (@daadaada): double-check.
// original logic in
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
// seems buggy for shape = [32, 16] ?
do {
changed = false;
if (ret[0] * ret[1] >= numWarps)
break;
if (shape[0] / shapePerWarp[0] / ret[0] >=
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
if (ret[0] < shape[0] / shapePerWarp[0]) {
ret[0] *= 2;
} else
ret[1] *= 2;
} else {
ret[1] *= 2;
}
} while (true);
return ret;
}
class BlockedToMMA : public mlir::RewritePattern {
int computeCapability;
mutable int mmaV1Counter{}; // used to generate ID for MMAv1 encoding
public:
BlockedToMMA(mlir::MLIRContext *context, int computeCapability)
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
computeCapability(computeCapability) {}
static SmallVector<unsigned, 2> getWarpsPerTile(triton::DotOp dotOp,
const ArrayRef<int64_t> shape,
int version, int numWarps) {
switch (version) {
case 1:
return warpsPerTileV1(shape, numWarps);
case 2:
return warpsPerTileV2(dotOp, shape, numWarps);
default:
llvm_unreachable("unsupported MMA version");
}
}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto dotOp = cast<triton::DotOp>(op);
// TODO: Check data-types and SM compatibility
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
if (!oldRetType.getEncoding() ||
oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
return failure();
// for FMA, should retain the blocked layout.
int versionMajor = computeCapabilityToMMAVersion(computeCapability);
if (!supportMMA(dotOp, versionMajor))
return failure();
// get MMA encoding for the given number of warps
auto retShape = oldRetType.getShape();
auto mod = op->getParentOfType<mlir::ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
auto warpsPerTile =
getWarpsPerTile(dotOp, retShape, versionMajor, numWarps);
triton::gpu::MmaEncodingAttr mmaEnc;
if (versionMajor == 1) {
mmaEnc = triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), versionMajor, numWarps, mmaV1Counter++);
} else if (versionMajor == 2) {
mmaEnc = triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), versionMajor, 0 /*versionMinor*/,
warpsPerTile);
} else {
llvm_unreachable("Mma layout only supports versionMajor in {1, 2}");
}
auto newRetType =
RankedTensorType::get(retShape, oldRetType.getElementType(), mmaEnc);
// convert accumulator
auto oldAcc = dotOp.getOperand(2);
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
oldAcc.getLoc(), newRetType, oldAcc);
Value a = dotOp.getA();
Value b = dotOp.getB();
auto oldAType = a.getType().cast<RankedTensorType>();
auto oldBType = b.getType().cast<RankedTensorType>();
auto oldAOrder = oldAType.getEncoding()
.cast<triton::gpu::DotOperandEncodingAttr>()
.getParent()
.cast<triton::gpu::BlockedEncodingAttr>()
.getOrder();
auto oldBOrder = oldBType.getEncoding()
.cast<triton::gpu::DotOperandEncodingAttr>()
.getParent()
.cast<triton::gpu::BlockedEncodingAttr>()
.getOrder();
Attribute isMMAv1RowA;
Attribute isMMAv1RowB;
if (versionMajor == 1) {
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
}
auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(
oldAType.getContext(), 0, newRetType.getEncoding(), isMMAv1RowA));
auto newBType = RankedTensorType::get(
oldBType.getShape(), oldBType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(
oldBType.getContext(), 1, newRetType.getEncoding(), isMMAv1RowB));
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
auto newDot = rewriter.create<triton::DotOp>(
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.getAllowTF32());
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
op, oldRetType, newDot.getResult());
return success();
}
};
} // namespace
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
class TritonGPUAccelerateMatmulPass
: public TritonGPUAccelerateMatmulBase<TritonGPUAccelerateMatmulPass> {
public:
TritonGPUAccelerateMatmulPass() = default;
TritonGPUAccelerateMatmulPass(int computeCapability) {
this->computeCapability = computeCapability;
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
mlir::RewritePatternSet patterns(context);
patterns.add<::BlockedToMMA>(context, computeCapability);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
signalPassFailure();
}
}
};
std::unique_ptr<Pass>
mlir::createTritonGPUAccelerateMatmulPass(int computeCapability) {
return std::make_unique<TritonGPUAccelerateMatmulPass>(computeCapability);
}

View File

@@ -1,22 +1,18 @@
set(LLVM_TARGET_DEFINITIONS Combine.td)
mlir_tablegen(TritonGPUCombine.inc -gen-rewriters)
add_public_tablegen_target(TritonGPUCombineIncGen)
add_mlir_dialect_library(TritonGPUTransforms
AccelerateMatmul.cpp
Coalesce.cpp
CanonicalizeLoops.cpp
Combine.cpp
DecomposeConversions.cpp
FuseTranspositions.cpp
Pipeline.cpp
Prefetch.cpp
RemoveLayoutConversions.cpp
ReorderInstructions.cpp
DecomposeConversions.cpp
TritonGPUConversion.cpp
UpdateMmaForVolta.cpp
Utility.cpp
DEPENDS
TritonGPUTransformsIncGen
TritonGPUCombineIncGen
LINK_LIBS PUBLIC
TritonIR

View File

@@ -1,55 +0,0 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::triton;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
namespace {
struct CanonicalizePass
: public TritonGPUCanonicalizeLoopsBase<CanonicalizePass> {
CanonicalizePass() = default;
void runOnOperation() override {
// Canonicalize pass may have created dead code that
// standard scf.for canonicalization cannot handle
// as of LLVM 14. For example, the iteration arguments
// for the pointer of the synchronous loads that are
// discarded.
// The following piece of code is a workaround to
// very crudely remove dead code, by making an iteration
// argument yield itself if it is not used to create
// side effects anywhere.
getOperation()->walk([&](scf::ForOp forOp) -> void {
for (size_t i = 0; i < forOp.getNumResults(); ++i) {
// condition 1: no other iter arguments depend on it
SetVector<Operation *> fwdSlice;
mlir::getForwardSlice(forOp.getRegionIterArgs()[i], &fwdSlice);
Operation *yieldOp = forOp.getBody()->getTerminator();
bool noOtherDependency = std::all_of(
yieldOp->operand_begin(), yieldOp->operand_end(), [&](Value arg) {
return arg == yieldOp->getOperand(i) ||
!fwdSlice.contains(arg.getDefiningOp());
});
// condition 2: final value is not used after the loop
auto retVal = forOp.getResult(i);
bool noUserAfterLoop = retVal.getUsers().empty();
// yielding the region iter arg will cause loop canonicalization
// to clean up the dead code
if (noOtherDependency && noUserAfterLoop) {
yieldOp->setOperand(i, forOp.getRegionIterArgs()[i]);
}
}
});
}
};
} // anonymous namespace
std::unique_ptr<Pass> mlir::createTritonGPUCanonicalizeLoopsPass() {
return std::make_unique<CanonicalizePass>();
}

View File

@@ -1,8 +0,0 @@
#ifndef TRITONGPU_PATTERNS
#define TRITONGPU_PATTERNS
include "triton/Dialect/TritonGPU/IR/TritonGPUOps.td"
include "triton/Dialect/Triton/IR/TritonOps.td"
include "mlir/IR/PatternBase.td"
#endif

View File

@@ -0,0 +1,153 @@
#include "Utility.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include <memory>
using namespace mlir;
namespace {
using triton::DotOp;
using triton::gpu::ConvertLayoutOp;
using triton::gpu::DotOperandEncodingAttr;
using triton::gpu::MmaEncodingAttr;
using triton::gpu::SliceEncodingAttr;
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
public:
explicit OptimizeConvertToDotOperand(mlir::MLIRContext *context)
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
// order
ArrayRef<unsigned> order;
if (auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
order = srcBlockedLayout.getOrder();
else if (auto srcSharedLayout =
srcType.getEncoding()
.dyn_cast<triton::gpu::SharedEncodingAttr>())
order = srcSharedLayout.getOrder();
else
return failure();
// dot operand output
auto dstDotOperandLayout =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (!dstDotOperandLayout)
return failure();
if (!dstDotOperandLayout.getIsMMAv1Row())
return failure();
bool isMMAv1Row =
dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row))
return failure();
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
op->getContext(), dstDotOperandLayout.getOpIdx(),
dstDotOperandLayout.getParent(), newIsRow);
auto newDstType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(), newDstEncoding);
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newDstType, cvt.getOperand());
rewriter.replaceOp(op, newCvt.getResult());
return success();
}
};
// convert(trans(convert(arg)))
// x = convert_layout arg: #distributed -> #shared_x
// y = trans x: #shared_x -> #shared_y
// z = convert_layout y: #shared_y -> #dot_operand
class ConvertTransConvert : public mlir::RewritePattern {
public:
ConvertTransConvert(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
1, context) {}
LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto dstOp = cast<triton::gpu::ConvertLayoutOp>(op);
auto tmpOp =
dyn_cast_or_null<triton::TransOp>(dstOp.getSrc().getDefiningOp());
if (!tmpOp)
return mlir::failure();
auto srcOp = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(
tmpOp.getSrc().getDefiningOp());
if (!srcOp)
return mlir::failure();
auto arg = srcOp.getSrc();
auto X = tmpOp.getSrc();
// types
auto argType = arg.getType().cast<RankedTensorType>();
auto XType = X.getType().cast<RankedTensorType>();
auto ZType = dstOp.getResult().getType().cast<RankedTensorType>();
// encodings
auto argEncoding = argType.getEncoding();
auto XEncoding =
XType.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
auto ZEncoding =
ZType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (!ZEncoding)
return mlir::failure();
// new X encoding
auto newXOrder = triton::gpu::getOrder(argEncoding);
auto newXEncoding = triton::gpu::SharedEncodingAttr::get(
getContext(), ZEncoding, XType.getShape(), newXOrder,
XType.getElementType());
auto newXType = RankedTensorType::get(XType.getShape(),
XType.getElementType(), newXEncoding);
if (XEncoding == newXEncoding)
return mlir::failure();
auto newX = rewriter.create<triton::gpu::ConvertLayoutOp>(srcOp.getLoc(),
newXType, arg);
auto newY = rewriter.create<triton::TransOp>(tmpOp.getLoc(), newX);
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(dstOp, ZType,
newY);
return mlir::success();
}
};
} // namespace
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
class TritonGPUFuseTranspositionsPass
: public TritonGPUFuseTranspositionsBase<TritonGPUFuseTranspositionsPass> {
public:
TritonGPUFuseTranspositionsPass() = default;
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
mlir::PassManager pm(m.getContext());
pm.addPass(mlir::createCanonicalizerPass());
auto ret = pm.run(m);
mlir::RewritePatternSet patterns(context);
patterns.add<OptimizeConvertToDotOperand>(context);
patterns.add<ConvertTransConvert>(context);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
signalPassFailure();
if (fixupLoops(m).failed())
signalPassFailure();
}
};
std::unique_ptr<Pass> mlir::createTritonGPUFuseTranspositionsPass() {
return std::make_unique<TritonGPUFuseTranspositionsPass>();
}

View File

@@ -22,7 +22,6 @@
using namespace mlir;
namespace {
#include "TritonGPUCombine.inc"
using triton::DotOp;
using triton::gpu::ConvertLayoutOp;
using triton::gpu::DotOperandEncodingAttr;
@@ -139,132 +138,7 @@ public:
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op))
return mlir::failure();
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
// we don't handle conversions to DotOperandEncodingAttr
// this is a heuristics to accommodate fused attention
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
auto dstType = convert.getType().cast<RankedTensorType>();
if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>() &&
srcType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
return mlir::failure();
// convert to the same layout -- we can delete
if (op->getResultTypes() == op->getOperandTypes()) {
rewriter.replaceOp(op, op->getOperands());
return mlir::success();
}
Operation *arg = op->getOperand(0).getDefiningOp();
// block argument
if (!arg)
return mlir::failure();
// cvt(view) -> view
if (auto view = dyn_cast<triton::ViewOp>(arg)) {
rewriter.replaceOpWithNewOp<triton::ViewOp>(
op, op->getResult(0).getType(), view.getResult());
return mlir::success();
}
// cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2)
auto alloc_tensor = dyn_cast<triton::gpu::AllocTensorOp>(arg);
if (alloc_tensor) {
if (!isSharedEncoding(op->getResult(0))) {
return mlir::failure();
}
rewriter.replaceOpWithNewOp<triton::gpu::AllocTensorOp>(
op, op->getResult(0).getType());
return mlir::success();
}
// cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2))
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
if (insert_slice) {
if (!isSharedEncoding(op->getResult(0))) {
return mlir::failure();
}
auto newType = op->getResult(0).getType().cast<RankedTensorType>();
// Ensure that the new insert_slice op is placed in the same place as the
// old insert_slice op. Otherwise, the new insert_slice op may be placed
// after the async_wait op, which is not allowed.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(insert_slice);
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newType, insert_slice.getDst());
rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>(
op, newType, insert_slice.getSrc(), newArg.getResult(),
insert_slice.getIndex(), insert_slice.getMask(),
insert_slice.getOther(), insert_slice.getCache(),
insert_slice.getEvict(), insert_slice.getIsVolatile(),
insert_slice.getAxis());
return mlir::success();
}
// cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2))
auto extract_slice = dyn_cast<tensor::ExtractSliceOp>(arg);
if (extract_slice) {
if (!isSharedEncoding(op->getResult(0))) {
return mlir::failure();
}
auto origType =
extract_slice.getSource().getType().cast<RankedTensorType>();
auto newType = RankedTensorType::get(
origType.getShape(), origType.getElementType(),
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
auto origResType = op->getResult(0).getType().cast<RankedTensorType>();
auto resType = RankedTensorType::get(
origResType.getShape(), origResType.getElementType(),
extract_slice.getType().cast<RankedTensorType>().getEncoding());
// Ensure that the new extract_slice op is placed in the same place as the
// old extract_slice op. Otherwise, the new extract_slice op may be placed
// after the async_wait op, which is not allowed.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(extract_slice);
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newType, extract_slice.getSource());
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
op, resType, newArg.getResult(), extract_slice.offsets(),
extract_slice.sizes(), extract_slice.strides(),
extract_slice.static_offsets(), extract_slice.static_sizes(),
extract_slice.static_strides());
return mlir::success();
}
// cvt(cvt(x, type1), type2) -> cvt(x, type2)
if (llvm::isa<triton::gpu::ConvertLayoutOp>(arg)) {
if (arg->getOperand(0).getDefiningOp() &&
!isSharedEncoding(arg->getOperand(0)) &&
isSharedEncoding(convert.getOperand()) &&
!isSharedEncoding(convert.getResult())) {
return mlir::failure();
}
if (isSharedEncoding(convert.getOperand()) &&
isSharedEncoding(convert.getResult())) {
return mlir::failure();
}
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
auto srcShared =
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
if (srcShared && srcShared.getVec() > 1)
return mlir::failure();
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
op, op->getResultTypes().front(), arg->getOperand(0));
return mlir::success();
}
// cvt(type1, splat(type2, x)) -> splat(type1, x)
if (auto splat = llvm::dyn_cast<triton::SplatOp>(arg)) {
rewriter.replaceOpWithNewOp<triton::SplatOp>(op, op->getResultTypes(),
splat.getSrc());
return mlir::success();
}
// cvt(type1, make_range(type2, x)) -> make_range(type1, x)
if (auto range = llvm::dyn_cast<triton::MakeRangeOp>(arg)) {
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
op, op->getResultTypes(), range.getStart(), range.getEnd());
return mlir::success();
}
// cvt(type, constant) -> constant
if (auto cst = llvm::dyn_cast<arith::ConstantOp>(arg))
if (auto ret = cst.getValue().dyn_cast<SplatElementsAttr>()) {
auto newRet = SplatElementsAttr::get(op->getResultTypes().front(),
ret.getSplatValue<Attribute>());
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newRet);
return mlir::success();
}
return mlir::failure();
return ConvertLayoutOp::canonicalize(convert, rewriter);
}
};
@@ -568,9 +442,9 @@ public:
};
//
class FoldConvertAndReduce : public mlir::RewritePattern {
class RematerializeForward : public mlir::RewritePattern {
public:
explicit FoldConvertAndReduce(mlir::MLIRContext *context)
explicit RematerializeForward(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
1, context) {}
@@ -837,393 +711,6 @@ public:
}
};
// -----------------------------------------------------------------------------
//
// -----------------------------------------------------------------------------
class RematerializeForward : public mlir::RewritePattern {
public:
explicit RematerializeForward(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
2, context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *_cvtOp,
mlir::PatternRewriter &rewriter) const override {
auto cvt = cast<triton::gpu::ConvertLayoutOp>(_cvtOp);
auto forOp = dyn_cast<scf::ForOp>(cvt->getParentOp());
if (!forOp)
return mlir::failure();
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
SetVector<Operation *> cvtSlices;
auto filter = [&](Operation *op) {
return isInLoop(op) &&
!isa<triton::LoadOp, triton::StoreOp, triton::AtomicRMWOp,
triton::AtomicCASOp>(op) &&
!isa<triton::DotOp>(op) && !isa<scf::YieldOp>(op) &&
!isa<triton::gpu::ConvertLayoutOp>(op);
};
mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter);
if (cvtSlices.empty())
return failure();
for (Operation *op : cvtSlices) {
if (!isa<triton::ViewOp, triton::CatOp>(op) &&
!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
!op->hasTrait<mlir::OpTrait::Elementwise>() &&
!isa<triton::StoreOp>(op))
return failure();
for (Value arg : op->getOperands()) {
Operation *argOp = arg.getDefiningOp();
if (argOp && (argOp != cvt) &&
!isa<arith::ConstantOp, triton::SplatOp, triton::MakeRangeOp>(
argOp)) {
return failure();
}
}
}
// Otherwise, we push the conversion forward
// since we'll be able to move it out of
// the loop once it reaches the yield op
pushConversionForward(cvt, cvtSlices, rewriter);
return success();
}
};
// -----------------------------------------------------------------------------
//
// -----------------------------------------------------------------------------
namespace {
int computeCapabilityToMMAVersion(int computeCapability) {
#ifdef USE_ROCM
return 1;
#endif
if (computeCapability < 70) {
return 0;
} else if (computeCapability < 80) {
return 1;
} else if (computeCapability < 90) {
return 2;
} else if (computeCapability < 100) {
// FIXME: temporarily add this to pass unis tests
return 2;
} else {
assert(false && "computeCapability > 100 not supported");
return 3;
}
}
SmallVector<int64_t, 2> mmaVersionToShapePerWarp(int version) {
if (version == 1)
return {16, 16};
else if (version == 2)
return {16, 8};
else {
assert(false && "version not supported");
return {0, 0};
}
}
SmallVector<unsigned, 2> warpsPerTileV1(const ArrayRef<int64_t> shape,
int numWarps) {
// Set a default value and ensure product of wpt equals numWarps
return {static_cast<unsigned>(numWarps), 1};
}
SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
const ArrayRef<int64_t> shape,
int numWarps) {
SetVector<Operation *> slices;
mlir::getForwardSlice(dotOp.getResult(), &slices);
if (llvm::find_if(slices, [](Operation *op) {
return isa<triton::DotOp>(op);
}) != slices.end())
return {(unsigned)numWarps, 1};
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
bool changed = false;
// TODO (@daadaada): double-check.
// original logic in
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
// seems buggy for shape = [32, 16] ?
do {
changed = false;
if (ret[0] * ret[1] >= numWarps)
break;
if (shape[0] / shapePerWarp[0] / ret[0] >=
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
if (ret[0] < shape[0] / shapePerWarp[0]) {
ret[0] *= 2;
} else
ret[1] *= 2;
} else {
ret[1] *= 2;
}
} while (true);
return ret;
}
} // namespace
class OptimizeBlockedToShared : public mlir::RewritePattern {
public:
explicit OptimizeBlockedToShared(mlir::MLIRContext *context)
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
auto dstSharedLayout =
dstType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
if (!srcBlockedLayout || !dstSharedLayout)
return failure();
if (srcBlockedLayout.getOrder() == dstSharedLayout.getOrder())
return failure();
// For now only works if single use is transpose
// TODO: rematerialize #shared uses
auto users = op->getUsers();
if (std::distance(users.begin(), users.end()) != 1 ||
!isa<triton::TransOp>(*users.begin()))
return failure();
auto tmpShared = triton::gpu::SharedEncodingAttr::get(
op->getContext(), dstSharedLayout.getVec(),
dstSharedLayout.getPerPhase(), dstSharedLayout.getMaxPhase(),
srcBlockedLayout.getOrder());
auto tmpType = RankedTensorType::get(srcType.getShape(),
srcType.getElementType(), tmpShared);
auto tmpCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), tmpType, cvt.getOperand());
auto newDstType = RankedTensorType::get(
users.begin()->getResultTypes()[0].cast<RankedTensorType>().getShape(),
srcType.getElementType(), dstSharedLayout);
auto newTrans = rewriter.create<triton::TransOp>(op->getLoc(), newDstType,
tmpCvt.getResult());
rewriter.replaceOp(*users.begin(), newTrans.getResult());
return success();
}
};
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
public:
explicit OptimizeConvertToDotOperand(mlir::MLIRContext *context)
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
// order
ArrayRef<unsigned> order;
if (auto srcBlockedLayout =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
order = srcBlockedLayout.getOrder();
else if (auto srcSharedLayout =
srcType.getEncoding()
.dyn_cast<triton::gpu::SharedEncodingAttr>())
order = srcSharedLayout.getOrder();
else
return failure();
// dot operand output
auto dstDotOperandLayout =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (!dstDotOperandLayout)
return failure();
if (!dstDotOperandLayout.getIsMMAv1Row())
return failure();
bool isMMAv1Row =
dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row))
return failure();
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
op->getContext(), dstDotOperandLayout.getOpIdx(),
dstDotOperandLayout.getParent(), newIsRow);
auto newDstType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(), newDstEncoding);
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newDstType, cvt.getOperand());
rewriter.replaceOp(op, newCvt.getResult());
return success();
}
};
class BlockedToMMA : public mlir::RewritePattern {
int computeCapability;
mutable int mmaV1Counter{}; // used to generate ID for MMAv1 encoding
public:
BlockedToMMA(mlir::MLIRContext *context, int computeCapability)
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
computeCapability(computeCapability) {}
static SmallVector<unsigned, 2> getWarpsPerTile(triton::DotOp dotOp,
const ArrayRef<int64_t> shape,
int version, int numWarps) {
switch (version) {
case 1:
return warpsPerTileV1(shape, numWarps);
case 2:
return warpsPerTileV2(dotOp, shape, numWarps);
default:
assert(false && "not supported version");
return {0, 0};
}
}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto dotOp = cast<triton::DotOp>(op);
// TODO: Check data-types and SM compatibility
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
if (!oldRetType.getEncoding() ||
oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
return failure();
// for FMA, should retain the blocked layout.
int versionMajor = computeCapabilityToMMAVersion(computeCapability);
if (!supportMMA(dotOp, versionMajor))
return failure();
// get MMA encoding for the given number of warps
auto retShape = oldRetType.getShape();
auto mod = op->getParentOfType<mlir::ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
auto warpsPerTile =
getWarpsPerTile(dotOp, retShape, versionMajor, numWarps);
triton::gpu::MmaEncodingAttr mmaEnc;
if (versionMajor == 1) {
mmaEnc = triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), versionMajor, numWarps, mmaV1Counter++);
} else if (versionMajor == 2) {
mmaEnc = triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), versionMajor, 0 /*versionMinor*/,
warpsPerTile);
} else {
assert(false && "Mma layout only support versionMajor of 1 or 2");
}
auto newRetType =
RankedTensorType::get(retShape, oldRetType.getElementType(), mmaEnc);
// convert accumulator
auto oldAcc = dotOp.getOperand(2);
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
oldAcc.getLoc(), newRetType, oldAcc);
Value a = dotOp.getA();
Value b = dotOp.getB();
auto oldAType = a.getType().cast<RankedTensorType>();
auto oldBType = b.getType().cast<RankedTensorType>();
auto oldAOrder = oldAType.getEncoding()
.cast<triton::gpu::DotOperandEncodingAttr>()
.getParent()
.cast<triton::gpu::BlockedEncodingAttr>()
.getOrder();
auto oldBOrder = oldBType.getEncoding()
.cast<triton::gpu::DotOperandEncodingAttr>()
.getParent()
.cast<triton::gpu::BlockedEncodingAttr>()
.getOrder();
Attribute isMMAv1RowA;
Attribute isMMAv1RowB;
if (versionMajor == 1) {
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
}
auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(
oldAType.getContext(), 0, newRetType.getEncoding(), isMMAv1RowA));
auto newBType = RankedTensorType::get(
oldBType.getShape(), oldBType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(
oldBType.getContext(), 1, newRetType.getEncoding(), isMMAv1RowB));
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
auto newDot = rewriter.create<triton::DotOp>(
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.getAllowTF32());
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
op, oldRetType, newDot.getResult());
return success();
}
};
// Convert + trans + convert
// x = convert_layout distributed -> #shared_x
// y = trans x -> #shared_y
// z = convert_layout y -> #dot_operand
class ConvertTransConvert : public mlir::RewritePattern {
public:
ConvertTransConvert(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
1, context) {}
LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto dstOp = cast<triton::gpu::ConvertLayoutOp>(op);
auto tmpOp =
dyn_cast_or_null<triton::TransOp>(dstOp.getSrc().getDefiningOp());
if (!tmpOp)
return mlir::failure();
auto srcOp = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(
tmpOp.getSrc().getDefiningOp());
if (!srcOp)
return mlir::failure();
auto arg = srcOp.getSrc();
auto X = tmpOp.getSrc();
// types
auto argType = arg.getType().cast<RankedTensorType>();
auto XType = X.getType().cast<RankedTensorType>();
auto ZType = dstOp.getResult().getType().cast<RankedTensorType>();
// encodings
auto argEncoding = argType.getEncoding();
auto XEncoding =
XType.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
auto ZEncoding =
ZType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (!ZEncoding)
return mlir::failure();
// new X encoding
auto newXOrder = triton::gpu::getOrder(argEncoding);
auto newXEncoding = triton::gpu::SharedEncodingAttr::get(
getContext(), ZEncoding, XType.getShape(), newXOrder,
XType.getElementType());
auto newXType = RankedTensorType::get(XType.getShape(),
XType.getElementType(), newXEncoding);
if (XEncoding == newXEncoding)
return mlir::failure();
auto newX = rewriter.create<triton::gpu::ConvertLayoutOp>(srcOp.getLoc(),
newXType, arg);
auto newY = rewriter.create<triton::TransOp>(tmpOp.getLoc(), newX);
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(dstOp, ZType,
newY);
return mlir::success();
}
};
//
class ConvertDotConvert : public mlir::RewritePattern {
public:
@@ -1275,31 +762,25 @@ public:
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
class TritonGPUCombineOpsPass
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
class TritonGPURemoveLayoutConversionsPass
: public TritonGPURemoveLayoutConversionsBase<
TritonGPURemoveLayoutConversionsPass> {
public:
TritonGPUCombineOpsPass() = default;
TritonGPUCombineOpsPass(int computeCapability) {
this->computeCapability = computeCapability;
}
TritonGPURemoveLayoutConversionsPass() = default;
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
mlir::RewritePatternSet patterns(context);
patterns.add<OptimizeBlockedToShared>(context);
patterns.add<OptimizeConvertToDotOperand>(context);
patterns.add<SimplifyConversion>(context);
patterns.add<SimplifyReduceCvt>(context);
patterns.add<FoldConvertAndReduce>(context);
patterns.add<DecomposeDotOperand>(context);
patterns.add<RematerializeBackward>(context);
patterns.add<RematerializeForward>(context);
patterns.add<MoveConvertOutOfLoop>(context);
patterns.add<MoveConvertOutOfIf>(context);
patterns.add<BlockedToMMA>(context, computeCapability);
patterns.add<ConvertTransConvert>(context);
patterns.add<DecomposeDotOperand>(context);
patterns.add<ConvertDotConvert>(context);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
@@ -1312,7 +793,6 @@ public:
}
};
std::unique_ptr<Pass>
mlir::createTritonGPUCombineOpsPass(int computeCapability) {
return std::make_unique<TritonGPUCombineOpsPass>(computeCapability);
std::unique_ptr<Pass> mlir::createTritonGPURemoveLayoutConversionsPass() {
return std::make_unique<TritonGPURemoveLayoutConversionsPass>();
}

View File

@@ -147,6 +147,12 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
if (!funcs.empty()) {
static const std::string libdevice = "libdevice";
// first search for environmental path
std::string env_path = ::triton::tools::getenv("TRITON_LIBDEVICE_PATH");
if (!env_path.empty()) {
externLibs.try_emplace(libdevice, env_path);
return externLibs;
}
namespace fs = std::filesystem;
// Search for libdevice relative to its library path if used from Python
// Then native code is in `triton/_C/libtriton.so` and libdevice in
@@ -302,12 +308,17 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
/*printAfterOnlyOnChange=*/true,
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(createConvertTritonGPUToLLVMPass(computeCapability));
// Canonicalize to eliminate the remaining UnrealizedConversionCastOp
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass()); // Simplify the IR to improve readability.
pm.addPass(mlir::createSymbolDCEPass());
pm.addPass(mlir::createCanonicalizerPass());
#ifdef USE_ROCM
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(createConvertControlFlowToLLVMPass());
#endif
if (failed(pm.run(module))) {
llvm::errs() << "Pass execution failed";

View File

@@ -65,7 +65,7 @@ def get_llvm_package_info():
linux_suffix = 'ubuntu-18.04' if vglibc > 217 else 'centos-7'
system_suffix = f"linux-gnu-{linux_suffix}"
else:
raise RuntimeError(f"unsupported system: {system}")
return Package("llvm", "LLVM-C.lib", "", "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
release_suffix = "assert" if use_assert_enabled_llvm else "release"
name = f'llvm+mlir-17.0.0-x86_64-{system_suffix}-{release_suffix}'
@@ -159,7 +159,11 @@ class CMakeBuild(build_ext):
def build_extension(self, ext):
lit_dir = shutil.which('lit')
triton_cache_path = os.path.join(os.environ["HOME"], ".triton")
user_home = os.getenv("HOME") or os.getenv("USERPROFILE") or \
os.getenv("HOMEPATH") or None
if not user_home:
raise RuntimeError("Could not find user home directory")
triton_cache_path = os.path.join(user_home, ".triton")
# lit is used by the test suite
thirdparty_cmake_args = get_thirdparty_packages(triton_cache_path)
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))

View File

@@ -1,4 +1,4 @@
#include <pybind11/pybind11.h>
#include <pybind11/pybind11.h>
void init_superblocking(pybind11::module &m);
void init_torch_utils(pybind11::module &m);

View File

@@ -1462,7 +1462,7 @@ void init_triton_ir(py::module &&m) {
.def(
"add_sccp_pass",
[](mlir::PassManager &self) { self.addPass(mlir::createSCCPPass()); })
.def("add_coalesce_pass",
.def("add_tritongpu_coalesce_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUCoalescePass());
})
@@ -1501,10 +1501,18 @@ void init_triton_ir(py::module &&m) {
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUPrefetchPass());
})
.def("add_tritongpu_combine_pass",
.def("add_tritongpu_accelerate_matmul_pass",
[](mlir::PassManager &self, int computeCapability) {
self.addPass(
mlir::createTritonGPUCombineOpsPass(computeCapability));
mlir::createTritonGPUAccelerateMatmulPass(computeCapability));
})
.def("add_tritongpu_fuse_transpositions_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUFuseTranspositionsPass());
})
.def("add_tritongpu_remove_layout_conversions_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPURemoveLayoutConversionsPass());
})
.def("add_tritongpu_update_mma_for_volta_pass",
[](mlir::PassManager &self) {

View File

@@ -8,7 +8,7 @@ import triton
import triton.language as tl
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
DEVICE_NAME = 'v100'
DEVICE_NAME = {7: 'v100', 8: 'a100'}[torch.cuda.get_device_capability()[0]]
#######################
# Utilities
@@ -34,7 +34,6 @@ mem_clocks = {'v100': 877, 'a100': 1215}
matmul_data = {
'v100': {
# square
(256, 256, 256): {'float16': 0.027},
(512, 512, 512): {'float16': 0.158},
(1024, 1024, 1024): {'float16': 0.466},
(2048, 2048, 2048): {'float16': 0.695},
@@ -51,29 +50,26 @@ matmul_data = {
(4096, 64, 4096): {'float16': 0.264},
(8192, 64, 8192): {'float16': 0.452},
},
# NOTE:
# A100 in the CI server is slow-ish for some reason.
# On some other servers, we are getting about 90% peak for 8kx8x8k float16
'a100': {
(256, 256, 256): {'float16': 0.010, 'float32': 0.0214, 'int8': 0.006},
(512, 512, 512): {'float16': 0.061, 'float32': 0.109, 'int8': 0.030},
(1024, 1024, 1024): {'float16': 0.287, 'float32': 0.331, 'int8': 0.169},
(2048, 2048, 2048): {'float16': 0.604, 'float32': 0.599, 'int8': 0.385},
(4096, 4096, 4096): {'float16': 0.842, 'float32': 0.862, 'int8': 0.711},
(8192, 8192, 8192): {'float16': 0.896, 'float32': 0.932, 'int8': 0.860},
(512, 512, 512): {'float16': 0.08, 'float32': 0.13, 'int8': 0.05},
(1024, 1024, 1024): {'float16': 0.33, 'float32': 0.35, 'int8': 0.169},
(2048, 2048, 2048): {'float16': 0.64, 'float32': 0.57, 'int8': 0.34},
(4096, 4096, 4096): {'float16': 0.81, 'float32': 0.75, 'int8': 0.46},
(8192, 8192, 8192): {'float16': 0.77, 'float32': 0.85, 'int8': 0.51},
# tall-skinny
(16, 1024, 1024): {'float16': 0.0077, 'float32': 0.0127, 'int8': 0.005},
(16, 4096, 4096): {'float16': 0.0363, 'float32': 0.0457, 'int8': 0.0259},
(16, 8192, 8192): {'float16': 0.0564, 'float32': 0.0648, 'int8': 0.0431},
(16, 8192, 8192): {'float16': 0.07, 'float32': 0.0648, 'int8': 0.0431},
(64, 1024, 1024): {'float16': 0.0271, 'float32': 0.0509, 'int8': 0.0169},
(64, 4096, 4096): {'float16': 0.141, 'float32': 0.162, 'int8': 0.097},
(64, 8192, 8192): {'float16': 0.244, 'float32': 0.257, 'int8': 0.174},
(64, 4096, 4096): {'float16': 0.16, 'float32': 0.162, 'int8': 0.097},
(64, 8192, 8192): {'float16': 0.30, 'float32': 0.257, 'int8': 0.174},
(1024, 64, 1024): {'float16': 0.0263, 'float32': 0.0458, 'int8': 0.017},
(4096, 64, 4096): {'float16': 0.135, 'float32': 0.177, 'int8': 0.102},
(8192, 64, 8192): {'float16': 0.216, 'float32': 0.230, 'int8': 0.177},
(4096, 64, 4096): {'float16': 0.16, 'float32': 0.177, 'int8': 0.102},
(8192, 64, 8192): {'float16': 0.25, 'float32': 0.230, 'int8': 0.177},
}
# # deep reductions
# (64 , 64 , 16384) : {'a100': 0.},
# (64 , 64 , 65536) : {'a100': 0.},
# (256 , 256 , 8192 ) : {'a100': 0.},
# (256 , 256 , 32768) : {'a100': 0.},
}
@@ -88,9 +84,7 @@ def test_matmul(M, N, K, dtype_str):
torch.manual_seed(0)
ref_gpu_util = matmul_data[DEVICE_NAME][(M, N, K)][dtype_str]
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
ref_sm_clock = sm_clocks[DEVICE_NAME]
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz'
if dtype == torch.int8:
a = torch.randint(-128, 127, (M, K), dtype=dtype, device='cuda')
b = torch.randint(-128, 127, (N, K), dtype=dtype, device='cuda')
@@ -99,10 +93,10 @@ def test_matmul(M, N, K, dtype_str):
a = torch.randn((M, K), dtype=dtype, device='cuda')
b = torch.randn((K, N), dtype=dtype, device='cuda')
fn = lambda: triton.ops.matmul(a, b)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=300)
cur_gpu_perf = 2. * M * N * K / ms * 1e-9
cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
#######################
@@ -149,16 +143,61 @@ elementwise_data = {
def test_elementwise(N):
torch.manual_seed(0)
ref_gpu_util = elementwise_data[DEVICE_NAME][N]
cur_mem_clock = nvsmi(['clocks.current.memory'])[0]
ref_mem_clock = mem_clocks[DEVICE_NAME]
max_gpu_perf = get_dram_gbps()
assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memory must run at {ref_mem_clock} MHz'
z = torch.empty((N, ), dtype=torch.float16, device='cuda')
x = torch.randn_like(z)
y = torch.randn_like(z)
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=250)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=500)
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
#######################
# Flash-Attention
#######################
flash_attention_data = {
"a100": {
(4, 48, 4096, 64, 'forward', 'float16'): 0.37,
(4, 48, 4096, 64, 'backward', 'float16'): 0.25,
}
}
@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [[4, 48, 4096, 64]])
@pytest.mark.parametrize("mode", ['forward', 'backward'])
@pytest.mark.parametrize("dtype_str", ['float16'])
def test_flash_attention(Z, H, N_CTX, D_HEAD, mode, dtype_str):
is_backward = mode == 'backward'
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
pytest.skip("Flash attention only supported for compute capability < 80")
torch.manual_seed(20)
dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str]
# init data
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
sm_scale = 0.2
# benchmark
fn = lambda: triton.ops.attention(q, k, v, sm_scale)
if is_backward:
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=500)
# compute flops
flops_per_matmul = 2. * Z * H * N_CTX * N_CTX * D_HEAD * 0.5
total_flops = 2 * flops_per_matmul
if is_backward:
total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
cur_gpu_perf = total_flops / ms * 1e-9
# maximum flops
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
cur_gpu_util = cur_gpu_perf / max_gpu_perf
ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, mode, dtype_str)]
assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)

View File

@@ -1034,7 +1034,7 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype",
[(*shape, 4, False, False, epilogue, allow_tf32, dtype)
for shape in [(64, 64, 64)]
for shape in [(64, 64, 64), (16, 16, 16)]
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
for allow_tf32 in [True, False]
for dtype in ['float16', 'float32']
@@ -1063,6 +1063,9 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
pytest.skip("Only test int8 on devices with sm >= 80")
elif dtype == 'float32' and allow_tf32:
pytest.skip("Only test tf32 on devices with sm >= 80")
if capability[0] == 7:
if (M, N, K, num_warps) == (128, 256, 32, 8):
pytest.skip("shared memory out of resource")
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
@@ -1298,18 +1301,18 @@ def test_noop(device='cuda'):
kernel[(1, )](x)
@pytest.mark.parametrize("device", ['cuda', 'cpu'])
@pytest.mark.parametrize("device", ['cuda', 'cpu', 'cpu_pinned'])
def test_pointer_arguments(device):
@triton.jit
def kernel(x):
pass
x = torch.empty(1024, device=device)
result = True
try:
kernel[(1,)](x)
except ValueError:
result = True if device == 'cpu' else False
assert result
pin_memory = 'pinned' in device
x = torch.empty(1024, device=device.split('_')[0], pin_memory=pin_memory)
if device == "cpu":
with pytest.raises(ValueError):
kernel[(1,)](x)
else:
kernel[(1, )](x)
@pytest.mark.parametrize("value, value_type", [

View File

@@ -16,7 +16,6 @@ import tempfile
import warnings
from collections import namedtuple
from pathlib import Path
from sysconfig import get_paths
from typing import Any, Callable, Dict, Tuple, Union
import setuptools
@@ -981,32 +980,32 @@ def ast_to_ttir(fn, signature, specialization, constants):
return optimize_triton_ir(mod)
def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability):
def ttir_to_ttgir(mod, num_warps):
pm = _triton.ir.pass_manager(mod.context)
pm.add_convert_triton_to_tritongpu_pass(num_warps)
pm.run(mod)
return mod
def optimize_ttgir(mod, num_stages, compute_capability):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_coalesce_pass()
# The combine pass converts blocked layout to mma layout
# for dot ops so that pipeline can get shared memory swizzled correctly.
pm.add_tritongpu_combine_pass(compute_capability)
pm.add_tritongpu_coalesce_pass()
pm.add_tritongpu_accelerate_matmul_pass(compute_capability)
pm.add_tritongpu_remove_layout_conversions_pass()
pm.add_tritongpu_fuse_transpositions_pass()
pm.add_tritongpu_pipeline_pass(num_stages)
# Prefetch must be done after pipeline pass because pipeline pass
# extracts slices from the original tensor.
pm.add_tritongpu_prefetch_pass()
pm.add_canonicalizer_pass()
pm.add_cse_pass()
pm.add_tritongpu_combine_pass(compute_capability)
pm.add_licm_pass()
pm.add_tritongpu_combine_pass(compute_capability)
pm.add_cse_pass()
pm.add_tritongpu_fuse_transpositions_pass()
pm.add_tritongpu_remove_layout_conversions_pass()
pm.add_tritongpu_decompose_conversions_pass()
if compute_capability // 10 == 7:
# The update_mma_for_volta pass helps to compute some information for MMA encoding specifically for MMAv1
# NOTE this pass should be placed after all the passes those modifies mma layout
pm.add_tritongpu_update_mma_for_volta_pass()
pm.add_tritongpu_reorder_instructions_pass()
pm.add_cse_pass()
pm.add_symbol_dce_pass()
pm.add_tritongpu_reorder_instructions_pass()
pm.run(mod)
return mod
@@ -1459,11 +1458,6 @@ def default_cache_dir():
return os.path.join(os.environ["HOME"], ".triton", "cache")
def default_cuda_dir():
default_dir = "/usr/local/cuda"
return os.getenv("CUDA_HOME", default=default_dir)
class CacheManager:
def __init__(self, key):
@@ -1537,14 +1531,15 @@ def _build(name, src, srcdir):
hip_include_dir = os.path.join(hip_home_dirs(), "include")
else:
cuda_lib_dirs = libcuda_dirs()
cuda_path = os.environ.get('CUDA_PATH', default_cuda_dir())
base_dir = os.path.dirname(__file__)
cuda_path = os.path.join(base_dir, "third_party", "cuda")
cu_include_dir = os.path.join(cuda_path, "include")
triton_include_dir = os.path.join(os.path.dirname(__file__), "include")
cuda_header = os.path.join(cu_include_dir, "cuda.h")
triton_cuda_header = os.path.join(triton_include_dir, "cuda.h")
if not os.path.exists(cuda_header) and os.path.exists(triton_cuda_header):
cu_include_dir = triton_include_dir
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
# try to avoid setuptools if possible
@@ -1556,7 +1551,17 @@ def _build(name, src, srcdir):
cc = gcc if gcc is not None else clang
if cc is None:
raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.")
py_include_dir = get_paths()["include"]
# This function was renamed and made public in Python 3.10
if hasattr(sysconfig, 'get_default_scheme'):
scheme = sysconfig.get_default_scheme()
else:
scheme = sysconfig._get_default_scheme()
# 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install
# path changes to include 'local'. This change is required to use triton with system-wide python.
if scheme == 'posix_local':
scheme = 'posix_prefix'
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
if torch.version.hip is not None:
ret = subprocess.check_call([cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{hip_lib_dir}", "-lamdhip64", "-o", so])
else:
@@ -1751,30 +1756,30 @@ def compile(fn, **kwargs):
raise RuntimeError('gfx_arch is None (not specified)')
stages = {
"ast": (lambda path: fn, None),
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
"ttir": (lambda path: parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
"ttgir": (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, capability)),
"llir": (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, capability)),
lambda src: ttgir_to_llir(src, extern_libs, capability)),
"amdgcn": (lambda path: Path(path).read_text(),
lambda src: llir_to_amdgcn_and_hsaco(src, gfx_arch,
gfx_arch_full_details[0],
lambda src: llir_to_amdgcn_and_hsaco(src, gfx_arch,
gfx_arch_full_details[0],
gfx_arch_full_details[2])),
}
else:
stages = {
"ast": (lambda path: fn, None),
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
"ttir": (lambda path: parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
"ttgir": (lambda path: parse_mlir_module(path, context),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, capability)),
"llir": (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, capability)),
lambda src: ttgir_to_llir(src, extern_libs, capability)),
"ptx": (lambda path: Path(path).read_text(),
lambda src: llir_to_ptx(src, capability)),
"cubin": (lambda path: Path(path).read_bytes(),
lambda src: ptx_to_cubin(src, capability))
lambda src: ptx_to_cubin(src, capability))
}
# find out the signature of the function

View File

@@ -9,6 +9,8 @@ from triton._C.libtriton.triton import ir
T = TypeVar('T')
TRITON_MAX_TENSOR_NUMEL = 131072
def _to_tensor(x, builder):
if isinstance(x, bool):
@@ -254,6 +256,8 @@ class block_type(dtype):
self.numel = 1
for s in self.shape:
self.numel *= s
if self.numel > TRITON_MAX_TENSOR_NUMEL:
raise ValueError(f"numel ({self.numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})")
self.name = self.__str__()
@@ -702,12 +706,13 @@ def num_programs(axis, _builder=None):
@builtin
def arange(start, end, _builder=None):
"""
Returns contiguous values within the open interval [:code:`start`, :code:`end`).
Returns contiguous values within the left-closed and right-open interval [:code:`start`, :code:`end`). \
End - Start must be less than or equal to TRITON_MAX_TENSOR_NUMEL = 131072
:param start: Start of the interval. Must be a power of two.
:type start: int
:param stop: End of the interval. Must be a power of two >= start.
:type stop: int
:type start: int32
:param end: End of the interval. Must be a power of two > start.
:type end: int32
"""
start = _constexpr_to_value(start)
end = _constexpr_to_value(end)

View File

@@ -5,9 +5,10 @@ from .. import impl
from . import core, extern
if torch.version.hip is not None:
LIBDEVICE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "cuda2gcn.bc")
LOCAL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "cuda2gcn.bc")
else:
LIBDEVICE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "third_party", "cuda", "lib", "libdevice.10.bc")
LOCAL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "third_party", "cuda", "lib", "libdevice.10.bc")
LIBDEVICE_PATH = os.getenv("TRITON_LIBDEVICE_PATH", LOCAL_PATH)
@impl.extern
def clz(arg0, _builder=None):

View File

@@ -480,6 +480,12 @@ def not_equal(input: tl.tensor,
def arange(start: int, end: int, builder: ir.builder) -> tl.tensor:
if not isinstance(start, int) or not isinstance(end, int):
raise ValueError("arange's arguments must be of type tl.constexpr")
is_start_int64 = bool(start >> 32)
is_end_int64 = bool(end >> 32)
if is_start_int64 or is_end_int64:
raise ValueError("arange must fit in int32")
if end <= start:
raise ValueError("arange's end argument must be greater than the start argument")
shape = [end - start]
ret_ty = tl.block_type(tl.int32, shape)
@@ -655,6 +661,8 @@ def cast(input: tl.tensor,
dst_ty: tl.dtype,
builder: ir.builder) -> tl.tensor:
src_ty = input.type
if isinstance(dst_ty, tl.constexpr):
dst_ty = dst_ty.value
if src_ty.is_block():
dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes())
if src_ty == dst_ty:

View File

@@ -62,9 +62,9 @@ class Autotuner(KernelInterface):
self.hook(args)
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
try:
return do_bench(kernel_call)
return do_bench(kernel_call, percentiles=(0.5, 0.2, 0.8))
except OutOfResources:
return float('inf')
return (float('inf'), float('inf'), float('inf'))
def run(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))

View File

@@ -93,7 +93,11 @@ def assert_almost_equal(x, y, decimal=2, err_msg=''):
npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal)
def allclose(x, y, tol=1e-2):
def allclose(x, y, atol=0, rtol=1e-2):
if not isinstance(x, torch.Tensor):
x = torch.tensor(x)
if not isinstance(y, torch.Tensor):
y = torch.tensor(y)
if x.dtype != y.dtype:
raise RuntimeError(f'{x.dtype} did not match with {x.dtype}')
if x.shape != y.shape:
@@ -101,12 +105,11 @@ def allclose(x, y, tol=1e-2):
if x.dtype == torch.bool:
return torch.sum(x ^ y) == 0
if x.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
tol = 0
rtol = 0
diff = abs(x - y)
x_max = torch.max(x)
y_max = torch.max(y)
err = torch.max(diff) / torch.max(x_max, y_max)
return err <= tol
return torch.max(diff) <= atol + rtol * torch.max(x_max, y_max)
def nvsmi(attrs):

View File

@@ -83,7 +83,8 @@ if __name__ == '__main__':
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
# triton-ir -> triton-gpu-ir
module = triton.compiler.ttir_to_ttgir(module, num_warps=4, num_stages=3, compute_capability=args.sm)
module = triton.compiler.ttir_to_ttgir(module, num_warps=4)
module = triton.compiler.optimize_ttgir(module, num_stages=3, compute_capability=args.sm)
if args.target == 'triton-gpu-ir':
print(module.str())
sys.exit(0)

View File

@@ -289,7 +289,8 @@ class Libdevice(ExternLibrary):
# return extern.dispatch("libdevice", <path>, <args>, <arg_type_symbol_dict>, _builder)
import_str = "from . import core, extern\n"
import_str += "import os\n"
header_str = "LIBDEVICE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), \"..\", \"third_party\", \"cuda\", \"lib\", \"libdevice.10.bc\")"
header_str = "LOCAL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), \"..\", \"third_party\", \"cuda\", \"lib\", \"libdevice.10.bc\")\n"
header_str += "LIBDEVICE_PATH = os.getenv(\"TRITON_LIBDEVICE_PATH\", LOCAL_PATH)\n"
func_str = ""
for symbols in self._symbol_groups.values():
func_str += "@extern.extern\n"

View File

@@ -223,6 +223,7 @@ class _attention(torch.autograd.Function):
BLOCK_DMODEL=Lk, num_warps=num_warps,
num_stages=2,
)
# print(h.asm["ttgir"])
ctx.save_for_backward(q, k, v, o, L, m)
ctx.grid = grid
@@ -260,6 +261,7 @@ class _attention(torch.autograd.Function):
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
num_stages=1,
)
# print(h.asm["ttgir"])
return dq, dk, dv, None

View File

@@ -179,8 +179,8 @@ func.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
return
}
// CHECK-LABEL: for_if_for
func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
// CHECK-LABEL: for_for_if
func.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
// CHECK: %cst -> %cst
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: %cst_0 -> %cst_0
@@ -213,3 +213,34 @@ func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>
}
return
}
// CHECK-LABEL: cf_for
func.func @cf_for(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr<f16>, %arg4: !tt.ptr<f16>) {
// CHECK: %cst -> %cst
%cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: %cst_0 -> %cst_0
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #A_SHARED>
gpu.barrier
// CHECK-NEXT: %0 -> %0
%0 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
// CHECK-NEXT: %cst_1 -> %cst_1
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: %2 -> %cst,%cst_0,%cst_1
// CHECK-NEXT: %3 -> %cst,%cst_0,%cst_1
// CHECK-NEXT: %4 -> %cst,%cst_0,%cst_1
cf.br ^bb1(%arg0, %cst, %cst_0, %cst_1 : index, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>)
^bb1(%1: index, %2: tensor<128x32xf16, #A_SHARED>, %3: tensor<128x32xf16, #A_SHARED>, %4: tensor<128x32xf16, #A_SHARED>): // 2 preds: ^bb0, ^bb2
%5 = arith.cmpi slt, %1, %arg1 : index
cf.cond_br %5, ^bb2, ^bb3
^bb2: // pred: ^bb1
%6 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #blocked>
gpu.barrier
%7 = tt.cat %2, %3 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #blocked>
%8 = arith.addi %1, %arg2 : index
cf.br ^bb1(%8, %4, %2, %3 : index, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>)
^bb3: // pred: ^bb1
gpu.barrier
// CHECK-NEXT: %9 -> %9
%9 = tt.cat %0, %0 {axis = 0 : i64} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
return
}

View File

@@ -315,8 +315,8 @@ func.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.pt
// a_shared_init, b_shared_init, and c_shared_init's liveness ranges are span over the entire function before cst2.
// So they cannot be reused by cst0 and cst1, but can be reused by cst2.
// CHECK-LABEL: for_if_for
func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
// CHECK-LABEL: for_for_if
func.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
// CHECK: offset = 0, size = 8192
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: offset = 8192, size = 8192

View File

@@ -1,4 +1,4 @@
// RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-membar 2>&1 | FileCheck %s
// RUN: triton-opt %s -split-input-file --mlir-disable-threading --convert-scf-to-cf -test-print-membar 2>&1 | FileCheck %s
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#sliceAd0 = #triton_gpu.slice<{dim = 0, parent = #AL}>
@@ -45,11 +45,12 @@ func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16
func.func @raw_single_block(%A : !tt.ptr<f16>) {
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_SHARED>
// CHECK: Membar 5
%a2 = triton_gpu.convert_layout %a1 : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #A_SHARED>
%0 = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
%1 = tt.load %0, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
%2 = triton_gpu.convert_layout %1 : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: triton_gpu.convert_layout
%3 = triton_gpu.convert_layout %2 : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #A_SHARED>
return
}
@@ -57,46 +58,51 @@ func.func @raw_single_block(%A : !tt.ptr<f16>) {
func.func @war_single_block(%A : !tt.ptr<f16>) {
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_SHARED>
// CHECK: Membar 5
%a2 = triton_gpu.convert_layout %a1 : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL>
// a2's liveness range ends here, and a3 and a2 have the same address range.
// So it makes sense to have a WAR dependency between a2 and a3.
// CHECK-NEXT: Membar 7
%a3 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_SHARED>
%0 = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
%1 = tt.load %0, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
%2 = triton_gpu.convert_layout %1 : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: triton_gpu.convert_layout
%3 = triton_gpu.convert_layout %2 : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL>
// CHECK: gpu.barrier
// CHECK-NEXT: %4 = triton_gpu.convert_layout
%4 = triton_gpu.convert_layout %1 : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_SHARED>
return
}
// CHECK-LABEL: scratch
func.func @scratch() {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK: Membar 1
%a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
// CHECK-NEXT: Membar 3
%aa = triton_gpu.convert_layout %a : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL>
%b = tt.reduce %aa {redOp = 1 : i32, axis = 0 : i32} : tensor<32x16xf16, #AL> -> tensor<16xf16, #sliceAd0>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%0 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: triton_gpu.convert_layout
%1 = triton_gpu.convert_layout %0 : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL>
%2 = tt.reduce %1 {redOp = 1 : i32, axis = 0 : i32} : tensor<32x16xf16, #AL> -> tensor<16xf16, #sliceAd0>
return
}
// CHECK-LABEL: async_wait
func.func @async_wait() {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK: Membar 1
%a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%0 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
triton_gpu.async_wait {num = 4 : i32}
// CHECK-NEXT: Membar 4
%a_ = triton_gpu.convert_layout %a : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL>
// CHECK: gpu.barrier
// CHECK-NEXT: triton_gpu.convert_layout
%1 = triton_gpu.convert_layout %0 : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL>
return
}
// CHECK-LABEL: alloc
func.func @alloc() {
%cst0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
%a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
// CHECK: Membar 2
%b = triton_gpu.convert_layout %a : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL>
%0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
%1 = tt.cat %0, %0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: triton_gpu.convert_layout
%2 = triton_gpu.convert_layout %1 : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL>
return
}
@@ -104,50 +110,58 @@ func.func @alloc() {
func.func @extract_slice() {
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
%index = arith.constant 0 : index
%cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED>
// CHECK: Membar 3
%cst2 = triton_gpu.convert_layout %cst1 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
// CHECK-NEXT: Membar 5
%cst3 = triton_gpu.convert_layout %cst2 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A_SHARED>
%0 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: triton_gpu.convert_layout
%1 = triton_gpu.convert_layout %0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
// CHECK: gpu.barrier
// CHECK-NEXT: triton_gpu.convert_layout
%2 = triton_gpu.convert_layout %1 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A_SHARED>
return
}
// CHECK-LABEL: trans
func.func @trans() {
// CHECK-NOT: gpu.barrier
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
%b = tt.trans %cst0 : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED_T>
return
}
// CHECK-LABEL: insert_slice_async
func.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
// CHECK-LABEL: insert_slice_async_op
func.func @insert_slice_async_op(%A : !tt.ptr<f16>, %i1 : i1) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
%tensor = triton_gpu.alloc_tensor : tensor<1x16x16xf16, #A_SHARED>
%index = arith.constant 0 : i32
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<1x16x16xf16, #A_SHARED>
// CHECK: Membar 6
%b = tt.cat %a, %a {axis = 0} : (tensor<1x16x16xf16, #A_SHARED>, tensor<1x16x16xf16, #A_SHARED>) -> tensor<2x16x16xf16, #A_SHARED>
// CHECK: Membar 8
%c = tt.cat %b, %b {axis = 0} : (tensor<2x16x16xf16, #A_SHARED>, tensor<2x16x16xf16, #A_SHARED>) -> tensor<4x16x16xf16, #A_SHARED>
%3 = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<1x16x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%4 = tt.cat %3, %3 {axis = 0} : (tensor<1x16x16xf16, #A_SHARED>, tensor<1x16x16xf16, #A_SHARED>) -> tensor<2x16x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%5 = tt.cat %4, %4 {axis = 0} : (tensor<2x16x16xf16, #A_SHARED>, tensor<2x16x16xf16, #A_SHARED>) -> tensor<4x16x16xf16, #A_SHARED>
return
}
// CHECK-LABEL: insert_slice
func.func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
// CHECK-LABEL: insert_slice_op
func.func @insert_slice_op(%A : !tt.ptr<f16>, %i1 : i1) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
%tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
%index = arith.constant 0 : index
%al = tt.load %a_ptr, %mask, %other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #AL>
// CHECK: Membar 6
%a = tensor.insert_slice %al into %tensor[%index, 0, 0][1, 16, 16][1, 1, 1]: tensor<16x16xf16, #AL> into tensor<1x16x16xf16, #A_SHARED>
// CHECK: Membar 8
%b = tt.cat %a, %a {axis = 0} : (tensor<1x16x16xf16, #A_SHARED>, tensor<1x16x16xf16, #A_SHARED>) -> tensor<2x16x16xf16, #A_SHARED>
// CHECK: Membar 10
%c = tt.cat %b, %b {axis = 0} : (tensor<2x16x16xf16, #A_SHARED>, tensor<2x16x16xf16, #A_SHARED>) -> tensor<4x16x16xf16, #A_SHARED>
%2 = tt.load %a_ptr, %mask, %other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #AL>
// CHECK: gpu.barrier
// CHECK-NEXT: tensor.insert_slice
%3 = tensor.insert_slice %2 into %tensor[%index, 0, 0][1, 16, 16][1, 1, 1]: tensor<16x16xf16, #AL> into tensor<1x16x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%4 = tt.cat %3, %3 {axis = 0} : (tensor<1x16x16xf16, #A_SHARED>, tensor<1x16x16xf16, #A_SHARED>) -> tensor<2x16x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%5 = tt.cat %4, %4 {axis = 0} : (tensor<2x16x16xf16, #A_SHARED>, tensor<2x16x16xf16, #A_SHARED>) -> tensor<4x16x16xf16, #A_SHARED>
return
}
@@ -157,18 +171,21 @@ func.func @multi_blocks(%i1 : i1) {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
scf.if %i1 {
// CHECK: Membar 2
%a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%0 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
scf.yield
} else {
%cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%cst3 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK-NEXT: Membar 7
%b = tt.cat %cst2, %cst3 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%1 = tt.cat %cst2, %cst3 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
scf.yield
}
// CHECK-NEXT: Membar 10
%c = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%2 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
return
}
@@ -178,12 +195,14 @@ func.func @multi_blocks_join_barrier(%i1 : i1) {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
scf.if %i1 {
// CHECK: Membar 2
%a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%0 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
scf.yield
} else {
// CHECK-NEXT: Membar 5
%a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%1 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
scf.yield
}
%a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
@@ -196,17 +215,42 @@ func.func @multi_blocks_yield(%i1 : i1) {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%a = scf.if %i1 -> (tensor<32x16xf16, #A_SHARED>) {
// CHECK: Membar 2
%a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
scf.yield %a : tensor<32x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%0 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
scf.yield %0 : tensor<32x16xf16, #A_SHARED>
} else {
// CHECK-NEXT: Membar 5
%b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
scf.yield %b : tensor<32x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%1 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
scf.yield %1 : tensor<32x16xf16, #A_SHARED>
}
%a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
// CHECK-NEXT: Membar 9
%b = tt.cat %a, %a {axis = 0} : (tensor<32x16xf16, #A_SHARED>, tensor<32x16xf16, #A_SHARED>) -> tensor<64x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%4 = tt.cat %a, %a {axis = 0} : (tensor<32x16xf16, #A_SHARED>, tensor<32x16xf16, #A_SHARED>) -> tensor<64x16xf16, #A_SHARED>
return
}
// Even though the entry block doesn't have a barrier, the successors should have barriers
// CHECK-LABEL: multi_blocks_entry_no_shared
func.func @multi_blocks_entry_no_shared(%i1 : i1) {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
%a = scf.if %i1 -> (tensor<32x16xf16, #A_SHARED>) {
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%0 = tt.cat %cst1, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
scf.yield %0 : tensor<32x16xf16, #A_SHARED>
} else {
// CHECK-NOT: gpu.barrier
// CHECK: arith.constant
%cst1 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #A_SHARED>
scf.yield %cst1 : tensor<32x16xf16, #A_SHARED>
}
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%1 = tt.cat %a, %a {axis = 0} : (tensor<32x16xf16, #A_SHARED>, tensor<32x16xf16, #A_SHARED>) -> tensor<64x16xf16, #A_SHARED>
return
}
@@ -216,11 +260,14 @@ func.func @multi_blocks_noelse(%i1 : i1) {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
scf.if %i1 {
// CHECK: Membar 2
%a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%0 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
scf.yield
}
%a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
// CHECK: gpu.barrier
// CHECK-NEXT: triton_gpu.convert_layout
%1 = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
return
}
@@ -231,18 +278,21 @@ func.func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) {
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
scf.if %i1 {
scf.if %i2 {
// CHECK: Membar 2
%b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%0 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
scf.yield
}
scf.yield
} else {
// CHECK-NEXT: Membar 6
%b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%1 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
scf.yield
}
// CHECK-NEXT: Membar 9
%a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
// CHECK: gpu.barrier
// CHECK-NEXT: triton_gpu.convert_layout
%2 = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
return
}
@@ -252,8 +302,9 @@ func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
// CHECK-NEXT: Membar 3
%cst0 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%5 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
return
@@ -265,17 +316,20 @@ func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
func.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: Membar 2
%cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
%cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #AL>
// CHECK-NEXT: Membar 6
%cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #AL>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%7 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #AL>
scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
// CHECK-NEXT: Membar 9
%cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%9 = tt.cat %0, %0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
return
}
@@ -285,41 +339,162 @@ func.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
func.func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: Membar 2
%cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
// CHECK-NEXT: Membar 5
%cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
// CHECK-NEXT: Membar 7
%cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%6 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%7 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
// CHECK-NEXT: Membar 10
%cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%9 = tt.cat %0, %0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
return
}
// CHECK-LABEL: for_reuse_nested
func.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: Membar 2
%cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
// CHECK-NEXT: Membar 5
%cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%6 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
%a_shared_next, %b_shared_next, %c_shared_next = scf.for %ivv = %lb to %ub step %step iter_args(%a_shared_nested = %a_shared_init, %b_shared_nested = %b_shared_init, %c_shared_nested = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
// CHECK-NEXT: Membar 7
%cst2 = tt.cat %a_shared_nested, %b_shared_nested {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%12 = tt.cat %a_shared_nested, %b_shared_nested {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
scf.yield %c_shared_nested, %a_shared_nested, %b_shared_nested : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
// CHECK-NEXT: Membar 11
%cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%15 = tt.cat %0, %0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
return
}
// repeatedly write to the same shared memory addresses
// CHECK-LABEL: for_for_if
func.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
%c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (tensor<128x32xf16, #A_SHARED>) {
%c_shared_next_next = scf.if %i1 -> tensor<128x32xf16, #A_SHARED> {
// CHECK: gpu.barrier
// CHECK-NEXT: arith.constant
%cst0 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
scf.yield %cst0 : tensor<128x32xf16, #A_SHARED>
} else {
// CHECK: gpu.barrier
// CHECK-NEXT: arith.constant
%cst0 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
scf.yield %cst0 : tensor<128x32xf16, #A_SHARED>
}
scf.yield %c_shared_next_next : tensor<128x32xf16, #A_SHARED>
}
scf.yield %a_shared, %b_shared, %c_shared_next : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
return
}
// c_block_next can either be converted from c_shared_init or c_shared_next_next
// CHECK-LABEL: for_if_for
func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK: gpu.barrier
%c_blocked = triton_gpu.convert_layout %c_shared_init : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL>
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
%c_shared_next_next = scf.if %i1 -> tensor<128x32xf16, #A_SHARED> {
// CHECK: gpu.barrier
// CHECK-NEXT: arith.constant
%cst0 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
scf.yield %cst0 : tensor<128x32xf16, #A_SHARED>
} else {
%c_shared_ = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (tensor<128x32xf16, #A_SHARED>) {
// CHECK: gpu.barrier
// CHECK-NEXT: triton_gpu.convert_layout
%c_blocked_next = triton_gpu.convert_layout %c_shared_next : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL>
scf.yield %c_shared : tensor<128x32xf16, #A_SHARED>
}
scf.yield %c_shared_ : tensor<128x32xf16, #A_SHARED>
}
// CHECK-NOT: gpu.barrier
%b_blocked_next = triton_gpu.convert_layout %b_shared: (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL>
scf.yield %a_shared, %b_shared, %c_shared_next_next : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
return
}
// CHECK-LABEL: cf_if
func.func @cf_if(%i1 : i1) {
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
cf.cond_br %i1, ^bb1, ^bb2
^bb1: // pred: ^bb0
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%0 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
cf.br ^bb2
^bb2: // 2 preds: ^bb0, ^bb1
// CHECK: gpu.barrier
// CHECK-NEXT: triton_gpu.convert_layout
%1 = triton_gpu.convert_layout %cst : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<16x16xf16, #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>>
return
}
func.func @cf_if_else(%i1 : i1) {
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
cf.cond_br %i1, ^bb1, ^bb2
^bb1: // pred: ^bb0
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%0 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
cf.br ^bb3(%0 : tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>)
^bb2: // pred: ^bb0
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%1 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
cf.br ^bb3(%1 : tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>)
^bb3(%2: tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>): // 2 preds: ^bb1, ^bb2
cf.br ^bb4
^bb4: // pred: ^bb3
%3 = triton_gpu.convert_layout %cst : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<16x16xf16, #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%4 = tt.cat %2, %2 {axis = 0 : i64} : (tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<64x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
return
}
func.func @cf_if_else_return(%i1 : i1) {
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
cf.cond_br %i1, ^bb1, ^bb2
^bb1: // pred: ^bb0
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%0 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
return
^bb2: // pred: ^bb0
// CHECK: gpu.barrier
// CHECK-NEXT: tt.cat
%1 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
return
}

View File

@@ -1,4 +1,4 @@
// RUN: triton-opt %s -split-input-file -tritongpu-coalesce -canonicalize | FileCheck %s
// RUN: triton-opt %s -split-input-file -tritongpu-coalesce | FileCheck %s
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>

View File

@@ -1,4 +1,4 @@
// RUN: triton-opt %s -split-input-file -tritongpu-combine 2>&1 | FileCheck %s
// RUN: triton-opt %s -split-input-file -tritongpu-remove-layout-conversions 2>&1 | FileCheck %s
#layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>

View File

@@ -1,4 +1,4 @@
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu -tritongpu-combine -tritongpu-pipeline=num-stages=3 -tritongpu-combine -test-print-allocation 2>&1 | FileCheck %s
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu -tritongpu-remove-layout-conversions -tritongpu-pipeline=num-stages=3 -test-print-allocation 2>&1 | FileCheck %s
// CHECK: offset = 0, size = 49152
// CHECK: offset = 49152, size = 49152

View File

@@ -1,4 +1,4 @@
// RUN: triton-opt %s -split-input-file -tritongpu-combine -tritongpu-update-mma-for-volta 2>&1 | FileCheck %s
// RUN: triton-opt %s -split-input-file -tritongpu-fuse-transposition -tritongpu-update-mma-for-volta 2>&1 | FileCheck %s
// -----

View File

@@ -65,8 +65,19 @@ struct TestAliasPass
};
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
if (op->getNumResults() < 1)
if (op->getNumResults() < 1) {
// cond br, br
if (auto branch = dyn_cast<BranchOpInterface>(op)) {
auto *block = branch->getBlock();
for (auto arg : llvm::enumerate(block->getArguments())) {
auto operand = block->getArgument(arg.index());
auto opNames = getAllocOpNames(operand);
auto argName = getValueOperandName(arg.value(), state);
print(argName, opNames, os);
}
}
return;
}
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
for (auto arg : llvm::enumerate(forOp.getRegionIterArgs())) {
auto operand = forOp.getOpOperandForRegionIterArg(arg.value()).get();

View File

@@ -1,6 +1,8 @@
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Analysis/Membar.h"
@@ -24,21 +26,13 @@ struct TestMembarPass
// Convert to std::string can remove quotes from op_name
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
os << opName << "\n";
// Print all ops after membar pass
Allocation allocation(operation);
MembarAnalysis membarPass(&allocation);
membarPass.run();
size_t operationId = 0;
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
if (isa<gpu::BarrierOp>(op)) {
os << "Membar " << operationId << "\n";
}
if (op->getNumRegions() == 0) {
// Don't count parent Operation to simplify the test.
operationId++;
}
return;
});
os << *operation << "\n";
}
};