diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 8a1389e8c..df7f3a10c 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -3,9 +3,10 @@ name: Integration Tests on: workflow_dispatch: pull_request: - branches: - - main - - triton-mlir + branches: [main] + merge_group: + branches: [main] + types: [checks_requested] concurrency: group: ${{ github.ref }} @@ -21,7 +22,7 @@ jobs: id: set-matrix run: | if [ x"${{ github.repository }}" == x"openai/triton" ]; then - echo '::set-output name=matrix::[["self-hosted", "A10"], ["self-hosted", "V100"], "macos-10.15"]' + echo '::set-output name=matrix::[["self-hosted", "A100"], ["self-hosted", "V100"], "macos-10.15"]' else echo '::set-output name=matrix::["ubuntu-latest", "macos-10.15"]' fi @@ -43,29 +44,33 @@ jobs: run: | rm -rf ~/.triton/cache/ + - name: Update path + run: | + echo "$HOME/.local/bin/" >> $GITHUB_PATH + - name: Check imports if: ${{ matrix.runner != 'macos-10.15' }} run: | - pip install isort + pip3 install isort isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 ) - name: Check python style if: ${{ matrix.runner != 'macos-10.15' }} run: | - pip install autopep8 + pip3 install autopep8 autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 ) - name: Check cpp style if: ${{ matrix.runner != 'macos-10.15' }} run: | - pip install clang-format + pip3 install clang-format find . -regex '.*\.\(cpp\|hpp\|h\|cc\)' -not -path "./python/triton/*" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file --dry-run -Werror -i || (echo '::error title=Style issues:: Please run `find . -regex ".*\.\(cpp\|hpp\|h\|cc\)" -not -path "./python/triton/*" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file -i`' ; exit 1) - name: Flake8 if: ${{ matrix.runner != 'macos-10.15' }} run: | - pip install flake8 + pip3 install flake8 flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 ) - name: Install Triton @@ -94,3 +99,12 @@ jobs: cd python/ cd "build/$(ls build)" ctest + + - name: Regression tests + if: ${{ contains(matrix.runner, 'A100') }} + run: | + cd python/test/regression + sudo nvidia-smi -i 0 -pm 1 + sudo nvidia-smi -i 0 --lock-gpu-clocks=1350,1350 + pytest -vs . + sudo nvidia-smi -i 0 -rgc \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index e254712d5..81aeb3d59 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,8 +51,8 @@ include_directories(${PYBIND11_INCLUDE_DIR}) if(WIN32) SET(BUILD_SHARED_LIBS OFF) - include_directories(${CMAKE_CURRENT_SOURCE_DIR}/deps/dlfcn-win32/src) - add_subdirectory(deps/dlfcn-win32/src ${CMAKE_BINARY_DIR}/dlfcn-win32) + find_package(dlfcn-win32 REQUIRED) + set(CMAKE_DL_LIBS dlfcn-win32::dl) endif() set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17 -fvisibility=hidden -fvisibility-inlines-hidden") @@ -215,8 +215,7 @@ get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) if(TRITON_BUILD_PYTHON_MODULE) add_library(triton SHARED ${PYTHON_SRC}) - - target_link_libraries(triton + set(TRITON_LIBRARIES TritonAnalysis TritonTransforms TritonGPUTransforms @@ -237,16 +236,25 @@ if(TRITON_BUILD_PYTHON_MODULE) MLIRROCDLToLLVMIRTranslation MLIRIR ) - - target_link_options(triton PRIVATE ${LLVM_LDFLAGS}) - if(WIN32) - target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32 + target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} ${CMAKE_DL_LIBS} + ${TRITON_LIBRARIES} + ) elseif(APPLE) - target_link_libraries(triton ${LLVM_LIBRARIES} z) + target_link_libraries(triton ${LLVM_LIBRARIES} z + ${TRITON_LIBRARIES} + ) else() - target_link_libraries(triton ${LLVM_LIBRARIES} z stdc++fs) + target_link_libraries(triton ${LLVM_LIBRARIES} z stdc++fs + ${TRITON_LIBRARIES} + ) endif() + + target_link_options(triton PRIVATE ${LLVM_LDFLAGS}) +endif() + +if (UNIX AND NOT APPLE) + set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--exclude-libs,ALL") endif() if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32) diff --git a/bin/triton-opt.cpp b/bin/triton-opt.cpp index c6a2b335d..6ec638940 100644 --- a/bin/triton-opt.cpp +++ b/bin/triton-opt.cpp @@ -32,7 +32,7 @@ int main(int argc, char **argv) { // TODO: register Triton & TritonGPU passes mlir::DialectRegistry registry; - registry.insert(); diff --git a/include/triton/Analysis/Membar.h b/include/triton/Analysis/Membar.h index ceb192753..ccec8e7ab 100644 --- a/include/triton/Analysis/Membar.h +++ b/include/triton/Analysis/Membar.h @@ -36,28 +36,29 @@ public: void run(); private: - struct RegionInfo { + struct BlockInfo { using BufferIdSetT = Allocation::BufferIdSetT; BufferIdSetT syncReadBuffers; BufferIdSetT syncWriteBuffers; - RegionInfo() = default; - RegionInfo(const BufferIdSetT &syncReadBuffers, - const BufferIdSetT &syncWriteBuffers) + BlockInfo() = default; + BlockInfo(const BufferIdSetT &syncReadBuffers, + const BufferIdSetT &syncWriteBuffers) : syncReadBuffers(syncReadBuffers), syncWriteBuffers(syncWriteBuffers) { } - /// Unions two RegionInfo objects. - void join(const RegionInfo &other) { + /// Unions two BlockInfo objects. + BlockInfo &join(const BlockInfo &other) { syncReadBuffers.insert(other.syncReadBuffers.begin(), other.syncReadBuffers.end()); syncWriteBuffers.insert(other.syncWriteBuffers.begin(), other.syncWriteBuffers.end()); + return *this; } - /// Returns true if buffers in two RegionInfo objects are intersected. - bool isIntersected(const RegionInfo &other, Allocation *allocation) const { + /// Returns true if buffers in two BlockInfo objects are intersected. + bool isIntersected(const BlockInfo &other, Allocation *allocation) const { return /*RAW*/ isIntersected(syncWriteBuffers, other.syncReadBuffers, allocation) || /*WAR*/ @@ -74,6 +75,14 @@ private: syncWriteBuffers.clear(); } + /// Compares two BlockInfo objects. + bool operator==(const BlockInfo &other) const { + return syncReadBuffers == other.syncReadBuffers && + syncWriteBuffers == other.syncWriteBuffers; + } + + bool operator!=(const BlockInfo &other) const { return !(*this == other); } + private: /// Returns true if buffers in two sets are intersected. bool isIntersected(const BufferIdSetT &lhs, const BufferIdSetT &rhs, @@ -99,19 +108,19 @@ private: /// op5 /// op6 /// op7 - /// region2 and region3 started with the information of region1. - /// Each region is analyzed separately and keeps their own copy of the - /// information. At op7, we union the information of the region2 and region3 - /// and update the information of region1. - void dfsOperation(Operation *operation, RegionInfo *blockInfo, - OpBuilder *builder); + /// TODO: Explain why we don't use ForwardAnalysis: + void resolve(Operation *operation, OpBuilder *builder); - /// Updates the RegionInfo operation based on the operation. - void transfer(Operation *operation, RegionInfo *blockInfo, - OpBuilder *builder); + /// Updates the BlockInfo operation based on the operation. + void update(Operation *operation, BlockInfo *blockInfo, OpBuilder *builder); + + /// Collects the successors of the terminator + void visitTerminator(Operation *operation, SmallVector &successors); private: Allocation *allocation; + DenseMap inputBlockInfoMap; + DenseMap outputBlockInfoMap; }; } // namespace mlir diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 5aa54c928..e2ecb6e53 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -6,10 +6,10 @@ include "triton/Dialect/Triton/IR/TritonTypes.td" include "triton/Dialect/Triton/IR/TritonAttrDefs.td" include "triton/Dialect/Triton/IR/TritonInterfaces.td" include "mlir/IR/OpBase.td" -include "mlir/Interfaces/SideEffectInterfaces.td" // NoMemoryEffect +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType -include "mlir/Interfaces/SideEffectInterfaces.td" // NoMemoryEffect +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface // @@ -29,7 +29,7 @@ class TT_Op traits = []> : // extui, extsi, tructi def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, SameOperandsAndResultEncoding, - NoMemoryEffect, + Pure, /*DeclareOpInterfaceMethods*/]> { let summary = "Cast int64 to pointer"; @@ -42,7 +42,7 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, SameOperandsAndResultEncoding, - NoMemoryEffect, + Pure, /*DeclareOpInterfaceMethods*/]> { let summary = "Cast pointer to int64"; @@ -56,7 +56,7 @@ def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, // arith.bitcast doesn't support pointers def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape, SameOperandsAndResultEncoding, - NoMemoryEffect, + Pure, /*DeclareOpInterfaceMethods*/]> { let summary = "Cast between types of the same bitwidth"; @@ -71,7 +71,7 @@ def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape, def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, SameOperandsAndResultEncoding, - NoMemoryEffect, + Pure, DeclareOpInterfaceMethods]> { let summary = "Floating point casting for custom types"; @@ -95,7 +95,7 @@ def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, // def TT_AddPtrOp : TT_Op<"addptr", - [NoMemoryEffect, + [Pure, SameOperandsAndResultShape, SameOperandsAndResultEncoding, TypesMatchWith<"result type matches ptr type", @@ -222,7 +222,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>, // // Shape Manipulation Ops // -def TT_SplatOp : TT_Op<"splat", [NoMemoryEffect, +def TT_SplatOp : TT_Op<"splat", [Pure, SameOperandsAndResultElementType, SameOperandsAndResultEncoding]> { let summary = "splat"; @@ -236,7 +236,7 @@ def TT_SplatOp : TT_Op<"splat", [NoMemoryEffect, let hasFolder = 1; } -def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoMemoryEffect, +def TT_ExpandDimsOp : TT_Op<"expand_dims", [Pure, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { let summary = "expand_dims"; @@ -248,6 +248,7 @@ def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoMemoryEffect, let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)"; } +// view is not `pure` because it may reorder elements def TT_ViewOp : TT_Op<"view", [NoMemoryEffect, SameOperandsAndResultElementType]> { let summary = "view"; @@ -260,7 +261,7 @@ def TT_ViewOp : TT_Op<"view", [NoMemoryEffect, } -def TT_BroadcastOp : TT_Op<"broadcast", [NoMemoryEffect, +def TT_BroadcastOp : TT_Op<"broadcast", [Pure, SameOperandsAndResultElementType, SameOperandsAndResultEncoding]> { let summary = "broadcast. No left-padding as of now."; @@ -274,6 +275,7 @@ def TT_BroadcastOp : TT_Op<"broadcast", [NoMemoryEffect, let hasFolder = 1; } +// cat is not `pure` because it may reorder elements def TT_CatOp : TT_Op<"cat", [NoMemoryEffect, SameOperandsAndResultElementType]> { let summary = "concatenate 2 tensors"; @@ -285,7 +287,7 @@ def TT_CatOp : TT_Op<"cat", [NoMemoryEffect, let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)"; } -def TT_TransOp : TT_Op<"trans", [NoMemoryEffect, +def TT_TransOp : TT_Op<"trans", [Pure, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { @@ -301,7 +303,7 @@ def TT_TransOp : TT_Op<"trans", [NoMemoryEffect, // // SPMD Ops // -def TT_GetProgramIdOp : TT_Op<"get_program_id", [NoMemoryEffect]> { +def TT_GetProgramIdOp : TT_Op<"get_program_id", [Pure]> { let arguments = (ins I32Attr:$axis); let results = (outs I32:$result); @@ -309,7 +311,7 @@ def TT_GetProgramIdOp : TT_Op<"get_program_id", [NoMemoryEffect]> { let assemblyFormat = "attr-dict `:` type($result)"; } -def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [NoMemoryEffect]> { +def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [Pure]> { let arguments = (ins I32Attr:$axis); let results = (outs I32:$result); @@ -320,7 +322,7 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [NoMemoryEffect]> { // // Dot Op // -def TT_DotOp : TT_Op<"dot", [NoMemoryEffect, +def TT_DotOp : TT_Op<"dot", [Pure, DeclareOpInterfaceMethods, TypesMatchWith<"result's type matches accumulator's type", "d", "c", "$_self">]> { @@ -340,7 +342,7 @@ def TT_DotOp : TT_Op<"dot", [NoMemoryEffect, // // Reduce Op // -def TT_ReduceOp : TT_Op<"reduce", [NoMemoryEffect, +def TT_ReduceOp : TT_Op<"reduce", [Pure, DeclareOpInterfaceMethods]> { let summary = "reduce"; @@ -364,7 +366,7 @@ def TT_ReduceOp : TT_Op<"reduce", [NoMemoryEffect, // // External elementwise op // -def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoMemoryEffect, Elementwise, SameOperandsAndResultShape, +def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [Pure, Elementwise, SameOperandsAndResultShape, SameOperandsAndResultEncoding, SameVariadicOperandSize]> { let summary = "ext_elemwise"; @@ -386,7 +388,7 @@ def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoMemoryEffect, Elementwise, SameO // Make Range Op // // TODO: should have ConstantLike as Trait -def TT_MakeRangeOp : TT_Op<"make_range", [NoMemoryEffect]> { +def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> { let summary = "make range"; let description = [{ diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 68a1a942d..5732b6299 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -7,7 +7,7 @@ include "mlir/Dialect/Arith/IR/ArithBase.td" include "triton/Dialect/Triton/IR/TritonTypes.td" include "triton/Dialect/Triton/IR/TritonAttrDefs.td" include "mlir/IR/OpBase.td" -include "mlir/Interfaces/SideEffectInterfaces.td" // NoMemoryEffect +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType def ResultsAreSharedEncoding: NativeOpTrait<"ResultsAreSharedEncoding">; @@ -18,13 +18,15 @@ class TTG_Op traits = []> : def TTG_ConvertLayoutOp : TTG_Op<"convert_layout", [SameOperandsAndResultShape, SameOperandsAndResultElementType, - NoMemoryEffect]> { + Pure]> { let summary = "convert layout"; let arguments = (ins TT_Tensor:$src); let results = (outs TT_Tensor:$result); + let hasCanonicalizeMethod = 1; + let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)"; } @@ -59,7 +61,7 @@ def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> { // This is needed because these ops don't // handle encodings // e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td#L111 -def TTG_CmpIOp : TTG_Op<"cmpi", [NoMemoryEffect, Elementwise, +def TTG_CmpIOp : TTG_Op<"cmpi", [Pure, Elementwise, SameOperandsAndResultShape, SameOperandsAndResultEncoding]> { let summary = "integer comparison operation"; @@ -73,7 +75,7 @@ def TTG_CmpIOp : TTG_Op<"cmpi", [NoMemoryEffect, Elementwise, let results = (outs TT_BoolLike:$result); } -def TTG_CmpFOp : TTG_Op<"cmpf", [NoMemoryEffect, Elementwise, +def TTG_CmpFOp : TTG_Op<"cmpf", [Pure, Elementwise, SameOperandsAndResultShape, SameOperandsAndResultEncoding]> { let summary = "floating-point comparison operation"; @@ -88,7 +90,7 @@ def TTG_CmpFOp : TTG_Op<"cmpf", [NoMemoryEffect, Elementwise, } // TODO: migrate to arith::SelectOp on LLVM16 -def TTG_SelectOp : TTG_Op<"select", [NoMemoryEffect, Elementwise, +def TTG_SelectOp : TTG_Op<"select", [Pure, Elementwise, SameOperandsAndResultShape, SameOperandsAndResultEncoding]> { let summary = "select operation"; diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.h b/include/triton/Dialect/TritonGPU/Transforms/Passes.h index 0e0fbd05b..e96c00fbe 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.h @@ -6,7 +6,9 @@ namespace mlir { std::unique_ptr createTritonGPUPipelinePass(int numStages = 2); -// TODO(Keren): prefetch pass not working yet +std::unique_ptr +createTritonGPUAccelerateMatmulPass(int computeCapability = 80); + std::unique_ptr createTritonGPUPrefetchPass(); std::unique_ptr createTritonGPUCanonicalizeLoopsPass(); @@ -17,10 +19,12 @@ std::unique_ptr createTritonGPUReorderInstructionsPass(); std::unique_ptr createTritonGPUDecomposeConversionsPass(); -std::unique_ptr createTritonGPUCombineOpsPass(int computeCapability = 80); +std::unique_ptr createTritonGPURemoveLayoutConversionsPass(); std::unique_ptr createTritonGPUVerifier(); +std::unique_ptr createTritonGPUFuseTranspositionsPass(); + std::unique_ptr createTritonGPUUpdateMmaForVoltaPass(); /// Generate the code for registering passes. diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index 294fafa2b..e6aa6f7ae 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -7,7 +7,8 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> { let summary = "pipeline"; let description = [{ - Unroll loops to hide global memory -> shared memory latency. + Replace `LoadOp` in loops by `InsertSliceAsyncOp` instructions that asynchronously construct the data + needed at the next iteration }]; let constructor = "mlir::createTritonGPUPipelinePass()"; @@ -27,7 +28,8 @@ def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> { let summary = "prefetch"; let description = [{ - Prefetch operands (a and b) of tt.dot into shared memory to hide shared memory -> register latency. + Decompose `DotOp` instructions in loops into several finer-grained `DotOp` + that may have their operands constructed at the end of the previous iteration }]; let constructor = "mlir::createTritonGPUPrefetchPass()"; @@ -37,6 +39,41 @@ def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> { "mlir::arith::ArithDialect"]; } +def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::ModuleOp"> { + let summary = "accelerate matmul"; + + let description = [{ + Optimize the input/output layout of `dot` instruction to make them compatible hardware accelerators + (e.g., Nvidia tensor cores) + }]; + + let constructor = "mlir::createTritonGPUAccelerateMatmulPass()"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; + + let options = [ + Option<"computeCapability", "compute-capability", + "int32_t", /*default*/"80", + "device compute capability"> + ]; + +} + +def TritonGPUFuseTranspositions : Pass<"tritongpu-fuse-transposition", "mlir::ModuleOp"> { + let summary = "fuse transpositions"; + + let description = [{ + Re-arranged layouts of tensors used as matrix multiplication operands so as to promote the use of + hardware-accelerated transpositions. + }]; + + let constructor = "mlir::createTritonGPUFuseTranspositionsPass()"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> { let summary = "coalesce"; @@ -49,26 +86,16 @@ def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> { let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; } -def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> { - let summary = "combine triton gpu ops"; +def TritonGPURemoveLayoutConversions : Pass<"tritongpu-remove-layout-conversions", "mlir::ModuleOp"> { + let summary = "remove superfluous layout conversions"; let description = [{ - convert_layout(convert_layout(%src, #LAYOUT_0), #LAYOUT_1) => - convert_layout(%src, #LAYOUT_1) - - convert_layout(%src, #LAYOUT) => %src if %src.layout() == #LAYOUT }]; - let constructor = "mlir::createTritonGPUCombineOpsPass()"; + let constructor = "mlir::createTritonGPURemoveLayoutConversionsPass()"; let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", "mlir::triton::TritonDialect"]; - - let options = [ - Option<"computeCapability", "compute-capability", - "int32_t", /*default*/"80", - "device compute capability"> - ]; } def TritonGPUReorderInstructions: Pass<"tritongpu-reorder-instructions", "mlir::ModuleOp"> { @@ -95,19 +122,6 @@ def TritonGPUDecomposeConversions: Pass<"tritongpu-decompose-conversions", "mlir "mlir::triton::TritonDialect"]; } -def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> { - let summary = "canonicalize scf.ForOp ops"; - - let description = [{ - This implements some optimizations that are missing in the standard scf.ForOp - canonicalizer. - }]; - - let constructor = "mlir::createTritonGPUCanonicalizeLoopsPass()"; - - let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; -} - def UpdateMmaForVolta : Pass<"tritongpu-update-mma-for-volta", "mlir::ModuleOp"> { let summary = "Update mma encodings for Volta"; diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index 910274b2a..d7d3fcd45 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -2,74 +2,84 @@ #include "triton/Analysis/Alias.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include namespace mlir { void MembarAnalysis::run() { auto *operation = allocation->getOperation(); - RegionInfo regionInfo; OpBuilder builder(operation); - dfsOperation(operation, ®ionInfo, &builder); + resolve(operation, &builder); } -void MembarAnalysis::dfsOperation(Operation *operation, - RegionInfo *parentRegionInfo, - OpBuilder *builder) { - transfer(operation, parentRegionInfo, builder); - if (operation->getNumRegions()) { - // If there's any nested regions, we need to visit them. - // scf.if and scf.else: two regions - // scf.if only: two regions - // scf.for: one region - RegionInfo curRegionInfo; - auto traverseRegions = [&]() -> auto{ - for (auto ®ion : operation->getRegions()) { - // Copy the parent info as the current info. - RegionInfo regionInfo = *parentRegionInfo; - for (auto &block : region.getBlocks()) { - // assert(region.getBlocks().size() == 1 && - // "Multiple blocks in a region is not supported"); - for (auto &op : block.getOperations()) { - // Traverse the nested operation. - dfsOperation(&op, ®ionInfo, builder); - } - } - curRegionInfo.join(regionInfo); +void MembarAnalysis::resolve(Operation *operation, OpBuilder *builder) { + // Initialize the blockList + std::deque blockList; + operation->walk([&](Block *block) { + for (auto &op : block->getOperations()) { + // Check if the operation belongs to scf dialect, if so, we need to + // throw an error + if (op.getDialect()->getNamespace() == "scf") { + op.emitError("scf dialect is not supported in membar. Please lower it " + "to cf dialect first."); + return; } - // Set the parent region info as the union of the nested region info. - *parentRegionInfo = curRegionInfo; - }; + } + if (block->isEntryBlock()) + blockList.emplace_back(block); + }); - traverseRegions(); - if (isa(operation)) { - // scf.for can have two possible inputs: the init value and the - // previous iteration's result. Although we've applied alias analysis, - // there could be unsynced memory accesses on reused memories. - // For example, consider the following code: - // %1 = convert_layout %0: blocked -> shared - // ... - // gpu.barrier - // ... - // %5 = convert_layout %4 : shared -> dot - // %6 = tt.dot %2, %5 - // scf.yield - // - // Though %5 could be released before scf.yield, it may shared the same - // memory with %1. So we actually have to insert a barrier before %1 to - // make sure the memory is synced. - traverseRegions(); + // A fixed point algorithm + while (!blockList.empty()) { + auto *block = blockList.front(); + blockList.pop_front(); + // Make a copy of the inputblockInfo but not update + auto inputBlockInfo = inputBlockInfoMap.lookup(block); + SmallVector successors; + for (auto &op : block->getOperations()) { + if (op.hasTrait()) { + visitTerminator(&op, successors); + } else { + update(&op, &inputBlockInfo, builder); + } + } + // Get the reference because we want to update if it changed + if (outputBlockInfoMap.count(block) && + inputBlockInfo == outputBlockInfoMap[block]) { + // If we have seen the block before and the inputBlockInfo is the same as + // the outputBlockInfo, we skip the successors + continue; + } + // Update the current block + outputBlockInfoMap[block].join(inputBlockInfo); + // Update the successors + for (auto *successor : successors) { + inputBlockInfoMap[successor].join(outputBlockInfoMap[block]); + blockList.emplace_back(successor); } } } -void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo, - OpBuilder *builder) { - if (isa(op) || isa(op) || isa(op) || - isa(op) || isa(op)) { - // Do not insert barriers before control flow operations and - // alloc/extract/insert +void MembarAnalysis::visitTerminator(Operation *op, + SmallVector &successors) { + if (auto branchInterface = dyn_cast(op)) { + Block *parentBlock = branchInterface->getBlock(); + for (Block *successor : parentBlock->getSuccessors()) { + successors.push_back(successor); + } + return; + } + // Otherwise, it could be a return op + assert(isa(op) && "Unknown terminator"); +} + +void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, + OpBuilder *builder) { + if (isa(op) || isa(op) || + isa(op)) { // alloc is an allocation op without memory write. // FIXME(Keren): extract_slice is always alias for now return; @@ -77,7 +87,7 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo, if (isa(op)) { // If the current op is a barrier, we sync previous reads and writes - regionInfo->sync(); + blockInfo->sync(); return; } @@ -85,26 +95,26 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo, !isa(op->getNextNode())) { // If the current op is an async wait and the next op is not a barrier we // insert a barrier op and sync - regionInfo->sync(); + blockInfo->sync(); OpBuilder::InsertionGuard g(*builder); builder->setInsertionPointAfter(op); builder->create(op->getLoc()); - regionInfo->sync(); + blockInfo->sync(); return; } - RegionInfo curRegionInfo; + BlockInfo curBlockInfo; for (Value value : op->getOperands()) { for (auto bufferId : allocation->getBufferIds(value)) { if (bufferId != Allocation::InvalidBufferId) { if (isa(op) || isa(op)) { - // FIXME(Keren): insert_slice and insert_slice_async are always alias - // for now - curRegionInfo.syncWriteBuffers.insert(bufferId); + // FIXME(Keren): insert_slice and insert_slice_async are always + // alias for now + curBlockInfo.syncWriteBuffers.insert(bufferId); } else { // ConvertLayoutOp: shared memory -> registers - curRegionInfo.syncReadBuffers.insert(bufferId); + curBlockInfo.syncReadBuffers.insert(bufferId); } } } @@ -113,25 +123,25 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo, // ConvertLayoutOp: registers -> shared memory auto bufferId = allocation->getBufferId(value); if (bufferId != Allocation::InvalidBufferId) { - curRegionInfo.syncWriteBuffers.insert(bufferId); + curBlockInfo.syncWriteBuffers.insert(bufferId); } } // Scratch buffer is considered as both shared memory write & read auto bufferId = allocation->getBufferId(op); if (bufferId != Allocation::InvalidBufferId) { - curRegionInfo.syncWriteBuffers.insert(bufferId); - curRegionInfo.syncReadBuffers.insert(bufferId); + curBlockInfo.syncWriteBuffers.insert(bufferId); + curBlockInfo.syncReadBuffers.insert(bufferId); } - if (regionInfo->isIntersected(curRegionInfo, allocation)) { + if (blockInfo->isIntersected(curBlockInfo, allocation)) { OpBuilder::InsertionGuard g(*builder); builder->setInsertionPoint(op); builder->create(op->getLoc()); - regionInfo->sync(); + blockInfo->sync(); } // Update the region info, even if barrier is inserted, we have to maintain // the current op's read/write buffers. - regionInfo->join(curRegionInfo); + blockInfo->join(curBlockInfo); } } // namespace mlir diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 99cad6b3f..f450a86cb 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -337,7 +337,8 @@ namespace { // Copied from TestDeadCodeAnalysis.cpp, because some dead code analysis // interacts with constant propagation, but SparseConstantPropagation // doesn't seem to be sufficient. -struct ConstantAnalysis : public DataFlowAnalysis { +class ConstantAnalysis : public DataFlowAnalysis { +public: using DataFlowAnalysis::DataFlowAnalysis; LogicalResult initialize(Operation *top) override { @@ -359,12 +360,19 @@ struct ConstantAnalysis : public DataFlowAnalysis { value, op->getDialect()))); return success(); } + // Dead code analysis requires every operands has initialized ConstantValue + // state before it is visited. + // https://github.com/llvm/llvm-project/blob/2ec1aba2b69faa1de5f71832a48e25aa3b5d5314/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp#L322 + // That's why we need to set all operands to unknown constants. setAllToUnknownConstants(op->getResults()); - for (Region ®ion : op->getRegions()) - setAllToUnknownConstants(region.getArguments()); + for (Region ®ion : op->getRegions()) { + for (Block &block : region.getBlocks()) + setAllToUnknownConstants(block.getArguments()); + } return success(); } +private: /// Set all given values as not constants. void setAllToUnknownConstants(ValueRange values) { dataflow::ConstantValue unknownConstant(nullptr, nullptr); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h index 91bc80b0c..6c3e4c97c 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h @@ -291,10 +291,10 @@ public: for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) { // extract multi dimensional index for current element auto idx = srcIndices[elemIdx]; - Value idxCol = idx[inOrder[0]]; // contiguous dimension - Value idxRow = idx[inOrder[1]]; // discontiguous dimension - Value strideCol = srcStrides[inOrder[0]]; - Value strideRow = srcStrides[inOrder[1]]; + Value idxCol = idx[outOrder[0]]; // contiguous dimension + Value idxRow = idx[outOrder[1]]; // discontiguous dimension + Value strideCol = srcStrides[outOrder[0]]; + Value strideRow = srcStrides[outOrder[1]]; // extract dynamic/static offset for immediate offseting unsigned immedateOffCol = 0; if (auto add = dyn_cast_or_null(idxCol.getDefiningOp())) @@ -338,7 +338,7 @@ public: Value currPtr = gep(dstPtrTy, dstPtrBase, offset); // compute immediate offset Value immedateOff = - add(mul(i32_val(immedateOffRow), srcStrides[inOrder[1]]), + add(mul(i32_val(immedateOffRow), srcStrides[outOrder[1]]), i32_val(immedateOffCol)); ret[elemIdx] = gep(dstPtrTy, currPtr, immedateOff); } diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp index 9a7b4c2a3..3afd38a26 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp @@ -128,16 +128,13 @@ public: // Step 2: Decompose insert_slice_async to use load + insert_slice for // pre-Ampere architectures or unsupported vectorized load sizes // Step 3: Allocate shared memories and insert barriers - // Step 4: Convert SCF to CFG - // Step 5: Convert FuncOp to LLVMFuncOp via partial conversion - // Step 6: Get axis and shared memory info - // Step 7: Convert the rest of ops via partial conversion + // Step 4: Convert FuncOp to LLVMFuncOp via partial conversion + // Step 5: Get axis and shared memory info + // Step 6: Convert the rest of ops via partial conversion // - // The reason for putting step 3 before step 4 is that the membar - // analysis currently only supports SCF but not CFG. The reason for a - // separation between 5/7 is that, step 6 is out of the scope of Dialect - // Conversion, thus we need to make sure the smem is not revised during the - // conversion of step 7. + // The reason for a separation between 4/6 is that, step 5 is out of the + // scope of Dialect Conversion, thus we need to make sure the smem is not + // revised during the conversion of step 6. // Step 1 decomposeMmaToDotOperand(mod, numWarps); @@ -153,24 +150,13 @@ public: membarPass.run(); // Step 4 - RewritePatternSet scf_patterns(context); - mlir::populateSCFToControlFlowConversionPatterns(scf_patterns); - mlir::ConversionTarget scf_target(*context); - scf_target.addIllegalOp(); - scf_target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + RewritePatternSet funcPatterns(context); + funcPatterns.add(typeConverter, numWarps, /*benefit=*/1); if (failed( - applyPartialConversion(mod, scf_target, std::move(scf_patterns)))) + applyPartialConversion(mod, funcTarget, std::move(funcPatterns)))) return signalPassFailure(); - // Step 5 - RewritePatternSet func_patterns(context); - func_patterns.add(typeConverter, numWarps, /*benefit=*/1); - if (failed( - applyPartialConversion(mod, funcTarget, std::move(func_patterns)))) - return signalPassFailure(); - - // Step 6 - get axis and shared memory info + // Step 5 - get axis and shared memory info std::unique_ptr solver = createDataFlowSolver(); AxisInfoAnalysis *axisInfoAnalysis = solver->load(); if (failed(solver->initializeAndRun(mod))) @@ -180,7 +166,7 @@ public: mlir::IntegerAttr::get(mlir::IntegerType::get(context, 32), allocation.getSharedMemorySize())); - // Step 7 - rewrite rest of ops + // Step 6 - rewrite rest of ops // We set a higher benefit here to ensure triton's patterns runs before // arith patterns for some encoding not supported by the community // patterns. @@ -222,30 +208,16 @@ public: // Add arith/math's patterns to help convert scalar expression to LLVM. mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns); + mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, + patterns); #ifdef USE_ROCM mlir::populateGpuToROCDLConversionPatterns(typeConverter, patterns, mlir::gpu::amd::HIP); #else - mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, - patterns); mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns); #endif if (failed(applyPartialConversion(mod, target, std::move(patterns)))) return signalPassFailure(); - - // Take care of scf pattern introduced by LoadStoreOp -#ifdef USE_ROCM - RewritePatternSet scf_patterns_extra(context); - mlir::populateSCFToControlFlowConversionPatterns(scf_patterns_extra); - if (failed( - applyPartialConversion(mod, scf_target, std::move(scf_patterns_extra)))) - return signalPassFailure(); - RewritePatternSet patterns_extra(context); - mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns_extra); - if (failed( - applyPartialConversion(mod, target, std::move(patterns_extra)))) - return signalPassFailure(); -#endif } private: diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 18642626b..e700c0b8d 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -790,6 +790,141 @@ struct TritonGPUInferLayoutInterface } }; +//===----------------------------------------------------------------------===// +// Canonicalizer +//===----------------------------------------------------------------------===// + +LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op, + PatternRewriter &rewriter) { + // we don't handle conversions to DotOperandEncodingAttr + // this is a heuristics to accommodate fused attention + auto srcType = op.getOperand().getType().cast(); + auto dstType = op.getType().cast(); + if (dstType.getEncoding().isa() && + srcType.getEncoding().isa()) + return mlir::failure(); + // convert to the same layout -- we can delete + if (op->getResultTypes() == op->getOperandTypes()) { + rewriter.replaceOp(op, op->getOperands()); + return mlir::success(); + } + Operation *arg = op->getOperand(0).getDefiningOp(); + // block argument + if (!arg) + return mlir::failure(); + // cvt(view) -> view + if (auto view = dyn_cast(arg)) { + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + view.getResult()); + return mlir::success(); + } + // cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2) + auto alloc_tensor = dyn_cast(arg); + if (alloc_tensor) { + if (!isSharedEncoding(op->getResult(0))) { + return mlir::failure(); + } + rewriter.replaceOpWithNewOp( + op, op->getResult(0).getType()); + return mlir::success(); + } + // cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2)) + auto insert_slice = dyn_cast(arg); + if (insert_slice) { + if (!isSharedEncoding(op->getResult(0))) { + return mlir::failure(); + } + auto newType = op->getResult(0).getType().cast(); + // Ensure that the new insert_slice op is placed in the same place as the + // old insert_slice op. Otherwise, the new insert_slice op may be placed + // after the async_wait op, which is not allowed. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(insert_slice); + auto newArg = rewriter.create( + op->getLoc(), newType, insert_slice.getDst()); + rewriter.replaceOpWithNewOp( + op, newType, insert_slice.getSrc(), newArg.getResult(), + insert_slice.getIndex(), insert_slice.getMask(), + insert_slice.getOther(), insert_slice.getCache(), + insert_slice.getEvict(), insert_slice.getIsVolatile(), + insert_slice.getAxis()); + return mlir::success(); + } + // cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2)) + auto extract_slice = dyn_cast(arg); + if (extract_slice) { + if (!isSharedEncoding(op->getResult(0))) { + return mlir::failure(); + } + auto origType = + extract_slice.getSource().getType().cast(); + auto newType = RankedTensorType::get( + origType.getShape(), origType.getElementType(), + op->getResult(0).getType().cast().getEncoding()); + auto origResType = op->getResult(0).getType().cast(); + auto resType = RankedTensorType::get( + origResType.getShape(), origResType.getElementType(), + extract_slice.getType().cast().getEncoding()); + // Ensure that the new extract_slice op is placed in the same place as the + // old extract_slice op. Otherwise, the new extract_slice op may be placed + // after the async_wait op, which is not allowed. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(extract_slice); + auto newArg = rewriter.create( + op->getLoc(), newType, extract_slice.getSource()); + rewriter.replaceOpWithNewOp( + op, resType, newArg.getResult(), extract_slice.offsets(), + extract_slice.sizes(), extract_slice.strides(), + extract_slice.static_offsets(), extract_slice.static_sizes(), + extract_slice.static_strides()); + return mlir::success(); + } + + // cvt(cvt(x, type1), type2) -> cvt(x, type2) + if (llvm::isa(arg)) { + if (arg->getOperand(0).getDefiningOp() && + !isSharedEncoding(arg->getOperand(0)) && + isSharedEncoding(op.getOperand()) && + !isSharedEncoding(op.getResult())) { + return mlir::failure(); + } + if (isSharedEncoding(op.getOperand()) && isSharedEncoding(op.getResult())) { + return mlir::failure(); + } + auto srcType = op.getOperand().getType().cast(); + auto srcShared = + srcType.getEncoding().dyn_cast(); + if (srcShared && srcShared.getVec() > 1) + return mlir::failure(); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes().front(), arg->getOperand(0)); + return mlir::success(); + } + // cvt(type1, splat(type2, x)) -> splat(type1, x) + if (auto splat = llvm::dyn_cast(arg)) { + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), + splat.getSrc()); + return mlir::success(); + } + // cvt(type1, make_range(type2, x)) -> make_range(type1, x) + if (auto range = llvm::dyn_cast(arg)) { + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), range.getStart(), range.getEnd()); + return mlir::success(); + } + // cvt(type, constant) -> constant + if (auto cst = llvm::dyn_cast(arg)) + if (auto ret = cst.getValue().dyn_cast()) { + auto newRet = SplatElementsAttr::get(op->getResultTypes().front(), + ret.getSplatValue()); + rewriter.replaceOpWithNewOp(op, newRet); + return mlir::success(); + } + return mlir::failure(); +} + +//===----------------------------------------------------------------------===// + void TritonGPUDialect::initialize() { addAttributes< #define GET_ATTRDEF_LIST diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp new file mode 100644 index 000000000..83f3be99a --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -0,0 +1,215 @@ +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include + +using namespace mlir; +namespace { +using triton::DotOp; +using triton::gpu::ConvertLayoutOp; +using triton::gpu::DotOperandEncodingAttr; +using triton::gpu::MmaEncodingAttr; +using triton::gpu::SliceEncodingAttr; + +int computeCapabilityToMMAVersion(int computeCapability) { +#ifdef USE_ROCM + return 1; +#endif + if (computeCapability < 70) { + return 0; + } else if (computeCapability < 80) { + return 1; + } else if (computeCapability < 90) { + return 2; + } else if (computeCapability < 100) { + // FIXME: temporarily add this to pass unis tests + return 2; + } else { + assert(false && "computeCapability > 100 not supported"); + return 3; + } +} + +SmallVector mmaVersionToShapePerWarp(int version) { + if (version == 1) + return {16, 16}; + else if (version == 2) + return {16, 8}; + else { + assert(false && "version not supported"); + return {0, 0}; + } +} + +SmallVector warpsPerTileV1(const ArrayRef shape, + int numWarps) { + // Set a default value that ensures product of wpt equals numWarps + return {static_cast(numWarps), 1}; +} + +SmallVector warpsPerTileV2(triton::DotOp dotOp, + const ArrayRef shape, + int numWarps) { + SetVector slices; + mlir::getForwardSlice(dotOp.getResult(), &slices); + if (llvm::find_if(slices, [](Operation *op) { + return isa(op); + }) != slices.end()) + return {(unsigned)numWarps, 1}; + + SmallVector ret = {1, 1}; + SmallVector shapePerWarp = {16, 8}; + bool changed = false; + // TODO (@daadaada): double-check. + // original logic in + // https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252 + // seems buggy for shape = [32, 16] ? + do { + changed = false; + if (ret[0] * ret[1] >= numWarps) + break; + if (shape[0] / shapePerWarp[0] / ret[0] >= + shape[1] / (shapePerWarp[1] * 2) / ret[1]) { + if (ret[0] < shape[0] / shapePerWarp[0]) { + ret[0] *= 2; + } else + ret[1] *= 2; + } else { + ret[1] *= 2; + } + } while (true); + return ret; +} + +class BlockedToMMA : public mlir::RewritePattern { + int computeCapability; + mutable int mmaV1Counter{}; // used to generate ID for MMAv1 encoding + +public: + BlockedToMMA(mlir::MLIRContext *context, int computeCapability) + : mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context), + computeCapability(computeCapability) {} + + static SmallVector getWarpsPerTile(triton::DotOp dotOp, + const ArrayRef shape, + int version, int numWarps) { + switch (version) { + case 1: + return warpsPerTileV1(shape, numWarps); + case 2: + return warpsPerTileV2(dotOp, shape, numWarps); + default: + llvm_unreachable("unsupported MMA version"); + } + } + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto dotOp = cast(op); + // TODO: Check data-types and SM compatibility + auto oldRetType = dotOp.getResult().getType().cast(); + if (!oldRetType.getEncoding() || + oldRetType.getEncoding().isa()) + return failure(); + + // for FMA, should retain the blocked layout. + int versionMajor = computeCapabilityToMMAVersion(computeCapability); + if (!supportMMA(dotOp, versionMajor)) + return failure(); + + // get MMA encoding for the given number of warps + auto retShape = oldRetType.getShape(); + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + + auto warpsPerTile = + getWarpsPerTile(dotOp, retShape, versionMajor, numWarps); + triton::gpu::MmaEncodingAttr mmaEnc; + if (versionMajor == 1) { + mmaEnc = triton::gpu::MmaEncodingAttr::get( + oldRetType.getContext(), versionMajor, numWarps, mmaV1Counter++); + } else if (versionMajor == 2) { + mmaEnc = triton::gpu::MmaEncodingAttr::get( + oldRetType.getContext(), versionMajor, 0 /*versionMinor*/, + warpsPerTile); + } else { + llvm_unreachable("Mma layout only supports versionMajor in {1, 2}"); + } + auto newRetType = + RankedTensorType::get(retShape, oldRetType.getElementType(), mmaEnc); + + // convert accumulator + auto oldAcc = dotOp.getOperand(2); + auto newAcc = rewriter.create( + oldAcc.getLoc(), newRetType, oldAcc); + Value a = dotOp.getA(); + Value b = dotOp.getB(); + auto oldAType = a.getType().cast(); + auto oldBType = b.getType().cast(); + auto oldAOrder = oldAType.getEncoding() + .cast() + .getParent() + .cast() + .getOrder(); + auto oldBOrder = oldBType.getEncoding() + .cast() + .getParent() + .cast() + .getOrder(); + Attribute isMMAv1RowA; + Attribute isMMAv1RowB; + if (versionMajor == 1) { + isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1); + isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1); + } + + auto newAType = RankedTensorType::get( + oldAType.getShape(), oldAType.getElementType(), + triton::gpu::DotOperandEncodingAttr::get( + oldAType.getContext(), 0, newRetType.getEncoding(), isMMAv1RowA)); + auto newBType = RankedTensorType::get( + oldBType.getShape(), oldBType.getElementType(), + triton::gpu::DotOperandEncodingAttr::get( + oldBType.getContext(), 1, newRetType.getEncoding(), isMMAv1RowB)); + + a = rewriter.create(a.getLoc(), newAType, a); + b = rewriter.create(b.getLoc(), newBType, b); + auto newDot = rewriter.create( + dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.getAllowTF32()); + + rewriter.replaceOpWithNewOp( + op, oldRetType, newDot.getResult()); + return success(); + } +}; +} // namespace + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUAccelerateMatmulPass + : public TritonGPUAccelerateMatmulBase { +public: + TritonGPUAccelerateMatmulPass() = default; + TritonGPUAccelerateMatmulPass(int computeCapability) { + this->computeCapability = computeCapability; + } + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + mlir::RewritePatternSet patterns(context); + patterns.add<::BlockedToMMA>(context, computeCapability); + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { + signalPassFailure(); + } + } +}; + +std::unique_ptr +mlir::createTritonGPUAccelerateMatmulPass(int computeCapability) { + return std::make_unique(computeCapability); +} diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index fbcb4dbe7..f33e7df12 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -1,22 +1,18 @@ -set(LLVM_TARGET_DEFINITIONS Combine.td) -mlir_tablegen(TritonGPUCombine.inc -gen-rewriters) -add_public_tablegen_target(TritonGPUCombineIncGen) - add_mlir_dialect_library(TritonGPUTransforms + AccelerateMatmul.cpp Coalesce.cpp - CanonicalizeLoops.cpp - Combine.cpp + DecomposeConversions.cpp + FuseTranspositions.cpp Pipeline.cpp Prefetch.cpp + RemoveLayoutConversions.cpp ReorderInstructions.cpp - DecomposeConversions.cpp TritonGPUConversion.cpp UpdateMmaForVolta.cpp Utility.cpp DEPENDS TritonGPUTransformsIncGen - TritonGPUCombineIncGen LINK_LIBS PUBLIC TritonIR diff --git a/lib/Dialect/TritonGPU/Transforms/CanonicalizeLoops.cpp b/lib/Dialect/TritonGPU/Transforms/CanonicalizeLoops.cpp deleted file mode 100644 index 462f393c4..000000000 --- a/lib/Dialect/TritonGPU/Transforms/CanonicalizeLoops.cpp +++ /dev/null @@ -1,55 +0,0 @@ -#include "mlir/Analysis/SliceAnalysis.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/Transforms/Passes.h" - -using namespace mlir; -using namespace mlir::triton; - -#define GEN_PASS_CLASSES -#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" - -namespace { - -struct CanonicalizePass - : public TritonGPUCanonicalizeLoopsBase { - CanonicalizePass() = default; - - void runOnOperation() override { - - // Canonicalize pass may have created dead code that - // standard scf.for canonicalization cannot handle - // as of LLVM 14. For example, the iteration arguments - // for the pointer of the synchronous loads that are - // discarded. - // The following piece of code is a workaround to - // very crudely remove dead code, by making an iteration - // argument yield itself if it is not used to create - // side effects anywhere. - getOperation()->walk([&](scf::ForOp forOp) -> void { - for (size_t i = 0; i < forOp.getNumResults(); ++i) { - // condition 1: no other iter arguments depend on it - SetVector fwdSlice; - mlir::getForwardSlice(forOp.getRegionIterArgs()[i], &fwdSlice); - Operation *yieldOp = forOp.getBody()->getTerminator(); - bool noOtherDependency = std::all_of( - yieldOp->operand_begin(), yieldOp->operand_end(), [&](Value arg) { - return arg == yieldOp->getOperand(i) || - !fwdSlice.contains(arg.getDefiningOp()); - }); - // condition 2: final value is not used after the loop - auto retVal = forOp.getResult(i); - bool noUserAfterLoop = retVal.getUsers().empty(); - // yielding the region iter arg will cause loop canonicalization - // to clean up the dead code - if (noOtherDependency && noUserAfterLoop) { - yieldOp->setOperand(i, forOp.getRegionIterArgs()[i]); - } - } - }); - } -}; -} // anonymous namespace - -std::unique_ptr mlir::createTritonGPUCanonicalizeLoopsPass() { - return std::make_unique(); -} \ No newline at end of file diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.td b/lib/Dialect/TritonGPU/Transforms/Combine.td deleted file mode 100644 index 6a7b10dbc..000000000 --- a/lib/Dialect/TritonGPU/Transforms/Combine.td +++ /dev/null @@ -1,8 +0,0 @@ -#ifndef TRITONGPU_PATTERNS -#define TRITONGPU_PATTERNS - -include "triton/Dialect/TritonGPU/IR/TritonGPUOps.td" -include "triton/Dialect/Triton/IR/TritonOps.td" -include "mlir/IR/PatternBase.td" - -#endif diff --git a/lib/Dialect/TritonGPU/Transforms/FuseTranspositions.cpp b/lib/Dialect/TritonGPU/Transforms/FuseTranspositions.cpp new file mode 100644 index 000000000..7584d3cca --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/FuseTranspositions.cpp @@ -0,0 +1,153 @@ +#include "Utility.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include + +using namespace mlir; +namespace { +using triton::DotOp; +using triton::gpu::ConvertLayoutOp; +using triton::gpu::DotOperandEncodingAttr; +using triton::gpu::MmaEncodingAttr; +using triton::gpu::SliceEncodingAttr; + +class OptimizeConvertToDotOperand : public mlir::RewritePattern { +public: + explicit OptimizeConvertToDotOperand(mlir::MLIRContext *context) + : RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1, + context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto cvt = cast(op); + auto srcType = cvt.getOperand().getType().cast(); + auto dstType = cvt.getResult().getType().cast(); + // order + ArrayRef order; + if (auto srcBlockedLayout = + srcType.getEncoding().dyn_cast()) + order = srcBlockedLayout.getOrder(); + else if (auto srcSharedLayout = + srcType.getEncoding() + .dyn_cast()) + order = srcSharedLayout.getOrder(); + else + return failure(); + // dot operand output + auto dstDotOperandLayout = + dstType.getEncoding().dyn_cast(); + if (!dstDotOperandLayout) + return failure(); + if (!dstDotOperandLayout.getIsMMAv1Row()) + return failure(); + bool isMMAv1Row = + dstDotOperandLayout.getIsMMAv1Row().cast().getValue(); + if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row)) + return failure(); + + auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row); + auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get( + op->getContext(), dstDotOperandLayout.getOpIdx(), + dstDotOperandLayout.getParent(), newIsRow); + auto newDstType = RankedTensorType::get( + dstType.getShape(), dstType.getElementType(), newDstEncoding); + auto newCvt = rewriter.create( + op->getLoc(), newDstType, cvt.getOperand()); + rewriter.replaceOp(op, newCvt.getResult()); + return success(); + } +}; + +// convert(trans(convert(arg))) +// x = convert_layout arg: #distributed -> #shared_x +// y = trans x: #shared_x -> #shared_y +// z = convert_layout y: #shared_y -> #dot_operand +class ConvertTransConvert : public mlir::RewritePattern { + +public: + ConvertTransConvert(mlir::MLIRContext *context) + : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), + 1, context) {} + + LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto dstOp = cast(op); + auto tmpOp = + dyn_cast_or_null(dstOp.getSrc().getDefiningOp()); + if (!tmpOp) + return mlir::failure(); + auto srcOp = dyn_cast_or_null( + tmpOp.getSrc().getDefiningOp()); + if (!srcOp) + return mlir::failure(); + auto arg = srcOp.getSrc(); + auto X = tmpOp.getSrc(); + // types + auto argType = arg.getType().cast(); + auto XType = X.getType().cast(); + auto ZType = dstOp.getResult().getType().cast(); + // encodings + auto argEncoding = argType.getEncoding(); + auto XEncoding = + XType.getEncoding().cast(); + auto ZEncoding = + ZType.getEncoding().dyn_cast(); + if (!ZEncoding) + return mlir::failure(); + // new X encoding + auto newXOrder = triton::gpu::getOrder(argEncoding); + auto newXEncoding = triton::gpu::SharedEncodingAttr::get( + getContext(), ZEncoding, XType.getShape(), newXOrder, + XType.getElementType()); + auto newXType = RankedTensorType::get(XType.getShape(), + XType.getElementType(), newXEncoding); + if (XEncoding == newXEncoding) + return mlir::failure(); + + auto newX = rewriter.create(srcOp.getLoc(), + newXType, arg); + auto newY = rewriter.create(tmpOp.getLoc(), newX); + rewriter.replaceOpWithNewOp(dstOp, ZType, + newY); + return mlir::success(); + } +}; + +} // namespace + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUFuseTranspositionsPass + : public TritonGPUFuseTranspositionsBase { +public: + TritonGPUFuseTranspositionsPass() = default; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + mlir::PassManager pm(m.getContext()); + pm.addPass(mlir::createCanonicalizerPass()); + auto ret = pm.run(m); + + mlir::RewritePatternSet patterns(context); + patterns.add(context); + patterns.add(context); + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + if (fixupLoops(m).failed()) + signalPassFailure(); + } +}; + +std::unique_ptr mlir::createTritonGPUFuseTranspositionsPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp similarity index 59% rename from lib/Dialect/TritonGPU/Transforms/Combine.cpp rename to lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 05bb94e41..f7bdaf978 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -22,7 +22,6 @@ using namespace mlir; namespace { -#include "TritonGPUCombine.inc" using triton::DotOp; using triton::gpu::ConvertLayoutOp; using triton::gpu::DotOperandEncodingAttr; @@ -139,132 +138,7 @@ public: if (!llvm::isa(op)) return mlir::failure(); auto convert = llvm::cast(op); - // we don't handle conversions to DotOperandEncodingAttr - // this is a heuristics to accommodate fused attention - auto srcType = convert.getOperand().getType().cast(); - auto dstType = convert.getType().cast(); - if (dstType.getEncoding().isa() && - srcType.getEncoding().isa()) - return mlir::failure(); - // convert to the same layout -- we can delete - if (op->getResultTypes() == op->getOperandTypes()) { - rewriter.replaceOp(op, op->getOperands()); - return mlir::success(); - } - Operation *arg = op->getOperand(0).getDefiningOp(); - // block argument - if (!arg) - return mlir::failure(); - // cvt(view) -> view - if (auto view = dyn_cast(arg)) { - rewriter.replaceOpWithNewOp( - op, op->getResult(0).getType(), view.getResult()); - return mlir::success(); - } - // cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2) - auto alloc_tensor = dyn_cast(arg); - if (alloc_tensor) { - if (!isSharedEncoding(op->getResult(0))) { - return mlir::failure(); - } - rewriter.replaceOpWithNewOp( - op, op->getResult(0).getType()); - return mlir::success(); - } - // cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2)) - auto insert_slice = dyn_cast(arg); - if (insert_slice) { - if (!isSharedEncoding(op->getResult(0))) { - return mlir::failure(); - } - auto newType = op->getResult(0).getType().cast(); - // Ensure that the new insert_slice op is placed in the same place as the - // old insert_slice op. Otherwise, the new insert_slice op may be placed - // after the async_wait op, which is not allowed. - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(insert_slice); - auto newArg = rewriter.create( - op->getLoc(), newType, insert_slice.getDst()); - rewriter.replaceOpWithNewOp( - op, newType, insert_slice.getSrc(), newArg.getResult(), - insert_slice.getIndex(), insert_slice.getMask(), - insert_slice.getOther(), insert_slice.getCache(), - insert_slice.getEvict(), insert_slice.getIsVolatile(), - insert_slice.getAxis()); - return mlir::success(); - } - // cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2)) - auto extract_slice = dyn_cast(arg); - if (extract_slice) { - if (!isSharedEncoding(op->getResult(0))) { - return mlir::failure(); - } - auto origType = - extract_slice.getSource().getType().cast(); - auto newType = RankedTensorType::get( - origType.getShape(), origType.getElementType(), - op->getResult(0).getType().cast().getEncoding()); - auto origResType = op->getResult(0).getType().cast(); - auto resType = RankedTensorType::get( - origResType.getShape(), origResType.getElementType(), - extract_slice.getType().cast().getEncoding()); - // Ensure that the new extract_slice op is placed in the same place as the - // old extract_slice op. Otherwise, the new extract_slice op may be placed - // after the async_wait op, which is not allowed. - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(extract_slice); - auto newArg = rewriter.create( - op->getLoc(), newType, extract_slice.getSource()); - rewriter.replaceOpWithNewOp( - op, resType, newArg.getResult(), extract_slice.offsets(), - extract_slice.sizes(), extract_slice.strides(), - extract_slice.static_offsets(), extract_slice.static_sizes(), - extract_slice.static_strides()); - return mlir::success(); - } - - // cvt(cvt(x, type1), type2) -> cvt(x, type2) - if (llvm::isa(arg)) { - if (arg->getOperand(0).getDefiningOp() && - !isSharedEncoding(arg->getOperand(0)) && - isSharedEncoding(convert.getOperand()) && - !isSharedEncoding(convert.getResult())) { - return mlir::failure(); - } - if (isSharedEncoding(convert.getOperand()) && - isSharedEncoding(convert.getResult())) { - return mlir::failure(); - } - auto srcType = convert.getOperand().getType().cast(); - auto srcShared = - srcType.getEncoding().dyn_cast(); - if (srcShared && srcShared.getVec() > 1) - return mlir::failure(); - rewriter.replaceOpWithNewOp( - op, op->getResultTypes().front(), arg->getOperand(0)); - return mlir::success(); - } - // cvt(type1, splat(type2, x)) -> splat(type1, x) - if (auto splat = llvm::dyn_cast(arg)) { - rewriter.replaceOpWithNewOp(op, op->getResultTypes(), - splat.getSrc()); - return mlir::success(); - } - // cvt(type1, make_range(type2, x)) -> make_range(type1, x) - if (auto range = llvm::dyn_cast(arg)) { - rewriter.replaceOpWithNewOp( - op, op->getResultTypes(), range.getStart(), range.getEnd()); - return mlir::success(); - } - // cvt(type, constant) -> constant - if (auto cst = llvm::dyn_cast(arg)) - if (auto ret = cst.getValue().dyn_cast()) { - auto newRet = SplatElementsAttr::get(op->getResultTypes().front(), - ret.getSplatValue()); - rewriter.replaceOpWithNewOp(op, newRet); - return mlir::success(); - } - return mlir::failure(); + return ConvertLayoutOp::canonicalize(convert, rewriter); } }; @@ -568,9 +442,9 @@ public: }; // -class FoldConvertAndReduce : public mlir::RewritePattern { +class RematerializeForward : public mlir::RewritePattern { public: - explicit FoldConvertAndReduce(mlir::MLIRContext *context) + explicit RematerializeForward(mlir::MLIRContext *context) : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1, context) {} @@ -837,393 +711,6 @@ public: } }; -// ----------------------------------------------------------------------------- -// -// ----------------------------------------------------------------------------- - -class RematerializeForward : public mlir::RewritePattern { -public: - explicit RematerializeForward(mlir::MLIRContext *context) - : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), - 2, context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *_cvtOp, - mlir::PatternRewriter &rewriter) const override { - auto cvt = cast(_cvtOp); - auto forOp = dyn_cast(cvt->getParentOp()); - if (!forOp) - return mlir::failure(); - auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; }; - - SetVector cvtSlices; - auto filter = [&](Operation *op) { - return isInLoop(op) && - !isa(op) && - !isa(op) && !isa(op) && - !isa(op); - }; - mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter); - if (cvtSlices.empty()) - return failure(); - - for (Operation *op : cvtSlices) { - if (!isa(op) && - !op->hasTrait() && - !op->hasTrait() && - !isa(op)) - return failure(); - for (Value arg : op->getOperands()) { - Operation *argOp = arg.getDefiningOp(); - if (argOp && (argOp != cvt) && - !isa( - argOp)) { - return failure(); - } - } - } - - // Otherwise, we push the conversion forward - // since we'll be able to move it out of - // the loop once it reaches the yield op - pushConversionForward(cvt, cvtSlices, rewriter); - return success(); - } -}; - -// ----------------------------------------------------------------------------- -// -// ----------------------------------------------------------------------------- -namespace { -int computeCapabilityToMMAVersion(int computeCapability) { -#ifdef USE_ROCM - return 1; -#endif - if (computeCapability < 70) { - return 0; - } else if (computeCapability < 80) { - return 1; - } else if (computeCapability < 90) { - return 2; - } else if (computeCapability < 100) { - // FIXME: temporarily add this to pass unis tests - return 2; - } else { - assert(false && "computeCapability > 100 not supported"); - return 3; - } -} - -SmallVector mmaVersionToShapePerWarp(int version) { - if (version == 1) - return {16, 16}; - else if (version == 2) - return {16, 8}; - else { - assert(false && "version not supported"); - return {0, 0}; - } -} - -SmallVector warpsPerTileV1(const ArrayRef shape, - int numWarps) { - // Set a default value and ensure product of wpt equals numWarps - return {static_cast(numWarps), 1}; -} - -SmallVector warpsPerTileV2(triton::DotOp dotOp, - const ArrayRef shape, - int numWarps) { - SetVector slices; - mlir::getForwardSlice(dotOp.getResult(), &slices); - if (llvm::find_if(slices, [](Operation *op) { - return isa(op); - }) != slices.end()) - return {(unsigned)numWarps, 1}; - - SmallVector ret = {1, 1}; - SmallVector shapePerWarp = {16, 8}; - bool changed = false; - // TODO (@daadaada): double-check. - // original logic in - // https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L252 - // seems buggy for shape = [32, 16] ? - do { - changed = false; - if (ret[0] * ret[1] >= numWarps) - break; - if (shape[0] / shapePerWarp[0] / ret[0] >= - shape[1] / (shapePerWarp[1] * 2) / ret[1]) { - if (ret[0] < shape[0] / shapePerWarp[0]) { - ret[0] *= 2; - } else - ret[1] *= 2; - } else { - ret[1] *= 2; - } - } while (true); - return ret; -} - -} // namespace - -class OptimizeBlockedToShared : public mlir::RewritePattern { -public: - explicit OptimizeBlockedToShared(mlir::MLIRContext *context) - : RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1, - context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto cvt = cast(op); - auto srcType = cvt.getOperand().getType().cast(); - auto dstType = cvt.getResult().getType().cast(); - auto srcBlockedLayout = - srcType.getEncoding().dyn_cast(); - auto dstSharedLayout = - dstType.getEncoding().dyn_cast(); - if (!srcBlockedLayout || !dstSharedLayout) - return failure(); - if (srcBlockedLayout.getOrder() == dstSharedLayout.getOrder()) - return failure(); - // For now only works if single use is transpose - // TODO: rematerialize #shared uses - auto users = op->getUsers(); - if (std::distance(users.begin(), users.end()) != 1 || - !isa(*users.begin())) - return failure(); - - auto tmpShared = triton::gpu::SharedEncodingAttr::get( - op->getContext(), dstSharedLayout.getVec(), - dstSharedLayout.getPerPhase(), dstSharedLayout.getMaxPhase(), - srcBlockedLayout.getOrder()); - auto tmpType = RankedTensorType::get(srcType.getShape(), - srcType.getElementType(), tmpShared); - auto tmpCvt = rewriter.create( - op->getLoc(), tmpType, cvt.getOperand()); - - auto newDstType = RankedTensorType::get( - users.begin()->getResultTypes()[0].cast().getShape(), - srcType.getElementType(), dstSharedLayout); - - auto newTrans = rewriter.create(op->getLoc(), newDstType, - tmpCvt.getResult()); - - rewriter.replaceOp(*users.begin(), newTrans.getResult()); - return success(); - } -}; - -class OptimizeConvertToDotOperand : public mlir::RewritePattern { -public: - explicit OptimizeConvertToDotOperand(mlir::MLIRContext *context) - : RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1, - context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto cvt = cast(op); - auto srcType = cvt.getOperand().getType().cast(); - auto dstType = cvt.getResult().getType().cast(); - // order - ArrayRef order; - if (auto srcBlockedLayout = - srcType.getEncoding().dyn_cast()) - order = srcBlockedLayout.getOrder(); - else if (auto srcSharedLayout = - srcType.getEncoding() - .dyn_cast()) - order = srcSharedLayout.getOrder(); - else - return failure(); - // dot operand output - auto dstDotOperandLayout = - dstType.getEncoding().dyn_cast(); - if (!dstDotOperandLayout) - return failure(); - if (!dstDotOperandLayout.getIsMMAv1Row()) - return failure(); - bool isMMAv1Row = - dstDotOperandLayout.getIsMMAv1Row().cast().getValue(); - if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row)) - return failure(); - - auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row); - auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get( - op->getContext(), dstDotOperandLayout.getOpIdx(), - dstDotOperandLayout.getParent(), newIsRow); - auto newDstType = RankedTensorType::get( - dstType.getShape(), dstType.getElementType(), newDstEncoding); - auto newCvt = rewriter.create( - op->getLoc(), newDstType, cvt.getOperand()); - rewriter.replaceOp(op, newCvt.getResult()); - return success(); - } -}; - -class BlockedToMMA : public mlir::RewritePattern { - int computeCapability; - mutable int mmaV1Counter{}; // used to generate ID for MMAv1 encoding - -public: - BlockedToMMA(mlir::MLIRContext *context, int computeCapability) - : mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context), - computeCapability(computeCapability) {} - - static SmallVector getWarpsPerTile(triton::DotOp dotOp, - const ArrayRef shape, - int version, int numWarps) { - switch (version) { - case 1: - return warpsPerTileV1(shape, numWarps); - case 2: - return warpsPerTileV2(dotOp, shape, numWarps); - default: - assert(false && "not supported version"); - return {0, 0}; - } - } - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto dotOp = cast(op); - // TODO: Check data-types and SM compatibility - auto oldRetType = dotOp.getResult().getType().cast(); - if (!oldRetType.getEncoding() || - oldRetType.getEncoding().isa()) - return failure(); - - // for FMA, should retain the blocked layout. - int versionMajor = computeCapabilityToMMAVersion(computeCapability); - if (!supportMMA(dotOp, versionMajor)) - return failure(); - - // get MMA encoding for the given number of warps - auto retShape = oldRetType.getShape(); - auto mod = op->getParentOfType(); - int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); - - auto warpsPerTile = - getWarpsPerTile(dotOp, retShape, versionMajor, numWarps); - triton::gpu::MmaEncodingAttr mmaEnc; - if (versionMajor == 1) { - mmaEnc = triton::gpu::MmaEncodingAttr::get( - oldRetType.getContext(), versionMajor, numWarps, mmaV1Counter++); - } else if (versionMajor == 2) { - mmaEnc = triton::gpu::MmaEncodingAttr::get( - oldRetType.getContext(), versionMajor, 0 /*versionMinor*/, - warpsPerTile); - } else { - assert(false && "Mma layout only support versionMajor of 1 or 2"); - } - auto newRetType = - RankedTensorType::get(retShape, oldRetType.getElementType(), mmaEnc); - - // convert accumulator - auto oldAcc = dotOp.getOperand(2); - auto newAcc = rewriter.create( - oldAcc.getLoc(), newRetType, oldAcc); - Value a = dotOp.getA(); - Value b = dotOp.getB(); - auto oldAType = a.getType().cast(); - auto oldBType = b.getType().cast(); - auto oldAOrder = oldAType.getEncoding() - .cast() - .getParent() - .cast() - .getOrder(); - auto oldBOrder = oldBType.getEncoding() - .cast() - .getParent() - .cast() - .getOrder(); - Attribute isMMAv1RowA; - Attribute isMMAv1RowB; - if (versionMajor == 1) { - isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1); - isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1); - } - - auto newAType = RankedTensorType::get( - oldAType.getShape(), oldAType.getElementType(), - triton::gpu::DotOperandEncodingAttr::get( - oldAType.getContext(), 0, newRetType.getEncoding(), isMMAv1RowA)); - auto newBType = RankedTensorType::get( - oldBType.getShape(), oldBType.getElementType(), - triton::gpu::DotOperandEncodingAttr::get( - oldBType.getContext(), 1, newRetType.getEncoding(), isMMAv1RowB)); - - a = rewriter.create(a.getLoc(), newAType, a); - b = rewriter.create(b.getLoc(), newBType, b); - auto newDot = rewriter.create( - dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.getAllowTF32()); - - rewriter.replaceOpWithNewOp( - op, oldRetType, newDot.getResult()); - return success(); - } -}; - -// Convert + trans + convert -// x = convert_layout distributed -> #shared_x -// y = trans x -> #shared_y -// z = convert_layout y -> #dot_operand -class ConvertTransConvert : public mlir::RewritePattern { - -public: - ConvertTransConvert(mlir::MLIRContext *context) - : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), - 1, context) {} - - LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto dstOp = cast(op); - auto tmpOp = - dyn_cast_or_null(dstOp.getSrc().getDefiningOp()); - if (!tmpOp) - return mlir::failure(); - auto srcOp = dyn_cast_or_null( - tmpOp.getSrc().getDefiningOp()); - if (!srcOp) - return mlir::failure(); - auto arg = srcOp.getSrc(); - auto X = tmpOp.getSrc(); - // types - auto argType = arg.getType().cast(); - auto XType = X.getType().cast(); - auto ZType = dstOp.getResult().getType().cast(); - // encodings - auto argEncoding = argType.getEncoding(); - auto XEncoding = - XType.getEncoding().cast(); - auto ZEncoding = - ZType.getEncoding().dyn_cast(); - if (!ZEncoding) - return mlir::failure(); - // new X encoding - auto newXOrder = triton::gpu::getOrder(argEncoding); - auto newXEncoding = triton::gpu::SharedEncodingAttr::get( - getContext(), ZEncoding, XType.getShape(), newXOrder, - XType.getElementType()); - auto newXType = RankedTensorType::get(XType.getShape(), - XType.getElementType(), newXEncoding); - if (XEncoding == newXEncoding) - return mlir::failure(); - - auto newX = rewriter.create(srcOp.getLoc(), - newXType, arg); - auto newY = rewriter.create(tmpOp.getLoc(), newX); - rewriter.replaceOpWithNewOp(dstOp, ZType, - newY); - return mlir::success(); - } -}; - // class ConvertDotConvert : public mlir::RewritePattern { public: @@ -1275,31 +762,25 @@ public: #define GEN_PASS_CLASSES #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" -class TritonGPUCombineOpsPass - : public TritonGPUCombineOpsBase { +class TritonGPURemoveLayoutConversionsPass + : public TritonGPURemoveLayoutConversionsBase< + TritonGPURemoveLayoutConversionsPass> { public: - TritonGPUCombineOpsPass() = default; - TritonGPUCombineOpsPass(int computeCapability) { - this->computeCapability = computeCapability; - } + TritonGPURemoveLayoutConversionsPass() = default; + void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp m = getOperation(); mlir::RewritePatternSet patterns(context); - patterns.add(context); - patterns.add(context); patterns.add(context); patterns.add(context); - patterns.add(context); - patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context); - patterns.add(context, computeCapability); - patterns.add(context); + patterns.add(context); patterns.add(context); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { @@ -1312,7 +793,6 @@ public: } }; -std::unique_ptr -mlir::createTritonGPUCombineOpsPass(int computeCapability) { - return std::make_unique(computeCapability); +std::unique_ptr mlir::createTritonGPURemoveLayoutConversionsPass() { + return std::make_unique(); } diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index 490cd7e43..6128a4b87 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -147,6 +147,12 @@ static std::map getExternLibs(mlir::ModuleOp module) { if (!funcs.empty()) { static const std::string libdevice = "libdevice"; + // first search for environmental path + std::string env_path = ::triton::tools::getenv("TRITON_LIBDEVICE_PATH"); + if (!env_path.empty()) { + externLibs.try_emplace(libdevice, env_path); + return externLibs; + } namespace fs = std::filesystem; // Search for libdevice relative to its library path if used from Python // Then native code is in `triton/_C/libtriton.so` and libdevice in @@ -302,12 +308,17 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, /*printAfterOnlyOnChange=*/true, /*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags); + pm.addPass(mlir::createConvertSCFToCFPass()); pm.addPass(createConvertTritonGPUToLLVMPass(computeCapability)); // Canonicalize to eliminate the remaining UnrealizedConversionCastOp pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); // Simplify the IR to improve readability. pm.addPass(mlir::createSymbolDCEPass()); pm.addPass(mlir::createCanonicalizerPass()); +#ifdef USE_ROCM + pm.addPass(mlir::createConvertSCFToCFPass()); + pm.addPass(createConvertControlFlowToLLVMPass()); +#endif if (failed(pm.run(module))) { llvm::errs() << "Pass execution failed"; diff --git a/python/setup.py b/python/setup.py index b2f981dff..e607c8001 100644 --- a/python/setup.py +++ b/python/setup.py @@ -65,7 +65,7 @@ def get_llvm_package_info(): linux_suffix = 'ubuntu-18.04' if vglibc > 217 else 'centos-7' system_suffix = f"linux-gnu-{linux_suffix}" else: - raise RuntimeError(f"unsupported system: {system}") + return Package("llvm", "LLVM-C.lib", "", "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH") use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False") release_suffix = "assert" if use_assert_enabled_llvm else "release" name = f'llvm+mlir-17.0.0-x86_64-{system_suffix}-{release_suffix}' @@ -159,7 +159,11 @@ class CMakeBuild(build_ext): def build_extension(self, ext): lit_dir = shutil.which('lit') - triton_cache_path = os.path.join(os.environ["HOME"], ".triton") + user_home = os.getenv("HOME") or os.getenv("USERPROFILE") or \ + os.getenv("HOMEPATH") or None + if not user_home: + raise RuntimeError("Could not find user home directory") + triton_cache_path = os.path.join(user_home, ".triton") # lit is used by the test suite thirdparty_cmake_args = get_thirdparty_packages(triton_cache_path) extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) diff --git a/python/src/main.cc b/python/src/main.cc index d09679727..801a83a4b 100644 --- a/python/src/main.cc +++ b/python/src/main.cc @@ -1,4 +1,4 @@ -#include +#include void init_superblocking(pybind11::module &m); void init_torch_utils(pybind11::module &m); diff --git a/python/src/triton.cc b/python/src/triton.cc index 5112e9cd2..b011186e8 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1462,7 +1462,7 @@ void init_triton_ir(py::module &&m) { .def( "add_sccp_pass", [](mlir::PassManager &self) { self.addPass(mlir::createSCCPPass()); }) - .def("add_coalesce_pass", + .def("add_tritongpu_coalesce_pass", [](mlir::PassManager &self) { self.addPass(mlir::createTritonGPUCoalescePass()); }) @@ -1501,10 +1501,18 @@ void init_triton_ir(py::module &&m) { [](mlir::PassManager &self) { self.addPass(mlir::createTritonGPUPrefetchPass()); }) - .def("add_tritongpu_combine_pass", + .def("add_tritongpu_accelerate_matmul_pass", [](mlir::PassManager &self, int computeCapability) { self.addPass( - mlir::createTritonGPUCombineOpsPass(computeCapability)); + mlir::createTritonGPUAccelerateMatmulPass(computeCapability)); + }) + .def("add_tritongpu_fuse_transpositions_pass", + [](mlir::PassManager &self) { + self.addPass(mlir::createTritonGPUFuseTranspositionsPass()); + }) + .def("add_tritongpu_remove_layout_conversions_pass", + [](mlir::PassManager &self) { + self.addPass(mlir::createTritonGPURemoveLayoutConversionsPass()); }) .def("add_tritongpu_update_mma_for_volta_pass", [](mlir::PassManager &self) { diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index afec19019..23d7bc293 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -8,7 +8,7 @@ import triton import triton.language as tl from triton.testing import get_dram_gbps, get_max_tensorcore_tflops -DEVICE_NAME = 'v100' +DEVICE_NAME = {7: 'v100', 8: 'a100'}[torch.cuda.get_device_capability()[0]] ####################### # Utilities @@ -34,7 +34,6 @@ mem_clocks = {'v100': 877, 'a100': 1215} matmul_data = { 'v100': { # square - (256, 256, 256): {'float16': 0.027}, (512, 512, 512): {'float16': 0.158}, (1024, 1024, 1024): {'float16': 0.466}, (2048, 2048, 2048): {'float16': 0.695}, @@ -51,29 +50,26 @@ matmul_data = { (4096, 64, 4096): {'float16': 0.264}, (8192, 64, 8192): {'float16': 0.452}, }, + # NOTE: + # A100 in the CI server is slow-ish for some reason. + # On some other servers, we are getting about 90% peak for 8kx8x8k float16 'a100': { - (256, 256, 256): {'float16': 0.010, 'float32': 0.0214, 'int8': 0.006}, - (512, 512, 512): {'float16': 0.061, 'float32': 0.109, 'int8': 0.030}, - (1024, 1024, 1024): {'float16': 0.287, 'float32': 0.331, 'int8': 0.169}, - (2048, 2048, 2048): {'float16': 0.604, 'float32': 0.599, 'int8': 0.385}, - (4096, 4096, 4096): {'float16': 0.842, 'float32': 0.862, 'int8': 0.711}, - (8192, 8192, 8192): {'float16': 0.896, 'float32': 0.932, 'int8': 0.860}, + (512, 512, 512): {'float16': 0.08, 'float32': 0.13, 'int8': 0.05}, + (1024, 1024, 1024): {'float16': 0.33, 'float32': 0.35, 'int8': 0.169}, + (2048, 2048, 2048): {'float16': 0.64, 'float32': 0.57, 'int8': 0.34}, + (4096, 4096, 4096): {'float16': 0.81, 'float32': 0.75, 'int8': 0.46}, + (8192, 8192, 8192): {'float16': 0.77, 'float32': 0.85, 'int8': 0.51}, # tall-skinny (16, 1024, 1024): {'float16': 0.0077, 'float32': 0.0127, 'int8': 0.005}, (16, 4096, 4096): {'float16': 0.0363, 'float32': 0.0457, 'int8': 0.0259}, - (16, 8192, 8192): {'float16': 0.0564, 'float32': 0.0648, 'int8': 0.0431}, + (16, 8192, 8192): {'float16': 0.07, 'float32': 0.0648, 'int8': 0.0431}, (64, 1024, 1024): {'float16': 0.0271, 'float32': 0.0509, 'int8': 0.0169}, - (64, 4096, 4096): {'float16': 0.141, 'float32': 0.162, 'int8': 0.097}, - (64, 8192, 8192): {'float16': 0.244, 'float32': 0.257, 'int8': 0.174}, + (64, 4096, 4096): {'float16': 0.16, 'float32': 0.162, 'int8': 0.097}, + (64, 8192, 8192): {'float16': 0.30, 'float32': 0.257, 'int8': 0.174}, (1024, 64, 1024): {'float16': 0.0263, 'float32': 0.0458, 'int8': 0.017}, - (4096, 64, 4096): {'float16': 0.135, 'float32': 0.177, 'int8': 0.102}, - (8192, 64, 8192): {'float16': 0.216, 'float32': 0.230, 'int8': 0.177}, + (4096, 64, 4096): {'float16': 0.16, 'float32': 0.177, 'int8': 0.102}, + (8192, 64, 8192): {'float16': 0.25, 'float32': 0.230, 'int8': 0.177}, } - # # deep reductions - # (64 , 64 , 16384) : {'a100': 0.}, - # (64 , 64 , 65536) : {'a100': 0.}, - # (256 , 256 , 8192 ) : {'a100': 0.}, - # (256 , 256 , 32768) : {'a100': 0.}, } @@ -88,9 +84,7 @@ def test_matmul(M, N, K, dtype_str): torch.manual_seed(0) ref_gpu_util = matmul_data[DEVICE_NAME][(M, N, K)][dtype_str] cur_sm_clock = nvsmi(['clocks.current.sm'])[0] - ref_sm_clock = sm_clocks[DEVICE_NAME] max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3) - assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz' if dtype == torch.int8: a = torch.randint(-128, 127, (M, K), dtype=dtype, device='cuda') b = torch.randint(-128, 127, (N, K), dtype=dtype, device='cuda') @@ -99,10 +93,10 @@ def test_matmul(M, N, K, dtype_str): a = torch.randn((M, K), dtype=dtype, device='cuda') b = torch.randn((K, N), dtype=dtype, device='cuda') fn = lambda: triton.ops.matmul(a, b) - ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000) + ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=300) cur_gpu_perf = 2. * M * N * K / ms * 1e-9 cur_gpu_util = cur_gpu_perf / max_gpu_perf - triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2) + assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05) ####################### @@ -149,16 +143,61 @@ elementwise_data = { def test_elementwise(N): torch.manual_seed(0) ref_gpu_util = elementwise_data[DEVICE_NAME][N] - cur_mem_clock = nvsmi(['clocks.current.memory'])[0] - ref_mem_clock = mem_clocks[DEVICE_NAME] max_gpu_perf = get_dram_gbps() - assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memory must run at {ref_mem_clock} MHz' z = torch.empty((N, ), dtype=torch.float16, device='cuda') x = torch.randn_like(z) y = torch.randn_like(z) grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), ) fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024) - ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=250) + ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=500) cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6 cur_gpu_util = cur_gpu_perf / max_gpu_perf - triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2) + assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05) + +####################### +# Flash-Attention +####################### + + +flash_attention_data = { + "a100": { + (4, 48, 4096, 64, 'forward', 'float16'): 0.37, + (4, 48, 4096, 64, 'backward', 'float16'): 0.25, + } +} + + +@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [[4, 48, 4096, 64]]) +@pytest.mark.parametrize("mode", ['forward', 'backward']) +@pytest.mark.parametrize("dtype_str", ['float16']) +def test_flash_attention(Z, H, N_CTX, D_HEAD, mode, dtype_str): + is_backward = mode == 'backward' + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + pytest.skip("Flash attention only supported for compute capability < 80") + torch.manual_seed(20) + dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str] + # init data + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_() + sm_scale = 0.2 + # benchmark + fn = lambda: triton.ops.attention(q, k, v, sm_scale) + if is_backward: + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, percentiles=None, warmup=100, rep=500) + # compute flops + flops_per_matmul = 2. * Z * H * N_CTX * N_CTX * D_HEAD * 0.5 + total_flops = 2 * flops_per_matmul + if is_backward: + total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) + cur_gpu_perf = total_flops / ms * 1e-9 + # maximum flops + cur_sm_clock = nvsmi(['clocks.current.sm'])[0] + max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3) + cur_gpu_util = cur_gpu_perf / max_gpu_perf + ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, mode, dtype_str)] + assert triton.testing.allclose(cur_gpu_util, ref_gpu_util, atol=0.01, rtol=0.05) diff --git a/python/test/unit/language/test_core_amd.py b/python/test/unit/language/test_core_amd.py index 51108ebb8..2cf65ced5 100644 --- a/python/test/unit/language/test_core_amd.py +++ b/python/test/unit/language/test_core_amd.py @@ -1034,7 +1034,7 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'): @pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype", [(*shape, 4, False, False, epilogue, allow_tf32, dtype) - for shape in [(64, 64, 64)] + for shape in [(64, 64, 64), (16, 16, 16)] for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] for allow_tf32 in [True, False] for dtype in ['float16', 'float32'] @@ -1063,6 +1063,9 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi pytest.skip("Only test int8 on devices with sm >= 80") elif dtype == 'float32' and allow_tf32: pytest.skip("Only test tf32 on devices with sm >= 80") + if capability[0] == 7: + if (M, N, K, num_warps) == (128, 256, 32, 8): + pytest.skip("shared memory out of resource") torch.backends.cuda.matmul.allow_tf32 = allow_tf32 @@ -1298,18 +1301,18 @@ def test_noop(device='cuda'): kernel[(1, )](x) -@pytest.mark.parametrize("device", ['cuda', 'cpu']) +@pytest.mark.parametrize("device", ['cuda', 'cpu', 'cpu_pinned']) def test_pointer_arguments(device): @triton.jit def kernel(x): pass - x = torch.empty(1024, device=device) - result = True - try: - kernel[(1,)](x) - except ValueError: - result = True if device == 'cpu' else False - assert result + pin_memory = 'pinned' in device + x = torch.empty(1024, device=device.split('_')[0], pin_memory=pin_memory) + if device == "cpu": + with pytest.raises(ValueError): + kernel[(1,)](x) + else: + kernel[(1, )](x) @pytest.mark.parametrize("value, value_type", [ diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 0d8047bd3..60c6b72da 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -16,7 +16,6 @@ import tempfile import warnings from collections import namedtuple from pathlib import Path -from sysconfig import get_paths from typing import Any, Callable, Dict, Tuple, Union import setuptools @@ -981,32 +980,32 @@ def ast_to_ttir(fn, signature, specialization, constants): return optimize_triton_ir(mod) -def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability): +def ttir_to_ttgir(mod, num_warps): pm = _triton.ir.pass_manager(mod.context) pm.add_convert_triton_to_tritongpu_pass(num_warps) + pm.run(mod) + return mod + + +def optimize_ttgir(mod, num_stages, compute_capability): + pm = _triton.ir.pass_manager(mod.context) pm.enable_debug() - pm.add_coalesce_pass() - # The combine pass converts blocked layout to mma layout - # for dot ops so that pipeline can get shared memory swizzled correctly. - pm.add_tritongpu_combine_pass(compute_capability) + pm.add_tritongpu_coalesce_pass() + pm.add_tritongpu_accelerate_matmul_pass(compute_capability) + pm.add_tritongpu_remove_layout_conversions_pass() + pm.add_tritongpu_fuse_transpositions_pass() pm.add_tritongpu_pipeline_pass(num_stages) - # Prefetch must be done after pipeline pass because pipeline pass - # extracts slices from the original tensor. pm.add_tritongpu_prefetch_pass() - pm.add_canonicalizer_pass() - pm.add_cse_pass() - pm.add_tritongpu_combine_pass(compute_capability) - pm.add_licm_pass() - pm.add_tritongpu_combine_pass(compute_capability) - pm.add_cse_pass() + pm.add_tritongpu_fuse_transpositions_pass() + pm.add_tritongpu_remove_layout_conversions_pass() pm.add_tritongpu_decompose_conversions_pass() if compute_capability // 10 == 7: # The update_mma_for_volta pass helps to compute some information for MMA encoding specifically for MMAv1 # NOTE this pass should be placed after all the passes those modifies mma layout pm.add_tritongpu_update_mma_for_volta_pass() + pm.add_tritongpu_reorder_instructions_pass() pm.add_cse_pass() pm.add_symbol_dce_pass() - pm.add_tritongpu_reorder_instructions_pass() pm.run(mod) return mod @@ -1459,11 +1458,6 @@ def default_cache_dir(): return os.path.join(os.environ["HOME"], ".triton", "cache") -def default_cuda_dir(): - default_dir = "/usr/local/cuda" - return os.getenv("CUDA_HOME", default=default_dir) - - class CacheManager: def __init__(self, key): @@ -1537,14 +1531,15 @@ def _build(name, src, srcdir): hip_include_dir = os.path.join(hip_home_dirs(), "include") else: cuda_lib_dirs = libcuda_dirs() - cuda_path = os.environ.get('CUDA_PATH', default_cuda_dir()) + base_dir = os.path.dirname(__file__) + cuda_path = os.path.join(base_dir, "third_party", "cuda") + cu_include_dir = os.path.join(cuda_path, "include") triton_include_dir = os.path.join(os.path.dirname(__file__), "include") cuda_header = os.path.join(cu_include_dir, "cuda.h") triton_cuda_header = os.path.join(triton_include_dir, "cuda.h") if not os.path.exists(cuda_header) and os.path.exists(triton_cuda_header): cu_include_dir = triton_include_dir - suffix = sysconfig.get_config_var('EXT_SUFFIX') so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) # try to avoid setuptools if possible @@ -1556,7 +1551,17 @@ def _build(name, src, srcdir): cc = gcc if gcc is not None else clang if cc is None: raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.") - py_include_dir = get_paths()["include"] + # This function was renamed and made public in Python 3.10 + if hasattr(sysconfig, 'get_default_scheme'): + scheme = sysconfig.get_default_scheme() + else: + scheme = sysconfig._get_default_scheme() + # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install + # path changes to include 'local'. This change is required to use triton with system-wide python. + if scheme == 'posix_local': + scheme = 'posix_prefix' + py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] + if torch.version.hip is not None: ret = subprocess.check_call([cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{hip_lib_dir}", "-lamdhip64", "-o", so]) else: @@ -1751,30 +1756,30 @@ def compile(fn, **kwargs): raise RuntimeError('gfx_arch is None (not specified)') stages = { "ast": (lambda path: fn, None), - "ttir": (lambda path: _triton.ir.parse_mlir_module(path, context), - lambda src: ast_to_ttir(src, signature, configs[0], constants)), - "ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context), - lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)), + "ttir": (lambda path: parse_mlir_module(path, context), + lambda src: ast_to_ttir(src, signature, configs[0], constants)), + "ttgir": (lambda path: parse_mlir_module(path, context), + lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, capability)), "llir": (lambda path: Path(path).read_text(), - lambda src: ttgir_to_llir(src, extern_libs, capability)), + lambda src: ttgir_to_llir(src, extern_libs, capability)), "amdgcn": (lambda path: Path(path).read_text(), - lambda src: llir_to_amdgcn_and_hsaco(src, gfx_arch, - gfx_arch_full_details[0], + lambda src: llir_to_amdgcn_and_hsaco(src, gfx_arch, + gfx_arch_full_details[0], gfx_arch_full_details[2])), } else: stages = { "ast": (lambda path: fn, None), - "ttir": (lambda path: _triton.ir.parse_mlir_module(path, context), - lambda src: ast_to_ttir(src, signature, configs[0], constants)), - "ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context), - lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)), + "ttir": (lambda path: parse_mlir_module(path, context), + lambda src: ast_to_ttir(src, signature, configs[0], constants)), + "ttgir": (lambda path: parse_mlir_module(path, context), + lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, capability)), "llir": (lambda path: Path(path).read_text(), - lambda src: ttgir_to_llir(src, extern_libs, capability)), + lambda src: ttgir_to_llir(src, extern_libs, capability)), "ptx": (lambda path: Path(path).read_text(), lambda src: llir_to_ptx(src, capability)), "cubin": (lambda path: Path(path).read_bytes(), - lambda src: ptx_to_cubin(src, capability)) + lambda src: ptx_to_cubin(src, capability)) } # find out the signature of the function diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 43e1b993a..be175d531 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -9,6 +9,8 @@ from triton._C.libtriton.triton import ir T = TypeVar('T') +TRITON_MAX_TENSOR_NUMEL = 131072 + def _to_tensor(x, builder): if isinstance(x, bool): @@ -254,6 +256,8 @@ class block_type(dtype): self.numel = 1 for s in self.shape: self.numel *= s + if self.numel > TRITON_MAX_TENSOR_NUMEL: + raise ValueError(f"numel ({self.numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})") self.name = self.__str__() @@ -702,12 +706,13 @@ def num_programs(axis, _builder=None): @builtin def arange(start, end, _builder=None): """ - Returns contiguous values within the open interval [:code:`start`, :code:`end`). + Returns contiguous values within the left-closed and right-open interval [:code:`start`, :code:`end`). \ + End - Start must be less than or equal to TRITON_MAX_TENSOR_NUMEL = 131072 :param start: Start of the interval. Must be a power of two. - :type start: int - :param stop: End of the interval. Must be a power of two >= start. - :type stop: int + :type start: int32 + :param end: End of the interval. Must be a power of two > start. + :type end: int32 """ start = _constexpr_to_value(start) end = _constexpr_to_value(end) diff --git a/python/triton/language/libdevice.py b/python/triton/language/libdevice.py index 87aa00f0c..b3f992fe5 100644 --- a/python/triton/language/libdevice.py +++ b/python/triton/language/libdevice.py @@ -5,9 +5,10 @@ from .. import impl from . import core, extern if torch.version.hip is not None: - LIBDEVICE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "cuda2gcn.bc") + LOCAL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "cuda2gcn.bc") else: - LIBDEVICE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "third_party", "cuda", "lib", "libdevice.10.bc") + LOCAL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "third_party", "cuda", "lib", "libdevice.10.bc") +LIBDEVICE_PATH = os.getenv("TRITON_LIBDEVICE_PATH", LOCAL_PATH) @impl.extern def clz(arg0, _builder=None): diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 05548f0a3..5a2af59fa 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -480,6 +480,12 @@ def not_equal(input: tl.tensor, def arange(start: int, end: int, builder: ir.builder) -> tl.tensor: if not isinstance(start, int) or not isinstance(end, int): raise ValueError("arange's arguments must be of type tl.constexpr") + is_start_int64 = bool(start >> 32) + is_end_int64 = bool(end >> 32) + if is_start_int64 or is_end_int64: + raise ValueError("arange must fit in int32") + if end <= start: + raise ValueError("arange's end argument must be greater than the start argument") shape = [end - start] ret_ty = tl.block_type(tl.int32, shape) @@ -655,6 +661,8 @@ def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor: src_ty = input.type + if isinstance(dst_ty, tl.constexpr): + dst_ty = dst_ty.value if src_ty.is_block(): dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes()) if src_ty == dst_ty: diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 1f5ea971b..aa5fb47b0 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -62,9 +62,9 @@ class Autotuner(KernelInterface): self.hook(args) self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) try: - return do_bench(kernel_call) + return do_bench(kernel_call, percentiles=(0.5, 0.2, 0.8)) except OutOfResources: - return float('inf') + return (float('inf'), float('inf'), float('inf')) def run(self, *args, **kwargs): self.nargs = dict(zip(self.arg_names, args)) diff --git a/python/triton/testing.py b/python/triton/testing.py index f277ec140..7a1684a77 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -93,7 +93,11 @@ def assert_almost_equal(x, y, decimal=2, err_msg=''): npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal) -def allclose(x, y, tol=1e-2): +def allclose(x, y, atol=0, rtol=1e-2): + if not isinstance(x, torch.Tensor): + x = torch.tensor(x) + if not isinstance(y, torch.Tensor): + y = torch.tensor(y) if x.dtype != y.dtype: raise RuntimeError(f'{x.dtype} did not match with {x.dtype}') if x.shape != y.shape: @@ -101,12 +105,11 @@ def allclose(x, y, tol=1e-2): if x.dtype == torch.bool: return torch.sum(x ^ y) == 0 if x.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: - tol = 0 + rtol = 0 diff = abs(x - y) x_max = torch.max(x) y_max = torch.max(y) - err = torch.max(diff) / torch.max(x_max, y_max) - return err <= tol + return torch.max(diff) <= atol + rtol * torch.max(x_max, y_max) def nvsmi(attrs): diff --git a/python/triton/tools/aot.py b/python/triton/tools/aot.py index b6c6f3608..77a6f50f7 100644 --- a/python/triton/tools/aot.py +++ b/python/triton/tools/aot.py @@ -83,7 +83,8 @@ if __name__ == '__main__': raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation") # triton-ir -> triton-gpu-ir - module = triton.compiler.ttir_to_ttgir(module, num_warps=4, num_stages=3, compute_capability=args.sm) + module = triton.compiler.ttir_to_ttgir(module, num_warps=4) + module = triton.compiler.optimize_ttgir(module, num_stages=3, compute_capability=args.sm) if args.target == 'triton-gpu-ir': print(module.str()) sys.exit(0) diff --git a/python/triton/tools/build_extern.py b/python/triton/tools/build_extern.py index 22011c273..5ae89d77b 100644 --- a/python/triton/tools/build_extern.py +++ b/python/triton/tools/build_extern.py @@ -289,7 +289,8 @@ class Libdevice(ExternLibrary): # return extern.dispatch("libdevice", , , , _builder) import_str = "from . import core, extern\n" import_str += "import os\n" - header_str = "LIBDEVICE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), \"..\", \"third_party\", \"cuda\", \"lib\", \"libdevice.10.bc\")" + header_str = "LOCAL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), \"..\", \"third_party\", \"cuda\", \"lib\", \"libdevice.10.bc\")\n" + header_str += "LIBDEVICE_PATH = os.getenv(\"TRITON_LIBDEVICE_PATH\", LOCAL_PATH)\n" func_str = "" for symbols in self._symbol_groups.values(): func_str += "@extern.extern\n" diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index b7832210f..de7bb8623 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -223,6 +223,7 @@ class _attention(torch.autograd.Function): BLOCK_DMODEL=Lk, num_warps=num_warps, num_stages=2, ) + # print(h.asm["ttgir"]) ctx.save_for_backward(q, k, v, o, L, m) ctx.grid = grid @@ -260,6 +261,7 @@ class _attention(torch.autograd.Function): BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, num_stages=1, ) + # print(h.asm["ttgir"]) return dq, dk, dv, None diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir index d1ce4ae7b..e036e92f4 100644 --- a/test/Analysis/test-alias.mlir +++ b/test/Analysis/test-alias.mlir @@ -179,8 +179,8 @@ func.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B return } -// CHECK-LABEL: for_if_for -func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { +// CHECK-LABEL: for_for_if +func.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: %cst -> %cst %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: %cst_0 -> %cst_0 @@ -213,3 +213,34 @@ func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr } return } + +// CHECK-LABEL: cf_for +func.func @cf_for(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr, %arg4: !tt.ptr) { + // CHECK: %cst -> %cst + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #A_SHARED> + // CHECK-NEXT: %cst_0 -> %cst_0 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #A_SHARED> + gpu.barrier + // CHECK-NEXT: %0 -> %0 + %0 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + // CHECK-NEXT: %cst_1 -> %cst_1 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #A_SHARED> + // CHECK-NEXT: %2 -> %cst,%cst_0,%cst_1 + // CHECK-NEXT: %3 -> %cst,%cst_0,%cst_1 + // CHECK-NEXT: %4 -> %cst,%cst_0,%cst_1 + cf.br ^bb1(%arg0, %cst, %cst_0, %cst_1 : index, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) +^bb1(%1: index, %2: tensor<128x32xf16, #A_SHARED>, %3: tensor<128x32xf16, #A_SHARED>, %4: tensor<128x32xf16, #A_SHARED>): // 2 preds: ^bb0, ^bb2 + %5 = arith.cmpi slt, %1, %arg1 : index + cf.cond_br %5, ^bb2, ^bb3 +^bb2: // pred: ^bb1 + %6 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #blocked> + gpu.barrier + %7 = tt.cat %2, %3 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #blocked> + %8 = arith.addi %1, %arg2 : index + cf.br ^bb1(%8, %4, %2, %3 : index, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) +^bb3: // pred: ^bb1 + gpu.barrier + // CHECK-NEXT: %9 -> %9 + %9 = tt.cat %0, %0 {axis = 0 : i64} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED> + return +} diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index f79222aa7..18df0010d 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -315,8 +315,8 @@ func.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.pt // a_shared_init, b_shared_init, and c_shared_init's liveness ranges are span over the entire function before cst2. // So they cannot be reused by cst0 and cst1, but can be reused by cst2. -// CHECK-LABEL: for_if_for -func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { +// CHECK-LABEL: for_for_if +func.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: offset = 0, size = 8192 %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: offset = 8192, size = 8192 diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index 17880b209..e368fada3 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-membar 2>&1 | FileCheck %s +// RUN: triton-opt %s -split-input-file --mlir-disable-threading --convert-scf-to-cf -test-print-membar 2>&1 | FileCheck %s #AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #sliceAd0 = #triton_gpu.slice<{dim = 0, parent = #AL}> @@ -45,11 +45,12 @@ func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr) { %cst1 = arith.constant dense : tensor<128x32xi1, #AL> %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> - %a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> - %a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_SHARED> - // CHECK: Membar 5 - %a2 = triton_gpu.convert_layout %a1 : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #A_SHARED> + %0 = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> + %1 = tt.load %0, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> + %2 = triton_gpu.convert_layout %1 : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: triton_gpu.convert_layout + %3 = triton_gpu.convert_layout %2 : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #A_SHARED> return } @@ -57,46 +58,51 @@ func.func @raw_single_block(%A : !tt.ptr) { func.func @war_single_block(%A : !tt.ptr) { %cst1 = arith.constant dense : tensor<128x32xi1, #AL> %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> - %a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> - %a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_SHARED> - // CHECK: Membar 5 - %a2 = triton_gpu.convert_layout %a1 : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL> - // a2's liveness range ends here, and a3 and a2 have the same address range. - // So it makes sense to have a WAR dependency between a2 and a3. - // CHECK-NEXT: Membar 7 - %a3 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_SHARED> + %0 = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> + %1 = tt.load %0, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> + %2 = triton_gpu.convert_layout %1 : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: triton_gpu.convert_layout + %3 = triton_gpu.convert_layout %2 : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL> + // CHECK: gpu.barrier + // CHECK-NEXT: %4 = triton_gpu.convert_layout + %4 = triton_gpu.convert_layout %1 : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_SHARED> return } // CHECK-LABEL: scratch func.func @scratch() { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> - // CHECK: Membar 1 - %a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> - // CHECK-NEXT: Membar 3 - %aa = triton_gpu.convert_layout %a : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL> - %b = tt.reduce %aa {redOp = 1 : i32, axis = 0 : i32} : tensor<32x16xf16, #AL> -> tensor<16xf16, #sliceAd0> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %0 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: triton_gpu.convert_layout + %1 = triton_gpu.convert_layout %0 : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL> + %2 = tt.reduce %1 {redOp = 1 : i32, axis = 0 : i32} : tensor<32x16xf16, #AL> -> tensor<16xf16, #sliceAd0> return } // CHECK-LABEL: async_wait func.func @async_wait() { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> - // CHECK: Membar 1 - %a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %0 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> triton_gpu.async_wait {num = 4 : i32} - // CHECK-NEXT: Membar 4 - %a_ = triton_gpu.convert_layout %a : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL> + // CHECK: gpu.barrier + // CHECK-NEXT: triton_gpu.convert_layout + %1 = triton_gpu.convert_layout %0 : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL> return } // CHECK-LABEL: alloc func.func @alloc() { - %cst0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED> - %a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> - // CHECK: Membar 2 - %b = triton_gpu.convert_layout %a : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL> + %0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED> + %1 = tt.cat %0, %0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: triton_gpu.convert_layout + %2 = triton_gpu.convert_layout %1 : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL> return } @@ -104,50 +110,58 @@ func.func @alloc() { func.func @extract_slice() { %cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED> %index = arith.constant 0 : index - %cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED> - // CHECK: Membar 3 - %cst2 = triton_gpu.convert_layout %cst1 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL> - // CHECK-NEXT: Membar 5 - %cst3 = triton_gpu.convert_layout %cst2 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A_SHARED> + %0 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: triton_gpu.convert_layout + %1 = triton_gpu.convert_layout %0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL> + // CHECK: gpu.barrier + // CHECK-NEXT: triton_gpu.convert_layout + %2 = triton_gpu.convert_layout %1 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A_SHARED> return } // CHECK-LABEL: trans func.func @trans() { + // CHECK-NOT: gpu.barrier %cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED> %b = tt.trans %cst0 : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED_T> return } -// CHECK-LABEL: insert_slice_async -func.func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { +// CHECK-LABEL: insert_slice_async_op +func.func @insert_slice_async_op(%A : !tt.ptr, %i1 : i1) { %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<16x16x!tt.ptr, #AL> %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> %tensor = triton_gpu.alloc_tensor : tensor<1x16x16xf16, #A_SHARED> %index = arith.constant 0 : i32 - %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr, #AL> -> tensor<1x16x16xf16, #A_SHARED> - // CHECK: Membar 6 - %b = tt.cat %a, %a {axis = 0} : (tensor<1x16x16xf16, #A_SHARED>, tensor<1x16x16xf16, #A_SHARED>) -> tensor<2x16x16xf16, #A_SHARED> - // CHECK: Membar 8 - %c = tt.cat %b, %b {axis = 0} : (tensor<2x16x16xf16, #A_SHARED>, tensor<2x16x16xf16, #A_SHARED>) -> tensor<4x16x16xf16, #A_SHARED> + %3 = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr, #AL> -> tensor<1x16x16xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %4 = tt.cat %3, %3 {axis = 0} : (tensor<1x16x16xf16, #A_SHARED>, tensor<1x16x16xf16, #A_SHARED>) -> tensor<2x16x16xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %5 = tt.cat %4, %4 {axis = 0} : (tensor<2x16x16xf16, #A_SHARED>, tensor<2x16x16xf16, #A_SHARED>) -> tensor<4x16x16xf16, #A_SHARED> return } -// CHECK-LABEL: insert_slice -func.func @insert_slice(%A : !tt.ptr, %i1 : i1) { +// CHECK-LABEL: insert_slice_op +func.func @insert_slice_op(%A : !tt.ptr, %i1 : i1) { %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<16x16x!tt.ptr, #AL> %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> %tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED> %index = arith.constant 0 : index - %al = tt.load %a_ptr, %mask, %other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #AL> - // CHECK: Membar 6 - %a = tensor.insert_slice %al into %tensor[%index, 0, 0][1, 16, 16][1, 1, 1]: tensor<16x16xf16, #AL> into tensor<1x16x16xf16, #A_SHARED> - // CHECK: Membar 8 - %b = tt.cat %a, %a {axis = 0} : (tensor<1x16x16xf16, #A_SHARED>, tensor<1x16x16xf16, #A_SHARED>) -> tensor<2x16x16xf16, #A_SHARED> - // CHECK: Membar 10 - %c = tt.cat %b, %b {axis = 0} : (tensor<2x16x16xf16, #A_SHARED>, tensor<2x16x16xf16, #A_SHARED>) -> tensor<4x16x16xf16, #A_SHARED> + %2 = tt.load %a_ptr, %mask, %other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #AL> + // CHECK: gpu.barrier + // CHECK-NEXT: tensor.insert_slice + %3 = tensor.insert_slice %2 into %tensor[%index, 0, 0][1, 16, 16][1, 1, 1]: tensor<16x16xf16, #AL> into tensor<1x16x16xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %4 = tt.cat %3, %3 {axis = 0} : (tensor<1x16x16xf16, #A_SHARED>, tensor<1x16x16xf16, #A_SHARED>) -> tensor<2x16x16xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %5 = tt.cat %4, %4 {axis = 0} : (tensor<2x16x16xf16, #A_SHARED>, tensor<2x16x16xf16, #A_SHARED>) -> tensor<4x16x16xf16, #A_SHARED> return } @@ -157,18 +171,21 @@ func.func @multi_blocks(%i1 : i1) { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> scf.if %i1 { - // CHECK: Membar 2 - %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %0 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> scf.yield } else { %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> %cst3 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> - // CHECK-NEXT: Membar 7 - %b = tt.cat %cst2, %cst3 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %1 = tt.cat %cst2, %cst3 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> scf.yield } - // CHECK-NEXT: Membar 10 - %c = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %2 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> return } @@ -178,12 +195,14 @@ func.func @multi_blocks_join_barrier(%i1 : i1) { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> scf.if %i1 { - // CHECK: Membar 2 - %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %0 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> scf.yield } else { - // CHECK-NEXT: Membar 5 - %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %1 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> scf.yield } %a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL> @@ -196,17 +215,42 @@ func.func @multi_blocks_yield(%i1 : i1) { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> %a = scf.if %i1 -> (tensor<32x16xf16, #A_SHARED>) { - // CHECK: Membar 2 - %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> - scf.yield %a : tensor<32x16xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %0 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> + scf.yield %0 : tensor<32x16xf16, #A_SHARED> } else { - // CHECK-NEXT: Membar 5 - %b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> - scf.yield %b : tensor<32x16xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %1 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> + scf.yield %1 : tensor<32x16xf16, #A_SHARED> } %a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL> - // CHECK-NEXT: Membar 9 - %b = tt.cat %a, %a {axis = 0} : (tensor<32x16xf16, #A_SHARED>, tensor<32x16xf16, #A_SHARED>) -> tensor<64x16xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %4 = tt.cat %a, %a {axis = 0} : (tensor<32x16xf16, #A_SHARED>, tensor<32x16xf16, #A_SHARED>) -> tensor<64x16xf16, #A_SHARED> + return +} + +// Even though the entry block doesn't have a barrier, the successors should have barriers +// CHECK-LABEL: multi_blocks_entry_no_shared +func.func @multi_blocks_entry_no_shared(%i1 : i1) { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + %a = scf.if %i1 -> (tensor<32x16xf16, #A_SHARED>) { + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + %0 = tt.cat %cst1, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> + scf.yield %0 : tensor<32x16xf16, #A_SHARED> + } else { + // CHECK-NOT: gpu.barrier + // CHECK: arith.constant + %cst1 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #A_SHARED> + scf.yield %cst1 : tensor<32x16xf16, #A_SHARED> + } + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %1 = tt.cat %a, %a {axis = 0} : (tensor<32x16xf16, #A_SHARED>, tensor<32x16xf16, #A_SHARED>) -> tensor<64x16xf16, #A_SHARED> return } @@ -216,11 +260,14 @@ func.func @multi_blocks_noelse(%i1 : i1) { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> scf.if %i1 { - // CHECK: Membar 2 - %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %0 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> scf.yield } - %a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL> + // CHECK: gpu.barrier + // CHECK-NEXT: triton_gpu.convert_layout + %1 = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL> return } @@ -231,18 +278,21 @@ func.func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) { %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> scf.if %i1 { scf.if %i2 { - // CHECK: Membar 2 - %b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %0 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> scf.yield } scf.yield } else { - // CHECK-NEXT: Membar 6 - %b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %1 = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> scf.yield } - // CHECK-NEXT: Membar 9 - %a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL> + // CHECK: gpu.barrier + // CHECK-NEXT: triton_gpu.convert_layout + %2 = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL> return } @@ -252,8 +302,9 @@ func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { - // CHECK-NEXT: Membar 3 - %cst0 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %5 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED> } return @@ -265,17 +316,20 @@ func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : func.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> - // CHECK-NEXT: Membar 2 - %cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { %cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #AL> - // CHECK-NEXT: Membar 6 - %cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #AL> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %7 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #AL> scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED> } - // CHECK-NEXT: Membar 9 - %cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %9 = tt.cat %0, %0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED> return } @@ -285,41 +339,162 @@ func.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, func.func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> - // CHECK-NEXT: Membar 2 - %cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { - // CHECK-NEXT: Membar 5 - %cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> - // CHECK-NEXT: Membar 7 - %cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %6 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %7 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED> } - // CHECK-NEXT: Membar 10 - %cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %9 = tt.cat %0, %0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED> return } - // CHECK-LABEL: for_reuse_nested func.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> - // CHECK-NEXT: Membar 2 - %cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { - // CHECK-NEXT: Membar 5 - %cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %6 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> %a_shared_next, %b_shared_next, %c_shared_next = scf.for %ivv = %lb to %ub step %step iter_args(%a_shared_nested = %a_shared_init, %b_shared_nested = %b_shared_init, %c_shared_nested = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { - // CHECK-NEXT: Membar 7 - %cst2 = tt.cat %a_shared_nested, %b_shared_nested {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %12 = tt.cat %a_shared_nested, %b_shared_nested {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> scf.yield %c_shared_nested, %a_shared_nested, %b_shared_nested : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED> } scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED> } - // CHECK-NEXT: Membar 11 - %cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %15 = tt.cat %0, %0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED> + return +} + +// repeatedly write to the same shared memory addresses +// CHECK-LABEL: for_for_if +func.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { + %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (tensor<128x32xf16, #A_SHARED>) { + %c_shared_next_next = scf.if %i1 -> tensor<128x32xf16, #A_SHARED> { + // CHECK: gpu.barrier + // CHECK-NEXT: arith.constant + %cst0 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + scf.yield %cst0 : tensor<128x32xf16, #A_SHARED> + } else { + // CHECK: gpu.barrier + // CHECK-NEXT: arith.constant + %cst0 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + scf.yield %cst0 : tensor<128x32xf16, #A_SHARED> + } + scf.yield %c_shared_next_next : tensor<128x32xf16, #A_SHARED> + } + scf.yield %a_shared, %b_shared, %c_shared_next : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED> + } + return +} + +// c_block_next can either be converted from c_shared_init or c_shared_next_next +// CHECK-LABEL: for_if_for +func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + // CHECK: gpu.barrier + %c_blocked = triton_gpu.convert_layout %c_shared_init : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL> + + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { + %c_shared_next_next = scf.if %i1 -> tensor<128x32xf16, #A_SHARED> { + // CHECK: gpu.barrier + // CHECK-NEXT: arith.constant + %cst0 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + scf.yield %cst0 : tensor<128x32xf16, #A_SHARED> + } else { + %c_shared_ = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (tensor<128x32xf16, #A_SHARED>) { + // CHECK: gpu.barrier + // CHECK-NEXT: triton_gpu.convert_layout + %c_blocked_next = triton_gpu.convert_layout %c_shared_next : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL> + scf.yield %c_shared : tensor<128x32xf16, #A_SHARED> + } + scf.yield %c_shared_ : tensor<128x32xf16, #A_SHARED> + } + // CHECK-NOT: gpu.barrier + %b_blocked_next = triton_gpu.convert_layout %b_shared: (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL> + scf.yield %a_shared, %b_shared, %c_shared_next_next : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED> + } + return +} + +// CHECK-LABEL: cf_if +func.func @cf_if(%i1 : i1) { + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>> + cf.cond_br %i1, ^bb1, ^bb2 +^bb1: // pred: ^bb0 + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %0 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>> + cf.br ^bb2 +^bb2: // 2 preds: ^bb0, ^bb1 + // CHECK: gpu.barrier + // CHECK-NEXT: triton_gpu.convert_layout + %1 = triton_gpu.convert_layout %cst : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<16x16xf16, #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>> + return +} + +func.func @cf_if_else(%i1 : i1) { + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>> + cf.cond_br %i1, ^bb1, ^bb2 +^bb1: // pred: ^bb0 + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %0 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>> + cf.br ^bb3(%0 : tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) +^bb2: // pred: ^bb0 + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %1 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>> + cf.br ^bb3(%1 : tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) +^bb3(%2: tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>): // 2 preds: ^bb1, ^bb2 + cf.br ^bb4 +^bb4: // pred: ^bb3 + %3 = triton_gpu.convert_layout %cst : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<16x16xf16, #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %4 = tt.cat %2, %2 {axis = 0 : i64} : (tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<64x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>> + return +} + +func.func @cf_if_else_return(%i1 : i1) { + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>> + cf.cond_br %i1, ^bb1, ^bb2 +^bb1: // pred: ^bb0 + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %0 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>> + return +^bb2: // pred: ^bb0 + // CHECK: gpu.barrier + // CHECK-NEXT: tt.cat + %1 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>, tensor<16x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>>) -> tensor<32x16xf16, #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>> return } diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir index 51cccccfb..aa99c8766 100644 --- a/test/TritonGPU/coalesce.mlir +++ b/test/TritonGPU/coalesce.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-coalesce -canonicalize | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritongpu-coalesce | FileCheck %s #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index f1a56b82f..393dd8eff 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-combine 2>&1 | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritongpu-remove-layout-conversions 2>&1 | FileCheck %s #layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> #layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> diff --git a/test/TritonGPU/matmul.mlir b/test/TritonGPU/matmul.mlir index 01dc3f0ab..fd7492f44 100644 --- a/test/TritonGPU/matmul.mlir +++ b/test/TritonGPU/matmul.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu -tritongpu-combine -tritongpu-pipeline=num-stages=3 -tritongpu-combine -test-print-allocation 2>&1 | FileCheck %s +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu -tritongpu-remove-layout-conversions -tritongpu-pipeline=num-stages=3 -test-print-allocation 2>&1 | FileCheck %s // CHECK: offset = 0, size = 49152 // CHECK: offset = 49152, size = 49152 diff --git a/test/TritonGPU/update-mma-for-volta.mlir b/test/TritonGPU/update-mma-for-volta.mlir index 2f0d31fe8..9e63d11e6 100644 --- a/test/TritonGPU/update-mma-for-volta.mlir +++ b/test/TritonGPU/update-mma-for-volta.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-combine -tritongpu-update-mma-for-volta 2>&1 | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritongpu-fuse-transposition -tritongpu-update-mma-for-volta 2>&1 | FileCheck %s // ----- diff --git a/test/lib/Analysis/TestAlias.cpp b/test/lib/Analysis/TestAlias.cpp index 09ae90579..908018735 100644 --- a/test/lib/Analysis/TestAlias.cpp +++ b/test/lib/Analysis/TestAlias.cpp @@ -65,8 +65,19 @@ struct TestAliasPass }; operation->walk([&](Operation *op) { - if (op->getNumResults() < 1) + if (op->getNumResults() < 1) { + // cond br, br + if (auto branch = dyn_cast(op)) { + auto *block = branch->getBlock(); + for (auto arg : llvm::enumerate(block->getArguments())) { + auto operand = block->getArgument(arg.index()); + auto opNames = getAllocOpNames(operand); + auto argName = getValueOperandName(arg.value(), state); + print(argName, opNames, os); + } + } return; + } if (auto forOp = dyn_cast(op)) { for (auto arg : llvm::enumerate(forOp.getRegionIterArgs())) { auto operand = forOp.getOpOperandForRegionIterArg(arg.value()).get(); diff --git a/test/lib/Analysis/TestMembar.cpp b/test/lib/Analysis/TestMembar.cpp index ab9b9f3fb..c43978b39 100644 --- a/test/lib/Analysis/TestMembar.cpp +++ b/test/lib/Analysis/TestMembar.cpp @@ -1,6 +1,8 @@ +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/IR/Dialect.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" #include "triton/Analysis/Allocation.h" #include "triton/Analysis/Membar.h" @@ -24,21 +26,13 @@ struct TestMembarPass // Convert to std::string can remove quotes from op_name auto opName = SymbolTable::getSymbolName(operation).getValue().str(); os << opName << "\n"; + + // Print all ops after membar pass Allocation allocation(operation); MembarAnalysis membarPass(&allocation); membarPass.run(); - size_t operationId = 0; - operation->walk([&](Operation *op) { - if (isa(op)) { - os << "Membar " << operationId << "\n"; - } - if (op->getNumRegions() == 0) { - // Don't count parent Operation to simplify the test. - operationId++; - } - return; - }); + os << *operation << "\n"; } };