From ca8f110617c0996f9db4966eca286dc4bef91fc3 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Thu, 2 Nov 2023 09:56:39 -0700 Subject: [PATCH] [BACKEND] Pipeliner refactoring (#2565) Refactor the pipeliner pass in order to make it more generic. The main change is that the pipeliner is now broken into 2 pieces one calculating a modulo schedule and create async ops based on the IR and an expander that will generate the pipelined IR based on the modulo schedule. The advantage of separating the two pieces is that it will allow us to create different schedule without having to change the expander and it will allow for more complex schedules. For now the schedule generated for matmul case matches rougly the schedule picked by the previous pipeliner in order to avoid changes. This also creates a different sequence of insert/extract slice for the alloc. We should probably change shared alloc to use memory semantic. --- .../Dialect/TritonGPU/Transforms/Utility.h | 3 +- .../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 32 +- .../TritonGPU/Transforms/CMakeLists.txt | 4 +- lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 1940 ----------------- .../Pipeliner/MatmulLoopPipeline.cpp | 826 +++++++ .../Transforms/Pipeliner/PipelineExpander.cpp | 704 ++++++ .../Transforms/Pipeliner/PipelineExpander.h | 101 + .../TritonGPU/Transforms/Pipeliner/Schedule.h | 27 + .../Pipeliner/SoftwarePipeliner.cpp | 88 + lib/Dialect/TritonGPU/Transforms/Utility.cpp | 4 - python/test/unit/hopper/test_gemm.py | 3 + test/TritonGPU/loop-pipeline-hopper.mlir | 88 +- test/TritonGPU/loop-pipeline.mlir | 134 +- .../pipeline-hopper-remove-wait.mlir | 8 +- 14 files changed, 1901 insertions(+), 2061 deletions(-) delete mode 100644 lib/Dialect/TritonGPU/Transforms/Pipeline.cpp create mode 100644 lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp create mode 100644 lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp create mode 100644 lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.h create mode 100644 lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.h create mode 100644 lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index 1273e6dac..8f7ce75f2 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -111,6 +111,8 @@ bool isExpensiveLoadOrStore(Operation *op); bool canFoldIntoConversion(Operation *op, Attribute targetEncoding); +// Replace ForOp with a new ForOp with extra operands. The YieldOp is not +// updated and needs to be updated separatly for the loop to be correct. scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, scf::ForOp loop, ValueRange newIterOperands); @@ -143,7 +145,6 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, ArrayRef shape); - } // namespace mlir #endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 361750ac0..30b1ef5fd 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -1421,18 +1421,37 @@ struct IndexCastOpLowering } }; +struct SelectOpConversion + : ElementwiseOpConversionBase { + using Base = + ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(mlir::arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + std::array llvmOperands; + if (operands[0].size() == 2) { + // Case of scalar condition with tensor operands. + assert(op.getCondition().getType().isInteger(1)); + llvmOperands = {adaptor.getCondition(), operands[0][0], operands[0][1]}; + } else { + llvmOperands = {operands[0][0], operands[0][1], operands[0][2]}; + } + return {rewriter.create( + loc, llvmOperands[1].getType(), llvmOperands, + adaptor.getAttributes().getValue())}; + } +}; + void populateElementwiseOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, int computeCapability, PatternBenefit benefit) { -#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \ - patterns.add>( \ - typeConverter, axisInfoAnalysis, benefit); - POPULATE_TERNARY_OP(arith::SelectOp, LLVM::SelectOp) -#undef POPULATE_TERNARY_OP - #define POPULATE_BINARY_OP(SRC_OP, DST_OP) \ patterns.add>( \ typeConverter, axisInfoAnalysis, benefit); @@ -1486,6 +1505,7 @@ void populateElementwiseOpToLLVMPatterns( patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index a06555c5a..2cd6e2672 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -5,7 +5,9 @@ add_mlir_dialect_library(TritonGPUTransforms OptimizeDotOperands.cpp OptimizeEpilogue.cpp OptimizeThreadLocality.cpp - Pipeline.cpp + Pipeliner/MatmulLoopPipeline.cpp + Pipeliner/PipelineExpander.cpp + Pipeliner/SoftwarePipeliner.cpp Prefetch.cpp RemoveLayoutConversions.cpp ReorderInstructions.cpp diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp deleted file mode 100644 index f48466b6f..000000000 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ /dev/null @@ -1,1940 +0,0 @@ -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Analysis/Utility.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/Transforms/Passes.h" -#include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" -#include "triton/Tools/Sys/GetEnv.hpp" -#include "llvm/ADT/MapVector.h" -#include "llvm/Support/Debug.h" - -//===----------------------------------------------------------------------===// -// This file implements software pipelining for loops. The implementation here -// is inspired by the pipeline pass in Triton (version 2.0) and SCF's -// LoopPipelining. -// -// We divide the loop body into the following phases: -// a. Pre-load operations: for instance, index computation. -// b. Load operations: loading from global memory to shared memory. -// c. Compute operations: for instance, Triton dot. -// d. Post-load operations: for instance, index computation. -// -// To pipeline the loop, we need to: -// - Hoist the pipelinable load operations for the first numStages-1 iterations -// to the loop pre-header -// - Find all the dependencies of the load operations. -// - Rematerialize the dependencies for their values at the first numStage-1 -// iterations -// - Assemble the loop body (numStage) and prefetch (numStage + 1). -// -// In the prologue, the sequence of operations is the same as the original loop -// body, following the (a) -> (b) -> (c) -> (d) order. In the loop body, -// however, we first execute the compute operations, then pre-load operations, -// post-load operations, and eventually the asynchronous load operations - in -// the (c) -> (a) -> (d) -> (b) order. This is used to better hide the latency -// of the load operations. Because of this, if post-load operations have direct -// dependencies on the load operations, we could repeat the post-load -// operations. More specifically, this occurs when: -// 1. Any load operand has an immediate dependency argument used at numStage-1. -// 2. The argument is first defined at numStage-2. -// To avoid the repeat, we peeled off post-load operations in the prologue that -// satisfy the above two conditions. See the example below for the definition of -// immediate and non-immediate dependencies. -// If we have a load that immediately depends on a block argument in the -// current iteration, it is an immediate dependency. Otherwise, it is a -// non-immediate dependency, which means the load depends on a block argument -// in the previous iterations. -// For example: -// scf.for (%arg0, %arg1, %arg2) { -// %0 = load %arg0 <--- immediate dep, this address is initialized before -// numStages-1. -// %1 = load %arg1 -// %2 = add %1, %arg2 -// %3 = load %2 <--- non-immediate dep, %arg1 must be an -// update-to-date value. -// } -// -// Our pipelining pass share some common characteristics with SCF's -// LoopPipelining. However, it is also noteworthy that our pipelining pass has -// the following characteristics different from SCF's LoopPipelining: -// 1. It can handle loop-carried dependencies of distance greater than 1. -// 2. It does not have a complicated epilogue but instead uses masking to handle -// boundary conditions. -// 3. Each operation/loop-carried argument cannot provide values to both -// immediate and non-immediate dependencies. Otherwise, we have to rematerialize -// the operation and arguments, which would likely increase register pressure. -// For example: -// scf.for (%arg0, %arg1, %arg2) { -// %0 = load %arg0 -// %1 = load %arg1, %0 <--- %0 is both a post-load op at numStages-2 and a -// pre-load op at numStages-1, so that we need two versions of %0. -// %2 = add %0, %arg2 -// scf.yield %arg0, %2, %arg2 -// } -// -//===----------------------------------------------------------------------===// - -using llvm::MapVector; -using namespace mlir; -namespace tt = mlir::triton; -namespace ttg = mlir::triton::gpu; -/// FIXME(Keren): The pipeline pass shouldn't be aware of nvidia_gpu dialect -namespace ttng = mlir::triton::nvidia_gpu; - -#define GEN_PASS_CLASSES -#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" - -#define int_attr(num) builder.getI64IntegerAttr(num) - -namespace { - -// Pass named attrs (e.g., tt.contiguity) from Triton to Triton -void addNamedAttrs(Operation *op, DictionaryAttr dictAttrs) { - NamedAttrList attrs = op->getDiscardableAttrs(); - // Collect the attributes to propagate: the ones in dictAttrs and not yet on - // the operation. - SmallVector toPropagate; - for (const NamedAttribute attr : dictAttrs.getValue()) { - if (!attrs.get(attr.getName())) - toPropagate.push_back(attr); - } - // If we found any, let's set them here as a single step. - if (toPropagate.size()) { - attrs.append(toPropagate); - op->setDiscardableAttrs(attrs); - } -} - -struct ConsumerReleaseInfo { - Value iterVar; - Value stageVar; - Value phaseVar; - Value nextIVVar; - Value stepVar; - Value upperBoundVar; - ttg::CTALayoutAttr CTALayout; - DenseMap consumerStageMap; -}; -typedef DenseMap - ConsumerReleaseMap; - -class LoopPipeliner { - /// Cache of ForOp and YieldOp related to this pipeliner. - scf::ForOp forOp; - scf::YieldOp yieldOp; - - /// Loads to be pipelined - SetVector validLoads; - /// The value that each load will be mapped to (after layout conversion) - DenseMap loadsMapping; - /// load => buffer - DenseMap loadsBuffer; - /// load => buffer type (with shared layout after swizzling) - DenseMap loadsBufferType; - /// load => buffer at stage N - DenseMap> loadStageBuffer; - /// load => after extract - DenseMap loadsExtract; - - /// XXX(Keren): The following are h100 only and disabled - /// load => full barrier arrive - DenseMap loadsBarrierArvOp; - /// load => mbarriers - DenseMap loadsFullBarriers; - DenseMap loadsEmptyBarriers; - /// load => null value or previous load which can share barrier with - DenseMap loadsCanShareBarriers; - /// Maintains the information to emit consumer_release mbarrier_arrive - ConsumerReleaseMap &consumerReleaseMap; - bool hasHopperDot = false; - // XXX(Keren): why the variable name is hopper dot and why do we need this - // check? - void checkHopperDots(SetVector &ops); - // XXX(Keren): it looks more like an optimization to be, not sure if it should - // exist in the base pipeliner - void checkOpShareBarriers(SetVector &ops); - int numLoadsRequireAsyncWait = 0; - int numLoadsRequireMBarrier = 0; - // Number of buffers to allocate for each input. - int numSharedMemorySlices = 0; - - /// Iterator values - Value nextIV; - Value pipelineIterIdx; - Value curWaitIdx; - - // Only needed when numLoadsRequireMBarrier > 0 - Value loopIterIdx; - Value curPhase; - Value curEmptyPhase; - - /// Yield values - SmallVector nextBuffers; - SmallVector extractSlices; - SmallVector yieldValues; - - /// The number of stages in the pipeline. - /// Stages in the range of [0, numStages-1) are in the prologue. - /// numStages-1 is appended after the loop body. - int numStages; - - /// Arg indicies - size_t bufferIdx, loadIdx, depArgsBeginIdx, ivIdx; - DenseMap depArgsIdx; - - /// XXX(Keren): The mode parameter is hacky, should be refactored - // false: legacy mode as a temporary solution for backward compatibility - // true: new mode for hopper - bool mode; - int numWarps; - int numCTAs; - - /// value (in loop) => value at stage N - DenseMap> valueMapping; - /// loop iter arg => value - DenseMap depArgsMapping; - /// forOp value => newForOp value - IRMapping mapping; - /// forOp value => prefetch value - IRMapping nextMapping; - - /// Dependency ops by program order - SmallVector orderedDeps; - - /// arg => source operand defined stages - DenseMap> immediateArgStages; - - /// block arguments that loads depend on - SetVector depArgs; - - /// operation => source operand defined stages - DenseMap> immediateOpStages; - - /// operations that loads depend on - SetVector depOps; - - /// Collect all pipelinable ops - LogicalResult collectOps(SetVector &ops); - - /// Collect values that `v` depends on and are defined inside the loop - void collectValueDep(Value v, int stage, SetVector &opDeps); - - /// Collect all op dependencies - void collectDeps(SetVector &ops, - MapVector> &opDeps); - - /// Check if none of the ops has valid uses - LogicalResult checkOpUses(SetVector &ops); - - /// Check if ops have dependencies that are not pipelinable - void checkOpDeps(SetVector &ops); - - void createBufferTypes(); - - void createOrderedDeps(); - - /// Return the stage at which `v` is defined prior to `stage` - int getValueDefStage(Value v, int stage); - - /// Map `origin` to `newValue` at `stage` - void setValueMapping(Value origin, Value newValue, int stage); - - /// Map `origin` to `newValue` at `stage` according to the association between - /// yieldOp and forOp - void setValueMappingYield(Value origin, Value newValue, int stage); - - /// Map `origin` to `newValue` at the next stage according to the association - /// between yieldOp and forOp - void setValueMappingYield(scf::ForOp newForOp, Value origin, Value newValue); - - /// Return the value mapped to `origin` at `stage`, if it exists. - Value lookupOrDefault(Value origin, int stage); - - /// Get the load mask for `loadOp`, given the mapped mask `mappedMask` (if - /// exists) and the current iteration's `loopCond`. - Value getLoadMask(tt::LoadOp loadOp, Value mappedMask, Value loopCond, - OpBuilder &builder); - - /// Return an empty buffer of size - ttg::AllocTensorOp allocateEmptyBuffer(tt::LoadOp loadOp, OpBuilder &builder); - - /// Collect all args of the new loop - SmallVector collectNewLoopArgs(); - - /// Clone the forOp and return the new forOp - scf::ForOp cloneForOp(ArrayRef newLoopArgs, OpBuilder &builder); - - /// Prefetch the next iteration for `newForOp` - void prefetchNextIteration(scf::ForOp newForOp, OpBuilder &builder); - - /// Check if curIdx is out of bound and wrap value around if necessary - Value getBoundedIterationValue(OpBuilder &builder, Value curIdx, - Value upperBoundIdx, Value curValue, - Value initValue); - - /// Assemble `newForOp`'s yield op - void finalizeYield(scf::ForOp newForOp, OpBuilder &builder); - -public: - LoopPipeliner(scf::ForOp forOp, int numStages, int numWarps, int numCTAs, - bool mode, int numSharedMemorySlices, - ConsumerReleaseMap &consumerReleaseMap) - : forOp(forOp), numStages(numStages), numWarps(numWarps), - numCTAs(numCTAs), mode(mode), - numSharedMemorySlices(numSharedMemorySlices), - consumerReleaseMap(consumerReleaseMap) { - // cache yieldOp - yieldOp = cast(forOp.getBody()->getTerminator()); - } - - LoopPipeliner() = delete; - - /// Collect loads to pipeline. Return success if we can pipeline this loop - LogicalResult initialize(); - - /// Emit pipelined loads (before loop body) - void emitPrologue(); - - /// emit pipelined loads (after loop body) - void emitEpilogue(); - - /// create the new ForOp (add new args & insert prefetched ops) - scf::ForOp createNewForOp(); - - friend struct PipelinePass; -}; - -/// Collect loads to pipeline. Return success if we can pipeline this loop -LogicalResult LoopPipeliner::collectOps(SetVector &ops) { - ModuleOp moduleOp = forOp->getParentOfType(); - ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); - - // We cannot use forOp.walk(...) here because we only want to visit the - // operations in the loop body block. Nested blocks are handled separately. - for (Operation &op : forOp) - if (auto loadOp = dyn_cast(&op)) { - if (isLoadFromTensorPtr(loadOp)) { - ops.insert(loadOp); - } else { - auto ptr = loadOp.getPtr(); - unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr); - if (auto mask = loadOp.getMask()) - vec = - std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); - - auto tensorTy = ptr.getType().dyn_cast(); - if (!tensorTy || tensorTy.getRank() < 2) - continue; - auto ty = - tensorTy.getElementType().cast().getPointeeType(); - unsigned width = vec * ty.getIntOrFloatBitWidth(); - // We do not pipeline all loads for the following reasons: - // 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8 and 16. - // 2. It's likely that pipling small loads won't offer much performance - // improvement and may even hurt performance by increasing register - // pressure. - if (width >= 32) - ops.insert(loadOp); - } - } - - if (ops.empty()) - return failure(); - else - return success(); -} - -void LoopPipeliner::collectValueDep(Value v, int stage, - SetVector &deps) { - // Loop-invariant value, skip - if (v.getParentRegion() != &forOp.getRegion()) - return; - - // Since we only need to peel the loop numStages-1 times, don't worry - // about depends that are too far away - if (stage < 0) - return; - - if (auto arg = v.dyn_cast()) { - if (arg.getArgNumber() > 0) { - deps.insert(v); - collectValueDep(yieldOp->getOperand(arg.getArgNumber() - 1), stage - 1, - deps); - } - } else { // value - deps.insert(v); - for (Value op : v.getDefiningOp()->getOperands()) - collectValueDep(op, stage, deps); - } -} - -void LoopPipeliner::collectDeps( - SetVector &ops, - MapVector> &valueDeps) { - for (auto op : ops) { - for (Value v : op->getOperands()) { - SetVector deps; - collectValueDep(v, numStages - 1, deps); - valueDeps[op] = deps; - } - } -} - -LogicalResult LoopPipeliner::checkOpUses(SetVector &ops) { - DenseSet invalidOps; - // Collect all ops' dependencies - MapVector> opDeps; - collectDeps(ops, opDeps); - - for (Operation *op : ops) { - if (auto loadOp = dyn_cast(op)) { - // Don't pipeline valid loads that depend on other valid loads - // (Because if a valid load depends on another valid load, this load needs - // to wait on the other load in the prologue, which is against the point - // of the pipeline pass) - bool isCandidate = true; - for (Operation *other : ops) - if (isa(other)) - if (opDeps[op].contains(other->getResult(0))) { - isCandidate = false; - break; - } - // We only pipeline loads that have one covert_layout (to dot_op) use - // TODO: lift this constraint in the future - if (isCandidate && loadOp.getResult().hasOneUse() && - !isLoadFromTensorPtr(loadOp)) { - isCandidate = false; - Operation *use = *loadOp.getResult().getUsers().begin(); - Operation *preUse = nullptr; - - // Advance to the first conversion as long as the use resides in shared - // memory and it has a single use itself - while (use) { - if (use->getNumResults() != 1 || !use->getResult(0).hasOneUse()) - break; - auto tensorType = - use->getResult(0).getType().dyn_cast(); - if (!tensorType.getEncoding().isa()) - break; - preUse = use; - use = *use->getResult(0).getUsers().begin(); - } - - if (auto convertLayout = llvm::dyn_cast(use)) { - if (auto tensorType = convertLayout.getResult() - .getType() - .dyn_cast()) - if (auto dotOpEnc = tensorType.getEncoding() - .dyn_cast()) { - isCandidate = true; - loadsMapping[loadOp] = convertLayout; - } - } else if (preUse && isa(use)) { - isCandidate = false; - // for MMAv3 whose dot take SharedEncoding as operands directly - Operation *post = *loadOp.getResult().getUsers().begin(); - auto newOrder = post->getResult(0) - .getType() - .cast() - .getEncoding() - .cast() - .getOrder(); - auto ty = loadOp.getType().cast(); - auto oldOrder = ttg::getOrder(ty.getEncoding()); - // The operand of MMAv3 is in SharedEncoding and it's order should not - // be changed after FuseTranspositions Pass. So we only pipeline the - // load if the order of the loaded BlockedEncoding is the same as the - // order of the SharedEncoding it is converted to. - // TODO: remove this constraint once the LoadOp supports transpose - // fusion - if (newOrder[0] == oldOrder[0] || newOrder[1] == oldOrder[1]) { - isCandidate = true; - loadsMapping[loadOp] = preUse->getResult(0); - } - } - } else if (isCandidate && mode && isLoadFromTensorPtr(loadOp)) { - loadsMapping[loadOp] = loadOp.getResult(); - } else - isCandidate = false; - - if (!isCandidate) - invalidOps.insert(loadOp); - else { - validLoads.insert(loadOp); - if (!isLoadFromTensorPtr(loadOp)) - numLoadsRequireAsyncWait++; - else - numLoadsRequireMBarrier++; - } - } - } - - for (Operation *op : invalidOps) - ops.remove(op); - - if (ops.empty()) - return failure(); - else - return success(); -} - -void LoopPipeliner::checkHopperDots(SetVector &ops) { - // dots to be pipelined - SetVector dots; - for (Operation &op : forOp) { - if (auto dotOp = dyn_cast(&op)) { - auto resTy = dotOp.getResult().getType().dyn_cast(); - if (auto resEnc = resTy.getEncoding().dyn_cast()) { - if (resEnc && resEnc.isHopper()) { - // Don't pipeline valid dots that depend on ops other than scf.yield - // and scf.for - auto dot = dotOp.getResult(); - bool valid = true; - - // all users of dot should be scf.yield - if (!dot.hasOneUse()) - valid = false; - if (!isa(*dot.getUsers().begin())) - valid = false; - - // C should be a block argument - auto CArg = dotOp.getOperand(2).dyn_cast(); - if (!CArg || !CArg.hasOneUse()) - valid = false; - - if (valid) - dots.insert(dotOp); - } - } - } - } - - hasHopperDot = true; -} - -void LoopPipeliner::checkOpShareBarriers(SetVector &ops) { - // Check if loads can share barriers - auto canShare = [&](Value load0, Value load1) -> bool { - if (!load0.hasOneUse() || !load1.hasOneUse()) - return false; - auto use0 = *load0.getUsers().begin(); - auto use1 = *load1.getUsers().begin(); - if (!use0->hasOneUse() || !use1->hasOneUse()) - return false; - if (*use0->getUsers().begin() != *use1->getUsers().begin()) - return false; - return true; - }; - // XXX(Keren): the logic here is pretty weird and might be incomplete - for (Value loadOp : validLoads) { - Value depLoad; - for (auto oldPair : loadsCanShareBarriers) { - Value oldLoad = oldPair.first; - if (canShare(loadOp, oldLoad)) { - depLoad = oldLoad; - break; - } - } - loadsCanShareBarriers[loadOp] = depLoad; - } -} - -void LoopPipeliner::checkOpDeps(SetVector &ops) { - SetVector nonImmediateDepArgs; - SetVector nonImmediateOps; - for (Operation *op : ops) { - for (Value v : op->getOperands()) { - SetVector deps; - collectValueDep(v, numStages - 1, deps); - int defStage = getValueDefStage(v, numStages - 1); - assert(defStage >= 0 && - "newLoopArgs has null args without a define op. Consider either " - "rewrite the loop to reduce cross iteration dependencies or " - "increase the num_stages value."); - for (auto dep : deps) { - auto immediate = deps.front().isa(); - if (auto arg = dyn_cast(dep)) { - depArgs.insert(arg); - if (immediate) - immediateArgStages[arg].insert(defStage); - else - nonImmediateDepArgs.insert(arg); - } else { - depOps.insert(dep.getDefiningOp()); - if (immediate) - immediateOpStages[dep.getDefiningOp()].insert(defStage); - else - nonImmediateOps.insert(dep.getDefiningOp()); - } - } - } - } - - // We could remove the following constraints if we can rematerialize in the - // loop. Check if immediateDepArgs and nonImmediateDepArgs are disjoint. - for (auto &[arg, stages] : immediateArgStages) { - assert(stages.size() == 1 && - "Triton doesn't support an argument provides values for " - "immediate operands of loads from multiple stages. Consider " - "removing post load instructions dependency on this argument."); - assert(!(nonImmediateDepArgs.contains(arg) && - stages.contains(numStages - 2)) && - "Loop-carried arguments provide values for both immediate and " - "non-immediate operands of loads. Please consider removing " - "pre/post load instructions dependency on this argument."); - } - - // Check if immediateOps and nonImmediateOps are disjoint. - for (auto &[op, stages] : immediateOpStages) { - assert(stages.size() == 1 && - "Triton doesn't support an operation provides values for " - "immediate operands of loads from multiple stages. Consider " - "removing post load instructions dependency on this argument."); - assert(!(nonImmediateOps.contains(op) && stages.contains(numStages - 2)) && - "Operations provide values for both immediate and " - "non-immediate operands of loads. Please consider " - "removing pre/post load instructions dependency on this " - "operation."); - } -} - -// helpers -void LoopPipeliner::setValueMapping(Value origin, Value newValue, int stage) { - if (valueMapping.find(origin) == valueMapping.end()) - valueMapping[origin] = SmallVector(numStages); - valueMapping[origin][stage] = newValue; -} - -void LoopPipeliner::setValueMappingYield(Value origin, Value newValue, - int stage) { - for (OpOperand &operand : origin.getUses()) { - if (operand.getOwner() == yieldOp) { - auto yieldIdx = operand.getOperandNumber(); - auto value = forOp.getRegionIterArgs()[yieldIdx]; - setValueMapping(value, newValue, stage); - } - } -} - -void LoopPipeliner::setValueMappingYield(scf::ForOp newForOp, Value origin, - Value newValue) { - for (OpOperand &operand : origin.getUses()) { - if (operand.getOwner() == yieldOp) { - auto yieldIdx = operand.getOperandNumber(); - auto depYieldIdx = depArgsIdx[forOp.getRegionIterArgs()[yieldIdx]]; - auto originArg = forOp.getRegionIterArgs()[yieldIdx]; - nextMapping.map(originArg, newValue); - auto newArg = newForOp.getRegionIterArgs()[depYieldIdx]; - if (!depArgsMapping.contains(newArg)) - depArgsMapping[newArg] = newValue; - } - } -} - -Value LoopPipeliner::lookupOrDefault(Value origin, int stage) { - if (valueMapping.find(origin) == valueMapping.end()) - return origin; - return valueMapping[origin][stage]; -} - -void LoopPipeliner::createBufferTypes() { - for (auto loadCvt : loadsMapping) { - auto loadOp = loadCvt.first; - Value cvt = loadCvt.second; - auto ty = loadOp.getType().cast(); - SmallVector bufferShape(ty.getShape().begin(), - ty.getShape().end()); - bufferShape.insert(bufferShape.begin(), numSharedMemorySlices); - auto CTALayout = ttg::getCTALayout(ty.getEncoding()); - Attribute sharedEnc; - if (auto dotOpEnc = cvt.getType() - .cast() - .getEncoding() - .dyn_cast()) { - // MMAv1 and MMAv2 - bool needTrans = dyn_cast_or_null( - cvt.getDefiningOp()->getOperand(0).getDefiningOp()); - unsigned bitWidth = ty.getElementType().getIntOrFloatBitWidth(); - sharedEnc = ttg::SharedEncodingAttr::get( - ty.getContext(), dotOpEnc, ty.getShape(), - ttg::getOrder(ty.getEncoding()), CTALayout, bitWidth, needTrans); - } else { - // MMAv3 - sharedEnc = ttg::SharedEncodingAttr::get(ty.getContext(), ty.getShape(), - ttg::getOrder(ty.getEncoding()), - CTALayout, ty.getElementType()); - } - // FIXME(Keren): block ptr not handled - loadsBufferType[loadOp] = - RankedTensorType::get(bufferShape, ty.getElementType(), sharedEnc); - } -} - -void LoopPipeliner::createOrderedDeps() { - for (Operation &op : *forOp.getBody()) { - if (depOps.contains(&op)) - orderedDeps.push_back(&op); - else if (op.getNumResults() > 0 && validLoads.contains(op.getResult(0))) - orderedDeps.push_back(&op); - } - assert(depOps.size() + validLoads.size() == orderedDeps.size() && - "depOps contains invalid values"); -} - -int LoopPipeliner::getValueDefStage(Value v, int stage) { - if (stage < 0) - return -1; - if (auto arg = v.dyn_cast()) { - if (arg.getArgNumber() > 0) - return getValueDefStage(yieldOp->getOperand(arg.getArgNumber() - 1), - stage - 1); - llvm_unreachable("Loop induction variable should not be a dependency"); - } else - return stage; -} - -ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(tt::LoadOp loadOp, - OpBuilder &builder) { - // Allocate a buffer for each pipelined tensor - // shape: e.g. (numStages==4), <32x64xbf16> -> <4x32x64xbf16> - Value convertLayout = loadsMapping[loadOp]; - if (auto tensorType = convertLayout.getType().dyn_cast()) - return builder.create(convertLayout.getLoc(), - loadsBufferType[loadOp]); - llvm_unreachable("Async copy's return should be of RankedTensorType"); -} - -LogicalResult LoopPipeliner::initialize() { - // All ops that maybe pipelined - SetVector ops; - - if (collectOps(ops).failed()) - return failure(); - - if (checkOpUses(ops).failed()) - return failure(); - - // XXX(Keren): hopper specific, should be cleaned up - checkHopperDots(ops); - - checkOpShareBarriers(ops); - - checkOpDeps(ops); - - createBufferTypes(); - - createOrderedDeps(); - - return success(); -} - -Value LoopPipeliner::getLoadMask(tt::LoadOp loadOp, Value mappedMask, - Value loopCond, OpBuilder &builder) { - Type maskType = tt::getI1SameShape(loadOp.getType()); - Value mask = loadOp.getMask(); - Value newMask; - if (mask) { - Value cond = loopCond; - if (isa(maskType)) { - cond = builder.create(mask.getLoc(), maskType, loopCond); - } - newMask = builder.create(mask.getLoc(), mappedMask, cond); - } else { - if (isa(maskType)) { - newMask = - builder.create(loopCond.getLoc(), maskType, loopCond); - } else { - newMask = loopCond; - } - } - return newMask; -} - -void LoopPipeliner::emitPrologue() { - OpBuilder builder(forOp); - // Get init operands for loop carried values - for (BlockArgument &arg : forOp.getRegionIterArgs()) { - OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg); - setValueMapping(arg, operand.get(), 0); - } - - // Alloc a vector of MBarriers in size numStages for each load to be pipelined - bool isMcast = false; - for (Value loadOp : validLoads) { - auto load = cast(loadOp.getDefiningOp()); - if (isLoadFromTensorPtr(load)) { - auto loadTy = loadOp.getType().cast(); - auto CTALayout = ttg::CTALayoutAttr::get( - load.getContext(), - /*CTAsPerCGA*/ {static_cast(numCTAs)}, - /*CTASplitNum*/ {1}, - /*CTAOrder*/ {0}); - auto sharedEncoding = ttg::SharedEncodingAttr::get( - load.getContext(), 1, 1, 1, {0}, CTALayout, false); - auto mBarriersTy = RankedTensorType::get( - {numStages}, builder.getIntegerType(64), sharedEncoding); - - if (!loadsCanShareBarriers[loadOp]) { - Value fullBarriers = builder.create( - load.getLoc(), mBarriersTy, 1); - loadsFullBarriers[loadOp] = fullBarriers; - } - auto layout = loadTy.getEncoding(); - auto CTASplitNum = ttg::getCTASplitNum(layout); - auto CTAsPerCGA = ttg::getCTAsPerCGA(layout); - if (CTASplitNum != CTAsPerCGA) { - isMcast = true; - // FIXME: numConsumerThreads could be 32 as well instead of 128 - // incase the consumer is not GMMA - unsigned arriveCnt = ttg::getNumWarpsPerCTA(layout); - if (hasHopperDot) - arriveCnt /= 4; - arriveCnt *= - product(CTAsPerCGA) / product(CTASplitNum); - - Value emptyBarriers = builder.create( - load.getLoc(), mBarriersTy, arriveCnt); - loadsEmptyBarriers[loadOp] = emptyBarriers; - } - } - } - - if (isMcast) { - builder.create(forOp.getLoc(), /*relaxed*/ 1); - builder.create(forOp.getLoc()); - } - - // prologue from [0, numStage-1) - Value iv = forOp.getLowerBound(); - pipelineIterIdx = builder.create(iv.getLoc(), 0, 32); - for (int stage = 0; stage < numStages - 1; ++stage) { - // Special handling for induction variable as the increment is implicit - if (stage != 0) - iv = builder.create(iv.getLoc(), iv, forOp.getStep()); - setValueMapping(forOp.getInductionVar(), iv, stage); - - // Special handling for loop condition as there is no condition in ForOp - Value loopCond = builder.create( - iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound()); - for (Operation *op : orderedDeps) { - Operation *newOp = nullptr; - if (validLoads.contains(op->getResult(0))) { - auto load = cast(op); - // Allocate empty buffer - if (stage == 0) { - loadsBuffer[load] = allocateEmptyBuffer(load, builder); - loadStageBuffer[load] = {loadsBuffer[load]}; - } - // load => copy async - if (auto loadOp = llvm::dyn_cast(op)) { - Value newMask = - getLoadMask(loadOp, lookupOrDefault(loadOp.getMask(), stage), - loopCond, builder); - - if (mode && isLoadFromTensorPtr(loadOp)) { - auto loc = op->getLoc(); - auto mBarTy = tt::PointerType::get(builder.getIntegerType(64), 3); - Value stageVal = - builder.create(loc, stage, 32); - // producer_acquire - if (loadsEmptyBarriers.count(loadOp)) { - Value emptyBarrier = builder.create( - loc, mBarTy, loadsEmptyBarriers[loadOp], stageVal); - auto trueVal = - builder.create(loc, 1, /*bitWidth*/ 1); - builder.create(loc, emptyBarrier, trueVal); - } - - // producer_commit - Value fullBarrier; - if (!loadsCanShareBarriers[loadOp]) { - fullBarrier = builder.create( - loc, mBarTy, loadsFullBarriers[loadOp], stageVal); - loadsExtract[loadOp] = fullBarrier; - } else { - // Reuse the barrier from previouse load. - fullBarrier = loadsExtract[loadsCanShareBarriers[loadOp]]; - } - - auto loadTy = loadOp.getType().dyn_cast(); - assert(loadTy); - auto CTASplitNum = ttg::getCTASplitNum(loadTy.getEncoding()); - auto shapePerSlice = - ttg::getShapePerCTA(CTASplitNum, loadTy.getShape()); - unsigned elems = - std::accumulate(shapePerSlice.begin(), shapePerSlice.end(), 1, - std::multiplies{}); - elems *= (loadTy.getElementType().getIntOrFloatBitWidth() / 8); - - if (!loadsCanShareBarriers[loadOp]) { - Value _0 = builder.create(loc, 0, 32); - Value threadId = builder.create(loc); - Value pred = builder.create( - loc, arith::CmpIPredicate::eq, threadId, _0); - pred = builder.create(loc, pred, loopCond); - Operation *barrierArvOp = builder.create( - loc, fullBarrier, pred, - /*remoteCtaId*/ nullptr, /*trackAsyncOp*/ false, elems); - loadsBarrierArvOp[loadOp] = barrierArvOp; - } else { - // Increase the transcnt for barrier of previouse load by the - // bytes of current load. - Operation *barrierArvOp = - loadsBarrierArvOp[loadsCanShareBarriers[loadOp]]; - unsigned base_elems = - barrierArvOp->getAttr("txCount").cast().getInt(); - barrierArvOp->setAttr("txCount", - IntegerAttr::get(builder.getIntegerType(32), - base_elems + elems)); - } - newOp = builder.create( - loc, loadsBuffer[loadOp].getType(), - lookupOrDefault(loadOp.getPtr(), stage), - loadStageBuffer[loadOp][stage], pipelineIterIdx, fullBarrier, - newMask, lookupOrDefault(loadOp.getOther(), stage), - loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile(), - /*axis*/ 0); - } else { - newOp = builder.create( - op->getLoc(), loadsBuffer[loadOp].getType(), - lookupOrDefault(loadOp.getPtr(), stage), - loadStageBuffer[loadOp][stage], pipelineIterIdx, newMask, - lookupOrDefault(loadOp.getOther(), stage), loadOp.getCache(), - loadOp.getEvict(), loadOp.getIsVolatile(), /*axis*/ 0); - builder.create(op->getLoc()); - } - loadStageBuffer[loadOp].push_back(newOp->getResult(0)); - } else - llvm_unreachable("This should be LoadOp"); - } else { - if (auto loadOp = dyn_cast(op)) { - Value newMask = - getLoadMask(loadOp, lookupOrDefault(loadOp.getMask(), stage), - loopCond, builder); - newOp = builder.create( - loadOp.getLoc(), loadOp.getResult().getType(), - lookupOrDefault(loadOp.getPtr(), stage), newMask, - lookupOrDefault(loadOp.getOther(), stage), - loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), - loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); - addNamedAttrs(newOp, op->getDiscardableAttrDictionary()); - } else - newOp = builder.clone(*op); - // Update loop-carried uses - for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) { - auto it = valueMapping.find(op->getOperand(opIdx)); - if (it != valueMapping.end()) { - Value v = it->second[stage]; - assert(v && "Value not found in valueMapping"); - newOp->setOperand(opIdx, v); - } // else, op at opIdx is a loop-invariant value - } - } - - for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) { - Value originResult = op->getResult(dstIdx); - if (validLoads.contains(originResult)) - break; - setValueMapping(originResult, newOp->getResult(dstIdx), stage); - // Update mapping for loop-carried values (args) - setValueMappingYield(op->getResult(dstIdx), newOp->getResult(dstIdx), - stage + 1); - } - } // for (Operation *op : orderedDeps) - - // Update pipeline index - pipelineIterIdx = builder.create( - iv.getLoc(), pipelineIterIdx, - builder.create(iv.getLoc(), 1, 32)); - Value numSlices = builder.create( - iv.getLoc(), numSharedMemorySlices, 32); - Value _0 = builder.create(iv.getLoc(), 0, 32); - pipelineIterIdx = getBoundedIterationValue(builder, pipelineIterIdx, - numSlices, pipelineIterIdx, _0); - // Some values have not been used by any ops in the loop body - for (BlockArgument arg : forOp.getRegionIterArgs()) - setValueMappingYield(arg, valueMapping[arg][stage], stage + 1); - } // for (int stage = 0; stage < numStages - 1; ++stage) - - // async.wait & extract_slice - if (numLoadsRequireAsyncWait > 0) - builder.create(validLoads.front().getLoc(), - validLoads.size() * (numStages - 2)); - for (Value loadOp : validLoads) { - auto bufferType = loadStageBuffer[loadOp][numStages - 1] - .getType() - .cast(); - auto bufferShape = bufferType.getShape(); - auto sliceType = loadsMapping[loadOp].getType().cast(); - sliceType = RankedTensorType::get({bufferShape[1], bufferShape[2]}, - sliceType.getElementType(), - loadsBufferType[loadOp].getEncoding()); - Value extractSlice = builder.create( - loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1], - SmallVector{int_attr(0), int_attr(0), int_attr(0)}, - SmallVector{int_attr(1), - int_attr(sliceType.getShape()[0]), - int_attr(sliceType.getShape()[1])}, - SmallVector{int_attr(1), int_attr(1), int_attr(1)}); - loadsExtract[loadOp] = extractSlice; - } - curWaitIdx = builder.create(iv.getLoc(), 0, 32); - loopIterIdx = builder.create(iv.getLoc(), 0, 32); - curPhase = builder.create(iv.getLoc(), 0, 1); - curEmptyPhase = builder.create(iv.getLoc(), 1, 1); -} - -void LoopPipeliner::emitEpilogue() { - // If there's any outstanding async copies, we need to wait for them. - if (numLoadsRequireAsyncWait > 0) { - OpBuilder builder(forOp); - OpBuilder::InsertionGuard g(builder); - builder.setInsertionPointAfter(forOp); - builder.create(forOp.getLoc(), 0); - } -} - -SmallVector LoopPipeliner::collectNewLoopArgs() { - // Order of new args: - // (original args) - // (insertSliceAsync buffer at stage numStages - 1) for each load - // (extracted tensor) for each load - // (depArgs at stage numStages - 1) - // (depArgs at stage numStages - 2) - // ... - // (iv at stage numStages - 2) - // (pipeline iteration index) - // (loop iteration index) - // (wait index) - // (phase index) - // (empty phase index) - - // We need this to update operands for yield - // original block arg => new arg's idx - SmallVector newLoopArgs; - for (auto v : forOp.getInitArgs()) - newLoopArgs.push_back(v); - - bufferIdx = newLoopArgs.size(); - for (auto loadOp : validLoads) - newLoopArgs.push_back(loadStageBuffer[loadOp].back()); - - loadIdx = newLoopArgs.size(); - for (auto loadOp : validLoads) - newLoopArgs.push_back(loadsExtract[loadOp]); - - depArgsBeginIdx = newLoopArgs.size(); - for (auto depArg : depArgs) { - depArgsIdx[depArg] = newLoopArgs.size(); - if (immediateArgStages[depArg].contains(numStages - 2)) - // Peel off post load ops in numStage-1 - newLoopArgs.push_back(valueMapping[depArg][numStages - 2]); - else - newLoopArgs.push_back(valueMapping[depArg][numStages - 1]); - } - - ivIdx = newLoopArgs.size(); - newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]); - newLoopArgs.push_back(pipelineIterIdx); - newLoopArgs.push_back(curWaitIdx); - if (numLoadsRequireMBarrier > 0) { - newLoopArgs.push_back(loopIterIdx); - newLoopArgs.push_back(curPhase); - newLoopArgs.push_back(curEmptyPhase); - } - - return newLoopArgs; -} - -scf::ForOp LoopPipeliner::cloneForOp(ArrayRef newLoopArgs, - OpBuilder &builder) { - // Clone the original ForOp - auto newForOp = builder.create( - forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), newLoopArgs); - - // Set mapping on body of the new ForOp - builder.setInsertionPointToStart(newForOp.getBody()); - for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) - mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); - mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); - - // Loop iteration args - Value upperBound = newForOp.getUpperBound(); - Value step = newForOp.getStep(); - Value curIV = newForOp.getRegionIterArgs()[ivIdx]; - pipelineIterIdx = newForOp.getRegionIterArgs()[ivIdx + 1]; - curWaitIdx = newForOp.getRegionIterArgs()[ivIdx + 2]; - if (numLoadsRequireMBarrier > 0) { - loopIterIdx = newForOp.getRegionIterArgs()[ivIdx + 3]; - curPhase = newForOp.getRegionIterArgs()[ivIdx + 4]; - curEmptyPhase = newForOp.getRegionIterArgs()[ivIdx + 5]; - } - - // Clone the loop body, replace original args with args of the new ForOp. - SmallVector loadsFromTensorPtr; - for (Operation &op : forOp.getBody()->without_terminator()) { - if (auto cvtOp = dyn_cast(op)) { - auto result = op.getResult(0); - auto cvtDstTy = result.getType().cast(); - auto it = - std::find(validLoads.begin(), validLoads.end(), op.getOperand(0)); - if (it != validLoads.end()) { - auto loadArgIdx = std::distance(validLoads.begin(), it); - if (cvtDstTy.getEncoding().isa()) { - // We want to find cvt ops that match the following pattern: - // %0 = load %ptr - // %1 (dotOperand) = cvt %0 - // We replace the use new load use with a convert layout - auto cvt = builder.create( - result.getLoc(), cvtDstTy, - newForOp.getRegionIterArgs()[loadIdx + loadArgIdx]); - mapping.map(result, cvt.getResult()); - continue; - } else if (cvtDstTy.getEncoding().isa()) { - // We want to find cvt ops that match the following pattern: - // %0 = load %ptr - // %1 (sharedEncoding) = cvt %0 - // We replace the use new load use with insert_slice_async's result - mapping.map(result, - newForOp.getRegionIterArgs()[loadIdx + loadArgIdx]); - continue; - } - } - } else if (auto loadOp = dyn_cast(op)) { - if (isLoadFromTensorPtr(loadOp)) { - // XXX(Keren): The comparison operator using std::find on tensor ptr - // doesn't work as expected - auto operand = loadOp.getPtr(); - auto tensorTy = - operand.getType().cast().getPointeeType(); - auto loadArgIdx = 0; - for (auto validLoad : validLoads) { - auto defOp = cast(validLoad.getDefiningOp()); - if (isLoadFromTensorPtr(defOp)) { - auto validOperand = defOp.getOperand(0); - auto validTensorTy = - validOperand.getType().cast().getPointeeType(); - if (tensorTy == validTensorTy) - break; - } - loadArgIdx++; - } - // consumer_wait, emitted before the first consumer - auto firstConsumer = getFirstUser(loadOp); - mapping.map(loadOp, newForOp.getRegionIterArgs()[loadIdx + loadArgIdx]); - - // If current load can reuse barriers shared by previous load, then we - // do nothing. - if (!loadsCanShareBarriers[loadOp]) { - // emit mbarrier wait before the first consumer of the loaD - OpBuilder mBarBuilder(firstConsumer); - auto mBarTy = tt::PointerType::get(builder.getIntegerType(64), 3); - Value fullBarrier = mBarBuilder.create( - loadOp.getLoc(), mBarTy, loadsFullBarriers[loadOp], curWaitIdx); - mBarBuilder.create(loadOp.getLoc(), fullBarrier, - curPhase); - } - - loadsFromTensorPtr.push_back(loadOp); - continue; - } - } - cloneWithInferType(builder, &op, mapping); - } - - for (Value load : loadsFromTensorPtr) { - // consumer_relase, emitted after the last consumer - // 'the last consumer' might be updated in the following Phase_1 since - // some of the consumers might be pipelined. Thus we maintain this - // information in 'consumerReleaseMap' and move the position of - // consumer_release barrier in a seperate Phase_2 in case necessary. - if (loadsEmptyBarriers.count(load)) { - auto users = mapping.lookup(load).getUsers(); - DenseMap consumerStageMap; - for (Operation *user : users) { - // All the stage is initialized to zero before Phase_1, - // since no consumers has been pipelined yet. - consumerStageMap[user] = 0; - } - auto CTALayout = ttg::getCTALayout( - load.getType().cast().getEncoding()); - ConsumerReleaseInfo info{ - loopIterIdx, pipelineIterIdx, curEmptyPhase, curIV, - step, upperBound, CTALayout, consumerStageMap}; - consumerReleaseMap[loadsEmptyBarriers[load]] = info; - } - } - - // Remove redundant conversions - // e.g., %145 = triton_gpu.convert_layout %arg15 : (tensor<128x64xf16, - // #shared1>) -> tensor<128x64xf16, #shared1> - for (Operation &op : newForOp.getBody()->without_terminator()) { - if (auto convert_layout = dyn_cast(op)) { - auto result = op.getResult(0); - auto cvtDstTy = result.getType(); - auto operand = convert_layout.getOperand(); - auto tensorTy = operand.getType(); - if (cvtDstTy == tensorTy) - result.replaceAllUsesWith(operand); - } - } - - return newForOp; -} - -Value LoopPipeliner::getBoundedIterationValue(OpBuilder &builder, Value curIdx, - Value upperBoundIdx, - Value curValue, Value initValue) { - Value cond = builder.create( - curIdx.getLoc(), arith::CmpIPredicate::uge, curIdx, upperBoundIdx); - Value selectValue = builder.create( - curIdx.getLoc(), cond, initValue, curValue); - return selectValue; -} - -void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp, - OpBuilder &builder) { - // Map the dep args of the next iteration to the dep args of the current - size_t argIdx = 0; - for (auto depArg : depArgs) { - BlockArgument nextArg = - newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]; - nextMapping.map(depArg, nextArg); - ++argIdx; - } - - // Update loop iteration args - Value curIV = newForOp.getRegionIterArgs()[ivIdx]; - pipelineIterIdx = newForOp.getRegionIterArgs()[ivIdx + 1]; - curWaitIdx = newForOp.getRegionIterArgs()[ivIdx + 2]; - if (numLoadsRequireMBarrier > 0) { - loopIterIdx = newForOp.getRegionIterArgs()[ivIdx + 3]; - curPhase = newForOp.getRegionIterArgs()[ivIdx + 4]; - curEmptyPhase = newForOp.getRegionIterArgs()[ivIdx + 5]; - } - - // Special handling for iv & loop condition - auto idxLoc = curIV.getLoc(); - nextIV = builder.create(idxLoc, curIV, newForOp.getStep()); - Value nextLoopCond = builder.create( - idxLoc, arith::CmpIPredicate::slt, nextIV, newForOp.getUpperBound()); - - // Constants - Value _0 = builder.create(idxLoc, 0, 32); - Value _1 = builder.create(idxLoc, 1, 32); - Value numStagesVal = - builder.create(idxLoc, numStages, 32); - Value numSlices = - builder.create(idxLoc, numSharedMemorySlices, 32); - - // nextWaitIdx - Value waitIdxPlusOne = builder.create(idxLoc, curWaitIdx, _1); - Value nextWaitIdx = getBoundedIterationValue(builder, waitIdxPlusOne, - numSlices, waitIdxPlusOne, _0); - - // Indices of InsertSliceAsyncOp and ExtractSliceOp - Value insertSliceIndex = pipelineIterIdx; - Value extractSliceIndex = nextWaitIdx; - - // Prefetch load deps - // If a load-dependent instruction that uses a block argument, we - // shouldn't update the new mapping of the block argument in the current - // iteration. - // For example. - // %a = add %arg0, %c - // %b = add %arg0, %d - // - // Update %arg0 will cause the value of %b to be incorrect. - // We do need to use the next iteration value of %arg0 because it could be a - // immediate arg of a load op. - // load %arg0 - // %a = add %arg0, %c - // yield %a - // - // We reroder instructions so %a and yield are actually before load. load - // %arg0 should use the updated %arg0. - IRMapping curMapping = nextMapping; - for (Operation *op : orderedDeps) - if (!validLoads.contains(op->getResult(0))) { - if (immediateOpStages[op].contains(numStages - 2)) - // A post load op that provides values for numStage - 2 - curMapping.map(forOp.getInductionVar(), curIV); - else - curMapping.map(forOp.getInductionVar(), nextIV); - Operation *nextOp; - if (auto loadOp = dyn_cast(op)) { - auto newMask = - getLoadMask(loadOp, curMapping.lookupOrDefault(loadOp.getMask()), - nextLoopCond, builder); - nextOp = builder.create( - loadOp.getLoc(), loadOp.getResult().getType(), - curMapping.lookupOrDefault(loadOp.getPtr()), newMask, - curMapping.lookupOrDefault(loadOp.getOther()), - loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), - loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); - addNamedAttrs(nextOp, op->getDiscardableAttrDictionary()); - curMapping.map(loadOp.getResult(), nextOp->getResult(0)); - nextMapping.map(loadOp.getResult(), nextOp->getResult(0)); - } else { - nextOp = builder.clone(*op, curMapping); - for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) - nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx)); - } - - for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) - setValueMappingYield(newForOp, op->getResult(dstIdx), - nextOp->getResult(dstIdx)); - } - - // loads -> async loads - for (Operation *op : orderedDeps) { - Operation *nextOp = nullptr; - // Update loading mask - if (validLoads.contains(op->getResult(0))) { - auto loadOp = llvm::cast(op); - auto mask = loadOp.getMask(); - auto newMask = - getLoadMask(loadOp, nextMapping.lookupOrDefault(loadOp.getMask()), - nextLoopCond, builder); - if (mask) { - // If mask is defined outside the loop, don't update the map more than - // once - if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask))) - nextMapping.map(mask, newMask); - newMask = nextMapping.lookupOrDefault(loadOp.getMask()); - } - Value insertedVal; - if (mode && isLoadFromTensorPtr(loadOp)) { - auto loc = op->getLoc(); - auto mBarTy = tt::PointerType::get(builder.getIntegerType(64), 3); - - // producer_acquire - if (loadsEmptyBarriers.count(loadOp)) { - auto ifOp = builder.create(loc, ArrayRef{}, - nextLoopCond, false); - builder.setInsertionPointToStart(ifOp.thenBlock()); - Value emptyBarrier = builder.create( - loc, mBarTy, loadsEmptyBarriers[loadOp], insertSliceIndex); - builder.create(loc, emptyBarrier, - curEmptyPhase); - builder.setInsertionPointAfter(ifOp); - } - - // producer_commit - Value fullBarrier; - if (!loadsCanShareBarriers[loadOp]) { - fullBarrier = builder.create( - loc, mBarTy, loadsFullBarriers[loadOp], insertSliceIndex); - loadsExtract[loadOp] = fullBarrier; - } else { - // Reuse the barrier from previouse load. - fullBarrier = loadsExtract[loadsCanShareBarriers[loadOp]]; - } - - auto loadTy = loadOp.getType().dyn_cast(); - assert(loadTy); - auto CTASplitNum = ttg::getCTASplitNum(loadTy.getEncoding()); - auto shapePerSlice = - ttg::getShapePerCTA(CTASplitNum, loadTy.getShape()); - unsigned elems = std::accumulate( - shapePerSlice.begin(), shapePerSlice.end(), 1, std::multiplies{}); - elems *= (loadTy.getElementType().getIntOrFloatBitWidth() / 8); - if (!loadsCanShareBarriers[loadOp]) { - Value _0 = builder.create(loc, 0, 32); - Value threadId = builder.create(loc); - Value pred = builder.create( - loc, arith::CmpIPredicate::eq, threadId, _0); - pred = builder.create(loc, pred, nextLoopCond); - Operation *barrierArvOp = builder.create( - loc, fullBarrier, pred, - /*remoteCtaId*/ nullptr, - /*trackAsyncOp*/ false, elems); - loadsBarrierArvOp[loadOp] = barrierArvOp; - } else { - // Increase the transcnt for barrier of previouse load by the bytes of - // current load. - Operation *barrierArvOp = - loadsBarrierArvOp[loadsCanShareBarriers[loadOp]]; - unsigned base_elems = - barrierArvOp->getAttr("txCount").cast().getInt(); - barrierArvOp->setAttr( - "txCount", - IntegerAttr::get(builder.getIntegerType(32), base_elems + elems)); - } - insertedVal = builder.create( - loc, loadsBuffer[loadOp].getType(), - nextMapping.lookupOrDefault(loadOp.getPtr()), - newForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()], - insertSliceIndex, fullBarrier, newMask, - nextMapping.lookupOrDefault(loadOp.getOther()), loadOp.getCache(), - loadOp.getEvict(), loadOp.getIsVolatile(), /*axis*/ 0); - } else { - insertedVal = builder.create( - op->getLoc(), loadsBuffer[loadOp].getType(), - nextMapping.lookupOrDefault(loadOp.getPtr()), - newForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()], - insertSliceIndex, newMask, - nextMapping.lookupOrDefault(loadOp.getOther()), loadOp.getCache(), - loadOp.getEvict(), loadOp.getIsVolatile(), /*axis*/ 0); - builder.create(op->getLoc()); - } - nextBuffers.push_back(insertedVal); - // Extract slice - auto bufferType = insertedVal.getType().cast(); - auto bufferShape = bufferType.getShape(); - auto sliceType = loadsMapping[loadOp].getType().cast(); - sliceType = RankedTensorType::get({bufferShape[1], bufferShape[2]}, - sliceType.getElementType(), - loadsBufferType[loadOp].getEncoding()); - - nextOp = builder.create( - op->getLoc(), sliceType, insertedVal, - SmallVector{extractSliceIndex, int_attr(0), - int_attr(0)}, - SmallVector{int_attr(1), - int_attr(sliceType.getShape()[0]), - int_attr(sliceType.getShape()[1])}, - SmallVector{int_attr(1), int_attr(1), int_attr(1)}); - extractSlices.push_back(nextOp->getResult(0)); - - // Update mapping of results - for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) - // If this is a loop-carried value, update the mapping for yield - setValueMappingYield(newForOp, op->getResult(dstIdx), - nextOp->getResult(dstIdx)); - } - } - - // Some values have not been used by any ops in the loop body - for (BlockArgument arg : forOp.getRegionIterArgs()) - setValueMappingYield(newForOp, arg, - newForOp.getRegionIterArgs()[depArgsIdx[arg]]); - - // async.wait & extract_slice - if (numLoadsRequireAsyncWait > 0) { - Operation *asyncWait = builder.create( - validLoads[0].getLoc(), validLoads.size() * (numStages - 2)); - for (auto it = extractSlices.rbegin(); it != extractSlices.rend(); ++it) { - // move extract_slice after asyncWait - it->getDefiningOp()->moveAfter(asyncWait); - } - } - - // Bump pipelineIterIdx - Value pipelineIterIdxPlusOne = - builder.create(idxLoc, pipelineIterIdx, _1); - pipelineIterIdx = getBoundedIterationValue( - builder, pipelineIterIdxPlusOne, numSlices, pipelineIterIdxPlusOne, _0); - - // Bump curWaitIdx - curWaitIdx = nextWaitIdx; - - if (numLoadsRequireMBarrier > 0) { - // Bump loopIterIdx - loopIterIdx = builder.create(idxLoc, loopIterIdx, _1); - - Value _1_1b = builder.create(idxLoc, 1, 1); - - // Flip curPhase - Value nextPhase = builder.create(idxLoc, curPhase, _1_1b); - curPhase = getBoundedIterationValue(builder, waitIdxPlusOne, numStagesVal, - curPhase, nextPhase); - - // Flip curEmptyPhase - Value nextEmptyPhase = - builder.create(idxLoc, curEmptyPhase, _1_1b); - curEmptyPhase = - getBoundedIterationValue(builder, pipelineIterIdxPlusOne, numStagesVal, - curEmptyPhase, nextEmptyPhase); - } -} - -void LoopPipeliner::finalizeYield(scf::ForOp newForOp, OpBuilder &builder) { - SmallVector yieldValues; - for (Value v : yieldOp->getOperands()) - yieldValues.push_back(mapping.lookup(v)); - for (Value nextBuffer : nextBuffers) - yieldValues.push_back(nextBuffer); - for (Value nextSlice : extractSlices) - yieldValues.push_back(nextSlice); - - for (size_t i = depArgsBeginIdx; i < ivIdx; ++i) { - auto arg = newForOp.getRegionIterArgs()[i]; - assert(depArgsMapping.count(arg) && "Missing loop-carried value"); - yieldValues.push_back(depArgsMapping[arg]); - } - - // Loop iteration args - yieldValues.push_back(nextIV); - yieldValues.push_back(pipelineIterIdx); - yieldValues.push_back(curWaitIdx); - if (numLoadsRequireMBarrier > 0) { - yieldValues.push_back(loopIterIdx); - yieldValues.push_back(curPhase); - yieldValues.push_back(curEmptyPhase); - } - - builder.setInsertionPointToEnd(newForOp.getBody()); - builder.create(yieldOp->getLoc(), yieldValues); -} - -scf::ForOp LoopPipeliner::createNewForOp() { - OpBuilder builder(forOp); - auto newLoopArgs = collectNewLoopArgs(); - auto newForOp = cloneForOp(newLoopArgs, builder); - prefetchNextIteration(newForOp, builder); - finalizeYield(newForOp, builder); - return newForOp; -} - -// ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp -struct PipelinePass : public TritonGPUPipelineBase { - PipelinePass() = default; - PipelinePass(int numStages, int numWarps, int numCTAs, - int computeCapability) { - this->numStages = numStages; - this->numWarps = numWarps; - this->numCTAs = numCTAs; - this->computeCapability = computeCapability; - } - - void runOnOperation() override { - // TODO[goostavz]: mode = 0 is temporary for backward compatible, will be - // deprecated after the refactor of pipeline fully gets done - // TODO[goostavz]: When mode = 1, the mask of prefetch insert_slice in the - // prologue is currently not properly provided. Need some second thought on - // the mask definition of InsertSliceOp when the src is ptr - bool mode = - computeCapability >= 90 && ::triton::tools::getBoolEnv("ENABLE_TMA"); - if (this->numStages <= 1) - return; - - // phase 0: pipeline loads in loops - // Pre-processing - // we make sure element-wise ops are done *after* the conversion - // to dot operands - // we can achieve this with simple recursive pattern matching - // MLIRContext *context = &getContext(); - // mlir::RewritePatternSet patterns(context); - // patterns.add(context); - // auto didPreprocess = - // applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - - llvm::SmallVector newForOps; - - // Currently we schedule stage 0 after stage `numStages - 1` during - // pipelining therefore we only need `numStages - 1` slice of memory. - // On Hopper we have a separate post-processing that pipelines wgmma so we - // need an extra buffer for each input. - // Note that an alternative would be to keep allocating `numStages` buffers - // and remove the barrier between the loads from shared memory and the - // copies from global to shared. This would require improving existing - // membar analysis. - int numSharedMemorySlices = - computeCapability < 90 ? numStages - 1 : numStages; - - // Do the pipelining - getOperation()->walk([&](scf::ForOp forOp) -> void { - LoopPipeliner pipeliner(forOp, this->numStages, this->numWarps, - this->numCTAs, mode, numSharedMemorySlices, - consumerReleaseMap); - if (pipeliner.initialize().failed()) - return; - - pipeliner.emitPrologue(); - scf::ForOp newForOp = pipeliner.createNewForOp(); - pipeliner.emitEpilogue(); - newForOps.push_back(newForOp); - - // Replace the original loop - for (unsigned i = 0; i < forOp->getNumResults(); ++i) - forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i)); - forOp->erase(); - }); - - // phase 1: pipeline dots in loops - // A tt.dot suitable for GMMA will be converted to ttg.dot_async. And a - // ttg.DotWaitOp will synchronize it lagging just one iteration, which is - // a hueristic rule. - for (auto forOp : newForOps) - asyncLaunchDots(forOp); - - // phase 2: emit consumer_release (empty barrier arrive) logics in case of - // TMA multicast. - // For each load ops, it is emitted after its last consumer, if the consumer - // is another async op, find its associated sync op. Each async load will be - // emitted with a consumer_release action. The merge of redundant mbarriers - // will be processed in the consequent OptimizeBarriers pass. - for (const auto &item : consumerReleaseMap) - emitConsumerRelease(item.first, item.second, numStages); - } - -private: - Value getRemoteCTAId(OpBuilder &b, Location loc, ttg::CTALayoutAttr CTALayout, - Value remoteCTAIdIdx) const; - void updateConsumerReleaseInfo(Operation *oldOp, Operation *newOp, int stage); - void asyncLaunchDots(scf::ForOp forOp); - void emitConsumerRelease(Value mbarTensor, const ConsumerReleaseInfo &info, - int numStages); - bool selfDepend(tt::DotOp op, scf::ForOp forOp, Operation **firstUse); - void removeExtraWait(tt::nvidia_gpu::DotWaitOp dotWaitOp, bool hasDotWait0); - ConsumerReleaseMap consumerReleaseMap; -}; - -void PipelinePass::updateConsumerReleaseInfo(Operation *oldOp, Operation *newOp, - int stage) { - for (auto &item : consumerReleaseMap) { - auto &m = item.second.consumerStageMap; - if (m.count(oldOp)) { - m.erase(oldOp); - m[newOp] = stage; - } - - for (Value operand : oldOp->getOperands()) { - Operation *op = operand.getDefiningOp(); - if (op && isa(op)) { - auto cvt = cast(op); - auto src = cvt.getSrc(); - auto srcEncoding = src.getType().cast().getEncoding(); - auto dstEncoding = - cvt.getResult().getType().cast().getEncoding(); - if (srcEncoding == dstEncoding && m.count(op)) { - m.erase(op); - m[newOp] = stage; - } - } - } - } -} - -bool PipelinePass::selfDepend(tt::DotOp dotOp, scf::ForOp forOp, - Operation **firstUse) { - std::function dependOn = - [&dependOn](Value v, int argId, scf::ForOp forOp) { - auto op = v.getDefiningOp(); - if (isa(v)) { - auto iterArgs = forOp.getRegionIterArgs(); - auto iter = std::find(iterArgs.begin(), iterArgs.end(), v); - if (iter != iterArgs.end()) - return std::distance(iterArgs.begin(), iter) == argId; - } else { - if (!op) - return false; - for (auto operand : op->getOperands()) { - if (dependOn(operand, argId, forOp)) - return true; - } - } - return false; - }; - auto result = dotOp.getResult(); - auto yieldOp = forOp.getBody()->getTerminator(); - int argIdx = -1; - auto iter = std::find(yieldOp->getOperands().begin(), - yieldOp->getOperands().end(), result); - if (iter != yieldOp->getOperands().end()) - argIdx = std::distance(yieldOp->getOperands().begin(), iter); - if (argIdx == -1) - return false; - for (auto operand : dotOp.getOperands()) { - if (dependOn(operand, argIdx, forOp)) { - auto iterArgs = forOp.getRegionIterArgs(); - *firstUse = iterArgs[argIdx].use_begin().getUser(); - return true; - } - } - return false; -} - -void PipelinePass::removeExtraWait(tt::nvidia_gpu::DotWaitOp dotWaitOp, - bool hasDotWait0) { - if (hasDotWait0) { - for (auto &item : consumerReleaseMap) { - auto &m = item.second.consumerStageMap; - if (m.count(dotWaitOp)) { - m.erase(dotWaitOp); - } - } - dotWaitOp->erase(); - } -} - -void PipelinePass::asyncLaunchDots(scf::ForOp forOp) { - Block *loop = forOp.getBody(); - auto getBlockNumInFor = [](Operation *op, scf::ForOp forOp) { - if (!op) - return -1l; - auto lastOp = op; - while (op->getBlock()->getParentOp() != forOp) { - lastOp = op; - op = op->getBlock()->getParentOp(); - } - return std::distance(lastOp->getBlock()->getParent()->begin(), - lastOp->getBlock()->getIterator()); - }; - /// XXX(Keren): Clean up the following duplicate code with checkDotOp - /// dots to be pipelined - bool hasSyncDot = false; - bool hasDotWait0 = false; - SmallVector allDots; - SmallVector dots; - SmallVector resultNeedSync; - for (Operation &op : *loop) { - if (auto dotWaitOp = dyn_cast(&op)) { - auto attr = dotWaitOp->getAttrOfType("pendings"); - auto pendingCount = attr.getInt(); - if (pendingCount == 0) - hasDotWait0 = true; - } - if (auto dotOp = dyn_cast(&op)) { - allDots.push_back(dotOp); - } - } - for (Operation &op : *loop) { - if (auto dotOp = dyn_cast(&op)) { - auto resTy = dotOp.getResult().getType().dyn_cast(); - if (auto resEnc = resTy.getEncoding().dyn_cast()) { - if (resEnc && resEnc.isHopper()) { - auto dot = dotOp.getResult(); - bool valid = true; - - // all users of dot should be scf.yield - if (!dot.hasOneUse()) - valid = false; - if (!isa(*dot.getUsers().begin())) - valid = false; - - Operation *firstUse = nullptr; - auto depend = selfDepend(dotOp, forOp, &firstUse); - bool selfDirectDepend = (dotOp == firstUse); - for (auto tempInAll : allDots) { - auto iter = std::find(dots.begin(), dots.end(), tempInAll); - if (iter != dots.end()) - continue; - auto db = getBlockNumInFor(tempInAll, forOp); - auto fb = getBlockNumInFor(firstUse, forOp); - if (db < fb || - (db == fb && db >= 0 && tempInAll->isBeforeInBlock(firstUse))) - hasSyncDot = true; - } - auto CArg = dotOp.getOperand(2); - if (!(selfDirectDepend || - (depend && !selfDirectDepend && hasSyncDot)) || - !CArg.hasOneUse()) - valid = false; - - if (valid) { - dots.push_back(dotOp); - resultNeedSync.push_back( - dotOp->getUses().begin()->getOperandNumber()); - } - } - } - } - } - - // Early stop: no need to continue if there is no valid dot in the loop. - if (dots.empty()) - return; - - OpBuilder builder(forOp); - // 0. insert dot_wait after the last dot in the loop as we implicitly pipeline - // wgmma ops by one stage. - // This is needed to prevent shared memory inputs to be overriden before the - // operation is completed. - // TODO: merge this with the rest of the pipelining transformation and look at - // a better representation for async dots. - tt::DotOp lastDot = dots.back(); - auto loc = lastDot.getLoc(); - builder.setInsertionPointAfter(lastDot); - auto dotWait = builder.create( - lastDot.getLoc(), lastDot.getResult(), dots.size()); - - // 1. replace Dot with DotAsync - for (size_t idx = 0; idx < dots.size(); ++idx) { - tt::DotOp dotOp = dots[idx]; - builder.setInsertionPoint(dotOp); - auto dotAsync = builder.create( - dotOp.getLoc(), dotOp.getA(), dotOp.getB(), dotOp.getC(), - dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc()); - dotOp.replaceAllUsesWith(dotAsync.getResult()); - updateConsumerReleaseInfo(dotOp, dotWait, /*stage=*/1); - dotOp->erase(); - } - - hasDotWait0 = hasDotWait0 || hasSyncDot; - - // 2. If there's any outstanding DotAsyncOps, we need to wait for them. - builder.setInsertionPointAfter(forOp); - SmallVector resultTypes(resultNeedSync.size()); - SmallVector yieldThenValues(resultNeedSync.size()); - SmallVector yieldElseValues(resultNeedSync.size()); - for (int i = 0; i < resultNeedSync.size(); ++i) { - resultTypes[i] = forOp->getResult(resultNeedSync[i]).getType(); - yieldThenValues[i] = forOp->getResult(resultNeedSync[i]); - yieldElseValues[i] = forOp->getResult(resultNeedSync[i]); - } - Value loopNotEmpty = builder.create( - loc, arith::CmpIPredicate::slt, forOp.getLowerBound(), - forOp.getUpperBound()); - auto ifOp = builder.create(loc, resultTypes, loopNotEmpty, - /*hasElse*/ true); - builder.setInsertionPointToStart(ifOp.thenBlock()); - for (int i = 0; i < resultNeedSync.size(); ++i) { - Value result = forOp->getResult(resultNeedSync[i]); - if (result.use_empty()) - continue; - auto dotWait = - builder.create(forOp.getLoc(), result, 0); - result.replaceAllUsesExcept(ifOp.getResult(i), dotWait); - yieldThenValues[i] = dotWait.getResult(); - } - auto yieldOpThen = builder.create(loc, yieldThenValues); - builder.setInsertionPointToEnd(ifOp.elseBlock()); - auto yieldOpElse = builder.create(loc, yieldElseValues); - - // 3. potentially remove redundant dot_wait after dot_async if having mutiple - // DotOp in the loop - removeExtraWait(dotWait, hasDotWait0); -} - -Value PipelinePass::getRemoteCTAId(OpBuilder &b, Location loc, - ttg::CTALayoutAttr CTALayout, - Value remoteCTAIdIdx) const { - auto CTAsPerCGA = CTALayout.getCTAsPerCGA(); - auto CTAOrder = CTALayout.getCTAOrder(); - auto CTASplitNum = CTALayout.getCTASplitNum(); - - // Short path when bcastMask is a constant - bool isConstMcastMask = true; - for (unsigned s : CTASplitNum) { - if (s > 1) { - isConstMcastMask = false; - break; - } - } - if (isConstMcastMask) - return remoteCTAIdIdx; - - Value linearCTAId = b.create(loc); - SmallVector multiDimCTAId = - delinearize(b, loc, linearCTAId, CTAsPerCGA, CTAOrder); - auto rank = CTAOrder.size(); - int bcastDim = -1; - for (size_t i = 0; i < rank; ++i) { - if (CTAsPerCGA[i] != CTASplitNum[i]) { - assert(bcastDim < 0 && "bcast in multiple dims is not expected"); - bcastDim = i; - } - } - multiDimCTAId[bcastDim] = remoteCTAIdIdx; - return linearize(b, loc, multiDimCTAId, CTAsPerCGA, CTAOrder); -} - -void PipelinePass::emitConsumerRelease(Value mbarTensor, - const ConsumerReleaseInfo &info, - int numStages) { - Value iterVar = info.iterVar; - Value stage = info.stageVar; - Value phase = info.phaseVar; - Value nextIV = info.nextIVVar; - Value step = info.stepVar; - Value upperBound = info.upperBoundVar; - - const auto &consumerStageMap = info.consumerStageMap; - // find the the last consumer among all the consumers with the largest stage. - SmallVector consumersWithLargestStage; - int maxStage = 0; - for (const auto &it : consumerStageMap) { - if (it.second > maxStage) { - consumersWithLargestStage.clear(); - consumersWithLargestStage.push_back(it.first); - maxStage = it.second; - } else if (it.second == maxStage) { - consumersWithLargestStage.push_back(it.first); - } - } - assert(consumersWithLargestStage.size() > 0); - DenseMap operationId; - consumersWithLargestStage[0]->getBlock()->walk( - [&](Operation *op) { operationId[op] = operationId.size(); }); - size_t maxId = 0; - Operation *lastUserWithLargestStage; - for (Operation *op : consumersWithLargestStage) { - assert(operationId.find(op) != operationId.end()); - size_t userId = operationId[op]; - if (userId > maxId) { - maxId = userId; - lastUserWithLargestStage = op; - } - } - - OpBuilder b(&getContext()); - b.setInsertionPointAfter(lastUserWithLargestStage); - auto loc = lastUserWithLargestStage->getLoc(); - auto maxStageVal = b.create(loc, maxStage, 32); - - // pred = (iterVar >= maxStage) && - // (threadId % (numConsumerThreads / numRemoteCTAs) == 0); - - // [benzh] maybe we can simplify the logics here - auto cmpOp = arith::CmpIPredicate::sge; - if (maxStage == 0) - cmpOp = arith::CmpIPredicate::sgt; - Value pred = b.create(loc, cmpOp, iterVar, maxStageVal); - - Value threadId = b.create(loc); - auto CTAsPerCGA = info.CTALayout.getCTAsPerCGA(); - auto CTASplitNum = info.CTALayout.getCTASplitNum(); - auto numRemoteCTAs = std::accumulate(CTAsPerCGA.begin(), CTAsPerCGA.end(), 1, - std::multiplies{}) / - std::accumulate(CTASplitNum.begin(), CTASplitNum.end(), - 1, std::multiplies{}); - auto numConsumerThreads = - isa(lastUserWithLargestStage) ? 128 : 32; - Value _0 = b.create(loc, 0, 32); - Value numArrives = b.create( - loc, numConsumerThreads / numRemoteCTAs, 32); - pred = b.create( - loc, pred, - b.create( - loc, arith::CmpIPredicate::eq, - b.create(loc, threadId, numArrives), _0)); - // remoteCtaIdIdx = (threadId % numConsumerThreads) / (numConsumerThreads / - // numRemoteCTAs); - Value remoteCTAIdIdx = b.create( - loc, - b.create( - loc, threadId, - b.create(loc, numConsumerThreads, 32)), - numArrives); - Value remoteCTAId = getRemoteCTAId(b, loc, info.CTALayout, remoteCTAIdIdx); - Value emptyBarrier = b.create( - loc, tt::PointerType::get(b.getIntegerType(64), 3), mbarTensor, stage); - - Value newNextIV = b.create(loc, nextIV, step); - Value nextLoopCond = b.create(loc, arith::CmpIPredicate::slt, - newNextIV, upperBound); - auto ifOp = b.create(loc, ArrayRef{}, nextLoopCond, - /*hasElse*/ false); - b.setInsertionPointToStart(ifOp.thenBlock()); - - b.create(loc, emptyBarrier, pred, remoteCTAId, - /*trackAsyncOp*/ false); -} - -} // anonymous namespace - -std::unique_ptr mlir::createTritonGPUPipelinePass(int numStages, - int numWarps, - int numCTAs, - int computeCapability) { - return std::make_unique(numStages, numWarps, numCTAs, - computeCapability); -} diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp new file mode 100644 index 000000000..bfee0aaac --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -0,0 +1,826 @@ +#include "PipelineExpander.h" +#include "Schedule.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/Support/Debug.h" + +#define int_attr(num) builder.getI64IntegerAttr(num) + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +// TODO: We can extra some helpers into common utilities once we add more +// schedules. + +/// Replace the yield with a new one with the given operands appended. +static void appendToYield(scf::ForOp forOp, ArrayRef newOperands) { + // Fix up the yield op. + Operation *yieldOp = forOp.getBody()->getTerminator(); + SmallVector operands(yieldOp->getOperands().begin(), + yieldOp->getOperands().end()); + operands.append(newOperands.begin(), newOperands.end()); + OpBuilder builder(yieldOp); + builder.create(yieldOp->getLoc(), operands); + yieldOp->erase(); +} + +static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, + Value insertIdx, Value extractIdx) { + OpBuilder builder(forOp); + // Replace the load with insert/extract slice. + builder.setInsertionPoint(loadOp); + Location loc = loadOp.getLoc(); + auto insertOp = builder.create( + loc, alloc.getType(), loadOp.getPtr(), alloc, insertIdx, loadOp.getMask(), + loadOp.getOther(), loadOp.getCache(), loadOp.getEvict(), + loadOp.getIsVolatile(), /*axis*/ 0); + auto commmit = builder.create(loc); + + // Extract part. + auto allocType = alloc.getType().cast(); + RankedTensorType sliceType = RankedTensorType::get( + {allocType.getShape()[1], allocType.getShape()[2]}, + allocType.getElementType(), allocType.getEncoding()); + auto extract = builder.create( + loc, sliceType, insertOp.getResult(), + SmallVector{extractIdx, int_attr(0), int_attr(0)}, + SmallVector{int_attr(1), int_attr(sliceType.getShape()[0]), + int_attr(sliceType.getShape()[1])}, + SmallVector{int_attr(1), int_attr(1), int_attr(1)}); + Operation *user = *loadOp.getResult().getUsers().begin(); + auto convertLayout = llvm::cast(user); + auto newCvt = builder.create( + convertLayout->getLoc(), convertLayout.getType(), extract.getResult()); + convertLayout->replaceAllUsesWith(newCvt->getResults()); + convertLayout->erase(); + loadOp.erase(); + + // Fix up the yield op. + appendToYield(forOp, {insertOp}); +} + +static void createTMALoad(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, + Value insertIdx, Value extractIdx, Value phase) { + OpBuilder builder(forOp); + Location loc = loadOp.getLoc(); + auto CTALayout = ttg::CTALayoutAttr::get(loadOp.getContext(), + /*CTAsPerCGA*/ {1}, + /*CTASplitNum*/ {1}, + /*CTAOrder*/ {0}); + auto sharedEncoding = ttg::SharedEncodingAttr::get(loadOp.getContext(), 1, 1, + 1, {0}, CTALayout, false); + int64_t numBuffers = alloc.getType().cast().getShape()[0]; + auto mBarriersTy = RankedTensorType::get( + {numBuffers}, builder.getIntegerType(64), sharedEncoding); + // Allocate an array of mbarrier objects outside the loop. + Value barrierArray = + builder.create(loc, mBarriersTy, 1); + // extract the barrier and emit arriver/copy/wait/extract code sequence. + builder.setInsertionPoint(loadOp); + auto mBarTy = tt::PointerType::get(builder.getIntegerType(64), 3); + Value barrier = builder.create( + loc, mBarTy, barrierArray, insertIdx); + Value zero = builder.create(loc, 0, 32); + Value threadId = builder.create(loc); + Value pred = builder.create(loc, arith::CmpIPredicate::eq, + threadId, zero); + + auto loadTy = loadOp.getType().dyn_cast(); + auto loadShape = loadTy.getShape(); + auto CTASplitNum = ttg::getCTASplitNum(loadTy.getEncoding()); + auto shapePerSlice = ttg::getShapePerCTA(CTASplitNum, loadShape); + auto elemTy = loadTy.getElementType(); + unsigned elems = std::accumulate(shapePerSlice.begin(), shapePerSlice.end(), + 1, std::multiplies{}); + elems *= (elemTy.getIntOrFloatBitWidth() / 8); + builder.create(loc, barrier, pred, + /*remoteCtaId*/ nullptr, + /*trackAsyncOp*/ false, elems); + auto allocType = alloc.getType().cast(); + auto insertOp = builder.create( + loc, allocType, loadOp.getPtr(), alloc, + /*index*/ insertIdx, barrier, loadOp.getMask(), loadOp.getOther(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile(), + /*axis*/ 0); + + RankedTensorType sliceType = RankedTensorType::get( + {allocType.getShape()[1], allocType.getShape()[2]}, + allocType.getElementType(), allocType.getEncoding()); + auto extract = builder.create( + loc, sliceType, insertOp.getResult(), + SmallVector{extractIdx, int_attr(0), int_attr(0)}, + SmallVector{int_attr(1), int_attr(sliceType.getShape()[0]), + int_attr(sliceType.getShape()[1])}, + SmallVector{int_attr(1), int_attr(1), int_attr(1)}); + + Value barrierWait = builder.create( + loc, mBarTy, barrierArray, extractIdx); + builder.create(loc, barrierWait, phase); + + Operation *user = *loadOp.getResult().getUsers().begin(); + auto convertLayout = llvm::cast(user); + auto newCvt = builder.create( + convertLayout->getLoc(), convertLayout.getType(), extract.getResult()); + convertLayout->replaceAllUsesWith(newCvt->getResults()); + convertLayout->erase(); + loadOp.erase(); + + // Fix up the yield op. + appendToYield(forOp, {insertOp}); +} + +/// Create an async load equivalent to the given load. +static void createAsyncLoad(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, + Value insertIdx, Value extractIdx, Value phase) { + if (isLoadFromTensorPtr(loadOp)) { + createTMALoad(forOp, loadOp, alloc, insertIdx, extractIdx, phase); + } else { + createAsyncCopy(forOp, loadOp, alloc, insertIdx, extractIdx); + } +} + +// Return the transitive use of the load which is a dot operand. +static Value loadDotOperand(tt::LoadOp loadOp, bool &hasMMAV3) { + // We only pipeline loads that have one covert_layout (to dot_op) use + // TODO: lift this constraint in the future + bool isCandidate = false; + if (!loadOp.getResult().hasOneUse()) + return Value(); + + Operation *use = *loadOp.getResult().getUsers().begin(); + Operation *preUse = nullptr; + + // Advance to the first conversion as long as the use resides in shared + // memory and it has a single use itself + while (use) { + if (use->getNumResults() != 1 || !use->getResult(0).hasOneUse()) + break; + auto tensorType = use->getResult(0).getType().dyn_cast(); + if (!tensorType.getEncoding().isa()) + break; + preUse = use; + use = *use->getResult(0).getUsers().begin(); + } + + if (auto convertLayout = llvm::dyn_cast(use)) { + if (auto tensorType = + convertLayout.getResult().getType().dyn_cast()) { + if (auto dotOpEnc = tensorType.getEncoding() + .dyn_cast()) { + return convertLayout.getResult(); + } + } + } else if (preUse && isa(use)) { + // for MMAv3 whose dot take SharedEncoding as operands directly + Operation *post = *loadOp.getResult().getUsers().begin(); + auto newOrder = post->getResult(0) + .getType() + .cast() + .getEncoding() + .cast() + .getOrder(); + auto ty = loadOp.getType().cast(); + auto oldOrder = ttg::getOrder(ty.getEncoding()); + // The operand of MMAv3 is in SharedEncoding and it's order should not + // be changed after FuseTranspositions Pass. So we only pipeline the + // load if the order of the loaded BlockedEncoding is the same as the + // order of the SharedEncoding it is converted to. + // TODO: remove this constraint once the LoadOp supports transpose + // fusion + if (newOrder[0] == oldOrder[0] || newOrder[1] == oldOrder[1]) { + hasMMAV3 = true; + return preUse->getResult(0); + } + } + return Value(); +} + +namespace { +struct LoadDotOperand { + LoadDotOperand(tt::LoadOp load, Value dotOperand) + : load(load), dotOperand(dotOperand) {} + tt::LoadOp load; + Value dotOperand; +}; +} // namespace + +/// Collect loads to pipeline. Return success if we can pipeline this loop +static void collectOpsToPipeline(scf::ForOp forOp, + SmallVectorImpl &ops, + bool &hasMMAV3) { + ModuleOp moduleOp = forOp->getParentOfType(); + ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + + // We cannot use forOp.walk(...) here because we only want to visit the + // operations in the loop body block. Nested blocks are handled separately. + for (Operation &op : forOp) { + if (auto loadOp = dyn_cast(&op)) { + bool candidate = false; + if (isLoadFromTensorPtr(loadOp)) { + // Map to TMA load. + candidate = true; + } else { + auto ptr = loadOp.getPtr(); + unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr); + if (auto mask = loadOp.getMask()) + vec = + std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); + + auto tensorTy = ptr.getType().dyn_cast(); + if (!tensorTy || tensorTy.getRank() < 2) + continue; + auto ty = + tensorTy.getElementType().cast().getPointeeType(); + unsigned width = vec * ty.getIntOrFloatBitWidth(); + // We do not pipeline all loads for the following reasons: + // 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8 and 16. + // 2. It's likely that pipling small loads won't offer much performance + // improvement and may even hurt performance by increasing register + // pressure. + if (width >= 32) + candidate = true; + } + if (!candidate) + continue; + Value dotOperand = loadDotOperand(loadOp, hasMMAV3); + if (!dotOperand) + continue; + ops.emplace_back(loadOp, dotOperand); + } + } +} + +// Create an allocation that can old distance number of loadOp shapes. +static Value createAlloc(scf::ForOp &forOp, tt::LoadOp loadOp, Value dotOperand, + unsigned distance) { + OpBuilder builder(forOp); + auto ty = loadOp.getType().cast(); + if (!loadOp.getResult().hasOneUse()) + return Value(); + Attribute sharedEnc; + auto CTALayout = ttg::getCTALayout(ty.getEncoding()); + auto tensorType = dotOperand.getType().cast(); + if (auto dotOpEnc = + tensorType.getEncoding().dyn_cast()) { + auto convertLayout = dotOperand.getDefiningOp(); + bool needTrans = dyn_cast_or_null( + convertLayout->getOperand(0).getDefiningOp()); + unsigned bitWidth = ty.getElementType().getIntOrFloatBitWidth(); + sharedEnc = ttg::SharedEncodingAttr::get( + ty.getContext(), dotOpEnc, ty.getShape(), + ttg::getOrder(ty.getEncoding()), CTALayout, bitWidth, needTrans); + } else { + // MMAv3 + sharedEnc = ttg::SharedEncodingAttr::get(ty.getContext(), ty.getShape(), + ttg::getOrder(ty.getEncoding()), + CTALayout, ty.getElementType()); + } + SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); + bufferShape.insert(bufferShape.begin(), distance); + Type allocType = + RankedTensorType::get(bufferShape, ty.getElementType(), sharedEnc); + Value alloc = builder.create( + loadOp.getLoc(), allocType); + return alloc; +} + +// Convert load ops into their asyn version and apply multi-buffering based on +// the number of stages. +static void createAsynOps(scf::ForOp &forOp, ArrayRef loads, + int numStages, bool hasMMAV3) { + struct AsyncLoad { + AsyncLoad(tt::LoadOp loadOp, Value alloc) : loadOp(loadOp), alloc(alloc) {} + tt::LoadOp loadOp; + Value alloc; + }; + int numBuffers = numStages - 1; + // For MMAv3 we need an extra buffer as this is assumed in the wgmma + // pipelining post-processing. + // TODO: Improve modeling of wgmma pipelining. + if (hasMMAV3) + numBuffers++; + SmallVector asyncLoads; + SmallVector newOperands; + bool needsMbarrierPhase = false; + bool needsAsyncWait = false; + for (const LoadDotOperand &loadOperand : loads) { + tt::LoadOp loadOp = loadOperand.load; + Value dotOperand = loadOperand.dotOperand; + Value alloc = createAlloc(forOp, loadOp, dotOperand, numBuffers); + assert(alloc && "Failed to create alloc for the async load."); + newOperands.push_back(alloc); + asyncLoads.emplace_back(loadOp, alloc); + if (isLoadFromTensorPtr(loadOp)) + needsMbarrierPhase = true; + else + needsAsyncWait = true; + } + + OpBuilder builder(forOp); + Location loc = forOp.getLoc(); + // Create two new counters to index into the allocs. + Value minusOne = builder.create(loc, -1, 32); + Value zero = builder.create(loc, 0, 32); + Value one = builder.create(loc, 1, 32); + Value insertIdx = minusOne; + Value extractIdx = minusOne; + Value numBuffersVal = + builder.create(loc, numBuffers, 32); + newOperands.push_back(insertIdx); + newOperands.push_back(extractIdx); + Value phase; + if (needsMbarrierPhase) { + phase = builder.create(loc, 0, 1); + newOperands.push_back(phase); + } + unsigned newOperandIndex = forOp.getBody()->getNumArguments(); + // Patch the loop to add the new loop carried dependencies. + scf::ForOp newForOp = + replaceForOpWithNewSignature(builder, forOp, newOperands); + forOp.erase(); + forOp = newForOp; + for (int i = 0; i < asyncLoads.size(); i++) { + asyncLoads[i].alloc = newForOp.getBody()->getArgument(newOperandIndex + i); + } + insertIdx = + newForOp.getBody()->getArgument(newOperandIndex + asyncLoads.size()); + extractIdx = + newForOp.getBody()->getArgument(newOperandIndex + asyncLoads.size() + 1); + + // Create two counters for the insert and extract indices to avoid creating + // long liverange. + builder.setInsertionPoint(asyncLoads.front().loadOp); + insertIdx = builder.create(loc, insertIdx, one); + Value cndIns = builder.create(loc, arith::CmpIPredicate::slt, + insertIdx, numBuffersVal); + insertIdx = builder.create(loc, cndIns, insertIdx, zero); + + extractIdx = builder.create(loc, extractIdx, one); + Value cndExt = builder.create(loc, arith::CmpIPredicate::slt, + extractIdx, numBuffersVal); + extractIdx = builder.create(loc, cndExt, extractIdx, zero); + + if (needsMbarrierPhase) { + phase = newForOp.getBody()->getArgument(newOperandIndex + + asyncLoads.size() + 2); + Value oneI1 = builder.create(loc, 1, 1); + Value nextPhase = builder.create(loc, phase, oneI1); + phase = builder.create(loc, cndExt, phase, nextPhase); + } + + bool firstLoad = true; + for (AsyncLoad &asyncLoad : asyncLoads) { + createAsyncLoad(forOp, asyncLoad.loadOp, asyncLoad.alloc, insertIdx, + extractIdx, phase); + firstLoad = false; + } + // Insert a waitOp after the first async copy. This does make the assumption + // that the wait will be scheduled in a different stage that all the async + // copy but we cannot guarantee that one wait is enough otherwise. + for (auto &op : forOp.getBody()->without_terminator()) { + if (isa(op)) { + OpBuilder builder(op.getContext()); + builder.setInsertionPointAfter(&op); + builder.create(op.getLoc(), 0); + break; + } + } + SmallVector newYieldOperands = {insertIdx, extractIdx}; + if (needsMbarrierPhase) + newYieldOperands.push_back(phase); + // Patch the yield with the updated counters. + appendToYield(forOp, newYieldOperands); +} + +// Combine the current mask with the given predicate. +static Value getPredMask(RewriterBase &rewriter, Type typeLike, + Value currentMask, Value pred) { + Type maskType = tt::getI1SameShape(typeLike); + Location loc = pred.getLoc(); + Value mask = pred; + if (maskType.isa()) { + mask = rewriter.create(loc, maskType, pred); + } + if (currentMask) { + mask = rewriter.create(loc, mask, currentMask); + } + return mask; +} + +// Function to mask operations during scheduling. +static Operation *predicateOp(RewriterBase &rewriter, Operation *op, + Value pred) { + OpBuilder::InsertionGuard guard(rewriter); + if (mlir::isMemoryEffectFree(op)) + return op; + if (isa(op)) + return op; + if (isa(op)) + return op; + if (auto insertOp = dyn_cast(op)) { + rewriter.setInsertionPoint(insertOp); + Value mask = getPredMask(rewriter, insertOp.getSrc().getType(), + insertOp.getMask(), pred); + insertOp.getMaskMutable().assign(mask); + return op; + } + if (auto insertOp = dyn_cast(op)) { + rewriter.setInsertionPoint(insertOp); + Value mask = getPredMask( + rewriter, + insertOp.getSrc().getType().cast().getPointeeType(), + insertOp.getMask(), pred); + insertOp.getMaskMutable().assign(mask); + return op; + } + if (auto arriveOp = dyn_cast(op)) { + rewriter.setInsertionPoint(arriveOp); + Value mask = getPredMask(rewriter, rewriter.getIntegerType(1), + arriveOp.getPred(), pred); + arriveOp.getPredMutable().assign(mask); + return op; + } + if (isa(op)) { + return op; + } + if (auto loadOp = dyn_cast(op)) { + rewriter.setInsertionPoint(loadOp); + Value mask = getPredMask(rewriter, loadOp.getPtr().getType(), + loadOp.getMask(), pred); + loadOp.getMaskMutable().assign(mask); + return op; + } + + assert("don't know how to predicate this op" && false); + return op; +} + +static void setWaitNum(Operation *op, + mlir::triton::PipeliningOption::PipelinerPart part, + unsigned iteration, unsigned numLoadsInStage) { + if (auto waitOp = dyn_cast(op)) { + waitOp.setNum(numLoadsInStage); + } +} + +/// Helper to recursively add dependencies to the same stage. +static void addDep(Operation *op, DenseSet &deps, + bool includeArg = true, + DenseSet *filter = nullptr) { + if (filter && filter->count(op)) + return; + if (!deps.insert(op).second) + return; + for (Value operand : op->getOperands()) { + Value v = operand; + llvm::SmallDenseSet seen; + while (auto arg = v.dyn_cast()) { + if (!includeArg) + break; + if (!seen.insert(v).second) + break; + if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) { + auto yieldOp = op->getBlock()->getTerminator(); + v = yieldOp->getOperand(arg.getArgNumber() - 1); + continue; + } + break; + } + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + addDep(defOp, deps, includeArg, filter); + } + } +} + +// Add operations to the shedule with the given stage based on the filter +// function. +static void addOps(scf::ForOp forOp, int stage, + std::vector> &schedule, + std::function filter) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!filter(&op)) + continue; + schedule.emplace_back(&op, stage); + } +} + +// create the schedule for a matmul loop. This is ad hoc based on how we know +// matmul loops should be pipelined and is not a generic scheduler. +static std::vector> +createSchedule(scf::ForOp forOp, int numStages, bool prefetchExtract) { + SmallVector insertOps; + SmallVector extractOps; + // Find the insert/extract ops that will go respectively in stage 0 and stage + // `numStages - 2`. All the other operations will go in stage `numStages - 1`. + for (Operation &op : forOp.getBody()->without_terminator()) { + if (isa(op)) + insertOps.emplace_back(&op); + if (prefetchExtract) { + if (isa(op)) + extractOps.emplace_back(&op); + } + } + DenseSet insertAndDeps; + for (Operation *op : insertOps) { + addDep(op, insertAndDeps, false); + } + + // Find depenencies with distance of 1. + SmallVector distanceOneUsers; + for (Operation *op : insertAndDeps) { + for (Value operand : op->getOperands()) { + if (auto arg = operand.dyn_cast()) { + if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) { + auto yieldOp = op->getBlock()->getTerminator(); + Value v = yieldOp->getOperand(arg.getArgNumber() - 1); + Operation *defOp = v.getDefiningOp(); + if (defOp && insertAndDeps.count(defOp) == 0) { + distanceOneUsers.push_back(defOp); + } + } + } + } + } + // Schedule loads with a distance of 1 in stage 0 + for (Operation *op : distanceOneUsers) { + if (isa(op)) { + addDep(op, insertAndDeps, true); + } + } + // For the rest of the ops we can move then into stage 1 so that they can be + // closer to their uses. + DenseSet stage1deps; + for (Operation *op : distanceOneUsers) { + if (!isa(op)) { + addDep(op, stage1deps, true, &insertAndDeps); + } + } + + DenseSet extractAndDeps; + for (Operation *op : extractOps) { + addDep(op, extractAndDeps, true, &insertAndDeps); + } + std::vector> schedule; + // Schedule stage `numStage - 1` first. + addOps(forOp, numStages - 1, schedule, [&](Operation *op) { + return insertAndDeps.count(op) == 0 && stage1deps.count(op) == 0 && + extractAndDeps.count(op) == 0; + }); + + // Schedule some dependencies with distance of 1 into stage 1 to reduce + // pressure. + addOps(forOp, 1, schedule, + [&](Operation *op) { return stage1deps.count(op); }); + + // Then Schedule stage 0. + addOps(forOp, 0, schedule, + [&](Operation *op) { return insertAndDeps.count(op); }); + + // Finally schedule the extract ops in stage `numStage - 2` so that they get + // pre-fetched and play well with pretech pass. + addOps(forOp, numStages - 2, schedule, + [&](Operation *op) { return extractAndDeps.count(op); }); + return schedule; +} + +bool mlir::triton::preProcessLoopAndGetSchedule( + scf::ForOp &forOp, int numStages, mlir::triton::PipeliningOption &options) { + // 1. First collect "interesting" operations with a stage where to schedule + // them. This gives a coarse scheduling for the loop. + SmallVector loads; + bool hasMMAV3 = false; + collectOpsToPipeline(forOp, loads, hasMMAV3); + if (loads.empty()) + return false; + bool hasAsynCp = llvm::any_of(loads, [](LoadDotOperand &load) { + return !isLoadFromTensorPtr(load.load); + }); + // 2. Convert the loads into async loads and create the allocs. + createAsynOps(forOp, loads, numStages, hasMMAV3); + + // 3. Create the final schedule for the kernel loop. This will dictate the + // stages and order of operations to the pipeline expander. + std::vector> schedule = + createSchedule(forOp, numStages, /*prefetchExtract=*/!hasMMAV3); + + // 4. Fill out the pipeline options. + options.getScheduleFn = + [schedule](scf::ForOp forOp, + std::vector> &s) { + s = std::move(schedule); + }; + options.peelEpilogue = false; + options.predicateFn = predicateOp; + options.supportDynamicLoops = true; + unsigned numLoadsInStage = (numStages - 2) * loads.size(); + options.annotateFn = + [numLoadsInStage](Operation *op, + mlir::triton::PipeliningOption::PipelinerPart part, + unsigned iteration) { + return setWaitNum(op, part, iteration, numLoadsInStage); + }; + + if (hasAsynCp) { + // Insert a wait 0 after the loop + OpBuilder builder(forOp); + builder.setInsertionPointAfter(forOp); + builder.create(forOp.getLoc(), 0); + } + return true; +} + +/// MMA V3 post-processing. +static bool selfDepend(tt::DotOp dotOp, scf::ForOp forOp, + Operation **firstUse) { + std::function dependOn = + [&dependOn](Value v, int argId, scf::ForOp forOp) { + auto op = v.getDefiningOp(); + if (isa(v)) { + auto iterArgs = forOp.getRegionIterArgs(); + auto iter = std::find(iterArgs.begin(), iterArgs.end(), v); + if (iter != iterArgs.end()) + return std::distance(iterArgs.begin(), iter) == argId; + } else { + if (!op) + return false; + for (auto operand : op->getOperands()) { + if (dependOn(operand, argId, forOp)) + return true; + } + } + return false; + }; + auto result = dotOp.getResult(); + auto yieldOp = forOp.getBody()->getTerminator(); + int argIdx = -1; + auto iter = std::find(yieldOp->getOperands().begin(), + yieldOp->getOperands().end(), result); + if (iter != yieldOp->getOperands().end()) + argIdx = std::distance(yieldOp->getOperands().begin(), iter); + if (argIdx == -1) + return false; + for (auto operand : dotOp.getOperands()) { + if (dependOn(operand, argIdx, forOp)) { + auto iterArgs = forOp.getRegionIterArgs(); + *firstUse = iterArgs[argIdx].use_begin().getUser(); + return true; + } + } + return false; +} + +static void removeExtraWait(tt::nvidia_gpu::DotWaitOp dotWaitOp, + bool hasDotWait0) { + if (hasDotWait0) { + dotWaitOp->erase(); + } +} + +void mlir::triton::asyncLaunchDots(scf::ForOp forOp) { + Block *loop = forOp.getBody(); + auto getBlockNumInFor = [](Operation *op, scf::ForOp forOp) { + if (!op) + return -1l; + auto lastOp = op; + while (op->getBlock()->getParentOp() != forOp) { + lastOp = op; + op = op->getBlock()->getParentOp(); + } + return std::distance(lastOp->getBlock()->getParent()->begin(), + lastOp->getBlock()->getIterator()); + }; + /// XXX(Keren): Clean up the following duplicate code with checkDotOp + /// dots to be pipelined + bool hasSyncDot = false; + bool hasDotWait0 = false; + SmallVector allDots; + SmallVector dots; + SmallVector resultNeedSync; + for (Operation &op : *loop) { + if (auto dotWaitOp = dyn_cast(&op)) { + auto attr = dotWaitOp->getAttrOfType("pendings"); + auto pendingCount = attr.getInt(); + if (pendingCount == 0) + hasDotWait0 = true; + } + if (auto dotOp = dyn_cast(&op)) { + allDots.push_back(dotOp); + } + } + for (Operation &op : *loop) { + if (auto dotOp = dyn_cast(&op)) { + auto resTy = dotOp.getResult().getType().dyn_cast(); + if (auto resEnc = resTy.getEncoding().dyn_cast()) { + if (resEnc && resEnc.isHopper()) { + auto dot = dotOp.getResult(); + bool valid = true; + + // all users of dot should be scf.yield + if (!dot.hasOneUse()) + valid = false; + if (!isa(*dot.getUsers().begin())) + valid = false; + + Operation *firstUse = nullptr; + auto depend = selfDepend(dotOp, forOp, &firstUse); + bool selfDirectDepend = (dotOp == firstUse); + for (auto tempInAll : allDots) { + auto iter = std::find(dots.begin(), dots.end(), tempInAll); + if (iter != dots.end()) + continue; + auto db = getBlockNumInFor(tempInAll, forOp); + auto fb = getBlockNumInFor(firstUse, forOp); + if (db < fb || + (db == fb && db >= 0 && tempInAll->isBeforeInBlock(firstUse))) + hasSyncDot = true; + } + auto CArg = dotOp.getOperand(2); + if (!(selfDirectDepend || + (depend && !selfDirectDepend && hasSyncDot)) || + !CArg.hasOneUse()) + valid = false; + + if (valid) { + dots.push_back(dotOp); + resultNeedSync.push_back( + dotOp->getUses().begin()->getOperandNumber()); + } + } + } + } + } + + // Early stop: no need to continue if there is no valid dot in the loop. + if (dots.empty()) + return; + + OpBuilder builder(forOp); + // 0. insert dot_wait after the last dot in the loop as we implicitly pipeline + // wgmma ops by one stage. + // This is needed to prevent shared memory inputs to be overriden before the + // operation is completed. + // TODO: merge this with the rest of the pipelining transformation and look at + // a better representation for async dots. + tt::DotOp lastDot = dots.back(); + auto loc = lastDot.getLoc(); + builder.setInsertionPointAfter(lastDot); + auto dotWait = builder.create( + lastDot.getLoc(), lastDot.getResult(), dots.size()); + + // 1. replace Dot with DotAsync + for (size_t idx = 0; idx < dots.size(); ++idx) { + tt::DotOp dotOp = dots[idx]; + builder.setInsertionPoint(dotOp); + auto dotAsync = builder.create( + dotOp.getLoc(), dotOp.getA(), dotOp.getB(), dotOp.getC(), + dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc()); + dotOp.replaceAllUsesWith(dotAsync.getResult()); + dotOp->erase(); + } + + hasDotWait0 = hasDotWait0 || hasSyncDot; + + // 2. If there's any outstanding DotAsyncOps, we need to wait for them. + builder.setInsertionPointAfter(forOp); + SmallVector resultTypes(resultNeedSync.size()); + SmallVector yieldThenValues(resultNeedSync.size()); + SmallVector yieldElseValues(resultNeedSync.size()); + for (int i = 0; i < resultNeedSync.size(); ++i) { + resultTypes[i] = forOp->getResult(resultNeedSync[i]).getType(); + yieldThenValues[i] = forOp->getResult(resultNeedSync[i]); + yieldElseValues[i] = forOp->getResult(resultNeedSync[i]); + } + Value loopNotEmpty = builder.create( + loc, arith::CmpIPredicate::slt, forOp.getLowerBound(), + forOp.getUpperBound()); + auto ifOp = builder.create(loc, resultTypes, loopNotEmpty, + /*hasElse*/ true); + builder.setInsertionPointToStart(ifOp.thenBlock()); + for (int i = 0; i < resultNeedSync.size(); ++i) { + Value result = forOp->getResult(resultNeedSync[i]); + if (result.use_empty()) + continue; + auto dotWait = + builder.create(forOp.getLoc(), result, 0); + result.replaceAllUsesExcept(ifOp.getResult(i), dotWait); + yieldThenValues[i] = dotWait.getResult(); + } + auto yieldOpThen = builder.create(loc, yieldThenValues); + builder.setInsertionPointToEnd(ifOp.elseBlock()); + auto yieldOpElse = builder.create(loc, yieldElseValues); + + // 3. potentially remove redundant dot_wait after dot_async if having mutiple + // DotOp in the loop + removeExtraWait(dotWait, hasDotWait0); +} diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp new file mode 100644 index 000000000..18df341ac --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp @@ -0,0 +1,704 @@ +//===- LoopPipelining.cpp - Code to perform loop software pipelining-------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements loop software pipelining +// +//===----------------------------------------------------------------------===// + +// Fork of upstream pipeliner. This will be merged upstream once things are +// stable. Modifications so far are: +// -Bug fix for def with a distance of 1 scheduled in stage 0. +// -Support dynamic loops and predicate operations in the prologue. +// -Support for non-index type for induction variable. +// -Support source with distance of 1 used multiple stages later. +// -Fix bug when a value yield is used outside the loop and the value def is not +// in the last stage. If we are not peeling the epilgue we need to remap the +// output correctly. + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/MathExtras.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/Support/Debug.h" + +#include "PipelineExpander.h" + +#define DEBUG_TYPE "triton-loop-pipelining" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +using namespace mlir::scf; +using namespace mlir::triton; + +namespace { + +/// Helper to keep internal information during pipelining transformation. +struct LoopPipelinerInternal { + /// Coarse liverange information for ops used across stages. + struct LiverangeInfo { + unsigned lastUseStage = 0; + unsigned defStage = 0; + }; + +protected: + ForOp forOp; + unsigned maxStage = 0; + DenseMap stages; + std::vector opOrder; + Value ub; + Value lb; + Value step; + bool dynamicLoop; + triton::PipeliningOption::AnnotationlFnType annotateFn = nullptr; + bool peelEpilogue; + triton::PipeliningOption::PredicateOpFnType predicateFn = nullptr; + + // When peeling the kernel we generate several version of each value for + // different stage of the prologue. This map tracks the mapping between + // original Values in the loop and the different versions + // peeled from the loop. + DenseMap> valueMapping; + + /// Assign a value to `valueMapping`, this means `val` represents the version + /// `idx` of `key` in the epilogue. + void setValueMapping(Value key, Value el, int64_t idx); + + std::pair getDefiningOpAndDistance(Value value); + +public: + /// Initalize the information for the given `op`, return true if it + /// satisfies the pre-condition to apply pipelining. + bool initializeLoopInfo(ForOp op, const triton::PipeliningOption &options); + /// Emits the prologue, this creates `maxStage - 1` part which will contain + /// operations from stages [0; i], where i is the part index. + void emitPrologue(RewriterBase &rewriter); + /// Gather liverange information for Values that are used in a different stage + /// than its definition. + llvm::MapVector analyzeCrossStageValues(); + scf::ForOp createKernelLoop( + const llvm::MapVector &crossStageValues, + RewriterBase &rewriter, + llvm::DenseMap, unsigned> &loopArgMap); + /// Emits the pipelined kernel. This clones loop operations following user + /// order and remaps operands defined in a different stage as their use. + LogicalResult createKernel( + scf::ForOp newForOp, + const llvm::MapVector &crossStageValues, + const llvm::DenseMap, unsigned> &loopArgMap, + RewriterBase &rewriter); + /// Emits the epilogue, this creates `maxStage - 1` part which will contain + /// operations from stages [i; maxStage], where i is the part index. + llvm::SmallVector emitEpilogue(RewriterBase &rewriter); +}; + +bool LoopPipelinerInternal::initializeLoopInfo( + ForOp op, const triton::PipeliningOption &options) { + LDBG("Start initializeLoopInfo"); + forOp = op; + ub = forOp.getUpperBound(); + lb = forOp.getLowerBound(); + step = forOp.getStep(); + + dynamicLoop = true; + auto upperBoundCst = ub.getDefiningOp(); + auto lowerBoundCst = lb.getDefiningOp(); + auto stepCst = step.getDefiningOp(); + if (!upperBoundCst || !lowerBoundCst || !stepCst) { + if (!options.supportDynamicLoops) { + LDBG("--dynamic loop not supported -> BAIL"); + return false; + } + } else { + int64_t ubImm = upperBoundCst.value(); + int64_t lbImm = lowerBoundCst.value(); + int64_t stepImm = stepCst.value(); + int64_t numIteration = ceilDiv(ubImm - lbImm, stepImm); + if (numIteration > maxStage) { + dynamicLoop = false; + } else if (!options.supportDynamicLoops) { + LDBG("--fewer loop iterations than pipeline stages -> BAIL"); + return false; + } + } + peelEpilogue = options.peelEpilogue; + predicateFn = options.predicateFn; + if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) { + LDBG("--no epilogue or predicate set -> BAIL"); + return false; + } + std::vector> schedule; + options.getScheduleFn(forOp, schedule); + if (schedule.empty()) { + LDBG("--empty schedule -> BAIL"); + return false; + } + + opOrder.reserve(schedule.size()); + for (auto &opSchedule : schedule) { + maxStage = std::max(maxStage, opSchedule.second); + stages[opSchedule.first] = opSchedule.second; + opOrder.push_back(opSchedule.first); + } + + // All operations need to have a stage. + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!stages.contains(&op)) { + op.emitOpError("not assigned a pipeline stage"); + LDBG("--op not assigned a pipeline stage: " << op << " -> BAIL"); + return false; + } + } + + // Currently, we do not support assigning stages to ops in nested regions. The + // block of all operations assigned a stage should be the single `scf.for` + // body block. + for (const auto &[op, stageNum] : stages) { + (void)stageNum; + if (op == forOp.getBody()->getTerminator()) { + op->emitError("terminator should not be assigned a stage"); + LDBG("--terminator should not be assigned stage: " << *op << " -> BAIL"); + return false; + } + if (op->getBlock() != forOp.getBody()) { + op->emitOpError("the owning Block of all operations assigned a stage " + "should be the loop body block"); + LDBG("--the owning Block of all operations assigned a stage " + "should be the loop body block: " + << *op << " -> BAIL"); + return false; + } + } + + // Only support loop carried dependency with a distance of 1. This means the + // source of all the scf.yield operands needs to be defined by operations in + // the loop. + if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), + [this](Value operand) { + Operation *def = operand.getDefiningOp(); + return !def || !stages.contains(def); + })) { + LDBG("--only support loop carried dependency with a distance of 1 -> BAIL"); + return false; + } + annotateFn = options.annotateFn; + return true; +} + +/// Clone `op` and call `callback` on the cloned op's oeprands as well as any +/// operands of nested ops that: +/// 1) aren't defined within the new op or +/// 2) are block arguments. +static Operation * +cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op, + function_ref callback) { + Operation *clone = rewriter.clone(*op); + for (OpOperand &operand : clone->getOpOperands()) + callback(&operand); + clone->walk([&](Operation *nested) { + for (OpOperand &operand : nested->getOpOperands()) { + Operation *def = operand.get().getDefiningOp(); + if ((def && !clone->isAncestor(def)) || isa(operand.get())) + callback(&operand); + } + }); + return clone; +} + +void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { + // Initialize the iteration argument to the loop initiale values. + for (BlockArgument &arg : forOp.getRegionIterArgs()) { + OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg); + setValueMapping(arg, operand.get(), 0); + } + auto yield = cast(forOp.getBody()->getTerminator()); + Location loc = forOp.getLoc(); + for (int64_t i = 0; i < maxStage; i++) { + Value predicate; + if (dynamicLoop) { + Type t = ub.getType(); + // pred = ub > lb + (i * step) + Value iv = rewriter.create( + loc, lb, + rewriter.create( + loc, step, + rewriter.create( + loc, rewriter.getIntegerAttr(t, i)))); + predicate = rewriter.create(loc, arith::CmpIPredicate::slt, + iv, ub); + } + + // special handling for induction variable as the increment is implicit. + // iv = lb + i * step + Type t = lb.getType(); + Value iv = rewriter.create( + loc, lb, + rewriter.create( + loc, step, + rewriter.create(loc, + rewriter.getIntegerAttr(t, i)))); + setValueMapping(forOp.getInductionVar(), iv, i); + for (Operation *op : opOrder) { + if (stages[op] > i) + continue; + Operation *newOp = + cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { + auto it = valueMapping.find(newOperand->get()); + if (it != valueMapping.end()) { + Value replacement = it->second[i - stages[op]]; + newOperand->set(replacement); + } + }); + if (predicate) { + newOp = predicateFn(rewriter, newOp, predicate); + assert(newOp && "failed to predicate op."); + } + if (annotateFn) + annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Prologue, i); + for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { + setValueMapping(op->getResult(destId), newOp->getResult(destId), + i - stages[op]); + // If the value is a loop carried dependency update the loop argument + // mapping. + for (OpOperand &operand : yield->getOpOperands()) { + if (operand.get() != op->getResult(destId)) + continue; + setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], + newOp->getResult(destId), i - stages[op] + 1); + } + } + } + } +} + +std::pair +LoopPipelinerInternal::getDefiningOpAndDistance(Value value) { + int64_t distance = 0; + if (auto arg = dyn_cast(value)) { + if (arg.getOwner() != forOp.getBody()) + return {nullptr, 0}; + // Ignore induction variable. + if (arg.getArgNumber() == 0) + return {nullptr, 0}; + distance++; + value = + forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1); + } + Operation *def = value.getDefiningOp(); + if (!def) + return {nullptr, 0}; + return {def, distance}; +} + +llvm::MapVector +LoopPipelinerInternal::analyzeCrossStageValues() { + llvm::MapVector crossStageValues; + for (Operation *op : opOrder) { + unsigned stage = stages[op]; + + auto analyzeOperand = [&](OpOperand &operand) { + auto [def, distance] = getDefiningOpAndDistance(operand.get()); + if (!def) + return; + auto defStage = stages.find(def); + if (defStage == stages.end() || defStage->second == stage || + defStage->second == stage + distance) + return; + assert(stage > defStage->second); + LiverangeInfo &info = crossStageValues[operand.get()]; + info.defStage = defStage->second; + info.lastUseStage = std::max(info.lastUseStage, stage); + }; + + for (OpOperand &operand : op->getOpOperands()) + analyzeOperand(operand); + visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) { + analyzeOperand(*operand); + }); + } + return crossStageValues; +} + +scf::ForOp LoopPipelinerInternal::createKernelLoop( + const llvm::MapVector + &crossStageValues, + RewriterBase &rewriter, + llvm::DenseMap, unsigned> &loopArgMap) { + // Creates the list of initial values associated to values used across + // stages. The initial values come from the prologue created above. + // Keep track of the kernel argument associated to each version of the + // values passed to the kernel. + llvm::SmallVector newLoopArg; + // For existing loop argument initialize them with the right version from the + // prologue. + for (const auto &retVal : + llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { + Operation *def = retVal.value().getDefiningOp(); + assert(def && "Only support loop carried dependencies of distance 1"); + unsigned defStage = stages[def]; + Value valueVersion = valueMapping[forOp.getRegionIterArgs()[retVal.index()]] + [maxStage - defStage]; + assert(valueVersion); + newLoopArg.push_back(valueVersion); + } + for (auto escape : crossStageValues) { + LiverangeInfo &info = escape.second; + Value value = escape.first; + for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage; + stageIdx++) { + Value valueVersion = + valueMapping[value][maxStage - info.lastUseStage + stageIdx]; + assert(valueVersion); + newLoopArg.push_back(valueVersion); + loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage - + stageIdx)] = newLoopArg.size() - 1; + } + } + + // Create the new kernel loop. When we peel the epilgue we need to peel + // `numStages - 1` iterations. Then we adjust the upper bound to remove those + // iterations. + Value newUb = forOp.getUpperBound(); + if (peelEpilogue) { + Type t = ub.getType(); + Location loc = forOp.getLoc(); + // newUb = ub - maxStage * step + newUb = rewriter.create( + loc, ub, + rewriter.create( + loc, step, + rewriter.create( + loc, rewriter.getIntegerAttr(t, -maxStage)))); + } + auto newForOp = + rewriter.create(forOp.getLoc(), forOp.getLowerBound(), newUb, + forOp.getStep(), newLoopArg); + // When there are no iter args, the loop body terminator will be created. + // Since we always create it below, remove the terminator if it was created. + if (!newForOp.getBody()->empty()) + rewriter.eraseOp(newForOp.getBody()->getTerminator()); + return newForOp; +} + +LogicalResult LoopPipelinerInternal::createKernel( + scf::ForOp newForOp, + const llvm::MapVector + &crossStageValues, + const llvm::DenseMap, unsigned> &loopArgMap, + RewriterBase &rewriter) { + valueMapping.clear(); + + // Create the kernel, we clone instruction based on the order given by + // user and remap operands coming from a previous stages. + rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); + IRMapping mapping; + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) { + mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + } + SmallVector predicates(maxStage + 1, nullptr); + if (!peelEpilogue) { + // Create a predicate for each stage except the last stage. + Location loc = newForOp.getLoc(); + Type t = ub.getType(); + for (unsigned i = 0; i < maxStage; i++) { + // c = ub - (maxStage - i) * step + Value c = rewriter.create( + loc, ub, + rewriter.create( + loc, step, + rewriter.create( + loc, rewriter.getIntegerAttr(t, -int64_t(maxStage - i))))); + + Value pred = rewriter.create( + newForOp.getLoc(), arith::CmpIPredicate::slt, + newForOp.getInductionVar(), c); + predicates[i] = pred; + } + } + for (Operation *op : opOrder) { + int64_t useStage = stages[op]; + auto *newOp = rewriter.clone(*op, mapping); + SmallVector operands; + // Collect all the operands for the cloned op and its nested ops. + op->walk([&operands](Operation *nestedOp) { + for (OpOperand &operand : nestedOp->getOpOperands()) { + operands.push_back(&operand); + } + }); + for (OpOperand *operand : operands) { + Operation *nestedNewOp = mapping.lookup(operand->getOwner()); + // Special case for the induction variable uses. We replace it with a + // version incremented based on the stage where it is used. + if (operand->get() == forOp.getInductionVar()) { + rewriter.setInsertionPoint(newOp); + + // offset = (maxStage - stages[op]) * step + Type t = step.getType(); + Value offset = rewriter.create( + forOp.getLoc(), step, + rewriter.create( + forOp.getLoc(), + rewriter.getIntegerAttr(t, maxStage - stages[op]))); + Value iv = rewriter.create( + forOp.getLoc(), newForOp.getInductionVar(), offset); + nestedNewOp->setOperand(operand->getOperandNumber(), iv); + rewriter.setInsertionPointAfter(newOp); + continue; + } + Value source = operand->get(); + auto arg = dyn_cast(source); + if (arg && arg.getOwner() == forOp.getBody()) { + Value ret = forOp.getBody()->getTerminator()->getOperand( + arg.getArgNumber() - 1); + Operation *dep = ret.getDefiningOp(); + if (!dep) + continue; + auto stageDep = stages.find(dep); + if (stageDep == stages.end() || stageDep->second == useStage) + continue; + // If the value is a loop carried value coming from stage N + 1 remap, + // it will become a direct use. + if (stageDep->second == useStage + 1) { + nestedNewOp->setOperand(operand->getOperandNumber(), + mapping.lookupOrDefault(ret)); + continue; + } + source = ret; + } + // For operands defined in a previous stage we need to remap it to use + // the correct region argument. We look for the right version of the + // Value based on the stage where it is used. + Operation *def = source.getDefiningOp(); + if (!def) + continue; + auto stageDef = stages.find(def); + if (stageDef == stages.end() || stageDef->second == useStage) + continue; + auto remap = loopArgMap.find( + std::make_pair(operand->get(), useStage - stageDef->second)); + assert(remap != loopArgMap.end()); + nestedNewOp->setOperand(operand->getOperandNumber(), + newForOp.getRegionIterArgs()[remap->second]); + } + + if (predicates[useStage]) { + newOp = predicateFn(rewriter, newOp, predicates[useStage]); + if (!newOp) + return failure(); + // Remap the results to the new predicated one. + for (auto values : llvm::zip(op->getResults(), newOp->getResults())) + mapping.map(std::get<0>(values), std::get<1>(values)); + } + rewriter.setInsertionPointAfter(newOp); + if (annotateFn) + annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Kernel, 0); + } + + // Collect the Values that need to be returned by the forOp. For each + // value we need to have `LastUseStage - DefStage` number of versions + // returned. + // We create a mapping between original values and the associated loop + // returned values that will be needed by the epilogue. + llvm::SmallVector yieldOperands; + for (OpOperand &yielOperand : + forOp.getBody()->getTerminator()->getOpOperands()) { + Value source = mapping.lookupOrDefault(yielOperand.get()); + // When we don't peel the epilogue the yield value is used outside the loop + // we need to make sure we return the version from numStages - defStage. + if (!peelEpilogue && + !forOp.getResult(yielOperand.getOperandNumber()).use_empty()) { + auto [def, distance] = getDefiningOpAndDistance(yielOperand.get()); + if (def) { + auto defStage = stages.find(def); + if (defStage != stages.end()) { + Value pred = predicates[defStage->second]; + if (pred) { + source = rewriter.create( + pred.getLoc(), pred, source, + newForOp.getBody() + ->getArguments()[yielOperand.getOperandNumber() + 1]); + } + } + } + } + yieldOperands.push_back(source); + } + + for (auto &it : crossStageValues) { + int64_t version = maxStage - it.second.lastUseStage + 1; + unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage; + // add the original version to yield ops. + // If there is a live range spanning across more than 2 stages we need to + // add extra arg. + for (unsigned i = 1; i < numVersionReturned; i++) { + setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), + version++); + yieldOperands.push_back( + newForOp.getBody()->getArguments()[yieldOperands.size() + 1 + + newForOp.getNumInductionVars()]); + } + setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), + version++); + yieldOperands.push_back(mapping.lookupOrDefault(it.first)); + } + // Map the yield operand to the forOp returned value. + for (const auto &retVal : + llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { + Operation *def = retVal.value().getDefiningOp(); + assert(def && "Only support loop carried dependencies of distance 1"); + unsigned defStage = stages[def]; + if (defStage > 0) { + setValueMapping(forOp.getRegionIterArgs()[retVal.index()], + newForOp->getResult(retVal.index()), + maxStage - defStage + 1); + } + } + rewriter.create(forOp.getLoc(), yieldOperands); + return success(); +} + +llvm::SmallVector +LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter) { + llvm::SmallVector returnValues(forOp->getNumResults()); + // Emit different versions of the induction variable. They will be + // removed by dead code if not used. + for (int64_t i = 0; i < maxStage; i++) { + Location loc = forOp.getLoc(); + Type t = lb.getType(); + Value minusOne = + rewriter.create(loc, rewriter.getIntegerAttr(t, -1)); + // number of iterations = ((ub - 1) - lb) / step + Value totlaNumIteration = rewriter.create( + loc, + rewriter.create( + loc, rewriter.create(loc, ub, minusOne), lb), + step); + // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i) + Value minusI = + rewriter.create(loc, rewriter.getIntegerAttr(t, -i)); + Value newlastIter = rewriter.create( + loc, lb, + rewriter.create( + loc, step, + rewriter.create(loc, totlaNumIteration, minusI))); + setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i); + } + // Emit `maxStage - 1` epilogue part that includes operations from stages + // [i; maxStage]. + for (int64_t i = 1; i <= maxStage; i++) { + for (Operation *op : opOrder) { + if (stages[op] < i) + continue; + Operation *newOp = + cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { + auto it = valueMapping.find(newOperand->get()); + if (it != valueMapping.end()) { + Value replacement = it->second[maxStage - stages[op] + i]; + newOperand->set(replacement); + } + }); + if (annotateFn) + annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Epilogue, + i - 1); + for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { + setValueMapping(op->getResult(destId), newOp->getResult(destId), + maxStage - stages[op] + i); + // If the value is a loop carried dependency update the loop argument + // mapping and keep track of the last version to replace the original + // forOp uses. + for (OpOperand &operand : + forOp.getBody()->getTerminator()->getOpOperands()) { + if (operand.get() != op->getResult(destId)) + continue; + unsigned version = maxStage - stages[op] + i + 1; + // If the version is greater than maxStage it means it maps to the + // original forOp returned value. + if (version > maxStage) { + returnValues[operand.getOperandNumber()] = newOp->getResult(destId); + continue; + } + setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], + newOp->getResult(destId), version); + } + } + } + } + return returnValues; +} + +void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) { + auto it = valueMapping.find(key); + // If the value is not in the map yet add a vector big enough to store all + // versions. + if (it == valueMapping.end()) + it = + valueMapping + .insert(std::make_pair(key, llvm::SmallVector(maxStage + 1))) + .first; + it->second[idx] = el; +} + +} // namespace + +FailureOr +mlir::triton::pipelineForLoop(RewriterBase &rewriter, ForOp forOp, + const triton::PipeliningOption &options, + bool *modifiedIR) { + if (modifiedIR) + *modifiedIR = false; + LoopPipelinerInternal pipeliner; + if (!pipeliner.initializeLoopInfo(forOp, options)) + return failure(); + + if (modifiedIR) + *modifiedIR = true; + + // 1. Emit prologue. + pipeliner.emitPrologue(rewriter); + + // 2. Track values used across stages. When a value cross stages it will + // need to be passed as loop iteration arguments. + // We first collect the values that are used in a different stage than where + // they are defined. + llvm::MapVector + crossStageValues = pipeliner.analyzeCrossStageValues(); + + // Mapping between original loop values used cross stage and the block + // arguments associated after pipelining. A Value may map to several + // arguments if its liverange spans across more than 2 stages. + llvm::DenseMap, unsigned> loopArgMap; + // 3. Create the new kernel loop and return the block arguments mapping. + ForOp newForOp = + pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap); + // Create the kernel block, order ops based on user choice and remap + // operands. + if (failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap, + rewriter))) + return failure(); + + llvm::SmallVector returnValues = + newForOp.getResults().take_front(forOp->getNumResults()); + if (options.peelEpilogue) { + // 4. Emit the epilogue after the new forOp. + rewriter.setInsertionPointAfter(newForOp); + returnValues = pipeliner.emitEpilogue(rewriter); + } + // 5. Erase the original loop and replace the uses with the epilogue output. + if (forOp->getNumResults() > 0) + rewriter.replaceOp(forOp, returnValues); + else + rewriter.eraseOp(forOp); + + return newForOp; +} diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.h b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.h new file mode 100644 index 000000000..0a3d736c6 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.h @@ -0,0 +1,101 @@ +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_ + +// This is a fork of upstream pipeline transformation. This will be merged back +// upstream once we have a stable solution. + +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" + +namespace mlir { + +class RewriterBase; +class Operation; +class Value; + +namespace scf { +class ForOp; +} + +namespace triton { + +/// Options to dictate how loops should be pipelined. +struct PipeliningOption { + /// Lambda returning all the operation in the forOp, with their stage, in the + /// order picked for the pipelined loop. + using GetScheduleFnType = std::function> &)>; + GetScheduleFnType getScheduleFn = nullptr; + enum class PipelinerPart { + Prologue, + Kernel, + Epilogue, + }; + /// Lambda called by the pipeliner to allow the user to annotate the IR while + /// it is generated. + /// The callback passes the operation created along with the part of the + /// pipeline and the iteration index. The iteration index is always 0 for the + /// kernel. For the prologue and epilogue, it corresponds to the iteration + /// peeled out of the loop in the range [0, maxStage[. + using AnnotationlFnType = + std::function; + AnnotationlFnType annotateFn = nullptr; + + /// Control whether the epilogue should be peeled out of the loop or + /// operations should be predicated to skip the early stages in the last loop + /// iterations. If the epilogue is predicated; the user needs to provide a + /// lambda to generate the predicated version of operations. + bool peelEpilogue = true; + + /// Control whether the transformation checks that the number of iterations is + /// greater or equal to the number of stages and skip the transformation if + /// this is not the case. If the loop is dynamic and this is set to true the + /// pipeliner will have to predicate operations in the the prologue/epilogue. + bool supportDynamicLoops = false; + + // Callback to predicate operations when the prologue or epilogue are not + // peeled. This takes the original operation, an i1 predicate value and the + // pattern rewriter. It is expected to replace the given operation with + // the predicated equivalent and return it, or return nullptr if the + // predication is impossible. In the latter case, pipelining will fail and + // may leave IR in a partially transformed state. + using PredicateOpFnType = + std::function; + PredicateOpFnType predicateFn = nullptr; + + // TODO: add option to decide if the prologue should be peeled. +}; + +/// Generate a pipelined version of the scf.for loop based on the schedule given +/// as option. This applies the mechanical transformation of changing the loop +/// and generating the prologue/epilogue for the pipelining and doesn't make any +/// decision regarding the schedule. +/// Based on the options the loop is split into several stages. +/// The transformation assumes that the scheduling given by user is valid. +/// For example if we break a loop into 3 stages named S0, S1, S2 we would +/// generate the following code with the number in parenthesis as the iteration +/// index: +/// +/// S0(0) // Prologue +/// S0(1) S1(0) // Prologue +/// scf.for %I = %C0 to %N - 2 { +/// S0(I+2) S1(I+1) S2(I) // Pipelined kernel +/// } +/// S1(N) S2(N-1) // Epilogue +/// S2(N) // Epilogue +/// +/// If `modifiedIR` is provided, it will be set to a value that indicates +/// whether pipelining modified the IR before failing, signaling to the caller +/// whether they can proceed with different transformations. +FailureOr pipelineForLoop(RewriterBase &rewriter, scf::ForOp forOp, + const PipeliningOption &options, + bool *modifiedIR = nullptr); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_ diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.h b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.h new file mode 100644 index 000000000..67ee2ca83 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.h @@ -0,0 +1,27 @@ +#ifndef TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ +#define TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ + +#include "PipelineExpander.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" +#include + +namespace mlir { +namespace triton { + +/// This fill out the pipelining options including schedule and annotations for +/// wait ops. This also does pre-processing by converting some of the loads into +/// async loads so that the IR is ready to be pipelined. +bool preProcessLoopAndGetSchedule(scf::ForOp &forOp, int numStages, + mlir::triton::PipeliningOption &options); + +/// This does post-processing on the pipelined loop to try to pipeline wgmma +/// ops. +// TODO: this should be included as part of the pipeline but currently the wgmma +// wait modeling is problematic. +void asyncLaunchDots(scf::ForOp forOp); + +} // namespace triton +} // namespace mlir +#endif // TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp new file mode 100644 index 000000000..573151115 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp @@ -0,0 +1,88 @@ +#include "PipelineExpander.h" +#include "Schedule.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/Debug.h" + +//===----------------------------------------------------------------------===// +// This file will create a schedule that will be handed over to the pipeline +// expander. +// Software pipeliners are usually separated into two pieces, one that create a +// modulo schedule and an expander that rewrites the loop and emits a prologue +// and epilogue. This pass first calls a helper that will pre-process the IR +// to create async operations and create a modulo schedule. Then we call the +// expander to generate the prologue and new loop. +//===----------------------------------------------------------------------===// + +using namespace mlir; + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +static void pipelineLoop(scf::ForOp forOp, int numStages) { + mlir::triton::PipeliningOption options; + // Skip loop with distance > 1 for now. + // TODO: relax the constraint in the expander. + if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), + [](Value operand) { + Operation *def = operand.getDefiningOp(); + return !def; + })) + return; + + bool foundSchedule = false; + foundSchedule = preProcessLoopAndGetSchedule(forOp, numStages, options); + + // TODO: add more pipelines strategy. + if (!foundSchedule) + return; + + IRRewriter rewriter(forOp->getContext()); + rewriter.setInsertionPoint(forOp); + FailureOr newForOp = + mlir::triton::pipelineForLoop(rewriter, forOp, options); + + if (succeeded(newForOp)) + mlir::triton::asyncLaunchDots(newForOp.value()); +} + +namespace { +struct PipelinePass : public TritonGPUPipelineBase { + PipelinePass() = default; + PipelinePass(int numStages, int numWarps, int numCTAs, + int computeCapability) { + this->numStages = numStages; + this->numWarps = numWarps; + this->numCTAs = numCTAs; + this->computeCapability = computeCapability; + } + + void runOnOperation() override { + if (this->numStages <= 1) + return; + SmallVector loops; + getOperation()->walk([&](scf::ForOp forOp) { loops.push_back(forOp); }); + for (scf::ForOp forOp : loops) { + pipelineLoop(forOp, numStages); + } + } +}; +} // anonymous namespace + +std::unique_ptr mlir::createTritonGPUPipelinePass(int numStages, + int numWarps, + int numCTAs, + int computeCapability) { + return std::make_unique(numStages, numWarps, numCTAs, + computeCapability); +} diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 388687c82..75375cdba 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -355,10 +355,6 @@ bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) { triton::MakeRangeOp, triton::SplatOp>(op); } -// - -// Replace ForOp with a new ForOp with extra operands. The YieldOp is not -// updated and needs to be updated separatly for the loop to be correct. scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, scf::ForOp loop, ValueRange newIterOperands) { OpBuilder::InsertionGuard g(rewriter); diff --git a/python/test/unit/hopper/test_gemm.py b/python/test/unit/hopper/test_gemm.py index a3e3f80b9..5c57fd17c 100644 --- a/python/test/unit/hopper/test_gemm.py +++ b/python/test/unit/hopper/test_gemm.py @@ -339,6 +339,9 @@ def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, '16-32-64-8-2-256-256-256-True', ]: pytest.skip('Known legacy issue, ldmatrix can only support x4') + enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower() + if NUM_CTAS > 1 and enable_tma in ["on", "true", "1"]: + pytest.skip('multi-CTA with TMA not supported in MaterializeLoadStore') M = BLOCK_M if M is None else M N = BLOCK_N if N is None else N diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index 8bfb7b576..03872748b 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -10,16 +10,15 @@ #A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> #B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> -// CHECK: tt.func @matmul_loop +// CHECK-LABEL: tt.func @matmul_loop // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32 -// CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] // CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] // CHECK-DAG: %[[LOOP_COND_0_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_0]] // CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[LOOP_COND_0_SPLAT_A]] -// CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor // CHECK-DAG: %[[LOOP_COND_0_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_0]] // CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[LOOP_COND_0_SPLAT_B]] // CHECK-DAG: %[[IV_1:.*]] = arith.addi %[[LB]], %[[STEP:.*]] @@ -29,18 +28,24 @@ // CHECK-DAG: %[[LOOP_COND_1_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_1]] // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]], %[[LOOP_COND_1_SPLAT_B]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0] -// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] +// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]] // CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] -// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] -// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}} -// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, {{.*}} -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, {{.*}} +// CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.convert_layout %[[arg_b0]] +// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_0]], {{.*}} +// CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] +// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][{{.*}}, 0, 0] -// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][{{.*}}, 0, 0] -// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} +// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: scf.yield {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, @@ -93,31 +98,37 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, #C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> #B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> -// CHECK: tt.func @matmul_loop_nested + +// CHECK-LABEL: tt.func @matmul_loop_nested // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32 // CHECK: scf.for // CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor -// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0] -// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] -// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] -// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] -// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}} -// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, {{.*}} -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, {{.*}} -// CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][{{.*}}, 0, 0] -// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][{{.*}}, 0, 0] -// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} +// CHECK-DAG: %[[A0:.*]] = triton_gpu.extract_slice %[[A0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK-DAG: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]] +// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] +// CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.convert_layout %[[arg_b0]] +// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_0]], {{.*}} +// CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] +// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] +// CHECK: triton_gpu.async_wait {num = 2 : i32} +// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: scf.yield {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, @@ -171,23 +182,28 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, #C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> #B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> -// CHECK: tt.func @matmul_loop_single_pipeline +// CHECK-LABEL: tt.func @matmul_loop_single_pipeline // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32 // CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor // CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: triton_gpu.async_wait {num = 1 : i32} -// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] +// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_b0:.*]] = %[[B0]] // CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] // CHECK: tt.dot {{.*}}, %[[arg_b0_dot_op]], {{.*}} -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, {{.*}} +// CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] // CHECK: triton_gpu.async_wait {num = 1 : i32} -// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][{{.*}}, 0, 0] -// CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, {{.*}} +// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: scf.yield {{.*}}, %[[NEXT_B_BUFFER]], %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_B]] module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 7907618bd..f3a9b01aa 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -11,15 +11,15 @@ #A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> #B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> -// CHECK: tt.func @matmul_loop +// CHECK-LABEL: tt.func @matmul_loop // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] // CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] // CHECK-DAG: %[[LOOP_COND_0_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_0]] // CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[LOOP_COND_0_SPLAT_A]] -// CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor // CHECK-DAG: %[[LOOP_COND_0_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_0]] // CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[LOOP_COND_0_SPLAT_B]] // CHECK-DAG: %[[IV_1:.*]] = arith.addi %[[LB]], %[[STEP:.*]] @@ -29,25 +29,25 @@ // CHECK-DAG: %[[LOOP_COND_1_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_1]] // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]], %[[LOOP_COND_1_SPLAT_B]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0] -// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_0]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] +// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]] // CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] // CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.convert_layout %[[arg_b0]] // CHECK: %[[arg_b0_dot_op_1:.*]] = arith.mulf %[[arg_b0_dot_op_0]] // CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_1]], {{.*}} -// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] : i32 -// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_2]] -// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[NEXT_LOOP_IDX]] -// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]] -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]] +// CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] +// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0] -// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0] -// CHECK-DAG: %[[PIPELINE_IDX_PLUS_ONE:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] -// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_2]] -// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX_PLUS_ONE]] -// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[EXTRACT_IDX]] +// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: scf.yield {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { @@ -92,36 +92,36 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, tt.return %loop#2: tensor<128x128xf32, #C> } -// CHECK: tt.func @matmul_loop_nested +// CHECK-LABEL: tt.func @matmul_loop_nested // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 // CHECK: scf.for // CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor -// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0] -// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_0]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] -// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] -// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] -// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}} -// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] : i32 -// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_2]] -// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[NEXT_LOOP_IDX]] -// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]] -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]] -// CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0] -// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0] -// CHECK-DAG: %[[PIPELINE_IDX_PLUS_ONE:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] -// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_2]] -// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX_PLUS_ONE]] -// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[EXTRACT_IDX]] +// CHECK-DAG: %[[A0:.*]] = triton_gpu.extract_slice %[[A0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK-DAG: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]] +// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] +// CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.convert_layout %[[arg_b0]] +// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_0]], {{.*}} +// CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] +// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] +// CHECK: triton_gpu.async_wait {num = 2 : i32} +// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: scf.yield {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C>{ @@ -168,7 +168,7 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, } -// CHECK: tt.func @matmul_loop_single_pipeline +// CHECK-LABEL: tt.func @matmul_loop_single_pipeline // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 @@ -176,20 +176,20 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, // CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: triton_gpu.async_wait {num = 1 : i32} -// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_0]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] +// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_b0:.*]] = %[[B0]] // CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] // CHECK: tt.dot {{.*}}, %[[arg_b0_dot_op]], {{.*}} -// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] : i32 -// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_2]] -// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[NEXT_LOOP_IDX]] -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]] +// CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] // CHECK: triton_gpu.async_wait {num = 1 : i32} -// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0] -// CHECK-DAG: %[[PIPELINE_IDX_PLUS_ONE:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] -// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_2]] -// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX_PLUS_ONE]] -// CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[EXTRACT_IDX]] +// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: scf.yield {{.*}}, %[[NEXT_B_BUFFER]], %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_B]] tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { @@ -228,18 +228,18 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, tt.return %loop#1 : tensor<128x128xf32, #C> } -// CHECK: tt.func @lut_bmm_scalar +// CHECK-LABEL: tt.func @lut_bmm_scalar // CHECK: triton_gpu.insert_slice_async // CHECK: triton_gpu.insert_slice_async // CHECK: triton_gpu.insert_slice_async // CHECK: triton_gpu.insert_slice_async // CHECK: triton_gpu.async_commit_group -// CHECK: %[[LUT_BUFFER_0:.*]] = tt.load %arg15, {{.*}} +// CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} +// CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_1]] +// CHECK: %[[LUT_BUFFER_0:.*]] = tt.load %{{.*}}, {{.*}} // CHECK: %[[LUT_BUFFER_1:.*]] = arith.muli {{.*}}, %[[LUT_BUFFER_0]] // CHECK: %[[LUT_BUFFER_2:.*]] = tt.splat %[[LUT_BUFFER_1]] // CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[LUT_BUFFER_2]] -// CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %arg14, {{.*}} -// CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_1]] // CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_0]] // CHECK: triton_gpu.async_wait {num = 2 : i32} tt.func @lut_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, @@ -271,19 +271,19 @@ tt.func @lut_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, tt.return %79#0 : tensor<16x16xf32, #C> } -// CHECK: tt.func @lut_bmm_vector +// CHECK-LABEL: tt.func @lut_bmm_vector // CHECK: triton_gpu.insert_slice_async // CHECK: triton_gpu.insert_slice_async // CHECK: triton_gpu.insert_slice_async // CHECK: triton_gpu.insert_slice_async // CHECK: triton_gpu.async_commit_group -// CHECK: %[[LUT_BUFFER_0:.*]] = tt.load %arg15, {{.*}} +// CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} +// CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_1]] +// CHECK: %[[LUT_BUFFER_0:.*]] = tt.load %{{.*}}, {{.*}} // CHECK: %[[LUT_BUFFER_1:.*]] = tt.expand_dims %[[LUT_BUFFER_0]] {axis = 1 : i32} // CHECK: %[[LUT_BUFFER_2:.*]] = tt.broadcast %[[LUT_BUFFER_1]] // CHECK: %[[LUT_BUFFER_3:.*]] = arith.muli {{.*}}, %[[LUT_BUFFER_2]] // CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[LUT_BUFFER_3]] -// CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %arg14, {{.*}} -// CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_1]] // CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_0]] // CHECK: triton_gpu.async_wait {num = 2 : i32} tt.func @lut_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, @@ -317,11 +317,11 @@ tt.func @lut_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt tt.return %79#0 : tensor<16x16xf32, #C> } -// CHECK: tt.func @post_load_inv +// CHECK-LABEL: tt.func @post_load_inv // CHECK: scf.for -// CHECK: arith.index_cast // CHECK-DAG: %[[IV:.*]] = arith.index_cast // CHECK: %[[NEXT_IV:.*]] = arith.addi %[[IV]], %c1_i32 : i32 +// CHECK: arith.index_cast // CHECK-NOT: arith.addi %[[NEXT_IV]] tt.func @post_load_inv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, @@ -373,17 +373,11 @@ tt.func @post_load_inv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, tt.return %85#0 : tensor<32x32xf32, #C> } -// CHECK: tt.func @cross_iter_dep -// CHECK: triton_gpu.async_commit_group -// CHECK: triton_gpu.async_commit_group -// CHECK: triton_gpu.async_commit_group -// CHECK: triton_gpu.async_commit_group -// CHECK: %[[PTR0:.*]] = tt.addptr -// CHECK: %[[PTR1:.*]] = tt.addptr -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[BUF0:.*]] = %[[PTR0]], {{.*}}, %[[BUF1:.*]] = %[[PTR1]] +// CHECK-LABEL: tt.func @cross_iter_dep +// TODO: enable pipelining with distance of 2 +// CHECK-NOT: triton_gpu.async_commit_group +// CHECK: scf.for // CHECK: scf.yield -// CHECK-SAME: %[[BUF0]] -// CHECK-SAME: %[[BUF1]] tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, @@ -436,7 +430,7 @@ tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, tt.return %119#0 : tensor<32x32xf32, #C> } -// CHECK: tt.func @dep_arg_two_uses +// CHECK-LABEL: tt.func @dep_arg_two_uses // CHECK: tt.expand_dims // CHECK: tt.expand_dims // CHECK: tt.expand_dims %arg5 diff --git a/test/TritonGPU/pipeline-hopper-remove-wait.mlir b/test/TritonGPU/pipeline-hopper-remove-wait.mlir index 777e36546..56c55816d 100644 --- a/test/TritonGPU/pipeline-hopper-remove-wait.mlir +++ b/test/TritonGPU/pipeline-hopper-remove-wait.mlir @@ -1,4 +1,4 @@ -// RUN: ENABLE_TMA=1 ENABLE_MMA_V3=1 triton-opt %s -split-input-file -tritongpu-pipeline=compute-capability=90 -canonicalize | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritongpu-rewrite-tensor-pointer -canonicalize -tritongpu-pipeline=compute-capability=90 -canonicalize | FileCheck %s #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> @@ -11,6 +11,7 @@ #shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> #shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: two_dependent_dot tt.func public @two_dependent_dot(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} , %arg2: !tt.ptr {tt.divisibility = 16 : i32} , %arg3: f32 , %arg4: !tt.ptr {tt.divisibility = 16 : i32} , %arg5: !tt.ptr {tt.divisibility = 16 : i32} , %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg10: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg11: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg12: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg13: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg14: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg15: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg16: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg17: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg18: i32 , %arg19: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg20: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg21: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} ) attributes {noinline = false} { %cst = arith.constant dense<0xFF800000> : tensor<128x64xf32, #mma> %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> @@ -72,8 +73,9 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %81 = arith.truncf %68 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> %82 = triton_gpu.convert_layout %60 : (tensor<64x128xf16, #blocked2>) -> tensor<64x128xf16, #shared> %83 = triton_gpu.convert_layout %81 : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> - // CHECK-LABEL: triton_nvidia_gpu.dot_async - // CHECK-LABEL-NOT: triton_nvidia_gpu.dot_wait + // CHECK: triton_nvidia_gpu.dot_async + // CHECK-NOT: triton_nvidia_gpu.dot_wait + // CHECK: scf.yield %84 = tt.dot %83, %82, %arg23 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<64x128xf16, #shared> -> tensor<128x128xf32, #mma1> %85 = arith.mulf %arg24, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> %87 = arith.addf %85, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>