diff --git a/CMakeLists.txt b/CMakeLists.txt index 3ac51329b..ad2bdc8a2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -239,6 +239,13 @@ set(TRITON_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}") get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +# TODO: Figure out which target is sufficient to fix errors; triton is +# apparently not enough. Currently set linking libstdc++fs for all targets +# to support some old version GCC compilers like 8.3.0. +if (NOT WIN32 AND NOT APPLE) + link_libraries(stdc++fs) +endif() + if(TRITON_BUILD_PYTHON_MODULE) add_library(triton SHARED ${PYTHON_SRC}) set(TRITON_LIBRARIES @@ -259,7 +266,6 @@ if(TRITON_BUILD_PYTHON_MODULE) MLIRLLVMDialect MLIRSupport MLIRTargetLLVMIRExport - MLIRExecutionEngine MLIRMathToLLVM MLIRNVVMToLLVMIRTranslation MLIRROCDLToLLVMIRTranslation @@ -278,9 +284,6 @@ if(TRITON_BUILD_PYTHON_MODULE) target_link_libraries(triton ${ROCM_LIBRARIES} ${LLVM_LIBRARIES} z ${TRITON_LIBRARIES} ) - # TODO: Figure out which target is sufficient to fix errors; triton is - # apparently not enough - link_libraries(stdc++fs) endif() target_link_options(triton PRIVATE ${LLVM_LDFLAGS}) diff --git a/README.md b/README.md index 52bb5b41e..cafc069b6 100644 --- a/README.md +++ b/README.md @@ -90,10 +90,10 @@ arbitrary LLVM version. 1. Find the version of LLVM that Triton builds against. Check `python/setup.py` for a line like - version = "llvm-17.0.0-c5dede880d17" + version = "llvmorg-18-init-7000-g76ce4736721a" This means that the version of Triton you have builds against - [LLVM](https://github.com/llvm/llvm-project) c5dede880d17. + [LLVM](https://github.com/llvm/llvm-project) 76ce4736721a. 2. `git checkout` LLVM at this revision. Optionally, make additional modifications to LLVM. diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 075d71282..7ce48d227 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -72,7 +72,6 @@ llvm_update_compile_flags(triton-translate) MLIRPass MLIRSupport MLIRTransforms - MLIRExecutionEngine MLIRMathToLLVM MLIRTransformUtils MLIRLLVMToLLVMIRTranslation diff --git a/bin/triton-translate.cpp b/bin/triton-translate.cpp index ee1da6a20..9993827bc 100644 --- a/bin/triton-translate.cpp +++ b/bin/triton-translate.cpp @@ -104,10 +104,15 @@ LogicalResult tritonTranslateMain(int argc, char **argv, "", llvm::cl::desc("AMDGCN features. e.g. '+sramecc,-xnack'"), llvm::cl::value_desc("features"), llvm::cl::init("+sramecc,-xnack")); + static llvm::cl::opt enableFpFusion( + "enable-fp-fusion", llvm::cl::desc("Enables fusion of fadd/fmul"), + llvm::cl::init(true)); + llvm::InitLLVM y(argc, argv); registerAsmPrinterCLOptions(); registerMLIRContextCLOptions(); + registerPassManagerCLOptions(); llvm::cl::ParseCommandLineOptions(argc, argv, toolName); mlir::MLIRContext context; @@ -142,12 +147,17 @@ LogicalResult tritonTranslateMain(int argc, char **argv, llvm::outs() << *llvmir << '\n'; } else if (targetKind == "ptx") { llvm::outs() << ::triton::translateLLVMIRToPTX(*llvmir, SMArch.getValue(), +<<<<<<< HEAD ptxVersion.getValue()); } else if (targetKind == "hsaco") { auto [module, hsaco] = mlir::triton::translateLLVMIRToHSACO( *llvmir, GCNArch.getValue(), GCNTriple.getValue(), GCNFeatures.getValue()); llvm::outs() << hsaco; +======= + ptxVersion.getValue(), + enableFpFusion.getValue()); +>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 } else { llvm::errs() << "Error: Unknown target specified: " << targetKind << "\n"; return failure(); diff --git a/docs/python-api/triton.language.rst b/docs/python-api/triton.language.rst index 410d02fab..de2ef6d7d 100644 --- a/docs/python-api/triton.language.rst +++ b/docs/python-api/triton.language.rst @@ -63,8 +63,6 @@ Memory Ops load store - atomic_cas - atomic_xchg Indexing Ops @@ -129,10 +127,11 @@ Atomic Ops :toctree: generated :nosignatures: - atomic_cas atomic_add + atomic_cas atomic_max atomic_min + atomic_xchg Comparison ops diff --git a/include/triton/Analysis/Alias.h b/include/triton/Analysis/Alias.h index 175006404..a06df5ae2 100644 --- a/include/triton/Analysis/Alias.h +++ b/include/triton/Analysis/Alias.h @@ -63,11 +63,12 @@ private: // Shared Memory Alias Analysis //===----------------------------------------------------------------------===// class SharedMemoryAliasAnalysis - : public dataflow::SparseDataFlowAnalysis> { + : public dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice> { public: - using dataflow::SparseDataFlowAnalysis< - dataflow::Lattice>::SparseDataFlowAnalysis; - using dataflow::SparseDataFlowAnalysis< + using dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice>::SparseForwardDataFlowAnalysis; + using dataflow::SparseForwardDataFlowAnalysis< dataflow::Lattice>::getLatticeElement; /// XXX(Keren): Compatible interface with MLIR AliasAnalysis for future use. diff --git a/include/triton/Analysis/AxisInfo.h b/include/triton/Analysis/AxisInfo.h index 8d28f46aa..efde6db06 100644 --- a/include/triton/Analysis/AxisInfo.h +++ b/include/triton/Analysis/AxisInfo.h @@ -271,8 +271,8 @@ private: std::vector> visitors; }; -class AxisInfoAnalysis - : public dataflow::SparseDataFlowAnalysis> { +class AxisInfoAnalysis : public dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice> { private: AxisInfoVisitorList visitors; @@ -284,7 +284,7 @@ private: public: AxisInfoAnalysis(DataFlowSolver &solver); - using dataflow::SparseDataFlowAnalysis< + using dataflow::SparseForwardDataFlowAnalysis< dataflow::Lattice>::getLatticeElement; using FuncAxisInfoMapT = DenseMap; diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 658edd183..5a9564814 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -54,6 +54,8 @@ public: SmallVector getScratchConfig(); + SmallVector getOrderWithAxisAtBeginning(); + unsigned getScratchSizeInBytes(); bool isSupportedLayout(); @@ -133,9 +135,9 @@ bool supportMMA(Value value, int version); bool isSingleValue(Value value); -bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy); +bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); -bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy); +bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy); // Return true if the src and dst layout match. bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, diff --git a/include/triton/Dialect/NVGPU/IR/NVGPUOps.td b/include/triton/Dialect/NVGPU/IR/NVGPUOps.td index 896a27c17..0f546aed5 100644 --- a/include/triton/Dialect/NVGPU/IR/NVGPUOps.td +++ b/include/triton/Dialect/NVGPU/IR/NVGPUOps.td @@ -27,6 +27,7 @@ include "triton/Dialect/NVGPU/IR/NVGPUAttrDefs.td" include "mlir/IR/OpBase.td" include "mlir/IR/EnumAttr.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType def I8Ptr_global : LLVM_IntPtrBase<8, 1>; def I8Ptr_shared : LLVM_IntPtrBase<8, 3>; @@ -44,9 +45,13 @@ def NVGPU_WGMMACommitGroupOp : NVGPU_Op<"wgmma_commit_group", []> { let assemblyFormat = "attr-dict"; } -def NVGPU_WGMMAWaitGroupOp : NVGPU_Op<"wgmma_wait_group", []> { - let arguments = (ins I32Attr:$pendings); +def NVGPU_WGMMAWaitGroupOp : NVGPU_Op<"wgmma_wait_group", +[DeclareOpInterfaceMethods, + AllTypesMatch<["input", "output"]>]> { + let arguments = (ins LLVM_AnyStruct:$input, I32Attr:$pendings); + let results = (outs LLVM_AnyStruct:$output); let assemblyFormat = "attr-dict"; + let assemblyFormat = "$input attr-dict `:` type($input)"; } def NVGPU_MBarrierInitOp : NVGPU_Op<"mbarrier_init", [MemoryEffects<[MemWrite]>]> { diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index dd763d345..3a9e4a404 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -9,9 +9,8 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" -#include "mlir/IR/FunctionInterfaces.h" -#include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "triton/Dialect/Triton/IR/Dialect.h.inc" #include "triton/Dialect/Triton/IR/OpsEnums.h.inc" #include "triton/Dialect/Triton/IR/Traits.h" diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 575db87be..9445b5c40 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -6,11 +6,11 @@ include "triton/Dialect/Triton/IR/TritonTypes.td" include "triton/Dialect/Triton/IR/TritonAttrDefs.td" include "triton/Dialect/Triton/IR/TritonInterfaces.td" include "mlir/IR/OpBase.td" -include "mlir/IR/FunctionInterfaces.td" // FunctionOpInterface include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface +include "mlir/Interfaces/FunctionInterfaces.td" // FunctionOpInterface include "mlir/Interfaces/SideEffectInterfaces.td" // Pure include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType @@ -668,6 +668,12 @@ def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpI void setCalleeFromCallable(CallInterfaceCallable callee) { (*this)->setAttr("callee", callee.get()); } + + // Required by CallOpInterface. + MutableOperandRange getArgOperandsMutable() { + return getOperandsMutable(); + } + }]; let assemblyFormat = [{ diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index 7d8cc7b41..ed0cb4cd8 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -269,17 +269,17 @@ def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [Pure, let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)"; } -def TTNG_DotWaitOp : TTNG_Op<"dot_wait", []> { +def TTNG_DotWaitOp : TTNG_Op<"dot_wait", [DeclareOpInterfaceMethods, + AllTypesMatch<["input", "output"]>]> { let summary = "dot wait"; - + let arguments = (ins TT_FpIntTensor:$input, I32Attr:$pendings); + let results = (outs TT_FpIntTensor:$output); let description = [{ This operation defining the waiting action for a async dot, MMAv3 .e.g. The subsequent operations should not execute until this operation completes waiting. }]; - let arguments = (ins I32Attr:$pendings); - - let assemblyFormat = "attr-dict"; + let assemblyFormat = "$input attr-dict `:` type($input)"; } def TTNG_StoreAsyncOp : TTNG_Op<"store_async", diff --git a/include/triton/Target/PTX/PTXTranslation.h b/include/triton/Target/PTX/PTXTranslation.h index 63ea87a5c..9dd20f85f 100644 --- a/include/triton/Target/PTX/PTXTranslation.h +++ b/include/triton/Target/PTX/PTXTranslation.h @@ -10,7 +10,8 @@ class Module; namespace triton { // Translate TritonGPU IR to PTX code. -std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version); +std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version, + bool enable_fp_fusion); } // namespace triton diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 14e766a1a..cf09310d4 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -885,7 +885,8 @@ public: //===----------------------------------------------------------------------===// AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver) - : dataflow::SparseDataFlowAnalysis>(solver) { + : dataflow::SparseForwardDataFlowAnalysis>( + solver) { // UnrealizedConversionCast: // This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is // in the process of a PartialConversion, where UnrealizedConversionCast diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 8b09e7293..adb08ef68 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -38,6 +38,17 @@ bool ReduceOpHelper::isReductionOnLayoutFastAxis() { getParentOrder(getSrcLayout())[0]; } +SmallVector ReduceOpHelper::getOrderWithAxisAtBeginning() { + auto srcLayout = getSrcLayout(); + auto order = triton::gpu::getOrder(srcLayout); + auto it = std::find(order.begin(), order.end(), axis); + // delete the axis from order + order.erase(it); + // insert axis at the beginning of order + order.insert(order.begin(), axis); + return order; +} + // Thread offset is the thread index offset of two adjacent threads on the // reduction axis within the warp. unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() { @@ -56,11 +67,11 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() { threadOffset = threadsPerWarp[sliceLayout.getDim()]; } else { auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout); - if (threadsPerWarp.size() == 1) { - threadOffset = 1; - } else { - assert(threadsPerWarp.size() == 2 && "Only supports 2D layouts"); - threadOffset = axis == 0 ? threadsPerWarp[1] : threadsPerWarp[0]; + auto order = triton::gpu::getOrder(srcLayout); + for (unsigned i = 0; i < order.size(); i++) { + if (order[i] == axis) + break; + threadOffset *= threadsPerWarp[order[i]]; } } return threadOffset; @@ -150,8 +161,10 @@ unsigned ReduceOpHelper::getThreadsReductionAxis() { } bool ReduceOpHelper::isWarpSynchronous() { - auto argsLayout = getSrcLayout(); - return triton::gpu::getWarpsPerCTA(argsLayout)[axis] == 1; + auto srcLayout = getSrcLayout(); + auto srcShape = getSrcShape(); + return triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis] == + 1; } SmallVector ReduceOpHelper::getScratchConfig() { @@ -502,10 +515,10 @@ static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) { // when #mma = MmaEncoding return src && dst && src.getVersionMajor() == 3 && src.getWarpsPerCTA()[1] == 1 && dst.getVersionMajor() == 3 && - dst.getWarpsPerCTA()[1] == 1 && srcInstrShape[2] == dstInstrShape[2]; + dst.getWarpsPerCTA()[1] == 1; } -bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { +bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { return isMmaToMmaShortcut(srcTy.getEncoding(), dstTy.getEncoding()); } @@ -521,7 +534,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, srcTy.getElementType().isF16(); } -bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { +bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) return true; // dot_op = #mma @@ -713,7 +726,10 @@ SetVector multiRootGetSlice(Operation *op, auto *currentOp = (slice)[currentIndex]; // Compute and insert the backwardSlice starting from currentOp. backwardSlice.clear(); - getBackwardSlice(currentOp, &backwardSlice, backwardFilter); + mlir::BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.filter = backwardFilter; + getBackwardSlice(currentOp, &backwardSlice, opt); slice.insert(backwardSlice.begin(), backwardSlice.end()); // Compute and insert the forwardSlice starting from currentOp. diff --git a/lib/Conversion/NVGPUToLLVM/CMakeLists.txt b/lib/Conversion/NVGPUToLLVM/CMakeLists.txt index 2f81f6766..9af263686 100644 --- a/lib/Conversion/NVGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/NVGPUToLLVM/CMakeLists.txt @@ -14,7 +14,7 @@ add_mlir_conversion_library(NVGPUToLLVM LINK_LIBS PUBLIC MLIRIR MLIRPass - MLIRGPUOps + MLIRGPUDialect MLIRGPUToNVVMTransforms MLIRGPUToROCDLTransforms MLIRGPUTransforms diff --git a/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp index 02b4c024f..1ba29b04b 100644 --- a/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -29,8 +29,6 @@ const std::string Reg_Alloc_Op = "setmaxnreg.inc.sync.aligned.u32 #regCount;"; const std::string Wgmma_Fence_Op = "wgmma.fence.sync.aligned;"; const std::string Cga_Barrier_Sync_op = "barrier.cluster.sync.aligned;"; const std::string Wgmma_Commit_Group_Op = "wgmma.commit_group.sync.aligned;"; -const std::string Wgmma_Wait_Group_Op = - "wgmma.wait_group.sync.aligned #pendings;"; const std::string Cluster_Wait_Op = "barrier.cluster.wait.aligned;"; const std::string Fence_Mbarrier_Init_Op = "fence.mbarrier_init.release.cluster;"; @@ -200,29 +198,6 @@ public: return {}; } - Type getReturnType(std::vector outputConstraints, - mlir::PatternRewriter &rewriter) const { - auto ctx = rewriter.getContext(); - Type resTy; - if (outputConstraints.empty()) { - resTy = void_ty(ctx); - } else { - SmallVector retTys; - for (auto &outputConstraint : outputConstraints) { - assert(outputConstraint[0] == '=' && - "Constraint must be for an output"); - Type retTy = getTypeFromConstraint(outputConstraint[1], rewriter); - retTys.push_back(retTy); - } - if (retTys.size() == 1) { - resTy = retTys[0]; - } else { - resTy = struct_ty(retTys); - } - } - return resTy; - } - std::string patchPtxAsm(mlir::Operation *op, std::string ptxAsm) const { std::vector> patchLocations; std::vector patchValues; @@ -285,7 +260,8 @@ public: outputsAndOperands.append(ptxOperands.begin(), ptxOperands.end()); auto &ptxInstr = *ptxBuilder.create(ptxAsmPatched); ptxInstr(outputsAndOperands, /*onlyAttachMLIRArgs=*/true); - auto retTy = getReturnType(outputConstraints, rewriter); + auto retTy = + op->getNumResults() == 0 ? void_ty(ctx) : op->getResult(0).getType(); auto res = ptxBuilder.launch(rewriter, loc, retTy, /*hasSideEffects*/ hasSideEffects); if (op->getNumResults() == 0) { @@ -700,6 +676,45 @@ public: } }; +class WGMMAWaitGroupOpPattern + : public NVGPUOpPatternBase { +public: + using Base = + NVGPUOpPatternBase; + using Base::Base; + + std::vector + getOutputConstraints(ttn::WGMMAWaitGroupOp op) const { + auto outputStructType = op.getType().cast(); + uint32_t numOutputRegs = outputStructType.getBody().size(); + std::string output = + outputStructType.getBody().front().isF32() ? "=f" : "=r"; + return std::vector(numOutputRegs, output); + } + + OperandsAndConstraints + getOperandsAndConstraints(ttn::WGMMAWaitGroupOp op) const { + OperandsAndConstraints operandsAndConstraints; + auto input = op.getInput(); + operandsAndConstraints.push_back({input, "0"}); + return operandsAndConstraints; + } + + std::string getPtxAsm(ttn::WGMMAWaitGroupOp op) const { + auto outputStructType = op.getType().dyn_cast(); + uint32_t numCRegs = outputStructType.getBody().size(); + std::string args = ""; + uint32_t asmOpIdx = 0; + for (uint32_t i = 0; i < numCRegs; ++i) { + args += "$" + std::to_string(asmOpIdx++) + (i == numCRegs - 1 ? "" : ","); + } + auto ptxAsm = "// wait for regs: " + args + "\n\t" + + "wgmma.wait_group.sync.aligned #pendings;"; + return ptxAsm; + } +}; + class WGMMAOpPattern : public NVGPUOpPatternBase { public: using Base = NVGPUOpPatternBase; @@ -1072,7 +1087,6 @@ public: POPULATE_NVGPU_OP(ttn::WGMMAFenceOp, Wgmma_Fence_Op) POPULATE_NVGPU_OP(ttn::CGABarrierSyncOp, Cga_Barrier_Sync_op) POPULATE_NVGPU_OP(ttn::WGMMACommitGroupOp, Wgmma_Commit_Group_Op) - POPULATE_NVGPU_OP(ttn::WGMMAWaitGroupOp, Wgmma_Wait_Group_Op) POPULATE_NVGPU_OP(ttn::ClusterWaitOp, Cluster_Wait_Op) POPULATE_NVGPU_OP(ttn::FenceMBarrierInitOp, Fence_Mbarrier_Init_Op) POPULATE_NVGPU_OP(ttn::CGABarrierArriveOp, Cga_Barrier_Arrive_Op) @@ -1100,7 +1114,8 @@ public: OffsetOfStmatrixV4OpPattern, MBarrierArriveOpPattern, ClusterArriveOpPattern, TMALoadTiledOpPattern, TMAStoreTiledOpPattern, LoadDSmemOpPattern, WGMMAOpPattern, - StoreDSmemOpPattern, OffsetOfSts64OpPattern>(context); + WGMMAWaitGroupOpPattern, StoreDSmemOpPattern, + OffsetOfSts64OpPattern>(context); if (applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed()) signalPassFailure(); diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index 794055e21..480424331 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -49,7 +49,7 @@ add_mlir_conversion_library(TritonGPUToLLVM ASMBuilder MLIRIR MLIRPass - MLIRGPUOps + MLIRGPUDialect MLIRGPUToNVVMTransforms MLIRGPUToROCDLTransforms MLIRGPUTransforms diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp index 97f8fb4ee..fcf6b2d76 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp @@ -146,10 +146,8 @@ struct DotWaitOpConversion matchAndRewrite(triton::nvidia_gpu::DotWaitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto pendings = op.getPendings(); - rewriter.create(op.getLoc(), pendings); - - // Safe to remove the op since it doesn't have any return value. - rewriter.eraseOp(op); + rewriter.replaceOpWithNewOp( + op, adaptor.getInput(), pendings); return success(); } }; diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp index b16aee5d8..21f1811c7 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -168,7 +168,9 @@ DotOpMmaV3SmemLoader loadA(TritonGPUToLLVMTypeConverter *typeConverter, int numRepM = ceil(shapePerCTA[0], instrShape[0] * wpt[0]); int numRepK = ceil(shapePerCTA[1], instrShape[2]); - Value warp = udiv(thread, i32_val(32)); + // The descriptor should be calculated based on the first warp of the + // warpgroup. + Value warp = and_(udiv(thread, i32_val(32)), i32_val(0xFFFFFFFC)); Value warpM = urem(warp, i32_val(wpt[0])); Value warpId = urem(warpM, i32_val(shapePerCTA[0] / instrShape[0])); @@ -199,7 +201,7 @@ DotOpMmaV3SmemLoader loadB(TritonGPUToLLVMTypeConverter *typeConverter, int numRepK = ceil(shapePerCTA[0], instrShape[2]); int numRepN = ceil(shapePerCTA[1], instrShape[1] * wpt[1]); - Value warp = udiv(thread, i32_val(32)); + Value warp = and_(udiv(thread, i32_val(32)), i32_val(0xFFFFFFFC)); Value warpMN = udiv(warp, i32_val(wpt[0])); Value warpN = urem(warpMN, i32_val(wpt[1])); Value warpId = urem(warpN, i32_val(shapePerCTA[1] / instrShape[1])); @@ -293,6 +295,26 @@ static bool isZero(Value v) { return false; } +static SmallVector emitWait(ConversionPatternRewriter &rewriter, + Location loc, SmallVector acc, + int pendings) { + SmallVector types(acc.size(), acc[0].getType()); + auto structTy = + LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types); + Value llvmStruct = rewriter.create(loc, structTy); + int i = 0; + for (Value v : acc) { + llvmStruct = insert_val(structTy, llvmStruct, v, i++); + } + Value res = rewriter.create(loc, llvmStruct, + pendings); + SmallVector results; + for (int i = 0; i < acc.size(); ++i) { + results.push_back(extract_val(types[0], res, i)); + } + return results; +} + LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc, Operation *op, Value a, Value b, Value c, Value d, @@ -427,7 +449,7 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, rewriter.create(loc); if (sync) - rewriter.create(loc, 0); + mmaResults = emitWait(rewriter, loc, mmaResults, 0); SmallVector results = unpackAccumulator(rewriter, loc, mmaResults, dTensorTy); diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 6bb16fe46..102bea5b2 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -6,7 +6,24 @@ using ::mlir::triton::gpu::getTotalElemsPerThread; /* ----- FP8E5M2 ------ */ // This data-type is the standard FP8E5M2 format +static const std::string Fp16_to_Fp8E5M2(bool hasNativeFP) { + std::string ret; + if (!hasNativeFP) { + ret = "{ \n" + ".reg .b32 a<2>; \n" + "and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe + "and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit) + "add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080 + "add.u32 a1, a1, 0x00800080; \n" // (round to nearest) + "prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0 + "}"; + } else { + ret = "cvt.rn.satfinite.e5m2x2.f16x2 $0, $1; \n\t"; + } + return ret; +} +<<<<<<< HEAD #ifdef USE_ROCM static SmallVector Fp16_to_Fp8E5M2(Location loc, ConversionPatternRewriter &rewriter, @@ -356,11 +373,115 @@ const std::string Bf16_to_Fp8E5M2 = "or.b32 $0, nosign, sign; \n" // restore sign "}"; #endif +======= +static const std::string Fp8E5M2_to_Fp16(bool hasNativeFP) { + std::string ret; + if (!hasNativeFP) { + ret = "{ \n" + "prmt.b32 $0, 0, $2, 0x5140; \n\t" + "prmt.b32 $1, 0, $2, 0x7362; \n\t" + "}"; + } else { + ret = "cvt.rn.f16x2.e5m2x2 $0, $1; \n\t"; + } + return ret; +} +static const std::string Fp8E5M2_to_Bf16(bool hasNativeFP) { + std::string ret; + if (!hasNativeFP) { + ret = + "{ \n" + ".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4 + "prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400 + "prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200 + "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff + "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign) + "shr.b32 b0, b0, 3; \n" // b0 >>= 3 + "shr.b32 b1, b1, 3; \n" // shift into bf16 position + "add.u32 b0, b0, 0x38003800; \n" // b0.exp += 2**7-2**4 + // exponent compensate = 112 + "add.u32 b1, b1, 0x38003800; \n" // b1 += 112<<7 | 112<<7<<16 + "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0) + "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign) + "}"; + } else { + ret = "{ \n" + ".reg .b32 a; \n" + ".reg .f16 a<2>; \n" + ".reg .b16 b<2>; \n" + "cvt.rn.f16x2.e5m2x2 a, $1; \n" + "mov.b32 {a0, a1}, a; \n" + "cvt.bf16.f16 b0, a0; \n" + "cvt.bf16.f16 b1, a1; \n" + "mov.b32 $0, {b0, b1}; \n" + "}"; + } + return ret; +} +>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 + +static const std::string Bf16_to_Fp8E5M2(bool hasNativeFP) { + std::string ret; + if (!hasNativeFP) { + ret = + "{ \n" // bf16=fp8>>3 + 112<<7 + ".reg .u32 sign, sign<2>, nosign, nosign<2>; \n" // fp8_min = 0b00000000 + ".reg .u32 fp8_min, fp8_max, rn_; \n" // fp8_max = 0b11111111 + "mov.u32 fp8_min, 0x38003800; \n" // so bf16_min = 0x3800 + "mov.u32 fp8_max, 0x57e057e0; \n" // so bf16_max = 0x57e0 + "mov.u32 rn_, 0x00100010; \n" // round to nearest + "and.b32 sign0, $1, 0x80008000; \n" // sign0=in0&0x80008000 + "and.b32 sign1, $2, 0x80008000; \n" // (store sign) + "prmt.b32 sign, sign0, sign1, 0x7531; \n" + "and.b32 nosign0, $1, 0x7fff7fff; \n" // nosign0=in0&0x7fff7fff + "and.b32 nosign1, $2, 0x7fff7fff; \n" // (strip sign) + + // nosign = clamp(nosign, min, max) + ".reg .u32 nosign_0_<2>, nosign_1_<2>; \n" + "and.b32 nosign_0_0, nosign0, 0xffff0000; \n" + "max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n" + "min.u32 nosign_0_0, nosign_0_0, 0x57e00000; \n" + "and.b32 nosign_0_1, nosign0, 0x0000ffff; \n" + "max.u32 nosign_0_1, nosign_0_1, 0x3800; \n" + "min.u32 nosign_0_1, nosign_0_1, 0x57e0; \n" + "or.b32 nosign0, nosign_0_0, nosign_0_1; \n" + "and.b32 nosign_1_0, nosign1, 0xffff0000; \n" + "max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n" + "min.u32 nosign_1_0, nosign_1_0, 0x57e00000; \n" + "and.b32 nosign_1_1, nosign1, 0x0000ffff; \n" + "max.u32 nosign_1_1, nosign_1_1, 0x3800; \n" + "min.u32 nosign_1_1, nosign_1_1, 0x57e0; \n" + "or.b32 nosign1, nosign_1_0, nosign_1_1; \n" + + "add.u32 nosign0, nosign0, rn_; \n" // nosign0 += rn_ + "add.u32 nosign1, nosign1, rn_; \n" // (round to nearest) + "sub.u32 nosign0, nosign0, 0x38003800; \n" // nosign0-=0x38003800 + "sub.u32 nosign1, nosign1, 0x38003800; \n" // (compensate offset) + "shl.b32 nosign0, nosign0, 3; \n" // nosign0 <<= 3 + "shl.b32 nosign1, nosign1, 3; \n" // shift into to fp8e4 + "prmt.b32 nosign, nosign0, nosign1, 0x7531; \n" // nosign0 = 0xf100f200 + // nosign1 = 0xf300f400 + // nosign = 0xf3f4f1f2 + "or.b32 $0, nosign, sign; \n" // restore sign + "}"; + } else { + ret = "{ \n" + ".reg .b16 a<2>; \n" + ".reg .f32 b<2>; \n" + "mov.b32 {a0, a1}, $1; \n" + "cvt.f32.bf16 b0, a0; \n" + "cvt.f32.bf16 b1, a1; \n" + "cvt.rn.satfinite.e5m2x2.f32 $0, b0, b1; \n" + "}"; + } + return ret; +} /* ----- FP8E4M3B15 ------ */ // This data-type is a variant of the standard FP8E4M3 format. // It was designed for fast software conversion to FP16 on // nvidia GPUs that do not support it natively. +<<<<<<< HEAD // Specifically, this data-type: // - has infinities // - has multiple nans (when all exponent bits are 1) @@ -404,6 +525,11 @@ Fp8E4M3B15_to_Fp16(Location loc, ConversionPatternRewriter &rewriter, }; } #else +======= +// This is the same format as FP8E4M3Nv, but: +// - the exponent bias is 15 instead of 7 +// - 0xff and 0x7f are mapped to +-1.750 instead of +-nan +>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 const std::string Fp8E4M3B15_to_Fp16 = "{ \n" ".reg .b32 a<2>, b<2>; \n" @@ -416,6 +542,7 @@ const std::string Fp8E4M3B15_to_Fp16 = "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" "shl.b32 $1, b1, 7; \n" "} \n"; +<<<<<<< HEAD #endif #ifdef USE_ROCM @@ -464,6 +591,10 @@ Fp16_to_Fp8E4M3B15(Location loc, ConversionPatternRewriter &rewriter, } #else const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) { +======= + +static const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) { +>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 std::string ret; ret += "{ \n" ".reg .pred p<4>; \n" @@ -509,6 +640,7 @@ const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) { // $0 = (($2 << 1) & 0x80008000u) | (($2 << 7) & 0x3f803f80u); // $1 = (($2 << 0) & 0x80008000u) | (($2 << 0) & 0x3f803f80u); // WARN: subnormal (0bs0000xxx) are not handled +<<<<<<< HEAD #ifdef USE_ROCM static SmallVector Fp8E4M3B15x4_to_Fp16(Location loc, ConversionPatternRewriter &rewriter, @@ -540,6 +672,9 @@ Fp8E4M3B15x4_to_Fp16(Location loc, ConversionPatternRewriter &rewriter, } #else const std::string Fp8E4M3B15x4_to_Fp16 = +======= +static const std::string Fp8E4M3B15x4_to_Fp16 = +>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 "{ \n" ".reg .b32 a<2>; \n" "add.u32 a0, $2, $2; \n" @@ -557,6 +692,7 @@ const std::string Fp8E4M3B15x4_to_Fp16 = // ((e4.y >> 0) & (0x80008000u >> 0)) | // ((e4.y >> 0) & (0x3f803f80u >> 0)) ; // WARN: subnormal (0bs0000xxx) are not handled +<<<<<<< HEAD #ifdef USE_ROCM static SmallVector Fp16_to_Fp8E4M3B15x4(Location loc, ConversionPatternRewriter &rewriter, @@ -591,6 +727,9 @@ Fp16_to_Fp8E4M3B15x4(Location loc, ConversionPatternRewriter &rewriter, } #else const std::string Fp16_to_Fp8E4M3B15x4 = +======= +static const std::string Fp16_to_Fp8E4M3B15x4 = +>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 "{ \n" ".reg .b32 a<2>; \n" "shr.b32 a0, $1, 1; \n" @@ -904,17 +1043,18 @@ const std::string Bf16_to_Fp8E4M3 = #endif // Fp8E4M3 (x2) -> Fp16 (x2) (packed) -const std::string Fp8E4M3Nv_to_Fp16 = "{ \n" - "cvt.rn.f16x2.e4m3x2 $0, $1; \n" - "}"; +static const std::string Fp8E4M3Nv_to_Fp16 = "{ \n" + "cvt.rn.f16x2.e4m3x2 $0, $1; \n" + "}"; // Fp16 (x2) -> Fp8E4M3 (x2) (packed) -const std::string Fp16_to_Fp8E4M3Nv = "{ \n" - "cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; \n" - "}"; +static const std::string Fp16_to_Fp8E4M3Nv = + "{ \n" + "cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; \n" + "}"; #ifndef USE_ROCM // Fp8E4M3 (x2) -> Fp16 (x2) (packed) -const std::string Fp8E4M3Nv_to_Bf16 = +static const std::string Fp8E4M3Nv_to_Bf16 = "{ \n" ".reg .b32 a; \n" ".reg .f16 a<2>; \n" @@ -927,7 +1067,7 @@ const std::string Fp8E4M3Nv_to_Bf16 = "}"; // Bf16 (x2) -> Fp8E4M3 (x2) (packed) -const std::string Bf16_to_Fp8E4M3Nv = +static const std::string Bf16_to_Fp8E4M3Nv = "{ \n" ".reg .b16 a<2>; \n" ".reg .f32 b<2>; \n" @@ -938,7 +1078,7 @@ const std::string Bf16_to_Fp8E4M3Nv = "}"; /* ----- Packed integer to BF16 ------ */ -const std::string S8_to_Bf16 = +static const std::string S8_to_Bf16 = "{ \n" ".reg .s8 s<4>; \n" ".reg .f32 f<4>; \n" @@ -952,6 +1092,12 @@ const std::string S8_to_Bf16 = "}"; #endif +// Fp32 (x2) -> Fp8 (x2) (packed) +static const std::string Fp32_to_Fp8E4M3Nv = + "cvt.rn.satfinite.e4m3x2.f32 $0, $2, $1; \n"; +static const std::string Fp32_to_Fp8E5M2 = + "cvt.rn.satfinite.e5m2x2.f32 $0, $2, $1; \n"; + static SmallVector reorderValues(const SmallVector &values, Type inType, Type ouType) { auto inTensorTy = inType.dyn_cast(); @@ -1383,9 +1529,14 @@ struct FpToFpOpConversion // F8 -> F16 {{F8E4M3B15TyID, F16TyID}, Fp8E4M3B15_to_Fp16}, {{F8E4M3FNTyID, F16TyID}, Fp8E4M3B15x4_to_Fp16}, +<<<<<<< HEAD {{F8E4M3FNUZTyID, F16TyID}, Fp8E4M3FNUZ_to_Fp16}, {{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16}, {{F8E5M2FNUZTyID, F16TyID}, Fp8E5M2FNUZ_to_Fp16}, +======= + {{F8E4M3TyID, F16TyID}, Fp8E4M3Nv_to_Fp16}, + {{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16(computeCapability >= 90)}, +>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 // F16 -> F8 #ifdef USE_ROCM {{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15}, @@ -1393,27 +1544,44 @@ struct FpToFpOpConversion {{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15(computeCapability >= 80)}, #endif {{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4}, +<<<<<<< HEAD {{F16TyID, F8E4M3FNUZTyID}, Fp16_to_Fp8E4M3FNUZ}, {{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2}, {{F16TyID, F8E5M2FNUZTyID}, Fp16_to_Fp8E5M2FNUZ}, // F8 -> BF16 {{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16}, #ifndef USE_ROCM +======= + {{F16TyID, F8E4M3TyID}, Fp16_to_Fp8E4M3Nv}, + {{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2(computeCapability >= 90)}, + // F8 -> BF16 + {{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16(computeCapability >= 90)}, +>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 {{F8E4M3TyID, BF16TyID}, Fp8E4M3Nv_to_Bf16}, #endif // BF16 -> F8 +<<<<<<< HEAD {{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2}, #ifndef USE_ROCM {{BF16TyID, F8E4M3TyID}, Bf16_to_Fp8E4M3Nv}, #endif +======= + {{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2(computeCapability >= 90)}, + {{BF16TyID, F8E4M3TyID}, Bf16_to_Fp8E4M3Nv}, + // F32 -> F8 + {{F32TyID, F8E4M3TyID}, Fp32_to_Fp8E4M3Nv}, + {{F32TyID, F8E5M2TyID}, Fp32_to_Fp8E5M2}, +>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 }; int inVecWidthBits = 32; int outVecWidthBits = 32; - if (srcTy.isFloat8E4M3FNUZ()) { + if (srcTy.isFloat8E4M3FNUZ() || + (computeCapability >= 90 && srcTy.isFloat8E5M2())) { inVecWidthBits = 16; outVecWidthBits = 32; } - if (dstTy.isFloat8E4M3FNUZ()) { + if (dstTy.isFloat8E4M3FNUZ() || + (computeCapability >= 90 && dstTy.isFloat8E5M2())) { inVecWidthBits = 32; outVecWidthBits = 16; } @@ -1450,18 +1618,24 @@ struct FpToFpOpConversion size_t numElements = 4; if (srcElementType.isFloat8E4M3FNUZ() || - dstElementType.isFloat8E4M3FNUZ()) { + dstElementType.isFloat8E4M3FNUZ() || + (computeCapability >= 90 && + (srcElementType.isFloat8E5M2() || dstElementType.isFloat8E5M2()))) { numElements = 2; } - bool isSrcFP32 = srcElementType.isF32(); + bool useFP16IntermediateSrc = + srcElementType.isF32() && + !(computeCapability >= 90 && + (dstElementType.isFloat8E4M3FNUZ() || dstElementType.isFloat8E5M2())); bool isDstFP32 = dstElementType.isF32(); - auto cvtFunc = getConversionFunc(isSrcFP32 ? f16_ty : srcElementType, - isDstFP32 ? f16_ty : dstElementType); + auto cvtFunc = + getConversionFunc(useFP16IntermediateSrc ? f16_ty : srcElementType, + isDstFP32 ? f16_ty : dstElementType); SmallVector inVals; for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) { inVals.push_back(operands[i][0]); } - if (isSrcFP32) + if (useFP16IntermediateSrc) for (Value &v : inVals) v = convertFp32ToFp16(loc, rewriter, v); inVals.resize(numElements, @@ -2115,18 +2289,18 @@ void populateElementwiseOpToLLVMPatterns( POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // % POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp) POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp) - POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // & - POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // | - POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^ - POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // << - POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >> - POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >> - POPULATE_BINARY_OP(arith::MinFOp, LLVM::MinNumOp) // fmin - POPULATE_BINARY_OP(arith::MaxFOp, LLVM::MaxNumOp) // fmax - POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin - POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax - POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin - POPULATE_BINARY_OP(arith::MaxUIOp, LLVM::UMaxOp) // umax + POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // & + POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // | + POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^ + POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // << + POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >> + POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >> + POPULATE_BINARY_OP(arith::MinimumFOp, LLVM::MinNumOp) // fmin + POPULATE_BINARY_OP(arith::MaximumFOp, LLVM::MaxNumOp) // fmax + POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin + POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax + POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin + POPULATE_BINARY_OP(arith::MaxUIOp, LLVM::UMaxOp) // umax #undef POPULATE_BINARY_OP #define POPULATE_UNARY_OP(SRC_OP, DST_OP) \ diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM/Fp8E4M3B15.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM/Fp8E4M3B15.cpp deleted file mode 100644 index e69de29bb..000000000 diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM/Fp8E5M2.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM/Fp8E5M2.cpp deleted file mode 100644 index e69de29bb..000000000 diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp index 1fd621079..843cf2fce 100644 --- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1549,10 +1549,12 @@ struct InsertSliceAsyncOpConversion auto srcTy = src.getType().cast(); auto resTy = dst.getType().cast(); auto resElemTy = getTypeConverter()->convertType(resTy.getElementType()); - auto srcBlockedLayout = srcTy.getEncoding().cast(); + auto srcLayout = srcTy.getEncoding(); + assert((srcLayout.isa() && + "Unexpected srcLayout in InsertSliceAsyncOpConversion")); auto resSharedLayout = resTy.getEncoding().cast(); auto srcShape = srcTy.getShape(); - assert(srcShape.size() == 2 && + assert((srcShape.size() == 1 || srcShape.size() == 2) && "insert_slice_async: Unexpected rank of %src"); Value llDst = adaptor.getDst(); @@ -1617,25 +1619,15 @@ struct InsertSliceAsyncOpConversion unsigned numElems = getTotalElemsPerThread(srcTy); unsigned perPhase = resSharedLayout.getPerPhase(); unsigned maxPhase = resSharedLayout.getMaxPhase(); - auto sizePerThread = srcBlockedLayout.getSizePerThread(); - auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout); - auto inOrder = srcBlockedLayout.getOrder(); DenseMap sharedPtrs = getSwizzledSharedPtrs(loc, inVec, srcTy, resSharedLayout, resElemTy, smemObj, rewriter, offsetVals, srcStrides); - // If perPhase * maxPhase > threadsPerCTA, we will have elements - // that share the same tile indices. The index calculation will - // be cached. - auto numSwizzleRows = std::max( - (perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1); // A sharedLayout encoding has a "vec" parameter. // On the column dimension, if inVec > outVec, it means we have to divide // single vector read into multiple ones auto numVecCols = std::max(inVec / outVec, 1); - auto srcIndices = emitIndices(loc, rewriter, srcBlockedLayout, srcTy); - for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) { // 16 * 8 = 128bits auto maxBitWidth = diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index fb148c423..09ed1f98f 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -419,16 +419,15 @@ private: getMultiDimWarpId(helper, warpId, loc, rewriter); Value warpIdAxis = multiDimWarpId[axis]; - if (!helper.isReductionOnLayoutFastAxis()) { - std::reverse(order.begin(), order.end()); - } + auto smemOrder = helper.getOrderWithAxisAtBeginning(); for (auto it : accs) { const SmallVector &key = it.first; SmallVector acc = it.second; SmallVector writeIdx = indices[key]; writeIdx[axis] = warpIdAxis; - Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, order); + Value writeOffset = + linearize(rewriter, loc, writeIdx, smemShape, smemOrder); for (unsigned i = 0; i < op.getNumOperands(); ++i) { auto elemPtrTy = getElementPtrType(op, i); Value writePtr = gep(elemPtrTy, smemBases[i], writeOffset); @@ -513,10 +512,7 @@ private: Location loc = op.getLoc(); auto srcLayout = helper.getSrcLayout(); auto axis = op.getAxis(); - auto order = getOrder(srcLayout); - if (!helper.isReductionOnLayoutFastAxis()) { - std::reverse(order.begin(), order.end()); - } + auto smemOrder = helper.getOrderWithAxisAtBeginning(); SmallVector results(op.getNumOperands()); for (unsigned i = 0; i < op.getNumOperands(); ++i) { if (auto resultTy = @@ -532,7 +528,7 @@ private: SmallVector readIdx = resultIndices[j]; readIdx.insert(readIdx.begin() + op.getAxis(), i32_val(0)); Value readOffset = - linearize(rewriter, loc, readIdx, smemShape, order); + linearize(rewriter, loc, readIdx, smemShape, smemOrder); Value readPtr = gep(getElementPtrType(op, i), smemBases[i], readOffset); resultVals[j] = load(readPtr); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index d72f26842..767478cea 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -622,10 +622,13 @@ struct AllocTensorOpConversion // TODO: we need to modify the pipeline pass to give a proper shared // encoding to 3D tensors SmallVector newOrder; - if (resultTy.getShape().size() == 3) - newOrder = {1 + order[0], 1 + order[1], 0}; - else + if (resultTy.getShape().size() != order.size()) { + for (auto i = 0; i < order.size(); ++i) + newOrder.push_back(order[i] + 1); + newOrder.push_back(0); + } else { newOrder = SmallVector(order.begin(), order.end()); + } auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape()); auto smemObj = @@ -659,10 +662,13 @@ struct ExtractSliceOpConversion SmallVector opOffsetVals; SmallVector offsetVals; auto mixedOffsets = op.getMixedOffsets(); - for (auto i = 0; i < mixedOffsets.size(); ++i) { - if (op.isDynamicOffset(i)) - opOffsetVals.emplace_back(adaptor.getOffsets()[i]); - else + for (auto i = 0, j = 0; i < mixedOffsets.size(); ++i) { + if (op.isDynamicOffset(i)) { + // adaptor.getOffsets() returns list of variable offsets. the size of + // the list may not be the same as mixedOffsets + opOffsetVals.emplace_back(adaptor.getOffsets()[j]); + ++j; + } else opOffsetVals.emplace_back(i32_val(op.getStaticOffset(i))); offsetVals.emplace_back(add(smemObj.offsets[i], opOffsetVals[i])); } diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h index cf031a339..034eff593 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h @@ -146,7 +146,8 @@ protected: } auto newFuncOp = rewriter.create( funcOp.getLoc(), funcOp.getName(), llvmType, linkage, - /*dsoLocal*/ false, LLVM::CConv::C, attributes); + /*dsoLocal*/ false, LLVM::CConv::C, /*comdat=*/SymbolRefAttr{}, + attributes); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter, @@ -361,8 +362,13 @@ public: unsigned numElemsPerSwizzlingRow = swizzlingByteWidth * 8 / resElemTy.getIntOrFloatBitWidth(); Value numElemsPerSwizzlingRowVal = i32_val(numElemsPerSwizzlingRow); - unsigned leadingDimOffset = - numElemsPerSwizzlingRow * srcShapePerCTA[outOrder[1]]; + unsigned leadingDimOffset; + if (outOrder.size() == 2) { + leadingDimOffset = numElemsPerSwizzlingRow * srcShapePerCTA[outOrder[1]]; + } else { + leadingDimOffset = numElemsPerSwizzlingRow; + } + Value leadingDimOffsetVal = i32_val(leadingDimOffset); // Return values DenseMap ret; @@ -374,9 +380,15 @@ public: // Extract multi dimensional index for current element auto idx = srcIndices[elemIdx]; Value idxCol = idx[outOrder[0]]; // contiguous dimension - Value idxRow = idx[outOrder[1]]; // discontiguous dimension + Value idxRow, strideRow; + if (outOrder.size() == 2) { + idxRow = idx[outOrder[1]]; // discontiguous dimension + strideRow = srcStrides[outOrder[1]]; + } else { + idxRow = i32_val(0); + strideRow = i32_val(0); + } Value strideCol = srcStrides[outOrder[0]]; - Value strideRow = srcStrides[outOrder[1]]; // compute phase = (row // perPhase) % maxPhase Value phase = urem(udiv(idxRow, i32_val(perPhase)), i32_val(maxPhase)); // extract dynamic/static offset for immediate offsetting @@ -428,10 +440,16 @@ public: offset = add(offset, add(rowOff, mul(colOff, strideCol))); Value currPtr = gep(dstPtrTy, dstPtrBase, offset); // compute immediate offset - Value immedateOff = - add(mul(i32_val(immedateOffRow), srcStrides[outOrder[1]]), - i32_val(immedateOffCol)); - ret[elemIdx] = gep(dstPtrTy, currPtr, immedateOff); + Value immediateOff; + if (outOrder.size() == 2) { + immediateOff = + add(mul(i32_val(immedateOffRow), srcStrides[outOrder[1]]), + i32_val(immedateOffCol)); + } else { + immediateOff = i32_val(immedateOffCol); + } + + ret[elemIdx] = gep(dstPtrTy, currPtr, immediateOff); } return ret; } diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 7d08533a4..d690acf6d 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -371,13 +371,15 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter, Type type = val.getType(); if (type != i32_ty) { val = bitcast(val, int_ty(bits)); - val = zext(i32_ty, val); + if (bits < 32) + val = zext(i32_ty, val); } Value mask = i32_val(0xFFFFFFFF); Value result = rewriter.create(loc, i32_ty, mask, val, i, clamp, mode, UnitAttr()); if (type != i32_ty) { - result = trunc(int_ty(bits), result); + if (bits < 32) + result = trunc(int_ty(bits), result); result = bitcast(result, type); } return result; diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h index fd69a4c63..ce28e41ac 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.h +++ b/lib/Conversion/TritonGPUToLLVM/Utility.h @@ -97,7 +97,7 @@ ptxBuilder.launch(rewriter, op->getLoc(), voidTy); \ } while (0) #define undef(...) rewriter.create(loc, __VA_ARGS__) -#define null(...) rewriter.create(loc, __VA_ARGS__) +#define null(...) rewriter.create(loc, __VA_ARGS__) #define call(...) rewriter.create(loc, __VA_ARGS__) // Types diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index b4f091d7b..327a9138e 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -115,8 +115,8 @@ void populateArithPatternsAndLegality(TritonGPUTypeConverter &typeConverter, // Floating point GenericOpPattern, GenericOpPattern, // MaxMin - GenericOpPattern, GenericOpPattern, - GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, // Floating point GenericOpPattern, GenericOpPattern, diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 2d7bec31c..d392544ef 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -1,9 +1,9 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/FunctionImplementation.h" -#include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" diff --git a/lib/Dialect/TritonGPU/IR/CMakeLists.txt b/lib/Dialect/TritonGPU/IR/CMakeLists.txt index 8477f4dcc..bab4a5dec 100644 --- a/lib/Dialect/TritonGPU/IR/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/IR/CMakeLists.txt @@ -8,6 +8,6 @@ add_mlir_dialect_library(TritonGPUIR TritonGPUAttrDefsIncGen LINK_LIBS PUBLIC - MLIRGPUOps + MLIRGPUDialect TritonIR ) diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 8d584d4b4..1db61b8e4 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1603,7 +1603,7 @@ ParseResult parseInsertSliceOp(OpAsmParser &parser, OperationState &result) { result.operands)) return failure(); - // Deduce operand_segment_sizes from the number of the operands. + // Deduce operandSegmentSizes from the number of the operands. auto operandSegmentSizesAttrName = OpT::getOperandSegmentSizesAttrName(result.name); result.addAttribute( @@ -1616,7 +1616,7 @@ template void printInsertSliceOp(OpAsmPrinter &printer, OpT insertSliceOp) { printer << " "; printer << insertSliceOp.getOperation()->getOperands(); - // "operand_segment_sizes" can be deduced, so we don't print it. + // "operandSegmentSizes" can be deduced, so we don't print it. printer.printOptionalAttrDict( insertSliceOp->getAttrs(), {insertSliceOp.getOperandSegmentSizesAttrName()}); diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 1fe5a1c63..a0887afdc 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -139,7 +139,10 @@ class BlockedToMMA : public mlir::RewritePattern { int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth(); int origBitWidth = finalBitWidth; SetVector slice; - mlir::getBackwardSlice(x, &slice, bwdFilter); + mlir::BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.filter = bwdFilter; + getBackwardSlice(x, &slice, opt); Operation *firstOp = slice.empty() ? nullptr : *slice.begin(); if (firstOp) if (Value arg = firstOp->getOperand(0)) @@ -235,8 +238,11 @@ public: if (versionMajor == 1) { SetVector aBwdSlices, bBwdSlices; auto isCvt = [](Operation *op) { return isa(op); }; - getBackwardSlice(a, &aBwdSlices, {isCvt}); - getBackwardSlice(b, &bBwdSlices, {isCvt}); + mlir::BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.filter = isCvt; + getBackwardSlice(a, &aBwdSlices, opt); + getBackwardSlice(b, &bBwdSlices, opt); // get the source of the first conversion found in slices auto getCvtArgOrder = [](Operation *op) { return cast(op) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 8e13719b7..d82f6dcd2 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -98,7 +98,9 @@ public: // and all operations between the load and the conversion // should be layout preserving SetVector slice; - getBackwardSlice(op, &slice); + mlir::BackwardSliceOptions opt; + opt.omitBlockArguments = true; + getBackwardSlice(op, &slice, opt); int loadIdx = -1; bool checkOp = false; for (int i = 0; i < slice.size(); i++) { diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 67b86e238..86d228988 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -160,6 +160,8 @@ class LoopPipeliner { void checkOpShareBarriers(SetVector &ops); int numLoadsRequireAsyncWait = 0; int numLoadsRequireMBarrier = 0; + // Number of buffers to allocate for each input. + int numSharedMemorySlices = 0; /// Iterator values Value nextIV; @@ -280,9 +282,12 @@ class LoopPipeliner { public: LoopPipeliner(scf::ForOp forOp, int numStages, int numWarps, int numCTAs, - bool mode, ConsumerReleaseMap &consumerReleaseMap) + bool mode, int numSharedMemorySlices, + ConsumerReleaseMap &consumerReleaseMap) : forOp(forOp), numStages(numStages), numWarps(numWarps), - numCTAs(numCTAs), mode(mode), consumerReleaseMap(consumerReleaseMap) { + numCTAs(numCTAs), mode(mode), + numSharedMemorySlices(numSharedMemorySlices), + consumerReleaseMap(consumerReleaseMap) { // cache yieldOp yieldOp = cast(forOp.getBody()->getTerminator()); } @@ -644,7 +649,7 @@ void LoopPipeliner::createBufferTypes() { auto ty = loadOp.getType().cast(); SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); - bufferShape.insert(bufferShape.begin(), numStages); + bufferShape.insert(bufferShape.begin(), numSharedMemorySlices); auto CTALayout = ttg::getCTALayout(ty.getEncoding()); Attribute sharedEnc; if (auto dotOpEnc = cvt.getType() @@ -946,6 +951,11 @@ void LoopPipeliner::emitPrologue() { pipelineIterIdx = builder.create( iv.getLoc(), pipelineIterIdx, builder.create(iv.getLoc(), 1, 32)); + Value numSlices = builder.create( + iv.getLoc(), numSharedMemorySlices, 32); + Value _0 = builder.create(iv.getLoc(), 0, 32); + pipelineIterIdx = getBoundedIterationValue(builder, pipelineIterIdx, + numSlices, pipelineIterIdx, _0); // Some values have not been used by any ops in the loop body for (BlockArgument arg : forOp.getRegionIterArgs()) setValueMappingYield(arg, valueMapping[arg][stage], stage + 1); @@ -1220,11 +1230,13 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp, Value _1 = builder.create(idxLoc, 1, 32); Value numStagesVal = builder.create(idxLoc, numStages, 32); + Value numSlices = + builder.create(idxLoc, numSharedMemorySlices, 32); // nextWaitIdx Value waitIdxPlusOne = builder.create(idxLoc, curWaitIdx, _1); - Value nextWaitIdx = getBoundedIterationValue( - builder, waitIdxPlusOne, numStagesVal, waitIdxPlusOne, _0); + Value nextWaitIdx = getBoundedIterationValue(builder, waitIdxPlusOne, + numSlices, waitIdxPlusOne, _0); // Indices of InsertSliceAsyncOp and ExtractSliceOp Value insertSliceIndex = pipelineIterIdx; @@ -1417,9 +1429,8 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp, // Bump pipelineIterIdx Value pipelineIterIdxPlusOne = builder.create(idxLoc, pipelineIterIdx, _1); - pipelineIterIdx = - getBoundedIterationValue(builder, pipelineIterIdxPlusOne, numStagesVal, - pipelineIterIdxPlusOne, _0); + pipelineIterIdx = getBoundedIterationValue( + builder, pipelineIterIdxPlusOne, numSlices, pipelineIterIdxPlusOne, _0); // Bump curWaitIdx curWaitIdx = nextWaitIdx; @@ -1516,10 +1527,23 @@ struct PipelinePass : public TritonGPUPipelineBase { // applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); llvm::SmallVector newForOps; + + // Currently we schedule stage 0 after stage `numStages - 1` during + // pipelining therefore we only need `numStages - 1` slice of memory. + // On Hopper we have a separate post-processing that pipelines wgmma so we + // need an extra buffer for each input. + // Note that an alternative would be to keep allocating `numStages` buffers + // and remove the barrier between the loads from shared memory and the + // copies from global to shared. This would require improving existing + // membar analysis. + int numSharedMemorySlices = + computeCapability < 90 ? numStages - 1 : numStages; + // Do the pipelining getOperation()->walk([&](scf::ForOp forOp) -> void { LoopPipeliner pipeliner(forOp, this->numStages, this->numWarps, - this->numCTAs, mode, consumerReleaseMap); + this->numCTAs, mode, numSharedMemorySlices, + consumerReleaseMap); if (pipeliner.initialize().failed()) return; @@ -1593,7 +1617,8 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) { /// XXX(Keren): Clean up the following duplicate code with checkDotOp /// dots to be pipelined - SetVector dots; + SmallVector dots; + SmallVector resultNeedSync; for (Operation &op : *loop) { if (auto dotOp = dyn_cast(&op)) { auto resTy = dotOp.getResult().getType().dyn_cast(); @@ -1615,8 +1640,11 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) { if (!CArg || !CArg.hasOneUse()) valid = false; - if (valid) - dots.insert(dotOp); + if (valid) { + dots.push_back(dotOp); + resultNeedSync.push_back( + dotOp->getUses().begin()->getOperandNumber()); + } } } } @@ -1627,39 +1655,39 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) { return; OpBuilder builder(forOp); - - // 0. insert dot_wait after the last dot in the loop - Value dot = dots.back(); - auto loc = dot.getLoc(); - builder.setInsertionPointAfter(dot.getDefiningOp()); - auto dotWait = builder.create(loc, dots.size()); + // 0. insert dot_wait after the last dot in the loop as we implicitly pipeline + // wgmma ops by one stage. + // This is needed to prevent shared memory inputs to be overriden before the + // operation is completed. + // TODO: merge this with the rest of the pipelining transformation and look at + // a better representation for async dots. + tt::DotOp lastDot = dots.back(); + builder.setInsertionPointAfter(lastDot); + auto dotWait = builder.create( + lastDot.getLoc(), lastDot.getResult(), dots.size()); // 1. replace Dot with DotAsync for (size_t idx = 0; idx < dots.size(); ++idx) { - Value dot = dots[idx]; - auto dotOp = cast(dot.getDefiningOp()); - builder.setInsertionPoint(dot.getDefiningOp()); + tt::DotOp dotOp = dots[idx]; + builder.setInsertionPoint(dotOp); auto dotAsync = builder.create( - loc, dotOp.getA(), dotOp.getB(), dotOp.getC(), dotOp.getAllowTF32(), - dotOp.getMaxNumImpreciseAcc()); - dot.replaceAllUsesWith(dotAsync.getResult()); - updateConsumerReleaseInfo(dot.getDefiningOp(), dotWait, /*stage=*/1); - dot.getDefiningOp()->erase(); + dotOp.getLoc(), dotOp.getA(), dotOp.getB(), dotOp.getC(), + dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc()); + dotOp.replaceAllUsesWith(dotAsync.getResult()); + updateConsumerReleaseInfo(dotOp, dotWait, /*stage=*/1); + dotOp->erase(); } // 2. If there's any outstanding DotAsyncOps, we need to wait for them. builder.setInsertionPointAfter(forOp); - Value loopNotEmpty = builder.create( - loc, arith::CmpIPredicate::slt, forOp.getLowerBound(), - forOp.getUpperBound()); - // TODO[goostavz]: it's a workaround to put the DotWaitOp in an IfOp for - // a bug in ptxas which mistakenly analysis the control flow and turn the GMMA - // into synchronuous implementation for safety. - // Remove this If once the bug is fixed. - auto ifOp = builder.create(loc, ArrayRef{}, loopNotEmpty, - /*hasElse*/ false); - builder.setInsertionPointToStart(ifOp.thenBlock()); - builder.create(forOp.getLoc(), 0); + for (unsigned resultIndex : resultNeedSync) { + Value result = forOp->getResult(resultIndex); + if (result.use_empty()) + continue; + auto dotWait = + builder.create(forOp.getLoc(), result, 0); + result.replaceAllUsesExcept(dotWait.getResult(), dotWait); + } } Value PipelinePass::getRemoteCTAId(OpBuilder &b, Location loc, diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index af6ae1c30..f40a21cc9 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -31,6 +31,7 @@ using triton::gpu::SliceEncodingAttr; // // ----------------------------------------------------------------------------- +<<<<<<< HEAD // convert(blocked, dot_operand) -> // convert(blocked, mma) + convert(mma, dot_operand) // if this value is itself the result of a dot operation @@ -102,6 +103,9 @@ public: }; // +======= +// dot(a, b, load(ptr)) -> add(load(ptr), dot(a, b, 0)) +>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 class ConvertDotConvert : public mlir::RewritePattern { public: ConvertDotConvert(mlir::MLIRContext *context) @@ -233,12 +237,17 @@ static bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) { getForwardSlice(currentValue, &forwardSlice); for (Operation *op : forwardSlice) { if (auto convertOp = dyn_cast(op)) { - if (convertOp.getResult() - .getType() - .cast() - .getEncoding() - .isa()) - return true; + Attribute dstEncoding = convertOp.getResult() + .getType() + .cast() + .getEncoding(); + if (auto mmaLayout = + dstEncoding.dyn_cast()) + return (mmaLayout.getVersionMajor() > 1) ? true + : mmaLayout == encoding; + if (dstEncoding.isa()) + return encoding.cast() + .getVersionMajor() > 1; } auto yield = dyn_cast(op); if (!yield) @@ -560,6 +569,15 @@ Value LayoutPropagation::getValueAs(Value value, Attribute encoding) { return rewrittenValue; OpBuilder rewriter(value.getContext()); rewriter.setInsertionPointAfterValue(rewrittenValue); + // Workaround: The pipeliner will insert async.wait after a pipelined loop + // to ensure that there is no pending copies and it is safe to re-use shared + // memory. We shouldn't insert ops that may use shared memory in between the + // loop and the async.wait. This is a hack until we fix the IR + // representation of async wait. + if (Operation *op = rewrittenValue.getDefiningOp()) { + if (isa(op->getNextNode())) + rewriter.setInsertionPointAfter(op->getNextNode()); + } auto tmpType = RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(), encoding); Value converted = rewriter.create( @@ -1122,7 +1140,6 @@ public: hoistConvert(m); mlir::RewritePatternSet decomposePatterns(context); - decomposePatterns.add(context); decomposePatterns.add(context); if (mlir::applyPatternsAndFoldGreedily(m, std::move(decomposePatterns)) .failed()) { diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp index 25a43529d..54cce21ad 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp @@ -91,7 +91,7 @@ private: // suport ForOp only if (auto forOp = dyn_cast(argOwner)) { // prologue - auto iterOperands = forOp.getIterOperands(); + auto iterOperands = forOp.getInitArgs(); if (argNum == 0) return false; if (dependOnSharedEncOperand(iterOperands[argNum - 1])) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp index d79da1ee9..af2799301 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp @@ -628,12 +628,13 @@ bool CTAPlanner::isElementwiseOp(Operation *op) const { arith::CeilDivUIOp, arith::DivFOp, arith::DivSIOp, arith::DivUIOp, arith::ExtFOp, arith::ExtSIOp, arith::ExtUIOp, arith::FloorDivSIOp, arith::FPToSIOp, arith::FPToUIOp, - arith::MaxFOp, arith::MaxSIOp, arith::MaxUIOp, arith::MinFOp, - arith::MinSIOp, arith::MinUIOp, arith::MulFOp, arith::MulIOp, - arith::NegFOp, arith::OrIOp, arith::RemFOp, arith::RemSIOp, - arith::RemUIOp, arith::ShLIOp, arith::ShRSIOp, arith::ShRUIOp, - arith::SIToFPOp, arith::SubFOp, arith::SubIOp, arith::TruncFOp, - arith::TruncIOp, arith::UIToFPOp, arith::XOrIOp>(op)) + arith::MaximumFOp, arith::MaxSIOp, arith::MaxUIOp, + arith::MinimumFOp, arith::MinSIOp, arith::MinUIOp, + arith::MulFOp, arith::MulIOp, arith::NegFOp, arith::OrIOp, + arith::RemFOp, arith::RemSIOp, arith::RemUIOp, arith::ShLIOp, + arith::ShRSIOp, arith::ShRUIOp, arith::SIToFPOp, arith::SubFOp, + arith::SubIOp, arith::TruncFOp, arith::TruncIOp, + arith::UIToFPOp, arith::XOrIOp>(op)) return true; if (llvm::isa backwardSlice; mod.walk([&](triton::MakeTensorPtrOp op) -> void { assert(isa(op->getParentOp())); - getBackwardSlice(op.getOperation(), &backwardSlice); + mlir::BackwardSliceOptions opt; + opt.omitBlockArguments = true; + getBackwardSlice(op.getOperation(), &backwardSlice, opt); op->removeAttr("async_agent"); }); for (auto op : backwardSlice) { diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/WSMaterialization.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/WSMaterialization.cpp index 9ebc78497..3bc0b69f3 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/WSMaterialization.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/WSMaterialization.cpp @@ -79,6 +79,7 @@ void materializeGetAgentIdOp(Operation *parentOp) { builder.setInsertionPoint(agentIdOp); Value globalRoleId = builder.create(loc, 0, 32); int globalNumWarps = 0; + SmallVector deprecatedOps; for (auto cmpOp : agentIdOp->getUsers()) { assert(isa(cmpOp)); for (auto u : cmpOp->getUsers()) { @@ -111,11 +112,14 @@ void materializeGetAgentIdOp(Operation *parentOp) { Value cond = builder.create(loc, lowerBound, upperBound); cmpOp->getResult(0).replaceAllUsesWith(cond); - cmpOp->erase(); + deprecatedOps.push_back(cmpOp); break; } } } + for (Operation *cmpOp : deprecatedOps) { + cmpOp->erase(); + } }); } @@ -145,39 +149,24 @@ LoadType scanLoadTypes(ttng::CreateTokenOp createTokenOp) { } Value getMBarrierPhaseBit(OpBuilder &builder, Operation *op, - bool skipFirstWait) { + bool emptyBarrier) { // TODO: currently we only support one loop, no nested loop, while or // condition. auto loc = op->getLoc(); auto forOp = op->getParentOfType(); if (!forOp) { - return builder.create(loc, skipFirstWait, 1); + return builder.create(loc, emptyBarrier, 1); } - auto defOp = op->getOperand(0).getDefiningOp(); - assert(isa(defOp) && - "mbarrier's definingOp is not createTokenOp"); - ttng::CreateTokenOp createTokenOp = dyn_cast(defOp); - Value numStage = - builder.create(loc, createTokenOp.getNum(), 32); - Value curStep = forOp.getBody()->getArguments().back(); - if (curStep.getType() == builder.getIndexType()) { - curStep = - builder.create(loc, numStage.getType(), curStep); + // for (..., phase, pipelineIdx) + unsigned numArgs = forOp.getBody()->getNumArguments(); + assert(numArgs > 2 && "Unexpected number of arguments"); + Value curPhase = forOp.getBody()->getArgument(numArgs - 2); + if (emptyBarrier) { + Value _1_1b = builder.create(loc, 1, 1); + curPhase = builder.create(loc, curPhase, _1_1b); } - Value curPhase = builder.create(loc, curStep, numStage); - if (skipFirstWait) { - // If skipFirstWait, it waits for phaseBit 1 - Value _1 = builder.create(loc, 1, 32); - curPhase = builder.create(loc, curPhase, _1); - } - Value _2 = builder.create(loc, 2, 32); - // TODO: May use alternative methods of phaseBit calculation to avoid high - // overhead of RemOp - Value phaseBit = builder.create(loc, curPhase, _2); - Value _0 = builder.create(loc, 0, 32); - return builder.create(loc, arith::CmpIPredicate::ne, phaseBit, - _0); + return curPhase; } int getTxBytes(ttng::InsertSliceAsyncV2Op load) { @@ -260,7 +249,7 @@ void processProducerAcquireOp(OpBuilder &builder, ttng::ProducerAcquireOp op, auto loc = op.getLoc(); // The first producer_aquire should be met immediately, so initailly producer // skips the fisrt wait - Value phase = getMBarrierPhaseBit(builder, op, 1); + Value phase = getMBarrierPhaseBit(builder, op, true); auto waitOp = builder.create(loc, bufferEmpty, phase); assert(op.getOperation()->hasAttr("async_agent")); setAgentIds(waitOp, getAgentIds(op.getOperation())); @@ -296,7 +285,7 @@ void processProducerCommitOp(OpBuilder &builder, ttng::ProducerCommitOp op, void processConsumerWaitOp(OpBuilder &builder, ttng::ConsumerWaitOp op, Value bufferFull) { auto loc = op.getLoc(); - Value phase = getMBarrierPhaseBit(builder, op, 0); + Value phase = getMBarrierPhaseBit(builder, op, false); auto waitOp = builder.create(loc, bufferFull, phase); assert(op.getOperation()->hasAttr("async_agent")); setAgentIds(waitOp, getAgentIds(op.getOperation())); @@ -530,6 +519,7 @@ void mutexSyncPingPang(Operation *parentOp, int numAgents, int &nameBarrierId, builder.create(loc, nameBarrierId - 1, 32); // Process mutex users int numUsers = 0; + SmallVector deprecatedOps; for (Operation *user : createMutexOp.getResult().getUsers()) { numUsers++; assert(numUsers <= 2); @@ -543,14 +533,20 @@ void mutexSyncPingPang(Operation *parentOp, int numAgents, int &nameBarrierId, Value barLeave = builder.create( loc, isRole0, namedBarrierId1, namedBarrierId0); builder.create(loc, barLeave, numThreads); - } else + } else { llvm_unreachable("Unexpected user of mutex"); + } + deprecatedOps.push_back(user); + } + for (Operation *user : deprecatedOps) { user->erase(); } nameBarrierId -= 2; nameBarrierIdEnd -= 2; - createMutexOp.erase(); }); + + parentOp->walk( + [](ttng::CreateMutexOp createMutexOp) { createMutexOp.erase(); }); } void processLockOp(OpBuilder &builder, ttng::LockOp op) { @@ -587,6 +583,7 @@ void materializeMutexOperationsOthers(ModuleOp parentOp) { OpBuilder builder(createMutexOp); // Process mutex users + SmallVector deprecatedOps; for (Operation *user : createMutexOp.getResult().getUsers()) { auto loc = user->getLoc(); builder.setInsertionPoint(user); @@ -596,6 +593,9 @@ void materializeMutexOperationsOthers(ModuleOp parentOp) { processUnlockOp(builder, op); else llvm_unreachable("Unexpected user of mutex"); + deprecatedOps.push_back(user); + } + for (Operation *user : deprecatedOps) { user->erase(); } diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/WSMutex.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/WSMutex.cpp index dd39e94b8..3a54e60a4 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/WSMutex.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/WSMutex.cpp @@ -156,14 +156,20 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp, persistentForOp.getInitArgsMutable() .slice(persistentForOp.getInitArgs().size() - 1, 1) .assign(newIdx); - auto yield = - llvm::cast(persistentForOp.getBody()->getTerminator()); - auto idxPlusOneOp = - yield->getOperand(yield->getNumOperands() - 1).getDefiningOp(); - assert(isa(idxPlusOneOp)); - assert(idxPlusOneOp->getOperand(0) == - persistentForOp.getBody()->getArgument( - persistentForOp.getBody()->getNumArguments() - 1)); + + pipelineIdx = persistentForOp.getBody()->getArgument( + persistentForOp.getBody()->getNumArguments() - 1); + Operation *idxPlusOneOp = nullptr; + for (OpOperand &v : pipelineIdx.getUses()) { + if (isa(v.getOwner())) { + idxPlusOneOp = v.getOwner(); + break; + } + } + assert(idxPlusOneOp && "idxPlusOneOp should be arith::AddIOp"); + Operation *use = *idxPlusOneOp->getUsers().begin(); + assert(isa(use) || isa(use) || + isa(use)); idxPlusOneOp->setOperand(1, numRolesValue); // Add operations at the start of persistentForOp @@ -213,45 +219,6 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp, unlockLocs[i] = op; } - // Update unlockLocs - // ====================== IR after async launch dots ====================== - // * %0:2 = scf.for %arg0 = %c0 to %1 step %c1 iter_args(%arg1 = %2, arg2 = - // %3) { - // * triton_nvidia_gpu.producer_wait arg2 - // * %5 = triton_nvidia_gpu.dot_async %4, %5 - // * triton_nvidia_gpu.dot_wait {pendings = 1} - // * %6 = arith.cmpi sgt, arg0, %c0 - // * scf.if %6 { - // * %7 = arith.subi arg2, c1 - // * triton_nvidia_gpu.consumer_release %7 - // * } - // * %8 = arith.addi arg2, c1 - // * scf.yield %5, %8 - // * } - // * triton_nvidia_gpu.dot_wait {pendings = 0} - // * %9 = arith.subi %0#1, c1 - // * triton_nvidia_gpu.consumer_release %9 - // * ======================================================================= - // after async launch dots, there will be outstanding consumerReleaseOp after - // ForOp. we should expend the unlockLocs from ForOp to the outstanding - // consumerReleaseOp. - for (int i = 0; i < numRoles; ++i) { - Operation *unlockOp = unlockLocs[i]; - auto filter = [&](Operation *op) { - return op->getBlock() == unlockOp->getBlock(); - }; - if (isa(unlockOp)) { - SetVector slices; - mlir::getForwardSlice(unlockOp->getResults().back(), &slices, {filter}); - auto iter = llvm::find_if(slices, [](Operation *op) { - return isa(op); - }); - if (iter != slices.end()) { - unlockLocs[i] = *iter; - } - } - } - // Only cases where all lock/unlock locations are in same level make sense. for (int i = 1; i < numRoles; ++i) { if (lockLocs[i]->getParentOp() != lockLocs[i - 1]->getParentOp() || @@ -281,6 +248,54 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp, else lockLocs[i] = unlockLocs[prevTypeIds[i]]; } + + // Update lockLocs + // ====================== IR after async launch dots ====================== + // * %0:2 = scf.for %arg0 = %c0 to %1 step %c1 iter_args(%arg1 = %2, arg2 = + // %3) { + // * triton_nvidia_gpu.producer_wait arg2 + // * %5 = triton_nvidia_gpu.dot_async %4, %5 + // * triton_nvidia_gpu.dot_wait {pendings = 1} + // * %6 = arith.cmpi sgt, arg0, %c0 + // * scf.if %6 { + // * %7 = arith.subi arg2, c1 + // * triton_nvidia_gpu.consumer_release %7 + // * } + // * %8 = arith.addi arg2, c1 + // * scf.yield %5, %8 + // * } + // * triton_nvidia_gpu.dot_wait {pendings = 0} + // * ... + // * triton_nvidia_gpu.consumer_release .. + // * ======================================================================= + // after async launch dots, there will be outstanding consumerReleaseOp after + // ForOp. we should set the epilogue lockLocs after the outstanding + // consumerReleaseOp. + for (int i = 0; i < numRoles; ++i) { + Operation *lockOp = lockLocs[i]; + if (isa(lockOp)) { + Operation *loc = nullptr; + unsigned numOutstandingConsumerRelease = 0; + for (auto v : lockOp->getResults()) { + SetVector slices; + mlir::getForwardSlice(v, &slices); + auto iter = llvm::find_if(slices, [](Operation *op) { + return isa(op); + }); + if (iter != slices.end()) { + numOutstandingConsumerRelease++; + loc = *iter; + } + } + assert(numOutstandingConsumerRelease <= 1 && + "should have only one outstanding " + "consumerReleaseOp after " + "async launch dots"); + if (loc) + lockLocs[i] = loc; + } + } + // lock for (int i = 0; i < numRoles; ++i) { builder.setInsertionPointAfter(lockLocs[i]); diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp index 5d6417fab..e21f6ddd3 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp @@ -129,11 +129,12 @@ DenseMap createBuffer(const SmallVector &channels, } //===----------------------------------------------------------------------===// -// appendPipelineIdxToLoopArgs +// createNewLoops //===----------------------------------------------------------------------===// -scf::ForOp appendPipelineIdxToLoopArgs(scf::ForOp forOp, int numStages, - scf::ForOp &parentForOp) { +// for(...) -> for(..., pipelineIdx) +scf::ForOp createNewPersistentLoop(scf::ForOp forOp, int numStages, + scf::ForOp &parentForOp) { auto loc = forOp.getLoc(); Block *body = forOp.getBody(); @@ -200,6 +201,117 @@ scf::ForOp appendPipelineIdxToLoopArgs(scf::ForOp forOp, int numStages, return newForOp; } +// for(...) -> for(..., phase, pipelineIdx) +scf::ForOp createNewMathLoop(scf::ForOp forOp, int numStages, + scf::ForOp &parentForOp) { + auto loc = forOp.getLoc(); + Block *body = forOp.getBody(); + + // The agentId set of pipelineIdx is the union of agentId sets of all ops in + // the for loop + OpBuilderWithAgentIds builder(forOp.getContext()); + builder.setAgentIdsFromArray(collectAgentIds(forOp)); + + builder.setInsertionPoint(forOp); + Value numStagesVal = + builder.createWithAgentIds(loc, numStages, 32); + + // 0. Append pipelineIdx to block arguments + Value phase = + body->insertArgument(body->getNumArguments(), builder.getI1Type(), loc); + Value pipelineIdx = + body->insertArgument(body->getNumArguments(), builder.getI32Type(), loc); + + // 1. prepare index and phase for next iteration + // nextIdx = curIdx + 1 + // nextPhase = ((nextIdx < numStages && curPhase) || (nextIdx >= numStages && + // curPhase^1)) + // nextIdx = nextIdx >= numStages ? 0 : nextIdx + auto yieldOp = llvm::cast(body->getTerminator()); + builder.setInsertionPoint(yieldOp); + Value one = builder.createWithAgentIds(loc, 1, 32); + Value zero = builder.createWithAgentIds(loc, 0, 32); + Value _1_1b = builder.createWithAgentIds(loc, 1, 1); + // generate index for next iter + Value nextPipelineIdx = + builder.createWithAgentIds(loc, pipelineIdx, one); + Value pipelineGECond = builder.createWithAgentIds( + loc, arith::CmpIPredicate::uge, nextPipelineIdx, numStagesVal); + Value pipelineLTCond = builder.createWithAgentIds( + loc, arith::CmpIPredicate::ult, nextPipelineIdx, numStagesVal); + Value cyclePipelineIdx = builder.createWithAgentIds( + loc, nextPipelineIdx, numStagesVal); + nextPipelineIdx = builder.createWithAgentIds( + loc, pipelineGECond, cyclePipelineIdx, nextPipelineIdx); + // generate phase for next iter + Value flipPhase = + builder.createWithAgentIds(loc, phase, _1_1b); + Value cond0 = builder.createWithAgentIds( + loc, pipelineGECond, flipPhase); + Value cond1 = builder.createWithAgentIds( + loc, pipelineLTCond, phase); + Value nextPhase = + builder.createWithAgentIds(loc, cond0, cond1); + + // 2. Append pipelineIdx to yield operands + yieldOp->insertOperands(yieldOp.getNumOperands(), + {nextPhase, nextPipelineIdx}); + + // 3. create newLoopArgs + SmallVector newLoopArgs; + for (auto operand : forOp.getInitArgs()) + newLoopArgs.push_back(operand); + + builder.setInsertionPoint(forOp); + Value initPipelineIdx, initEmptyIdx, initPhase; + zero = builder.createWithAgentIds(loc, 0, 32); + if (parentForOp) { + // Make sure prior pipelineIdx is inserted in the end of parentForOp + initPipelineIdx = parentForOp.getBody()->getArguments().back(); + Value numSteps = builder.createWithAgentIds( + loc, forOp.getUpperBound(), forOp.getLowerBound()); + numSteps = builder.createWithAgentIds(loc, numSteps, + forOp.getStep()); + Value one = builder.createWithAgentIds(loc, 1, 32); + Value two = builder.createWithAgentIds(loc, 2, 32); + numSteps = builder.createWithAgentIds(loc, numSteps, one); + numSteps = builder.createWithAgentIds(loc, numSteps, + forOp.getStep()); + // initPipelineIdx = (parentForOp.pipelineIdx * numSteps) % numStages + // initPhase = ((parentForOp.pipelineIdx * numSteps) / numStages) & 1 + initPipelineIdx = builder.createWithAgentIds( + loc, initPipelineIdx, numSteps); + Value pipelineIdx = builder.createWithAgentIds( + loc, initPipelineIdx, numStagesVal); + initPipelineIdx = builder.createWithAgentIds( + loc, initPipelineIdx, + builder.createWithAgentIds(loc, pipelineIdx, + numStagesVal)); + pipelineIdx = + builder.createWithAgentIds(loc, pipelineIdx, one); + initPhase = builder.createWithAgentIds( + loc, builder.getI1Type(), pipelineIdx); + } else { + // phase init to false and pipelineIdx init to 0 + initPipelineIdx = zero; + initPhase = builder.createWithAgentIds(loc, 0, 1); + } + newLoopArgs.append({initPhase, initPipelineIdx}); + + // 4. Create newForOp and take the region of forOp + auto newForOp = builder.createWithAgentIds( + loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), + newLoopArgs); + newForOp.getRegion().takeBody(forOp.getRegion()); + + // 5. Replace forOp with newForOp + for (unsigned i = 0; i < forOp.getNumResults(); ++i) + forOp.getResult(i).replaceAllUsesWith(newForOp.getResult(i)); + forOp.erase(); + + return newForOp; +} + //===----------------------------------------------------------------------===// // appendPipelineIdxArgs //===----------------------------------------------------------------------===// @@ -217,7 +329,22 @@ void appendPipelineIdxArgs(SmallVector &backbone, int numStages) { for (auto &op : orderedForOps) { scf::ForOp parentForOp = op->getParentOfType(); - auto newForOp = appendPipelineIdxToLoopArgs(op, numStages, parentForOp); + scf::ForOp newForOp; + bool hasDotOp = false; + for (Operation &subOp : *op.getBody()) { + if (isa(&subOp)) { + hasDotOp = true; + break; + } + } + + if (hasDotOp) { + // for(...) -> for(..., phase, pipelineIdx) + newForOp = createNewMathLoop(op, numStages, parentForOp); + } else { + // for(...) -> for(..., pipelineIdx) + newForOp = createNewPersistentLoop(op, numStages, parentForOp); + } auto backboneForItr = std::find(backbone.begin(), backbone.end(), op.getOperation()); if (backboneForItr != backbone.end()) { @@ -688,8 +815,6 @@ void buildAsyncComm(const DenseMap> &map, } builder.setAgentIdsFromArray(agentsPC); Value pipelineIdx; - Value numStagesVal = builder.createWithAgentIds( - headProducer->getLoc(), numStages, 32); if (auto forOp = headProducer->getParentOfType()) { pipelineIdx = forOp.getBody()->getArguments().back(); } else { @@ -700,10 +825,6 @@ void buildAsyncComm(const DenseMap> &map, // insert ProducerAcquireOp builder.setInsertionPoint(headProducer); - if (headProducer->getParentOfType()) { - pipelineIdx = builder.createWithAgentIds( - headProducer->getLoc(), pipelineIdx, numStagesVal); - } builder.setAgentIdsFromArray(agentP); builder.createWithAgentIds(headProducer->getLoc(), token, pipelineIdx); @@ -738,7 +859,8 @@ void buildAsyncComm(const DenseMap> &map, loc, dotOp.getA(), dotOp.getB(), dotOp.getC(), dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc()); dot.replaceAllUsesWith(dotAsync.getResult()); - builder.createWithAgentIds(loc, 1); + builder.createWithAgentIds( + loc, dotAsync.getResult(), 1); // 1. insert ConsumerReleaseOp for DotAsyncOps Value cond = builder.createWithAgentIds( @@ -747,31 +869,43 @@ void buildAsyncComm(const DenseMap> &map, auto ifOp = builder.createWithAgentIds(loc, ArrayRef{}, cond, /*hasElse*/ false); + setAgentIds(ifOp.thenYield().getOperation(), agentIds); builder.setInsertionPointToStart(ifOp.thenBlock()); - Value one = builder.createWithAgentIds( - headConsumer->getLoc(), 1, 32); - auto oriIdx = forOp.getBody()->getArguments().back(); - Value consumerReleaseIdx = - builder.createWithAgentIds(loc, oriIdx, one); - consumerReleaseIdx = builder.createWithAgentIds( - loc, consumerReleaseIdx, numStagesVal); + Value consumerReleaseIdx = forOp.getBody()->getArguments().back(); + Value zero = builder.createWithAgentIds(loc, 0, 32); + Value one = builder.createWithAgentIds(loc, 1, 32); + Value lastStage = builder.createWithAgentIds( + loc, numStages - 1, 32); + Value consumerReleaseIdxMinusOne = + builder.createWithAgentIds(loc, consumerReleaseIdx, + one); + cond = builder.createWithAgentIds( + loc, arith::CmpIPredicate::eq, consumerReleaseIdx, zero); + consumerReleaseIdx = builder.createWithAgentIds( + loc, cond, lastStage, consumerReleaseIdxMinusOne); builder.createWithAgentIds(loc, token, consumerReleaseIdx); - setAgentIds(ifOp.thenYield().getOperation(), agentIds); // 2. If there's any outstanding DotAsyncOps, we need to wait for them. builder.setInsertionPointAfter(forOp); - builder.createWithAgentIds(forOp.getLoc(), - 0); + unsigned resultIndex = dotAsync->getUses().begin()->getOperandNumber(); + Value result = forOp->getResult(resultIndex); + auto dotWait = builder.createWithAgentIds( + forOp.getLoc(), result, 0); + result.replaceAllUsesExcept(dotWait.getResult(), dotWait); // 3. insert ConsumerReleaseOp for outstanding DotAsyncOps - Value one_ = builder.createWithAgentIds( - headConsumer->getLoc(), 1, 32); + zero = builder.createWithAgentIds(loc, 0, 32); + one = builder.createWithAgentIds(loc, 1, 32); + lastStage = builder.createWithAgentIds( + loc, numStages - 1, 32); consumerReleaseIdx = forOp.getResults().back(); - consumerReleaseIdx = builder.createWithAgentIds( - loc, consumerReleaseIdx, one_); - consumerReleaseIdx = builder.createWithAgentIds( - loc, consumerReleaseIdx, numStagesVal); + consumerReleaseIdxMinusOne = builder.createWithAgentIds( + loc, consumerReleaseIdx, one); + cond = builder.createWithAgentIds( + loc, arith::CmpIPredicate::eq, consumerReleaseIdx, zero); + consumerReleaseIdx = builder.createWithAgentIds( + loc, cond, lastStage, consumerReleaseIdxMinusOne); builder.createWithAgentIds(loc, token, consumerReleaseIdx); dotOp->erase(); diff --git a/lib/Target/LLVMIR/CMakeLists.txt b/lib/Target/LLVMIR/CMakeLists.txt index 9c0a6c26e..e6c7e2b23 100644 --- a/lib/Target/LLVMIR/CMakeLists.txt +++ b/lib/Target/LLVMIR/CMakeLists.txt @@ -14,7 +14,6 @@ add_mlir_translation_library(TritonLLVMIR PUBLIC MLIRArithToLLVM MLIRBuiltinToLLVMIRTranslation - MLIRExecutionEngineUtils MLIRIndexToLLVM MLIRIR MLIRLLVMDialect diff --git a/lib/Target/PTX/PTXTranslation.cpp b/lib/Target/PTX/PTXTranslation.cpp index 3ae1bac1a..5fb7f3fbd 100644 --- a/lib/Target/PTX/PTXTranslation.cpp +++ b/lib/Target/PTX/PTXTranslation.cpp @@ -44,7 +44,8 @@ static bool findAndReplace(std::string &str, const std::string &begin, return true; } -std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) { +std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version, + bool enable_fp_fusion) { // LLVM version in use may not officially support target hardware. // Supported versions for LLVM 14 are here: // https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/clang/include/clang/Basic/BuiltinsNVPTX.def @@ -84,13 +85,15 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) { auto target = llvm::TargetRegistry::lookupTarget(module.getTargetTriple(), error); llvm::TargetOptions opt; - opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; + if (enable_fp_fusion) + opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; opt.UnsafeFPMath = false; opt.NoInfsFPMath = false; opt.NoNaNsFPMath = true; + opt.TrapUnreachable = true; std::unique_ptr machine{target->createTargetMachine( module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, - std::nullopt, llvm::CodeGenOpt::Aggressive)}; + std::nullopt, llvm::CodeGenOptLevel::Aggressive)}; // set data layout if (layout.empty()) module.setDataLayout(machine->createDataLayout()); @@ -106,7 +109,7 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) { llvm::legacy::PassManager pass; // emit machine->addPassesToEmitFile(pass, pstream, nullptr, - llvm::CodeGenFileType::CGFT_AssemblyFile); + llvm::CodeGenFileType::AssemblyFile); pass.run(module); } // post-process diff --git a/python/setup.py b/python/setup.py index f420e7f18..c5ddff6e0 100644 --- a/python/setup.py +++ b/python/setup.py @@ -73,23 +73,20 @@ def get_llvm_package_info(): if arch == 'aarch64': arch = 'arm64' if system == "Darwin": - system_suffix = "apple-darwin" arch = platform.machine() + system_suffix = f"macos-{arch}" elif system == "Linux": + # TODO: arm64 vglibc = tuple(map(int, platform.libc_ver()[1].split('.'))) vglibc = vglibc[0] * 100 + vglibc[1] - linux_suffix = 'ubuntu-18.04' if vglibc > 217 else 'centos-7' - system_suffix = f"linux-gnu-{linux_suffix}" + system_suffix = 'ubuntu-x64' if vglibc > 217 else 'centos-x64' else: return Package("llvm", "LLVM-C.lib", "", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH") - use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False") - release_suffix = "assert" if use_assert_enabled_llvm else "release" - name = f'llvm+mlir-17.0.0-{arch}-{system_suffix}-{release_suffix}' - version = "llvm-17.0.0-c5dede880d17" - url = f"https://github.com/ptillet/triton-llvm-releases/releases/download/{version}/{name}.tar.xz" - # FIXME: remove the following once github.com/ptillet/triton-llvm-releases has arm64 llvm releases - if arch == 'arm64' and 'linux' in system_suffix: - url = f"https://github.com/acollins3/triton-llvm-releases/releases/download/{version}/{name}.tar.xz" + # use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False") + # release_suffix = "assert" if use_assert_enabled_llvm else "release" + rev = "b1115f8c" + name = f"llvm-{rev}-{system_suffix}" + url = f"https://tritonlang.blob.core.windows.net/llvm-builds/{name}.tar.gz" return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH") diff --git a/python/src/triton.cc b/python/src/triton.cc index 149ef1034..fe335d93a 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1236,7 +1236,7 @@ void init_triton_ir(py::module &&m) { .def("create_minf", [](TritonOpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { - return mlir::Value(self.create(lhs, rhs)); + return mlir::Value(self.create(lhs, rhs)); }) .def("create_maxsi", [](TritonOpBuilder &self, mlir::Value &lhs, @@ -1251,7 +1251,7 @@ void init_triton_ir(py::module &&m) { .def("create_maxf", [](TritonOpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { - return mlir::Value(self.create(lhs, rhs)); + return mlir::Value(self.create(lhs, rhs)); }) // AddPtr (similar to GEP) .def("create_addptr", @@ -2006,7 +2006,8 @@ void init_triton_translation(py::module &m) { m.def( "translate_llvmir_to_ptx", - [](const std::string llvmIR, int capability, int version) -> std::string { + [](const std::string llvmIR, int capability, int version, + bool enable_fp_fusion) -> std::string { py::gil_scoped_release allow_threads; // create LLVM module from C++ llvm::LLVMContext context; @@ -2021,75 +2022,77 @@ void init_triton_translation(py::module &m) { "lineno: " + std::to_string(error.getLineNo())); } // translate module to PTX - auto ptxCode = - triton::translateLLVMIRToPTX(*module, capability, version); + auto ptxCode = triton::translateLLVMIRToPTX(*module, capability, + version, enable_fp_fusion); return ptxCode; }, ret::take_ownership); - m.def( - "compile_ptx_to_cubin", - [](const std::string &ptxCode, const std::string &ptxasPath, - int capability) -> py::object { - std::string cubin; - { - py::gil_scoped_release allow_threads; + m.def("compile_ptx_to_cubin", + [](const std::string &ptxCode, const std::string &ptxasPath, + int capability, bool enable_fp_fusion) -> py::object { + std::string cubin; + { + py::gil_scoped_release allow_threads; - // compile ptx with ptxas - llvm::SmallString<64> fsrc; - llvm::SmallString<64> flog; - llvm::sys::fs::createTemporaryFile("compile-ptx-src", "", fsrc); - llvm::sys::fs::createTemporaryFile("compile-ptx-log", "", flog); - std::string fbin = std::string(fsrc) + ".o"; - llvm::FileRemover logRemover(flog); - llvm::FileRemover binRemover(fbin); - const char *_fsrc = fsrc.c_str(); - const char *_flog = flog.c_str(); - const char *_fbin = fbin.c_str(); - std::ofstream ofs(_fsrc); - ofs << ptxCode << std::endl; - ofs.close(); + // compile ptx with ptxas + llvm::SmallString<64> fsrc; + llvm::SmallString<64> flog; + llvm::sys::fs::createTemporaryFile("compile-ptx-src", "", fsrc); + llvm::sys::fs::createTemporaryFile("compile-ptx-log", "", flog); + std::string fbin = std::string(fsrc) + ".o"; + llvm::FileRemover logRemover(flog); + llvm::FileRemover binRemover(fbin); + const char *_fsrc = fsrc.c_str(); + const char *_flog = flog.c_str(); + const char *_fbin = fbin.c_str(); + std::ofstream ofs(_fsrc); + ofs << ptxCode << std::endl; + ofs.close(); - auto lineInfoOption = - triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO") - ? "" - : " -lineinfo"; - auto capabilitySuffix = (capability == 90) ? "a " : " "; - auto outputFileName = std::string(_fsrc) + ".o"; - auto logRedirect = " 2> " + std::string(_flog); - std::string cmd = ptxasPath + lineInfoOption + " -v --gpu-name=sm_" + - std::to_string(capability) + capabilitySuffix + - _fsrc + " -o " + outputFileName + logRedirect; + auto lineInfoOption = + triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO") + ? "" + : " -lineinfo"; + auto fmadOption = enable_fp_fusion ? "" : " --fmad=false"; + auto capabilitySuffix = (capability == 90) ? "a " : " "; + auto outputFileName = std::string(_fsrc) + ".o"; + auto logRedirect = " 2> " + std::string(_flog); + std::string cmd = ptxasPath + lineInfoOption + fmadOption + + " -v --gpu-name=sm_" + + std::to_string(capability) + capabilitySuffix + + _fsrc + " -o " + outputFileName + logRedirect; - int err = system(cmd.c_str()); - if (err != 0) { - err >>= 8; - std::ifstream _log(_flog); - std::string log(std::istreambuf_iterator(_log), {}); - if (err == 255) { - throw std::runtime_error("Internal Triton PTX codegen error: \n" + - log); - } else if (err == 128 + SIGSEGV) { - throw std::runtime_error("Please run `ptxas " + fsrc.str().str() + - "` to confirm that this is a " - "bug in `ptxas`\n" + - log); + int err = system(cmd.c_str()); + if (err != 0) { + err >>= 8; + std::ifstream _log(_flog); + std::string log(std::istreambuf_iterator(_log), {}); + if (err == 255) { + throw std::runtime_error( + "Internal Triton PTX codegen error: \n" + log); + } else if (err == 128 + SIGSEGV) { + throw std::runtime_error("Please run `ptxas " + + fsrc.str().str() + + "` to confirm that this is a " + "bug in `ptxas`\n" + + log); + } else { + throw std::runtime_error("`ptxas` failed with error code " + + std::to_string(err) + ": \n" + log); + } + return {}; } else { - throw std::runtime_error("`ptxas` failed with error code " + - std::to_string(err) + ": \n" + log); + llvm::FileRemover srcRemover(fsrc); + std::ifstream _cubin(_fbin, std::ios::binary); + cubin = std::string(std::istreambuf_iterator(_cubin), {}); + _cubin.close(); + // Do not return here, exit the gil scope and return below } - return {}; - } else { - llvm::FileRemover srcRemover(fsrc); - std::ifstream _cubin(_fbin, std::ios::binary); - cubin = std::string(std::istreambuf_iterator(_cubin), {}); - _cubin.close(); - // Do not return here, exit the gil scope and return below } - } - py::bytes bytes(cubin); - return std::move(bytes); - }); + py::bytes bytes(cubin); + return std::move(bytes); + }); m.def("add_external_libs", [](mlir::ModuleOp &op, const std::vector &names, diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index a8557b4e6..c8cab4a99 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -860,9 +860,9 @@ def test_unary_op(dtype_x, expr, num_ctas, device): # ---------------- -@pytest.mark.parametrize("dtype_x, expr", [(dtype_x, expr) for dtype_x in ["float32", "float64"] for expr in ['exp', 'log', 'cos', 'sin']]) -def test_math_op(dtype_x, expr, device): - _test_unary(dtype_x, f'tl.{expr}(x)', f'np.{expr}(x) ', device=device) +@pytest.mark.parametrize("dtype_x, expr, x", [(dtype_x, expr, x) for dtype_x in ["float32", "float64"] for expr in ['exp', 'log', 'cos', 'sin'] for x in ['x', '3.0']]) +def test_math_op(dtype_x, expr, device, x): + _test_unary(dtype_x, f'tl.{expr}({x})', f'np.{expr}({x}) ', device=device) # ---------------- # test abs @@ -1662,10 +1662,18 @@ reduce_configs2 = [ for op in ['min', 'max', 'sum'] ] +reduce3d_shapes = [(2, 32, 16), (32, 2, 16), (32, 16, 2)] +reduce_configs3 = [ + (op, 'float32', shape, axis) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for shape in reduce3d_shapes + for axis in [0, 1, 2] +] -@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2) + +@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2 + reduce_configs3) @pytest.mark.parametrize("num_ctas", num_ctas_list) -def test_reduce2d(op, dtype_str, shape, axis, num_ctas, device): +def test_reduce(op, dtype_str, shape, axis, num_ctas, device): check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested if is_hip(): @@ -1673,17 +1681,31 @@ def test_reduce2d(op, dtype_str, shape, axis, num_ctas, device): # triton kernel @triton.jit - def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): + def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr, AXIS: tl.constexpr): range_m = tl.arange(0, BLOCK_M) range_n = tl.arange(0, BLOCK_N) - x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) - z = GENERATE_TEST_HERE - if AXIS is None: - tl.store(Z, z) - elif AXIS == 1: - tl.store(Z + range_m, z) + range_k = tl.arange(0, BLOCK_K) + if IS_3D: + x = tl.load(X + range_m[:, None, None] * BLOCK_N * BLOCK_K + range_n[None, :, None] * BLOCK_K + range_k[None, None, :]) else: - tl.store(Z + range_n, z) + x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) + z = GENERATE_TEST_HERE + if IS_3D: + if AXIS is None: + tl.store(Z, z) + elif AXIS == 0: + tl.store(Z + range_n[:, None] * BLOCK_K + range_k[None, :], z) + elif AXIS == 1: + tl.store(Z + range_m[:, None] * BLOCK_K + range_k[None, :], z) + else: + tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z) + else: + if AXIS is None: + tl.store(Z, z) + elif AXIS == 0: + tl.store(Z + range_n, z) + else: + tl.store(Z + range_m, z) kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS)'}) # input @@ -1706,10 +1728,13 @@ def test_reduce2d(op, dtype_str, shape, axis, num_ctas, device): z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str)) # triton result ret_numel = 1 if axis is None else shape[1 - axis] - z_tri = to_triton(numpy_random((ret_numel,), dtype_str=z_dtype_str, rs=rs), + z_shape = (1,) if axis is None else tuple(shape_i for i, shape_i in enumerate(shape) if i != axis) + z_tri = to_triton(numpy_random(z_shape, dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str) + BLOCK_K = 1 if len(shape) == 2 else shape[2] + IS_3D = bool(len(shape) == 3) kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], - BLOCK_N=shape[1], AXIS=axis, num_ctas=num_ctas) + BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, num_ctas=num_ctas) z_tri = to_numpy(z_tri) # compare if op == 'sum': @@ -1890,7 +1915,7 @@ def test_reduce_layouts(M, N, src_layout, axis, reduce2d, dtype_str, reduce_op, ty = {"int32": "i32", "float32": "f32", "float16": "f16"}[dtype_str] arith_op = { - "max": {"int32": "arith.maxsi", "float32": "arith.maxf", "float16": "arith.maxf"}, + "max": {"int32": "arith.maxsi", "float32": "arith.maximumf", "float16": "arith.maximumf"}, "sum": {"int32": "arith.addi", "float32": "arith.addf", "float16": "arith.addf"} }[reduce_op][dtype_str] numpy_op = { @@ -3049,6 +3074,20 @@ def test_constexpr_scalar_shape(device): kernel[(1,)](x_tri, 32) np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8) + +@triton.jit +def static_assert_func(): + tl.static_assert(tl.constexpr(False), f"Assert is firing because the constexpr progation did not work properly") + + +def test_constexpr_propagation(): + @triton.jit + def _kernel(COND: tl.constexpr): + NEW_COND = COND + if NEW_COND: + static_assert_func() + _kernel[(1,)](False) + # ------------- # test call # ------------- @@ -3893,3 +3932,22 @@ def test_fp8_dot_acc(in_type_str, low_precision_acc, device): torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3) else: torch.testing.assert_close(ref_out, C) + +# ----------------------- +# test enable_fp_fusion +# ----------------------- + + +@pytest.mark.parametrize("enable_fp_fusion", [False, True]) +def test_enable_fp_fusion(enable_fp_fusion): + # Sequential multiply add can be fused by backend + @triton.jit + def mul_add(data): + ptrs = data + tl.arange(0, 128) + tl.store(ptrs, tl.load(ptrs) * 1.5 + 1.0) + + data = torch.randn((128,), device='cuda', dtype=torch.float32) + h = mul_add[(1,)](data, enable_fp_fusion=enable_fp_fusion) + + found_fma = re.search(r'(mad|fma)\.r[nzmp]\.(ftz\.)?f32', h.asm["ptx"]) is not None + assert found_fma == enable_fp_fusion diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 21bd0cbf4..1bba2d079 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -228,6 +228,7 @@ class CodeGenerator(ast.NodeVisitor): self.local_defs: Dict[str, tensor] = {} self.global_uses: Dict[str, tensor] = {} self.dereference_name: Callable[[str], Any] = self._define_name_lookup() + self.fn = None builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (range, float, int, isinstance, getattr)} builtin_namespace.update(( @@ -322,6 +323,8 @@ class CodeGenerator(ast.NodeVisitor): def visit_FunctionDef(self, node): arg_names, kwarg_names = self.visit(node.args) + if self.fn: + raise UnsupportedLanguageConstruct(None, node, "nested function definition is not supported.") # initialize defaults for i, default_value in enumerate(node.args.defaults): arg_node = node.args.args[-i - 1] @@ -335,9 +338,9 @@ class CodeGenerator(ast.NodeVisitor): self.visit(init_node) # initialize function visibility = "public" if self.is_kernel else "private" - fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility, self.noinline) - self.module.push_back(fn) - entry = fn.add_entry_block() + self.fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility, self.noinline) + self.module.push_back(self.fn) + entry = self.fn.add_entry_block() arg_values = [] idx = 0 for i, arg_name in enumerate(arg_names): @@ -350,8 +353,8 @@ class CodeGenerator(ast.NodeVisitor): else: if i in self.attributes: for name, value in self.attributes[i]: - fn.set_arg_attr(idx, name, value) - arg_values.append(tensor(fn.args(idx), self.prototype.param_types[idx])) + self.fn.set_arg_attr(idx, name, value) + arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx])) idx += 1 insert_pt = self.builder.get_insertion_block() @@ -367,14 +370,14 @@ class CodeGenerator(ast.NodeVisitor): # update return type if isinstance(self.last_ret_type, tuple): self.prototype.ret_types = list(self.last_ret_type) - fn.reset_type(self.prototype.to_ir(self.builder)) + self.fn.reset_type(self.prototype.to_ir(self.builder)) else: self.prototype.ret_types = [self.last_ret_type] - fn.reset_type(self.prototype.to_ir(self.builder)) + self.fn.reset_type(self.prototype.to_ir(self.builder)) if insert_pt: self.builder.set_insertion_point_to_end(insert_pt) # Remove dead code - fn.finalize() + self.fn.finalize() def visit_arguments(self, node): arg_names = [] @@ -412,6 +415,9 @@ class CodeGenerator(ast.NodeVisitor): raise UnsupportedLanguageConstruct(None, node, "simultaneous multiple assignment is not supported.") names = _names[0] values = self.visit(node.value) + if not isinstance(node.value, ast.Constant) and _is_constexpr(values): + self.set_value(names, values) + return if not _is_list_like(names): names = [names] if not _is_list_like(values): @@ -686,9 +692,13 @@ class CodeGenerator(ast.NodeVisitor): for name in loop_defs: if name in liveins: # We should not def new constexpr - assert _is_triton_tensor(loop_defs[name]) - assert _is_triton_tensor(liveins[name]) - assert loop_defs[name].type == liveins[name].type + assert _is_triton_tensor(loop_defs[name]), f'cannoe reassign constxpr {name} in the loop' + assert _is_triton_tensor(liveins[name]), f'cannot reasign constexpr {name} in the loop' + assert loop_defs[name].type == liveins[name].type, \ + f'Loop-carried variable {name} has initial type {liveins[name].type} '\ + f'but is re-assigned to {loop_defs[name].type} in loop! '\ + f'Please make sure that the type stays consistent.' + # these are loop-carried values names.append(name) ret_types.append(loop_defs[name].type) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index c8d57d0ef..f55f3adb6 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -37,6 +37,7 @@ CUDA_DEFAULT_WARP_SIZE = 32 class CudaTargetDescriptor: capability: int num_warps: int + enable_fp_fusion: bool def _is_cuda(target): @@ -147,6 +148,7 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, pm.add_tritongpu_wspipeline_pass(num_stages, num_warps, capability) pm.add_tritongpu_wsmutex_pass(capability) pm.add_tritongpu_wsmaterialization_pass(capability) + pm.add_licm_pass() pm.add_cse_pass() else: if is_hip(): @@ -220,7 +222,7 @@ def llir_to_ptx(mod: Any, target: CudaTargetDescriptor, ptx_version: int = None) if ptx_version is None: _, cuda_version = path_to_ptxas() ptx_version = ptx_get_version(cuda_version) - return translate_llvmir_to_ptx(mod, target.capability, ptx_version) + return translate_llvmir_to_ptx(mod, target.capability, ptx_version, target.enable_fp_fusion) def ptx_to_cubin(ptx: str, target: CudaTargetDescriptor): @@ -231,7 +233,7 @@ def ptx_to_cubin(ptx: str, target: CudaTargetDescriptor): :return: str ''' ptxas, _ = path_to_ptxas() - return compile_ptx_to_cubin(ptx, ptxas, target.capability) + return compile_ptx_to_cubin(ptx, ptxas, target.capability, target.enable_fp_fusion) # ------------------------------------------------------------------------------ @@ -407,8 +409,12 @@ def compile(fn, **kwargs): assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2" num_ctas = kwargs.get("num_ctas", 1) num_stages = kwargs.get("num_stages", get_arch_default_num_stages(device_type, capability=capability)) +<<<<<<< HEAD waves_per_eu = kwargs.get("waves_per_eu", 0) matrix_instr_nonkdim = kwargs.get("matrix_instr_nonkdim", 0) +======= + enable_fp_fusion = kwargs.get("enable_fp_fusion", True) +>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 # TODO[shuhaoj]: Default should be to enable warp specialization once possible enable_warp_specialization = kwargs.get("enable_warp_specialization", False) # TODO[shuhaoj]: persistent can be decoupled with warp specialization @@ -431,7 +437,7 @@ def compile(fn, **kwargs): # build architecture descriptor if device_type == "cuda": _device_backend = get_backend(device_type) - target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps) + target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps, enable_fp_fusion=enable_fp_fusion) else: _device_backend = get_backend(device_type) assert _device_backend diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 5e2abf192..a262e7d66 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1057,14 +1057,19 @@ def load(pointer, mask=None, other=None, boundary_check=tuple(), padding_option= """ Return a tensor of data whose values are loaded from memory at location defined by `pointer`: (1) `pointer` could be a single element pointer, then a scalar will be loaded + - `mask` and `other` must be scalar too - `other` is implicitly typecast to `pointer.dtype.element_ty` - `boundary_check` and `padding_option` must be empty + (2) `pointer` could be element-wise tensor of pointers, in which case: + - `mask` and `other` are implicitly broadcast to `pointer.shape` - `other` is implicitly typecast to `pointer.dtype.element_ty` - `boundary_check` and `padding_option` must be empty + (3) `pointer` could be a block pointer defined by `make_block_ptr`, in which case: + - `mask` and `other` must be None - `boundary_check` and `padding_option` can be specified to control the behavior of out-of-bound access @@ -1103,14 +1108,20 @@ def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", evict """ Store a tensor of data into memory locations defined by `pointer`: (1) `pointer` could be a single element pointer, then a scalar will be stored + - `mask` must be scalar too - `boundary_check` and `padding_option` must be empty + (2) `pointer` could be element-wise tensor of pointers, in which case: + - `mask` is implicitly broadcast to `pointer.shape` - `boundary_check` must be empty + (3) or `pointer` could be a block pointer defined by `make_block_ptr`, in which case: + - `mask` must be None - `boundary_check` can be specified to control the behavior of out-of-bound access + `value` is implicitly broadcast to `pointer.shape` and typecast to `pointer.dtype.element_ty`. :param pointer: The memory location where the elements of `value` are stored @@ -1165,29 +1176,35 @@ def advance(base: tensor, offsets, _builder=None): # ----------------------- -def _add_atomic_docstr(name: str) -> Callable[[T], T]: +def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]: def _decorator(func: T) -> T: - docstr = """ + docstr = f""" Performs an atomic {name} at the memory location specified by :code:`pointer`. Return the data stored at :code:`pointer` before the atomic operation. - :param pointer: The memory locations to compare-and-swap. - :type pointer: Block of dtype=triton.PointerDType + :param pointer: The memory locations to operate on + :type pointer: Block of dtype=triton.PointerDType""" + if has_cmp: + docstr += """ :param cmp: The values expected to be found in the atomic object - :type cmp: Block of dtype=`pointer.dtype.element_ty` - :param val: The values to copy in case the expected value matches the contained value. - :type val: Block of dtype=`pointer.dtype.element_ty` + :type cmp: Block of dtype=pointer.dtype.element_ty""" + docstr += """ + :param val: The values with which to perform the atomic operation + :type val: Block of dtype=pointer.dtype.element_ty + :param sem: Memory semantics to use ("ACQUIRE_RELEASE" (default), + "ACQUIRE", "RELEASE", or "RELAXED") + :type sem: str """ - func.__doc__ = docstr.format(name=name) + func.__doc__ = docstr return func return _decorator @builtin -@_add_atomic_docstr("compare-and-swap") +@_add_atomic_docstr("compare-and-swap", has_cmp=True) def atomic_cas(pointer, cmp, val, sem=None, _builder=None): cmp = _to_tensor(cmp, _builder) val = _to_tensor(val, _builder) @@ -1309,6 +1326,8 @@ def fdiv(x, y, ieee_rounding=False, _builder=None): :type ieee_rounding: bool """ ieee_rounding = _constexpr_to_value(ieee_rounding) + x = _to_tensor(x, _builder) + y = _to_tensor(y, _builder) return semantic.fdiv(x, y, ieee_rounding, _builder) @@ -1330,36 +1349,42 @@ def _add_math_1arg_docstr(name: str) -> Callable[[T], T]: @builtin @_add_math_1arg_docstr("exponential") def exp(x, _builder=None): + x = _to_tensor(x, _builder) return semantic.exp(x, _builder) @builtin @_add_math_1arg_docstr("natural logarithm") def log(x, _builder=None): + x = _to_tensor(x, _builder) return semantic.log(x, _builder) @builtin @_add_math_1arg_docstr("cosine") def cos(x, _builder=None): + x = _to_tensor(x, _builder) return semantic.cos(x, _builder) @builtin @_add_math_1arg_docstr("sine") def sin(x, _builder=None): + x = _to_tensor(x, _builder) return semantic.sin(x, _builder) @builtin @_add_math_1arg_docstr("square root") def sqrt(x, _builder=None): + x = _to_tensor(x, _builder) return semantic.sqrt(x, _builder) @builtin @_add_math_1arg_docstr("absolute value") def abs(x, _builder=None): + x = _to_tensor(x, _builder) return semantic.abs(x, _builder) diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index e066ea082..ac364ffa7 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -453,7 +453,7 @@ def _unwrap(tensor): builder = Builder() -RESERVED_KWS = ['num_warps', 'num_stages', 'num_ctas', 'enable_warp_specialization'] +RESERVED_KWS = ['num_warps', 'num_stages', 'num_ctas', 'enable_warp_specialization', 'enable_fp_fusion'] class GridExecutor: diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index b558b1380..7447b06f0 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -281,13 +281,21 @@ class JITFunction(KernelInterface[T]): constants = dict(zip(self.constexprs, constexpr_key)) return constants +<<<<<<< HEAD def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, configs): +======= + def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, configs): +>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 if JITFunction.cache_hook is None: return False name = self.fn.__name__ module = self.fn.__module__ arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])]) +<<<<<<< HEAD repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, matrix_instr_nonkdim={matrix_instr_nonkdim}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs})" +======= + repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, enable_warp_specialization={enable_warp_specialization}, enable_fp_fusion={enable_fp_fusion}]({arg_reprs})" +>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 key = str(key) class LegacyCompiler: @@ -297,7 +305,11 @@ class JITFunction(KernelInterface[T]): pass kwargs = dict(signature=signature, device=device, constants=constants, +<<<<<<< HEAD num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, +======= + num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs, +>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 configs=configs) return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={ @@ -351,7 +363,11 @@ class JITFunction(KernelInterface[T]): def regular_args_v(args_proxy): return [args_proxy[arg_name] for arg_name in regular_args] +<<<<<<< HEAD def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, stream, warmup, device, device_type): +======= + def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, stream, warmup, device, device_type): +>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 from ..compiler import (CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps) @@ -402,7 +418,11 @@ class JITFunction(KernelInterface[T]): if num_stages is None: num_stages = get_arch_default_num_stages(device_type) +<<<<<<< HEAD key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, self.debug) +======= + key = (version_key(), sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, self.debug) +>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 if extern_libs is not None: key = (key, tuple(extern_libs.items())) @@ -430,8 +450,13 @@ class JITFunction(KernelInterface[T]): for i, arg in constants.items(): if callable(arg): raise TypeError(f"Callable constexpr at index {i} is not supported") +<<<<<<< HEAD if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, configs): bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=matrix_instr_nonkdim, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type) +======= + if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, configs): + bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, enable_fp_fusion=enable_fp_fusion, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type) +>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 # Create tensormaps and append to args args = bin.assemble_tensormap_to_arg(args) if not warmup: @@ -446,8 +471,13 @@ class JITFunction(KernelInterface[T]): args_signature = args_signature + ', ' if len(args_signature) > 0 else '' src = f""" import triton +<<<<<<< HEAD def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, waves_per_eu=0, matrix_instr_nonkdim=0, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None): return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, stream, warmup, device, device_type) +======= +def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, enable_fp_fusion=True, extern_libs=None, stream=None, warmup=False, device=None, device_type=None): + return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, enable_fp_fusion, extern_libs, stream, warmup, device, device_type) +>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 """ scope = {"launcher_body": launcher_body} exec(src, scope) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 2b94edd60..827413e9d 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -32,13 +32,17 @@ def _attn_fwd_inner( STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, + N_CTX: tl.constexpr, ): # range of values handled by this stage if STAGE == 1: lo, hi = 0, start_m * BLOCK_M - else: + elif STAGE == 2: lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M lo = tl.multiple_of(lo, BLOCK_M) + # causal = False + else: + lo, hi = 0, N_CTX K_block_ptr = tl.advance(K_block_ptr, (0, lo)) V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) # loop over k, v and update accumulator @@ -72,6 +76,7 @@ def _attn_fwd_inner( K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) return acc, l_i, m_i +<<<<<<< HEAD @triton.jit def _attn_fwd_inner( acc, l_i, m_i, q, @@ -138,6 +143,31 @@ def _attn_fwd_inner( ], key=['Z', 'H', 'N_CTX', 'STAGE', 'BLOCK_DMODEL'], ) +======= +# We don't run auto-tuning everytime to keep the tutorial fast. Uncommenting +# the code below and commenting out the equivalent parameters is convenient for +# re-tuning. +# @triton.autotune( +# configs=[ +# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=4, num_warps=8), +# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64}, num_stages=3, num_warps=8), +# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 32}, num_stages=3, num_warps=8), +# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 32}, num_stages=3, num_warps=4), +# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=3, num_warps=4), +# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=4, num_warps=4), +# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=3, num_warps=4), +# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=4, num_warps=4), +# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=3, num_warps=8), +# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=7, num_warps=8), +# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=7, num_warps=8), +# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=6, num_warps=8), +# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=5, num_warps=8), +# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=4, num_warps=8), +# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=6, num_warps=4), +# ], +# key=['N_CTX'], +# ) +>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 @triton.jit @@ -227,8 +257,12 @@ def _attn_fwd( acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, BLOCK_M, BLOCK_DMODEL, BLOCK_N, +<<<<<<< HEAD 4 - STAGE, offs_m, offs_n, N_CTX, pre_load_v, +======= + 4 - STAGE, offs_m, offs_n, N_CTX, +>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 ) # stage 2: on-band if STAGE & 2: @@ -239,8 +273,12 @@ def _attn_fwd( acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, BLOCK_M, BLOCK_DMODEL, BLOCK_N, +<<<<<<< HEAD 2, offs_m, offs_n, N_CTX, pre_load_v, +======= + 2, offs_m, offs_n, N_CTX, +>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 ) # epilogue # write back m @@ -701,6 +739,7 @@ class _attention(torch.autograd.Function): Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv assert Lk in {16, 32, 64, 128} +<<<<<<< HEAD o = torch.empty_like(q, dtype=v.dtype) if torch.version.hip is None: BLOCK_M = 128 @@ -716,6 +755,20 @@ class _attention(torch.autograd.Function): ) M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) +======= + o = torch.empty_like(q) + BLOCK_M = 128 + BLOCK_N = 64 if Lk <= 64 else 32 + num_stages = 4 if Lk <= 64 else 3 + num_warps = 4 + stage = 3 if causal else 1 + # Tuning for H100 + if torch.cuda.get_device_capability()[0] == 9: + num_warps = 8 + num_stages = 7 if Lk >= 64 else 3 + grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) + M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) +>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 _attn_fwd[grid]( q, k, v, sm_scale, M, o, q.stride(0), q.stride(1), q.stride(2), q.stride(3), @@ -726,6 +779,11 @@ class _attention(torch.autograd.Function): N_CTX=q.shape[2], BLOCK_DMODEL=Lk, STAGE=stage, +<<<<<<< HEAD +======= + num_warps=num_warps, + num_stages=num_stages, +>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 ) ## restore the grid for bwd kernel @@ -905,6 +963,7 @@ try: except BaseException: HAS_FLASH = False +<<<<<<< HEAD # vary seq length for fixed head and batch=4 configs = [] for mode in ['fwd', 'bwd']: @@ -939,6 +998,36 @@ for mode in ['fwd', 'bwd']: 'mode': mode, 'causal': causal}) ) +======= +TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') +BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 +# vary seq length for fixed head and batch=4 +configs = [] +for mode in ["fwd", "bwd"]: + for causal in [True, False]: + if mode == "bwd" and not causal: + continue + configs.append( + triton.testing.Benchmark( + x_names=["N_CTX"], + x_vals=[2**i for i in range(10, 15)], + line_arg="provider", + line_vals=["triton"] + (["flash"] if HAS_FLASH else []), + line_names=["Triton"] + (["Flash-2"] if HAS_FLASH else []), + styles=[("red", "-"), ("blue", "-")], + ylabel="ms", + plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}", + args={ + "H": N_HEADS, + "BATCH": BATCH, + "D_HEAD": D_HEAD, + "dtype": torch.float16, + "mode": mode, + "causal": causal, + }, + ) + ) +>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 @triton.testing.perf_report(configs) @@ -956,6 +1045,9 @@ def bench_flash_attention( if provider == "triton": q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + if mode == "fwd" and TORCH_HAS_FP8: + q = q.to(torch.float8_e5m2) + k = k.to(torch.float8_e5m2) v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) if mode == "fwd": q = q.to(torch_dtype) diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir index 1f3b1df65..6c68cd24d 100644 --- a/test/Analysis/test-alias.mlir +++ b/test/Analysis/test-alias.mlir @@ -216,10 +216,10 @@ tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK-NEXT: %arg9 -> %cst_1 // CHECK-NEXT: %0#0 -> %cst // CHECK-NEXT: %0#1 -> %cst_0 - // CHECK-NEXT: %0#2 -> %cst_2,%cst_2 + // CHECK-NEXT: %0#2 -> %cst_1,%cst_2,%cst_2 %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { // CHECK-NEXT: %arg11 -> %cst_1,%cst_2,%cst_2 - // CHECK-NEXT: %1 -> %cst_2,%cst_2 + // CHECK-NEXT: %1 -> %cst_1,%cst_2,%cst_2 %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (tensor<128x32xf16, #A_SHARED>) { // CHECK-NEXT: %2 -> %cst_2,%cst_2 %c_shared_next_next = scf.if %i1 -> tensor<128x32xf16, #A_SHARED> { @@ -257,9 +257,9 @@ tt.func @cf_for(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr, % %5 = arith.cmpi slt, %1, %arg1 : index cf.cond_br %5, ^bb2, ^bb3 ^bb2: // pred: ^bb1 - %6 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #blocked> + %6 = tt.cat %cst, %cst_0 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16> gpu.barrier - %7 = tt.cat %2, %3 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #blocked> + %7 = tt.cat %2, %3 {axis = 0 : i64} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16> %8 = arith.addi %1, %arg2 : index cf.br ^bb1(%8, %4, %2, %3 : index, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) ^bb3: // pred: ^bb1 diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index c265f2f44..5b0b0778a 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -864,6 +864,37 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 8], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#slice1d0 = #triton_gpu.slice<{dim = 0, parent = #blocked1}> +#shared = #triton_gpu.shared<{vec = 2, perPhase = 1, maxPhase = 8, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { + // CHECK-LABEL: basic_insert_slice_async_1d + tt.func @basic_insert_slice_async_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<64> : tensor<64xi32, #slice1d0> + %58 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x!tt.ptr, #slice1d0> + %24 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1d0> + %59 = tt.addptr %58, %24 : tensor<64x!tt.ptr, #slice1d0>, tensor<64xi32, #slice1d0> + %66 = tt.addptr %59, %cst_2 : tensor<64x!tt.ptr, #slice1d0>, tensor<64xi32, #slice1d0> + %71 = triton_gpu.alloc_tensor : tensor<2x64xi64, #shared> + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 + // CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 + // CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 + // CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 + // CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 + // CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 + // CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 + // CHECK-NEXT: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 + // CHECK-NEXT: cp.async.commit_group + %73 = triton_gpu.insert_slice_async %66, %71, %c0_i32 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x!tt.ptr, #slice1d0> -> tensor<2x64xi64, #shared> + triton_gpu.async_commit_group + tt.return + } +} + +// ----- + #block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [4], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> #block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> #block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> @@ -2012,6 +2043,7 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c // ----- +<<<<<<< HEAD // CHECK-LABEL: copyitem // GCN: llvm.store // GCN: llvm.load @@ -2023,11 +2055,21 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { tt.func public @copyitem() attributes {noinline = false} { %cst = arith.constant dense : tensor<4x1xi1, #blocked> +======= +// CHECK-LABEL: reduce_slice +// CHECK-NOT: st.shared +// CHECK-NOT: ld.shared +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [4, 4, 2], warpsPerCTA = [2, 4, 2], order = [2, 0, 1], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [0, 1, 2]}> +#sliced2 = #triton_gpu.slice<{dim = 2, parent = #blocked}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @reduce_slice() attributes {noinline = false} { + %cst = arith.constant dense : tensor<4x1xi1, #sliced2> +>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 %0 = "tt.reduce"(%cst) <{axis = 1 : i32}> ({ ^bb0(%arg0: i1, %arg1: i1): %1 = arith.ori %arg0, %arg1 : i1 tt.reduce.return %1 : i1 - }) : (tensor<4x1xi1, #blocked>) -> tensor<4xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + }) : (tensor<4x1xi1, #sliced2>) -> tensor<4xi1, #triton_gpu.slice<{dim = 1, parent = #sliced2}>> tt.return } } diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 21a7e29b0..e2a39c4b3 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm 2>&1 | FileCheck %s +// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=compute-capability=90 2>&1 | FileCheck %s #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 4], CTASplitNum = [1, 4], CTAOrder = [0, 1]}> #shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 4], CTASplitNum = [1, 4], CTAOrder = [0, 1], hasLeadingOffset = true}> @@ -12,8 +12,8 @@ module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 : %dst = triton_gpu.alloc_tensor : tensor<1x64x64xf16, #shared> %c0 = arith.constant 0 : i32 %src = tt.make_tensor_ptr %basePtr, [%dim0, %dim1], [%stride0, %stride1], [%coord0, %coord1] {order = array} : !tt.ptr, 1> - // CHECK: nvgpu.tma_load_tiled %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {operand_segment_sizes = array} : !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32 - %res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr -> tensor<1x64x64xf16, #shared> + // CHECK: nvgpu.tma_load_tiled %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {operandSegmentSizes = array} : !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32 + %res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array} : !tt.ptr, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr -> tensor<1x64x64xf16, #shared> tt.return } } @@ -34,7 +34,7 @@ module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 : %src = tt.make_tensor_ptr %basePtr, [%dim0, %dim1], [%stride0, %stride1], [%coord0, %coord1] {order = array} : !tt.ptr, 1> // CHECK: %[[C15:.*]] = llvm.mlir.constant(15 : i16) : i16 // CHECK: nvgpu.tma_load_tiled %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C15]] - %res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr -> tensor<1x64x64xf16, #shared> + %res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array} : !tt.ptr, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr -> tensor<1x64x64xf16, #shared> tt.return } } @@ -55,7 +55,7 @@ module attributes {"triton_gpu.num-ctas" = 4 : i32, "triton_gpu.num-warps" = 4 : %src = tt.make_tensor_ptr %basePtr, [%dim0, %dim1], [%stride0, %stride1], [%coord0, %coord1] {order = array} : !tt.ptr, 1> // CHECK: nvgpu.cluster_id // CHECK: nvgpu.tma_load_tiled - %res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr -> tensor<1x64x64xf16, #shared> + %res = triton_nvidia_gpu.insert_slice_async_v2 %src, %dst, %c0, %mbar {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array} : !tt.ptr, 1>, tensor<1x64x64xf16, #shared>, i32, !tt.ptr -> tensor<1x64x64xf16, #shared> tt.return } } @@ -175,6 +175,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-LABEL: @dot_reg_operand_A // Generate a wgmma where the first operand is a struct. // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: tensor<64x64xf16, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> %opA = triton_gpu.convert_layout %a : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> @@ -183,3 +184,29 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: test_fp8_to_f16_conversion + tt.func @test_fp8_to_f16_conversion( + %in0: tensor<128xf8E5M2, #blocked>, %in1: tensor<128xf8E4M3FNUZ, #blocked>, + %in2: tensor<128xf16, #blocked>, %in3: tensor<128xf32, #blocked>) { + // CHECK-COUNT-2: cvt.rn.f16x2.e5m2x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16> + %out0 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xf16, #blocked> + // CHECK-COUNT-2: cvt.rn.f16x2.e4m3x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16> + %out1 = tt.fp_to_fp %in1 : tensor<128xf8E4M3FNUZ, #blocked> -> tensor<128xf16, #blocked> + + // CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8> + %out2 = tt.fp_to_fp %in2 : tensor<128xf16, #blocked> -> tensor<128xf8E5M2, #blocked> + // CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8> + %out3 = tt.fp_to_fp %in2 : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked> + + // CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8> + %out4 = tt.fp_to_fp %in3 : tensor<128xf32, #blocked> -> tensor<128xf8E5M2, #blocked> + // CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8> + %out5 = tt.fp_to_fp %in3 : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked> + tt.return + } +} diff --git a/test/NVGPU/test_cga.mlir b/test/NVGPU/test_cga.mlir index 8b9705db5..8b5a333c7 100644 --- a/test/NVGPU/test_cga.mlir +++ b/test/NVGPU/test_cga.mlir @@ -14,7 +14,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : nvgpu.cga_barrier_arrive nvgpu.cga_barrier_wait - %ptr = llvm.mlir.null : !llvm.ptr + %ptr = llvm.mlir.zero : !llvm.ptr // CHECK: llvm.inline_asm %v = nvgpu.cluster_id diff --git a/test/NVGPU/test_mbarrier.mlir b/test/NVGPU/test_mbarrier.mlir index b12ea5864..95b608810 100644 --- a/test/NVGPU/test_mbarrier.mlir +++ b/test/NVGPU/test_mbarrier.mlir @@ -2,7 +2,7 @@ #SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} { tt.func @test_mbarrier() { - %mbarrier = llvm.mlir.null : !llvm.ptr + %mbarrier = llvm.mlir.zero : !llvm.ptr %pred = arith.constant 1 : i1 // CHECK: llvm.inline_asm nvgpu.mbarrier_init %mbarrier, %pred { count = 32 : i32 } : !llvm.ptr diff --git a/test/NVGPU/test_tma.mlir b/test/NVGPU/test_tma.mlir index 4cf7f9b5e..9ffb35d2d 100644 --- a/test/NVGPU/test_tma.mlir +++ b/test/NVGPU/test_tma.mlir @@ -2,9 +2,9 @@ #SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} { tt.func @test_tma(%im2colOffsets0 : !llvm.struct<(i16, i16)>, %im2colOffsets1 : !llvm.struct<(i16, i16, i16)>) { - %mbarrier = llvm.mlir.null : !llvm.ptr - %tmaDesc = llvm.mlir.null : !llvm.ptr - %dst = llvm.mlir.null : !llvm.ptr + %mbarrier = llvm.mlir.zero : !llvm.ptr + %tmaDesc = llvm.mlir.zero : !llvm.ptr + %dst = llvm.mlir.zero : !llvm.ptr %l2desc = arith.constant 0 : i64 %c0 = arith.constant 0 : i32 %c1 = arith.constant 1 : i32 @@ -16,13 +16,13 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 // CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint // CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint - nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1 {operand_segment_sizes = array}: !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32 - nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operand_segment_sizes = array}: !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32, i32, i32 + nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1 {operandSegmentSizes = array}: !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32 + nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operandSegmentSizes = array}: !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32, i32, i32 // CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint // CHECK: llvm.inline_asm {{.*}} cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint - nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %mask {operand_segment_sizes = array}: !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32, i16 - nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operand_segment_sizes = array}: !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32, i32, i32 + nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %mask {operandSegmentSizes = array}: !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32, i16 + nvgpu.tma_load_tiled %dst, %mbarrier, %tmaDesc, %l2desc, %pred, %c0, %c1, %c2, %c3 {operandSegmentSizes = array}: !llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1, i32, i32, i32, i32 tt.return } diff --git a/test/NVGPU/test_wgmma.mlir b/test/NVGPU/test_wgmma.mlir index 9ad7e606c..33ed5d1bf 100644 --- a/test/NVGPU/test_wgmma.mlir +++ b/test/NVGPU/test_wgmma.mlir @@ -2,7 +2,7 @@ #SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} { tt.func @test_tma(%opC : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) { - %buffer = llvm.mlir.null : !llvm.ptr + %buffer = llvm.mlir.zero : !llvm.ptr %height = arith.constant 16 : i32 // CHECK: llvm.ptrtoint // CHECK: llvm.inline_asm @@ -30,3 +30,15 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 tt.return } } + +// ----- + +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} { + tt.func @wgmma_wait(%in: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) { + // CHECK: // wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63 + // CHECK: wgmma.wait_group.sync.aligned 0; + %out = nvgpu.wgmma_wait_group %in {pendings = 0 : i32} : + !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + tt.return + } +} diff --git a/test/Triton/print.mlir b/test/Triton/print.mlir new file mode 100644 index 000000000..f164e4684 --- /dev/null +++ b/test/Triton/print.mlir @@ -0,0 +1,31 @@ +// RUN: triton-translate %s --mlir-print-ir-after-all -o %t 2>&1 | FileCheck %s + +// CHECK: IR Dump After SCFToControlFlow (convert-scf-to-cf) +// CHECK: tt.func public @add_kernel_0d1d2d3de +// CHECK: IR Dump After ConvertIndexToLLVMPass (convert-index-to-llvm) +// CHECK: tt.func public @add_kernel_0d1d2d3de + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @add_kernel_0d1d2d3de(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg3 : (i32) -> tensor<1024xi32, #blocked> + %6 = "triton_gpu.cmpi"(%4, %5) <{predicate = 2 : i64}> : (tensor<1024xi32, #blocked>, tensor<1024xi32, #blocked>) -> tensor<1024xi1, #blocked> + %7 = tt.splat %arg0 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> + %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %9 = tt.load %8, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked> + %10 = tt.splat %arg1 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> + %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %12 = tt.load %11, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked> + %13 = arith.addf %9, %12 : tensor<1024xf32, #blocked> + %14 = tt.splat %arg2 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> + %15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + tt.store %15, %13, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf32, #blocked> + tt.return + } +} diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 8f5685ae8..869f39146 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -223,7 +223,7 @@ tt.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = // CHECK: [[$col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> // CHECK: [[$col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -// CHECK-LABEL: transpose +// CHECK-LABEL: @transpose module attributes {"triton_gpu.num-warps" = 4 : i32} { tt.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { // CHECK-NOT: triton_gpu.convert_layout @@ -920,7 +920,7 @@ tt.func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: ! %29 = "triton_gpu.select"(%28, %26, %cst_2) : (tensor<16x16xi1, #blocked2>, tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xf32, #blocked2> %30 = "tt.reduce" (%29) ({ ^bb0(%arg4: f32, %arg5: f32): - %max = arith.maxf %arg4, %arg5 : f32 + %max = arith.maximumf %arg4, %arg5 : f32 tt.reduce.return %max : f32 }) {axis = 1 : i32} : (tensor<16x16xf32, #blocked2>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %31 = triton_gpu.convert_layout %30 : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16xf32, #blocked0> @@ -1697,11 +1697,11 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %122 = arith.extf %121 : tensor<128x64xf16, #blocked2> to tensor<128x64xf32, #blocked2> %123 = "tt.reduce"(%122) <{axis = 1 : i32}> ({ ^bb0(%arg28: f32, %arg29: f32): - %153 = arith.maxf %arg28, %arg29 : f32 + %153 = arith.maximumf %arg28, %arg29 : f32 tt.reduce.return %153 : f32 }) : (tensor<128x64xf32, #blocked2>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %124 = triton_gpu.convert_layout %123 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xf32, #blocked1> - %125 = arith.maxf %arg25, %124 : tensor<128xf32, #blocked1> + %125 = arith.maximumf %arg25, %124 : tensor<128xf32, #blocked1> %126 = arith.subf %arg25, %125 : tensor<128xf32, #blocked1> %127 = tt.extern_elementwise %126 {pure = true, libname = "libdevice", libpath = "/root/.pyenv/versions/3.9.9/lib/python3.9/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_exp2f"} : (tensor<128xf32, #blocked1>) -> tensor<128xf32, #blocked1> %128 = triton_gpu.convert_layout %125 : (tensor<128xf32, #blocked1>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> diff --git a/test/TritonGPU/fence-inserstion.mlir b/test/TritonGPU/fence-inserstion.mlir index ff4a1fe1e..7e31a1468 100644 --- a/test/TritonGPU/fence-inserstion.mlir +++ b/test/TritonGPU/fence-inserstion.mlir @@ -32,19 +32,19 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %14 = triton_nvidia_gpu.get_thread_id : i32 %15 = arith.cmpi eq, %14, %c0_i32 : i32 %16 = arith.andi %15, %10 : i1 - triton_nvidia_gpu.mbarrier_arrive %13, %16 {operand_segment_sizes = array, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr, i1 - %17 = triton_nvidia_gpu.insert_slice_async_v2 %3, %11, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + triton_nvidia_gpu.mbarrier_arrive %13, %16 {operandSegmentSizes = array, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr, i1 + %17 = triton_nvidia_gpu.insert_slice_async_v2 %3, %11, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> %18 = triton_gpu.alloc_tensor : tensor<3x128x128xf16, #shared1> - %19 = triton_nvidia_gpu.insert_slice_async_v2 %6, %18, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %19 = triton_nvidia_gpu.insert_slice_async_v2 %6, %18, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> %20 = tt.advance %3, [%c0_i32, %c128_i32] : , 1> %21 = tt.advance %6, [%c128_i32, %c0_i32] : , 1> %22 = arith.cmpi sgt, %arg5, %c128_i32 : i32 %23 = tt.splat %22 : (i1) -> tensor<128x128xi1, #blocked1> %24 = triton_nvidia_gpu.extract_mbarrier %9[%c1_i32] : tensor<3xi64, #shared>, i32 -> %25 = arith.andi %15, %22 : i1 - triton_nvidia_gpu.mbarrier_arrive %24, %25 {operand_segment_sizes = array, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr, i1 - %26 = triton_nvidia_gpu.insert_slice_async_v2 %20, %17, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> - %27 = triton_nvidia_gpu.insert_slice_async_v2 %21, %19, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + triton_nvidia_gpu.mbarrier_arrive %24, %25 {operandSegmentSizes = array, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr, i1 + %26 = triton_nvidia_gpu.insert_slice_async_v2 %20, %17, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %27 = triton_nvidia_gpu.insert_slice_async_v2 %21, %19, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> %28 = triton_gpu.extract_slice %26[0, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1> %29 = triton_gpu.extract_slice %27[0, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1> %30:15 = scf.for %arg9 = %c0_i32 to %arg5 step %c128_i32 iter_args(%arg10 = %cst, %arg11 = %3, %arg12 = %6, %arg13 = %26, %arg14 = %27, %arg15 = %28, %arg16 = %29, %arg17 = %20, %arg18 = %21, %arg19 = %c128_i32, %arg20 = %c2_i32, %arg21 = %c0_i32, %arg22 = %c0_i32, %arg23 = %false, %arg24 = %true) -> (tensor<128x128xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, tensor<3x128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, !tt.ptr, 1>, !tt.ptr, 1>, i32, i32, i32, i32, i1, i1) : i32 { @@ -52,7 +52,6 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c triton_nvidia_gpu.mbarrier_wait %33, %arg23 : // CHECK: triton_nvidia_gpu.fence_async_shared %34 = triton_nvidia_gpu.dot_async %arg15, %arg16, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf16, #shared1> * tensor<128x128xf16, #shared1> -> tensor<128x128xf32, #mma> - triton_nvidia_gpu.dot_wait {pendings = 1 : i32} %35 = tt.advance %arg11, [%c0_i32, %c128_i32] : , 1> %36 = tt.advance %arg12, [%c128_i32, %c0_i32] : , 1> %37 = arith.addi %arg19, %c128_i32 : i32 @@ -65,10 +64,10 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %44 = tt.splat %38 : (i1) -> tensor<128x128xi1, #blocked1> %45 = triton_nvidia_gpu.extract_mbarrier %9[%arg20] : tensor<3xi64, #shared>, i32 -> %46 = arith.andi %15, %38 : i1 - triton_nvidia_gpu.mbarrier_arrive %45, %46 {operand_segment_sizes = array, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr, i1 - %47 = triton_nvidia_gpu.insert_slice_async_v2 %42, %arg13, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + triton_nvidia_gpu.mbarrier_arrive %45, %46 {operandSegmentSizes = array, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr, i1 + %47 = triton_nvidia_gpu.insert_slice_async_v2 %42, %arg13, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> %48 = triton_gpu.extract_slice %47[%41, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1> - %49 = triton_nvidia_gpu.insert_slice_async_v2 %43, %arg14, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %49 = triton_nvidia_gpu.insert_slice_async_v2 %43, %arg14, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> %50 = triton_gpu.extract_slice %49[%41, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1> %b_48 = triton_gpu.convert_layout %48 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #blocked1> %s_48 = triton_gpu.convert_layout %b_48 : (tensor<128x128xf16, #blocked1>) -> tensor<128x128xf16, #shared1> @@ -88,10 +87,8 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %64 = arith.ori %62, %63 : i1 scf.yield %34, %35, %36, %47, %49, %s_48, %50, %42, %43, %37, %53, %41, %54, %59, %64 : tensor<128x128xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, tensor<3x128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, !tt.ptr, 1>, !tt.ptr, 1>, i32, i32, i32, i32, i1, i1 } - scf.if %10 { - triton_nvidia_gpu.dot_wait {pendings = 0 : i32} - } - %31 = arith.truncf %30#0 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> + %w = triton_nvidia_gpu.dot_wait %30#0 {pendings = 0 : i32} : tensor<128x128xf32, #mma> + %31 = arith.truncf %w : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> %32 = triton_gpu.convert_layout %31 : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #shared1> triton_nvidia_gpu.store_async %8, %32 : !tt.ptr, 1>, tensor<128x128xf16, #shared1> triton_gpu.async_bulk_commit_group @@ -136,19 +133,19 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %14 = triton_nvidia_gpu.get_thread_id : i32 %15 = arith.cmpi eq, %14, %c0_i32 : i32 %16 = arith.andi %15, %10 : i1 - triton_nvidia_gpu.mbarrier_arrive %13, %16 {operand_segment_sizes = array, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr, i1 - %17 = triton_nvidia_gpu.insert_slice_async_v2 %3, %11, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + triton_nvidia_gpu.mbarrier_arrive %13, %16 {operandSegmentSizes = array, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr, i1 + %17 = triton_nvidia_gpu.insert_slice_async_v2 %3, %11, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> %18 = triton_gpu.alloc_tensor : tensor<3x128x128xf16, #shared1> - %19 = triton_nvidia_gpu.insert_slice_async_v2 %6, %18, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %19 = triton_nvidia_gpu.insert_slice_async_v2 %6, %18, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> %20 = tt.advance %3, [%c0_i32, %c128_i32] : , 1> %21 = tt.advance %6, [%c128_i32, %c0_i32] : , 1> %22 = arith.cmpi sgt, %arg5, %c128_i32 : i32 %23 = tt.splat %22 : (i1) -> tensor<128x128xi1, #blocked1> %24 = triton_nvidia_gpu.extract_mbarrier %9[%c1_i32] : tensor<3xi64, #shared>, i32 -> %25 = arith.andi %15, %22 : i1 - triton_nvidia_gpu.mbarrier_arrive %24, %25 {operand_segment_sizes = array, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr, i1 - %26 = triton_nvidia_gpu.insert_slice_async_v2 %20, %17, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> - %27 = triton_nvidia_gpu.insert_slice_async_v2 %21, %19, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + triton_nvidia_gpu.mbarrier_arrive %24, %25 {operandSegmentSizes = array, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr, i1 + %26 = triton_nvidia_gpu.insert_slice_async_v2 %20, %17, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %27 = triton_nvidia_gpu.insert_slice_async_v2 %21, %19, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> %28 = triton_gpu.extract_slice %26[0, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1> %29 = triton_gpu.extract_slice %27[0, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1> %b_29 = triton_gpu.convert_layout %29 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #blocked1> @@ -158,7 +155,6 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c triton_nvidia_gpu.mbarrier_wait %33, %arg23 : // CHECK: triton_nvidia_gpu.fence_async_shared %34 = triton_nvidia_gpu.dot_async %arg15, %arg16, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf16, #shared1> * tensor<128x128xf16, #shared1> -> tensor<128x128xf32, #mma> - triton_nvidia_gpu.dot_wait {pendings = 1 : i32} %35 = tt.advance %arg11, [%c0_i32, %c128_i32] : , 1> %36 = tt.advance %arg12, [%c128_i32, %c0_i32] : , 1> %37 = arith.addi %arg19, %c128_i32 : i32 @@ -171,10 +167,10 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %44 = tt.splat %38 : (i1) -> tensor<128x128xi1, #blocked1> %45 = triton_nvidia_gpu.extract_mbarrier %9[%arg20] : tensor<3xi64, #shared>, i32 -> %46 = arith.andi %15, %38 : i1 - triton_nvidia_gpu.mbarrier_arrive %45, %46 {operand_segment_sizes = array, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr, i1 - %47 = triton_nvidia_gpu.insert_slice_async_v2 %42, %arg13, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + triton_nvidia_gpu.mbarrier_arrive %45, %46 {operandSegmentSizes = array, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr, i1 + %47 = triton_nvidia_gpu.insert_slice_async_v2 %42, %arg13, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> %48 = triton_gpu.extract_slice %47[%41, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1> - %49 = triton_nvidia_gpu.insert_slice_async_v2 %43, %arg14, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %49 = triton_nvidia_gpu.insert_slice_async_v2 %43, %arg14, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> %50 = triton_gpu.extract_slice %49[%41, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1> %51 = arith.addi %arg20, %c1_i32 : i32 %52 = arith.cmpi uge, %51, %c3_i32 : i32 @@ -192,10 +188,8 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %64 = arith.ori %62, %63 : i1 scf.yield %34, %35, %36, %47, %49, %48, %50, %42, %43, %37, %53, %41, %54, %59, %64 : tensor<128x128xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, tensor<3x128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, !tt.ptr, 1>, !tt.ptr, 1>, i32, i32, i32, i32, i1, i1 } - scf.if %10 { - triton_nvidia_gpu.dot_wait {pendings = 0 : i32} - } - %31 = arith.truncf %30#0 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> + %w = triton_nvidia_gpu.dot_wait %30#0 {pendings = 0 : i32} : tensor<128x128xf32, #mma> + %31 = arith.truncf %w : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> %32 = triton_gpu.convert_layout %31 : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #shared1> triton_nvidia_gpu.store_async %8, %32 : !tt.ptr, 1>, tensor<128x128xf16, #shared1> triton_gpu.async_bulk_commit_group diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index a5bb2f239..d56b27291 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -15,7 +15,6 @@ // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32 // CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] // CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor // CHECK-DAG: %[[LOOP_COND_0_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_0]] @@ -32,13 +31,13 @@ // CHECK: triton_gpu.async_wait {num = 2 : i32} // CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0] // CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_0]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] // CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] // CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.convert_layout %[[arg_b0]] // CHECK: %[[arg_b0_dot_op_1:.*]] = arith.mulf %[[arg_b0_dot_op_0]] // CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_1]], {{.*}} // CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] : i32 -// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_3]] +// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_2]] // CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[NEXT_LOOP_IDX]] // CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]] // CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]] @@ -46,7 +45,7 @@ // CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0] // CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0] // CHECK-DAG: %[[PIPELINE_IDX_PLUS_ONE:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] -// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_3]] +// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_2]] // CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX_PLUS_ONE]] // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[EXTRACT_IDX]] tt.func @matmul_loop(%lb : index, %ub : index, %step : index, @@ -97,7 +96,6 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32 // CHECK: scf.for // CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor // CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] @@ -108,12 +106,12 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // CHECK: triton_gpu.async_wait {num = 2 : i32} // CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0] // CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_0]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] // CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] // CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] // CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}} // CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] : i32 -// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_3]] +// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_2]] // CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[NEXT_LOOP_IDX]] // CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]] // CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]] @@ -121,7 +119,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0] // CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0] // CHECK-DAG: %[[PIPELINE_IDX_PLUS_ONE:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] -// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_3]] +// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_2]] // CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX_PLUS_ONE]] // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[EXTRACT_IDX]] tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, @@ -174,23 +172,22 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32 // CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor // CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: triton_gpu.async_wait {num = 1 : i32} // CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_0]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] // CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] // CHECK: tt.dot {{.*}}, %[[arg_b0_dot_op]], {{.*}} // CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] : i32 -// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_3]] +// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_2]] // CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[NEXT_LOOP_IDX]] // CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]] // CHECK: triton_gpu.async_wait {num = 1 : i32} // CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0] // CHECK-DAG: %[[PIPELINE_IDX_PLUS_ONE:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] -// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_3]] +// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_2]] // CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX_PLUS_ONE]] // CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[EXTRACT_IDX]] tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, diff --git a/test/TritonGPU/materialize-load-store.mlir b/test/TritonGPU/materialize-load-store.mlir index 58bc51514..d8a8f85d5 100644 --- a/test/TritonGPU/materialize-load-store.mlir +++ b/test/TritonGPU/materialize-load-store.mlir @@ -12,7 +12,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %a_tileptr_init = tt.make_tensor_ptr %A, [%c64, %c16], [%c16, %c1], [%c0, %c0] { order = array } : !tt.ptr, 1> // CHECK: %[[BUFFER:.*]] = triton_gpu.alloc_tensor : tensor<1x64x16xf16, #shared> // CHECK: %[[MBAR:.*]] = triton_nvidia_gpu.alloc_mbarrier {count = 1 : i32} : !tt.ptr - // CHECK: triton_nvidia_gpu.mbarrier_arrive %[[MBAR]], %{{.*}} {operand_segment_sizes = array, trackAsyncOp = false, txCount = 2048 : i32} : !tt.ptr, i1 + // CHECK: triton_nvidia_gpu.mbarrier_arrive %[[MBAR]], %{{.*}} {operandSegmentSizes = array, trackAsyncOp = false, txCount = 2048 : i32} : !tt.ptr, i1 // CHECK: %[[INSERT:.*]] = triton_nvidia_gpu.insert_slice_async_v2 %[[TENSOR_PTR]], %[[BUFFER]], %{{.*}}, %[[MBAR]] // CHECK: %[[EXT:.*]] = triton_gpu.extract_slice %[[INSERT]][0, 0, 0] [1, 64, 16] [1, 1, 1] : tensor<1x64x16xf16, #shared> to tensor<64x16xf16, #shared> // CHECK: triton_nvidia_gpu.mbarrier_wait %[[MBAR]], %false : diff --git a/test/TritonGPU/matmul.mlir b/test/TritonGPU/matmul.mlir index 6c9264400..6a2b6fabc 100644 --- a/test/TritonGPU/matmul.mlir +++ b/test/TritonGPU/matmul.mlir @@ -1,8 +1,8 @@ // RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu -tritongpu-remove-layout-conversions -tritongpu-pipeline=num-stages=3 -test-print-allocation 2>&1 | FileCheck %s -// CHECK: offset = 0, size = 49152 -// CHECK: offset = 49152, size = 49152 -// CHECK: size = 98304 +// CHECK: offset = 0, size = 32768 +// CHECK: offset = 32768, size = 32768 +// CHECK: size = 65536 module { tt.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c64_13c64_14c64_15c8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32) { %cst = arith.constant dense : tensor<64x64xi1> diff --git a/test/TritonGPU/wsmaterialization.mlir b/test/TritonGPU/wsmaterialization.mlir index 07ee80f9b..fd9fcf535 100644 --- a/test/TritonGPU/wsmaterialization.mlir +++ b/test/TritonGPU/wsmaterialization.mlir @@ -1,234 +1,172 @@ // RUN: triton-opt -split-input-file -triton-nvidia-gpu-ws-materialization='compute-capability=90' %s | FileCheck %s -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> #mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.enable-warp-specialization" = 1 : i32} { +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"async.num-agents" = 2 : i32, "triton_gpu.compute-capability" = 90 : i32, "triton_gpu.enable-warp-specialization" = 1 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @simple_gemm // CHECK: triton_nvidia_gpu.alloc_mbarrier // CHECK: scf.if // CHECK: scf.for // CHECK: triton_nvidia_gpu.extract_mbarrier // CHECK: triton_nvidia_gpu.mbarrier_wait - // CHECK: triton_gpu.insert_slice - // CHECK: triton_gpu.insert_slice + // CHECK: triton_nvidia_gpu.insert_slice_async_v2 + // CHECK: triton_nvidia_gpu.insert_slice_async_v2 // CHECK: triton_nvidia_gpu.extract_mbarrier // CHECK: triton_nvidia_gpu.mbarrier_arrive // CHECK: scf.yield // CHECK: scf.if - // CHECK: triton_nvidia_gpu.bar_wait // CHECK: scf.for // CHECK: triton_nvidia_gpu.extract_mbarrier // CHECK: triton_nvidia_gpu.mbarrier_wait // CHECK: triton_gpu.extract_slice // CHECK: triton_gpu.extract_slice - // CHECK: tt.dot + // CHECK: triton_nvidia_gpu.dot_async // CHECK: triton_nvidia_gpu.extract_mbarrier // CHECK: triton_nvidia_gpu.mbarrier_arrive // CHECK: scf.yield - // CHECK: triton_nvidia_gpu.bar_arrive - // CHECK: triton_nvidia_gpu.bar_wait + // CHECK: triton_nvidia_gpu.dot_wait + // CHECK: triton_nvidia_gpu.extract_mbarrier + // CHECK: triton_nvidia_gpu.mbarrier_arrive // CHECK: tt.store - // CHECK: triton_nvidia_gpu.bar_arrive - tt.func public @simple_gemm(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) { - %0 = triton_gpu.alloc_tensor : tensor<3x32x128xf16, #shared> - %1 = triton_gpu.alloc_tensor : tensor<3x128x32xf16, #shared1> + tt.func public @simple_gemm(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg6: i32 {tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg10: i32 {tt.max_divisibility = 8 : i32}, %arg11: i32 {tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %0 = triton_gpu.alloc_tensor : tensor<3x128x64xf16, #shared> + %1 = triton_gpu.alloc_tensor : tensor<3x64x128xf16, #shared1> %2 = triton_nvidia_gpu.create_token {num = 3 : i32} : tensor<3x!triton_nvidia_gpu.token> - %3 = triton_nvidia_gpu.create_mutex : !triton_nvidia_gpu.mutex - %4 = triton_nvidia_gpu.create_mutex : !triton_nvidia_gpu.mutex - %5 = triton_nvidia_gpu.get_agent_id : i32 %c0_i32 = arith.constant 0 : i32 - %6 = arith.cmpi eq, %5, %c0_i32 : i32 - scf.if %6 { - %cst = arith.constant {async_agent = dense<0> : vector<1xi32>} dense<32> : tensor<32x128xi32, #blocked> - %cst_1 = arith.constant {async_agent = dense<0> : vector<1xi32>} dense<32> : tensor<128x32xi32, #blocked1> - %c31_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 31 : i32 - %c127_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 127 : i32 - %c1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 1 : index - %c0 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : index - %c32_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 32 : i32 - %c128_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 128 : i32 - %c8_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 8 : i32 - %8 = tt.get_program_id x {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %9 = tt.get_program_id y {async_agent = dense<0> : vector<1xi32>} : i32 - %10 = arith.addi %arg3, %c127_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %11 = arith.divsi %10, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %12 = arith.addi %arg4, %c127_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %13 = arith.divsi %12, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %14 = arith.muli %13, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %15 = arith.divsi %8, %14 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %16 = arith.muli %15, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %17 = arith.subi %11, %16 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %18 = arith.cmpi slt, %17, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %19 = arith.select %18, %17, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %20 = arith.remsi %8, %19 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %21 = arith.addi %16, %20 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %22 = arith.remsi %8, %14 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %23 = arith.divsi %22, %19 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %24 = arith.muli %21, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %25 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %26 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %27 = tt.splat %24 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %28 = arith.addi %27, %25 {async_agent = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %29 = arith.muli %23, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %30 = tt.splat %29 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %31 = arith.addi %30, %26 {async_agent = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %32 = tt.splat %arg3 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %33 = arith.remsi %28, %32 {async_agent = dense<0> : vector<1xi32>, tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %34 = tt.splat %arg4 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %35 = arith.remsi %31, %34 {async_agent = dense<0> : vector<1xi32>, tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %36 = arith.muli %9, %c32_i32 {async_agent = dense<0> : vector<1xi32>} : i32 - %37 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %38 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %39 = tt.splat %36 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %40 = tt.splat %36 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %41 = arith.addi %39, %37 {async_agent = dense<0> : vector<1xi32>} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %42 = arith.addi %40, %38 {async_agent = dense<0> : vector<1xi32>} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %43 = tt.expand_dims %33 {async_agent = dense<0> : vector<1xi32>, axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1> - %44 = tt.splat %arg6 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<128x1xi32, #blocked1> - %45 = arith.muli %43, %44 {async_agent = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked1> - %46 = tt.expand_dims %41 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x32xi32, #blocked1> - %47 = tt.broadcast %45 {async_agent = dense<0> : vector<1xi32>} : (tensor<128x1xi32, #blocked1>) -> tensor<128x32xi32, #blocked1> - %48 = tt.broadcast %46 {async_agent = dense<0> : vector<1xi32>} : (tensor<1x32xi32, #blocked1>) -> tensor<128x32xi32, #blocked1> - %49 = arith.addi %47, %48 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32xi32, #blocked1> - %50 = tt.splat %arg0 {async_agent = dense<0> : vector<1xi32>} : (!tt.ptr) -> tensor<128x32x!tt.ptr, #blocked1> - %51 = tt.addptr %50, %49 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> - %52 = tt.expand_dims %42 {async_agent = dense<0> : vector<1xi32>, axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32x1xi32, #blocked> - %53 = tt.expand_dims %35 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x128xi32, #blocked> - %54 = tt.splat %arg7 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<1x128xi32, #blocked> - %55 = arith.muli %53, %54 {async_agent = dense<0> : vector<1xi32>} : tensor<1x128xi32, #blocked> - %56 = tt.broadcast %52 {async_agent = dense<0> : vector<1xi32>} : (tensor<32x1xi32, #blocked>) -> tensor<32x128xi32, #blocked> - %57 = tt.broadcast %55 {async_agent = dense<0> : vector<1xi32>} : (tensor<1x128xi32, #blocked>) -> tensor<32x128xi32, #blocked> - %58 = arith.addi %56, %57 {async_agent = dense<0> : vector<1xi32>} : tensor<32x128xi32, #blocked> - %59 = tt.splat %arg1 {async_agent = dense<0> : vector<1xi32>} : (!tt.ptr) -> tensor<32x128x!tt.ptr, #blocked> - %60 = tt.addptr %59, %58 {async_agent = dense<0> : vector<1xi32>} : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> - %61 = arith.addi %arg5, %c31_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %62 = arith.divsi %61, %c32_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %63 = arith.index_cast %62 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 to index - %c0_i32_2 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32 - %64:3 = scf.for %arg9 = %c0 to %63 step %c1 iter_args(%arg10 = %51, %arg11 = %60, %arg12 = %c0_i32_2) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, i32) { - triton_nvidia_gpu.producer_acquire %2, %arg12 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 - %65 = triton_gpu.insert_slice %arg10, %1, %arg12 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32x!tt.ptr, #blocked1> -> tensor<3x128x32xf16, #shared1> - %66 = triton_gpu.insert_slice %arg11, %0, %arg12 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128x!tt.ptr, #blocked> -> tensor<3x32x128xf16, #shared> - %67 = tt.addptr %arg10, %cst_1 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> - %68 = tt.addptr %arg11, %cst {async_agent = dense<0> : vector<1xi32>} : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> - %c1_i32_3 = arith.constant {async_agent = dense<0> : vector<1xi32>} 1 : i32 - %c3_i32 = arith.constant {async_agent = dense<0> : vector<1xi32>} 3 : i32 - %69 = arith.addi %arg12, %c1_i32_3 {async_agent = dense<0> : vector<1xi32>} : i32 - %70 = arith.remsi %69, %c3_i32 {async_agent = dense<0> : vector<1xi32>} : i32 - triton_nvidia_gpu.producer_commit %2, %arg12 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 - scf.yield %67, %68, %70 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, i32 - } {async_agent = dense<0> : vector<1xi32>} - } - %c1_i32 = arith.constant 1 : i32 - %c1_i32_0 = arith.constant 1 : i32 - %7 = arith.cmpi sge, %5, %c1_i32_0 : i32 - scf.if %7 { - %cst = arith.constant {async_agent = dense<1> : vector<1xi32>} dense<0.000000e+00> : tensor<128x128xf32, #mma> - %c31_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 31 : i32 - %c127_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 127 : i32 - %c1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 1 : index - %c0 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : index - %c32_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 32 : i32 - %c128_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 128 : i32 - %c8_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 8 : i32 - %8 = tt.get_program_id x {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %9 = arith.addi %arg3, %c127_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %10 = arith.divsi %9, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %11 = arith.addi %arg4, %c127_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %12 = arith.divsi %11, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %13 = arith.muli %12, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %14 = arith.divsi %8, %13 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %15 = arith.muli %14, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %16 = arith.subi %10, %15 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %17 = arith.cmpi slt, %16, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %18 = arith.select %17, %16, %c8_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %19 = arith.remsi %8, %18 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %20 = arith.addi %15, %19 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %21 = arith.remsi %8, %13 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %22 = arith.divsi %21, %18 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %23 = arith.muli %20, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %24 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %25 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %26 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %27 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %28 = tt.splat %23 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %29 = tt.splat %23 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %30 = arith.addi %28, %24 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %31 = arith.addi %29, %26 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %32 = arith.muli %22, %c128_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %33 = tt.splat %32 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %34 = tt.splat %32 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %35 = arith.addi %33, %25 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %36 = arith.addi %34, %27 {async_agent = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %37 = tt.splat %arg3 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %38 = tt.splat %arg4 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %39 = arith.addi %arg5, %c31_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %40 = arith.divsi %39, %c32_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %41 = arith.index_cast %40 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 to index + %c127_i32 = arith.constant 127 : i32 + %c1_i64 = arith.constant 1 : i64 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %3 = tt.get_program_id x : i32 + %4 = arith.addi %arg6, %c127_i32 : i32 + %5 = arith.divsi %4, %c128_i32 : i32 + %6 = arith.addi %arg5, %c127_i32 : i32 + %7 = arith.divsi %6, %c128_i32 : i32 + %8 = arith.muli %5, %c8_i32 : i32 + %9 = arith.divsi %3, %8 : i32 + %10 = arith.muli %9, %c8_i32 : i32 + %11 = arith.subi %7, %10 : i32 + %12 = arith.minsi %11, %c8_i32 : i32 + %13 = arith.remsi %3, %12 : i32 + %14 = arith.addi %10, %13 : i32 + %15 = arith.remsi %3, %8 : i32 + %16 = arith.divsi %15, %12 : i32 + %17 = arith.muli %14, %c128_i32 : i32 + %18 = arith.muli %16, %c128_i32 : i32 + %19 = arith.extsi %arg5 : i32 to i64 + %20 = arith.extsi %arg7 : i32 to i64 + %21 = arith.extsi %arg8 : i32 to i64 + %22 = tt.make_tensor_ptr %arg0, [%19, %20], [%21, %c1_i64], [%17, %c0_i32] {order = array} : , 1> + %23 = arith.extsi %arg6 : i32 to i64 + %24 = arith.extsi %arg9 : i32 to i64 + %25 = tt.make_tensor_ptr %arg1, [%20, %23], [%c1_i64, %24], [%c0_i32, %18] {order = array} : , 1> + %26 = arith.extsi %arg11 : i32 to i64 + %27 = tt.make_tensor_ptr %arg4, [%19, %23], [%26, %c1_i64], [%17, %18] {order = array} : , 1> + %28 = triton_nvidia_gpu.get_agent_id : i32 + %c0_i32_0 = arith.constant 0 : i32 + %29 = arith.cmpi eq, %28, %c0_i32_0 : i32 + scf.if %29 { + %c64_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 64 : i32 + %c3_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 3 : i32 %c0_i32_1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32 - triton_nvidia_gpu.lock %3 {mutex.barId = dense<1> : vector<1xi32>, mutex.numThreads = dense<256> : vector<1xi32>} : !triton_nvidia_gpu.mutex - %42:2 = scf.for %arg9 = %c0 to %41 step %c1 iter_args(%arg10 = %cst, %arg11 = %c0_i32_1) -> (tensor<128x128xf32, #mma>, i32) { - triton_nvidia_gpu.consumer_wait %2, %arg11 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 - %62 = triton_gpu.extract_slice %1[%arg11, 0, 0] [1, 128, 32] [1, 1, 1] {async_agent = dense<1> : vector<1xi32>} : tensor<3x128x32xf16, #shared1> to tensor<128x32xf16, #shared1> - %63 = triton_gpu.extract_slice %0[%arg11, 0, 0] [1, 32, 128] [1, 1, 1] {async_agent = dense<1> : vector<1xi32>} : tensor<3x32x128xf16, #shared> to tensor<32x128xf16, #shared> - %64 = triton_gpu.convert_layout %62 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x32xf16, #shared1>) -> tensor<128x32xf16, #shared1> - %65 = triton_gpu.convert_layout %63 {async_agent = dense<1> : vector<1xi32>} : (tensor<32x128xf16, #shared>) -> tensor<32x128xf16, #shared> - %66 = tt.dot %64, %65, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared1> * tensor<32x128xf16, #shared> -> tensor<128x128xf32, #mma> - %c1_i32_2 = arith.constant {async_agent = dense<1> : vector<1xi32>} 1 : i32 - %c3_i32 = arith.constant {async_agent = dense<1> : vector<1xi32>} 3 : i32 - %67 = arith.addi %arg11, %c1_i32_2 {async_agent = dense<1> : vector<1xi32>} : i32 - %68 = arith.remsi %67, %c3_i32 {async_agent = dense<1> : vector<1xi32>} : i32 - triton_nvidia_gpu.consumer_release %2, %arg11 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 - scf.yield %66, %68 : tensor<128x128xf32, #mma>, i32 + %false = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} false + %31:4 = scf.for %arg12 = %c0_i32 to %arg7 step %c64_i32 iter_args(%arg13 = %22, %arg14 = %25, %arg15 = %false, %arg16 = %c0_i32_1) -> (!tt.ptr, 1>, !tt.ptr, 1>, i1, i32) : i32 { + triton_nvidia_gpu.producer_acquire %2, %arg16 {async_agent = dense<0> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 + %32 = triton_gpu.insert_slice %arg13, %0, %arg16 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<3x128x64xf16, #shared> + %33 = triton_gpu.insert_slice %arg14, %1, %arg16 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<3x64x128xf16, #shared1> + triton_nvidia_gpu.producer_commit %2, %arg16 {async_agent = dense<0> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 + %34 = tt.advance %arg13, [%c0_i32, %c64_i32] {async_agent = dense<0> : vector<1xi32>} : , 1> + %35 = tt.advance %arg14, [%c64_i32, %c0_i32] {async_agent = dense<0> : vector<1xi32>} : , 1> + %c1_i32_2 = arith.constant {async_agent = dense<0> : vector<1xi32>} 1 : i32 + %c0_i32_3 = arith.constant {async_agent = dense<0> : vector<1xi32>} 0 : i32 + %true = arith.constant {async_agent = dense<0> : vector<1xi32>} true + %36 = arith.addi %arg16, %c1_i32_2 {async_agent = dense<0> : vector<1xi32>} : i32 + %37 = arith.cmpi uge, %36, %c3_i32 {async_agent = dense<0> : vector<1xi32>} : i32 + %38 = arith.cmpi ult, %36, %c3_i32 {async_agent = dense<0> : vector<1xi32>} : i32 + %39 = arith.subi %36, %c3_i32 {async_agent = dense<0> : vector<1xi32>} : i32 + %40 = arith.select %37, %39, %36 {async_agent = dense<0> : vector<1xi32>} : i32 + %41 = arith.xori %arg15, %true {async_agent = dense<0> : vector<1xi32>} : i1 + %42 = arith.andi %37, %41 {async_agent = dense<0> : vector<1xi32>} : i1 + %43 = arith.andi %38, %arg15 {async_agent = dense<0> : vector<1xi32>} : i1 + %44 = arith.ori %42, %43 {async_agent = dense<0> : vector<1xi32>} : i1 + scf.yield {async_agent = dense<0> : vector<1xi32>} %34, %35, %44, %40 : !tt.ptr, 1>, !tt.ptr, 1>, i1, i32 + } {async_agent = dense<0> : vector<1xi32>} + } {async_agent = dense<0> : vector<1xi32>} + %c1_i32 = arith.constant 1 : i32 + %30 = arith.cmpi eq, %28, %c1_i32 : i32 + scf.if %30 { + %cst = arith.constant {async_agent = dense<1> : vector<1xi32>} dense<0.000000e+00> : tensor<128x128xf32, #mma> + %c64_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 64 : i32 + %c3_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 3 : i32 + %c0_i32_1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %false = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} false + %31:3 = scf.for %arg12 = %c0_i32 to %arg7 step %c64_i32 iter_args(%arg13 = %cst, %arg14 = %false, %arg15 = %c0_i32_1) -> (tensor<128x128xf32, #mma>, i1, i32) : i32 { + triton_nvidia_gpu.consumer_wait %2, %arg15 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 + %37 = triton_gpu.extract_slice %0[%arg15, 0, 0] [1, 128, 64] [1, 1, 1] {async_agent = dense<1> : vector<1xi32>} : tensor<3x128x64xf16, #shared> to tensor<128x64xf16, #shared> + %38 = triton_gpu.convert_layout %37 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x64xf16, #shared>) -> tensor<128x64xf16, #shared> + %39 = triton_gpu.extract_slice %1[%arg15, 0, 0] [1, 64, 128] [1, 1, 1] {async_agent = dense<1> : vector<1xi32>} : tensor<3x64x128xf16, #shared1> to tensor<64x128xf16, #shared1> + %40 = triton_gpu.convert_layout %39 {async_agent = dense<1> : vector<1xi32>} : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #shared1> + %41 = triton_nvidia_gpu.dot_async %38, %40, %arg13 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<128x128xf32, #mma> + %42 = arith.cmpi sgt, %arg12, %c0_i32 {async_agent = dense<1> : vector<1xi32>} : i32 + scf.if %42 { + %c0_i32_6 = arith.constant {async_agent = dense<1> : vector<1xi32>} 0 : i32 + %c1_i32_7 = arith.constant {async_agent = dense<1> : vector<1xi32>} 1 : i32 + %c2_i32_8 = arith.constant {async_agent = dense<1> : vector<1xi32>} 2 : i32 + %52 = arith.subi %arg15, %c1_i32_7 {async_agent = dense<1> : vector<1xi32>} : i32 + %53 = arith.cmpi eq, %arg15, %c0_i32_6 {async_agent = dense<1> : vector<1xi32>} : i32 + %54 = arith.select %53, %c2_i32_8, %52 {async_agent = dense<1> : vector<1xi32>} : i32 + triton_nvidia_gpu.consumer_release %2, %54 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 + } {async_agent = dense<1> : vector<1xi32>} + %c1_i32_4 = arith.constant {async_agent = dense<1> : vector<1xi32>} 1 : i32 + %c0_i32_5 = arith.constant {async_agent = dense<1> : vector<1xi32>} 0 : i32 + %true = arith.constant {async_agent = dense<1> : vector<1xi32>} true + %43 = arith.addi %arg15, %c1_i32_4 {async_agent = dense<1> : vector<1xi32>} : i32 + %44 = arith.cmpi uge, %43, %c3_i32 {async_agent = dense<1> : vector<1xi32>} : i32 + %45 = arith.cmpi ult, %43, %c3_i32 {async_agent = dense<1> : vector<1xi32>} : i32 + %46 = arith.subi %43, %c3_i32 {async_agent = dense<1> : vector<1xi32>} : i32 + %47 = arith.select %44, %46, %43 {async_agent = dense<1> : vector<1xi32>} : i32 + %48 = arith.xori %arg14, %true {async_agent = dense<1> : vector<1xi32>} : i1 + %49 = arith.andi %44, %48 {async_agent = dense<1> : vector<1xi32>} : i1 + %50 = arith.andi %45, %arg14 {async_agent = dense<1> : vector<1xi32>} : i1 + %51 = arith.ori %49, %50 {async_agent = dense<1> : vector<1xi32>} : i1 + scf.yield {async_agent = dense<1> : vector<1xi32>} %41, %51, %47 : tensor<128x128xf32, #mma>, i1, i32 } {async_agent = dense<1> : vector<1xi32>} - triton_nvidia_gpu.unlock %3 {mutex.barId = dense<2> : vector<1xi32>, mutex.numThreads = dense<256> : vector<1xi32>} : !triton_nvidia_gpu.mutex - triton_nvidia_gpu.lock %4 {mutex.barId = dense<3> : vector<1xi32>, mutex.numThreads = dense<256> : vector<1xi32>} : !triton_nvidia_gpu.mutex - %43 = arith.truncf %42#0 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> - %44 = tt.expand_dims %30 {async_agent = dense<1> : vector<1xi32>, axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2> - %45 = tt.splat %arg8 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128x1xi32, #blocked2> - %46 = arith.muli %44, %45 {async_agent = dense<1> : vector<1xi32>} : tensor<128x1xi32, #blocked2> - %47 = tt.expand_dims %35 {async_agent = dense<1> : vector<1xi32>, axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi32, #blocked2> - %48 = tt.broadcast %46 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x1xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> - %49 = tt.broadcast %47 {async_agent = dense<1> : vector<1xi32>} : (tensor<1x128xi32, #blocked2>) -> tensor<128x128xi32, #blocked2> - %50 = arith.addi %48, %49 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128xi32, #blocked2> - %51 = tt.splat %arg2 {async_agent = dense<1> : vector<1xi32>} : (!tt.ptr) -> tensor<128x128x!tt.ptr, #blocked2> - %52 = tt.addptr %51, %50 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr, #blocked2>, tensor<128x128xi32, #blocked2> - %53 = "triton_gpu.cmpi"(%31, %37) {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %54 = tt.expand_dims %53 {async_agent = dense<1> : vector<1xi32>, axis = 1 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi1, #blocked2> - %55 = "triton_gpu.cmpi"(%36, %38) {async_agent = dense<1> : vector<1xi32>, predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %56 = tt.expand_dims %55 {async_agent = dense<1> : vector<1xi32>, axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x128xi1, #blocked2> - %57 = tt.broadcast %54 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x1xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> - %58 = tt.broadcast %56 {async_agent = dense<1> : vector<1xi32>} : (tensor<1x128xi1, #blocked2>) -> tensor<128x128xi1, #blocked2> - %59 = arith.andi %57, %58 {async_agent = dense<1> : vector<1xi32>} : tensor<128x128xi1, #blocked2> - %60 = triton_gpu.convert_layout %43 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #blocked2> - tt.store %52, %60, %59 {async_agent = dense<1> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32} : tensor<128x128xf16, #blocked2> - triton_nvidia_gpu.unlock %4 {mutex.barId = dense<4> : vector<1xi32>, mutex.numThreads = dense<256> : vector<1xi32>} : !triton_nvidia_gpu.mutex - } + %32 = triton_nvidia_gpu.dot_wait %31#0 {async_agent = dense<1> : vector<1xi32>, pendings = 0 : i32} : tensor<128x128xf32, #mma> + %c0_i32_2 = arith.constant {async_agent = dense<1> : vector<1xi32>} 0 : i32 + %c1_i32_3 = arith.constant {async_agent = dense<1> : vector<1xi32>} 1 : i32 + %c2_i32 = arith.constant {async_agent = dense<1> : vector<1xi32>} 2 : i32 + %33 = arith.subi %31#2, %c1_i32_3 {async_agent = dense<1> : vector<1xi32>} : i32 + %34 = arith.cmpi eq, %31#2, %c0_i32_2 {async_agent = dense<1> : vector<1xi32>} : i32 + %35 = arith.select %34, %c2_i32, %33 {async_agent = dense<1> : vector<1xi32>} : i32 + triton_nvidia_gpu.consumer_release %2, %35 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 + %36 = triton_gpu.convert_layout %32 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x128xf32, #mma>) -> tensor<128x128xf32, #blocked2> + tt.store %27, %36 {async_agent = dense<1> : vector<1xi32>, boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1>, tensor<128x128xf32, #blocked2> + } {async_agent = dense<1> : vector<1xi32>} tt.return } } // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> #mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> #shared = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> #shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> -module attributes {"async.num-agents" = 2 : i32, "triton_gpu.compute-capability" = 90 : i32, "triton_gpu.enable-warp-specialization" = 1 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"async.num-agents" = 2 : i32, "triton_gpu.compute-capability" = 90 : i32, "triton_gpu.enable-warp-specialization" = 1 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @matmal_from_wsmutex // CHECK: triton_nvidia_gpu.alloc_mbarrier // CHECK: scf.if // CHECK: scf.for // CHECK: triton_nvidia_gpu.extract_mbarrier // CHECK: triton_nvidia_gpu.mbarrier_wait - // CHECK: triton_gpu.insert_slice - // CHECK: triton_gpu.insert_slice + // CHECK: triton_nvidia_gpu.insert_slice_async_v2 + // CHECK: triton_nvidia_gpu.insert_slice_async_v2 // CHECK: triton_nvidia_gpu.extract_mbarrier // CHECK: triton_nvidia_gpu.mbarrier_arrive // CHECK: scf.yield @@ -239,174 +177,224 @@ module attributes {"async.num-agents" = 2 : i32, "triton_gpu.compute-capability" // CHECK: triton_nvidia_gpu.mbarrier_wait // CHECK: triton_gpu.extract_slice // CHECK: triton_gpu.extract_slice - // CHECK: tt.dot + // CHECK: triton_nvidia_gpu.dot_async // CHECK: triton_nvidia_gpu.extract_mbarrier // CHECK: triton_nvidia_gpu.mbarrier_arrive // CHECK: scf.yield // CHECK: triton_nvidia_gpu.bar_arrive + // CHECK: triton_nvidia_gpu.dot_wait + // CHECK: triton_nvidia_gpu.extract_mbarrier + // CHECK: triton_nvidia_gpu.mbarrier_arrive // CHECK: triton_nvidia_gpu.bar_wait // CHECK: tt.store // CHECK: triton_nvidia_gpu.bar_arrive - tt.func public @matmal_from_wsmutex(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg4: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) { + tt.func public @matmal_from_wsmutex(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg10: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { %0 = triton_gpu.alloc_tensor : tensor<3x64x16xf16, #shared> %1 = triton_gpu.alloc_tensor : tensor<3x16x64xf16, #shared1> %2 = triton_nvidia_gpu.create_token {num = 3 : i32} : tensor<3x!triton_nvidia_gpu.token> - %3 = triton_nvidia_gpu.get_agent_id : i32 + %c63_i32 = arith.constant 63 : i32 %c0_i32 = arith.constant 0 : i32 - %4 = arith.cmpi eq, %3, %c0_i32 : i32 - scf.if %4 { - %cst = arith.constant {async_agent = dense<0> : vector<1xi32>} dense<16> : tensor<16x64xi32, #blocked> - %cst_0 = arith.constant {async_agent = dense<0> : vector<1xi32>} dense<16> : tensor<64x16xi32, #blocked1> - %c63_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 63 : i32 - %c114_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 114 : i32 + %c1_i64 = arith.constant 1 : i64 + %c64_i32 = arith.constant 64 : i32 + %c8_i32 = arith.constant 8 : i32 + %3 = tt.get_program_id x : i32 + %4 = arith.addi %arg6, %c63_i32 : i32 + %5 = arith.divsi %4, %c64_i32 : i32 + %6 = arith.addi %arg5, %c63_i32 : i32 + %7 = arith.divsi %6, %c64_i32 : i32 + %8 = arith.muli %5, %c8_i32 : i32 + %9 = arith.divsi %3, %8 : i32 + %10 = arith.muli %9, %c8_i32 : i32 + %11 = arith.subi %7, %10 : i32 + %12 = arith.minsi %11, %c8_i32 : i32 + %13 = arith.remsi %3, %8 : i32 + %14 = arith.remsi %13, %12 : i32 + %15 = arith.addi %10, %14 : i32 + %16 = arith.divsi %13, %12 : i32 + %17 = arith.muli %15, %c64_i32 : i32 + %18 = arith.muli %16, %c64_i32 : i32 + %19 = arith.extsi %arg5 : i32 to i64 + %20 = arith.extsi %arg7 : i32 to i64 + %21 = arith.extsi %arg8 : i32 to i64 + %22 = tt.make_tensor_ptr %arg0, [%19, %20], [%21, %c1_i64], [%17, %c0_i32] {order = array} : , 1> + %23 = arith.extsi %arg6 : i32 to i64 + %24 = arith.extsi %arg9 : i32 to i64 + %25 = tt.make_tensor_ptr %arg1, [%20, %23], [%c1_i64, %24], [%c0_i32, %18] {order = array} : , 1> + %26 = arith.extsi %arg10 : i32 to i64 + %27 = tt.make_tensor_ptr %arg4, [%19, %23], [%26, %c1_i64], [%17, %18] {order = array} : , 1> + %28 = triton_nvidia_gpu.get_agent_id : i32 + %c0_i32_0 = arith.constant 0 : i32 + %29 = arith.cmpi eq, %28, %c0_i32_0 : i32 + scf.if %29 { + %c132_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 132 : i32 + %c15_i32 = arith.constant {async_agent = dense<0> : vector<1xi32>} 15 : i32 %c16_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 16 : i32 - %c0_i32_1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32 - %c64_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 64 : i32 - %6 = tt.get_program_id x {async_agent = dense<[0, 1]> : vector<2xi32>, axis = 0 : i32} : i32 - %7 = arith.addi %arg3, %c63_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %8 = arith.divsi %7, %c64_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %9 = arith.addi %arg4, %c63_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %10 = arith.divsi %9, %c64_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %11 = arith.muli %8, %10 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %12 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %13 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %14 = tt.splat %arg3 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %15 = tt.splat %arg4 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %16 = tt.splat %arg6 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<64x1xi32, #blocked1> - %17 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %18 = tt.expand_dims %17 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x16xi32, #blocked1> - %19 = tt.broadcast %18 {async_agent = dense<0> : vector<1xi32>} : (tensor<1x16xi32, #blocked1>) -> tensor<64x16xi32, #blocked1> - %20 = tt.splat %arg0 {async_agent = dense<0> : vector<1xi32>} : (!tt.ptr) -> tensor<64x16x!tt.ptr, #blocked1> - %21 = tt.make_range {async_agent = dense<0> : vector<1xi32>, end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %22 = tt.expand_dims %21 {async_agent = dense<0> : vector<1xi32>, axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<16x1xi32, #blocked> - %23 = tt.splat %arg7 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<1x64xi32, #blocked> - %24 = tt.broadcast %22 {async_agent = dense<0> : vector<1xi32>} : (tensor<16x1xi32, #blocked>) -> tensor<16x64xi32, #blocked> - %25 = tt.splat %arg1 {async_agent = dense<0> : vector<1xi32>} : (!tt.ptr) -> tensor<16x64x!tt.ptr, #blocked> + %31 = arith.muli %7, %5 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 + %32 = arith.addi %arg7, %c15_i32 {async_agent = dense<0> : vector<1xi32>} : i32 + %33 = arith.divsi %32, %c16_i32 {async_agent = dense<0> : vector<1xi32>} : i32 + %34 = arith.subi %c0_i32, %33 {async_agent = dense<0> : vector<1xi32>} : i32 + %35 = arith.muli %34, %c16_i32 {async_agent = dense<0> : vector<1xi32>} : i32 %c3_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 3 : i32 - %c0_i32_2 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32 - %26 = scf.for %arg9 = %6 to %11 step %c114_i32 iter_args(%arg10 = %c0_i32_2) -> (i32) : i32 { - %27 = arith.divsi %arg9, %10 {async_agent = dense<0> : vector<1xi32>} : i32 - %28 = arith.remsi %arg9, %10 {async_agent = dense<0> : vector<1xi32>} : i32 - %29 = arith.muli %27, %c64_i32 {async_agent = dense<0> : vector<1xi32>} : i32 - %30 = tt.splat %29 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %31 = arith.addi %30, %12 {async_agent = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %32 = arith.remsi %31, %14 {async_agent = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %33 = arith.muli %28, %c64_i32 {async_agent = dense<0> : vector<1xi32>} : i32 - %34 = tt.splat %33 {async_agent = dense<0> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %35 = arith.addi %34, %13 {async_agent = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %36 = arith.remsi %35, %15 {async_agent = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %37 = tt.expand_dims %32 {async_agent = dense<0> : vector<1xi32>, axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1> - %38 = arith.muli %37, %16 {async_agent = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> - %39 = tt.broadcast %38 {async_agent = dense<0> : vector<1xi32>} : (tensor<64x1xi32, #blocked1>) -> tensor<64x16xi32, #blocked1> - %40 = arith.addi %39, %19 {async_agent = dense<0> : vector<1xi32>} : tensor<64x16xi32, #blocked1> - %41 = tt.addptr %20, %40 {async_agent = dense<0> : vector<1xi32>} : tensor<64x16x!tt.ptr, #blocked1>, tensor<64x16xi32, #blocked1> - %42 = tt.expand_dims %36 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x64xi32, #blocked> - %43 = arith.muli %42, %23 {async_agent = dense<0> : vector<1xi32>} : tensor<1x64xi32, #blocked> - %44 = tt.broadcast %43 {async_agent = dense<0> : vector<1xi32>} : (tensor<1x64xi32, #blocked>) -> tensor<16x64xi32, #blocked> - %45 = arith.addi %24, %44 {async_agent = dense<0> : vector<1xi32>} : tensor<16x64xi32, #blocked> - %46 = tt.addptr %25, %45 {async_agent = dense<0> : vector<1xi32>} : tensor<16x64x!tt.ptr, #blocked>, tensor<16x64xi32, #blocked> - %c3_i32_3 = arith.constant {async_agent = dense<0> : vector<1xi32>} 3 : i32 - %47 = arith.subi %arg5, %c0_i32_1 {async_agent = dense<0> : vector<1xi32>} : i32 - %48 = arith.divui %47, %c16_i32 {async_agent = dense<0> : vector<1xi32>} : i32 - %49 = arith.muli %arg10, %48 {async_agent = dense<0> : vector<1xi32>} : i32 - %c3_i32_4 = arith.constant {async_agent = dense<0> : vector<1xi32>} 3 : i32 - %50:3 = scf.for %arg11 = %c0_i32_1 to %arg5 step %c16_i32 iter_args(%arg12 = %41, %arg13 = %46, %arg14 = %49) -> (tensor<64x16x!tt.ptr, #blocked1>, tensor<16x64x!tt.ptr, #blocked>, i32) : i32 { - %52 = arith.remsi %arg14, %c3_i32_4 {async_agent = dense<0> : vector<1xi32>} : i32 - triton_nvidia_gpu.producer_acquire %2, %52 {async_agent = dense<0> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 - %53 = triton_gpu.insert_slice %arg12, %0, %52 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x16x!tt.ptr, #blocked1> -> tensor<3x64x16xf16, #shared> - %54 = triton_gpu.insert_slice %arg13, %1, %52 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr, #blocked> -> tensor<3x16x64xf16, #shared1> - triton_nvidia_gpu.producer_commit %2, %52 {async_agent = dense<0> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 - %55 = tt.addptr %arg12, %cst_0 {async_agent = dense<0> : vector<1xi32>} : tensor<64x16x!tt.ptr, #blocked1>, tensor<64x16xi32, #blocked1> - %56 = tt.addptr %arg13, %cst {async_agent = dense<0> : vector<1xi32>} : tensor<16x64x!tt.ptr, #blocked>, tensor<16x64xi32, #blocked> + %c0_i32_1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %36:5 = scf.for %arg11 = %3 to %31 step %c132_i32 iter_args(%arg12 = %22, %arg13 = %25, %arg14 = %15, %arg15 = %16, %arg16 = %c0_i32_1) -> (!tt.ptr, 1>, !tt.ptr, 1>, i32, i32, i32) : i32 { + %37 = arith.divsi %arg11, %8 {async_agent = dense<0> : vector<1xi32>} : i32 + %38 = arith.muli %37, %c8_i32 {async_agent = dense<0> : vector<1xi32>} : i32 + %39 = arith.subi %7, %38 {async_agent = dense<0> : vector<1xi32>} : i32 + %40 = arith.minsi %39, %c8_i32 {async_agent = dense<0> : vector<1xi32>} : i32 + %41 = arith.remsi %arg11, %8 {async_agent = dense<0> : vector<1xi32>} : i32 + %42 = arith.remsi %41, %40 {async_agent = dense<0> : vector<1xi32>} : i32 + %43 = arith.addi %38, %42 {async_agent = dense<0> : vector<1xi32>} : i32 + %44 = arith.divsi %41, %40 {async_agent = dense<0> : vector<1xi32>} : i32 + %45 = arith.subi %43, %arg14 {async_agent = dense<0> : vector<1xi32>} : i32 + %46 = arith.muli %45, %c64_i32 {async_agent = dense<0> : vector<1xi32>} : i32 + %47 = tt.advance %arg12, [%46, %c0_i32] {async_agent = dense<0> : vector<1xi32>} : , 1> + %48 = arith.subi %44, %arg15 {async_agent = dense<0> : vector<1xi32>} : i32 + %49 = arith.muli %48, %c64_i32 {async_agent = dense<0> : vector<1xi32>} : i32 + %50 = tt.advance %arg13, [%c0_i32, %49] {async_agent = dense<0> : vector<1xi32>} : , 1> + %c3_i32_2 = arith.constant {async_agent = dense<0> : vector<1xi32>} 3 : i32 + %c0_i32_3 = arith.constant {async_agent = dense<0> : vector<1xi32>} 0 : i32 + %51 = arith.subi %arg7, %c0_i32 {async_agent = dense<0> : vector<1xi32>} : i32 + %52 = arith.addi %51, %c16_i32 {async_agent = dense<0> : vector<1xi32>} : i32 + %c1_i32_4 = arith.constant {async_agent = dense<0> : vector<1xi32>} 1 : i32 + %c2_i32 = arith.constant {async_agent = dense<0> : vector<1xi32>} 2 : i32 + %53 = arith.subi %52, %c1_i32_4 {async_agent = dense<0> : vector<1xi32>} : i32 + %54 = arith.divui %53, %c16_i32 {async_agent = dense<0> : vector<1xi32>} : i32 + %55 = arith.muli %arg16, %54 {async_agent = dense<0> : vector<1xi32>} : i32 + %56 = arith.divui %55, %c3_i32_2 {async_agent = dense<0> : vector<1xi32>} : i32 + %57 = arith.muli %56, %c3_i32_2 {async_agent = dense<0> : vector<1xi32>} : i32 + %58 = arith.subi %55, %57 {async_agent = dense<0> : vector<1xi32>} : i32 + %59 = arith.andi %56, %c1_i32_4 {async_agent = dense<0> : vector<1xi32>} : i32 + %60 = arith.trunci %59 {async_agent = dense<0> : vector<1xi32>} : i32 to i1 + %61:4 = scf.for %arg17 = %c0_i32 to %arg7 step %c16_i32 iter_args(%arg18 = %47, %arg19 = %50, %arg20 = %60, %arg21 = %58) -> (!tt.ptr, 1>, !tt.ptr, 1>, i1, i32) : i32 { + triton_nvidia_gpu.producer_acquire %2, %arg21 {async_agent = dense<0> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 + %65 = triton_gpu.insert_slice %arg18, %0, %arg21 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<3x64x16xf16, #shared> + %66 = triton_gpu.insert_slice %arg19, %1, %arg21 {async_agent = dense<0> : vector<1xi32>, axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<3x16x64xf16, #shared1> + triton_nvidia_gpu.producer_commit %2, %arg21 {async_agent = dense<0> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 + %67 = tt.advance %arg18, [%c0_i32, %c16_i32] {async_agent = dense<0> : vector<1xi32>} : , 1> + %68 = tt.advance %arg19, [%c16_i32, %c0_i32] {async_agent = dense<0> : vector<1xi32>} : , 1> %c1_i32_6 = arith.constant {async_agent = dense<0> : vector<1xi32>} 1 : i32 - %57 = arith.addi %arg14, %c1_i32_6 {async_agent = dense<0> : vector<1xi32>} : i32 - scf.yield {async_agent = dense<0> : vector<1xi32>} %55, %56, %57 : tensor<64x16x!tt.ptr, #blocked1>, tensor<16x64x!tt.ptr, #blocked>, i32 + %c0_i32_7 = arith.constant {async_agent = dense<0> : vector<1xi32>} 0 : i32 + %true = arith.constant {async_agent = dense<0> : vector<1xi32>} true + %69 = arith.addi %arg21, %c1_i32_6 {async_agent = dense<0> : vector<1xi32>} : i32 + %70 = arith.cmpi uge, %69, %c3_i32_2 {async_agent = dense<0> : vector<1xi32>} : i32 + %71 = arith.cmpi ult, %69, %c3_i32_2 {async_agent = dense<0> : vector<1xi32>} : i32 + %72 = arith.subi %69, %c3_i32_2 {async_agent = dense<0> : vector<1xi32>} : i32 + %73 = arith.select %70, %72, %69 {async_agent = dense<0> : vector<1xi32>} : i32 + %74 = arith.xori %arg20, %true {async_agent = dense<0> : vector<1xi32>} : i1 + %75 = arith.andi %70, %74 {async_agent = dense<0> : vector<1xi32>} : i1 + %76 = arith.andi %71, %arg20 {async_agent = dense<0> : vector<1xi32>} : i1 + %77 = arith.ori %75, %76 {async_agent = dense<0> : vector<1xi32>} : i1 + scf.yield {async_agent = dense<0> : vector<1xi32>} %67, %68, %77, %73 : !tt.ptr, 1>, !tt.ptr, 1>, i1, i32 } {async_agent = dense<0> : vector<1xi32>} + %62 = tt.advance %61#0, [%c0_i32, %35] {async_agent = dense<0> : vector<1xi32>} : , 1> + %63 = tt.advance %61#1, [%35, %c0_i32] {async_agent = dense<0> : vector<1xi32>} : , 1> %c1_i32_5 = arith.constant {async_agent = dense<0> : vector<1xi32>} 1 : i32 - %51 = arith.addi %arg10, %c1_i32_5 {async_agent = dense<0> : vector<1xi32>} : i32 - scf.yield {async_agent = dense<0> : vector<1xi32>} %51 : i32 + %64 = arith.addi %arg16, %c1_i32_5 {async_agent = dense<0> : vector<1xi32>} : i32 + scf.yield {async_agent = dense<0> : vector<1xi32>} %62, %63, %43, %44, %64 : !tt.ptr, 1>, !tt.ptr, 1>, i32, i32, i32 } {async_agent = dense<0> : vector<1xi32>} } {async_agent = dense<0> : vector<1xi32>} %c1_i32 = arith.constant 1 : i32 - %5 = arith.cmpi eq, %3, %c1_i32 : i32 - scf.if %5 { - %c0_i32_0 = arith.constant 0 : i32 - %6 = triton_nvidia_gpu.get_mutex_role_id {async_agent = dense<1> : vector<1xi32>, num = 2 : i32} : i32 - %7 = arith.cmpi ne, %6, %c0_i32_0 : i32 - %8 = triton_nvidia_gpu.create_mutex {async_agent = dense<1> : vector<1xi32>} : !triton_nvidia_gpu.mutex - %9 = triton_nvidia_gpu.create_mutex {async_agent = dense<1> : vector<1xi32>} : !triton_nvidia_gpu.mutex + %30 = arith.cmpi eq, %28, %c1_i32 : i32 + scf.if %30 { + %c0_i32_1 = arith.constant 0 : i32 + %31 = triton_nvidia_gpu.get_mutex_role_id {async_agent = dense<1> : vector<1xi32>, num = 2 : i32} : i32 + %32 = arith.cmpi ne, %31, %c0_i32_1 : i32 + %33 = triton_nvidia_gpu.create_mutex {async_agent = dense<1> : vector<1xi32>} : !triton_nvidia_gpu.mutex + %34 = triton_nvidia_gpu.create_mutex {async_agent = dense<1> : vector<1xi32>} : !triton_nvidia_gpu.mutex %cst = arith.constant {async_agent = dense<1> : vector<1xi32>} dense<0.000000e+00> : tensor<64x64xf32, #mma> - %c63_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 63 : i32 - %c114_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 114 : i32 + %c132_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 132 : i32 %c16_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 16 : i32 - %c0_i32_1 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32 - %c64_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 64 : i32 - %10 = tt.get_program_id x {async_agent = dense<[0, 1]> : vector<2xi32>, axis = 0 : i32} : i32 - %11 = arith.addi %arg3, %c63_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %12 = arith.divsi %11, %c64_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %13 = arith.addi %arg4, %c63_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %14 = arith.divsi %13, %c64_i32 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %15 = arith.muli %12, %14 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 - %16 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %17 = tt.make_range {async_agent = dense<1> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %18 = tt.splat %arg8 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<64x1xi32, #blocked2> - %19 = tt.splat %arg2 {async_agent = dense<1> : vector<1xi32>} : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked2> + %35 = arith.muli %7, %5 {async_agent = dense<[0, 1]> : vector<2xi32>} : i32 %c3_i32 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 3 : i32 %c0_i32_2 = arith.constant {async_agent = dense<[0, 1]> : vector<2xi32>} 0 : i32 - %20 = arith.muli %c114_i32, %6 {async_agent = dense<1> : vector<1xi32>} : i32 - %21 = arith.addi %10, %20 {async_agent = dense<1> : vector<1xi32>} : i32 + %36 = arith.muli %c132_i32, %31 {async_agent = dense<1> : vector<1xi32>} : i32 + %37 = arith.addi %3, %36 {async_agent = dense<1> : vector<1xi32>} : i32 %c2_i32 = arith.constant {async_agent = dense<1> : vector<1xi32>} 2 : i32 - %22 = arith.muli %c114_i32, %c2_i32 {async_agent = dense<1> : vector<1xi32>} : i32 - %23 = arith.addi %c0_i32_2, %6 {async_agent = dense<1> : vector<1xi32>} : i32 - %24 = scf.for %arg9 = %21 to %15 step %22 iter_args(%arg10 = %23) -> (i32) : i32 { - %25 = arith.cmpi ne, %arg9, %10 : i32 - %26 = arith.ori %25, %7 {agent.mutex_role = 0 : i32} : i1 - scf.if %26 { - triton_nvidia_gpu.lock %8 {agent.mutex_role = 0 : i32} : !triton_nvidia_gpu.mutex + %38 = arith.muli %c132_i32, %c2_i32 {async_agent = dense<1> : vector<1xi32>} : i32 + %39 = arith.addi %c0_i32_2, %31 {async_agent = dense<1> : vector<1xi32>} : i32 + %40:4 = scf.for %arg11 = %37 to %35 step %38 iter_args(%arg12 = %27, %arg13 = %15, %arg14 = %16, %arg15 = %39) -> (!tt.ptr, 1>, i32, i32, i32) : i32 { + %41 = arith.cmpi ne, %arg11, %3 : i32 + %42 = arith.ori %41, %32 : i1 + scf.if %42 { + triton_nvidia_gpu.lock %33 {agent.mutex_role = 0 : i32} : !triton_nvidia_gpu.mutex } {agent.mutex_role = 0 : i32} - %27 = arith.divsi %arg9, %14 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 - %28 = arith.remsi %arg9, %14 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 - %29 = arith.muli %27, %c64_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 - %30 = tt.splat %29 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %31 = arith.addi %30, %17 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %32 = arith.muli %28, %c64_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 - %33 = tt.splat %32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %34 = arith.addi %33, %16 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %43 = arith.divsi %arg11, %8 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %44 = arith.muli %43, %c8_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %45 = arith.subi %7, %44 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %46 = arith.minsi %45, %c8_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %47 = arith.remsi %arg11, %8 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %48 = arith.remsi %47, %46 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %49 = arith.addi %44, %48 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %50 = arith.divsi %47, %46 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %51 = arith.subi %49, %arg13 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %52 = arith.muli %51, %c64_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %53 = arith.subi %50, %arg14 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %54 = arith.muli %53, %c64_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 %c3_i32_3 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 3 : i32 - %35 = arith.subi %arg5, %c0_i32_1 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 - %36 = arith.divui %35, %c16_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 - %37 = arith.muli %arg10, %36 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 - %c3_i32_4 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 3 : i32 - %38:2 = scf.for %arg11 = %c0_i32_1 to %arg5 step %c16_i32 iter_args(%arg12 = %cst, %arg13 = %37) -> (tensor<64x64xf32, #mma>, i32) : i32 { - %48 = arith.remsi %arg13, %c3_i32_4 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 - triton_nvidia_gpu.consumer_wait %2, %48 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 - %49 = triton_gpu.extract_slice %0[%48, 0, 0] [1, 64, 16] [1, 1, 1] {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x64x16xf16, #shared> to tensor<64x16xf16, #shared> - %50 = triton_gpu.convert_layout %49 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<64x16xf16, #shared>) -> tensor<64x16xf16, #shared> - %51 = triton_gpu.extract_slice %1[%48, 0, 0] [1, 16, 64] [1, 1, 1] {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x16x64xf16, #shared1> to tensor<16x64xf16, #shared1> - %52 = triton_gpu.convert_layout %51 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<16x64xf16, #shared1>) -> tensor<16x64xf16, #shared1> - %53 = tt.dot %50, %52, %arg12 {agent.mutex_role = 0 : i32, allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> - triton_nvidia_gpu.consumer_release %2, %48 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 - %c1_i32_6 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32 - %54 = arith.addi %arg13, %c1_i32_6 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 - scf.yield {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} %53, %54 : tensor<64x64xf32, #mma>, i32 - } {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} - triton_nvidia_gpu.unlock %8 : !triton_nvidia_gpu.mutex - scf.if %26 { - triton_nvidia_gpu.lock %9 {agent.mutex_role = 1 : i32} : !triton_nvidia_gpu.mutex + %c0_i32_4 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 0 : i32 + %55 = arith.subi %arg7, %c0_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %56 = arith.addi %55, %c16_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %c1_i32_5 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32 + %c2_i32_6 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 2 : i32 + %57 = arith.subi %56, %c1_i32_5 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %58 = arith.divui %57, %c16_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %59 = arith.muli %arg15, %58 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %60 = arith.divui %59, %c3_i32_3 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %61 = arith.muli %60, %c3_i32_3 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %62 = arith.subi %59, %61 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %63 = arith.andi %60, %c1_i32_5 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %64 = arith.trunci %63 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 to i1 + %65:3 = scf.for %arg16 = %c0_i32 to %arg7 step %c16_i32 iter_args(%arg17 = %cst, %arg18 = %64, %arg19 = %62) -> (tensor<64x64xf32, #mma>, i1, i32) : i32 { + triton_nvidia_gpu.consumer_wait %2, %arg19 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 + %74 = triton_gpu.extract_slice %0[%arg19, 0, 0] [1, 64, 16] [1, 1, 1] {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x64x16xf16, #shared> to tensor<64x16xf16, #shared> + %75 = triton_gpu.convert_layout %74 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<64x16xf16, #shared>) -> tensor<64x16xf16, #shared> + %76 = triton_gpu.extract_slice %1[%arg19, 0, 0] [1, 16, 64] [1, 1, 1] {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x16x64xf16, #shared1> to tensor<16x64xf16, #shared1> + %77 = triton_gpu.convert_layout %76 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<16x64xf16, #shared1>) -> tensor<16x64xf16, #shared1> + %78 = triton_nvidia_gpu.dot_async %75, %77, %arg17 {agent.mutex_role = 0 : i32, allowTF32 = true, async_agent = dense<1> : vector<1xi32>, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %79 = arith.cmpi sgt, %arg16, %c0_i32 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + scf.if %79 { + %c0_i32_13 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 0 : i32 + %c1_i32_14 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32 + %c2_i32_15 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 2 : i32 + %89 = arith.subi %arg19, %c1_i32_14 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %90 = arith.cmpi eq, %arg19, %c0_i32_13 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %91 = arith.select %90, %c2_i32_15, %89 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + triton_nvidia_gpu.consumer_release %2, %91 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 + } {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} + %c1_i32_11 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32 + %c0_i32_12 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 0 : i32 + %true = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} true + %80 = arith.addi %arg19, %c1_i32_11 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %81 = arith.cmpi uge, %80, %c3_i32_3 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %82 = arith.cmpi ult, %80, %c3_i32_3 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %83 = arith.subi %80, %c3_i32_3 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %84 = arith.select %81, %83, %80 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %85 = arith.xori %arg18, %true {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i1 + %86 = arith.andi %81, %85 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i1 + %87 = arith.andi %82, %arg18 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i1 + %88 = arith.ori %86, %87 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i1 + scf.yield {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} %78, %88, %84 : tensor<64x64xf32, #mma>, i1, i32 + } {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} + triton_nvidia_gpu.unlock %33 : !triton_nvidia_gpu.mutex + %66 = triton_nvidia_gpu.dot_wait %65#0 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>, pendings = 0 : i32} : tensor<64x64xf32, #mma> + %c0_i32_7 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 0 : i32 + %c1_i32_8 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32 + %c2_i32_9 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 2 : i32 + %67 = arith.subi %65#2, %c1_i32_8 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %68 = arith.cmpi eq, %65#2, %c0_i32_7 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + %69 = arith.select %68, %c2_i32_9, %67 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + triton_nvidia_gpu.consumer_release %2, %69 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 + scf.if %42 { + triton_nvidia_gpu.lock %34 {agent.mutex_role = 1 : i32} : !triton_nvidia_gpu.mutex } {agent.mutex_role = 1 : i32} - %39 = tt.expand_dims %31 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>, axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<64x1xi32, #blocked2> - %40 = arith.muli %39, %18 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x1xi32, #blocked2> - %41 = tt.addptr %19, %40 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked2>, tensor<64x1xi32, #blocked2> - %42 = tt.expand_dims %34 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>, axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2> - %43 = tt.broadcast %41 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<64x1x!tt.ptr, #blocked2>) -> tensor<64x64x!tt.ptr, #blocked2> - %44 = tt.broadcast %42 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> - %45 = tt.addptr %43, %44 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked2>, tensor<64x64xi32, #blocked2> - %46 = triton_gpu.convert_layout %38#0 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked2> - tt.store %45, %46 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32} : tensor<64x64xf32, #blocked2> - triton_nvidia_gpu.unlock %9 : !triton_nvidia_gpu.mutex - %c1_i32_5 = arith.constant {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32 - %47 = arith.addi %arg10, %c2_i32 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : i32 - scf.yield {async_agent = dense<1> : vector<1xi32>} %47 : i32 + %70 = arith.truncf %66 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x64xf32, #mma> to tensor<64x64xf16, #mma> + %71 = tt.advance %arg12, [%52, %54] {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : , 1> + %72 = triton_gpu.convert_layout %70 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<64x64xf16, #mma>) -> tensor<64x64xf16, #blocked2> + tt.store %71, %72 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>, boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1>, tensor<64x64xf16, #blocked2> + triton_nvidia_gpu.unlock %34 : !triton_nvidia_gpu.mutex + %c1_i32_10 = arith.constant {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32 + %73 = arith.addi %arg15, %c2_i32 {agent.mutex_role = 1 : i32, async_agent = dense<1> : vector<1xi32>} : i32 + scf.yield {async_agent = dense<1> : vector<1xi32>} %71, %49, %50, %73 : !tt.ptr, 1>, i32, i32, i32 } {async_agent = dense<1> : vector<1xi32>} } {"agent.num-roles" = 2 : i32, async_agent = dense<1> : vector<1xi32>} tt.return diff --git a/test/TritonGPU/wsmutex.mlir b/test/TritonGPU/wsmutex.mlir index 1c0ad7712..3c3e5b995 100644 --- a/test/TritonGPU/wsmutex.mlir +++ b/test/TritonGPU/wsmutex.mlir @@ -2,9 +2,9 @@ // CHECK: scf.if +// CHECK: triton_nvidia_gpu.create_mutex +// CHECK: triton_nvidia_gpu.create_mutex // CHECK: scf.for -// CHECK: triton_nvidia_gpu.create_mutex -// CHECK: triton_nvidia_gpu.create_mutex // CHECK: triton_nvidia_gpu.lock // CHECK: agent.mutex_role = 0 : i32 // CHECK: triton_nvidia_gpu.unlock diff --git a/test/TritonGPU/wspipeline.mlir b/test/TritonGPU/wspipeline.mlir index a42ca46d0..5356002b2 100644 --- a/test/TritonGPU/wspipeline.mlir +++ b/test/TritonGPU/wspipeline.mlir @@ -22,9 +22,10 @@ // CHECK: triton_gpu.extract_slice // CHECK: triton_gpu.extract_slice // CHECK: triton_nvidia_gpu.dot_async -// CHECK: triton_nvidia_gpu.dot_wait +// CHECK: triton_nvidia_gpu.dot_wait {{.*}} pendings = 1 // CHECK: triton_nvidia_gpu.consumer_release // CHECK: scf.yield +// CHECK: triton_nvidia_gpu.dot_wait {{.*}} pendings = 0 // CHECK: async_agent = dense<1> : vector<1xi32> #blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> diff --git a/test/lib/Analysis/TestAlias.cpp b/test/lib/Analysis/TestAlias.cpp index c91f668bd..674e1e5e3 100644 --- a/test/lib/Analysis/TestAlias.cpp +++ b/test/lib/Analysis/TestAlias.cpp @@ -20,7 +20,7 @@ struct TestAliasPass return opName; } - static void print(StringRef name, SmallVector &vals, + static void print(StringRef name, SmallVector &vals, raw_ostream &os) { if (vals.empty()) return; @@ -57,7 +57,7 @@ struct TestAliasPass auto getAllocOpNames = [&](Value value) { dataflow::Lattice *latticeElement = analysis->getLatticeElement(value); - SmallVector opNames; + SmallVector opNames; if (latticeElement) { auto &info = latticeElement->getValue(); for (auto &alias : info.getAllocs()) { diff --git a/third_party/triton_shared b/third_party/triton_shared index d0ac5898f..07ea84207 160000 --- a/third_party/triton_shared +++ b/third_party/triton_shared @@ -1 +1 @@ -Subproject commit d0ac5898ff97ab33c2839306ec10bfa4fab816f5 +Subproject commit 07ea84207ac7763af16206e6d790c7bc37c6c2d9