mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
Merge remote-tracking branch 'upstream/main' into triton-mlir-IFU-02232023
This commit is contained in:
30
.github/workflows/integration-tests.yml
vendored
30
.github/workflows/integration-tests.yml
vendored
@@ -3,9 +3,10 @@ name: Integration Tests
|
||||
on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- triton-mlir
|
||||
branches: [main]
|
||||
merge_group:
|
||||
branches: [main]
|
||||
types: [checks_requested]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.ref }}
|
||||
@@ -21,7 +22,7 @@ jobs:
|
||||
id: set-matrix
|
||||
run: |
|
||||
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
|
||||
echo '::set-output name=matrix::[["self-hosted", "A10"], ["self-hosted", "V100"], "macos-10.15"]'
|
||||
echo '::set-output name=matrix::[["self-hosted", "A100"], ["self-hosted", "V100"], "macos-10.15"]'
|
||||
else
|
||||
echo '::set-output name=matrix::["ubuntu-latest", "macos-10.15"]'
|
||||
fi
|
||||
@@ -43,29 +44,33 @@ jobs:
|
||||
run: |
|
||||
rm -rf ~/.triton/cache/
|
||||
|
||||
- name: Update path
|
||||
run: |
|
||||
echo "$HOME/.local/bin/" >> $GITHUB_PATH
|
||||
|
||||
- name: Check imports
|
||||
if: ${{ matrix.runner != 'macos-10.15' }}
|
||||
run: |
|
||||
pip install isort
|
||||
pip3 install isort
|
||||
isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )
|
||||
|
||||
- name: Check python style
|
||||
if: ${{ matrix.runner != 'macos-10.15' }}
|
||||
run: |
|
||||
pip install autopep8
|
||||
pip3 install autopep8
|
||||
autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )
|
||||
|
||||
- name: Check cpp style
|
||||
if: ${{ matrix.runner != 'macos-10.15' }}
|
||||
run: |
|
||||
pip install clang-format
|
||||
pip3 install clang-format
|
||||
find . -regex '.*\.\(cpp\|hpp\|h\|cc\)' -not -path "./python/triton/*" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file --dry-run -Werror -i ||
|
||||
(echo '::error title=Style issues:: Please run `find . -regex ".*\.\(cpp\|hpp\|h\|cc\)" -not -path "./python/triton/*" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file -i`' ; exit 1)
|
||||
|
||||
- name: Flake8
|
||||
if: ${{ matrix.runner != 'macos-10.15' }}
|
||||
run: |
|
||||
pip install flake8
|
||||
pip3 install flake8
|
||||
flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )
|
||||
|
||||
- name: Install Triton
|
||||
@@ -94,3 +99,12 @@ jobs:
|
||||
cd python/
|
||||
cd "build/$(ls build)"
|
||||
ctest
|
||||
|
||||
- name: Regression tests
|
||||
if: ${{ contains(matrix.runner, 'A100') }}
|
||||
run: |
|
||||
cd python/test/regression
|
||||
sudo nvidia-smi -i 0 -pm 1
|
||||
sudo nvidia-smi -i 0 --lock-gpu-clocks=1350,1350
|
||||
pytest -vs .
|
||||
sudo nvidia-smi -i 0 -rgc
|
||||
@@ -51,8 +51,8 @@ include_directories(${PYBIND11_INCLUDE_DIR})
|
||||
|
||||
if(WIN32)
|
||||
SET(BUILD_SHARED_LIBS OFF)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/deps/dlfcn-win32/src)
|
||||
add_subdirectory(deps/dlfcn-win32/src ${CMAKE_BINARY_DIR}/dlfcn-win32)
|
||||
find_package(dlfcn-win32 REQUIRED)
|
||||
set(CMAKE_DL_LIBS dlfcn-win32::dl)
|
||||
endif()
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17 -fvisibility=hidden -fvisibility-inlines-hidden")
|
||||
@@ -215,8 +215,7 @@ get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
|
||||
|
||||
if(TRITON_BUILD_PYTHON_MODULE)
|
||||
add_library(triton SHARED ${PYTHON_SRC})
|
||||
|
||||
target_link_libraries(triton
|
||||
set(TRITON_LIBRARIES
|
||||
TritonAnalysis
|
||||
TritonTransforms
|
||||
TritonGPUTransforms
|
||||
@@ -237,16 +236,25 @@ if(TRITON_BUILD_PYTHON_MODULE)
|
||||
MLIRROCDLToLLVMIRTranslation
|
||||
MLIRIR
|
||||
)
|
||||
|
||||
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
||||
|
||||
if(WIN32)
|
||||
target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32
|
||||
target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} ${CMAKE_DL_LIBS}
|
||||
${TRITON_LIBRARIES}
|
||||
)
|
||||
elseif(APPLE)
|
||||
target_link_libraries(triton ${LLVM_LIBRARIES} z)
|
||||
target_link_libraries(triton ${LLVM_LIBRARIES} z
|
||||
${TRITON_LIBRARIES}
|
||||
)
|
||||
else()
|
||||
target_link_libraries(triton ${LLVM_LIBRARIES} z stdc++fs)
|
||||
target_link_libraries(triton ${LLVM_LIBRARIES} z stdc++fs
|
||||
${TRITON_LIBRARIES}
|
||||
)
|
||||
endif()
|
||||
|
||||
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
||||
endif()
|
||||
|
||||
if (UNIX AND NOT APPLE)
|
||||
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--exclude-libs,ALL")
|
||||
endif()
|
||||
|
||||
if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32)
|
||||
|
||||
@@ -32,7 +32,7 @@ int main(int argc, char **argv) {
|
||||
|
||||
// TODO: register Triton & TritonGPU passes
|
||||
mlir::DialectRegistry registry;
|
||||
registry.insert<mlir::triton::TritonDialect,
|
||||
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
|
||||
mlir::triton::gpu::TritonGPUDialect, mlir::func::FuncDialect,
|
||||
mlir::math::MathDialect, mlir::arith::ArithDialect,
|
||||
mlir::scf::SCFDialect, mlir::gpu::GPUDialect>();
|
||||
|
||||
@@ -36,28 +36,29 @@ public:
|
||||
void run();
|
||||
|
||||
private:
|
||||
struct RegionInfo {
|
||||
struct BlockInfo {
|
||||
using BufferIdSetT = Allocation::BufferIdSetT;
|
||||
|
||||
BufferIdSetT syncReadBuffers;
|
||||
BufferIdSetT syncWriteBuffers;
|
||||
|
||||
RegionInfo() = default;
|
||||
RegionInfo(const BufferIdSetT &syncReadBuffers,
|
||||
const BufferIdSetT &syncWriteBuffers)
|
||||
BlockInfo() = default;
|
||||
BlockInfo(const BufferIdSetT &syncReadBuffers,
|
||||
const BufferIdSetT &syncWriteBuffers)
|
||||
: syncReadBuffers(syncReadBuffers), syncWriteBuffers(syncWriteBuffers) {
|
||||
}
|
||||
|
||||
/// Unions two RegionInfo objects.
|
||||
void join(const RegionInfo &other) {
|
||||
/// Unions two BlockInfo objects.
|
||||
BlockInfo &join(const BlockInfo &other) {
|
||||
syncReadBuffers.insert(other.syncReadBuffers.begin(),
|
||||
other.syncReadBuffers.end());
|
||||
syncWriteBuffers.insert(other.syncWriteBuffers.begin(),
|
||||
other.syncWriteBuffers.end());
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Returns true if buffers in two RegionInfo objects are intersected.
|
||||
bool isIntersected(const RegionInfo &other, Allocation *allocation) const {
|
||||
/// Returns true if buffers in two BlockInfo objects are intersected.
|
||||
bool isIntersected(const BlockInfo &other, Allocation *allocation) const {
|
||||
return /*RAW*/ isIntersected(syncWriteBuffers, other.syncReadBuffers,
|
||||
allocation) ||
|
||||
/*WAR*/
|
||||
@@ -74,6 +75,14 @@ private:
|
||||
syncWriteBuffers.clear();
|
||||
}
|
||||
|
||||
/// Compares two BlockInfo objects.
|
||||
bool operator==(const BlockInfo &other) const {
|
||||
return syncReadBuffers == other.syncReadBuffers &&
|
||||
syncWriteBuffers == other.syncWriteBuffers;
|
||||
}
|
||||
|
||||
bool operator!=(const BlockInfo &other) const { return !(*this == other); }
|
||||
|
||||
private:
|
||||
/// Returns true if buffers in two sets are intersected.
|
||||
bool isIntersected(const BufferIdSetT &lhs, const BufferIdSetT &rhs,
|
||||
@@ -99,19 +108,19 @@ private:
|
||||
/// op5
|
||||
/// op6
|
||||
/// op7
|
||||
/// region2 and region3 started with the information of region1.
|
||||
/// Each region is analyzed separately and keeps their own copy of the
|
||||
/// information. At op7, we union the information of the region2 and region3
|
||||
/// and update the information of region1.
|
||||
void dfsOperation(Operation *operation, RegionInfo *blockInfo,
|
||||
OpBuilder *builder);
|
||||
/// TODO: Explain why we don't use ForwardAnalysis:
|
||||
void resolve(Operation *operation, OpBuilder *builder);
|
||||
|
||||
/// Updates the RegionInfo operation based on the operation.
|
||||
void transfer(Operation *operation, RegionInfo *blockInfo,
|
||||
OpBuilder *builder);
|
||||
/// Updates the BlockInfo operation based on the operation.
|
||||
void update(Operation *operation, BlockInfo *blockInfo, OpBuilder *builder);
|
||||
|
||||
/// Collects the successors of the terminator
|
||||
void visitTerminator(Operation *operation, SmallVector<Block *> &successors);
|
||||
|
||||
private:
|
||||
Allocation *allocation;
|
||||
DenseMap<Block *, BlockInfo> inputBlockInfoMap;
|
||||
DenseMap<Block *, BlockInfo> outputBlockInfoMap;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -6,10 +6,10 @@ include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
|
||||
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // NoMemoryEffect
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // NoMemoryEffect
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
|
||||
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
|
||||
|
||||
//
|
||||
@@ -29,7 +29,7 @@ class TT_Op<string mnemonic, list<Trait> traits = []> :
|
||||
// extui, extsi, tructi
|
||||
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
NoMemoryEffect,
|
||||
Pure,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Cast int64 to pointer";
|
||||
|
||||
@@ -42,7 +42,7 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
|
||||
|
||||
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
NoMemoryEffect,
|
||||
Pure,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Cast pointer to int64";
|
||||
|
||||
@@ -56,7 +56,7 @@ def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
|
||||
// arith.bitcast doesn't support pointers
|
||||
def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
NoMemoryEffect,
|
||||
Pure,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Cast between types of the same bitwidth";
|
||||
|
||||
@@ -71,7 +71,7 @@ def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
|
||||
|
||||
def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
NoMemoryEffect,
|
||||
Pure,
|
||||
DeclareOpInterfaceMethods<CastOpInterface>]> {
|
||||
let summary = "Floating point casting for custom types";
|
||||
|
||||
@@ -95,7 +95,7 @@ def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
|
||||
//
|
||||
|
||||
def TT_AddPtrOp : TT_Op<"addptr",
|
||||
[NoMemoryEffect,
|
||||
[Pure,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
TypesMatchWith<"result type matches ptr type",
|
||||
@@ -222,7 +222,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>,
|
||||
//
|
||||
// Shape Manipulation Ops
|
||||
//
|
||||
def TT_SplatOp : TT_Op<"splat", [NoMemoryEffect,
|
||||
def TT_SplatOp : TT_Op<"splat", [Pure,
|
||||
SameOperandsAndResultElementType,
|
||||
SameOperandsAndResultEncoding]> {
|
||||
let summary = "splat";
|
||||
@@ -236,7 +236,7 @@ def TT_SplatOp : TT_Op<"splat", [NoMemoryEffect,
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoMemoryEffect,
|
||||
def TT_ExpandDimsOp : TT_Op<"expand_dims", [Pure,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "expand_dims";
|
||||
@@ -248,6 +248,7 @@ def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoMemoryEffect,
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
// view is not `pure` because it may reorder elements
|
||||
def TT_ViewOp : TT_Op<"view", [NoMemoryEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "view";
|
||||
@@ -260,7 +261,7 @@ def TT_ViewOp : TT_Op<"view", [NoMemoryEffect,
|
||||
|
||||
}
|
||||
|
||||
def TT_BroadcastOp : TT_Op<"broadcast", [NoMemoryEffect,
|
||||
def TT_BroadcastOp : TT_Op<"broadcast", [Pure,
|
||||
SameOperandsAndResultElementType,
|
||||
SameOperandsAndResultEncoding]> {
|
||||
let summary = "broadcast. No left-padding as of now.";
|
||||
@@ -274,6 +275,7 @@ def TT_BroadcastOp : TT_Op<"broadcast", [NoMemoryEffect,
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
// cat is not `pure` because it may reorder elements
|
||||
def TT_CatOp : TT_Op<"cat", [NoMemoryEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "concatenate 2 tensors";
|
||||
@@ -285,7 +287,7 @@ def TT_CatOp : TT_Op<"cat", [NoMemoryEffect,
|
||||
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def TT_TransOp : TT_Op<"trans", [NoMemoryEffect,
|
||||
def TT_TransOp : TT_Op<"trans", [Pure,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
SameOperandsAndResultElementType]> {
|
||||
|
||||
@@ -301,7 +303,7 @@ def TT_TransOp : TT_Op<"trans", [NoMemoryEffect,
|
||||
//
|
||||
// SPMD Ops
|
||||
//
|
||||
def TT_GetProgramIdOp : TT_Op<"get_program_id", [NoMemoryEffect]> {
|
||||
def TT_GetProgramIdOp : TT_Op<"get_program_id", [Pure]> {
|
||||
let arguments = (ins I32Attr:$axis);
|
||||
|
||||
let results = (outs I32:$result);
|
||||
@@ -309,7 +311,7 @@ def TT_GetProgramIdOp : TT_Op<"get_program_id", [NoMemoryEffect]> {
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [NoMemoryEffect]> {
|
||||
def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [Pure]> {
|
||||
let arguments = (ins I32Attr:$axis);
|
||||
|
||||
let results = (outs I32:$result);
|
||||
@@ -320,7 +322,7 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [NoMemoryEffect]> {
|
||||
//
|
||||
// Dot Op
|
||||
//
|
||||
def TT_DotOp : TT_Op<"dot", [NoMemoryEffect,
|
||||
def TT_DotOp : TT_Op<"dot", [Pure,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
TypesMatchWith<"result's type matches accumulator's type",
|
||||
"d", "c", "$_self">]> {
|
||||
@@ -340,7 +342,7 @@ def TT_DotOp : TT_Op<"dot", [NoMemoryEffect,
|
||||
//
|
||||
// Reduce Op
|
||||
//
|
||||
def TT_ReduceOp : TT_Op<"reduce", [NoMemoryEffect,
|
||||
def TT_ReduceOp : TT_Op<"reduce", [Pure,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "reduce";
|
||||
|
||||
@@ -364,7 +366,7 @@ def TT_ReduceOp : TT_Op<"reduce", [NoMemoryEffect,
|
||||
//
|
||||
// External elementwise op
|
||||
//
|
||||
def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoMemoryEffect, Elementwise, SameOperandsAndResultShape,
|
||||
def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [Pure, Elementwise, SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
SameVariadicOperandSize]> {
|
||||
let summary = "ext_elemwise";
|
||||
@@ -386,7 +388,7 @@ def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoMemoryEffect, Elementwise, SameO
|
||||
// Make Range Op
|
||||
//
|
||||
// TODO: should have ConstantLike as Trait
|
||||
def TT_MakeRangeOp : TT_Op<"make_range", [NoMemoryEffect]> {
|
||||
def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> {
|
||||
let summary = "make range";
|
||||
|
||||
let description = [{
|
||||
|
||||
@@ -7,7 +7,7 @@ include "mlir/Dialect/Arith/IR/ArithBase.td"
|
||||
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // NoMemoryEffect
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||
|
||||
def ResultsAreSharedEncoding: NativeOpTrait<"ResultsAreSharedEncoding">;
|
||||
@@ -18,13 +18,15 @@ class TTG_Op<string mnemonic, list<Trait> traits = []> :
|
||||
def TTG_ConvertLayoutOp : TTG_Op<"convert_layout",
|
||||
[SameOperandsAndResultShape,
|
||||
SameOperandsAndResultElementType,
|
||||
NoMemoryEffect]> {
|
||||
Pure]> {
|
||||
let summary = "convert layout";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let hasCanonicalizeMethod = 1;
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
@@ -59,7 +61,7 @@ def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> {
|
||||
// This is needed because these ops don't
|
||||
// handle encodings
|
||||
// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td#L111
|
||||
def TTG_CmpIOp : TTG_Op<"cmpi", [NoMemoryEffect, Elementwise,
|
||||
def TTG_CmpIOp : TTG_Op<"cmpi", [Pure, Elementwise,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding]> {
|
||||
let summary = "integer comparison operation";
|
||||
@@ -73,7 +75,7 @@ def TTG_CmpIOp : TTG_Op<"cmpi", [NoMemoryEffect, Elementwise,
|
||||
let results = (outs TT_BoolLike:$result);
|
||||
}
|
||||
|
||||
def TTG_CmpFOp : TTG_Op<"cmpf", [NoMemoryEffect, Elementwise,
|
||||
def TTG_CmpFOp : TTG_Op<"cmpf", [Pure, Elementwise,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding]> {
|
||||
let summary = "floating-point comparison operation";
|
||||
@@ -88,7 +90,7 @@ def TTG_CmpFOp : TTG_Op<"cmpf", [NoMemoryEffect, Elementwise,
|
||||
}
|
||||
|
||||
// TODO: migrate to arith::SelectOp on LLVM16
|
||||
def TTG_SelectOp : TTG_Op<"select", [NoMemoryEffect, Elementwise,
|
||||
def TTG_SelectOp : TTG_Op<"select", [Pure, Elementwise,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding]> {
|
||||
let summary = "select operation";
|
||||
|
||||
@@ -6,7 +6,9 @@
|
||||
namespace mlir {
|
||||
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 2);
|
||||
|
||||
// TODO(Keren): prefetch pass not working yet
|
||||
std::unique_ptr<Pass>
|
||||
createTritonGPUAccelerateMatmulPass(int computeCapability = 80);
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUPrefetchPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
|
||||
@@ -17,10 +19,12 @@ std::unique_ptr<Pass> createTritonGPUReorderInstructionsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUDecomposeConversionsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUCombineOpsPass(int computeCapability = 80);
|
||||
std::unique_ptr<Pass> createTritonGPURemoveLayoutConversionsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUVerifier();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUFuseTranspositionsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUUpdateMmaForVoltaPass();
|
||||
|
||||
/// Generate the code for registering passes.
|
||||
|
||||
@@ -7,7 +7,8 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
|
||||
let summary = "pipeline";
|
||||
|
||||
let description = [{
|
||||
Unroll loops to hide global memory -> shared memory latency.
|
||||
Replace `LoadOp` in loops by `InsertSliceAsyncOp` instructions that asynchronously construct the data
|
||||
needed at the next iteration
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUPipelinePass()";
|
||||
@@ -27,7 +28,8 @@ def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
|
||||
let summary = "prefetch";
|
||||
|
||||
let description = [{
|
||||
Prefetch operands (a and b) of tt.dot into shared memory to hide shared memory -> register latency.
|
||||
Decompose `DotOp` instructions in loops into several finer-grained `DotOp`
|
||||
that may have their operands constructed at the end of the previous iteration
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUPrefetchPass()";
|
||||
@@ -37,6 +39,41 @@ def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
|
||||
"mlir::arith::ArithDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::ModuleOp"> {
|
||||
let summary = "accelerate matmul";
|
||||
|
||||
let description = [{
|
||||
Optimize the input/output layout of `dot` instruction to make them compatible hardware accelerators
|
||||
(e.g., Nvidia tensor cores)
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUAccelerateMatmulPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">
|
||||
];
|
||||
|
||||
}
|
||||
|
||||
def TritonGPUFuseTranspositions : Pass<"tritongpu-fuse-transposition", "mlir::ModuleOp"> {
|
||||
let summary = "fuse transpositions";
|
||||
|
||||
let description = [{
|
||||
Re-arranged layouts of tensors used as matrix multiplication operands so as to promote the use of
|
||||
hardware-accelerated transpositions.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUFuseTranspositionsPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> {
|
||||
let summary = "coalesce";
|
||||
|
||||
@@ -49,26 +86,16 @@ def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> {
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
|
||||
let summary = "combine triton gpu ops";
|
||||
def TritonGPURemoveLayoutConversions : Pass<"tritongpu-remove-layout-conversions", "mlir::ModuleOp"> {
|
||||
let summary = "remove superfluous layout conversions";
|
||||
|
||||
let description = [{
|
||||
convert_layout(convert_layout(%src, #LAYOUT_0), #LAYOUT_1) =>
|
||||
convert_layout(%src, #LAYOUT_1)
|
||||
|
||||
convert_layout(%src, #LAYOUT) => %src if %src.layout() == #LAYOUT
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUCombineOpsPass()";
|
||||
let constructor = "mlir::createTritonGPURemoveLayoutConversionsPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUReorderInstructions: Pass<"tritongpu-reorder-instructions", "mlir::ModuleOp"> {
|
||||
@@ -95,19 +122,6 @@ def TritonGPUDecomposeConversions: Pass<"tritongpu-decompose-conversions", "mlir
|
||||
"mlir::triton::TritonDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {
|
||||
let summary = "canonicalize scf.ForOp ops";
|
||||
|
||||
let description = [{
|
||||
This implements some optimizations that are missing in the standard scf.ForOp
|
||||
canonicalizer.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUCanonicalizeLoopsPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
|
||||
}
|
||||
|
||||
def UpdateMmaForVolta : Pass<"tritongpu-update-mma-for-volta", "mlir::ModuleOp"> {
|
||||
let summary = "Update mma encodings for Volta";
|
||||
|
||||
|
||||
@@ -2,74 +2,84 @@
|
||||
#include "triton/Analysis/Alias.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include <deque>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
void MembarAnalysis::run() {
|
||||
auto *operation = allocation->getOperation();
|
||||
RegionInfo regionInfo;
|
||||
OpBuilder builder(operation);
|
||||
dfsOperation(operation, ®ionInfo, &builder);
|
||||
resolve(operation, &builder);
|
||||
}
|
||||
|
||||
void MembarAnalysis::dfsOperation(Operation *operation,
|
||||
RegionInfo *parentRegionInfo,
|
||||
OpBuilder *builder) {
|
||||
transfer(operation, parentRegionInfo, builder);
|
||||
if (operation->getNumRegions()) {
|
||||
// If there's any nested regions, we need to visit them.
|
||||
// scf.if and scf.else: two regions
|
||||
// scf.if only: two regions
|
||||
// scf.for: one region
|
||||
RegionInfo curRegionInfo;
|
||||
auto traverseRegions = [&]() -> auto{
|
||||
for (auto ®ion : operation->getRegions()) {
|
||||
// Copy the parent info as the current info.
|
||||
RegionInfo regionInfo = *parentRegionInfo;
|
||||
for (auto &block : region.getBlocks()) {
|
||||
// assert(region.getBlocks().size() == 1 &&
|
||||
// "Multiple blocks in a region is not supported");
|
||||
for (auto &op : block.getOperations()) {
|
||||
// Traverse the nested operation.
|
||||
dfsOperation(&op, ®ionInfo, builder);
|
||||
}
|
||||
}
|
||||
curRegionInfo.join(regionInfo);
|
||||
void MembarAnalysis::resolve(Operation *operation, OpBuilder *builder) {
|
||||
// Initialize the blockList
|
||||
std::deque<Block *> blockList;
|
||||
operation->walk<WalkOrder::PreOrder>([&](Block *block) {
|
||||
for (auto &op : block->getOperations()) {
|
||||
// Check if the operation belongs to scf dialect, if so, we need to
|
||||
// throw an error
|
||||
if (op.getDialect()->getNamespace() == "scf") {
|
||||
op.emitError("scf dialect is not supported in membar. Please lower it "
|
||||
"to cf dialect first.");
|
||||
return;
|
||||
}
|
||||
// Set the parent region info as the union of the nested region info.
|
||||
*parentRegionInfo = curRegionInfo;
|
||||
};
|
||||
}
|
||||
if (block->isEntryBlock())
|
||||
blockList.emplace_back(block);
|
||||
});
|
||||
|
||||
traverseRegions();
|
||||
if (isa<scf::ForOp>(operation)) {
|
||||
// scf.for can have two possible inputs: the init value and the
|
||||
// previous iteration's result. Although we've applied alias analysis,
|
||||
// there could be unsynced memory accesses on reused memories.
|
||||
// For example, consider the following code:
|
||||
// %1 = convert_layout %0: blocked -> shared
|
||||
// ...
|
||||
// gpu.barrier
|
||||
// ...
|
||||
// %5 = convert_layout %4 : shared -> dot
|
||||
// %6 = tt.dot %2, %5
|
||||
// scf.yield
|
||||
//
|
||||
// Though %5 could be released before scf.yield, it may shared the same
|
||||
// memory with %1. So we actually have to insert a barrier before %1 to
|
||||
// make sure the memory is synced.
|
||||
traverseRegions();
|
||||
// A fixed point algorithm
|
||||
while (!blockList.empty()) {
|
||||
auto *block = blockList.front();
|
||||
blockList.pop_front();
|
||||
// Make a copy of the inputblockInfo but not update
|
||||
auto inputBlockInfo = inputBlockInfoMap.lookup(block);
|
||||
SmallVector<Block *> successors;
|
||||
for (auto &op : block->getOperations()) {
|
||||
if (op.hasTrait<OpTrait::IsTerminator>()) {
|
||||
visitTerminator(&op, successors);
|
||||
} else {
|
||||
update(&op, &inputBlockInfo, builder);
|
||||
}
|
||||
}
|
||||
// Get the reference because we want to update if it changed
|
||||
if (outputBlockInfoMap.count(block) &&
|
||||
inputBlockInfo == outputBlockInfoMap[block]) {
|
||||
// If we have seen the block before and the inputBlockInfo is the same as
|
||||
// the outputBlockInfo, we skip the successors
|
||||
continue;
|
||||
}
|
||||
// Update the current block
|
||||
outputBlockInfoMap[block].join(inputBlockInfo);
|
||||
// Update the successors
|
||||
for (auto *successor : successors) {
|
||||
inputBlockInfoMap[successor].join(outputBlockInfoMap[block]);
|
||||
blockList.emplace_back(successor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
|
||||
OpBuilder *builder) {
|
||||
if (isa<scf::ForOp>(op) || isa<scf::IfOp>(op) || isa<scf::YieldOp>(op) ||
|
||||
isa<tensor::ExtractSliceOp>(op) || isa<triton::gpu::AllocTensorOp>(op)) {
|
||||
// Do not insert barriers before control flow operations and
|
||||
// alloc/extract/insert
|
||||
void MembarAnalysis::visitTerminator(Operation *op,
|
||||
SmallVector<Block *> &successors) {
|
||||
if (auto branchInterface = dyn_cast<BranchOpInterface>(op)) {
|
||||
Block *parentBlock = branchInterface->getBlock();
|
||||
for (Block *successor : parentBlock->getSuccessors()) {
|
||||
successors.push_back(successor);
|
||||
}
|
||||
return;
|
||||
}
|
||||
// Otherwise, it could be a return op
|
||||
assert(isa<func::ReturnOp>(op) && "Unknown terminator");
|
||||
}
|
||||
|
||||
void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
|
||||
OpBuilder *builder) {
|
||||
if (isa<tensor::ExtractSliceOp>(op) || isa<triton::gpu::AllocTensorOp>(op) ||
|
||||
isa<triton::TransOp>(op)) {
|
||||
// alloc is an allocation op without memory write.
|
||||
// FIXME(Keren): extract_slice is always alias for now
|
||||
return;
|
||||
@@ -77,7 +87,7 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
|
||||
|
||||
if (isa<gpu::BarrierOp>(op)) {
|
||||
// If the current op is a barrier, we sync previous reads and writes
|
||||
regionInfo->sync();
|
||||
blockInfo->sync();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -85,26 +95,26 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
|
||||
!isa<gpu::BarrierOp>(op->getNextNode())) {
|
||||
// If the current op is an async wait and the next op is not a barrier we
|
||||
// insert a barrier op and sync
|
||||
regionInfo->sync();
|
||||
blockInfo->sync();
|
||||
OpBuilder::InsertionGuard g(*builder);
|
||||
builder->setInsertionPointAfter(op);
|
||||
builder->create<gpu::BarrierOp>(op->getLoc());
|
||||
regionInfo->sync();
|
||||
blockInfo->sync();
|
||||
return;
|
||||
}
|
||||
|
||||
RegionInfo curRegionInfo;
|
||||
BlockInfo curBlockInfo;
|
||||
for (Value value : op->getOperands()) {
|
||||
for (auto bufferId : allocation->getBufferIds(value)) {
|
||||
if (bufferId != Allocation::InvalidBufferId) {
|
||||
if (isa<triton::gpu::InsertSliceAsyncOp>(op) ||
|
||||
isa<tensor::InsertSliceOp>(op)) {
|
||||
// FIXME(Keren): insert_slice and insert_slice_async are always alias
|
||||
// for now
|
||||
curRegionInfo.syncWriteBuffers.insert(bufferId);
|
||||
// FIXME(Keren): insert_slice and insert_slice_async are always
|
||||
// alias for now
|
||||
curBlockInfo.syncWriteBuffers.insert(bufferId);
|
||||
} else {
|
||||
// ConvertLayoutOp: shared memory -> registers
|
||||
curRegionInfo.syncReadBuffers.insert(bufferId);
|
||||
curBlockInfo.syncReadBuffers.insert(bufferId);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -113,25 +123,25 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
|
||||
// ConvertLayoutOp: registers -> shared memory
|
||||
auto bufferId = allocation->getBufferId(value);
|
||||
if (bufferId != Allocation::InvalidBufferId) {
|
||||
curRegionInfo.syncWriteBuffers.insert(bufferId);
|
||||
curBlockInfo.syncWriteBuffers.insert(bufferId);
|
||||
}
|
||||
}
|
||||
// Scratch buffer is considered as both shared memory write & read
|
||||
auto bufferId = allocation->getBufferId(op);
|
||||
if (bufferId != Allocation::InvalidBufferId) {
|
||||
curRegionInfo.syncWriteBuffers.insert(bufferId);
|
||||
curRegionInfo.syncReadBuffers.insert(bufferId);
|
||||
curBlockInfo.syncWriteBuffers.insert(bufferId);
|
||||
curBlockInfo.syncReadBuffers.insert(bufferId);
|
||||
}
|
||||
|
||||
if (regionInfo->isIntersected(curRegionInfo, allocation)) {
|
||||
if (blockInfo->isIntersected(curBlockInfo, allocation)) {
|
||||
OpBuilder::InsertionGuard g(*builder);
|
||||
builder->setInsertionPoint(op);
|
||||
builder->create<gpu::BarrierOp>(op->getLoc());
|
||||
regionInfo->sync();
|
||||
blockInfo->sync();
|
||||
}
|
||||
// Update the region info, even if barrier is inserted, we have to maintain
|
||||
// the current op's read/write buffers.
|
||||
regionInfo->join(curRegionInfo);
|
||||
blockInfo->join(curBlockInfo);
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -337,7 +337,8 @@ namespace {
|
||||
// Copied from TestDeadCodeAnalysis.cpp, because some dead code analysis
|
||||
// interacts with constant propagation, but SparseConstantPropagation
|
||||
// doesn't seem to be sufficient.
|
||||
struct ConstantAnalysis : public DataFlowAnalysis {
|
||||
class ConstantAnalysis : public DataFlowAnalysis {
|
||||
public:
|
||||
using DataFlowAnalysis::DataFlowAnalysis;
|
||||
|
||||
LogicalResult initialize(Operation *top) override {
|
||||
@@ -359,12 +360,19 @@ struct ConstantAnalysis : public DataFlowAnalysis {
|
||||
value, op->getDialect())));
|
||||
return success();
|
||||
}
|
||||
// Dead code analysis requires every operands has initialized ConstantValue
|
||||
// state before it is visited.
|
||||
// https://github.com/llvm/llvm-project/blob/2ec1aba2b69faa1de5f71832a48e25aa3b5d5314/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp#L322
|
||||
// That's why we need to set all operands to unknown constants.
|
||||
setAllToUnknownConstants(op->getResults());
|
||||
for (Region ®ion : op->getRegions())
|
||||
setAllToUnknownConstants(region.getArguments());
|
||||
for (Region ®ion : op->getRegions()) {
|
||||
for (Block &block : region.getBlocks())
|
||||
setAllToUnknownConstants(block.getArguments());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
/// Set all given values as not constants.
|
||||
void setAllToUnknownConstants(ValueRange values) {
|
||||
dataflow::ConstantValue unknownConstant(nullptr, nullptr);
|
||||
|
||||
@@ -291,10 +291,10 @@ public:
|
||||
for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) {
|
||||
// extract multi dimensional index for current element
|
||||
auto idx = srcIndices[elemIdx];
|
||||
Value idxCol = idx[inOrder[0]]; // contiguous dimension
|
||||
Value idxRow = idx[inOrder[1]]; // discontiguous dimension
|
||||
Value strideCol = srcStrides[inOrder[0]];
|
||||
Value strideRow = srcStrides[inOrder[1]];
|
||||
Value idxCol = idx[outOrder[0]]; // contiguous dimension
|
||||
Value idxRow = idx[outOrder[1]]; // discontiguous dimension
|
||||
Value strideCol = srcStrides[outOrder[0]];
|
||||
Value strideRow = srcStrides[outOrder[1]];
|
||||
// extract dynamic/static offset for immediate offseting
|
||||
unsigned immedateOffCol = 0;
|
||||
if (auto add = dyn_cast_or_null<LLVM::AddOp>(idxCol.getDefiningOp()))
|
||||
@@ -338,7 +338,7 @@ public:
|
||||
Value currPtr = gep(dstPtrTy, dstPtrBase, offset);
|
||||
// compute immediate offset
|
||||
Value immedateOff =
|
||||
add(mul(i32_val(immedateOffRow), srcStrides[inOrder[1]]),
|
||||
add(mul(i32_val(immedateOffRow), srcStrides[outOrder[1]]),
|
||||
i32_val(immedateOffCol));
|
||||
ret[elemIdx] = gep(dstPtrTy, currPtr, immedateOff);
|
||||
}
|
||||
|
||||
@@ -128,16 +128,13 @@ public:
|
||||
// Step 2: Decompose insert_slice_async to use load + insert_slice for
|
||||
// pre-Ampere architectures or unsupported vectorized load sizes
|
||||
// Step 3: Allocate shared memories and insert barriers
|
||||
// Step 4: Convert SCF to CFG
|
||||
// Step 5: Convert FuncOp to LLVMFuncOp via partial conversion
|
||||
// Step 6: Get axis and shared memory info
|
||||
// Step 7: Convert the rest of ops via partial conversion
|
||||
// Step 4: Convert FuncOp to LLVMFuncOp via partial conversion
|
||||
// Step 5: Get axis and shared memory info
|
||||
// Step 6: Convert the rest of ops via partial conversion
|
||||
//
|
||||
// The reason for putting step 3 before step 4 is that the membar
|
||||
// analysis currently only supports SCF but not CFG. The reason for a
|
||||
// separation between 5/7 is that, step 6 is out of the scope of Dialect
|
||||
// Conversion, thus we need to make sure the smem is not revised during the
|
||||
// conversion of step 7.
|
||||
// The reason for a separation between 4/6 is that, step 5 is out of the
|
||||
// scope of Dialect Conversion, thus we need to make sure the smem is not
|
||||
// revised during the conversion of step 6.
|
||||
|
||||
// Step 1
|
||||
decomposeMmaToDotOperand(mod, numWarps);
|
||||
@@ -153,24 +150,13 @@ public:
|
||||
membarPass.run();
|
||||
|
||||
// Step 4
|
||||
RewritePatternSet scf_patterns(context);
|
||||
mlir::populateSCFToControlFlowConversionPatterns(scf_patterns);
|
||||
mlir::ConversionTarget scf_target(*context);
|
||||
scf_target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp,
|
||||
scf::WhileOp, scf::ExecuteRegionOp>();
|
||||
scf_target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
|
||||
RewritePatternSet funcPatterns(context);
|
||||
funcPatterns.add<FuncOpConversion>(typeConverter, numWarps, /*benefit=*/1);
|
||||
if (failed(
|
||||
applyPartialConversion(mod, scf_target, std::move(scf_patterns))))
|
||||
applyPartialConversion(mod, funcTarget, std::move(funcPatterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
// Step 5
|
||||
RewritePatternSet func_patterns(context);
|
||||
func_patterns.add<FuncOpConversion>(typeConverter, numWarps, /*benefit=*/1);
|
||||
if (failed(
|
||||
applyPartialConversion(mod, funcTarget, std::move(func_patterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
// Step 6 - get axis and shared memory info
|
||||
// Step 5 - get axis and shared memory info
|
||||
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
|
||||
AxisInfoAnalysis *axisInfoAnalysis = solver->load<AxisInfoAnalysis>();
|
||||
if (failed(solver->initializeAndRun(mod)))
|
||||
@@ -180,7 +166,7 @@ public:
|
||||
mlir::IntegerAttr::get(mlir::IntegerType::get(context, 32),
|
||||
allocation.getSharedMemorySize()));
|
||||
|
||||
// Step 7 - rewrite rest of ops
|
||||
// Step 6 - rewrite rest of ops
|
||||
// We set a higher benefit here to ensure triton's patterns runs before
|
||||
// arith patterns for some encoding not supported by the community
|
||||
// patterns.
|
||||
@@ -222,30 +208,16 @@ public:
|
||||
// Add arith/math's patterns to help convert scalar expression to LLVM.
|
||||
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
|
||||
mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns);
|
||||
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
#ifdef USE_ROCM
|
||||
mlir::populateGpuToROCDLConversionPatterns(typeConverter, patterns, mlir::gpu::amd::HIP);
|
||||
#else
|
||||
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
|
||||
#endif
|
||||
|
||||
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
// Take care of scf pattern introduced by LoadStoreOp
|
||||
#ifdef USE_ROCM
|
||||
RewritePatternSet scf_patterns_extra(context);
|
||||
mlir::populateSCFToControlFlowConversionPatterns(scf_patterns_extra);
|
||||
if (failed(
|
||||
applyPartialConversion(mod, scf_target, std::move(scf_patterns_extra))))
|
||||
return signalPassFailure();
|
||||
RewritePatternSet patterns_extra(context);
|
||||
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns_extra);
|
||||
if (failed(
|
||||
applyPartialConversion(mod, target, std::move(patterns_extra))))
|
||||
return signalPassFailure();
|
||||
#endif
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
@@ -790,6 +790,141 @@ struct TritonGPUInferLayoutInterface
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Canonicalizer
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
// we don't handle conversions to DotOperandEncodingAttr
|
||||
// this is a heuristics to accommodate fused attention
|
||||
auto srcType = op.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = op.getType().cast<RankedTensorType>();
|
||||
if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>() &&
|
||||
srcType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
||||
return mlir::failure();
|
||||
// convert to the same layout -- we can delete
|
||||
if (op->getResultTypes() == op->getOperandTypes()) {
|
||||
rewriter.replaceOp(op, op->getOperands());
|
||||
return mlir::success();
|
||||
}
|
||||
Operation *arg = op->getOperand(0).getDefiningOp();
|
||||
// block argument
|
||||
if (!arg)
|
||||
return mlir::failure();
|
||||
// cvt(view) -> view
|
||||
if (auto view = dyn_cast<triton::ViewOp>(arg)) {
|
||||
rewriter.replaceOpWithNewOp<triton::ViewOp>(op, op->getResult(0).getType(),
|
||||
view.getResult());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2)
|
||||
auto alloc_tensor = dyn_cast<triton::gpu::AllocTensorOp>(arg);
|
||||
if (alloc_tensor) {
|
||||
if (!isSharedEncoding(op->getResult(0))) {
|
||||
return mlir::failure();
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::AllocTensorOp>(
|
||||
op, op->getResult(0).getType());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2))
|
||||
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
|
||||
if (insert_slice) {
|
||||
if (!isSharedEncoding(op->getResult(0))) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto newType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
// Ensure that the new insert_slice op is placed in the same place as the
|
||||
// old insert_slice op. Otherwise, the new insert_slice op may be placed
|
||||
// after the async_wait op, which is not allowed.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(insert_slice);
|
||||
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newType, insert_slice.getDst());
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>(
|
||||
op, newType, insert_slice.getSrc(), newArg.getResult(),
|
||||
insert_slice.getIndex(), insert_slice.getMask(),
|
||||
insert_slice.getOther(), insert_slice.getCache(),
|
||||
insert_slice.getEvict(), insert_slice.getIsVolatile(),
|
||||
insert_slice.getAxis());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2))
|
||||
auto extract_slice = dyn_cast<tensor::ExtractSliceOp>(arg);
|
||||
if (extract_slice) {
|
||||
if (!isSharedEncoding(op->getResult(0))) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto origType =
|
||||
extract_slice.getSource().getType().cast<RankedTensorType>();
|
||||
auto newType = RankedTensorType::get(
|
||||
origType.getShape(), origType.getElementType(),
|
||||
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
|
||||
auto origResType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
auto resType = RankedTensorType::get(
|
||||
origResType.getShape(), origResType.getElementType(),
|
||||
extract_slice.getType().cast<RankedTensorType>().getEncoding());
|
||||
// Ensure that the new extract_slice op is placed in the same place as the
|
||||
// old extract_slice op. Otherwise, the new extract_slice op may be placed
|
||||
// after the async_wait op, which is not allowed.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(extract_slice);
|
||||
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newType, extract_slice.getSource());
|
||||
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
|
||||
op, resType, newArg.getResult(), extract_slice.offsets(),
|
||||
extract_slice.sizes(), extract_slice.strides(),
|
||||
extract_slice.static_offsets(), extract_slice.static_sizes(),
|
||||
extract_slice.static_strides());
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// cvt(cvt(x, type1), type2) -> cvt(x, type2)
|
||||
if (llvm::isa<triton::gpu::ConvertLayoutOp>(arg)) {
|
||||
if (arg->getOperand(0).getDefiningOp() &&
|
||||
!isSharedEncoding(arg->getOperand(0)) &&
|
||||
isSharedEncoding(op.getOperand()) &&
|
||||
!isSharedEncoding(op.getResult())) {
|
||||
return mlir::failure();
|
||||
}
|
||||
if (isSharedEncoding(op.getOperand()) && isSharedEncoding(op.getResult())) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto srcType = op.getOperand().getType().cast<RankedTensorType>();
|
||||
auto srcShared =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
if (srcShared && srcShared.getVec() > 1)
|
||||
return mlir::failure();
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
||||
op, op->getResultTypes().front(), arg->getOperand(0));
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(type1, splat(type2, x)) -> splat(type1, x)
|
||||
if (auto splat = llvm::dyn_cast<triton::SplatOp>(arg)) {
|
||||
rewriter.replaceOpWithNewOp<triton::SplatOp>(op, op->getResultTypes(),
|
||||
splat.getSrc());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(type1, make_range(type2, x)) -> make_range(type1, x)
|
||||
if (auto range = llvm::dyn_cast<triton::MakeRangeOp>(arg)) {
|
||||
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
|
||||
op, op->getResultTypes(), range.getStart(), range.getEnd());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(type, constant) -> constant
|
||||
if (auto cst = llvm::dyn_cast<arith::ConstantOp>(arg))
|
||||
if (auto ret = cst.getValue().dyn_cast<SplatElementsAttr>()) {
|
||||
auto newRet = SplatElementsAttr::get(op->getResultTypes().front(),
|
||||
ret.getSplatValue<Attribute>());
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newRet);
|
||||
return mlir::success();
|
||||
}
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void TritonGPUDialect::initialize() {
|
||||
addAttributes<
|
||||
#define GET_ATTRDEF_LIST
|
||||
|
||||
215
lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Normal file
215
lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Normal file
@@ -0,0 +1,215 @@
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include <memory>
|
||||
|
||||
using namespace mlir;
|
||||
namespace {
|
||||
using triton::DotOp;
|
||||
using triton::gpu::ConvertLayoutOp;
|
||||
using triton::gpu::DotOperandEncodingAttr;
|
||||
using triton::gpu::MmaEncodingAttr;
|
||||
using triton::gpu::SliceEncodingAttr;
|
||||
|
||||
int computeCapabilityToMMAVersion(int computeCapability) {
|
||||
#ifdef USE_ROCM
|
||||
return 1;
|
||||
#endif
|
||||
if (computeCapability < 70) {
|
||||
return 0;
|
||||
} else if (computeCapability < 80) {
|
||||
return 1;
|
||||
} else if (computeCapability < 90) {
|
||||
return 2;
|
||||
} else if (computeCapability < 100) {
|
||||
// FIXME: temporarily add this to pass unis tests
|
||||
return 2;
|
||||
} else {
|
||||
assert(false && "computeCapability > 100 not supported");
|
||||
return 3;
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 2> mmaVersionToShapePerWarp(int version) {
|
||||
if (version == 1)
|
||||
return {16, 16};
|
||||
else if (version == 2)
|
||||
return {16, 8};
|
||||
else {
|
||||
assert(false && "version not supported");
|
||||
return {0, 0};
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2> warpsPerTileV1(const ArrayRef<int64_t> shape,
|
||||
int numWarps) {
|
||||
// Set a default value that ensures product of wpt equals numWarps
|
||||
return {static_cast<unsigned>(numWarps), 1};
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
|
||||
const ArrayRef<int64_t> shape,
|
||||
int numWarps) {
|
||||
SetVector<Operation *> slices;
|
||||
mlir::getForwardSlice(dotOp.getResult(), &slices);
|
||||
if (llvm::find_if(slices, [](Operation *op) {
|
||||
return isa<triton::DotOp>(op);
|
||||
}) != slices.end())
|
||||
return {(unsigned)numWarps, 1};
|
||||
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
|
||||
bool changed = false;
|
||||
// TODO (@daadaada): double-check.
|
||||
// original logic in
|
||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
|
||||
// seems buggy for shape = [32, 16] ?
|
||||
do {
|
||||
changed = false;
|
||||
if (ret[0] * ret[1] >= numWarps)
|
||||
break;
|
||||
if (shape[0] / shapePerWarp[0] / ret[0] >=
|
||||
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
|
||||
if (ret[0] < shape[0] / shapePerWarp[0]) {
|
||||
ret[0] *= 2;
|
||||
} else
|
||||
ret[1] *= 2;
|
||||
} else {
|
||||
ret[1] *= 2;
|
||||
}
|
||||
} while (true);
|
||||
return ret;
|
||||
}
|
||||
|
||||
class BlockedToMMA : public mlir::RewritePattern {
|
||||
int computeCapability;
|
||||
mutable int mmaV1Counter{}; // used to generate ID for MMAv1 encoding
|
||||
|
||||
public:
|
||||
BlockedToMMA(mlir::MLIRContext *context, int computeCapability)
|
||||
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
|
||||
computeCapability(computeCapability) {}
|
||||
|
||||
static SmallVector<unsigned, 2> getWarpsPerTile(triton::DotOp dotOp,
|
||||
const ArrayRef<int64_t> shape,
|
||||
int version, int numWarps) {
|
||||
switch (version) {
|
||||
case 1:
|
||||
return warpsPerTileV1(shape, numWarps);
|
||||
case 2:
|
||||
return warpsPerTileV2(dotOp, shape, numWarps);
|
||||
default:
|
||||
llvm_unreachable("unsupported MMA version");
|
||||
}
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto dotOp = cast<triton::DotOp>(op);
|
||||
// TODO: Check data-types and SM compatibility
|
||||
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
|
||||
if (!oldRetType.getEncoding() ||
|
||||
oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
||||
return failure();
|
||||
|
||||
// for FMA, should retain the blocked layout.
|
||||
int versionMajor = computeCapabilityToMMAVersion(computeCapability);
|
||||
if (!supportMMA(dotOp, versionMajor))
|
||||
return failure();
|
||||
|
||||
// get MMA encoding for the given number of warps
|
||||
auto retShape = oldRetType.getShape();
|
||||
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
|
||||
auto warpsPerTile =
|
||||
getWarpsPerTile(dotOp, retShape, versionMajor, numWarps);
|
||||
triton::gpu::MmaEncodingAttr mmaEnc;
|
||||
if (versionMajor == 1) {
|
||||
mmaEnc = triton::gpu::MmaEncodingAttr::get(
|
||||
oldRetType.getContext(), versionMajor, numWarps, mmaV1Counter++);
|
||||
} else if (versionMajor == 2) {
|
||||
mmaEnc = triton::gpu::MmaEncodingAttr::get(
|
||||
oldRetType.getContext(), versionMajor, 0 /*versionMinor*/,
|
||||
warpsPerTile);
|
||||
} else {
|
||||
llvm_unreachable("Mma layout only supports versionMajor in {1, 2}");
|
||||
}
|
||||
auto newRetType =
|
||||
RankedTensorType::get(retShape, oldRetType.getElementType(), mmaEnc);
|
||||
|
||||
// convert accumulator
|
||||
auto oldAcc = dotOp.getOperand(2);
|
||||
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
oldAcc.getLoc(), newRetType, oldAcc);
|
||||
Value a = dotOp.getA();
|
||||
Value b = dotOp.getB();
|
||||
auto oldAType = a.getType().cast<RankedTensorType>();
|
||||
auto oldBType = b.getType().cast<RankedTensorType>();
|
||||
auto oldAOrder = oldAType.getEncoding()
|
||||
.cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent()
|
||||
.cast<triton::gpu::BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
auto oldBOrder = oldBType.getEncoding()
|
||||
.cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent()
|
||||
.cast<triton::gpu::BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
Attribute isMMAv1RowA;
|
||||
Attribute isMMAv1RowB;
|
||||
if (versionMajor == 1) {
|
||||
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
|
||||
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
|
||||
}
|
||||
|
||||
auto newAType = RankedTensorType::get(
|
||||
oldAType.getShape(), oldAType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(
|
||||
oldAType.getContext(), 0, newRetType.getEncoding(), isMMAv1RowA));
|
||||
auto newBType = RankedTensorType::get(
|
||||
oldBType.getShape(), oldBType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(
|
||||
oldBType.getContext(), 1, newRetType.getEncoding(), isMMAv1RowB));
|
||||
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
||||
auto newDot = rewriter.create<triton::DotOp>(
|
||||
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.getAllowTF32());
|
||||
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
||||
op, oldRetType, newDot.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
class TritonGPUAccelerateMatmulPass
|
||||
: public TritonGPUAccelerateMatmulBase<TritonGPUAccelerateMatmulPass> {
|
||||
public:
|
||||
TritonGPUAccelerateMatmulPass() = default;
|
||||
TritonGPUAccelerateMatmulPass(int computeCapability) {
|
||||
this->computeCapability = computeCapability;
|
||||
}
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp m = getOperation();
|
||||
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
patterns.add<::BlockedToMMA>(context, computeCapability);
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
mlir::createTritonGPUAccelerateMatmulPass(int computeCapability) {
|
||||
return std::make_unique<TritonGPUAccelerateMatmulPass>(computeCapability);
|
||||
}
|
||||
@@ -1,22 +1,18 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Combine.td)
|
||||
mlir_tablegen(TritonGPUCombine.inc -gen-rewriters)
|
||||
add_public_tablegen_target(TritonGPUCombineIncGen)
|
||||
|
||||
add_mlir_dialect_library(TritonGPUTransforms
|
||||
AccelerateMatmul.cpp
|
||||
Coalesce.cpp
|
||||
CanonicalizeLoops.cpp
|
||||
Combine.cpp
|
||||
DecomposeConversions.cpp
|
||||
FuseTranspositions.cpp
|
||||
Pipeline.cpp
|
||||
Prefetch.cpp
|
||||
RemoveLayoutConversions.cpp
|
||||
ReorderInstructions.cpp
|
||||
DecomposeConversions.cpp
|
||||
TritonGPUConversion.cpp
|
||||
UpdateMmaForVolta.cpp
|
||||
Utility.cpp
|
||||
|
||||
DEPENDS
|
||||
TritonGPUTransformsIncGen
|
||||
TritonGPUCombineIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
TritonIR
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
namespace {
|
||||
|
||||
struct CanonicalizePass
|
||||
: public TritonGPUCanonicalizeLoopsBase<CanonicalizePass> {
|
||||
CanonicalizePass() = default;
|
||||
|
||||
void runOnOperation() override {
|
||||
|
||||
// Canonicalize pass may have created dead code that
|
||||
// standard scf.for canonicalization cannot handle
|
||||
// as of LLVM 14. For example, the iteration arguments
|
||||
// for the pointer of the synchronous loads that are
|
||||
// discarded.
|
||||
// The following piece of code is a workaround to
|
||||
// very crudely remove dead code, by making an iteration
|
||||
// argument yield itself if it is not used to create
|
||||
// side effects anywhere.
|
||||
getOperation()->walk([&](scf::ForOp forOp) -> void {
|
||||
for (size_t i = 0; i < forOp.getNumResults(); ++i) {
|
||||
// condition 1: no other iter arguments depend on it
|
||||
SetVector<Operation *> fwdSlice;
|
||||
mlir::getForwardSlice(forOp.getRegionIterArgs()[i], &fwdSlice);
|
||||
Operation *yieldOp = forOp.getBody()->getTerminator();
|
||||
bool noOtherDependency = std::all_of(
|
||||
yieldOp->operand_begin(), yieldOp->operand_end(), [&](Value arg) {
|
||||
return arg == yieldOp->getOperand(i) ||
|
||||
!fwdSlice.contains(arg.getDefiningOp());
|
||||
});
|
||||
// condition 2: final value is not used after the loop
|
||||
auto retVal = forOp.getResult(i);
|
||||
bool noUserAfterLoop = retVal.getUsers().empty();
|
||||
// yielding the region iter arg will cause loop canonicalization
|
||||
// to clean up the dead code
|
||||
if (noOtherDependency && noUserAfterLoop) {
|
||||
yieldOp->setOperand(i, forOp.getRegionIterArgs()[i]);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUCanonicalizeLoopsPass() {
|
||||
return std::make_unique<CanonicalizePass>();
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
#ifndef TRITONGPU_PATTERNS
|
||||
#define TRITONGPU_PATTERNS
|
||||
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUOps.td"
|
||||
include "triton/Dialect/Triton/IR/TritonOps.td"
|
||||
include "mlir/IR/PatternBase.td"
|
||||
|
||||
#endif
|
||||
153
lib/Dialect/TritonGPU/Transforms/FuseTranspositions.cpp
Normal file
153
lib/Dialect/TritonGPU/Transforms/FuseTranspositions.cpp
Normal file
@@ -0,0 +1,153 @@
|
||||
#include "Utility.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include <memory>
|
||||
|
||||
using namespace mlir;
|
||||
namespace {
|
||||
using triton::DotOp;
|
||||
using triton::gpu::ConvertLayoutOp;
|
||||
using triton::gpu::DotOperandEncodingAttr;
|
||||
using triton::gpu::MmaEncodingAttr;
|
||||
using triton::gpu::SliceEncodingAttr;
|
||||
|
||||
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
|
||||
public:
|
||||
explicit OptimizeConvertToDotOperand(mlir::MLIRContext *context)
|
||||
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
||||
// order
|
||||
ArrayRef<unsigned> order;
|
||||
if (auto srcBlockedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
||||
order = srcBlockedLayout.getOrder();
|
||||
else if (auto srcSharedLayout =
|
||||
srcType.getEncoding()
|
||||
.dyn_cast<triton::gpu::SharedEncodingAttr>())
|
||||
order = srcSharedLayout.getOrder();
|
||||
else
|
||||
return failure();
|
||||
// dot operand output
|
||||
auto dstDotOperandLayout =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (!dstDotOperandLayout)
|
||||
return failure();
|
||||
if (!dstDotOperandLayout.getIsMMAv1Row())
|
||||
return failure();
|
||||
bool isMMAv1Row =
|
||||
dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row))
|
||||
return failure();
|
||||
|
||||
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
|
||||
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
op->getContext(), dstDotOperandLayout.getOpIdx(),
|
||||
dstDotOperandLayout.getParent(), newIsRow);
|
||||
auto newDstType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(), newDstEncoding);
|
||||
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newDstType, cvt.getOperand());
|
||||
rewriter.replaceOp(op, newCvt.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// convert(trans(convert(arg)))
|
||||
// x = convert_layout arg: #distributed -> #shared_x
|
||||
// y = trans x: #shared_x -> #shared_y
|
||||
// z = convert_layout y: #shared_y -> #dot_operand
|
||||
class ConvertTransConvert : public mlir::RewritePattern {
|
||||
|
||||
public:
|
||||
ConvertTransConvert(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
1, context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto dstOp = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto tmpOp =
|
||||
dyn_cast_or_null<triton::TransOp>(dstOp.getSrc().getDefiningOp());
|
||||
if (!tmpOp)
|
||||
return mlir::failure();
|
||||
auto srcOp = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(
|
||||
tmpOp.getSrc().getDefiningOp());
|
||||
if (!srcOp)
|
||||
return mlir::failure();
|
||||
auto arg = srcOp.getSrc();
|
||||
auto X = tmpOp.getSrc();
|
||||
// types
|
||||
auto argType = arg.getType().cast<RankedTensorType>();
|
||||
auto XType = X.getType().cast<RankedTensorType>();
|
||||
auto ZType = dstOp.getResult().getType().cast<RankedTensorType>();
|
||||
// encodings
|
||||
auto argEncoding = argType.getEncoding();
|
||||
auto XEncoding =
|
||||
XType.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
|
||||
auto ZEncoding =
|
||||
ZType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (!ZEncoding)
|
||||
return mlir::failure();
|
||||
// new X encoding
|
||||
auto newXOrder = triton::gpu::getOrder(argEncoding);
|
||||
auto newXEncoding = triton::gpu::SharedEncodingAttr::get(
|
||||
getContext(), ZEncoding, XType.getShape(), newXOrder,
|
||||
XType.getElementType());
|
||||
auto newXType = RankedTensorType::get(XType.getShape(),
|
||||
XType.getElementType(), newXEncoding);
|
||||
if (XEncoding == newXEncoding)
|
||||
return mlir::failure();
|
||||
|
||||
auto newX = rewriter.create<triton::gpu::ConvertLayoutOp>(srcOp.getLoc(),
|
||||
newXType, arg);
|
||||
auto newY = rewriter.create<triton::TransOp>(tmpOp.getLoc(), newX);
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(dstOp, ZType,
|
||||
newY);
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
class TritonGPUFuseTranspositionsPass
|
||||
: public TritonGPUFuseTranspositionsBase<TritonGPUFuseTranspositionsPass> {
|
||||
public:
|
||||
TritonGPUFuseTranspositionsPass() = default;
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp m = getOperation();
|
||||
|
||||
mlir::PassManager pm(m.getContext());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
auto ret = pm.run(m);
|
||||
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
patterns.add<OptimizeConvertToDotOperand>(context);
|
||||
patterns.add<ConvertTransConvert>(context);
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
|
||||
signalPassFailure();
|
||||
if (fixupLoops(m).failed())
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUFuseTranspositionsPass() {
|
||||
return std::make_unique<TritonGPUFuseTranspositionsPass>();
|
||||
}
|
||||
@@ -22,7 +22,6 @@
|
||||
|
||||
using namespace mlir;
|
||||
namespace {
|
||||
#include "TritonGPUCombine.inc"
|
||||
using triton::DotOp;
|
||||
using triton::gpu::ConvertLayoutOp;
|
||||
using triton::gpu::DotOperandEncodingAttr;
|
||||
@@ -139,132 +138,7 @@ public:
|
||||
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op))
|
||||
return mlir::failure();
|
||||
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
// we don't handle conversions to DotOperandEncodingAttr
|
||||
// this is a heuristics to accommodate fused attention
|
||||
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = convert.getType().cast<RankedTensorType>();
|
||||
if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>() &&
|
||||
srcType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
||||
return mlir::failure();
|
||||
// convert to the same layout -- we can delete
|
||||
if (op->getResultTypes() == op->getOperandTypes()) {
|
||||
rewriter.replaceOp(op, op->getOperands());
|
||||
return mlir::success();
|
||||
}
|
||||
Operation *arg = op->getOperand(0).getDefiningOp();
|
||||
// block argument
|
||||
if (!arg)
|
||||
return mlir::failure();
|
||||
// cvt(view) -> view
|
||||
if (auto view = dyn_cast<triton::ViewOp>(arg)) {
|
||||
rewriter.replaceOpWithNewOp<triton::ViewOp>(
|
||||
op, op->getResult(0).getType(), view.getResult());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2)
|
||||
auto alloc_tensor = dyn_cast<triton::gpu::AllocTensorOp>(arg);
|
||||
if (alloc_tensor) {
|
||||
if (!isSharedEncoding(op->getResult(0))) {
|
||||
return mlir::failure();
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::AllocTensorOp>(
|
||||
op, op->getResult(0).getType());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2))
|
||||
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
|
||||
if (insert_slice) {
|
||||
if (!isSharedEncoding(op->getResult(0))) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto newType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
// Ensure that the new insert_slice op is placed in the same place as the
|
||||
// old insert_slice op. Otherwise, the new insert_slice op may be placed
|
||||
// after the async_wait op, which is not allowed.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(insert_slice);
|
||||
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newType, insert_slice.getDst());
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>(
|
||||
op, newType, insert_slice.getSrc(), newArg.getResult(),
|
||||
insert_slice.getIndex(), insert_slice.getMask(),
|
||||
insert_slice.getOther(), insert_slice.getCache(),
|
||||
insert_slice.getEvict(), insert_slice.getIsVolatile(),
|
||||
insert_slice.getAxis());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2))
|
||||
auto extract_slice = dyn_cast<tensor::ExtractSliceOp>(arg);
|
||||
if (extract_slice) {
|
||||
if (!isSharedEncoding(op->getResult(0))) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto origType =
|
||||
extract_slice.getSource().getType().cast<RankedTensorType>();
|
||||
auto newType = RankedTensorType::get(
|
||||
origType.getShape(), origType.getElementType(),
|
||||
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
|
||||
auto origResType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
auto resType = RankedTensorType::get(
|
||||
origResType.getShape(), origResType.getElementType(),
|
||||
extract_slice.getType().cast<RankedTensorType>().getEncoding());
|
||||
// Ensure that the new extract_slice op is placed in the same place as the
|
||||
// old extract_slice op. Otherwise, the new extract_slice op may be placed
|
||||
// after the async_wait op, which is not allowed.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(extract_slice);
|
||||
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newType, extract_slice.getSource());
|
||||
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
|
||||
op, resType, newArg.getResult(), extract_slice.offsets(),
|
||||
extract_slice.sizes(), extract_slice.strides(),
|
||||
extract_slice.static_offsets(), extract_slice.static_sizes(),
|
||||
extract_slice.static_strides());
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// cvt(cvt(x, type1), type2) -> cvt(x, type2)
|
||||
if (llvm::isa<triton::gpu::ConvertLayoutOp>(arg)) {
|
||||
if (arg->getOperand(0).getDefiningOp() &&
|
||||
!isSharedEncoding(arg->getOperand(0)) &&
|
||||
isSharedEncoding(convert.getOperand()) &&
|
||||
!isSharedEncoding(convert.getResult())) {
|
||||
return mlir::failure();
|
||||
}
|
||||
if (isSharedEncoding(convert.getOperand()) &&
|
||||
isSharedEncoding(convert.getResult())) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
|
||||
auto srcShared =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
if (srcShared && srcShared.getVec() > 1)
|
||||
return mlir::failure();
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
||||
op, op->getResultTypes().front(), arg->getOperand(0));
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(type1, splat(type2, x)) -> splat(type1, x)
|
||||
if (auto splat = llvm::dyn_cast<triton::SplatOp>(arg)) {
|
||||
rewriter.replaceOpWithNewOp<triton::SplatOp>(op, op->getResultTypes(),
|
||||
splat.getSrc());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(type1, make_range(type2, x)) -> make_range(type1, x)
|
||||
if (auto range = llvm::dyn_cast<triton::MakeRangeOp>(arg)) {
|
||||
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
|
||||
op, op->getResultTypes(), range.getStart(), range.getEnd());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(type, constant) -> constant
|
||||
if (auto cst = llvm::dyn_cast<arith::ConstantOp>(arg))
|
||||
if (auto ret = cst.getValue().dyn_cast<SplatElementsAttr>()) {
|
||||
auto newRet = SplatElementsAttr::get(op->getResultTypes().front(),
|
||||
ret.getSplatValue<Attribute>());
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newRet);
|
||||
return mlir::success();
|
||||
}
|
||||
return mlir::failure();
|
||||
return ConvertLayoutOp::canonicalize(convert, rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -568,9 +442,9 @@ public:
|
||||
};
|
||||
|
||||
//
|
||||
class FoldConvertAndReduce : public mlir::RewritePattern {
|
||||
class RematerializeForward : public mlir::RewritePattern {
|
||||
public:
|
||||
explicit FoldConvertAndReduce(mlir::MLIRContext *context)
|
||||
explicit RematerializeForward(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
1, context) {}
|
||||
|
||||
@@ -837,393 +711,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
class RematerializeForward : public mlir::RewritePattern {
|
||||
public:
|
||||
explicit RematerializeForward(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
2, context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *_cvtOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cvt = cast<triton::gpu::ConvertLayoutOp>(_cvtOp);
|
||||
auto forOp = dyn_cast<scf::ForOp>(cvt->getParentOp());
|
||||
if (!forOp)
|
||||
return mlir::failure();
|
||||
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
|
||||
|
||||
SetVector<Operation *> cvtSlices;
|
||||
auto filter = [&](Operation *op) {
|
||||
return isInLoop(op) &&
|
||||
!isa<triton::LoadOp, triton::StoreOp, triton::AtomicRMWOp,
|
||||
triton::AtomicCASOp>(op) &&
|
||||
!isa<triton::DotOp>(op) && !isa<scf::YieldOp>(op) &&
|
||||
!isa<triton::gpu::ConvertLayoutOp>(op);
|
||||
};
|
||||
mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter);
|
||||
if (cvtSlices.empty())
|
||||
return failure();
|
||||
|
||||
for (Operation *op : cvtSlices) {
|
||||
if (!isa<triton::ViewOp, triton::CatOp>(op) &&
|
||||
!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
|
||||
!op->hasTrait<mlir::OpTrait::Elementwise>() &&
|
||||
!isa<triton::StoreOp>(op))
|
||||
return failure();
|
||||
for (Value arg : op->getOperands()) {
|
||||
Operation *argOp = arg.getDefiningOp();
|
||||
if (argOp && (argOp != cvt) &&
|
||||
!isa<arith::ConstantOp, triton::SplatOp, triton::MakeRangeOp>(
|
||||
argOp)) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise, we push the conversion forward
|
||||
// since we'll be able to move it out of
|
||||
// the loop once it reaches the yield op
|
||||
pushConversionForward(cvt, cvtSlices, rewriter);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
namespace {
|
||||
int computeCapabilityToMMAVersion(int computeCapability) {
|
||||
#ifdef USE_ROCM
|
||||
return 1;
|
||||
#endif
|
||||
if (computeCapability < 70) {
|
||||
return 0;
|
||||
} else if (computeCapability < 80) {
|
||||
return 1;
|
||||
} else if (computeCapability < 90) {
|
||||
return 2;
|
||||
} else if (computeCapability < 100) {
|
||||
// FIXME: temporarily add this to pass unis tests
|
||||
return 2;
|
||||
} else {
|
||||
assert(false && "computeCapability > 100 not supported");
|
||||
return 3;
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 2> mmaVersionToShapePerWarp(int version) {
|
||||
if (version == 1)
|
||||
return {16, 16};
|
||||
else if (version == 2)
|
||||
return {16, 8};
|
||||
else {
|
||||
assert(false && "version not supported");
|
||||
return {0, 0};
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2> warpsPerTileV1(const ArrayRef<int64_t> shape,
|
||||
int numWarps) {
|
||||
// Set a default value and ensure product of wpt equals numWarps
|
||||
return {static_cast<unsigned>(numWarps), 1};
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
|
||||
const ArrayRef<int64_t> shape,
|
||||
int numWarps) {
|
||||
SetVector<Operation *> slices;
|
||||
mlir::getForwardSlice(dotOp.getResult(), &slices);
|
||||
if (llvm::find_if(slices, [](Operation *op) {
|
||||
return isa<triton::DotOp>(op);
|
||||
}) != slices.end())
|
||||
return {(unsigned)numWarps, 1};
|
||||
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
|
||||
bool changed = false;
|
||||
// TODO (@daadaada): double-check.
|
||||
// original logic in
|
||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252
|
||||
// seems buggy for shape = [32, 16] ?
|
||||
do {
|
||||
changed = false;
|
||||
if (ret[0] * ret[1] >= numWarps)
|
||||
break;
|
||||
if (shape[0] / shapePerWarp[0] / ret[0] >=
|
||||
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
|
||||
if (ret[0] < shape[0] / shapePerWarp[0]) {
|
||||
ret[0] *= 2;
|
||||
} else
|
||||
ret[1] *= 2;
|
||||
} else {
|
||||
ret[1] *= 2;
|
||||
}
|
||||
} while (true);
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class OptimizeBlockedToShared : public mlir::RewritePattern {
|
||||
public:
|
||||
explicit OptimizeBlockedToShared(mlir::MLIRContext *context)
|
||||
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
||||
auto srcBlockedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||
auto dstSharedLayout =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
if (!srcBlockedLayout || !dstSharedLayout)
|
||||
return failure();
|
||||
if (srcBlockedLayout.getOrder() == dstSharedLayout.getOrder())
|
||||
return failure();
|
||||
// For now only works if single use is transpose
|
||||
// TODO: rematerialize #shared uses
|
||||
auto users = op->getUsers();
|
||||
if (std::distance(users.begin(), users.end()) != 1 ||
|
||||
!isa<triton::TransOp>(*users.begin()))
|
||||
return failure();
|
||||
|
||||
auto tmpShared = triton::gpu::SharedEncodingAttr::get(
|
||||
op->getContext(), dstSharedLayout.getVec(),
|
||||
dstSharedLayout.getPerPhase(), dstSharedLayout.getMaxPhase(),
|
||||
srcBlockedLayout.getOrder());
|
||||
auto tmpType = RankedTensorType::get(srcType.getShape(),
|
||||
srcType.getElementType(), tmpShared);
|
||||
auto tmpCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), tmpType, cvt.getOperand());
|
||||
|
||||
auto newDstType = RankedTensorType::get(
|
||||
users.begin()->getResultTypes()[0].cast<RankedTensorType>().getShape(),
|
||||
srcType.getElementType(), dstSharedLayout);
|
||||
|
||||
auto newTrans = rewriter.create<triton::TransOp>(op->getLoc(), newDstType,
|
||||
tmpCvt.getResult());
|
||||
|
||||
rewriter.replaceOp(*users.begin(), newTrans.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
|
||||
public:
|
||||
explicit OptimizeConvertToDotOperand(mlir::MLIRContext *context)
|
||||
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
||||
// order
|
||||
ArrayRef<unsigned> order;
|
||||
if (auto srcBlockedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
||||
order = srcBlockedLayout.getOrder();
|
||||
else if (auto srcSharedLayout =
|
||||
srcType.getEncoding()
|
||||
.dyn_cast<triton::gpu::SharedEncodingAttr>())
|
||||
order = srcSharedLayout.getOrder();
|
||||
else
|
||||
return failure();
|
||||
// dot operand output
|
||||
auto dstDotOperandLayout =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (!dstDotOperandLayout)
|
||||
return failure();
|
||||
if (!dstDotOperandLayout.getIsMMAv1Row())
|
||||
return failure();
|
||||
bool isMMAv1Row =
|
||||
dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row))
|
||||
return failure();
|
||||
|
||||
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
|
||||
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
op->getContext(), dstDotOperandLayout.getOpIdx(),
|
||||
dstDotOperandLayout.getParent(), newIsRow);
|
||||
auto newDstType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(), newDstEncoding);
|
||||
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newDstType, cvt.getOperand());
|
||||
rewriter.replaceOp(op, newCvt.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class BlockedToMMA : public mlir::RewritePattern {
|
||||
int computeCapability;
|
||||
mutable int mmaV1Counter{}; // used to generate ID for MMAv1 encoding
|
||||
|
||||
public:
|
||||
BlockedToMMA(mlir::MLIRContext *context, int computeCapability)
|
||||
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
|
||||
computeCapability(computeCapability) {}
|
||||
|
||||
static SmallVector<unsigned, 2> getWarpsPerTile(triton::DotOp dotOp,
|
||||
const ArrayRef<int64_t> shape,
|
||||
int version, int numWarps) {
|
||||
switch (version) {
|
||||
case 1:
|
||||
return warpsPerTileV1(shape, numWarps);
|
||||
case 2:
|
||||
return warpsPerTileV2(dotOp, shape, numWarps);
|
||||
default:
|
||||
assert(false && "not supported version");
|
||||
return {0, 0};
|
||||
}
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto dotOp = cast<triton::DotOp>(op);
|
||||
// TODO: Check data-types and SM compatibility
|
||||
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
|
||||
if (!oldRetType.getEncoding() ||
|
||||
oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
||||
return failure();
|
||||
|
||||
// for FMA, should retain the blocked layout.
|
||||
int versionMajor = computeCapabilityToMMAVersion(computeCapability);
|
||||
if (!supportMMA(dotOp, versionMajor))
|
||||
return failure();
|
||||
|
||||
// get MMA encoding for the given number of warps
|
||||
auto retShape = oldRetType.getShape();
|
||||
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
|
||||
auto warpsPerTile =
|
||||
getWarpsPerTile(dotOp, retShape, versionMajor, numWarps);
|
||||
triton::gpu::MmaEncodingAttr mmaEnc;
|
||||
if (versionMajor == 1) {
|
||||
mmaEnc = triton::gpu::MmaEncodingAttr::get(
|
||||
oldRetType.getContext(), versionMajor, numWarps, mmaV1Counter++);
|
||||
} else if (versionMajor == 2) {
|
||||
mmaEnc = triton::gpu::MmaEncodingAttr::get(
|
||||
oldRetType.getContext(), versionMajor, 0 /*versionMinor*/,
|
||||
warpsPerTile);
|
||||
} else {
|
||||
assert(false && "Mma layout only support versionMajor of 1 or 2");
|
||||
}
|
||||
auto newRetType =
|
||||
RankedTensorType::get(retShape, oldRetType.getElementType(), mmaEnc);
|
||||
|
||||
// convert accumulator
|
||||
auto oldAcc = dotOp.getOperand(2);
|
||||
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
oldAcc.getLoc(), newRetType, oldAcc);
|
||||
Value a = dotOp.getA();
|
||||
Value b = dotOp.getB();
|
||||
auto oldAType = a.getType().cast<RankedTensorType>();
|
||||
auto oldBType = b.getType().cast<RankedTensorType>();
|
||||
auto oldAOrder = oldAType.getEncoding()
|
||||
.cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent()
|
||||
.cast<triton::gpu::BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
auto oldBOrder = oldBType.getEncoding()
|
||||
.cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent()
|
||||
.cast<triton::gpu::BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
Attribute isMMAv1RowA;
|
||||
Attribute isMMAv1RowB;
|
||||
if (versionMajor == 1) {
|
||||
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
|
||||
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
|
||||
}
|
||||
|
||||
auto newAType = RankedTensorType::get(
|
||||
oldAType.getShape(), oldAType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(
|
||||
oldAType.getContext(), 0, newRetType.getEncoding(), isMMAv1RowA));
|
||||
auto newBType = RankedTensorType::get(
|
||||
oldBType.getShape(), oldBType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(
|
||||
oldBType.getContext(), 1, newRetType.getEncoding(), isMMAv1RowB));
|
||||
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
||||
auto newDot = rewriter.create<triton::DotOp>(
|
||||
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.getAllowTF32());
|
||||
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
||||
op, oldRetType, newDot.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Convert + trans + convert
|
||||
// x = convert_layout distributed -> #shared_x
|
||||
// y = trans x -> #shared_y
|
||||
// z = convert_layout y -> #dot_operand
|
||||
class ConvertTransConvert : public mlir::RewritePattern {
|
||||
|
||||
public:
|
||||
ConvertTransConvert(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
1, context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto dstOp = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto tmpOp =
|
||||
dyn_cast_or_null<triton::TransOp>(dstOp.getSrc().getDefiningOp());
|
||||
if (!tmpOp)
|
||||
return mlir::failure();
|
||||
auto srcOp = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(
|
||||
tmpOp.getSrc().getDefiningOp());
|
||||
if (!srcOp)
|
||||
return mlir::failure();
|
||||
auto arg = srcOp.getSrc();
|
||||
auto X = tmpOp.getSrc();
|
||||
// types
|
||||
auto argType = arg.getType().cast<RankedTensorType>();
|
||||
auto XType = X.getType().cast<RankedTensorType>();
|
||||
auto ZType = dstOp.getResult().getType().cast<RankedTensorType>();
|
||||
// encodings
|
||||
auto argEncoding = argType.getEncoding();
|
||||
auto XEncoding =
|
||||
XType.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
|
||||
auto ZEncoding =
|
||||
ZType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (!ZEncoding)
|
||||
return mlir::failure();
|
||||
// new X encoding
|
||||
auto newXOrder = triton::gpu::getOrder(argEncoding);
|
||||
auto newXEncoding = triton::gpu::SharedEncodingAttr::get(
|
||||
getContext(), ZEncoding, XType.getShape(), newXOrder,
|
||||
XType.getElementType());
|
||||
auto newXType = RankedTensorType::get(XType.getShape(),
|
||||
XType.getElementType(), newXEncoding);
|
||||
if (XEncoding == newXEncoding)
|
||||
return mlir::failure();
|
||||
|
||||
auto newX = rewriter.create<triton::gpu::ConvertLayoutOp>(srcOp.getLoc(),
|
||||
newXType, arg);
|
||||
auto newY = rewriter.create<triton::TransOp>(tmpOp.getLoc(), newX);
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(dstOp, ZType,
|
||||
newY);
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
class ConvertDotConvert : public mlir::RewritePattern {
|
||||
public:
|
||||
@@ -1275,31 +762,25 @@ public:
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
class TritonGPUCombineOpsPass
|
||||
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
|
||||
class TritonGPURemoveLayoutConversionsPass
|
||||
: public TritonGPURemoveLayoutConversionsBase<
|
||||
TritonGPURemoveLayoutConversionsPass> {
|
||||
public:
|
||||
TritonGPUCombineOpsPass() = default;
|
||||
TritonGPUCombineOpsPass(int computeCapability) {
|
||||
this->computeCapability = computeCapability;
|
||||
}
|
||||
TritonGPURemoveLayoutConversionsPass() = default;
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp m = getOperation();
|
||||
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
|
||||
patterns.add<OptimizeBlockedToShared>(context);
|
||||
patterns.add<OptimizeConvertToDotOperand>(context);
|
||||
patterns.add<SimplifyConversion>(context);
|
||||
patterns.add<SimplifyReduceCvt>(context);
|
||||
patterns.add<FoldConvertAndReduce>(context);
|
||||
patterns.add<DecomposeDotOperand>(context);
|
||||
patterns.add<RematerializeBackward>(context);
|
||||
patterns.add<RematerializeForward>(context);
|
||||
patterns.add<MoveConvertOutOfLoop>(context);
|
||||
patterns.add<MoveConvertOutOfIf>(context);
|
||||
patterns.add<BlockedToMMA>(context, computeCapability);
|
||||
patterns.add<ConvertTransConvert>(context);
|
||||
patterns.add<DecomposeDotOperand>(context);
|
||||
patterns.add<ConvertDotConvert>(context);
|
||||
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
|
||||
@@ -1312,7 +793,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
mlir::createTritonGPUCombineOpsPass(int computeCapability) {
|
||||
return std::make_unique<TritonGPUCombineOpsPass>(computeCapability);
|
||||
std::unique_ptr<Pass> mlir::createTritonGPURemoveLayoutConversionsPass() {
|
||||
return std::make_unique<TritonGPURemoveLayoutConversionsPass>();
|
||||
}
|
||||
@@ -147,6 +147,12 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
|
||||
|
||||
if (!funcs.empty()) {
|
||||
static const std::string libdevice = "libdevice";
|
||||
// first search for environmental path
|
||||
std::string env_path = ::triton::tools::getenv("TRITON_LIBDEVICE_PATH");
|
||||
if (!env_path.empty()) {
|
||||
externLibs.try_emplace(libdevice, env_path);
|
||||
return externLibs;
|
||||
}
|
||||
namespace fs = std::filesystem;
|
||||
// Search for libdevice relative to its library path if used from Python
|
||||
// Then native code is in `triton/_C/libtriton.so` and libdevice in
|
||||
@@ -302,12 +308,17 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
/*printAfterOnlyOnChange=*/true,
|
||||
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
|
||||
|
||||
pm.addPass(mlir::createConvertSCFToCFPass());
|
||||
pm.addPass(createConvertTritonGPUToLLVMPass(computeCapability));
|
||||
// Canonicalize to eliminate the remaining UnrealizedConversionCastOp
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::createCSEPass()); // Simplify the IR to improve readability.
|
||||
pm.addPass(mlir::createSymbolDCEPass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
#ifdef USE_ROCM
|
||||
pm.addPass(mlir::createConvertSCFToCFPass());
|
||||
pm.addPass(createConvertControlFlowToLLVMPass());
|
||||
#endif
|
||||
|
||||
if (failed(pm.run(module))) {
|
||||
llvm::errs() << "Pass execution failed";
|
||||
|
||||
@@ -65,7 +65,7 @@ def get_llvm_package_info():
|
||||
linux_suffix = 'ubuntu-18.04' if vglibc > 217 else 'centos-7'
|
||||
system_suffix = f"linux-gnu-{linux_suffix}"
|
||||
else:
|
||||
raise RuntimeError(f"unsupported system: {system}")
|
||||
return Package("llvm", "LLVM-C.lib", "", "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
|
||||
use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
|
||||
release_suffix = "assert" if use_assert_enabled_llvm else "release"
|
||||
name = f'llvm+mlir-17.0.0-x86_64-{system_suffix}-{release_suffix}'
|
||||
@@ -159,7 +159,11 @@ class CMakeBuild(build_ext):
|
||||
|
||||
def build_extension(self, ext):
|
||||
lit_dir = shutil.which('lit')
|
||||
triton_cache_path = os.path.join(os.environ["HOME"], ".triton")
|
||||
user_home = os.getenv("HOME") or os.getenv("USERPROFILE") or \
|
||||
os.getenv("HOMEPATH") or None
|
||||
if not user_home:
|
||||
raise RuntimeError("Could not find user home directory")
|
||||
triton_cache_path = os.path.join(user_home, ".triton")
|
||||
# lit is used by the test suite
|
||||
thirdparty_cmake_args = get_thirdparty_packages(triton_cache_path)
|
||||
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
void init_superblocking(pybind11::module &m);
|
||||
void init_torch_utils(pybind11::module &m);
|
||||
|
||||
@@ -1462,7 +1462,7 @@ void init_triton_ir(py::module &&m) {
|
||||
.def(
|
||||
"add_sccp_pass",
|
||||
[](mlir::PassManager &self) { self.addPass(mlir::createSCCPPass()); })
|
||||
.def("add_coalesce_pass",
|
||||
.def("add_tritongpu_coalesce_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUCoalescePass());
|
||||
})
|
||||
@@ -1501,10 +1501,18 @@ void init_triton_ir(py::module &&m) {
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUPrefetchPass());
|
||||
})
|
||||
.def("add_tritongpu_combine_pass",
|
||||
.def("add_tritongpu_accelerate_matmul_pass",
|
||||
[](mlir::PassManager &self, int computeCapability) {
|
||||
self.addPass(
|
||||
mlir::createTritonGPUCombineOpsPass(computeCapability));
|
||||
mlir::createTritonGPUAccelerateMatmulPass(computeCapability));
|
||||
})
|
||||
.def("add_tritongpu_fuse_transpositions_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUFuseTranspositionsPass());
|
||||
})
|
||||
.def("add_tritongpu_remove_layout_conversions_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPURemoveLayoutConversionsPass());
|
||||
})
|
||||
.def("add_tritongpu_update_mma_for_volta_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
|
||||
@@ -8,7 +8,7 @@ import triton
|
||||
import triton.language as tl
|
||||
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
|
||||
|
||||
DEVICE_NAME = 'v100'
|
||||
DEVICE_NAME = {7: 'v100', 8: 'a100'}[torch.cuda.get_device_capability()[0]]
|
||||
|
||||
#######################
|
||||
# Utilities
|
||||
@@ -34,7 +34,6 @@ mem_clocks = {'v100': 877, 'a100': 1215}
|
||||
matmul_data = {
|
||||
'v100': {
|
||||
# square
|
||||
(256, 256, 256): {'float16': 0.027},
|
||||
(512, 512, 512): {'float16': 0.158},
|
||||
(1024, 1024, 1024): {'float16': 0.466},
|
||||
(2048, 2048, 2048): {'float16': 0.695},
|
||||
@@ -51,29 +50,26 @@ matmul_data = {
|
||||
(4096, 64, 4096): {'float16': 0.264},
|
||||
(8192, 64, 8192): {'float16': 0.452},
|
||||
},
|
||||
# NOTE:
|
||||
# A100 in the CI server is slow-ish for some reason.
|
||||
# On some other servers, we are getting about 90% peak for 8kx8x8k float16
|
||||
'a100': {
|
||||
(256, 256, 256): {'float16': 0.010, 'float32': 0.0214, 'int8': 0.006},
|
||||
(512, 512, 512): {'float16': 0.061, 'float32': 0.109, 'int8': 0.030},
|
||||
(1024, 1024, 1024): {'float16': 0.287, 'float32': 0.331, 'int8': 0.169},
|
||||
(2048, 2048, 2048): {'float16': 0.604, 'float32': 0.599, 'int8': 0.385},
|
||||
(4096, 4096, 4096): {'float16': 0.842, 'float32': 0.862, 'int8': 0.711},
|
||||
(8192, 8192, 8192): {'float16': 0.896, 'float32': 0.932, 'int8': 0.860},
|
||||
(512, 512, 512): {'float16': 0.08, 'float32': 0.13, 'int8': 0.05},
|
||||
(1024, 1024, 1024): {'float16': 0.33, 'float32': 0.35, 'int8': 0.169},
|
||||
(2048, 2048, 2048): {'float16': 0.64, 'float32': 0.57, 'int8': 0.34},
|
||||
(4096, 4096, 4096): {'float16': 0.81, 'float32': 0.75, 'int8': 0.46},
|
||||
(8192, 8192, 8192): {'float16': 0.77, 'float32': 0.85, 'int8': 0.51},
|
||||
# tall-skinny
|
||||
(16, 1024, 1024): {'float16': 0.0077, 'float32': 0.0127, 'int8': 0.005},
|
||||
(16, 4096, 4096): {'float16': 0.0363, 'float32': 0.0457, 'int8': 0.0259},
|
||||
(16, 8192, 8192): {'float16': 0.0564, 'float32': 0.0648, 'int8': 0.0431},
|
||||
(16, 8192, 8192): {'float16': 0.07, 'float32': 0.0648, 'int8': 0.0431},
|
||||
(64, 1024, 1024): {'float16': 0.0271, 'float32': 0.0509, 'int8': 0.0169},
|
||||
(64, 4096, 4096): {'float16': 0.141, 'float32': 0.162, 'int8': 0.097},
|
||||
(64, 8192, 8192): {'float16': 0.244, 'float32': 0.257, 'int8': 0.174},
|
||||
(64, 4096, 4096): {'float16': 0.16, 'float32': 0.162, 'int8': 0.097},
|
||||
(64, 8192, 8192): {'float16': 0.30, 'float32': 0.257, 'int8': 0.174},
|
||||
(1024, 64, 1024): {'float16': 0.0263, 'float32': 0.0458, 'int8': 0.017},
|
||||
(4096, 64, 4096): {'float16': 0.135, 'float32': 0.177, 'int8': 0.102},
|
||||
(8192, 64, 8192): {'float16': 0.216, 'float32': 0.230, 'int8': 0.177},
|
||||
(4096, 64, 4096): {'float16': 0.16, 'float32': 0.177, 'int8': 0.102},
|
||||
(8192, 64, 8192): {'float16': 0.25, 'float32': 0.230, 'int8': 0.177},
|
||||
}
|
||||
# # deep reductions
|
||||
# (64 , 64 , 16384) : {'a100': 0.},
|
||||
# (64 , 64 , 65536) : {'a100': 0.},
|
||||
# (256 , 256 , 8192 ) : {'a100': 0.},
|
||||
# (256 , 256 , 32768) : {'a100': 0.},
|
||||
}
|
||||
|
||||
|
||||
@@ -88,9 +84,7 @@ def test_matmul(M, N, K, dtype_str):
|
||||
torch.manual_seed(0)
|
||||
ref_gpu_util = matmul_data[DEVICE_NAME][(M, N, K)][dtype_str]
|
||||
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
|
||||
ref_sm_clock = sm_clocks[DEVICE_NAME]
|
||||
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
|
||||
assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz'
|
||||
if dtype == torch.int8:
|
||||
a = torch.randint(-128, 127, (M, K), dtype=dtype, device='cuda')
|
||||
b = torch.randint(-128, 127, (N, K), dtype=dtype, device='cuda')
|
||||
@@ -99,10 +93,10 @@ def test_matmul(M, N, K, dtype_str):
|
||||
a = torch.randn((M, K), dtype=dtype, device='cuda')
|
||||
b = torch.randn((K, N), dtype=dtype, device='cuda')
|
||||
fn = lambda: triton.ops.matmul(a, b)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=300)
|
||||
cur_gpu_perf = 2. * M * N * K / ms * 1e-9
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
|
||||
assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
|
||||
|
||||
#######################
|
||||
@@ -149,16 +143,61 @@ elementwise_data = {
|
||||
def test_elementwise(N):
|
||||
torch.manual_seed(0)
|
||||
ref_gpu_util = elementwise_data[DEVICE_NAME][N]
|
||||
cur_mem_clock = nvsmi(['clocks.current.memory'])[0]
|
||||
ref_mem_clock = mem_clocks[DEVICE_NAME]
|
||||
max_gpu_perf = get_dram_gbps()
|
||||
assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memory must run at {ref_mem_clock} MHz'
|
||||
z = torch.empty((N, ), dtype=torch.float16, device='cuda')
|
||||
x = torch.randn_like(z)
|
||||
y = torch.randn_like(z)
|
||||
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
|
||||
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=250)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=500)
|
||||
cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
|
||||
assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
|
||||
#######################
|
||||
# Flash-Attention
|
||||
#######################
|
||||
|
||||
|
||||
flash_attention_data = {
|
||||
"a100": {
|
||||
(4, 48, 4096, 64, 'forward', 'float16'): 0.37,
|
||||
(4, 48, 4096, 64, 'backward', 'float16'): 0.25,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [[4, 48, 4096, 64]])
|
||||
@pytest.mark.parametrize("mode", ['forward', 'backward'])
|
||||
@pytest.mark.parametrize("dtype_str", ['float16'])
|
||||
def test_flash_attention(Z, H, N_CTX, D_HEAD, mode, dtype_str):
|
||||
is_backward = mode == 'backward'
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8:
|
||||
pytest.skip("Flash attention only supported for compute capability < 80")
|
||||
torch.manual_seed(20)
|
||||
dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str]
|
||||
# init data
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
|
||||
sm_scale = 0.2
|
||||
# benchmark
|
||||
fn = lambda: triton.ops.attention(q, k, v, sm_scale)
|
||||
if is_backward:
|
||||
o = fn()
|
||||
do = torch.randn_like(o)
|
||||
fn = lambda: o.backward(do, retain_graph=True)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=500)
|
||||
# compute flops
|
||||
flops_per_matmul = 2. * Z * H * N_CTX * N_CTX * D_HEAD * 0.5
|
||||
total_flops = 2 * flops_per_matmul
|
||||
if is_backward:
|
||||
total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
|
||||
cur_gpu_perf = total_flops / ms * 1e-9
|
||||
# maximum flops
|
||||
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
|
||||
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, mode, dtype_str)]
|
||||
assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05)
|
||||
|
||||
@@ -1034,7 +1034,7 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
|
||||
@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype",
|
||||
[(*shape, 4, False, False, epilogue, allow_tf32, dtype)
|
||||
for shape in [(64, 64, 64)]
|
||||
for shape in [(64, 64, 64), (16, 16, 16)]
|
||||
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
|
||||
for allow_tf32 in [True, False]
|
||||
for dtype in ['float16', 'float32']
|
||||
@@ -1063,6 +1063,9 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
pytest.skip("Only test int8 on devices with sm >= 80")
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||
if capability[0] == 7:
|
||||
if (M, N, K, num_warps) == (128, 256, 32, 8):
|
||||
pytest.skip("shared memory out of resource")
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
|
||||
|
||||
@@ -1298,18 +1301,18 @@ def test_noop(device='cuda'):
|
||||
kernel[(1, )](x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ['cuda', 'cpu'])
|
||||
@pytest.mark.parametrize("device", ['cuda', 'cpu', 'cpu_pinned'])
|
||||
def test_pointer_arguments(device):
|
||||
@triton.jit
|
||||
def kernel(x):
|
||||
pass
|
||||
x = torch.empty(1024, device=device)
|
||||
result = True
|
||||
try:
|
||||
kernel[(1,)](x)
|
||||
except ValueError:
|
||||
result = True if device == 'cpu' else False
|
||||
assert result
|
||||
pin_memory = 'pinned' in device
|
||||
x = torch.empty(1024, device=device.split('_')[0], pin_memory=pin_memory)
|
||||
if device == "cpu":
|
||||
with pytest.raises(ValueError):
|
||||
kernel[(1,)](x)
|
||||
else:
|
||||
kernel[(1, )](x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value, value_type", [
|
||||
|
||||
@@ -16,7 +16,6 @@ import tempfile
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
from sysconfig import get_paths
|
||||
from typing import Any, Callable, Dict, Tuple, Union
|
||||
|
||||
import setuptools
|
||||
@@ -981,32 +980,32 @@ def ast_to_ttir(fn, signature, specialization, constants):
|
||||
return optimize_triton_ir(mod)
|
||||
|
||||
|
||||
def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability):
|
||||
def ttir_to_ttgir(mod, num_warps):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.add_convert_triton_to_tritongpu_pass(num_warps)
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def optimize_ttgir(mod, num_stages, compute_capability):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_coalesce_pass()
|
||||
# The combine pass converts blocked layout to mma layout
|
||||
# for dot ops so that pipeline can get shared memory swizzled correctly.
|
||||
pm.add_tritongpu_combine_pass(compute_capability)
|
||||
pm.add_tritongpu_coalesce_pass()
|
||||
pm.add_tritongpu_accelerate_matmul_pass(compute_capability)
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
pm.add_tritongpu_fuse_transpositions_pass()
|
||||
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||
# Prefetch must be done after pipeline pass because pipeline pass
|
||||
# extracts slices from the original tensor.
|
||||
pm.add_tritongpu_prefetch_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_tritongpu_combine_pass(compute_capability)
|
||||
pm.add_licm_pass()
|
||||
pm.add_tritongpu_combine_pass(compute_capability)
|
||||
pm.add_cse_pass()
|
||||
pm.add_tritongpu_fuse_transpositions_pass()
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
pm.add_tritongpu_decompose_conversions_pass()
|
||||
if compute_capability // 10 == 7:
|
||||
# The update_mma_for_volta pass helps to compute some information for MMA encoding specifically for MMAv1
|
||||
# NOTE this pass should be placed after all the passes those modifies mma layout
|
||||
pm.add_tritongpu_update_mma_for_volta_pass()
|
||||
pm.add_tritongpu_reorder_instructions_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_symbol_dce_pass()
|
||||
pm.add_tritongpu_reorder_instructions_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
@@ -1459,11 +1458,6 @@ def default_cache_dir():
|
||||
return os.path.join(os.environ["HOME"], ".triton", "cache")
|
||||
|
||||
|
||||
def default_cuda_dir():
|
||||
default_dir = "/usr/local/cuda"
|
||||
return os.getenv("CUDA_HOME", default=default_dir)
|
||||
|
||||
|
||||
class CacheManager:
|
||||
|
||||
def __init__(self, key):
|
||||
@@ -1537,14 +1531,15 @@ def _build(name, src, srcdir):
|
||||
hip_include_dir = os.path.join(hip_home_dirs(), "include")
|
||||
else:
|
||||
cuda_lib_dirs = libcuda_dirs()
|
||||
cuda_path = os.environ.get('CUDA_PATH', default_cuda_dir())
|
||||
base_dir = os.path.dirname(__file__)
|
||||
cuda_path = os.path.join(base_dir, "third_party", "cuda")
|
||||
|
||||
cu_include_dir = os.path.join(cuda_path, "include")
|
||||
triton_include_dir = os.path.join(os.path.dirname(__file__), "include")
|
||||
cuda_header = os.path.join(cu_include_dir, "cuda.h")
|
||||
triton_cuda_header = os.path.join(triton_include_dir, "cuda.h")
|
||||
if not os.path.exists(cuda_header) and os.path.exists(triton_cuda_header):
|
||||
cu_include_dir = triton_include_dir
|
||||
|
||||
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
||||
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
|
||||
# try to avoid setuptools if possible
|
||||
@@ -1556,7 +1551,17 @@ def _build(name, src, srcdir):
|
||||
cc = gcc if gcc is not None else clang
|
||||
if cc is None:
|
||||
raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.")
|
||||
py_include_dir = get_paths()["include"]
|
||||
# This function was renamed and made public in Python 3.10
|
||||
if hasattr(sysconfig, 'get_default_scheme'):
|
||||
scheme = sysconfig.get_default_scheme()
|
||||
else:
|
||||
scheme = sysconfig._get_default_scheme()
|
||||
# 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install
|
||||
# path changes to include 'local'. This change is required to use triton with system-wide python.
|
||||
if scheme == 'posix_local':
|
||||
scheme = 'posix_prefix'
|
||||
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
|
||||
|
||||
if torch.version.hip is not None:
|
||||
ret = subprocess.check_call([cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{hip_lib_dir}", "-lamdhip64", "-o", so])
|
||||
else:
|
||||
@@ -1751,30 +1756,30 @@ def compile(fn, **kwargs):
|
||||
raise RuntimeError('gfx_arch is None (not specified)')
|
||||
stages = {
|
||||
"ast": (lambda path: fn, None),
|
||||
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
||||
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
|
||||
"ttir": (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
||||
"ttgir": (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, capability)),
|
||||
"llir": (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, capability)),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, capability)),
|
||||
"amdgcn": (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_amdgcn_and_hsaco(src, gfx_arch,
|
||||
gfx_arch_full_details[0],
|
||||
lambda src: llir_to_amdgcn_and_hsaco(src, gfx_arch,
|
||||
gfx_arch_full_details[0],
|
||||
gfx_arch_full_details[2])),
|
||||
}
|
||||
else:
|
||||
stages = {
|
||||
"ast": (lambda path: fn, None),
|
||||
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
||||
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
|
||||
"ttir": (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
||||
"ttgir": (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, capability)),
|
||||
"llir": (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, capability)),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, capability)),
|
||||
"ptx": (lambda path: Path(path).read_text(),
|
||||
lambda src: llir_to_ptx(src, capability)),
|
||||
"cubin": (lambda path: Path(path).read_bytes(),
|
||||
lambda src: ptx_to_cubin(src, capability))
|
||||
lambda src: ptx_to_cubin(src, capability))
|
||||
}
|
||||
|
||||
# find out the signature of the function
|
||||
|
||||
@@ -9,6 +9,8 @@ from triton._C.libtriton.triton import ir
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
TRITON_MAX_TENSOR_NUMEL = 131072
|
||||
|
||||
|
||||
def _to_tensor(x, builder):
|
||||
if isinstance(x, bool):
|
||||
@@ -254,6 +256,8 @@ class block_type(dtype):
|
||||
self.numel = 1
|
||||
for s in self.shape:
|
||||
self.numel *= s
|
||||
if self.numel > TRITON_MAX_TENSOR_NUMEL:
|
||||
raise ValueError(f"numel ({self.numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})")
|
||||
|
||||
self.name = self.__str__()
|
||||
|
||||
@@ -702,12 +706,13 @@ def num_programs(axis, _builder=None):
|
||||
@builtin
|
||||
def arange(start, end, _builder=None):
|
||||
"""
|
||||
Returns contiguous values within the open interval [:code:`start`, :code:`end`).
|
||||
Returns contiguous values within the left-closed and right-open interval [:code:`start`, :code:`end`). \
|
||||
End - Start must be less than or equal to TRITON_MAX_TENSOR_NUMEL = 131072
|
||||
|
||||
:param start: Start of the interval. Must be a power of two.
|
||||
:type start: int
|
||||
:param stop: End of the interval. Must be a power of two >= start.
|
||||
:type stop: int
|
||||
:type start: int32
|
||||
:param end: End of the interval. Must be a power of two > start.
|
||||
:type end: int32
|
||||
"""
|
||||
start = _constexpr_to_value(start)
|
||||
end = _constexpr_to_value(end)
|
||||
|
||||
@@ -5,9 +5,10 @@ from .. import impl
|
||||
from . import core, extern
|
||||
|
||||
if torch.version.hip is not None:
|
||||
LIBDEVICE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "cuda2gcn.bc")
|
||||
LOCAL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "cuda2gcn.bc")
|
||||
else:
|
||||
LIBDEVICE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "third_party", "cuda", "lib", "libdevice.10.bc")
|
||||
LOCAL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "third_party", "cuda", "lib", "libdevice.10.bc")
|
||||
LIBDEVICE_PATH = os.getenv("TRITON_LIBDEVICE_PATH", LOCAL_PATH)
|
||||
|
||||
@impl.extern
|
||||
def clz(arg0, _builder=None):
|
||||
|
||||
@@ -480,6 +480,12 @@ def not_equal(input: tl.tensor,
|
||||
def arange(start: int, end: int, builder: ir.builder) -> tl.tensor:
|
||||
if not isinstance(start, int) or not isinstance(end, int):
|
||||
raise ValueError("arange's arguments must be of type tl.constexpr")
|
||||
is_start_int64 = bool(start >> 32)
|
||||
is_end_int64 = bool(end >> 32)
|
||||
if is_start_int64 or is_end_int64:
|
||||
raise ValueError("arange must fit in int32")
|
||||
if end <= start:
|
||||
raise ValueError("arange's end argument must be greater than the start argument")
|
||||
|
||||
shape = [end - start]
|
||||
ret_ty = tl.block_type(tl.int32, shape)
|
||||
@@ -655,6 +661,8 @@ def cast(input: tl.tensor,
|
||||
dst_ty: tl.dtype,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
src_ty = input.type
|
||||
if isinstance(dst_ty, tl.constexpr):
|
||||
dst_ty = dst_ty.value
|
||||
if src_ty.is_block():
|
||||
dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes())
|
||||
if src_ty == dst_ty:
|
||||
|
||||
@@ -62,9 +62,9 @@ class Autotuner(KernelInterface):
|
||||
self.hook(args)
|
||||
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||
try:
|
||||
return do_bench(kernel_call)
|
||||
return do_bench(kernel_call, percentiles=(0.5, 0.2, 0.8))
|
||||
except OutOfResources:
|
||||
return float('inf')
|
||||
return (float('inf'), float('inf'), float('inf'))
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
self.nargs = dict(zip(self.arg_names, args))
|
||||
|
||||
@@ -93,7 +93,11 @@ def assert_almost_equal(x, y, decimal=2, err_msg=''):
|
||||
npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal)
|
||||
|
||||
|
||||
def allclose(x, y, tol=1e-2):
|
||||
def allclose(x, y, atol=0, rtol=1e-2):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = torch.tensor(x)
|
||||
if not isinstance(y, torch.Tensor):
|
||||
y = torch.tensor(y)
|
||||
if x.dtype != y.dtype:
|
||||
raise RuntimeError(f'{x.dtype} did not match with {x.dtype}')
|
||||
if x.shape != y.shape:
|
||||
@@ -101,12 +105,11 @@ def allclose(x, y, tol=1e-2):
|
||||
if x.dtype == torch.bool:
|
||||
return torch.sum(x ^ y) == 0
|
||||
if x.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
||||
tol = 0
|
||||
rtol = 0
|
||||
diff = abs(x - y)
|
||||
x_max = torch.max(x)
|
||||
y_max = torch.max(y)
|
||||
err = torch.max(diff) / torch.max(x_max, y_max)
|
||||
return err <= tol
|
||||
return torch.max(diff) <= atol + rtol * torch.max(x_max, y_max)
|
||||
|
||||
|
||||
def nvsmi(attrs):
|
||||
|
||||
@@ -83,7 +83,8 @@ if __name__ == '__main__':
|
||||
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
|
||||
|
||||
# triton-ir -> triton-gpu-ir
|
||||
module = triton.compiler.ttir_to_ttgir(module, num_warps=4, num_stages=3, compute_capability=args.sm)
|
||||
module = triton.compiler.ttir_to_ttgir(module, num_warps=4)
|
||||
module = triton.compiler.optimize_ttgir(module, num_stages=3, compute_capability=args.sm)
|
||||
if args.target == 'triton-gpu-ir':
|
||||
print(module.str())
|
||||
sys.exit(0)
|
||||
|
||||
@@ -289,7 +289,8 @@ class Libdevice(ExternLibrary):
|
||||
# return extern.dispatch("libdevice", <path>, <args>, <arg_type_symbol_dict>, _builder)
|
||||
import_str = "from . import core, extern\n"
|
||||
import_str += "import os\n"
|
||||
header_str = "LIBDEVICE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), \"..\", \"third_party\", \"cuda\", \"lib\", \"libdevice.10.bc\")"
|
||||
header_str = "LOCAL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), \"..\", \"third_party\", \"cuda\", \"lib\", \"libdevice.10.bc\")\n"
|
||||
header_str += "LIBDEVICE_PATH = os.getenv(\"TRITON_LIBDEVICE_PATH\", LOCAL_PATH)\n"
|
||||
func_str = ""
|
||||
for symbols in self._symbol_groups.values():
|
||||
func_str += "@extern.extern\n"
|
||||
|
||||
@@ -223,6 +223,7 @@ class _attention(torch.autograd.Function):
|
||||
BLOCK_DMODEL=Lk, num_warps=num_warps,
|
||||
num_stages=2,
|
||||
)
|
||||
# print(h.asm["ttgir"])
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
ctx.grid = grid
|
||||
@@ -260,6 +261,7 @@ class _attention(torch.autograd.Function):
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
||||
num_stages=1,
|
||||
)
|
||||
# print(h.asm["ttgir"])
|
||||
return dq, dk, dv, None
|
||||
|
||||
|
||||
|
||||
@@ -179,8 +179,8 @@ func.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: for_if_for
|
||||
func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
// CHECK-LABEL: for_for_if
|
||||
func.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
// CHECK: %cst -> %cst
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: %cst_0 -> %cst_0
|
||||
@@ -213,3 +213,34 @@ func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: cf_for
|
||||
func.func @cf_for(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr<f16>, %arg4: !tt.ptr<f16>) {
|
||||
// CHECK: %cst -> %cst
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: %cst_0 -> %cst_0
|
||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
gpu.barrier
|
||||
// CHECK-NEXT: %0 -> %0
|
||||
%0 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: %cst_1 -> %cst_1
|
||||
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: %2 -> %cst,%cst_0,%cst_1
|
||||
// CHECK-NEXT: %3 -> %cst,%cst_0,%cst_1
|
||||
// CHECK-NEXT: %4 -> %cst,%cst_0,%cst_1
|
||||
cf.br ^bb1(%arg0, %cst, %cst_0, %cst_1 : index, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>)
|
||||
^bb1(%1: index, %2: tensor<128x32xf16, #A_SHARED>, %3: tensor<128x32xf16, #A_SHARED>, %4: tensor<128x32xf16, #A_SHARED>): // 2 preds: ^bb0, ^bb2
|
||||
%5 = arith.cmpi slt, %1, %arg1 : index
|
||||
cf.cond_br %5, ^bb2, ^bb3
|
||||
^bb2: // pred: ^bb1
|
||||
%6 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #blocked>
|
||||
gpu.barrier
|
||||
%7 = tt.cat %2, %3 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #blocked>
|
||||
%8 = arith.addi %1, %arg2 : index
|
||||
cf.br ^bb1(%8, %4, %2, %3 : index, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>)
|
||||
^bb3: // pred: ^bb1
|
||||
gpu.barrier
|
||||
// CHECK-NEXT: %9 -> %9
|
||||
%9 = tt.cat %0, %0 {axis = 0 : i64} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -315,8 +315,8 @@ func.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.pt
|
||||
|
||||
// a_shared_init, b_shared_init, and c_shared_init's liveness ranges are span over the entire function before cst2.
|
||||
// So they cannot be reused by cst0 and cst1, but can be reused by cst2.
|
||||
// CHECK-LABEL: for_if_for
|
||||
func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
// CHECK-LABEL: for_for_if
|
||||
func.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
// CHECK: offset = 0, size = 8192
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 8192, size = 8192
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-membar 2>&1 | FileCheck %s
|
||||
// RUN: triton-opt %s -split-input-file --mlir-disable-threading --convert-scf-to-cf -test-print-membar 2>&1 | FileCheck %s
|
||||
|
||||
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#sliceAd0 = #triton_gpu.slice<{dim = 0, parent = #AL}>
|
||||
@@ -45,11 +45,12 @@ func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16
|
||||
func.func @raw_single_block(%A : !tt.ptr<f16>) {
|
||||
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
||||
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK: Membar 5
|
||||
%a2 = triton_gpu.convert_layout %a1 : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #A_SHARED>
|
||||
%0 = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%1 = tt.load %0, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
%2 = triton_gpu.convert_layout %1 : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: triton_gpu.convert_layout
|
||||
%3 = triton_gpu.convert_layout %2 : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #A_SHARED>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -57,46 +58,51 @@ func.func @raw_single_block(%A : !tt.ptr<f16>) {
|
||||
func.func @war_single_block(%A : !tt.ptr<f16>) {
|
||||
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
||||
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK: Membar 5
|
||||
%a2 = triton_gpu.convert_layout %a1 : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL>
|
||||
// a2's liveness range ends here, and a3 and a2 have the same address range.
|
||||
// So it makes sense to have a WAR dependency between a2 and a3.
|
||||
// CHECK-NEXT: Membar 7
|
||||
%a3 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_SHARED>
|
||||
%0 = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%1 = tt.load %0, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
%2 = triton_gpu.convert_layout %1 : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: triton_gpu.convert_layout
|
||||
%3 = triton_gpu.convert_layout %2 : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: %4 = triton_gpu.convert_layout
|
||||
%4 = triton_gpu.convert_layout %1 : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_SHARED>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: scratch
|
||||
func.func @scratch() {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK: Membar 1
|
||||
%a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
// CHECK-NEXT: Membar 3
|
||||
%aa = triton_gpu.convert_layout %a : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL>
|
||||
%b = tt.reduce %aa {redOp = 1 : i32, axis = 0 : i32} : tensor<32x16xf16, #AL> -> tensor<16xf16, #sliceAd0>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%0 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: triton_gpu.convert_layout
|
||||
%1 = triton_gpu.convert_layout %0 : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL>
|
||||
%2 = tt.reduce %1 {redOp = 1 : i32, axis = 0 : i32} : tensor<32x16xf16, #AL> -> tensor<16xf16, #sliceAd0>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: async_wait
|
||||
func.func @async_wait() {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK: Membar 1
|
||||
%a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%0 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
triton_gpu.async_wait {num = 4 : i32}
|
||||
// CHECK-NEXT: Membar 4
|
||||
%a_ = triton_gpu.convert_layout %a : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: triton_gpu.convert_layout
|
||||
%1 = triton_gpu.convert_layout %0 : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: alloc
|
||||
func.func @alloc() {
|
||||
%cst0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
|
||||
%a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
// CHECK: Membar 2
|
||||
%b = triton_gpu.convert_layout %a : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL>
|
||||
%0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
|
||||
%1 = tt.cat %0, %0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: triton_gpu.convert_layout
|
||||
%2 = triton_gpu.convert_layout %1 : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -104,50 +110,58 @@ func.func @alloc() {
|
||||
func.func @extract_slice() {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
|
||||
%index = arith.constant 0 : index
|
||||
%cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK: Membar 3
|
||||
%cst2 = triton_gpu.convert_layout %cst1 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
|
||||
// CHECK-NEXT: Membar 5
|
||||
%cst3 = triton_gpu.convert_layout %cst2 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A_SHARED>
|
||||
%0 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: triton_gpu.convert_layout
|
||||
%1 = triton_gpu.convert_layout %0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: triton_gpu.convert_layout
|
||||
%2 = triton_gpu.convert_layout %1 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A_SHARED>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: trans
|
||||
func.func @trans() {
|
||||
// CHECK-NOT: gpu.barrier
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
|
||||
%b = tt.trans %cst0 : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED_T>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: insert_slice_async
|
||||
func.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
// CHECK-LABEL: insert_slice_async_op
|
||||
func.func @insert_slice_async_op(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
|
||||
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
|
||||
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||
%tensor = triton_gpu.alloc_tensor : tensor<1x16x16xf16, #A_SHARED>
|
||||
%index = arith.constant 0 : i32
|
||||
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<1x16x16xf16, #A_SHARED>
|
||||
// CHECK: Membar 6
|
||||
%b = tt.cat %a, %a {axis = 0} : (tensor<1x16x16xf16, #A_SHARED>, tensor<1x16x16xf16, #A_SHARED>) -> tensor<2x16x16xf16, #A_SHARED>
|
||||
// CHECK: Membar 8
|
||||
%c = tt.cat %b, %b {axis = 0} : (tensor<2x16x16xf16, #A_SHARED>, tensor<2x16x16xf16, #A_SHARED>) -> tensor<4x16x16xf16, #A_SHARED>
|
||||
%3 = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<1x16x16xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%4 = tt.cat %3, %3 {axis = 0} : (tensor<1x16x16xf16, #A_SHARED>, tensor<1x16x16xf16, #A_SHARED>) -> tensor<2x16x16xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%5 = tt.cat %4, %4 {axis = 0} : (tensor<2x16x16xf16, #A_SHARED>, tensor<2x16x16xf16, #A_SHARED>) -> tensor<4x16x16xf16, #A_SHARED>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: insert_slice
|
||||
func.func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
// CHECK-LABEL: insert_slice_op
|
||||
func.func @insert_slice_op(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
|
||||
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
|
||||
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||
%tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
|
||||
%index = arith.constant 0 : index
|
||||
%al = tt.load %a_ptr, %mask, %other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #AL>
|
||||
// CHECK: Membar 6
|
||||
%a = tensor.insert_slice %al into %tensor[%index, 0, 0][1, 16, 16][1, 1, 1]: tensor<16x16xf16, #AL> into tensor<1x16x16xf16, #A_SHARED>
|
||||
// CHECK: Membar 8
|
||||
%b = tt.cat %a, %a {axis = 0} : (tensor<1x16x16xf16, #A_SHARED>, tensor<1x16x16xf16, #A_SHARED>) -> tensor<2x16x16xf16, #A_SHARED>
|
||||
// CHECK: Membar 10
|
||||
%c = tt.cat %b, %b {axis = 0} : (tensor<2x16x16xf16, #A_SHARED>, tensor<2x16x16xf16, #A_SHARED>) -> tensor<4x16x16xf16, #A_SHARED>
|
||||
%2 = tt.load %a_ptr, %mask, %other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #AL>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tensor.insert_slice
|
||||
%3 = tensor.insert_slice %2 into %tensor[%index, 0, 0][1, 16, 16][1, 1, 1]: tensor<16x16xf16, #AL> into tensor<1x16x16xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%4 = tt.cat %3, %3 {axis = 0} : (tensor<1x16x16xf16, #A_SHARED>, tensor<1x16x16xf16, #A_SHARED>) -> tensor<2x16x16xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%5 = tt.cat %4, %4 {axis = 0} : (tensor<2x16x16xf16, #A_SHARED>, tensor<2x16x16xf16, #A_SHARED>) -> tensor<4x16x16xf16, #A_SHARED>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -157,18 +171,21 @@ func.func @multi_blocks(%i1 : i1) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
scf.if %i1 {
|
||||
// CHECK: Membar 2
|
||||
%a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%0 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
scf.yield
|
||||
} else {
|
||||
%cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%cst3 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK-NEXT: Membar 7
|
||||
%b = tt.cat %cst2, %cst3 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%1 = tt.cat %cst2, %cst3 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
scf.yield
|
||||
}
|
||||
// CHECK-NEXT: Membar 10
|
||||
%c = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%2 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -178,12 +195,14 @@ func.func @multi_blocks_join_barrier(%i1 : i1) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
scf.if %i1 {
|
||||
// CHECK: Membar 2
|
||||
%a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%0 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
scf.yield
|
||||
} else {
|
||||
// CHECK-NEXT: Membar 5
|
||||
%a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%1 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
scf.yield
|
||||
}
|
||||
%a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
|
||||
@@ -196,17 +215,42 @@ func.func @multi_blocks_yield(%i1 : i1) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%a = scf.if %i1 -> (tensor<32x16xf16, #A_SHARED>) {
|
||||
// CHECK: Membar 2
|
||||
%a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
scf.yield %a : tensor<32x16xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%0 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
scf.yield %0 : tensor<32x16xf16, #A_SHARED>
|
||||
} else {
|
||||
// CHECK-NEXT: Membar 5
|
||||
%b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
scf.yield %b : tensor<32x16xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%1 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
scf.yield %1 : tensor<32x16xf16, #A_SHARED>
|
||||
}
|
||||
%a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
|
||||
// CHECK-NEXT: Membar 9
|
||||
%b = tt.cat %a, %a {axis = 0} : (tensor<32x16xf16, #A_SHARED>, tensor<32x16xf16, #A_SHARED>) -> tensor<64x16xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%4 = tt.cat %a, %a {axis = 0} : (tensor<32x16xf16, #A_SHARED>, tensor<32x16xf16, #A_SHARED>) -> tensor<64x16xf16, #A_SHARED>
|
||||
return
|
||||
}
|
||||
|
||||
// Even though the entry block doesn't have a barrier, the successors should have barriers
|
||||
// CHECK-LABEL: multi_blocks_entry_no_shared
|
||||
func.func @multi_blocks_entry_no_shared(%i1 : i1) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||
%a = scf.if %i1 -> (tensor<32x16xf16, #A_SHARED>) {
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%0 = tt.cat %cst1, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
scf.yield %0 : tensor<32x16xf16, #A_SHARED>
|
||||
} else {
|
||||
// CHECK-NOT: gpu.barrier
|
||||
// CHECK: arith.constant
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #A_SHARED>
|
||||
scf.yield %cst1 : tensor<32x16xf16, #A_SHARED>
|
||||
}
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%1 = tt.cat %a, %a {axis = 0} : (tensor<32x16xf16, #A_SHARED>, tensor<32x16xf16, #A_SHARED>) -> tensor<64x16xf16, #A_SHARED>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -216,11 +260,14 @@ func.func @multi_blocks_noelse(%i1 : i1) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
scf.if %i1 {
|
||||
// CHECK: Membar 2
|
||||
%a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%0 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
scf.yield
|
||||
}
|
||||
%a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: triton_gpu.convert_layout
|
||||
%1 = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -231,18 +278,21 @@ func.func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) {
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
scf.if %i1 {
|
||||
scf.if %i2 {
|
||||
// CHECK: Membar 2
|
||||
%b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%0 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
scf.yield
|
||||
}
|
||||
scf.yield
|
||||
} else {
|
||||
// CHECK-NEXT: Membar 6
|
||||
%b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%1 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
scf.yield
|
||||
}
|
||||
// CHECK-NEXT: Membar 9
|
||||
%a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: triton_gpu.convert_layout
|
||||
%2 = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -252,8 +302,9 @@ func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
|
||||
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
|
||||
// CHECK-NEXT: Membar 3
|
||||
%cst0 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%5 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
return
|
||||
@@ -265,17 +316,20 @@ func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
|
||||
func.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: Membar 2
|
||||
%cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
|
||||
%cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #AL>
|
||||
// CHECK-NEXT: Membar 6
|
||||
%cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #AL>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%7 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #AL>
|
||||
scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
// CHECK-NEXT: Membar 9
|
||||
%cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%9 = tt.cat %0, %0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -285,41 +339,162 @@ func.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
|
||||
func.func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: Membar 2
|
||||
%cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
|
||||
// CHECK-NEXT: Membar 5
|
||||
%cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: Membar 7
|
||||
%cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%6 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%7 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
// CHECK-NEXT: Membar 10
|
||||
%cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%9 = tt.cat %0, %0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: for_reuse_nested
|
||||
func.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: Membar 2
|
||||
%cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
|
||||
// CHECK-NEXT: Membar 5
|
||||
%cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%6 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
%a_shared_next, %b_shared_next, %c_shared_next = scf.for %ivv = %lb to %ub step %step iter_args(%a_shared_nested = %a_shared_init, %b_shared_nested = %b_shared_init, %c_shared_nested = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
|
||||
// CHECK-NEXT: Membar 7
|
||||
%cst2 = tt.cat %a_shared_nested, %b_shared_nested {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%12 = tt.cat %a_shared_nested, %b_shared_nested {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED>
|
||||
scf.yield %c_shared_nested, %a_shared_nested, %b_shared_nested : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
// CHECK-NEXT: Membar 11
|
||||
%cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%15 = tt.cat %0, %0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
|
||||
return
|
||||
}
|
||||
|
||||
// repeatedly write to the same shared memory addresses
|
||||
// CHECK-LABEL: for_for_if
|
||||
func.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
|
||||
%c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (tensor<128x32xf16, #A_SHARED>) {
|
||||
%c_shared_next_next = scf.if %i1 -> tensor<128x32xf16, #A_SHARED> {
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: arith.constant
|
||||
%cst0 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
scf.yield %cst0 : tensor<128x32xf16, #A_SHARED>
|
||||
} else {
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: arith.constant
|
||||
%cst0 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
scf.yield %cst0 : tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
scf.yield %c_shared_next_next : tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
scf.yield %a_shared, %b_shared, %c_shared_next : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// c_block_next can either be converted from c_shared_init or c_shared_next_next
|
||||
// CHECK-LABEL: for_if_for
|
||||
func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
|
||||
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
%c_blocked = triton_gpu.convert_layout %c_shared_init : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL>
|
||||
|
||||
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
|
||||
%c_shared_next_next = scf.if %i1 -> tensor<128x32xf16, #A_SHARED> {
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: arith.constant
|
||||
%cst0 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
scf.yield %cst0 : tensor<128x32xf16, #A_SHARED>
|
||||
} else {
|
||||
%c_shared_ = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (tensor<128x32xf16, #A_SHARED>) {
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: triton_gpu.convert_layout
|
||||
%c_blocked_next = triton_gpu.convert_layout %c_shared_next : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL>
|
||||
scf.yield %c_shared : tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
scf.yield %c_shared_ : tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
// CHECK-NOT: gpu.barrier
|
||||
%b_blocked_next = triton_gpu.convert_layout %b_shared: (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL>
|
||||
scf.yield %a_shared, %b_shared, %c_shared_next_next : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: cf_if
|
||||
func.func @cf_if(%i1 : i1) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
|
||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
|
||||
cf.cond_br %i1, ^bb1, ^bb2
|
||||
^bb1: // pred: ^bb0
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%0 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
|
||||
cf.br ^bb2
|
||||
^bb2: // 2 preds: ^bb0, ^bb1
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: triton_gpu.convert_layout
|
||||
%1 = triton_gpu.convert_layout %cst : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<16x16xf16, #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>>
|
||||
return
|
||||
}
|
||||
|
||||
func.func @cf_if_else(%i1 : i1) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
|
||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
|
||||
cf.cond_br %i1, ^bb1, ^bb2
|
||||
^bb1: // pred: ^bb0
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%0 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
|
||||
cf.br ^bb3(%0 : tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>)
|
||||
^bb2: // pred: ^bb0
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%1 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
|
||||
cf.br ^bb3(%1 : tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>)
|
||||
^bb3(%2: tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>): // 2 preds: ^bb1, ^bb2
|
||||
cf.br ^bb4
|
||||
^bb4: // pred: ^bb3
|
||||
%3 = triton_gpu.convert_layout %cst : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<16x16xf16, #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%4 = tt.cat %2, %2 {axis = 0 : i64} : (tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<64x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
|
||||
return
|
||||
}
|
||||
|
||||
func.func @cf_if_else_return(%i1 : i1) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
|
||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
|
||||
cf.cond_br %i1, ^bb1, ^bb2
|
||||
^bb1: // pred: ^bb0
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%0 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
|
||||
return
|
||||
^bb2: // pred: ^bb0
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.cat
|
||||
%1 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-coalesce -canonicalize | FileCheck %s
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-coalesce | FileCheck %s
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-combine 2>&1 | FileCheck %s
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-remove-layout-conversions 2>&1 | FileCheck %s
|
||||
|
||||
#layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
#layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu -tritongpu-combine -tritongpu-pipeline=num-stages=3 -tritongpu-combine -test-print-allocation 2>&1 | FileCheck %s
|
||||
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu -tritongpu-remove-layout-conversions -tritongpu-pipeline=num-stages=3 -test-print-allocation 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK: offset = 0, size = 49152
|
||||
// CHECK: offset = 49152, size = 49152
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-combine -tritongpu-update-mma-for-volta 2>&1 | FileCheck %s
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-fuse-transposition -tritongpu-update-mma-for-volta 2>&1 | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
@@ -65,8 +65,19 @@ struct TestAliasPass
|
||||
};
|
||||
|
||||
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
|
||||
if (op->getNumResults() < 1)
|
||||
if (op->getNumResults() < 1) {
|
||||
// cond br, br
|
||||
if (auto branch = dyn_cast<BranchOpInterface>(op)) {
|
||||
auto *block = branch->getBlock();
|
||||
for (auto arg : llvm::enumerate(block->getArguments())) {
|
||||
auto operand = block->getArgument(arg.index());
|
||||
auto opNames = getAllocOpNames(operand);
|
||||
auto argName = getValueOperandName(arg.value(), state);
|
||||
print(argName, opNames, os);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
|
||||
for (auto arg : llvm::enumerate(forOp.getRegionIterArgs())) {
|
||||
auto operand = forOp.getOpOperandForRegionIterArg(arg.value()).get();
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
#include "triton/Analysis/Membar.h"
|
||||
|
||||
@@ -24,21 +26,13 @@ struct TestMembarPass
|
||||
// Convert to std::string can remove quotes from op_name
|
||||
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
|
||||
os << opName << "\n";
|
||||
|
||||
// Print all ops after membar pass
|
||||
Allocation allocation(operation);
|
||||
MembarAnalysis membarPass(&allocation);
|
||||
membarPass.run();
|
||||
|
||||
size_t operationId = 0;
|
||||
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
|
||||
if (isa<gpu::BarrierOp>(op)) {
|
||||
os << "Membar " << operationId << "\n";
|
||||
}
|
||||
if (op->getNumRegions() == 0) {
|
||||
// Don't count parent Operation to simplify the test.
|
||||
operationId++;
|
||||
}
|
||||
return;
|
||||
});
|
||||
os << *operation << "\n";
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user