diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index 1273e6dac..8f7ce75f2 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -111,6 +111,8 @@ bool isExpensiveLoadOrStore(Operation *op); bool canFoldIntoConversion(Operation *op, Attribute targetEncoding); +// Replace ForOp with a new ForOp with extra operands. The YieldOp is not +// updated and needs to be updated separatly for the loop to be correct. scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, scf::ForOp loop, ValueRange newIterOperands); @@ -143,7 +145,6 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, ArrayRef shape); - } // namespace mlir #endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 361750ac0..30b1ef5fd 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -1421,18 +1421,37 @@ struct IndexCastOpLowering } }; +struct SelectOpConversion + : ElementwiseOpConversionBase { + using Base = + ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(mlir::arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + std::array llvmOperands; + if (operands[0].size() == 2) { + // Case of scalar condition with tensor operands. + assert(op.getCondition().getType().isInteger(1)); + llvmOperands = {adaptor.getCondition(), operands[0][0], operands[0][1]}; + } else { + llvmOperands = {operands[0][0], operands[0][1], operands[0][2]}; + } + return {rewriter.create( + loc, llvmOperands[1].getType(), llvmOperands, + adaptor.getAttributes().getValue())}; + } +}; + void populateElementwiseOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, int computeCapability, PatternBenefit benefit) { -#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \ - patterns.add>( \ - typeConverter, axisInfoAnalysis, benefit); - POPULATE_TERNARY_OP(arith::SelectOp, LLVM::SelectOp) -#undef POPULATE_TERNARY_OP - #define POPULATE_BINARY_OP(SRC_OP, DST_OP) \ patterns.add>( \ typeConverter, axisInfoAnalysis, benefit); @@ -1486,6 +1505,7 @@ void populateElementwiseOpToLLVMPatterns( patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index a06555c5a..2cd6e2672 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -5,7 +5,9 @@ add_mlir_dialect_library(TritonGPUTransforms OptimizeDotOperands.cpp OptimizeEpilogue.cpp OptimizeThreadLocality.cpp - Pipeline.cpp + Pipeliner/MatmulLoopPipeline.cpp + Pipeliner/PipelineExpander.cpp + Pipeliner/SoftwarePipeliner.cpp Prefetch.cpp RemoveLayoutConversions.cpp ReorderInstructions.cpp diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp deleted file mode 100644 index f48466b6f..000000000 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ /dev/null @@ -1,1940 +0,0 @@ -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Analysis/Utility.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/Transforms/Passes.h" -#include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" -#include "triton/Tools/Sys/GetEnv.hpp" -#include "llvm/ADT/MapVector.h" -#include "llvm/Support/Debug.h" - -//===----------------------------------------------------------------------===// -// This file implements software pipelining for loops. The implementation here -// is inspired by the pipeline pass in Triton (version 2.0) and SCF's -// LoopPipelining. -// -// We divide the loop body into the following phases: -// a. Pre-load operations: for instance, index computation. -// b. Load operations: loading from global memory to shared memory. -// c. Compute operations: for instance, Triton dot. -// d. Post-load operations: for instance, index computation. -// -// To pipeline the loop, we need to: -// - Hoist the pipelinable load operations for the first numStages-1 iterations -// to the loop pre-header -// - Find all the dependencies of the load operations. -// - Rematerialize the dependencies for their values at the first numStage-1 -// iterations -// - Assemble the loop body (numStage) and prefetch (numStage + 1). -// -// In the prologue, the sequence of operations is the same as the original loop -// body, following the (a) -> (b) -> (c) -> (d) order. In the loop body, -// however, we first execute the compute operations, then pre-load operations, -// post-load operations, and eventually the asynchronous load operations - in -// the (c) -> (a) -> (d) -> (b) order. This is used to better hide the latency -// of the load operations. Because of this, if post-load operations have direct -// dependencies on the load operations, we could repeat the post-load -// operations. More specifically, this occurs when: -// 1. Any load operand has an immediate dependency argument used at numStage-1. -// 2. The argument is first defined at numStage-2. -// To avoid the repeat, we peeled off post-load operations in the prologue that -// satisfy the above two conditions. See the example below for the definition of -// immediate and non-immediate dependencies. -// If we have a load that immediately depends on a block argument in the -// current iteration, it is an immediate dependency. Otherwise, it is a -// non-immediate dependency, which means the load depends on a block argument -// in the previous iterations. -// For example: -// scf.for (%arg0, %arg1, %arg2) { -// %0 = load %arg0 <--- immediate dep, this address is initialized before -// numStages-1. -// %1 = load %arg1 -// %2 = add %1, %arg2 -// %3 = load %2 <--- non-immediate dep, %arg1 must be an -// update-to-date value. -// } -// -// Our pipelining pass share some common characteristics with SCF's -// LoopPipelining. However, it is also noteworthy that our pipelining pass has -// the following characteristics different from SCF's LoopPipelining: -// 1. It can handle loop-carried dependencies of distance greater than 1. -// 2. It does not have a complicated epilogue but instead uses masking to handle -// boundary conditions. -// 3. Each operation/loop-carried argument cannot provide values to both -// immediate and non-immediate dependencies. Otherwise, we have to rematerialize -// the operation and arguments, which would likely increase register pressure. -// For example: -// scf.for (%arg0, %arg1, %arg2) { -// %0 = load %arg0 -// %1 = load %arg1, %0 <--- %0 is both a post-load op at numStages-2 and a -// pre-load op at numStages-1, so that we need two versions of %0. -// %2 = add %0, %arg2 -// scf.yield %arg0, %2, %arg2 -// } -// -//===----------------------------------------------------------------------===// - -using llvm::MapVector; -using namespace mlir; -namespace tt = mlir::triton; -namespace ttg = mlir::triton::gpu; -/// FIXME(Keren): The pipeline pass shouldn't be aware of nvidia_gpu dialect -namespace ttng = mlir::triton::nvidia_gpu; - -#define GEN_PASS_CLASSES -#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" - -#define int_attr(num) builder.getI64IntegerAttr(num) - -namespace { - -// Pass named attrs (e.g., tt.contiguity) from Triton to Triton -void addNamedAttrs(Operation *op, DictionaryAttr dictAttrs) { - NamedAttrList attrs = op->getDiscardableAttrs(); - // Collect the attributes to propagate: the ones in dictAttrs and not yet on - // the operation. - SmallVector toPropagate; - for (const NamedAttribute attr : dictAttrs.getValue()) { - if (!attrs.get(attr.getName())) - toPropagate.push_back(attr); - } - // If we found any, let's set them here as a single step. - if (toPropagate.size()) { - attrs.append(toPropagate); - op->setDiscardableAttrs(attrs); - } -} - -struct ConsumerReleaseInfo { - Value iterVar; - Value stageVar; - Value phaseVar; - Value nextIVVar; - Value stepVar; - Value upperBoundVar; - ttg::CTALayoutAttr CTALayout; - DenseMap consumerStageMap; -}; -typedef DenseMap - ConsumerReleaseMap; - -class LoopPipeliner { - /// Cache of ForOp and YieldOp related to this pipeliner. - scf::ForOp forOp; - scf::YieldOp yieldOp; - - /// Loads to be pipelined - SetVector validLoads; - /// The value that each load will be mapped to (after layout conversion) - DenseMap loadsMapping; - /// load => buffer - DenseMap loadsBuffer; - /// load => buffer type (with shared layout after swizzling) - DenseMap loadsBufferType; - /// load => buffer at stage N - DenseMap> loadStageBuffer; - /// load => after extract - DenseMap loadsExtract; - - /// XXX(Keren): The following are h100 only and disabled - /// load => full barrier arrive - DenseMap loadsBarrierArvOp; - /// load => mbarriers - DenseMap loadsFullBarriers; - DenseMap loadsEmptyBarriers; - /// load => null value or previous load which can share barrier with - DenseMap loadsCanShareBarriers; - /// Maintains the information to emit consumer_release mbarrier_arrive - ConsumerReleaseMap &consumerReleaseMap; - bool hasHopperDot = false; - // XXX(Keren): why the variable name is hopper dot and why do we need this - // check? - void checkHopperDots(SetVector &ops); - // XXX(Keren): it looks more like an optimization to be, not sure if it should - // exist in the base pipeliner - void checkOpShareBarriers(SetVector &ops); - int numLoadsRequireAsyncWait = 0; - int numLoadsRequireMBarrier = 0; - // Number of buffers to allocate for each input. - int numSharedMemorySlices = 0; - - /// Iterator values - Value nextIV; - Value pipelineIterIdx; - Value curWaitIdx; - - // Only needed when numLoadsRequireMBarrier > 0 - Value loopIterIdx; - Value curPhase; - Value curEmptyPhase; - - /// Yield values - SmallVector nextBuffers; - SmallVector extractSlices; - SmallVector yieldValues; - - /// The number of stages in the pipeline. - /// Stages in the range of [0, numStages-1) are in the prologue. - /// numStages-1 is appended after the loop body. - int numStages; - - /// Arg indicies - size_t bufferIdx, loadIdx, depArgsBeginIdx, ivIdx; - DenseMap depArgsIdx; - - /// XXX(Keren): The mode parameter is hacky, should be refactored - // false: legacy mode as a temporary solution for backward compatibility - // true: new mode for hopper - bool mode; - int numWarps; - int numCTAs; - - /// value (in loop) => value at stage N - DenseMap> valueMapping; - /// loop iter arg => value - DenseMap depArgsMapping; - /// forOp value => newForOp value - IRMapping mapping; - /// forOp value => prefetch value - IRMapping nextMapping; - - /// Dependency ops by program order - SmallVector orderedDeps; - - /// arg => source operand defined stages - DenseMap> immediateArgStages; - - /// block arguments that loads depend on - SetVector depArgs; - - /// operation => source operand defined stages - DenseMap> immediateOpStages; - - /// operations that loads depend on - SetVector depOps; - - /// Collect all pipelinable ops - LogicalResult collectOps(SetVector &ops); - - /// Collect values that `v` depends on and are defined inside the loop - void collectValueDep(Value v, int stage, SetVector &opDeps); - - /// Collect all op dependencies - void collectDeps(SetVector &ops, - MapVector> &opDeps); - - /// Check if none of the ops has valid uses - LogicalResult checkOpUses(SetVector &ops); - - /// Check if ops have dependencies that are not pipelinable - void checkOpDeps(SetVector &ops); - - void createBufferTypes(); - - void createOrderedDeps(); - - /// Return the stage at which `v` is defined prior to `stage` - int getValueDefStage(Value v, int stage); - - /// Map `origin` to `newValue` at `stage` - void setValueMapping(Value origin, Value newValue, int stage); - - /// Map `origin` to `newValue` at `stage` according to the association between - /// yieldOp and forOp - void setValueMappingYield(Value origin, Value newValue, int stage); - - /// Map `origin` to `newValue` at the next stage according to the association - /// between yieldOp and forOp - void setValueMappingYield(scf::ForOp newForOp, Value origin, Value newValue); - - /// Return the value mapped to `origin` at `stage`, if it exists. - Value lookupOrDefault(Value origin, int stage); - - /// Get the load mask for `loadOp`, given the mapped mask `mappedMask` (if - /// exists) and the current iteration's `loopCond`. - Value getLoadMask(tt::LoadOp loadOp, Value mappedMask, Value loopCond, - OpBuilder &builder); - - /// Return an empty buffer of size - ttg::AllocTensorOp allocateEmptyBuffer(tt::LoadOp loadOp, OpBuilder &builder); - - /// Collect all args of the new loop - SmallVector collectNewLoopArgs(); - - /// Clone the forOp and return the new forOp - scf::ForOp cloneForOp(ArrayRef newLoopArgs, OpBuilder &builder); - - /// Prefetch the next iteration for `newForOp` - void prefetchNextIteration(scf::ForOp newForOp, OpBuilder &builder); - - /// Check if curIdx is out of bound and wrap value around if necessary - Value getBoundedIterationValue(OpBuilder &builder, Value curIdx, - Value upperBoundIdx, Value curValue, - Value initValue); - - /// Assemble `newForOp`'s yield op - void finalizeYield(scf::ForOp newForOp, OpBuilder &builder); - -public: - LoopPipeliner(scf::ForOp forOp, int numStages, int numWarps, int numCTAs, - bool mode, int numSharedMemorySlices, - ConsumerReleaseMap &consumerReleaseMap) - : forOp(forOp), numStages(numStages), numWarps(numWarps), - numCTAs(numCTAs), mode(mode), - numSharedMemorySlices(numSharedMemorySlices), - consumerReleaseMap(consumerReleaseMap) { - // cache yieldOp - yieldOp = cast(forOp.getBody()->getTerminator()); - } - - LoopPipeliner() = delete; - - /// Collect loads to pipeline. Return success if we can pipeline this loop - LogicalResult initialize(); - - /// Emit pipelined loads (before loop body) - void emitPrologue(); - - /// emit pipelined loads (after loop body) - void emitEpilogue(); - - /// create the new ForOp (add new args & insert prefetched ops) - scf::ForOp createNewForOp(); - - friend struct PipelinePass; -}; - -/// Collect loads to pipeline. Return success if we can pipeline this loop -LogicalResult LoopPipeliner::collectOps(SetVector &ops) { - ModuleOp moduleOp = forOp->getParentOfType(); - ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); - - // We cannot use forOp.walk(...) here because we only want to visit the - // operations in the loop body block. Nested blocks are handled separately. - for (Operation &op : forOp) - if (auto loadOp = dyn_cast(&op)) { - if (isLoadFromTensorPtr(loadOp)) { - ops.insert(loadOp); - } else { - auto ptr = loadOp.getPtr(); - unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr); - if (auto mask = loadOp.getMask()) - vec = - std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); - - auto tensorTy = ptr.getType().dyn_cast(); - if (!tensorTy || tensorTy.getRank() < 2) - continue; - auto ty = - tensorTy.getElementType().cast().getPointeeType(); - unsigned width = vec * ty.getIntOrFloatBitWidth(); - // We do not pipeline all loads for the following reasons: - // 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8 and 16. - // 2. It's likely that pipling small loads won't offer much performance - // improvement and may even hurt performance by increasing register - // pressure. - if (width >= 32) - ops.insert(loadOp); - } - } - - if (ops.empty()) - return failure(); - else - return success(); -} - -void LoopPipeliner::collectValueDep(Value v, int stage, - SetVector &deps) { - // Loop-invariant value, skip - if (v.getParentRegion() != &forOp.getRegion()) - return; - - // Since we only need to peel the loop numStages-1 times, don't worry - // about depends that are too far away - if (stage < 0) - return; - - if (auto arg = v.dyn_cast()) { - if (arg.getArgNumber() > 0) { - deps.insert(v); - collectValueDep(yieldOp->getOperand(arg.getArgNumber() - 1), stage - 1, - deps); - } - } else { // value - deps.insert(v); - for (Value op : v.getDefiningOp()->getOperands()) - collectValueDep(op, stage, deps); - } -} - -void LoopPipeliner::collectDeps( - SetVector &ops, - MapVector> &valueDeps) { - for (auto op : ops) { - for (Value v : op->getOperands()) { - SetVector deps; - collectValueDep(v, numStages - 1, deps); - valueDeps[op] = deps; - } - } -} - -LogicalResult LoopPipeliner::checkOpUses(SetVector &ops) { - DenseSet invalidOps; - // Collect all ops' dependencies - MapVector> opDeps; - collectDeps(ops, opDeps); - - for (Operation *op : ops) { - if (auto loadOp = dyn_cast(op)) { - // Don't pipeline valid loads that depend on other valid loads - // (Because if a valid load depends on another valid load, this load needs - // to wait on the other load in the prologue, which is against the point - // of the pipeline pass) - bool isCandidate = true; - for (Operation *other : ops) - if (isa(other)) - if (opDeps[op].contains(other->getResult(0))) { - isCandidate = false; - break; - } - // We only pipeline loads that have one covert_layout (to dot_op) use - // TODO: lift this constraint in the future - if (isCandidate && loadOp.getResult().hasOneUse() && - !isLoadFromTensorPtr(loadOp)) { - isCandidate = false; - Operation *use = *loadOp.getResult().getUsers().begin(); - Operation *preUse = nullptr; - - // Advance to the first conversion as long as the use resides in shared - // memory and it has a single use itself - while (use) { - if (use->getNumResults() != 1 || !use->getResult(0).hasOneUse()) - break; - auto tensorType = - use->getResult(0).getType().dyn_cast(); - if (!tensorType.getEncoding().isa()) - break; - preUse = use; - use = *use->getResult(0).getUsers().begin(); - } - - if (auto convertLayout = llvm::dyn_cast(use)) { - if (auto tensorType = convertLayout.getResult() - .getType() - .dyn_cast()) - if (auto dotOpEnc = tensorType.getEncoding() - .dyn_cast()) { - isCandidate = true; - loadsMapping[loadOp] = convertLayout; - } - } else if (preUse && isa(use)) { - isCandidate = false; - // for MMAv3 whose dot take SharedEncoding as operands directly - Operation *post = *loadOp.getResult().getUsers().begin(); - auto newOrder = post->getResult(0) - .getType() - .cast() - .getEncoding() - .cast() - .getOrder(); - auto ty = loadOp.getType().cast(); - auto oldOrder = ttg::getOrder(ty.getEncoding()); - // The operand of MMAv3 is in SharedEncoding and it's order should not - // be changed after FuseTranspositions Pass. So we only pipeline the - // load if the order of the loaded BlockedEncoding is the same as the - // order of the SharedEncoding it is converted to. - // TODO: remove this constraint once the LoadOp supports transpose - // fusion - if (newOrder[0] == oldOrder[0] || newOrder[1] == oldOrder[1]) { - isCandidate = true; - loadsMapping[loadOp] = preUse->getResult(0); - } - } - } else if (isCandidate && mode && isLoadFromTensorPtr(loadOp)) { - loadsMapping[loadOp] = loadOp.getResult(); - } else - isCandidate = false; - - if (!isCandidate) - invalidOps.insert(loadOp); - else { - validLoads.insert(loadOp); - if (!isLoadFromTensorPtr(loadOp)) - numLoadsRequireAsyncWait++; - else - numLoadsRequireMBarrier++; - } - } - } - - for (Operation *op : invalidOps) - ops.remove(op); - - if (ops.empty()) - return failure(); - else - return success(); -} - -void LoopPipeliner::checkHopperDots(SetVector &ops) { - // dots to be pipelined - SetVector dots; - for (Operation &op : forOp) { - if (auto dotOp = dyn_cast(&op)) { - auto resTy = dotOp.getResult().getType().dyn_cast(); - if (auto resEnc = resTy.getEncoding().dyn_cast()) { - if (resEnc && resEnc.isHopper()) { - // Don't pipeline valid dots that depend on ops other than scf.yield - // and scf.for - auto dot = dotOp.getResult(); - bool valid = true; - - // all users of dot should be scf.yield - if (!dot.hasOneUse()) - valid = false; - if (!isa(*dot.getUsers().begin())) - valid = false; - - // C should be a block argument - auto CArg = dotOp.getOperand(2).dyn_cast(); - if (!CArg || !CArg.hasOneUse()) - valid = false; - - if (valid) - dots.insert(dotOp); - } - } - } - } - - hasHopperDot = true; -} - -void LoopPipeliner::checkOpShareBarriers(SetVector &ops) { - // Check if loads can share barriers - auto canShare = [&](Value load0, Value load1) -> bool { - if (!load0.hasOneUse() || !load1.hasOneUse()) - return false; - auto use0 = *load0.getUsers().begin(); - auto use1 = *load1.getUsers().begin(); - if (!use0->hasOneUse() || !use1->hasOneUse()) - return false; - if (*use0->getUsers().begin() != *use1->getUsers().begin()) - return false; - return true; - }; - // XXX(Keren): the logic here is pretty weird and might be incomplete - for (Value loadOp : validLoads) { - Value depLoad; - for (auto oldPair : loadsCanShareBarriers) { - Value oldLoad = oldPair.first; - if (canShare(loadOp, oldLoad)) { - depLoad = oldLoad; - break; - } - } - loadsCanShareBarriers[loadOp] = depLoad; - } -} - -void LoopPipeliner::checkOpDeps(SetVector &ops) { - SetVector nonImmediateDepArgs; - SetVector nonImmediateOps; - for (Operation *op : ops) { - for (Value v : op->getOperands()) { - SetVector deps; - collectValueDep(v, numStages - 1, deps); - int defStage = getValueDefStage(v, numStages - 1); - assert(defStage >= 0 && - "newLoopArgs has null args without a define op. Consider either " - "rewrite the loop to reduce cross iteration dependencies or " - "increase the num_stages value."); - for (auto dep : deps) { - auto immediate = deps.front().isa(); - if (auto arg = dyn_cast(dep)) { - depArgs.insert(arg); - if (immediate) - immediateArgStages[arg].insert(defStage); - else - nonImmediateDepArgs.insert(arg); - } else { - depOps.insert(dep.getDefiningOp()); - if (immediate) - immediateOpStages[dep.getDefiningOp()].insert(defStage); - else - nonImmediateOps.insert(dep.getDefiningOp()); - } - } - } - } - - // We could remove the following constraints if we can rematerialize in the - // loop. Check if immediateDepArgs and nonImmediateDepArgs are disjoint. - for (auto &[arg, stages] : immediateArgStages) { - assert(stages.size() == 1 && - "Triton doesn't support an argument provides values for " - "immediate operands of loads from multiple stages. Consider " - "removing post load instructions dependency on this argument."); - assert(!(nonImmediateDepArgs.contains(arg) && - stages.contains(numStages - 2)) && - "Loop-carried arguments provide values for both immediate and " - "non-immediate operands of loads. Please consider removing " - "pre/post load instructions dependency on this argument."); - } - - // Check if immediateOps and nonImmediateOps are disjoint. - for (auto &[op, stages] : immediateOpStages) { - assert(stages.size() == 1 && - "Triton doesn't support an operation provides values for " - "immediate operands of loads from multiple stages. Consider " - "removing post load instructions dependency on this argument."); - assert(!(nonImmediateOps.contains(op) && stages.contains(numStages - 2)) && - "Operations provide values for both immediate and " - "non-immediate operands of loads. Please consider " - "removing pre/post load instructions dependency on this " - "operation."); - } -} - -// helpers -void LoopPipeliner::setValueMapping(Value origin, Value newValue, int stage) { - if (valueMapping.find(origin) == valueMapping.end()) - valueMapping[origin] = SmallVector(numStages); - valueMapping[origin][stage] = newValue; -} - -void LoopPipeliner::setValueMappingYield(Value origin, Value newValue, - int stage) { - for (OpOperand &operand : origin.getUses()) { - if (operand.getOwner() == yieldOp) { - auto yieldIdx = operand.getOperandNumber(); - auto value = forOp.getRegionIterArgs()[yieldIdx]; - setValueMapping(value, newValue, stage); - } - } -} - -void LoopPipeliner::setValueMappingYield(scf::ForOp newForOp, Value origin, - Value newValue) { - for (OpOperand &operand : origin.getUses()) { - if (operand.getOwner() == yieldOp) { - auto yieldIdx = operand.getOperandNumber(); - auto depYieldIdx = depArgsIdx[forOp.getRegionIterArgs()[yieldIdx]]; - auto originArg = forOp.getRegionIterArgs()[yieldIdx]; - nextMapping.map(originArg, newValue); - auto newArg = newForOp.getRegionIterArgs()[depYieldIdx]; - if (!depArgsMapping.contains(newArg)) - depArgsMapping[newArg] = newValue; - } - } -} - -Value LoopPipeliner::lookupOrDefault(Value origin, int stage) { - if (valueMapping.find(origin) == valueMapping.end()) - return origin; - return valueMapping[origin][stage]; -} - -void LoopPipeliner::createBufferTypes() { - for (auto loadCvt : loadsMapping) { - auto loadOp = loadCvt.first; - Value cvt = loadCvt.second; - auto ty = loadOp.getType().cast(); - SmallVector bufferShape(ty.getShape().begin(), - ty.getShape().end()); - bufferShape.insert(bufferShape.begin(), numSharedMemorySlices); - auto CTALayout = ttg::getCTALayout(ty.getEncoding()); - Attribute sharedEnc; - if (auto dotOpEnc = cvt.getType() - .cast() - .getEncoding() - .dyn_cast()) { - // MMAv1 and MMAv2 - bool needTrans = dyn_cast_or_null( - cvt.getDefiningOp()->getOperand(0).getDefiningOp()); - unsigned bitWidth = ty.getElementType().getIntOrFloatBitWidth(); - sharedEnc = ttg::SharedEncodingAttr::get( - ty.getContext(), dotOpEnc, ty.getShape(), - ttg::getOrder(ty.getEncoding()), CTALayout, bitWidth, needTrans); - } else { - // MMAv3 - sharedEnc = ttg::SharedEncodingAttr::get(ty.getContext(), ty.getShape(), - ttg::getOrder(ty.getEncoding()), - CTALayout, ty.getElementType()); - } - // FIXME(Keren): block ptr not handled - loadsBufferType[loadOp] = - RankedTensorType::get(bufferShape, ty.getElementType(), sharedEnc); - } -} - -void LoopPipeliner::createOrderedDeps() { - for (Operation &op : *forOp.getBody()) { - if (depOps.contains(&op)) - orderedDeps.push_back(&op); - else if (op.getNumResults() > 0 && validLoads.contains(op.getResult(0))) - orderedDeps.push_back(&op); - } - assert(depOps.size() + validLoads.size() == orderedDeps.size() && - "depOps contains invalid values"); -} - -int LoopPipeliner::getValueDefStage(Value v, int stage) { - if (stage < 0) - return -1; - if (auto arg = v.dyn_cast()) { - if (arg.getArgNumber() > 0) - return getValueDefStage(yieldOp->getOperand(arg.getArgNumber() - 1), - stage - 1); - llvm_unreachable("Loop induction variable should not be a dependency"); - } else - return stage; -} - -ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(tt::LoadOp loadOp, - OpBuilder &builder) { - // Allocate a buffer for each pipelined tensor - // shape: e.g. (numStages==4), <32x64xbf16> -> <4x32x64xbf16> - Value convertLayout = loadsMapping[loadOp]; - if (auto tensorType = convertLayout.getType().dyn_cast()) - return builder.create(convertLayout.getLoc(), - loadsBufferType[loadOp]); - llvm_unreachable("Async copy's return should be of RankedTensorType"); -} - -LogicalResult LoopPipeliner::initialize() { - // All ops that maybe pipelined - SetVector ops; - - if (collectOps(ops).failed()) - return failure(); - - if (checkOpUses(ops).failed()) - return failure(); - - // XXX(Keren): hopper specific, should be cleaned up - checkHopperDots(ops); - - checkOpShareBarriers(ops); - - checkOpDeps(ops); - - createBufferTypes(); - - createOrderedDeps(); - - return success(); -} - -Value LoopPipeliner::getLoadMask(tt::LoadOp loadOp, Value mappedMask, - Value loopCond, OpBuilder &builder) { - Type maskType = tt::getI1SameShape(loadOp.getType()); - Value mask = loadOp.getMask(); - Value newMask; - if (mask) { - Value cond = loopCond; - if (isa(maskType)) { - cond = builder.create(mask.getLoc(), maskType, loopCond); - } - newMask = builder.create(mask.getLoc(), mappedMask, cond); - } else { - if (isa(maskType)) { - newMask = - builder.create(loopCond.getLoc(), maskType, loopCond); - } else { - newMask = loopCond; - } - } - return newMask; -} - -void LoopPipeliner::emitPrologue() { - OpBuilder builder(forOp); - // Get init operands for loop carried values - for (BlockArgument &arg : forOp.getRegionIterArgs()) { - OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg); - setValueMapping(arg, operand.get(), 0); - } - - // Alloc a vector of MBarriers in size numStages for each load to be pipelined - bool isMcast = false; - for (Value loadOp : validLoads) { - auto load = cast(loadOp.getDefiningOp()); - if (isLoadFromTensorPtr(load)) { - auto loadTy = loadOp.getType().cast(); - auto CTALayout = ttg::CTALayoutAttr::get( - load.getContext(), - /*CTAsPerCGA*/ {static_cast(numCTAs)}, - /*CTASplitNum*/ {1}, - /*CTAOrder*/ {0}); - auto sharedEncoding = ttg::SharedEncodingAttr::get( - load.getContext(), 1, 1, 1, {0}, CTALayout, false); - auto mBarriersTy = RankedTensorType::get( - {numStages}, builder.getIntegerType(64), sharedEncoding); - - if (!loadsCanShareBarriers[loadOp]) { - Value fullBarriers = builder.create( - load.getLoc(), mBarriersTy, 1); - loadsFullBarriers[loadOp] = fullBarriers; - } - auto layout = loadTy.getEncoding(); - auto CTASplitNum = ttg::getCTASplitNum(layout); - auto CTAsPerCGA = ttg::getCTAsPerCGA(layout); - if (CTASplitNum != CTAsPerCGA) { - isMcast = true; - // FIXME: numConsumerThreads could be 32 as well instead of 128 - // incase the consumer is not GMMA - unsigned arriveCnt = ttg::getNumWarpsPerCTA(layout); - if (hasHopperDot) - arriveCnt /= 4; - arriveCnt *= - product(CTAsPerCGA) / product(CTASplitNum); - - Value emptyBarriers = builder.create( - load.getLoc(), mBarriersTy, arriveCnt); - loadsEmptyBarriers[loadOp] = emptyBarriers; - } - } - } - - if (isMcast) { - builder.create(forOp.getLoc(), /*relaxed*/ 1); - builder.create(forOp.getLoc()); - } - - // prologue from [0, numStage-1) - Value iv = forOp.getLowerBound(); - pipelineIterIdx = builder.create(iv.getLoc(), 0, 32); - for (int stage = 0; stage < numStages - 1; ++stage) { - // Special handling for induction variable as the increment is implicit - if (stage != 0) - iv = builder.create(iv.getLoc(), iv, forOp.getStep()); - setValueMapping(forOp.getInductionVar(), iv, stage); - - // Special handling for loop condition as there is no condition in ForOp - Value loopCond = builder.create( - iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound()); - for (Operation *op : orderedDeps) { - Operation *newOp = nullptr; - if (validLoads.contains(op->getResult(0))) { - auto load = cast(op); - // Allocate empty buffer - if (stage == 0) { - loadsBuffer[load] = allocateEmptyBuffer(load, builder); - loadStageBuffer[load] = {loadsBuffer[load]}; - } - // load => copy async - if (auto loadOp = llvm::dyn_cast(op)) { - Value newMask = - getLoadMask(loadOp, lookupOrDefault(loadOp.getMask(), stage), - loopCond, builder); - - if (mode && isLoadFromTensorPtr(loadOp)) { - auto loc = op->getLoc(); - auto mBarTy = tt::PointerType::get(builder.getIntegerType(64), 3); - Value stageVal = - builder.create(loc, stage, 32); - // producer_acquire - if (loadsEmptyBarriers.count(loadOp)) { - Value emptyBarrier = builder.create( - loc, mBarTy, loadsEmptyBarriers[loadOp], stageVal); - auto trueVal = - builder.create(loc, 1, /*bitWidth*/ 1); - builder.create(loc, emptyBarrier, trueVal); - } - - // producer_commit - Value fullBarrier; - if (!loadsCanShareBarriers[loadOp]) { - fullBarrier = builder.create( - loc, mBarTy, loadsFullBarriers[loadOp], stageVal); - loadsExtract[loadOp] = fullBarrier; - } else { - // Reuse the barrier from previouse load. - fullBarrier = loadsExtract[loadsCanShareBarriers[loadOp]]; - } - - auto loadTy = loadOp.getType().dyn_cast(); - assert(loadTy); - auto CTASplitNum = ttg::getCTASplitNum(loadTy.getEncoding()); - auto shapePerSlice = - ttg::getShapePerCTA(CTASplitNum, loadTy.getShape()); - unsigned elems = - std::accumulate(shapePerSlice.begin(), shapePerSlice.end(), 1, - std::multiplies{}); - elems *= (loadTy.getElementType().getIntOrFloatBitWidth() / 8); - - if (!loadsCanShareBarriers[loadOp]) { - Value _0 = builder.create(loc, 0, 32); - Value threadId = builder.create(loc); - Value pred = builder.create( - loc, arith::CmpIPredicate::eq, threadId, _0); - pred = builder.create(loc, pred, loopCond); - Operation *barrierArvOp = builder.create( - loc, fullBarrier, pred, - /*remoteCtaId*/ nullptr, /*trackAsyncOp*/ false, elems); - loadsBarrierArvOp[loadOp] = barrierArvOp; - } else { - // Increase the transcnt for barrier of previouse load by the - // bytes of current load. - Operation *barrierArvOp = - loadsBarrierArvOp[loadsCanShareBarriers[loadOp]]; - unsigned base_elems = - barrierArvOp->getAttr("txCount").cast().getInt(); - barrierArvOp->setAttr("txCount", - IntegerAttr::get(builder.getIntegerType(32), - base_elems + elems)); - } - newOp = builder.create( - loc, loadsBuffer[loadOp].getType(), - lookupOrDefault(loadOp.getPtr(), stage), - loadStageBuffer[loadOp][stage], pipelineIterIdx, fullBarrier, - newMask, lookupOrDefault(loadOp.getOther(), stage), - loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile(), - /*axis*/ 0); - } else { - newOp = builder.create( - op->getLoc(), loadsBuffer[loadOp].getType(), - lookupOrDefault(loadOp.getPtr(), stage), - loadStageBuffer[loadOp][stage], pipelineIterIdx, newMask, - lookupOrDefault(loadOp.getOther(), stage), loadOp.getCache(), - loadOp.getEvict(), loadOp.getIsVolatile(), /*axis*/ 0); - builder.create(op->getLoc()); - } - loadStageBuffer[loadOp].push_back(newOp->getResult(0)); - } else - llvm_unreachable("This should be LoadOp"); - } else { - if (auto loadOp = dyn_cast(op)) { - Value newMask = - getLoadMask(loadOp, lookupOrDefault(loadOp.getMask(), stage), - loopCond, builder); - newOp = builder.create( - loadOp.getLoc(), loadOp.getResult().getType(), - lookupOrDefault(loadOp.getPtr(), stage), newMask, - lookupOrDefault(loadOp.getOther(), stage), - loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), - loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); - addNamedAttrs(newOp, op->getDiscardableAttrDictionary()); - } else - newOp = builder.clone(*op); - // Update loop-carried uses - for (unsigned opIdx = 0; opIdx < op->getNumOperands(); ++opIdx) { - auto it = valueMapping.find(op->getOperand(opIdx)); - if (it != valueMapping.end()) { - Value v = it->second[stage]; - assert(v && "Value not found in valueMapping"); - newOp->setOperand(opIdx, v); - } // else, op at opIdx is a loop-invariant value - } - } - - for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) { - Value originResult = op->getResult(dstIdx); - if (validLoads.contains(originResult)) - break; - setValueMapping(originResult, newOp->getResult(dstIdx), stage); - // Update mapping for loop-carried values (args) - setValueMappingYield(op->getResult(dstIdx), newOp->getResult(dstIdx), - stage + 1); - } - } // for (Operation *op : orderedDeps) - - // Update pipeline index - pipelineIterIdx = builder.create( - iv.getLoc(), pipelineIterIdx, - builder.create(iv.getLoc(), 1, 32)); - Value numSlices = builder.create( - iv.getLoc(), numSharedMemorySlices, 32); - Value _0 = builder.create(iv.getLoc(), 0, 32); - pipelineIterIdx = getBoundedIterationValue(builder, pipelineIterIdx, - numSlices, pipelineIterIdx, _0); - // Some values have not been used by any ops in the loop body - for (BlockArgument arg : forOp.getRegionIterArgs()) - setValueMappingYield(arg, valueMapping[arg][stage], stage + 1); - } // for (int stage = 0; stage < numStages - 1; ++stage) - - // async.wait & extract_slice - if (numLoadsRequireAsyncWait > 0) - builder.create(validLoads.front().getLoc(), - validLoads.size() * (numStages - 2)); - for (Value loadOp : validLoads) { - auto bufferType = loadStageBuffer[loadOp][numStages - 1] - .getType() - .cast(); - auto bufferShape = bufferType.getShape(); - auto sliceType = loadsMapping[loadOp].getType().cast(); - sliceType = RankedTensorType::get({bufferShape[1], bufferShape[2]}, - sliceType.getElementType(), - loadsBufferType[loadOp].getEncoding()); - Value extractSlice = builder.create( - loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1], - SmallVector{int_attr(0), int_attr(0), int_attr(0)}, - SmallVector{int_attr(1), - int_attr(sliceType.getShape()[0]), - int_attr(sliceType.getShape()[1])}, - SmallVector{int_attr(1), int_attr(1), int_attr(1)}); - loadsExtract[loadOp] = extractSlice; - } - curWaitIdx = builder.create(iv.getLoc(), 0, 32); - loopIterIdx = builder.create(iv.getLoc(), 0, 32); - curPhase = builder.create(iv.getLoc(), 0, 1); - curEmptyPhase = builder.create(iv.getLoc(), 1, 1); -} - -void LoopPipeliner::emitEpilogue() { - // If there's any outstanding async copies, we need to wait for them. - if (numLoadsRequireAsyncWait > 0) { - OpBuilder builder(forOp); - OpBuilder::InsertionGuard g(builder); - builder.setInsertionPointAfter(forOp); - builder.create(forOp.getLoc(), 0); - } -} - -SmallVector LoopPipeliner::collectNewLoopArgs() { - // Order of new args: - // (original args) - // (insertSliceAsync buffer at stage numStages - 1) for each load - // (extracted tensor) for each load - // (depArgs at stage numStages - 1) - // (depArgs at stage numStages - 2) - // ... - // (iv at stage numStages - 2) - // (pipeline iteration index) - // (loop iteration index) - // (wait index) - // (phase index) - // (empty phase index) - - // We need this to update operands for yield - // original block arg => new arg's idx - SmallVector newLoopArgs; - for (auto v : forOp.getInitArgs()) - newLoopArgs.push_back(v); - - bufferIdx = newLoopArgs.size(); - for (auto loadOp : validLoads) - newLoopArgs.push_back(loadStageBuffer[loadOp].back()); - - loadIdx = newLoopArgs.size(); - for (auto loadOp : validLoads) - newLoopArgs.push_back(loadsExtract[loadOp]); - - depArgsBeginIdx = newLoopArgs.size(); - for (auto depArg : depArgs) { - depArgsIdx[depArg] = newLoopArgs.size(); - if (immediateArgStages[depArg].contains(numStages - 2)) - // Peel off post load ops in numStage-1 - newLoopArgs.push_back(valueMapping[depArg][numStages - 2]); - else - newLoopArgs.push_back(valueMapping[depArg][numStages - 1]); - } - - ivIdx = newLoopArgs.size(); - newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]); - newLoopArgs.push_back(pipelineIterIdx); - newLoopArgs.push_back(curWaitIdx); - if (numLoadsRequireMBarrier > 0) { - newLoopArgs.push_back(loopIterIdx); - newLoopArgs.push_back(curPhase); - newLoopArgs.push_back(curEmptyPhase); - } - - return newLoopArgs; -} - -scf::ForOp LoopPipeliner::cloneForOp(ArrayRef newLoopArgs, - OpBuilder &builder) { - // Clone the original ForOp - auto newForOp = builder.create( - forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), newLoopArgs); - - // Set mapping on body of the new ForOp - builder.setInsertionPointToStart(newForOp.getBody()); - for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) - mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); - mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); - - // Loop iteration args - Value upperBound = newForOp.getUpperBound(); - Value step = newForOp.getStep(); - Value curIV = newForOp.getRegionIterArgs()[ivIdx]; - pipelineIterIdx = newForOp.getRegionIterArgs()[ivIdx + 1]; - curWaitIdx = newForOp.getRegionIterArgs()[ivIdx + 2]; - if (numLoadsRequireMBarrier > 0) { - loopIterIdx = newForOp.getRegionIterArgs()[ivIdx + 3]; - curPhase = newForOp.getRegionIterArgs()[ivIdx + 4]; - curEmptyPhase = newForOp.getRegionIterArgs()[ivIdx + 5]; - } - - // Clone the loop body, replace original args with args of the new ForOp. - SmallVector loadsFromTensorPtr; - for (Operation &op : forOp.getBody()->without_terminator()) { - if (auto cvtOp = dyn_cast(op)) { - auto result = op.getResult(0); - auto cvtDstTy = result.getType().cast(); - auto it = - std::find(validLoads.begin(), validLoads.end(), op.getOperand(0)); - if (it != validLoads.end()) { - auto loadArgIdx = std::distance(validLoads.begin(), it); - if (cvtDstTy.getEncoding().isa()) { - // We want to find cvt ops that match the following pattern: - // %0 = load %ptr - // %1 (dotOperand) = cvt %0 - // We replace the use new load use with a convert layout - auto cvt = builder.create( - result.getLoc(), cvtDstTy, - newForOp.getRegionIterArgs()[loadIdx + loadArgIdx]); - mapping.map(result, cvt.getResult()); - continue; - } else if (cvtDstTy.getEncoding().isa()) { - // We want to find cvt ops that match the following pattern: - // %0 = load %ptr - // %1 (sharedEncoding) = cvt %0 - // We replace the use new load use with insert_slice_async's result - mapping.map(result, - newForOp.getRegionIterArgs()[loadIdx + loadArgIdx]); - continue; - } - } - } else if (auto loadOp = dyn_cast(op)) { - if (isLoadFromTensorPtr(loadOp)) { - // XXX(Keren): The comparison operator using std::find on tensor ptr - // doesn't work as expected - auto operand = loadOp.getPtr(); - auto tensorTy = - operand.getType().cast().getPointeeType(); - auto loadArgIdx = 0; - for (auto validLoad : validLoads) { - auto defOp = cast(validLoad.getDefiningOp()); - if (isLoadFromTensorPtr(defOp)) { - auto validOperand = defOp.getOperand(0); - auto validTensorTy = - validOperand.getType().cast().getPointeeType(); - if (tensorTy == validTensorTy) - break; - } - loadArgIdx++; - } - // consumer_wait, emitted before the first consumer - auto firstConsumer = getFirstUser(loadOp); - mapping.map(loadOp, newForOp.getRegionIterArgs()[loadIdx + loadArgIdx]); - - // If current load can reuse barriers shared by previous load, then we - // do nothing. - if (!loadsCanShareBarriers[loadOp]) { - // emit mbarrier wait before the first consumer of the loaD - OpBuilder mBarBuilder(firstConsumer); - auto mBarTy = tt::PointerType::get(builder.getIntegerType(64), 3); - Value fullBarrier = mBarBuilder.create( - loadOp.getLoc(), mBarTy, loadsFullBarriers[loadOp], curWaitIdx); - mBarBuilder.create(loadOp.getLoc(), fullBarrier, - curPhase); - } - - loadsFromTensorPtr.push_back(loadOp); - continue; - } - } - cloneWithInferType(builder, &op, mapping); - } - - for (Value load : loadsFromTensorPtr) { - // consumer_relase, emitted after the last consumer - // 'the last consumer' might be updated in the following Phase_1 since - // some of the consumers might be pipelined. Thus we maintain this - // information in 'consumerReleaseMap' and move the position of - // consumer_release barrier in a seperate Phase_2 in case necessary. - if (loadsEmptyBarriers.count(load)) { - auto users = mapping.lookup(load).getUsers(); - DenseMap consumerStageMap; - for (Operation *user : users) { - // All the stage is initialized to zero before Phase_1, - // since no consumers has been pipelined yet. - consumerStageMap[user] = 0; - } - auto CTALayout = ttg::getCTALayout( - load.getType().cast().getEncoding()); - ConsumerReleaseInfo info{ - loopIterIdx, pipelineIterIdx, curEmptyPhase, curIV, - step, upperBound, CTALayout, consumerStageMap}; - consumerReleaseMap[loadsEmptyBarriers[load]] = info; - } - } - - // Remove redundant conversions - // e.g., %145 = triton_gpu.convert_layout %arg15 : (tensor<128x64xf16, - // #shared1>) -> tensor<128x64xf16, #shared1> - for (Operation &op : newForOp.getBody()->without_terminator()) { - if (auto convert_layout = dyn_cast(op)) { - auto result = op.getResult(0); - auto cvtDstTy = result.getType(); - auto operand = convert_layout.getOperand(); - auto tensorTy = operand.getType(); - if (cvtDstTy == tensorTy) - result.replaceAllUsesWith(operand); - } - } - - return newForOp; -} - -Value LoopPipeliner::getBoundedIterationValue(OpBuilder &builder, Value curIdx, - Value upperBoundIdx, - Value curValue, Value initValue) { - Value cond = builder.create( - curIdx.getLoc(), arith::CmpIPredicate::uge, curIdx, upperBoundIdx); - Value selectValue = builder.create( - curIdx.getLoc(), cond, initValue, curValue); - return selectValue; -} - -void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp, - OpBuilder &builder) { - // Map the dep args of the next iteration to the dep args of the current - size_t argIdx = 0; - for (auto depArg : depArgs) { - BlockArgument nextArg = - newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]; - nextMapping.map(depArg, nextArg); - ++argIdx; - } - - // Update loop iteration args - Value curIV = newForOp.getRegionIterArgs()[ivIdx]; - pipelineIterIdx = newForOp.getRegionIterArgs()[ivIdx + 1]; - curWaitIdx = newForOp.getRegionIterArgs()[ivIdx + 2]; - if (numLoadsRequireMBarrier > 0) { - loopIterIdx = newForOp.getRegionIterArgs()[ivIdx + 3]; - curPhase = newForOp.getRegionIterArgs()[ivIdx + 4]; - curEmptyPhase = newForOp.getRegionIterArgs()[ivIdx + 5]; - } - - // Special handling for iv & loop condition - auto idxLoc = curIV.getLoc(); - nextIV = builder.create(idxLoc, curIV, newForOp.getStep()); - Value nextLoopCond = builder.create( - idxLoc, arith::CmpIPredicate::slt, nextIV, newForOp.getUpperBound()); - - // Constants - Value _0 = builder.create(idxLoc, 0, 32); - Value _1 = builder.create(idxLoc, 1, 32); - Value numStagesVal = - builder.create(idxLoc, numStages, 32); - Value numSlices = - builder.create(idxLoc, numSharedMemorySlices, 32); - - // nextWaitIdx - Value waitIdxPlusOne = builder.create(idxLoc, curWaitIdx, _1); - Value nextWaitIdx = getBoundedIterationValue(builder, waitIdxPlusOne, - numSlices, waitIdxPlusOne, _0); - - // Indices of InsertSliceAsyncOp and ExtractSliceOp - Value insertSliceIndex = pipelineIterIdx; - Value extractSliceIndex = nextWaitIdx; - - // Prefetch load deps - // If a load-dependent instruction that uses a block argument, we - // shouldn't update the new mapping of the block argument in the current - // iteration. - // For example. - // %a = add %arg0, %c - // %b = add %arg0, %d - // - // Update %arg0 will cause the value of %b to be incorrect. - // We do need to use the next iteration value of %arg0 because it could be a - // immediate arg of a load op. - // load %arg0 - // %a = add %arg0, %c - // yield %a - // - // We reroder instructions so %a and yield are actually before load. load - // %arg0 should use the updated %arg0. - IRMapping curMapping = nextMapping; - for (Operation *op : orderedDeps) - if (!validLoads.contains(op->getResult(0))) { - if (immediateOpStages[op].contains(numStages - 2)) - // A post load op that provides values for numStage - 2 - curMapping.map(forOp.getInductionVar(), curIV); - else - curMapping.map(forOp.getInductionVar(), nextIV); - Operation *nextOp; - if (auto loadOp = dyn_cast(op)) { - auto newMask = - getLoadMask(loadOp, curMapping.lookupOrDefault(loadOp.getMask()), - nextLoopCond, builder); - nextOp = builder.create( - loadOp.getLoc(), loadOp.getResult().getType(), - curMapping.lookupOrDefault(loadOp.getPtr()), newMask, - curMapping.lookupOrDefault(loadOp.getOther()), - loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), - loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); - addNamedAttrs(nextOp, op->getDiscardableAttrDictionary()); - curMapping.map(loadOp.getResult(), nextOp->getResult(0)); - nextMapping.map(loadOp.getResult(), nextOp->getResult(0)); - } else { - nextOp = builder.clone(*op, curMapping); - for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) - nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx)); - } - - for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) - setValueMappingYield(newForOp, op->getResult(dstIdx), - nextOp->getResult(dstIdx)); - } - - // loads -> async loads - for (Operation *op : orderedDeps) { - Operation *nextOp = nullptr; - // Update loading mask - if (validLoads.contains(op->getResult(0))) { - auto loadOp = llvm::cast(op); - auto mask = loadOp.getMask(); - auto newMask = - getLoadMask(loadOp, nextMapping.lookupOrDefault(loadOp.getMask()), - nextLoopCond, builder); - if (mask) { - // If mask is defined outside the loop, don't update the map more than - // once - if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask))) - nextMapping.map(mask, newMask); - newMask = nextMapping.lookupOrDefault(loadOp.getMask()); - } - Value insertedVal; - if (mode && isLoadFromTensorPtr(loadOp)) { - auto loc = op->getLoc(); - auto mBarTy = tt::PointerType::get(builder.getIntegerType(64), 3); - - // producer_acquire - if (loadsEmptyBarriers.count(loadOp)) { - auto ifOp = builder.create(loc, ArrayRef{}, - nextLoopCond, false); - builder.setInsertionPointToStart(ifOp.thenBlock()); - Value emptyBarrier = builder.create( - loc, mBarTy, loadsEmptyBarriers[loadOp], insertSliceIndex); - builder.create(loc, emptyBarrier, - curEmptyPhase); - builder.setInsertionPointAfter(ifOp); - } - - // producer_commit - Value fullBarrier; - if (!loadsCanShareBarriers[loadOp]) { - fullBarrier = builder.create( - loc, mBarTy, loadsFullBarriers[loadOp], insertSliceIndex); - loadsExtract[loadOp] = fullBarrier; - } else { - // Reuse the barrier from previouse load. - fullBarrier = loadsExtract[loadsCanShareBarriers[loadOp]]; - } - - auto loadTy = loadOp.getType().dyn_cast(); - assert(loadTy); - auto CTASplitNum = ttg::getCTASplitNum(loadTy.getEncoding()); - auto shapePerSlice = - ttg::getShapePerCTA(CTASplitNum, loadTy.getShape()); - unsigned elems = std::accumulate( - shapePerSlice.begin(), shapePerSlice.end(), 1, std::multiplies{}); - elems *= (loadTy.getElementType().getIntOrFloatBitWidth() / 8); - if (!loadsCanShareBarriers[loadOp]) { - Value _0 = builder.create(loc, 0, 32); - Value threadId = builder.create(loc); - Value pred = builder.create( - loc, arith::CmpIPredicate::eq, threadId, _0); - pred = builder.create(loc, pred, nextLoopCond); - Operation *barrierArvOp = builder.create( - loc, fullBarrier, pred, - /*remoteCtaId*/ nullptr, - /*trackAsyncOp*/ false, elems); - loadsBarrierArvOp[loadOp] = barrierArvOp; - } else { - // Increase the transcnt for barrier of previouse load by the bytes of - // current load. - Operation *barrierArvOp = - loadsBarrierArvOp[loadsCanShareBarriers[loadOp]]; - unsigned base_elems = - barrierArvOp->getAttr("txCount").cast().getInt(); - barrierArvOp->setAttr( - "txCount", - IntegerAttr::get(builder.getIntegerType(32), base_elems + elems)); - } - insertedVal = builder.create( - loc, loadsBuffer[loadOp].getType(), - nextMapping.lookupOrDefault(loadOp.getPtr()), - newForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()], - insertSliceIndex, fullBarrier, newMask, - nextMapping.lookupOrDefault(loadOp.getOther()), loadOp.getCache(), - loadOp.getEvict(), loadOp.getIsVolatile(), /*axis*/ 0); - } else { - insertedVal = builder.create( - op->getLoc(), loadsBuffer[loadOp].getType(), - nextMapping.lookupOrDefault(loadOp.getPtr()), - newForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()], - insertSliceIndex, newMask, - nextMapping.lookupOrDefault(loadOp.getOther()), loadOp.getCache(), - loadOp.getEvict(), loadOp.getIsVolatile(), /*axis*/ 0); - builder.create(op->getLoc()); - } - nextBuffers.push_back(insertedVal); - // Extract slice - auto bufferType = insertedVal.getType().cast(); - auto bufferShape = bufferType.getShape(); - auto sliceType = loadsMapping[loadOp].getType().cast(); - sliceType = RankedTensorType::get({bufferShape[1], bufferShape[2]}, - sliceType.getElementType(), - loadsBufferType[loadOp].getEncoding()); - - nextOp = builder.create( - op->getLoc(), sliceType, insertedVal, - SmallVector{extractSliceIndex, int_attr(0), - int_attr(0)}, - SmallVector{int_attr(1), - int_attr(sliceType.getShape()[0]), - int_attr(sliceType.getShape()[1])}, - SmallVector{int_attr(1), int_attr(1), int_attr(1)}); - extractSlices.push_back(nextOp->getResult(0)); - - // Update mapping of results - for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) - // If this is a loop-carried value, update the mapping for yield - setValueMappingYield(newForOp, op->getResult(dstIdx), - nextOp->getResult(dstIdx)); - } - } - - // Some values have not been used by any ops in the loop body - for (BlockArgument arg : forOp.getRegionIterArgs()) - setValueMappingYield(newForOp, arg, - newForOp.getRegionIterArgs()[depArgsIdx[arg]]); - - // async.wait & extract_slice - if (numLoadsRequireAsyncWait > 0) { - Operation *asyncWait = builder.create( - validLoads[0].getLoc(), validLoads.size() * (numStages - 2)); - for (auto it = extractSlices.rbegin(); it != extractSlices.rend(); ++it) { - // move extract_slice after asyncWait - it->getDefiningOp()->moveAfter(asyncWait); - } - } - - // Bump pipelineIterIdx - Value pipelineIterIdxPlusOne = - builder.create(idxLoc, pipelineIterIdx, _1); - pipelineIterIdx = getBoundedIterationValue( - builder, pipelineIterIdxPlusOne, numSlices, pipelineIterIdxPlusOne, _0); - - // Bump curWaitIdx - curWaitIdx = nextWaitIdx; - - if (numLoadsRequireMBarrier > 0) { - // Bump loopIterIdx - loopIterIdx = builder.create(idxLoc, loopIterIdx, _1); - - Value _1_1b = builder.create(idxLoc, 1, 1); - - // Flip curPhase - Value nextPhase = builder.create(idxLoc, curPhase, _1_1b); - curPhase = getBoundedIterationValue(builder, waitIdxPlusOne, numStagesVal, - curPhase, nextPhase); - - // Flip curEmptyPhase - Value nextEmptyPhase = - builder.create(idxLoc, curEmptyPhase, _1_1b); - curEmptyPhase = - getBoundedIterationValue(builder, pipelineIterIdxPlusOne, numStagesVal, - curEmptyPhase, nextEmptyPhase); - } -} - -void LoopPipeliner::finalizeYield(scf::ForOp newForOp, OpBuilder &builder) { - SmallVector yieldValues; - for (Value v : yieldOp->getOperands()) - yieldValues.push_back(mapping.lookup(v)); - for (Value nextBuffer : nextBuffers) - yieldValues.push_back(nextBuffer); - for (Value nextSlice : extractSlices) - yieldValues.push_back(nextSlice); - - for (size_t i = depArgsBeginIdx; i < ivIdx; ++i) { - auto arg = newForOp.getRegionIterArgs()[i]; - assert(depArgsMapping.count(arg) && "Missing loop-carried value"); - yieldValues.push_back(depArgsMapping[arg]); - } - - // Loop iteration args - yieldValues.push_back(nextIV); - yieldValues.push_back(pipelineIterIdx); - yieldValues.push_back(curWaitIdx); - if (numLoadsRequireMBarrier > 0) { - yieldValues.push_back(loopIterIdx); - yieldValues.push_back(curPhase); - yieldValues.push_back(curEmptyPhase); - } - - builder.setInsertionPointToEnd(newForOp.getBody()); - builder.create(yieldOp->getLoc(), yieldValues); -} - -scf::ForOp LoopPipeliner::createNewForOp() { - OpBuilder builder(forOp); - auto newLoopArgs = collectNewLoopArgs(); - auto newForOp = cloneForOp(newLoopArgs, builder); - prefetchNextIteration(newForOp, builder); - finalizeYield(newForOp, builder); - return newForOp; -} - -// ref: mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp -struct PipelinePass : public TritonGPUPipelineBase { - PipelinePass() = default; - PipelinePass(int numStages, int numWarps, int numCTAs, - int computeCapability) { - this->numStages = numStages; - this->numWarps = numWarps; - this->numCTAs = numCTAs; - this->computeCapability = computeCapability; - } - - void runOnOperation() override { - // TODO[goostavz]: mode = 0 is temporary for backward compatible, will be - // deprecated after the refactor of pipeline fully gets done - // TODO[goostavz]: When mode = 1, the mask of prefetch insert_slice in the - // prologue is currently not properly provided. Need some second thought on - // the mask definition of InsertSliceOp when the src is ptr - bool mode = - computeCapability >= 90 && ::triton::tools::getBoolEnv("ENABLE_TMA"); - if (this->numStages <= 1) - return; - - // phase 0: pipeline loads in loops - // Pre-processing - // we make sure element-wise ops are done *after* the conversion - // to dot operands - // we can achieve this with simple recursive pattern matching - // MLIRContext *context = &getContext(); - // mlir::RewritePatternSet patterns(context); - // patterns.add(context); - // auto didPreprocess = - // applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - - llvm::SmallVector newForOps; - - // Currently we schedule stage 0 after stage `numStages - 1` during - // pipelining therefore we only need `numStages - 1` slice of memory. - // On Hopper we have a separate post-processing that pipelines wgmma so we - // need an extra buffer for each input. - // Note that an alternative would be to keep allocating `numStages` buffers - // and remove the barrier between the loads from shared memory and the - // copies from global to shared. This would require improving existing - // membar analysis. - int numSharedMemorySlices = - computeCapability < 90 ? numStages - 1 : numStages; - - // Do the pipelining - getOperation()->walk([&](scf::ForOp forOp) -> void { - LoopPipeliner pipeliner(forOp, this->numStages, this->numWarps, - this->numCTAs, mode, numSharedMemorySlices, - consumerReleaseMap); - if (pipeliner.initialize().failed()) - return; - - pipeliner.emitPrologue(); - scf::ForOp newForOp = pipeliner.createNewForOp(); - pipeliner.emitEpilogue(); - newForOps.push_back(newForOp); - - // Replace the original loop - for (unsigned i = 0; i < forOp->getNumResults(); ++i) - forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i)); - forOp->erase(); - }); - - // phase 1: pipeline dots in loops - // A tt.dot suitable for GMMA will be converted to ttg.dot_async. And a - // ttg.DotWaitOp will synchronize it lagging just one iteration, which is - // a hueristic rule. - for (auto forOp : newForOps) - asyncLaunchDots(forOp); - - // phase 2: emit consumer_release (empty barrier arrive) logics in case of - // TMA multicast. - // For each load ops, it is emitted after its last consumer, if the consumer - // is another async op, find its associated sync op. Each async load will be - // emitted with a consumer_release action. The merge of redundant mbarriers - // will be processed in the consequent OptimizeBarriers pass. - for (const auto &item : consumerReleaseMap) - emitConsumerRelease(item.first, item.second, numStages); - } - -private: - Value getRemoteCTAId(OpBuilder &b, Location loc, ttg::CTALayoutAttr CTALayout, - Value remoteCTAIdIdx) const; - void updateConsumerReleaseInfo(Operation *oldOp, Operation *newOp, int stage); - void asyncLaunchDots(scf::ForOp forOp); - void emitConsumerRelease(Value mbarTensor, const ConsumerReleaseInfo &info, - int numStages); - bool selfDepend(tt::DotOp op, scf::ForOp forOp, Operation **firstUse); - void removeExtraWait(tt::nvidia_gpu::DotWaitOp dotWaitOp, bool hasDotWait0); - ConsumerReleaseMap consumerReleaseMap; -}; - -void PipelinePass::updateConsumerReleaseInfo(Operation *oldOp, Operation *newOp, - int stage) { - for (auto &item : consumerReleaseMap) { - auto &m = item.second.consumerStageMap; - if (m.count(oldOp)) { - m.erase(oldOp); - m[newOp] = stage; - } - - for (Value operand : oldOp->getOperands()) { - Operation *op = operand.getDefiningOp(); - if (op && isa(op)) { - auto cvt = cast(op); - auto src = cvt.getSrc(); - auto srcEncoding = src.getType().cast().getEncoding(); - auto dstEncoding = - cvt.getResult().getType().cast().getEncoding(); - if (srcEncoding == dstEncoding && m.count(op)) { - m.erase(op); - m[newOp] = stage; - } - } - } - } -} - -bool PipelinePass::selfDepend(tt::DotOp dotOp, scf::ForOp forOp, - Operation **firstUse) { - std::function dependOn = - [&dependOn](Value v, int argId, scf::ForOp forOp) { - auto op = v.getDefiningOp(); - if (isa(v)) { - auto iterArgs = forOp.getRegionIterArgs(); - auto iter = std::find(iterArgs.begin(), iterArgs.end(), v); - if (iter != iterArgs.end()) - return std::distance(iterArgs.begin(), iter) == argId; - } else { - if (!op) - return false; - for (auto operand : op->getOperands()) { - if (dependOn(operand, argId, forOp)) - return true; - } - } - return false; - }; - auto result = dotOp.getResult(); - auto yieldOp = forOp.getBody()->getTerminator(); - int argIdx = -1; - auto iter = std::find(yieldOp->getOperands().begin(), - yieldOp->getOperands().end(), result); - if (iter != yieldOp->getOperands().end()) - argIdx = std::distance(yieldOp->getOperands().begin(), iter); - if (argIdx == -1) - return false; - for (auto operand : dotOp.getOperands()) { - if (dependOn(operand, argIdx, forOp)) { - auto iterArgs = forOp.getRegionIterArgs(); - *firstUse = iterArgs[argIdx].use_begin().getUser(); - return true; - } - } - return false; -} - -void PipelinePass::removeExtraWait(tt::nvidia_gpu::DotWaitOp dotWaitOp, - bool hasDotWait0) { - if (hasDotWait0) { - for (auto &item : consumerReleaseMap) { - auto &m = item.second.consumerStageMap; - if (m.count(dotWaitOp)) { - m.erase(dotWaitOp); - } - } - dotWaitOp->erase(); - } -} - -void PipelinePass::asyncLaunchDots(scf::ForOp forOp) { - Block *loop = forOp.getBody(); - auto getBlockNumInFor = [](Operation *op, scf::ForOp forOp) { - if (!op) - return -1l; - auto lastOp = op; - while (op->getBlock()->getParentOp() != forOp) { - lastOp = op; - op = op->getBlock()->getParentOp(); - } - return std::distance(lastOp->getBlock()->getParent()->begin(), - lastOp->getBlock()->getIterator()); - }; - /// XXX(Keren): Clean up the following duplicate code with checkDotOp - /// dots to be pipelined - bool hasSyncDot = false; - bool hasDotWait0 = false; - SmallVector allDots; - SmallVector dots; - SmallVector resultNeedSync; - for (Operation &op : *loop) { - if (auto dotWaitOp = dyn_cast(&op)) { - auto attr = dotWaitOp->getAttrOfType("pendings"); - auto pendingCount = attr.getInt(); - if (pendingCount == 0) - hasDotWait0 = true; - } - if (auto dotOp = dyn_cast(&op)) { - allDots.push_back(dotOp); - } - } - for (Operation &op : *loop) { - if (auto dotOp = dyn_cast(&op)) { - auto resTy = dotOp.getResult().getType().dyn_cast(); - if (auto resEnc = resTy.getEncoding().dyn_cast()) { - if (resEnc && resEnc.isHopper()) { - auto dot = dotOp.getResult(); - bool valid = true; - - // all users of dot should be scf.yield - if (!dot.hasOneUse()) - valid = false; - if (!isa(*dot.getUsers().begin())) - valid = false; - - Operation *firstUse = nullptr; - auto depend = selfDepend(dotOp, forOp, &firstUse); - bool selfDirectDepend = (dotOp == firstUse); - for (auto tempInAll : allDots) { - auto iter = std::find(dots.begin(), dots.end(), tempInAll); - if (iter != dots.end()) - continue; - auto db = getBlockNumInFor(tempInAll, forOp); - auto fb = getBlockNumInFor(firstUse, forOp); - if (db < fb || - (db == fb && db >= 0 && tempInAll->isBeforeInBlock(firstUse))) - hasSyncDot = true; - } - auto CArg = dotOp.getOperand(2); - if (!(selfDirectDepend || - (depend && !selfDirectDepend && hasSyncDot)) || - !CArg.hasOneUse()) - valid = false; - - if (valid) { - dots.push_back(dotOp); - resultNeedSync.push_back( - dotOp->getUses().begin()->getOperandNumber()); - } - } - } - } - } - - // Early stop: no need to continue if there is no valid dot in the loop. - if (dots.empty()) - return; - - OpBuilder builder(forOp); - // 0. insert dot_wait after the last dot in the loop as we implicitly pipeline - // wgmma ops by one stage. - // This is needed to prevent shared memory inputs to be overriden before the - // operation is completed. - // TODO: merge this with the rest of the pipelining transformation and look at - // a better representation for async dots. - tt::DotOp lastDot = dots.back(); - auto loc = lastDot.getLoc(); - builder.setInsertionPointAfter(lastDot); - auto dotWait = builder.create( - lastDot.getLoc(), lastDot.getResult(), dots.size()); - - // 1. replace Dot with DotAsync - for (size_t idx = 0; idx < dots.size(); ++idx) { - tt::DotOp dotOp = dots[idx]; - builder.setInsertionPoint(dotOp); - auto dotAsync = builder.create( - dotOp.getLoc(), dotOp.getA(), dotOp.getB(), dotOp.getC(), - dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc()); - dotOp.replaceAllUsesWith(dotAsync.getResult()); - updateConsumerReleaseInfo(dotOp, dotWait, /*stage=*/1); - dotOp->erase(); - } - - hasDotWait0 = hasDotWait0 || hasSyncDot; - - // 2. If there's any outstanding DotAsyncOps, we need to wait for them. - builder.setInsertionPointAfter(forOp); - SmallVector resultTypes(resultNeedSync.size()); - SmallVector yieldThenValues(resultNeedSync.size()); - SmallVector yieldElseValues(resultNeedSync.size()); - for (int i = 0; i < resultNeedSync.size(); ++i) { - resultTypes[i] = forOp->getResult(resultNeedSync[i]).getType(); - yieldThenValues[i] = forOp->getResult(resultNeedSync[i]); - yieldElseValues[i] = forOp->getResult(resultNeedSync[i]); - } - Value loopNotEmpty = builder.create( - loc, arith::CmpIPredicate::slt, forOp.getLowerBound(), - forOp.getUpperBound()); - auto ifOp = builder.create(loc, resultTypes, loopNotEmpty, - /*hasElse*/ true); - builder.setInsertionPointToStart(ifOp.thenBlock()); - for (int i = 0; i < resultNeedSync.size(); ++i) { - Value result = forOp->getResult(resultNeedSync[i]); - if (result.use_empty()) - continue; - auto dotWait = - builder.create(forOp.getLoc(), result, 0); - result.replaceAllUsesExcept(ifOp.getResult(i), dotWait); - yieldThenValues[i] = dotWait.getResult(); - } - auto yieldOpThen = builder.create(loc, yieldThenValues); - builder.setInsertionPointToEnd(ifOp.elseBlock()); - auto yieldOpElse = builder.create(loc, yieldElseValues); - - // 3. potentially remove redundant dot_wait after dot_async if having mutiple - // DotOp in the loop - removeExtraWait(dotWait, hasDotWait0); -} - -Value PipelinePass::getRemoteCTAId(OpBuilder &b, Location loc, - ttg::CTALayoutAttr CTALayout, - Value remoteCTAIdIdx) const { - auto CTAsPerCGA = CTALayout.getCTAsPerCGA(); - auto CTAOrder = CTALayout.getCTAOrder(); - auto CTASplitNum = CTALayout.getCTASplitNum(); - - // Short path when bcastMask is a constant - bool isConstMcastMask = true; - for (unsigned s : CTASplitNum) { - if (s > 1) { - isConstMcastMask = false; - break; - } - } - if (isConstMcastMask) - return remoteCTAIdIdx; - - Value linearCTAId = b.create(loc); - SmallVector multiDimCTAId = - delinearize(b, loc, linearCTAId, CTAsPerCGA, CTAOrder); - auto rank = CTAOrder.size(); - int bcastDim = -1; - for (size_t i = 0; i < rank; ++i) { - if (CTAsPerCGA[i] != CTASplitNum[i]) { - assert(bcastDim < 0 && "bcast in multiple dims is not expected"); - bcastDim = i; - } - } - multiDimCTAId[bcastDim] = remoteCTAIdIdx; - return linearize(b, loc, multiDimCTAId, CTAsPerCGA, CTAOrder); -} - -void PipelinePass::emitConsumerRelease(Value mbarTensor, - const ConsumerReleaseInfo &info, - int numStages) { - Value iterVar = info.iterVar; - Value stage = info.stageVar; - Value phase = info.phaseVar; - Value nextIV = info.nextIVVar; - Value step = info.stepVar; - Value upperBound = info.upperBoundVar; - - const auto &consumerStageMap = info.consumerStageMap; - // find the the last consumer among all the consumers with the largest stage. - SmallVector consumersWithLargestStage; - int maxStage = 0; - for (const auto &it : consumerStageMap) { - if (it.second > maxStage) { - consumersWithLargestStage.clear(); - consumersWithLargestStage.push_back(it.first); - maxStage = it.second; - } else if (it.second == maxStage) { - consumersWithLargestStage.push_back(it.first); - } - } - assert(consumersWithLargestStage.size() > 0); - DenseMap operationId; - consumersWithLargestStage[0]->getBlock()->walk( - [&](Operation *op) { operationId[op] = operationId.size(); }); - size_t maxId = 0; - Operation *lastUserWithLargestStage; - for (Operation *op : consumersWithLargestStage) { - assert(operationId.find(op) != operationId.end()); - size_t userId = operationId[op]; - if (userId > maxId) { - maxId = userId; - lastUserWithLargestStage = op; - } - } - - OpBuilder b(&getContext()); - b.setInsertionPointAfter(lastUserWithLargestStage); - auto loc = lastUserWithLargestStage->getLoc(); - auto maxStageVal = b.create(loc, maxStage, 32); - - // pred = (iterVar >= maxStage) && - // (threadId % (numConsumerThreads / numRemoteCTAs) == 0); - - // [benzh] maybe we can simplify the logics here - auto cmpOp = arith::CmpIPredicate::sge; - if (maxStage == 0) - cmpOp = arith::CmpIPredicate::sgt; - Value pred = b.create(loc, cmpOp, iterVar, maxStageVal); - - Value threadId = b.create(loc); - auto CTAsPerCGA = info.CTALayout.getCTAsPerCGA(); - auto CTASplitNum = info.CTALayout.getCTASplitNum(); - auto numRemoteCTAs = std::accumulate(CTAsPerCGA.begin(), CTAsPerCGA.end(), 1, - std::multiplies{}) / - std::accumulate(CTASplitNum.begin(), CTASplitNum.end(), - 1, std::multiplies{}); - auto numConsumerThreads = - isa(lastUserWithLargestStage) ? 128 : 32; - Value _0 = b.create(loc, 0, 32); - Value numArrives = b.create( - loc, numConsumerThreads / numRemoteCTAs, 32); - pred = b.create( - loc, pred, - b.create( - loc, arith::CmpIPredicate::eq, - b.create(loc, threadId, numArrives), _0)); - // remoteCtaIdIdx = (threadId % numConsumerThreads) / (numConsumerThreads / - // numRemoteCTAs); - Value remoteCTAIdIdx = b.create( - loc, - b.create( - loc, threadId, - b.create(loc, numConsumerThreads, 32)), - numArrives); - Value remoteCTAId = getRemoteCTAId(b, loc, info.CTALayout, remoteCTAIdIdx); - Value emptyBarrier = b.create( - loc, tt::PointerType::get(b.getIntegerType(64), 3), mbarTensor, stage); - - Value newNextIV = b.create(loc, nextIV, step); - Value nextLoopCond = b.create(loc, arith::CmpIPredicate::slt, - newNextIV, upperBound); - auto ifOp = b.create(loc, ArrayRef{}, nextLoopCond, - /*hasElse*/ false); - b.setInsertionPointToStart(ifOp.thenBlock()); - - b.create(loc, emptyBarrier, pred, remoteCTAId, - /*trackAsyncOp*/ false); -} - -} // anonymous namespace - -std::unique_ptr mlir::createTritonGPUPipelinePass(int numStages, - int numWarps, - int numCTAs, - int computeCapability) { - return std::make_unique(numStages, numWarps, numCTAs, - computeCapability); -} diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp new file mode 100644 index 000000000..bfee0aaac --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -0,0 +1,826 @@ +#include "PipelineExpander.h" +#include "Schedule.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/Support/Debug.h" + +#define int_attr(num) builder.getI64IntegerAttr(num) + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +// TODO: We can extra some helpers into common utilities once we add more +// schedules. + +/// Replace the yield with a new one with the given operands appended. +static void appendToYield(scf::ForOp forOp, ArrayRef newOperands) { + // Fix up the yield op. + Operation *yieldOp = forOp.getBody()->getTerminator(); + SmallVector operands(yieldOp->getOperands().begin(), + yieldOp->getOperands().end()); + operands.append(newOperands.begin(), newOperands.end()); + OpBuilder builder(yieldOp); + builder.create(yieldOp->getLoc(), operands); + yieldOp->erase(); +} + +static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, + Value insertIdx, Value extractIdx) { + OpBuilder builder(forOp); + // Replace the load with insert/extract slice. + builder.setInsertionPoint(loadOp); + Location loc = loadOp.getLoc(); + auto insertOp = builder.create( + loc, alloc.getType(), loadOp.getPtr(), alloc, insertIdx, loadOp.getMask(), + loadOp.getOther(), loadOp.getCache(), loadOp.getEvict(), + loadOp.getIsVolatile(), /*axis*/ 0); + auto commmit = builder.create(loc); + + // Extract part. + auto allocType = alloc.getType().cast(); + RankedTensorType sliceType = RankedTensorType::get( + {allocType.getShape()[1], allocType.getShape()[2]}, + allocType.getElementType(), allocType.getEncoding()); + auto extract = builder.create( + loc, sliceType, insertOp.getResult(), + SmallVector{extractIdx, int_attr(0), int_attr(0)}, + SmallVector{int_attr(1), int_attr(sliceType.getShape()[0]), + int_attr(sliceType.getShape()[1])}, + SmallVector{int_attr(1), int_attr(1), int_attr(1)}); + Operation *user = *loadOp.getResult().getUsers().begin(); + auto convertLayout = llvm::cast(user); + auto newCvt = builder.create( + convertLayout->getLoc(), convertLayout.getType(), extract.getResult()); + convertLayout->replaceAllUsesWith(newCvt->getResults()); + convertLayout->erase(); + loadOp.erase(); + + // Fix up the yield op. + appendToYield(forOp, {insertOp}); +} + +static void createTMALoad(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, + Value insertIdx, Value extractIdx, Value phase) { + OpBuilder builder(forOp); + Location loc = loadOp.getLoc(); + auto CTALayout = ttg::CTALayoutAttr::get(loadOp.getContext(), + /*CTAsPerCGA*/ {1}, + /*CTASplitNum*/ {1}, + /*CTAOrder*/ {0}); + auto sharedEncoding = ttg::SharedEncodingAttr::get(loadOp.getContext(), 1, 1, + 1, {0}, CTALayout, false); + int64_t numBuffers = alloc.getType().cast().getShape()[0]; + auto mBarriersTy = RankedTensorType::get( + {numBuffers}, builder.getIntegerType(64), sharedEncoding); + // Allocate an array of mbarrier objects outside the loop. + Value barrierArray = + builder.create(loc, mBarriersTy, 1); + // extract the barrier and emit arriver/copy/wait/extract code sequence. + builder.setInsertionPoint(loadOp); + auto mBarTy = tt::PointerType::get(builder.getIntegerType(64), 3); + Value barrier = builder.create( + loc, mBarTy, barrierArray, insertIdx); + Value zero = builder.create(loc, 0, 32); + Value threadId = builder.create(loc); + Value pred = builder.create(loc, arith::CmpIPredicate::eq, + threadId, zero); + + auto loadTy = loadOp.getType().dyn_cast(); + auto loadShape = loadTy.getShape(); + auto CTASplitNum = ttg::getCTASplitNum(loadTy.getEncoding()); + auto shapePerSlice = ttg::getShapePerCTA(CTASplitNum, loadShape); + auto elemTy = loadTy.getElementType(); + unsigned elems = std::accumulate(shapePerSlice.begin(), shapePerSlice.end(), + 1, std::multiplies{}); + elems *= (elemTy.getIntOrFloatBitWidth() / 8); + builder.create(loc, barrier, pred, + /*remoteCtaId*/ nullptr, + /*trackAsyncOp*/ false, elems); + auto allocType = alloc.getType().cast(); + auto insertOp = builder.create( + loc, allocType, loadOp.getPtr(), alloc, + /*index*/ insertIdx, barrier, loadOp.getMask(), loadOp.getOther(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile(), + /*axis*/ 0); + + RankedTensorType sliceType = RankedTensorType::get( + {allocType.getShape()[1], allocType.getShape()[2]}, + allocType.getElementType(), allocType.getEncoding()); + auto extract = builder.create( + loc, sliceType, insertOp.getResult(), + SmallVector{extractIdx, int_attr(0), int_attr(0)}, + SmallVector{int_attr(1), int_attr(sliceType.getShape()[0]), + int_attr(sliceType.getShape()[1])}, + SmallVector{int_attr(1), int_attr(1), int_attr(1)}); + + Value barrierWait = builder.create( + loc, mBarTy, barrierArray, extractIdx); + builder.create(loc, barrierWait, phase); + + Operation *user = *loadOp.getResult().getUsers().begin(); + auto convertLayout = llvm::cast(user); + auto newCvt = builder.create( + convertLayout->getLoc(), convertLayout.getType(), extract.getResult()); + convertLayout->replaceAllUsesWith(newCvt->getResults()); + convertLayout->erase(); + loadOp.erase(); + + // Fix up the yield op. + appendToYield(forOp, {insertOp}); +} + +/// Create an async load equivalent to the given load. +static void createAsyncLoad(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, + Value insertIdx, Value extractIdx, Value phase) { + if (isLoadFromTensorPtr(loadOp)) { + createTMALoad(forOp, loadOp, alloc, insertIdx, extractIdx, phase); + } else { + createAsyncCopy(forOp, loadOp, alloc, insertIdx, extractIdx); + } +} + +// Return the transitive use of the load which is a dot operand. +static Value loadDotOperand(tt::LoadOp loadOp, bool &hasMMAV3) { + // We only pipeline loads that have one covert_layout (to dot_op) use + // TODO: lift this constraint in the future + bool isCandidate = false; + if (!loadOp.getResult().hasOneUse()) + return Value(); + + Operation *use = *loadOp.getResult().getUsers().begin(); + Operation *preUse = nullptr; + + // Advance to the first conversion as long as the use resides in shared + // memory and it has a single use itself + while (use) { + if (use->getNumResults() != 1 || !use->getResult(0).hasOneUse()) + break; + auto tensorType = use->getResult(0).getType().dyn_cast(); + if (!tensorType.getEncoding().isa()) + break; + preUse = use; + use = *use->getResult(0).getUsers().begin(); + } + + if (auto convertLayout = llvm::dyn_cast(use)) { + if (auto tensorType = + convertLayout.getResult().getType().dyn_cast()) { + if (auto dotOpEnc = tensorType.getEncoding() + .dyn_cast()) { + return convertLayout.getResult(); + } + } + } else if (preUse && isa(use)) { + // for MMAv3 whose dot take SharedEncoding as operands directly + Operation *post = *loadOp.getResult().getUsers().begin(); + auto newOrder = post->getResult(0) + .getType() + .cast() + .getEncoding() + .cast() + .getOrder(); + auto ty = loadOp.getType().cast(); + auto oldOrder = ttg::getOrder(ty.getEncoding()); + // The operand of MMAv3 is in SharedEncoding and it's order should not + // be changed after FuseTranspositions Pass. So we only pipeline the + // load if the order of the loaded BlockedEncoding is the same as the + // order of the SharedEncoding it is converted to. + // TODO: remove this constraint once the LoadOp supports transpose + // fusion + if (newOrder[0] == oldOrder[0] || newOrder[1] == oldOrder[1]) { + hasMMAV3 = true; + return preUse->getResult(0); + } + } + return Value(); +} + +namespace { +struct LoadDotOperand { + LoadDotOperand(tt::LoadOp load, Value dotOperand) + : load(load), dotOperand(dotOperand) {} + tt::LoadOp load; + Value dotOperand; +}; +} // namespace + +/// Collect loads to pipeline. Return success if we can pipeline this loop +static void collectOpsToPipeline(scf::ForOp forOp, + SmallVectorImpl &ops, + bool &hasMMAV3) { + ModuleOp moduleOp = forOp->getParentOfType(); + ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + + // We cannot use forOp.walk(...) here because we only want to visit the + // operations in the loop body block. Nested blocks are handled separately. + for (Operation &op : forOp) { + if (auto loadOp = dyn_cast(&op)) { + bool candidate = false; + if (isLoadFromTensorPtr(loadOp)) { + // Map to TMA load. + candidate = true; + } else { + auto ptr = loadOp.getPtr(); + unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr); + if (auto mask = loadOp.getMask()) + vec = + std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); + + auto tensorTy = ptr.getType().dyn_cast(); + if (!tensorTy || tensorTy.getRank() < 2) + continue; + auto ty = + tensorTy.getElementType().cast().getPointeeType(); + unsigned width = vec * ty.getIntOrFloatBitWidth(); + // We do not pipeline all loads for the following reasons: + // 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8 and 16. + // 2. It's likely that pipling small loads won't offer much performance + // improvement and may even hurt performance by increasing register + // pressure. + if (width >= 32) + candidate = true; + } + if (!candidate) + continue; + Value dotOperand = loadDotOperand(loadOp, hasMMAV3); + if (!dotOperand) + continue; + ops.emplace_back(loadOp, dotOperand); + } + } +} + +// Create an allocation that can old distance number of loadOp shapes. +static Value createAlloc(scf::ForOp &forOp, tt::LoadOp loadOp, Value dotOperand, + unsigned distance) { + OpBuilder builder(forOp); + auto ty = loadOp.getType().cast(); + if (!loadOp.getResult().hasOneUse()) + return Value(); + Attribute sharedEnc; + auto CTALayout = ttg::getCTALayout(ty.getEncoding()); + auto tensorType = dotOperand.getType().cast(); + if (auto dotOpEnc = + tensorType.getEncoding().dyn_cast()) { + auto convertLayout = dotOperand.getDefiningOp(); + bool needTrans = dyn_cast_or_null( + convertLayout->getOperand(0).getDefiningOp()); + unsigned bitWidth = ty.getElementType().getIntOrFloatBitWidth(); + sharedEnc = ttg::SharedEncodingAttr::get( + ty.getContext(), dotOpEnc, ty.getShape(), + ttg::getOrder(ty.getEncoding()), CTALayout, bitWidth, needTrans); + } else { + // MMAv3 + sharedEnc = ttg::SharedEncodingAttr::get(ty.getContext(), ty.getShape(), + ttg::getOrder(ty.getEncoding()), + CTALayout, ty.getElementType()); + } + SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); + bufferShape.insert(bufferShape.begin(), distance); + Type allocType = + RankedTensorType::get(bufferShape, ty.getElementType(), sharedEnc); + Value alloc = builder.create( + loadOp.getLoc(), allocType); + return alloc; +} + +// Convert load ops into their asyn version and apply multi-buffering based on +// the number of stages. +static void createAsynOps(scf::ForOp &forOp, ArrayRef loads, + int numStages, bool hasMMAV3) { + struct AsyncLoad { + AsyncLoad(tt::LoadOp loadOp, Value alloc) : loadOp(loadOp), alloc(alloc) {} + tt::LoadOp loadOp; + Value alloc; + }; + int numBuffers = numStages - 1; + // For MMAv3 we need an extra buffer as this is assumed in the wgmma + // pipelining post-processing. + // TODO: Improve modeling of wgmma pipelining. + if (hasMMAV3) + numBuffers++; + SmallVector asyncLoads; + SmallVector newOperands; + bool needsMbarrierPhase = false; + bool needsAsyncWait = false; + for (const LoadDotOperand &loadOperand : loads) { + tt::LoadOp loadOp = loadOperand.load; + Value dotOperand = loadOperand.dotOperand; + Value alloc = createAlloc(forOp, loadOp, dotOperand, numBuffers); + assert(alloc && "Failed to create alloc for the async load."); + newOperands.push_back(alloc); + asyncLoads.emplace_back(loadOp, alloc); + if (isLoadFromTensorPtr(loadOp)) + needsMbarrierPhase = true; + else + needsAsyncWait = true; + } + + OpBuilder builder(forOp); + Location loc = forOp.getLoc(); + // Create two new counters to index into the allocs. + Value minusOne = builder.create(loc, -1, 32); + Value zero = builder.create(loc, 0, 32); + Value one = builder.create(loc, 1, 32); + Value insertIdx = minusOne; + Value extractIdx = minusOne; + Value numBuffersVal = + builder.create(loc, numBuffers, 32); + newOperands.push_back(insertIdx); + newOperands.push_back(extractIdx); + Value phase; + if (needsMbarrierPhase) { + phase = builder.create(loc, 0, 1); + newOperands.push_back(phase); + } + unsigned newOperandIndex = forOp.getBody()->getNumArguments(); + // Patch the loop to add the new loop carried dependencies. + scf::ForOp newForOp = + replaceForOpWithNewSignature(builder, forOp, newOperands); + forOp.erase(); + forOp = newForOp; + for (int i = 0; i < asyncLoads.size(); i++) { + asyncLoads[i].alloc = newForOp.getBody()->getArgument(newOperandIndex + i); + } + insertIdx = + newForOp.getBody()->getArgument(newOperandIndex + asyncLoads.size()); + extractIdx = + newForOp.getBody()->getArgument(newOperandIndex + asyncLoads.size() + 1); + + // Create two counters for the insert and extract indices to avoid creating + // long liverange. + builder.setInsertionPoint(asyncLoads.front().loadOp); + insertIdx = builder.create(loc, insertIdx, one); + Value cndIns = builder.create(loc, arith::CmpIPredicate::slt, + insertIdx, numBuffersVal); + insertIdx = builder.create(loc, cndIns, insertIdx, zero); + + extractIdx = builder.create(loc, extractIdx, one); + Value cndExt = builder.create(loc, arith::CmpIPredicate::slt, + extractIdx, numBuffersVal); + extractIdx = builder.create(loc, cndExt, extractIdx, zero); + + if (needsMbarrierPhase) { + phase = newForOp.getBody()->getArgument(newOperandIndex + + asyncLoads.size() + 2); + Value oneI1 = builder.create(loc, 1, 1); + Value nextPhase = builder.create(loc, phase, oneI1); + phase = builder.create(loc, cndExt, phase, nextPhase); + } + + bool firstLoad = true; + for (AsyncLoad &asyncLoad : asyncLoads) { + createAsyncLoad(forOp, asyncLoad.loadOp, asyncLoad.alloc, insertIdx, + extractIdx, phase); + firstLoad = false; + } + // Insert a waitOp after the first async copy. This does make the assumption + // that the wait will be scheduled in a different stage that all the async + // copy but we cannot guarantee that one wait is enough otherwise. + for (auto &op : forOp.getBody()->without_terminator()) { + if (isa(op)) { + OpBuilder builder(op.getContext()); + builder.setInsertionPointAfter(&op); + builder.create(op.getLoc(), 0); + break; + } + } + SmallVector newYieldOperands = {insertIdx, extractIdx}; + if (needsMbarrierPhase) + newYieldOperands.push_back(phase); + // Patch the yield with the updated counters. + appendToYield(forOp, newYieldOperands); +} + +// Combine the current mask with the given predicate. +static Value getPredMask(RewriterBase &rewriter, Type typeLike, + Value currentMask, Value pred) { + Type maskType = tt::getI1SameShape(typeLike); + Location loc = pred.getLoc(); + Value mask = pred; + if (maskType.isa()) { + mask = rewriter.create(loc, maskType, pred); + } + if (currentMask) { + mask = rewriter.create(loc, mask, currentMask); + } + return mask; +} + +// Function to mask operations during scheduling. +static Operation *predicateOp(RewriterBase &rewriter, Operation *op, + Value pred) { + OpBuilder::InsertionGuard guard(rewriter); + if (mlir::isMemoryEffectFree(op)) + return op; + if (isa(op)) + return op; + if (isa(op)) + return op; + if (auto insertOp = dyn_cast(op)) { + rewriter.setInsertionPoint(insertOp); + Value mask = getPredMask(rewriter, insertOp.getSrc().getType(), + insertOp.getMask(), pred); + insertOp.getMaskMutable().assign(mask); + return op; + } + if (auto insertOp = dyn_cast(op)) { + rewriter.setInsertionPoint(insertOp); + Value mask = getPredMask( + rewriter, + insertOp.getSrc().getType().cast().getPointeeType(), + insertOp.getMask(), pred); + insertOp.getMaskMutable().assign(mask); + return op; + } + if (auto arriveOp = dyn_cast(op)) { + rewriter.setInsertionPoint(arriveOp); + Value mask = getPredMask(rewriter, rewriter.getIntegerType(1), + arriveOp.getPred(), pred); + arriveOp.getPredMutable().assign(mask); + return op; + } + if (isa(op)) { + return op; + } + if (auto loadOp = dyn_cast(op)) { + rewriter.setInsertionPoint(loadOp); + Value mask = getPredMask(rewriter, loadOp.getPtr().getType(), + loadOp.getMask(), pred); + loadOp.getMaskMutable().assign(mask); + return op; + } + + assert("don't know how to predicate this op" && false); + return op; +} + +static void setWaitNum(Operation *op, + mlir::triton::PipeliningOption::PipelinerPart part, + unsigned iteration, unsigned numLoadsInStage) { + if (auto waitOp = dyn_cast(op)) { + waitOp.setNum(numLoadsInStage); + } +} + +/// Helper to recursively add dependencies to the same stage. +static void addDep(Operation *op, DenseSet &deps, + bool includeArg = true, + DenseSet *filter = nullptr) { + if (filter && filter->count(op)) + return; + if (!deps.insert(op).second) + return; + for (Value operand : op->getOperands()) { + Value v = operand; + llvm::SmallDenseSet seen; + while (auto arg = v.dyn_cast()) { + if (!includeArg) + break; + if (!seen.insert(v).second) + break; + if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) { + auto yieldOp = op->getBlock()->getTerminator(); + v = yieldOp->getOperand(arg.getArgNumber() - 1); + continue; + } + break; + } + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + addDep(defOp, deps, includeArg, filter); + } + } +} + +// Add operations to the shedule with the given stage based on the filter +// function. +static void addOps(scf::ForOp forOp, int stage, + std::vector> &schedule, + std::function filter) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!filter(&op)) + continue; + schedule.emplace_back(&op, stage); + } +} + +// create the schedule for a matmul loop. This is ad hoc based on how we know +// matmul loops should be pipelined and is not a generic scheduler. +static std::vector> +createSchedule(scf::ForOp forOp, int numStages, bool prefetchExtract) { + SmallVector insertOps; + SmallVector extractOps; + // Find the insert/extract ops that will go respectively in stage 0 and stage + // `numStages - 2`. All the other operations will go in stage `numStages - 1`. + for (Operation &op : forOp.getBody()->without_terminator()) { + if (isa(op)) + insertOps.emplace_back(&op); + if (prefetchExtract) { + if (isa(op)) + extractOps.emplace_back(&op); + } + } + DenseSet insertAndDeps; + for (Operation *op : insertOps) { + addDep(op, insertAndDeps, false); + } + + // Find depenencies with distance of 1. + SmallVector distanceOneUsers; + for (Operation *op : insertAndDeps) { + for (Value operand : op->getOperands()) { + if (auto arg = operand.dyn_cast()) { + if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) { + auto yieldOp = op->getBlock()->getTerminator(); + Value v = yieldOp->getOperand(arg.getArgNumber() - 1); + Operation *defOp = v.getDefiningOp(); + if (defOp && insertAndDeps.count(defOp) == 0) { + distanceOneUsers.push_back(defOp); + } + } + } + } + } + // Schedule loads with a distance of 1 in stage 0 + for (Operation *op : distanceOneUsers) { + if (isa(op)) { + addDep(op, insertAndDeps, true); + } + } + // For the rest of the ops we can move then into stage 1 so that they can be + // closer to their uses. + DenseSet stage1deps; + for (Operation *op : distanceOneUsers) { + if (!isa(op)) { + addDep(op, stage1deps, true, &insertAndDeps); + } + } + + DenseSet extractAndDeps; + for (Operation *op : extractOps) { + addDep(op, extractAndDeps, true, &insertAndDeps); + } + std::vector> schedule; + // Schedule stage `numStage - 1` first. + addOps(forOp, numStages - 1, schedule, [&](Operation *op) { + return insertAndDeps.count(op) == 0 && stage1deps.count(op) == 0 && + extractAndDeps.count(op) == 0; + }); + + // Schedule some dependencies with distance of 1 into stage 1 to reduce + // pressure. + addOps(forOp, 1, schedule, + [&](Operation *op) { return stage1deps.count(op); }); + + // Then Schedule stage 0. + addOps(forOp, 0, schedule, + [&](Operation *op) { return insertAndDeps.count(op); }); + + // Finally schedule the extract ops in stage `numStage - 2` so that they get + // pre-fetched and play well with pretech pass. + addOps(forOp, numStages - 2, schedule, + [&](Operation *op) { return extractAndDeps.count(op); }); + return schedule; +} + +bool mlir::triton::preProcessLoopAndGetSchedule( + scf::ForOp &forOp, int numStages, mlir::triton::PipeliningOption &options) { + // 1. First collect "interesting" operations with a stage where to schedule + // them. This gives a coarse scheduling for the loop. + SmallVector loads; + bool hasMMAV3 = false; + collectOpsToPipeline(forOp, loads, hasMMAV3); + if (loads.empty()) + return false; + bool hasAsynCp = llvm::any_of(loads, [](LoadDotOperand &load) { + return !isLoadFromTensorPtr(load.load); + }); + // 2. Convert the loads into async loads and create the allocs. + createAsynOps(forOp, loads, numStages, hasMMAV3); + + // 3. Create the final schedule for the kernel loop. This will dictate the + // stages and order of operations to the pipeline expander. + std::vector> schedule = + createSchedule(forOp, numStages, /*prefetchExtract=*/!hasMMAV3); + + // 4. Fill out the pipeline options. + options.getScheduleFn = + [schedule](scf::ForOp forOp, + std::vector> &s) { + s = std::move(schedule); + }; + options.peelEpilogue = false; + options.predicateFn = predicateOp; + options.supportDynamicLoops = true; + unsigned numLoadsInStage = (numStages - 2) * loads.size(); + options.annotateFn = + [numLoadsInStage](Operation *op, + mlir::triton::PipeliningOption::PipelinerPart part, + unsigned iteration) { + return setWaitNum(op, part, iteration, numLoadsInStage); + }; + + if (hasAsynCp) { + // Insert a wait 0 after the loop + OpBuilder builder(forOp); + builder.setInsertionPointAfter(forOp); + builder.create(forOp.getLoc(), 0); + } + return true; +} + +/// MMA V3 post-processing. +static bool selfDepend(tt::DotOp dotOp, scf::ForOp forOp, + Operation **firstUse) { + std::function dependOn = + [&dependOn](Value v, int argId, scf::ForOp forOp) { + auto op = v.getDefiningOp(); + if (isa(v)) { + auto iterArgs = forOp.getRegionIterArgs(); + auto iter = std::find(iterArgs.begin(), iterArgs.end(), v); + if (iter != iterArgs.end()) + return std::distance(iterArgs.begin(), iter) == argId; + } else { + if (!op) + return false; + for (auto operand : op->getOperands()) { + if (dependOn(operand, argId, forOp)) + return true; + } + } + return false; + }; + auto result = dotOp.getResult(); + auto yieldOp = forOp.getBody()->getTerminator(); + int argIdx = -1; + auto iter = std::find(yieldOp->getOperands().begin(), + yieldOp->getOperands().end(), result); + if (iter != yieldOp->getOperands().end()) + argIdx = std::distance(yieldOp->getOperands().begin(), iter); + if (argIdx == -1) + return false; + for (auto operand : dotOp.getOperands()) { + if (dependOn(operand, argIdx, forOp)) { + auto iterArgs = forOp.getRegionIterArgs(); + *firstUse = iterArgs[argIdx].use_begin().getUser(); + return true; + } + } + return false; +} + +static void removeExtraWait(tt::nvidia_gpu::DotWaitOp dotWaitOp, + bool hasDotWait0) { + if (hasDotWait0) { + dotWaitOp->erase(); + } +} + +void mlir::triton::asyncLaunchDots(scf::ForOp forOp) { + Block *loop = forOp.getBody(); + auto getBlockNumInFor = [](Operation *op, scf::ForOp forOp) { + if (!op) + return -1l; + auto lastOp = op; + while (op->getBlock()->getParentOp() != forOp) { + lastOp = op; + op = op->getBlock()->getParentOp(); + } + return std::distance(lastOp->getBlock()->getParent()->begin(), + lastOp->getBlock()->getIterator()); + }; + /// XXX(Keren): Clean up the following duplicate code with checkDotOp + /// dots to be pipelined + bool hasSyncDot = false; + bool hasDotWait0 = false; + SmallVector allDots; + SmallVector dots; + SmallVector resultNeedSync; + for (Operation &op : *loop) { + if (auto dotWaitOp = dyn_cast(&op)) { + auto attr = dotWaitOp->getAttrOfType("pendings"); + auto pendingCount = attr.getInt(); + if (pendingCount == 0) + hasDotWait0 = true; + } + if (auto dotOp = dyn_cast(&op)) { + allDots.push_back(dotOp); + } + } + for (Operation &op : *loop) { + if (auto dotOp = dyn_cast(&op)) { + auto resTy = dotOp.getResult().getType().dyn_cast(); + if (auto resEnc = resTy.getEncoding().dyn_cast()) { + if (resEnc && resEnc.isHopper()) { + auto dot = dotOp.getResult(); + bool valid = true; + + // all users of dot should be scf.yield + if (!dot.hasOneUse()) + valid = false; + if (!isa(*dot.getUsers().begin())) + valid = false; + + Operation *firstUse = nullptr; + auto depend = selfDepend(dotOp, forOp, &firstUse); + bool selfDirectDepend = (dotOp == firstUse); + for (auto tempInAll : allDots) { + auto iter = std::find(dots.begin(), dots.end(), tempInAll); + if (iter != dots.end()) + continue; + auto db = getBlockNumInFor(tempInAll, forOp); + auto fb = getBlockNumInFor(firstUse, forOp); + if (db < fb || + (db == fb && db >= 0 && tempInAll->isBeforeInBlock(firstUse))) + hasSyncDot = true; + } + auto CArg = dotOp.getOperand(2); + if (!(selfDirectDepend || + (depend && !selfDirectDepend && hasSyncDot)) || + !CArg.hasOneUse()) + valid = false; + + if (valid) { + dots.push_back(dotOp); + resultNeedSync.push_back( + dotOp->getUses().begin()->getOperandNumber()); + } + } + } + } + } + + // Early stop: no need to continue if there is no valid dot in the loop. + if (dots.empty()) + return; + + OpBuilder builder(forOp); + // 0. insert dot_wait after the last dot in the loop as we implicitly pipeline + // wgmma ops by one stage. + // This is needed to prevent shared memory inputs to be overriden before the + // operation is completed. + // TODO: merge this with the rest of the pipelining transformation and look at + // a better representation for async dots. + tt::DotOp lastDot = dots.back(); + auto loc = lastDot.getLoc(); + builder.setInsertionPointAfter(lastDot); + auto dotWait = builder.create( + lastDot.getLoc(), lastDot.getResult(), dots.size()); + + // 1. replace Dot with DotAsync + for (size_t idx = 0; idx < dots.size(); ++idx) { + tt::DotOp dotOp = dots[idx]; + builder.setInsertionPoint(dotOp); + auto dotAsync = builder.create( + dotOp.getLoc(), dotOp.getA(), dotOp.getB(), dotOp.getC(), + dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc()); + dotOp.replaceAllUsesWith(dotAsync.getResult()); + dotOp->erase(); + } + + hasDotWait0 = hasDotWait0 || hasSyncDot; + + // 2. If there's any outstanding DotAsyncOps, we need to wait for them. + builder.setInsertionPointAfter(forOp); + SmallVector resultTypes(resultNeedSync.size()); + SmallVector yieldThenValues(resultNeedSync.size()); + SmallVector yieldElseValues(resultNeedSync.size()); + for (int i = 0; i < resultNeedSync.size(); ++i) { + resultTypes[i] = forOp->getResult(resultNeedSync[i]).getType(); + yieldThenValues[i] = forOp->getResult(resultNeedSync[i]); + yieldElseValues[i] = forOp->getResult(resultNeedSync[i]); + } + Value loopNotEmpty = builder.create( + loc, arith::CmpIPredicate::slt, forOp.getLowerBound(), + forOp.getUpperBound()); + auto ifOp = builder.create(loc, resultTypes, loopNotEmpty, + /*hasElse*/ true); + builder.setInsertionPointToStart(ifOp.thenBlock()); + for (int i = 0; i < resultNeedSync.size(); ++i) { + Value result = forOp->getResult(resultNeedSync[i]); + if (result.use_empty()) + continue; + auto dotWait = + builder.create(forOp.getLoc(), result, 0); + result.replaceAllUsesExcept(ifOp.getResult(i), dotWait); + yieldThenValues[i] = dotWait.getResult(); + } + auto yieldOpThen = builder.create(loc, yieldThenValues); + builder.setInsertionPointToEnd(ifOp.elseBlock()); + auto yieldOpElse = builder.create(loc, yieldElseValues); + + // 3. potentially remove redundant dot_wait after dot_async if having mutiple + // DotOp in the loop + removeExtraWait(dotWait, hasDotWait0); +} diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp new file mode 100644 index 000000000..18df341ac --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp @@ -0,0 +1,704 @@ +//===- LoopPipelining.cpp - Code to perform loop software pipelining-------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements loop software pipelining +// +//===----------------------------------------------------------------------===// + +// Fork of upstream pipeliner. This will be merged upstream once things are +// stable. Modifications so far are: +// -Bug fix for def with a distance of 1 scheduled in stage 0. +// -Support dynamic loops and predicate operations in the prologue. +// -Support for non-index type for induction variable. +// -Support source with distance of 1 used multiple stages later. +// -Fix bug when a value yield is used outside the loop and the value def is not +// in the last stage. If we are not peeling the epilgue we need to remap the +// output correctly. + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/MathExtras.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/Support/Debug.h" + +#include "PipelineExpander.h" + +#define DEBUG_TYPE "triton-loop-pipelining" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +using namespace mlir::scf; +using namespace mlir::triton; + +namespace { + +/// Helper to keep internal information during pipelining transformation. +struct LoopPipelinerInternal { + /// Coarse liverange information for ops used across stages. + struct LiverangeInfo { + unsigned lastUseStage = 0; + unsigned defStage = 0; + }; + +protected: + ForOp forOp; + unsigned maxStage = 0; + DenseMap stages; + std::vector opOrder; + Value ub; + Value lb; + Value step; + bool dynamicLoop; + triton::PipeliningOption::AnnotationlFnType annotateFn = nullptr; + bool peelEpilogue; + triton::PipeliningOption::PredicateOpFnType predicateFn = nullptr; + + // When peeling the kernel we generate several version of each value for + // different stage of the prologue. This map tracks the mapping between + // original Values in the loop and the different versions + // peeled from the loop. + DenseMap> valueMapping; + + /// Assign a value to `valueMapping`, this means `val` represents the version + /// `idx` of `key` in the epilogue. + void setValueMapping(Value key, Value el, int64_t idx); + + std::pair getDefiningOpAndDistance(Value value); + +public: + /// Initalize the information for the given `op`, return true if it + /// satisfies the pre-condition to apply pipelining. + bool initializeLoopInfo(ForOp op, const triton::PipeliningOption &options); + /// Emits the prologue, this creates `maxStage - 1` part which will contain + /// operations from stages [0; i], where i is the part index. + void emitPrologue(RewriterBase &rewriter); + /// Gather liverange information for Values that are used in a different stage + /// than its definition. + llvm::MapVector analyzeCrossStageValues(); + scf::ForOp createKernelLoop( + const llvm::MapVector &crossStageValues, + RewriterBase &rewriter, + llvm::DenseMap, unsigned> &loopArgMap); + /// Emits the pipelined kernel. This clones loop operations following user + /// order and remaps operands defined in a different stage as their use. + LogicalResult createKernel( + scf::ForOp newForOp, + const llvm::MapVector &crossStageValues, + const llvm::DenseMap, unsigned> &loopArgMap, + RewriterBase &rewriter); + /// Emits the epilogue, this creates `maxStage - 1` part which will contain + /// operations from stages [i; maxStage], where i is the part index. + llvm::SmallVector emitEpilogue(RewriterBase &rewriter); +}; + +bool LoopPipelinerInternal::initializeLoopInfo( + ForOp op, const triton::PipeliningOption &options) { + LDBG("Start initializeLoopInfo"); + forOp = op; + ub = forOp.getUpperBound(); + lb = forOp.getLowerBound(); + step = forOp.getStep(); + + dynamicLoop = true; + auto upperBoundCst = ub.getDefiningOp(); + auto lowerBoundCst = lb.getDefiningOp(); + auto stepCst = step.getDefiningOp(); + if (!upperBoundCst || !lowerBoundCst || !stepCst) { + if (!options.supportDynamicLoops) { + LDBG("--dynamic loop not supported -> BAIL"); + return false; + } + } else { + int64_t ubImm = upperBoundCst.value(); + int64_t lbImm = lowerBoundCst.value(); + int64_t stepImm = stepCst.value(); + int64_t numIteration = ceilDiv(ubImm - lbImm, stepImm); + if (numIteration > maxStage) { + dynamicLoop = false; + } else if (!options.supportDynamicLoops) { + LDBG("--fewer loop iterations than pipeline stages -> BAIL"); + return false; + } + } + peelEpilogue = options.peelEpilogue; + predicateFn = options.predicateFn; + if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) { + LDBG("--no epilogue or predicate set -> BAIL"); + return false; + } + std::vector> schedule; + options.getScheduleFn(forOp, schedule); + if (schedule.empty()) { + LDBG("--empty schedule -> BAIL"); + return false; + } + + opOrder.reserve(schedule.size()); + for (auto &opSchedule : schedule) { + maxStage = std::max(maxStage, opSchedule.second); + stages[opSchedule.first] = opSchedule.second; + opOrder.push_back(opSchedule.first); + } + + // All operations need to have a stage. + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!stages.contains(&op)) { + op.emitOpError("not assigned a pipeline stage"); + LDBG("--op not assigned a pipeline stage: " << op << " -> BAIL"); + return false; + } + } + + // Currently, we do not support assigning stages to ops in nested regions. The + // block of all operations assigned a stage should be the single `scf.for` + // body block. + for (const auto &[op, stageNum] : stages) { + (void)stageNum; + if (op == forOp.getBody()->getTerminator()) { + op->emitError("terminator should not be assigned a stage"); + LDBG("--terminator should not be assigned stage: " << *op << " -> BAIL"); + return false; + } + if (op->getBlock() != forOp.getBody()) { + op->emitOpError("the owning Block of all operations assigned a stage " + "should be the loop body block"); + LDBG("--the owning Block of all operations assigned a stage " + "should be the loop body block: " + << *op << " -> BAIL"); + return false; + } + } + + // Only support loop carried dependency with a distance of 1. This means the + // source of all the scf.yield operands needs to be defined by operations in + // the loop. + if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), + [this](Value operand) { + Operation *def = operand.getDefiningOp(); + return !def || !stages.contains(def); + })) { + LDBG("--only support loop carried dependency with a distance of 1 -> BAIL"); + return false; + } + annotateFn = options.annotateFn; + return true; +} + +/// Clone `op` and call `callback` on the cloned op's oeprands as well as any +/// operands of nested ops that: +/// 1) aren't defined within the new op or +/// 2) are block arguments. +static Operation * +cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op, + function_ref callback) { + Operation *clone = rewriter.clone(*op); + for (OpOperand &operand : clone->getOpOperands()) + callback(&operand); + clone->walk([&](Operation *nested) { + for (OpOperand &operand : nested->getOpOperands()) { + Operation *def = operand.get().getDefiningOp(); + if ((def && !clone->isAncestor(def)) || isa(operand.get())) + callback(&operand); + } + }); + return clone; +} + +void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { + // Initialize the iteration argument to the loop initiale values. + for (BlockArgument &arg : forOp.getRegionIterArgs()) { + OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg); + setValueMapping(arg, operand.get(), 0); + } + auto yield = cast(forOp.getBody()->getTerminator()); + Location loc = forOp.getLoc(); + for (int64_t i = 0; i < maxStage; i++) { + Value predicate; + if (dynamicLoop) { + Type t = ub.getType(); + // pred = ub > lb + (i * step) + Value iv = rewriter.create( + loc, lb, + rewriter.create( + loc, step, + rewriter.create( + loc, rewriter.getIntegerAttr(t, i)))); + predicate = rewriter.create(loc, arith::CmpIPredicate::slt, + iv, ub); + } + + // special handling for induction variable as the increment is implicit. + // iv = lb + i * step + Type t = lb.getType(); + Value iv = rewriter.create( + loc, lb, + rewriter.create( + loc, step, + rewriter.create(loc, + rewriter.getIntegerAttr(t, i)))); + setValueMapping(forOp.getInductionVar(), iv, i); + for (Operation *op : opOrder) { + if (stages[op] > i) + continue; + Operation *newOp = + cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { + auto it = valueMapping.find(newOperand->get()); + if (it != valueMapping.end()) { + Value replacement = it->second[i - stages[op]]; + newOperand->set(replacement); + } + }); + if (predicate) { + newOp = predicateFn(rewriter, newOp, predicate); + assert(newOp && "failed to predicate op."); + } + if (annotateFn) + annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Prologue, i); + for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { + setValueMapping(op->getResult(destId), newOp->getResult(destId), + i - stages[op]); + // If the value is a loop carried dependency update the loop argument + // mapping. + for (OpOperand &operand : yield->getOpOperands()) { + if (operand.get() != op->getResult(destId)) + continue; + setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], + newOp->getResult(destId), i - stages[op] + 1); + } + } + } + } +} + +std::pair +LoopPipelinerInternal::getDefiningOpAndDistance(Value value) { + int64_t distance = 0; + if (auto arg = dyn_cast(value)) { + if (arg.getOwner() != forOp.getBody()) + return {nullptr, 0}; + // Ignore induction variable. + if (arg.getArgNumber() == 0) + return {nullptr, 0}; + distance++; + value = + forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1); + } + Operation *def = value.getDefiningOp(); + if (!def) + return {nullptr, 0}; + return {def, distance}; +} + +llvm::MapVector +LoopPipelinerInternal::analyzeCrossStageValues() { + llvm::MapVector crossStageValues; + for (Operation *op : opOrder) { + unsigned stage = stages[op]; + + auto analyzeOperand = [&](OpOperand &operand) { + auto [def, distance] = getDefiningOpAndDistance(operand.get()); + if (!def) + return; + auto defStage = stages.find(def); + if (defStage == stages.end() || defStage->second == stage || + defStage->second == stage + distance) + return; + assert(stage > defStage->second); + LiverangeInfo &info = crossStageValues[operand.get()]; + info.defStage = defStage->second; + info.lastUseStage = std::max(info.lastUseStage, stage); + }; + + for (OpOperand &operand : op->getOpOperands()) + analyzeOperand(operand); + visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) { + analyzeOperand(*operand); + }); + } + return crossStageValues; +} + +scf::ForOp LoopPipelinerInternal::createKernelLoop( + const llvm::MapVector + &crossStageValues, + RewriterBase &rewriter, + llvm::DenseMap, unsigned> &loopArgMap) { + // Creates the list of initial values associated to values used across + // stages. The initial values come from the prologue created above. + // Keep track of the kernel argument associated to each version of the + // values passed to the kernel. + llvm::SmallVector newLoopArg; + // For existing loop argument initialize them with the right version from the + // prologue. + for (const auto &retVal : + llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { + Operation *def = retVal.value().getDefiningOp(); + assert(def && "Only support loop carried dependencies of distance 1"); + unsigned defStage = stages[def]; + Value valueVersion = valueMapping[forOp.getRegionIterArgs()[retVal.index()]] + [maxStage - defStage]; + assert(valueVersion); + newLoopArg.push_back(valueVersion); + } + for (auto escape : crossStageValues) { + LiverangeInfo &info = escape.second; + Value value = escape.first; + for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage; + stageIdx++) { + Value valueVersion = + valueMapping[value][maxStage - info.lastUseStage + stageIdx]; + assert(valueVersion); + newLoopArg.push_back(valueVersion); + loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage - + stageIdx)] = newLoopArg.size() - 1; + } + } + + // Create the new kernel loop. When we peel the epilgue we need to peel + // `numStages - 1` iterations. Then we adjust the upper bound to remove those + // iterations. + Value newUb = forOp.getUpperBound(); + if (peelEpilogue) { + Type t = ub.getType(); + Location loc = forOp.getLoc(); + // newUb = ub - maxStage * step + newUb = rewriter.create( + loc, ub, + rewriter.create( + loc, step, + rewriter.create( + loc, rewriter.getIntegerAttr(t, -maxStage)))); + } + auto newForOp = + rewriter.create(forOp.getLoc(), forOp.getLowerBound(), newUb, + forOp.getStep(), newLoopArg); + // When there are no iter args, the loop body terminator will be created. + // Since we always create it below, remove the terminator if it was created. + if (!newForOp.getBody()->empty()) + rewriter.eraseOp(newForOp.getBody()->getTerminator()); + return newForOp; +} + +LogicalResult LoopPipelinerInternal::createKernel( + scf::ForOp newForOp, + const llvm::MapVector + &crossStageValues, + const llvm::DenseMap, unsigned> &loopArgMap, + RewriterBase &rewriter) { + valueMapping.clear(); + + // Create the kernel, we clone instruction based on the order given by + // user and remap operands coming from a previous stages. + rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); + IRMapping mapping; + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) { + mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + } + SmallVector predicates(maxStage + 1, nullptr); + if (!peelEpilogue) { + // Create a predicate for each stage except the last stage. + Location loc = newForOp.getLoc(); + Type t = ub.getType(); + for (unsigned i = 0; i < maxStage; i++) { + // c = ub - (maxStage - i) * step + Value c = rewriter.create( + loc, ub, + rewriter.create( + loc, step, + rewriter.create( + loc, rewriter.getIntegerAttr(t, -int64_t(maxStage - i))))); + + Value pred = rewriter.create( + newForOp.getLoc(), arith::CmpIPredicate::slt, + newForOp.getInductionVar(), c); + predicates[i] = pred; + } + } + for (Operation *op : opOrder) { + int64_t useStage = stages[op]; + auto *newOp = rewriter.clone(*op, mapping); + SmallVector operands; + // Collect all the operands for the cloned op and its nested ops. + op->walk([&operands](Operation *nestedOp) { + for (OpOperand &operand : nestedOp->getOpOperands()) { + operands.push_back(&operand); + } + }); + for (OpOperand *operand : operands) { + Operation *nestedNewOp = mapping.lookup(operand->getOwner()); + // Special case for the induction variable uses. We replace it with a + // version incremented based on the stage where it is used. + if (operand->get() == forOp.getInductionVar()) { + rewriter.setInsertionPoint(newOp); + + // offset = (maxStage - stages[op]) * step + Type t = step.getType(); + Value offset = rewriter.create( + forOp.getLoc(), step, + rewriter.create( + forOp.getLoc(), + rewriter.getIntegerAttr(t, maxStage - stages[op]))); + Value iv = rewriter.create( + forOp.getLoc(), newForOp.getInductionVar(), offset); + nestedNewOp->setOperand(operand->getOperandNumber(), iv); + rewriter.setInsertionPointAfter(newOp); + continue; + } + Value source = operand->get(); + auto arg = dyn_cast(source); + if (arg && arg.getOwner() == forOp.getBody()) { + Value ret = forOp.getBody()->getTerminator()->getOperand( + arg.getArgNumber() - 1); + Operation *dep = ret.getDefiningOp(); + if (!dep) + continue; + auto stageDep = stages.find(dep); + if (stageDep == stages.end() || stageDep->second == useStage) + continue; + // If the value is a loop carried value coming from stage N + 1 remap, + // it will become a direct use. + if (stageDep->second == useStage + 1) { + nestedNewOp->setOperand(operand->getOperandNumber(), + mapping.lookupOrDefault(ret)); + continue; + } + source = ret; + } + // For operands defined in a previous stage we need to remap it to use + // the correct region argument. We look for the right version of the + // Value based on the stage where it is used. + Operation *def = source.getDefiningOp(); + if (!def) + continue; + auto stageDef = stages.find(def); + if (stageDef == stages.end() || stageDef->second == useStage) + continue; + auto remap = loopArgMap.find( + std::make_pair(operand->get(), useStage - stageDef->second)); + assert(remap != loopArgMap.end()); + nestedNewOp->setOperand(operand->getOperandNumber(), + newForOp.getRegionIterArgs()[remap->second]); + } + + if (predicates[useStage]) { + newOp = predicateFn(rewriter, newOp, predicates[useStage]); + if (!newOp) + return failure(); + // Remap the results to the new predicated one. + for (auto values : llvm::zip(op->getResults(), newOp->getResults())) + mapping.map(std::get<0>(values), std::get<1>(values)); + } + rewriter.setInsertionPointAfter(newOp); + if (annotateFn) + annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Kernel, 0); + } + + // Collect the Values that need to be returned by the forOp. For each + // value we need to have `LastUseStage - DefStage` number of versions + // returned. + // We create a mapping between original values and the associated loop + // returned values that will be needed by the epilogue. + llvm::SmallVector yieldOperands; + for (OpOperand &yielOperand : + forOp.getBody()->getTerminator()->getOpOperands()) { + Value source = mapping.lookupOrDefault(yielOperand.get()); + // When we don't peel the epilogue the yield value is used outside the loop + // we need to make sure we return the version from numStages - defStage. + if (!peelEpilogue && + !forOp.getResult(yielOperand.getOperandNumber()).use_empty()) { + auto [def, distance] = getDefiningOpAndDistance(yielOperand.get()); + if (def) { + auto defStage = stages.find(def); + if (defStage != stages.end()) { + Value pred = predicates[defStage->second]; + if (pred) { + source = rewriter.create( + pred.getLoc(), pred, source, + newForOp.getBody() + ->getArguments()[yielOperand.getOperandNumber() + 1]); + } + } + } + } + yieldOperands.push_back(source); + } + + for (auto &it : crossStageValues) { + int64_t version = maxStage - it.second.lastUseStage + 1; + unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage; + // add the original version to yield ops. + // If there is a live range spanning across more than 2 stages we need to + // add extra arg. + for (unsigned i = 1; i < numVersionReturned; i++) { + setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), + version++); + yieldOperands.push_back( + newForOp.getBody()->getArguments()[yieldOperands.size() + 1 + + newForOp.getNumInductionVars()]); + } + setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), + version++); + yieldOperands.push_back(mapping.lookupOrDefault(it.first)); + } + // Map the yield operand to the forOp returned value. + for (const auto &retVal : + llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { + Operation *def = retVal.value().getDefiningOp(); + assert(def && "Only support loop carried dependencies of distance 1"); + unsigned defStage = stages[def]; + if (defStage > 0) { + setValueMapping(forOp.getRegionIterArgs()[retVal.index()], + newForOp->getResult(retVal.index()), + maxStage - defStage + 1); + } + } + rewriter.create(forOp.getLoc(), yieldOperands); + return success(); +} + +llvm::SmallVector +LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter) { + llvm::SmallVector returnValues(forOp->getNumResults()); + // Emit different versions of the induction variable. They will be + // removed by dead code if not used. + for (int64_t i = 0; i < maxStage; i++) { + Location loc = forOp.getLoc(); + Type t = lb.getType(); + Value minusOne = + rewriter.create(loc, rewriter.getIntegerAttr(t, -1)); + // number of iterations = ((ub - 1) - lb) / step + Value totlaNumIteration = rewriter.create( + loc, + rewriter.create( + loc, rewriter.create(loc, ub, minusOne), lb), + step); + // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i) + Value minusI = + rewriter.create(loc, rewriter.getIntegerAttr(t, -i)); + Value newlastIter = rewriter.create( + loc, lb, + rewriter.create( + loc, step, + rewriter.create(loc, totlaNumIteration, minusI))); + setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i); + } + // Emit `maxStage - 1` epilogue part that includes operations from stages + // [i; maxStage]. + for (int64_t i = 1; i <= maxStage; i++) { + for (Operation *op : opOrder) { + if (stages[op] < i) + continue; + Operation *newOp = + cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { + auto it = valueMapping.find(newOperand->get()); + if (it != valueMapping.end()) { + Value replacement = it->second[maxStage - stages[op] + i]; + newOperand->set(replacement); + } + }); + if (annotateFn) + annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Epilogue, + i - 1); + for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { + setValueMapping(op->getResult(destId), newOp->getResult(destId), + maxStage - stages[op] + i); + // If the value is a loop carried dependency update the loop argument + // mapping and keep track of the last version to replace the original + // forOp uses. + for (OpOperand &operand : + forOp.getBody()->getTerminator()->getOpOperands()) { + if (operand.get() != op->getResult(destId)) + continue; + unsigned version = maxStage - stages[op] + i + 1; + // If the version is greater than maxStage it means it maps to the + // original forOp returned value. + if (version > maxStage) { + returnValues[operand.getOperandNumber()] = newOp->getResult(destId); + continue; + } + setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], + newOp->getResult(destId), version); + } + } + } + } + return returnValues; +} + +void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) { + auto it = valueMapping.find(key); + // If the value is not in the map yet add a vector big enough to store all + // versions. + if (it == valueMapping.end()) + it = + valueMapping + .insert(std::make_pair(key, llvm::SmallVector(maxStage + 1))) + .first; + it->second[idx] = el; +} + +} // namespace + +FailureOr +mlir::triton::pipelineForLoop(RewriterBase &rewriter, ForOp forOp, + const triton::PipeliningOption &options, + bool *modifiedIR) { + if (modifiedIR) + *modifiedIR = false; + LoopPipelinerInternal pipeliner; + if (!pipeliner.initializeLoopInfo(forOp, options)) + return failure(); + + if (modifiedIR) + *modifiedIR = true; + + // 1. Emit prologue. + pipeliner.emitPrologue(rewriter); + + // 2. Track values used across stages. When a value cross stages it will + // need to be passed as loop iteration arguments. + // We first collect the values that are used in a different stage than where + // they are defined. + llvm::MapVector + crossStageValues = pipeliner.analyzeCrossStageValues(); + + // Mapping between original loop values used cross stage and the block + // arguments associated after pipelining. A Value may map to several + // arguments if its liverange spans across more than 2 stages. + llvm::DenseMap, unsigned> loopArgMap; + // 3. Create the new kernel loop and return the block arguments mapping. + ForOp newForOp = + pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap); + // Create the kernel block, order ops based on user choice and remap + // operands. + if (failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap, + rewriter))) + return failure(); + + llvm::SmallVector returnValues = + newForOp.getResults().take_front(forOp->getNumResults()); + if (options.peelEpilogue) { + // 4. Emit the epilogue after the new forOp. + rewriter.setInsertionPointAfter(newForOp); + returnValues = pipeliner.emitEpilogue(rewriter); + } + // 5. Erase the original loop and replace the uses with the epilogue output. + if (forOp->getNumResults() > 0) + rewriter.replaceOp(forOp, returnValues); + else + rewriter.eraseOp(forOp); + + return newForOp; +} diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.h b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.h new file mode 100644 index 000000000..0a3d736c6 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.h @@ -0,0 +1,101 @@ +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_ + +// This is a fork of upstream pipeline transformation. This will be merged back +// upstream once we have a stable solution. + +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" + +namespace mlir { + +class RewriterBase; +class Operation; +class Value; + +namespace scf { +class ForOp; +} + +namespace triton { + +/// Options to dictate how loops should be pipelined. +struct PipeliningOption { + /// Lambda returning all the operation in the forOp, with their stage, in the + /// order picked for the pipelined loop. + using GetScheduleFnType = std::function> &)>; + GetScheduleFnType getScheduleFn = nullptr; + enum class PipelinerPart { + Prologue, + Kernel, + Epilogue, + }; + /// Lambda called by the pipeliner to allow the user to annotate the IR while + /// it is generated. + /// The callback passes the operation created along with the part of the + /// pipeline and the iteration index. The iteration index is always 0 for the + /// kernel. For the prologue and epilogue, it corresponds to the iteration + /// peeled out of the loop in the range [0, maxStage[. + using AnnotationlFnType = + std::function; + AnnotationlFnType annotateFn = nullptr; + + /// Control whether the epilogue should be peeled out of the loop or + /// operations should be predicated to skip the early stages in the last loop + /// iterations. If the epilogue is predicated; the user needs to provide a + /// lambda to generate the predicated version of operations. + bool peelEpilogue = true; + + /// Control whether the transformation checks that the number of iterations is + /// greater or equal to the number of stages and skip the transformation if + /// this is not the case. If the loop is dynamic and this is set to true the + /// pipeliner will have to predicate operations in the the prologue/epilogue. + bool supportDynamicLoops = false; + + // Callback to predicate operations when the prologue or epilogue are not + // peeled. This takes the original operation, an i1 predicate value and the + // pattern rewriter. It is expected to replace the given operation with + // the predicated equivalent and return it, or return nullptr if the + // predication is impossible. In the latter case, pipelining will fail and + // may leave IR in a partially transformed state. + using PredicateOpFnType = + std::function; + PredicateOpFnType predicateFn = nullptr; + + // TODO: add option to decide if the prologue should be peeled. +}; + +/// Generate a pipelined version of the scf.for loop based on the schedule given +/// as option. This applies the mechanical transformation of changing the loop +/// and generating the prologue/epilogue for the pipelining and doesn't make any +/// decision regarding the schedule. +/// Based on the options the loop is split into several stages. +/// The transformation assumes that the scheduling given by user is valid. +/// For example if we break a loop into 3 stages named S0, S1, S2 we would +/// generate the following code with the number in parenthesis as the iteration +/// index: +/// +/// S0(0) // Prologue +/// S0(1) S1(0) // Prologue +/// scf.for %I = %C0 to %N - 2 { +/// S0(I+2) S1(I+1) S2(I) // Pipelined kernel +/// } +/// S1(N) S2(N-1) // Epilogue +/// S2(N) // Epilogue +/// +/// If `modifiedIR` is provided, it will be set to a value that indicates +/// whether pipelining modified the IR before failing, signaling to the caller +/// whether they can proceed with different transformations. +FailureOr pipelineForLoop(RewriterBase &rewriter, scf::ForOp forOp, + const PipeliningOption &options, + bool *modifiedIR = nullptr); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_ diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.h b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.h new file mode 100644 index 000000000..67ee2ca83 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.h @@ -0,0 +1,27 @@ +#ifndef TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ +#define TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ + +#include "PipelineExpander.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" +#include + +namespace mlir { +namespace triton { + +/// This fill out the pipelining options including schedule and annotations for +/// wait ops. This also does pre-processing by converting some of the loads into +/// async loads so that the IR is ready to be pipelined. +bool preProcessLoopAndGetSchedule(scf::ForOp &forOp, int numStages, + mlir::triton::PipeliningOption &options); + +/// This does post-processing on the pipelined loop to try to pipeline wgmma +/// ops. +// TODO: this should be included as part of the pipeline but currently the wgmma +// wait modeling is problematic. +void asyncLaunchDots(scf::ForOp forOp); + +} // namespace triton +} // namespace mlir +#endif // TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp new file mode 100644 index 000000000..573151115 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp @@ -0,0 +1,88 @@ +#include "PipelineExpander.h" +#include "Schedule.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/Debug.h" + +//===----------------------------------------------------------------------===// +// This file will create a schedule that will be handed over to the pipeline +// expander. +// Software pipeliners are usually separated into two pieces, one that create a +// modulo schedule and an expander that rewrites the loop and emits a prologue +// and epilogue. This pass first calls a helper that will pre-process the IR +// to create async operations and create a modulo schedule. Then we call the +// expander to generate the prologue and new loop. +//===----------------------------------------------------------------------===// + +using namespace mlir; + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +static void pipelineLoop(scf::ForOp forOp, int numStages) { + mlir::triton::PipeliningOption options; + // Skip loop with distance > 1 for now. + // TODO: relax the constraint in the expander. + if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), + [](Value operand) { + Operation *def = operand.getDefiningOp(); + return !def; + })) + return; + + bool foundSchedule = false; + foundSchedule = preProcessLoopAndGetSchedule(forOp, numStages, options); + + // TODO: add more pipelines strategy. + if (!foundSchedule) + return; + + IRRewriter rewriter(forOp->getContext()); + rewriter.setInsertionPoint(forOp); + FailureOr newForOp = + mlir::triton::pipelineForLoop(rewriter, forOp, options); + + if (succeeded(newForOp)) + mlir::triton::asyncLaunchDots(newForOp.value()); +} + +namespace { +struct PipelinePass : public TritonGPUPipelineBase { + PipelinePass() = default; + PipelinePass(int numStages, int numWarps, int numCTAs, + int computeCapability) { + this->numStages = numStages; + this->numWarps = numWarps; + this->numCTAs = numCTAs; + this->computeCapability = computeCapability; + } + + void runOnOperation() override { + if (this->numStages <= 1) + return; + SmallVector loops; + getOperation()->walk([&](scf::ForOp forOp) { loops.push_back(forOp); }); + for (scf::ForOp forOp : loops) { + pipelineLoop(forOp, numStages); + } + } +}; +} // anonymous namespace + +std::unique_ptr mlir::createTritonGPUPipelinePass(int numStages, + int numWarps, + int numCTAs, + int computeCapability) { + return std::make_unique(numStages, numWarps, numCTAs, + computeCapability); +} diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 388687c82..75375cdba 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -355,10 +355,6 @@ bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) { triton::MakeRangeOp, triton::SplatOp>(op); } -// - -// Replace ForOp with a new ForOp with extra operands. The YieldOp is not -// updated and needs to be updated separatly for the loop to be correct. scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, scf::ForOp loop, ValueRange newIterOperands) { OpBuilder::InsertionGuard g(rewriter); diff --git a/python/test/unit/hopper/test_gemm.py b/python/test/unit/hopper/test_gemm.py index a3e3f80b9..5c57fd17c 100644 --- a/python/test/unit/hopper/test_gemm.py +++ b/python/test/unit/hopper/test_gemm.py @@ -339,6 +339,9 @@ def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, '16-32-64-8-2-256-256-256-True', ]: pytest.skip('Known legacy issue, ldmatrix can only support x4') + enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower() + if NUM_CTAS > 1 and enable_tma in ["on", "true", "1"]: + pytest.skip('multi-CTA with TMA not supported in MaterializeLoadStore') M = BLOCK_M if M is None else M N = BLOCK_N if N is None else N diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index 8bfb7b576..03872748b 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -10,16 +10,15 @@ #A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> #B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> -// CHECK: tt.func @matmul_loop +// CHECK-LABEL: tt.func @matmul_loop // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32 -// CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] // CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] // CHECK-DAG: %[[LOOP_COND_0_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_0]] // CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[LOOP_COND_0_SPLAT_A]] -// CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor // CHECK-DAG: %[[LOOP_COND_0_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_0]] // CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[LOOP_COND_0_SPLAT_B]] // CHECK-DAG: %[[IV_1:.*]] = arith.addi %[[LB]], %[[STEP:.*]] @@ -29,18 +28,24 @@ // CHECK-DAG: %[[LOOP_COND_1_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_1]] // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]], %[[LOOP_COND_1_SPLAT_B]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0] -// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] +// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]] // CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] -// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] -// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}} -// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, {{.*}} -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, {{.*}} +// CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.convert_layout %[[arg_b0]] +// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_0]], {{.*}} +// CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] +// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][{{.*}}, 0, 0] -// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][{{.*}}, 0, 0] -// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} +// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: scf.yield {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, @@ -93,31 +98,37 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, #C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> #B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> -// CHECK: tt.func @matmul_loop_nested + +// CHECK-LABEL: tt.func @matmul_loop_nested // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32 // CHECK: scf.for // CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor -// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0] -// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] -// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] -// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] -// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}} -// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, {{.*}} -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, {{.*}} -// CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][{{.*}}, 0, 0] -// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][{{.*}}, 0, 0] -// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} +// CHECK-DAG: %[[A0:.*]] = triton_gpu.extract_slice %[[A0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK-DAG: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]] +// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] +// CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.convert_layout %[[arg_b0]] +// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_0]], {{.*}} +// CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] +// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] +// CHECK: triton_gpu.async_wait {num = 2 : i32} +// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: scf.yield {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, @@ -171,23 +182,28 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, #C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> #B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> -// CHECK: tt.func @matmul_loop_single_pipeline +// CHECK-LABEL: tt.func @matmul_loop_single_pipeline // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK-DAG: %[[CONSTANT_3:.*]] = arith.constant 3 : i32 // CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor // CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: triton_gpu.async_wait {num = 1 : i32} -// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] +// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_b0:.*]] = %[[B0]] // CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] // CHECK: tt.dot {{.*}}, %[[arg_b0_dot_op]], {{.*}} -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, {{.*}} +// CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] // CHECK: triton_gpu.async_wait {num = 1 : i32} -// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][{{.*}}, 0, 0] -// CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, {{.*}} +// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: scf.yield {{.*}}, %[[NEXT_B_BUFFER]], %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_B]] module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 7907618bd..f3a9b01aa 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -11,15 +11,15 @@ #A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> #B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> -// CHECK: tt.func @matmul_loop +// CHECK-LABEL: tt.func @matmul_loop // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 -// CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] // CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] // CHECK-DAG: %[[LOOP_COND_0_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_0]] // CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[LOOP_COND_0_SPLAT_A]] -// CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor // CHECK-DAG: %[[LOOP_COND_0_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_0]] // CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]], %[[LOOP_COND_0_SPLAT_B]] // CHECK-DAG: %[[IV_1:.*]] = arith.addi %[[LB]], %[[STEP:.*]] @@ -29,25 +29,25 @@ // CHECK-DAG: %[[LOOP_COND_1_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_1]] // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]], %[[LOOP_COND_1_SPLAT_B]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0] -// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_0]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] +// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]] // CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] // CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.convert_layout %[[arg_b0]] // CHECK: %[[arg_b0_dot_op_1:.*]] = arith.mulf %[[arg_b0_dot_op_0]] // CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_1]], {{.*}} -// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] : i32 -// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_2]] -// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[NEXT_LOOP_IDX]] -// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]] -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]] +// CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] +// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0] -// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0] -// CHECK-DAG: %[[PIPELINE_IDX_PLUS_ONE:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] -// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_2]] -// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX_PLUS_ONE]] -// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[EXTRACT_IDX]] +// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: scf.yield {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { @@ -92,36 +92,36 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, tt.return %loop#2: tensor<128x128xf32, #C> } -// CHECK: tt.func @matmul_loop_nested +// CHECK-LABEL: tt.func @matmul_loop_nested // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 // CHECK: scf.for // CHECK: %[[ABUFFER:.*]] = triton_gpu.alloc_tensor -// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[BBUFFER:.*]] = triton_gpu.alloc_tensor +// CHECK: %[[A0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0] -// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_0]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] -// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] -// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] -// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}} -// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] : i32 -// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_2]] -// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[NEXT_LOOP_IDX]] -// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]] -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]] -// CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0] -// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0] -// CHECK-DAG: %[[PIPELINE_IDX_PLUS_ONE:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] -// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_2]] -// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX_PLUS_ONE]] -// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[EXTRACT_IDX]] +// CHECK-DAG: %[[A0:.*]] = triton_gpu.extract_slice %[[A0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK-DAG: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]] +// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] +// CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.convert_layout %[[arg_b0]] +// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_0]], {{.*}} +// CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] +// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] +// CHECK: triton_gpu.async_wait {num = 2 : i32} +// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: scf.yield {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C>{ @@ -168,7 +168,7 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, } -// CHECK: tt.func @matmul_loop_single_pipeline +// CHECK-LABEL: tt.func @matmul_loop_single_pipeline // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 @@ -176,20 +176,20 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, // CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]] // CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]] // CHECK: triton_gpu.async_wait {num = 1 : i32} -// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_0]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] +// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B0BUFFER]][%[[CONSTANT_0]], 0, 0] +// CHECK: scf.for {{.*}} iter_args({{.*}}, %[[INS_IDX:.*]] = %[[CONSTANT_1]], %[[EXT_IDX:.*]] = %[[CONSTANT_0]]{{.*}}, %[[arg_b0:.*]] = %[[B0]] // CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] // CHECK: tt.dot {{.*}}, %[[arg_b0_dot_op]], {{.*}} -// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] : i32 -// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[NEXT_LOOP_IDX]], %[[CONSTANT_2]] -// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[NEXT_LOOP_IDX]] -// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[PIPELINE_IDX]] +// CHECK-DAG: %[[INS_IDX_2:.*]] = arith.addi %[[INS_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_INS:.*]] = arith.cmpi slt, %[[INS_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[INS_IDX_3:.*]] = arith.select %[[CMP_INS]], %[[INS_IDX_2]], %[[CONSTANT_0]] +// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INS_IDX_3]] +// CHECK-DAG: %[[EXT_IDX_2:.*]] = arith.addi %[[EXT_IDX]], %[[CONSTANT_1]] : i32 +// CHECK-DAG: %[[CMP_EXT:.*]] = arith.cmpi slt, %[[EXT_IDX_2]], %[[CONSTANT_2]] +// CHECK-DAG: %[[EXT_IDX_3:.*]] = arith.select %[[CMP_EXT]], %[[EXT_IDX_2]], %[[CONSTANT_0]] // CHECK: triton_gpu.async_wait {num = 1 : i32} -// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0] -// CHECK-DAG: %[[PIPELINE_IDX_PLUS_ONE:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] -// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX_PLUS_ONE]], %[[CONSTANT_2]] -// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX_PLUS_ONE]] -// CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[EXTRACT_IDX]] +// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %{{.+}}[%[[EXT_IDX_3]], 0, 0] +// CHECK: scf.yield {{.*}}, %[[NEXT_B_BUFFER]], %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_B]] tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { @@ -228,18 +228,18 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, tt.return %loop#1 : tensor<128x128xf32, #C> } -// CHECK: tt.func @lut_bmm_scalar +// CHECK-LABEL: tt.func @lut_bmm_scalar // CHECK: triton_gpu.insert_slice_async // CHECK: triton_gpu.insert_slice_async // CHECK: triton_gpu.insert_slice_async // CHECK: triton_gpu.insert_slice_async // CHECK: triton_gpu.async_commit_group -// CHECK: %[[LUT_BUFFER_0:.*]] = tt.load %arg15, {{.*}} +// CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} +// CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_1]] +// CHECK: %[[LUT_BUFFER_0:.*]] = tt.load %{{.*}}, {{.*}} // CHECK: %[[LUT_BUFFER_1:.*]] = arith.muli {{.*}}, %[[LUT_BUFFER_0]] // CHECK: %[[LUT_BUFFER_2:.*]] = tt.splat %[[LUT_BUFFER_1]] // CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[LUT_BUFFER_2]] -// CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %arg14, {{.*}} -// CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_1]] // CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_0]] // CHECK: triton_gpu.async_wait {num = 2 : i32} tt.func @lut_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, @@ -271,19 +271,19 @@ tt.func @lut_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, tt.return %79#0 : tensor<16x16xf32, #C> } -// CHECK: tt.func @lut_bmm_vector +// CHECK-LABEL: tt.func @lut_bmm_vector // CHECK: triton_gpu.insert_slice_async // CHECK: triton_gpu.insert_slice_async // CHECK: triton_gpu.insert_slice_async // CHECK: triton_gpu.insert_slice_async // CHECK: triton_gpu.async_commit_group -// CHECK: %[[LUT_BUFFER_0:.*]] = tt.load %arg15, {{.*}} +// CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} +// CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_1]] +// CHECK: %[[LUT_BUFFER_0:.*]] = tt.load %{{.*}}, {{.*}} // CHECK: %[[LUT_BUFFER_1:.*]] = tt.expand_dims %[[LUT_BUFFER_0]] {axis = 1 : i32} // CHECK: %[[LUT_BUFFER_2:.*]] = tt.broadcast %[[LUT_BUFFER_1]] // CHECK: %[[LUT_BUFFER_3:.*]] = arith.muli {{.*}}, %[[LUT_BUFFER_2]] // CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[LUT_BUFFER_3]] -// CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %arg14, {{.*}} -// CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_1]] // CHECK: triton_gpu.insert_slice_async %[[NEXT_BUFFER_0]] // CHECK: triton_gpu.async_wait {num = 2 : i32} tt.func @lut_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, @@ -317,11 +317,11 @@ tt.func @lut_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt tt.return %79#0 : tensor<16x16xf32, #C> } -// CHECK: tt.func @post_load_inv +// CHECK-LABEL: tt.func @post_load_inv // CHECK: scf.for -// CHECK: arith.index_cast // CHECK-DAG: %[[IV:.*]] = arith.index_cast // CHECK: %[[NEXT_IV:.*]] = arith.addi %[[IV]], %c1_i32 : i32 +// CHECK: arith.index_cast // CHECK-NOT: arith.addi %[[NEXT_IV]] tt.func @post_load_inv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, @@ -373,17 +373,11 @@ tt.func @post_load_inv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, tt.return %85#0 : tensor<32x32xf32, #C> } -// CHECK: tt.func @cross_iter_dep -// CHECK: triton_gpu.async_commit_group -// CHECK: triton_gpu.async_commit_group -// CHECK: triton_gpu.async_commit_group -// CHECK: triton_gpu.async_commit_group -// CHECK: %[[PTR0:.*]] = tt.addptr -// CHECK: %[[PTR1:.*]] = tt.addptr -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[BUF0:.*]] = %[[PTR0]], {{.*}}, %[[BUF1:.*]] = %[[PTR1]] +// CHECK-LABEL: tt.func @cross_iter_dep +// TODO: enable pipelining with distance of 2 +// CHECK-NOT: triton_gpu.async_commit_group +// CHECK: scf.for // CHECK: scf.yield -// CHECK-SAME: %[[BUF0]] -// CHECK-SAME: %[[BUF1]] tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, @@ -436,7 +430,7 @@ tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, tt.return %119#0 : tensor<32x32xf32, #C> } -// CHECK: tt.func @dep_arg_two_uses +// CHECK-LABEL: tt.func @dep_arg_two_uses // CHECK: tt.expand_dims // CHECK: tt.expand_dims // CHECK: tt.expand_dims %arg5 diff --git a/test/TritonGPU/pipeline-hopper-remove-wait.mlir b/test/TritonGPU/pipeline-hopper-remove-wait.mlir index 777e36546..56c55816d 100644 --- a/test/TritonGPU/pipeline-hopper-remove-wait.mlir +++ b/test/TritonGPU/pipeline-hopper-remove-wait.mlir @@ -1,4 +1,4 @@ -// RUN: ENABLE_TMA=1 ENABLE_MMA_V3=1 triton-opt %s -split-input-file -tritongpu-pipeline=compute-capability=90 -canonicalize | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritongpu-rewrite-tensor-pointer -canonicalize -tritongpu-pipeline=compute-capability=90 -canonicalize | FileCheck %s #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> @@ -11,6 +11,7 @@ #shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> #shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: two_dependent_dot tt.func public @two_dependent_dot(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} , %arg2: !tt.ptr {tt.divisibility = 16 : i32} , %arg3: f32 , %arg4: !tt.ptr {tt.divisibility = 16 : i32} , %arg5: !tt.ptr {tt.divisibility = 16 : i32} , %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg10: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg11: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg12: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg13: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg14: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg15: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg16: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg17: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg18: i32 , %arg19: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg20: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} , %arg21: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32} ) attributes {noinline = false} { %cst = arith.constant dense<0xFF800000> : tensor<128x64xf32, #mma> %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> @@ -72,8 +73,9 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %81 = arith.truncf %68 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> %82 = triton_gpu.convert_layout %60 : (tensor<64x128xf16, #blocked2>) -> tensor<64x128xf16, #shared> %83 = triton_gpu.convert_layout %81 : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> - // CHECK-LABEL: triton_nvidia_gpu.dot_async - // CHECK-LABEL-NOT: triton_nvidia_gpu.dot_wait + // CHECK: triton_nvidia_gpu.dot_async + // CHECK-NOT: triton_nvidia_gpu.dot_wait + // CHECK: scf.yield %84 = tt.dot %83, %82, %arg23 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<64x128xf16, #shared> -> tensor<128x128xf32, #mma1> %85 = arith.mulf %arg24, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> %87 = arith.addf %85, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>