#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/RegionUtils.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include using namespace mlir; namespace { using triton::DotOp; using triton::gpu::ConvertLayoutOp; using triton::gpu::DotOperandEncodingAttr; using triton::gpu::MmaEncodingAttr; using triton::gpu::SliceEncodingAttr; // ----------------------------------------------------------------------------- // // ----------------------------------------------------------------------------- <<<<<<< HEAD // convert(blocked, dot_operand) -> // convert(blocked, mma) + convert(mma, dot_operand) // if this value is itself the result of a dot operation // this is a heuristic to accommodate some pattern seen in fused attention // kernels. // TODO: replace this by something more generic, i.e. layout-aware CSE class DecomposeDotOperand : public mlir::RewritePattern { public: explicit DecomposeDotOperand(mlir::MLIRContext *context) : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1, context) {} template mlir::LogicalResult processEncoding(encTy encoding, triton::gpu::ConvertLayoutOp convert, RankedTensorType &dstType, mlir::PatternRewriter &rewriter) const { SetVector bwdSlices; mlir::getBackwardSlice(convert.getResult(), &bwdSlices); if (llvm::find_if(bwdSlices, [](Operation *op) { return isa(op); }) == bwdSlices.end()) return mlir::failure(); auto tmpType = RankedTensorType::get(dstType.getShape(), dstType.getElementType(), encoding); auto tmp = rewriter.create( convert.getLoc(), tmpType, convert.getOperand()); rewriter.replaceOpWithNewOp(convert, dstType, tmp); return mlir::success(); } mlir::LogicalResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { if (!llvm::isa(op)) return mlir::failure(); auto convert = llvm::cast(op); auto srcType = convert.getOperand().getType().cast(); auto dstType = convert.getType().cast(); if (srcType.getEncoding().isa() && dstType.getEncoding().isa()) { auto dstDotOperand = dstType.getEncoding().cast(); auto dstParent = dstDotOperand.getParent(); if (dstDotOperand.getOpIdx() == 1 || (!dstParent.isa() && !dstParent.isa())) return mlir::failure(); if (dstParent.isa()) { auto dstParentMma = dstParent.cast(); if (dstParentMma.isVolta() || dstParentMma.getWarpsPerCTA()[1] > 1) return mlir::failure(); return processEncoding(dstParentMma, convert, dstType, rewriter); } if (dstParent.isa()) { auto dstParentMfma = dstParent.cast(); if (dstParentMfma.getWarpsPerCTA()[1] > 1) return mlir::failure(); return processEncoding(dstParentMfma, convert, dstType, rewriter); } } return mlir::failure(); } }; // ======= // dot(a, b, load(ptr)) -> add(load(ptr), dot(a, b, 0)) >>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177 class ConvertDotConvert : public mlir::RewritePattern { public: ConvertDotConvert(mlir::MLIRContext *context) : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1, context) {} LogicalResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { auto dstOp = cast(op); auto dotOp = dstOp.getSrc().getDefiningOp(); if (!dotOp) return mlir::failure(); if (std::distance(dstOp->user_begin(), dstOp->user_end()) != 1 || std::distance(dotOp->user_begin(), dotOp->user_end()) != 1) return mlir::failure(); auto cvtOp = dotOp.getOperand(2).getDefiningOp(); if (!cvtOp) return mlir::failure(); if (!cvtOp.getSrc().getDefiningOp()) return failure(); auto dstTy = dstOp.getResult().getType().cast(); auto srcTy = cvtOp.getOperand().getType().cast(); if (dstTy != srcTy) return mlir::failure(); auto _0f = rewriter.create( op->getLoc(), dstTy.getElementType(), rewriter.getZeroAttr(dstTy.getElementType())); auto _0 = rewriter.create( op->getLoc(), dotOp.getResult().getType(), _0f); auto newDot = rewriter.create( op->getLoc(), dotOp.getResult().getType(), dotOp.getOperand(0), dotOp.getOperand(1), _0, dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc()); auto newCvt = rewriter.create( op->getLoc(), dstTy, newDot.getResult()); rewriter.replaceOpWithNewOp(op, newCvt, cvtOp.getOperand()); return mlir::success(); } }; // Class to propagate layout globally within a function. // The current algorithm works by analysis the IR and doing a one shot rewrite // based on the analysis. The algorithm is as follows: // 1. Find all the anchor ops. These are ops that have a layout we want to // preserve. // // 2. Propagate the layout to every op reachable which is a transitive child of // an anchor op until we reach a fix point. // An op can have multiple transitive anchor parents therefore at this stage // it may have multiple layout associated to it. // // 3. Resolve conflicts by deciding which of the multiple layouts the op should // keep. If one of the parents has a different layout than what is picked a // convert operation will be inserted. After this stage each value should have // only one layout associated. // // 4. Rewrite the IR by walking the function following dominance order. Since we // assume the IR is structured we just need to process the regions in the // correct order. For each op rewrite it using the layout decided by the // analysis phase. class LayoutPropagation { public: // Structure to keep track of the layout associated to a value. struct LayoutInfo { LayoutInfo(Attribute encoding) { encodings.insert(encoding); } LayoutInfo() {} llvm::SmallSetVector encodings; }; LayoutPropagation(triton::FuncOp F) : funcOp(F) {} // Find the anchor ops and set their layout in the data structure. void initAnchorLayout(); // Recursively Propagate the layout to all the users of the anchor ops until // we reach a fix point. void propagateLayout(); // Add layouts given in `Info` to the uses of `value`. SmallVector propagateToUsers(Value value, LayoutInfo &info); // Set the encoding to all the values and fill out the values with new layout // in `changed`. void setEncoding(ValueRange values, LayoutInfo &info, SmallVector &changed, Operation *op); // Resolve cases where a value has multiple layouts associated to it. void resolveConflicts(); // Rewrite the IR for the full module. void rewrite(); // Rewrite the IR for a region. void rewriteRegion(Region &R); // Rewrite an op based on the layout picked by the analysis. Operation *rewriteOp(Operation *op); // Rewrite a for op based on the layout picked by the analysis. Operation *rewriteForOp(scf::ForOp forOp); Operation *rewriteWhileOp(scf::WhileOp whileOp); Operation *rewriteIfOp(scf::IfOp ifOp); void rewriteYieldOp(scf::YieldOp yieldOp); void rewriteConditionOp(scf::ConditionOp conditionOp); void rewriteReduceToScalar(Operation *reduceOp); Operation *cloneElementwise(OpBuilder &rewriter, Operation *op, Attribute encoding); // Map the original value to the rewritten one. void map(Value old, Value newV); // Return the mapped value in the given encoding. This will insert a convert // if the encoding is different than the encoding decided at resolve time. Value getValueAs(Value value, Attribute encoding); // Dump the current stage of layout information. void dump(); private: // map from value to layout information. llvm::MapVector layouts; // map of the values rewrite based on their encoding. DenseMap, Value> rewriteMapping; std::vector opToDelete; triton::FuncOp funcOp; }; } // namespace // Look ahead to at the transitive uses and see if there is a convert to mma // operations. static bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) { SmallVector queue = {op->getResult(0)}; SetVector forwardSlice; llvm::SmallDenseSet seen; while (!queue.empty()) { Value currentValue = queue.back(); queue.pop_back(); getForwardSlice(currentValue, &forwardSlice); for (Operation *op : forwardSlice) { if (auto convertOp = dyn_cast(op)) { Attribute dstEncoding = convertOp.getResult() .getType() .cast() .getEncoding(); if (auto mmaLayout = dstEncoding.dyn_cast()) return (mmaLayout.getVersionMajor() > 1) ? true : mmaLayout == encoding; if (dstEncoding.isa()) return encoding.cast() .getVersionMajor() > 1; } auto yield = dyn_cast(op); if (!yield) continue; auto forOp = dyn_cast(yield.getOperation()->getParentOp()); if (!forOp) continue; for (OpOperand &operand : yield->getOpOperands()) { Operation *def = operand.get().getDefiningOp(); if (def && forwardSlice.count(def) && (seen.insert(operand.get()).second == true)) queue.push_back(forOp.getRegionIterArg(operand.getOperandNumber())); } } } return false; } #ifdef USE_ROCM // Look ahead to at the transitive uses and see if there is a convert to mfma // operations. // TODO: unify with hasConvertToMMATransisitiveUse? static bool hasConvertToMFMATransisitiveUse(Operation *op, Attribute encoding) { SmallVector queue = {op->getResult(0)}; SetVector forwardSlice; llvm::SmallDenseSet seen; while (!queue.empty()) { Value currentValue = queue.back(); queue.pop_back(); getForwardSlice(currentValue, &forwardSlice); for (Operation *op : forwardSlice) { if (auto convertOp = dyn_cast(op)) { if (convertOp.getResult() .getType() .cast() .getEncoding() == encoding) return true; } auto yield = dyn_cast(op); if (!yield) continue; auto forOp = dyn_cast(yield.getOperation()->getParentOp()); if (!forOp) continue; for (OpOperand &operand : yield->getOpOperands()) { Operation *def = operand.get().getDefiningOp(); if (def && forwardSlice.count(def) && (seen.insert(operand.get()).second == true)) queue.push_back(forOp.getRegionIterArg(operand.getOperandNumber())); } } } return false; } #endif // Return true if the op is an op with a layout we don't want to change. We will // propagate the layout starting from anchor ops. static bool isLayoutAnchor(Operation *op) { if (isa(op)) return isExpensiveLoadOrStore(op); if (isa(op)) return true; return false; } void LayoutPropagation::initAnchorLayout() { funcOp.walk([&](Operation *op) { if (isLayoutAnchor(op)) { for (auto result : op->getResults()) { if (auto tensorType = result.getType().dyn_cast()) { // Workaround, don't popagate MMA layout unless there is a convert // back to mma further down to avoid generating reduction with MMA // layout that may have lower performance. // This can be improved with more aggressive backward propagation. if (tensorType.getEncoding().isa() && !hasConvertToMMATransisitiveUse(op, tensorType.getEncoding())) continue; #ifdef USE_ROCM // Workaround to not propagate MFMA layout in case there are // no chained dots MFMA layout is expensive to convert, so we want // to convert it to something else as soon as possible. // It saves LDS space in some cases. // // TODO: rework this heuristic if we can store MFMA layout directly // into global memory. if (tensorType.getEncoding().isa() && !hasConvertToMFMATransisitiveUse(op, tensorType.getEncoding())) continue; #endif layouts.insert({result, LayoutInfo(tensorType.getEncoding())}); } } } }); } void LayoutPropagation::setEncoding(ValueRange values, LayoutInfo &info, SmallVector &changed, Operation *op) { for (Value value : values) { if (!value.getType().isa()) continue; bool hasChanged = false; for (auto encoding : info.encodings) { auto dstEncoding = inferDstEncoding(op, encoding); if (dstEncoding) hasChanged |= layouts[value].encodings.insert(*dstEncoding); } if (hasChanged) changed.push_back(value); } } SmallVector LayoutPropagation::propagateToUsers(Value value, LayoutInfo &info) { SmallVector changed; for (OpOperand &use : value.getUses()) { Operation *user = use.getOwner(); if (auto forOp = dyn_cast(user)) { Value arg = forOp.getRegionIterArgForOpOperand(use); Value result = forOp.getResultForOpOperand(use); setEncoding({arg, result}, info, changed, user); continue; } if (auto whileOp = dyn_cast(user)) { Value arg = whileOp.getBeforeArguments()[use.getOperandNumber()]; setEncoding({arg}, info, changed, user); continue; } if (auto yieldOp = dyn_cast(user)) { auto parent = yieldOp->getParentOp(); SmallVector valuesToPropagate; if (isa(parent)) valuesToPropagate.push_back(parent->getResult(use.getOperandNumber())); if (auto forOp = dyn_cast(parent)) valuesToPropagate.push_back( forOp.getRegionIterArg(use.getOperandNumber())); if (auto whileOp = dyn_cast(parent)) { valuesToPropagate.push_back( whileOp.getBeforeArguments()[use.getOperandNumber()]); valuesToPropagate.push_back( whileOp->getOperand(use.getOperandNumber())); } if (isa(parent)) setEncoding(valuesToPropagate, info, changed, user); continue; } if (auto conditionOp = dyn_cast(user)) { auto whileOp = cast(conditionOp->getParentOp()); // Skip arg 0 as it is the condition. unsigned argIndex = use.getOperandNumber() - 1; Value afterArg = whileOp.getAfterArguments()[argIndex]; Value result = whileOp->getResult(argIndex); setEncoding({afterArg, result}, info, changed, user); continue; } // Workaround: don't propagate through truncI if (isa(user)) continue; if (user->hasTrait() || user->hasTrait() || isa(user)) { #ifdef USE_ROCM if (auto convertOp = dyn_cast(user)) { if (triton::gpu::isSharedEncoding(convertOp.getResult()) || triton::gpu::isSharedEncoding(convertOp.getOperand())) continue; } #endif setEncoding(user->getResults(), info, changed, user); continue; } } return changed; } void LayoutPropagation::propagateLayout() { SmallVector queue; for (auto it : layouts) { queue.push_back(it.first); } while (!queue.empty()) { Value currentValue = queue.back(); LayoutInfo info = layouts[currentValue]; queue.pop_back(); SmallVector changed = propagateToUsers(currentValue, info); queue.insert(queue.end(), changed.begin(), changed.end()); } } void LayoutPropagation::resolveConflicts() { for (auto &it : layouts) { Operation *op = it.first.getDefiningOp(); LayoutInfo &info = it.second; if (info.encodings.size() <= 1) continue; // Hacky resolve, prefer block encoding. // TODO: add a proper heuristic. int maxSizePerThread = 1; Attribute encoding = *info.encodings.begin(); bool isLoadOrStore = op && isa(op); for (Attribute e : info.encodings) { if ((isLoadOrStore && e.isa()) || (!isLoadOrStore && e.isa())) { encoding = e; break; } } info.encodings.clear(); info.encodings.insert(encoding); } } void LayoutPropagation::dump() { for (auto it : layouts) { llvm::errs() << "Value: "; OpPrintingFlags flags; flags.skipRegions(); it.first.print(llvm::errs(), flags); llvm::errs() << " \n encoding:\n"; for (auto encoding : it.second.encodings) { encoding.print(llvm::errs()); llvm::errs() << "\n"; } llvm::errs() << "--\n"; } } void LayoutPropagation::rewrite() { rewriteRegion(funcOp->getRegion(0)); } static bool reduceToScalar(Operation *op) { // For reductions returning a scalar we can change the src encoding without // affecting the output. return isa(op) && !op->getResultTypes()[0].isa(); } void LayoutPropagation::rewriteRegion(Region ®ion) { SmallVector queue = {®ion}; while (!queue.empty()) { Region *currentRegion = queue.back(); queue.pop_back(); for (Operation &op : currentRegion->getOps()) { bool needRewrite = false; SmallVector results = op.getResults(); for (Value result : results) { auto it = layouts.find(result); // If we haven't mapped this value skip. if (it == layouts.end()) continue; LayoutInfo &info = it->second; assert(info.encodings.size() == 1 && "we should have resolved to a single encoding"); auto encoding = result.getType().cast().getEncoding(); // If the encoding is already what we want skip. if (encoding == *info.encodings.begin()) continue; needRewrite = true; } if (needRewrite) { Operation *newOp = rewriteOp(&op); for (Region &R : newOp->getRegions()) queue.push_back(&R); } else if (auto yieldOp = dyn_cast(&op)) { rewriteYieldOp(yieldOp); } else if (auto conditionOp = dyn_cast(&op)) { rewriteConditionOp(conditionOp); } else if (reduceToScalar(&op)) { rewriteReduceToScalar(&op); } else { // If we don't need to rewrite the op we still need to remap the // operands. for (OpOperand &operand : op.getOpOperands()) { auto it = layouts.find(operand.get()); if (it == layouts.end()) continue; Attribute encoding = operand.get().getType().cast().getEncoding(); Value newOperand = getValueAs(operand.get(), encoding); op.setOperand(operand.getOperandNumber(), newOperand); } for (Region &R : op.getRegions()) queue.push_back(&R); } } } for (Operation *op : llvm::reverse(opToDelete)) op->erase(); } void LayoutPropagation::map(Value old, Value newV) { rewriteMapping[{old, newV.getType().cast().getEncoding()}] = newV; } Value LayoutPropagation::getValueAs(Value value, Attribute encoding) { if (auto tensorType = value.getType().dyn_cast()) { Value rewrittenValue; auto layoutIt = layouts.find(value); if (layoutIt == layouts.end()) { rewrittenValue = value; } else { assert(layoutIt->second.encodings.size() == 1 && "we should have resolved to a single encoding"); Attribute encodingPicked = *(layoutIt->second.encodings.begin()); if (encodingPicked == tensorType.getEncoding()) rewrittenValue = value; else rewrittenValue = rewriteMapping[{value, encodingPicked}]; } assert(rewrittenValue); if (rewrittenValue.getType().cast().getEncoding() == encoding) return rewrittenValue; OpBuilder rewriter(value.getContext()); rewriter.setInsertionPointAfterValue(rewrittenValue); // Workaround: The pipeliner will insert async.wait after a pipelined loop // to ensure that there is no pending copies and it is safe to re-use shared // memory. We shouldn't insert ops that may use shared memory in between the // loop and the async.wait. This is a hack until we fix the IR // representation of async wait. if (Operation *op = rewrittenValue.getDefiningOp()) { if (isa(op->getNextNode())) rewriter.setInsertionPointAfter(op->getNextNode()); } auto tmpType = RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(), encoding); Value converted = rewriter.create( value.getLoc(), tmpType, rewrittenValue); // TODO: we could cache the conversion. return converted; } return value; } Operation *LayoutPropagation::cloneElementwise(OpBuilder &rewriter, Operation *op, Attribute encoding) { Operation *newOp = rewriter.clone(*op); for (OpOperand &operand : op->getOpOperands()) newOp->setOperand( operand.getOperandNumber(), getValueAs(operand.get(), *inferSrcEncoding(op, encoding))); for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) { auto origType = op->getResult(i).getType().dyn_cast(); if (!origType) continue; auto newType = RankedTensorType::get(origType.getShape(), origType.getElementType(), encoding); newOp->getResult(i).setType(newType); } return newOp; } Operation *LayoutPropagation::rewriteForOp(scf::ForOp forOp) { SmallVector operands; OpBuilder rewriter(forOp); for (auto [operand, result] : llvm::zip(forOp.getInitArgs(), forOp.getResults())) { Value convertedOperand = operand; if (layouts.count(result)) convertedOperand = getValueAs(operand, *layouts[result].encodings.begin()); operands.push_back(convertedOperand); } auto newForOp = rewriter.create( forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), operands); newForOp.getBody()->getOperations().splice( newForOp.getBody()->getOperations().begin(), forOp.getBody()->getOperations()); for (auto [oldResult, newResult] : llvm::zip(forOp.getResults(), newForOp.getResults())) { if (oldResult.getType() == newResult.getType()) { oldResult.replaceAllUsesWith(newResult); continue; } map(oldResult, newResult); } for (auto [oldArg, newArg] : llvm::zip(forOp.getBody()->getArguments(), newForOp.getBody()->getArguments())) { if (oldArg.getType() == newArg.getType()) { oldArg.replaceAllUsesWith(newArg); continue; } map(oldArg, newArg); } return newForOp.getOperation(); } Operation *LayoutPropagation::rewriteWhileOp(scf::WhileOp whileOp) { SmallVector operands; SmallVector returnTypes; OpBuilder rewriter(whileOp); for (auto [operand, arg] : llvm::zip(whileOp->getOperands(), whileOp.getBeforeArguments())) { Value convertedOperand = operand; if (layouts.count(arg)) convertedOperand = getValueAs(operand, *layouts[arg].encodings.begin()); operands.push_back(convertedOperand); } for (Value ret : whileOp.getResults()) { auto it = layouts.find(ret); if (it == layouts.end()) { returnTypes.push_back(ret.getType()); continue; } auto origType = ret.getType().dyn_cast(); auto newType = RankedTensorType::get(origType.getShape(), origType.getElementType(), it->second.encodings[0]); returnTypes.push_back(newType); } auto newWhileOp = rewriter.create(whileOp.getLoc(), returnTypes, operands); SmallVector argsTypesBefore; for (Value operand : operands) argsTypesBefore.push_back(operand.getType()); SmallVector bbArgLocsBefore(argsTypesBefore.size(), whileOp.getLoc()); SmallVector bbArgLocsAfter(returnTypes.size(), whileOp.getLoc()); rewriter.createBlock(&newWhileOp.getBefore(), {}, argsTypesBefore, bbArgLocsBefore); rewriter.createBlock(&newWhileOp.getAfter(), {}, returnTypes, bbArgLocsAfter); for (int i = 0; i < whileOp.getNumRegions(); ++i) { newWhileOp->getRegion(i).front().getOperations().splice( newWhileOp->getRegion(i).front().getOperations().begin(), whileOp->getRegion(i).front().getOperations()); } auto remapArg = [&](Value oldVal, Value newVal) { if (oldVal.getType() == newVal.getType()) oldVal.replaceAllUsesWith(newVal); else map(oldVal, newVal); }; for (auto [oldResult, newResult] : llvm::zip(whileOp.getResults(), newWhileOp.getResults())) remapArg(oldResult, newResult); for (auto [oldArg, newArg] : llvm::zip(whileOp.getBeforeArguments(), newWhileOp.getBeforeArguments())) remapArg(oldArg, newArg); for (auto [oldArg, newArg] : llvm::zip(whileOp.getAfterArguments(), newWhileOp.getAfterArguments())) remapArg(oldArg, newArg); return newWhileOp.getOperation(); } Operation *LayoutPropagation::rewriteIfOp(scf::IfOp ifOp) { SmallVector operands; OpBuilder rewriter(ifOp); SmallVector newResultTypes(ifOp->getResultTypes()); for (unsigned i = 0, e = ifOp->getNumResults(); i < e; ++i) { auto it = layouts.find(ifOp->getResult(i)); if (it == layouts.end()) continue; auto origType = ifOp->getResult(i).getType().cast(); Attribute encoding = *(it->second.encodings.begin()); newResultTypes[i] = RankedTensorType::get( origType.getShape(), origType.getElementType(), encoding); } auto newIfOp = rewriter.create(ifOp.getLoc(), newResultTypes, ifOp.getCondition(), true, true); newIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); newIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); for (auto [oldResult, newResult] : llvm::zip(ifOp.getResults(), newIfOp.getResults())) { if (oldResult.getType() == newResult.getType()) { oldResult.replaceAllUsesWith(newResult); continue; } map(oldResult, newResult); } return newIfOp.getOperation(); } void LayoutPropagation::rewriteYieldOp(scf::YieldOp yieldOp) { Operation *parentOp = yieldOp->getParentOp(); for (OpOperand &operand : yieldOp->getOpOperands()) { Type yieldType = operand.get().getType(); if (isa(parentOp)) yieldType = parentOp->getResult(operand.getOperandNumber()).getType(); if (auto whileOp = dyn_cast(parentOp)) yieldType = whileOp.getBeforeArguments()[operand.getOperandNumber()].getType(); auto tensorType = yieldType.dyn_cast(); if (!tensorType) continue; Value newOperand = getValueAs(operand.get(), tensorType.getEncoding()); yieldOp->setOperand(operand.getOperandNumber(), newOperand); } } void LayoutPropagation::rewriteConditionOp(scf::ConditionOp conditionOp) { scf::WhileOp whileOp = cast(conditionOp->getParentOp()); for (unsigned i = 1; i < conditionOp->getNumOperands(); ++i) { OpOperand &operand = conditionOp->getOpOperand(i); Type argType = whileOp->getResult(operand.getOperandNumber() - 1).getType(); auto tensorType = argType.dyn_cast(); if (!tensorType) continue; Value newOperand = getValueAs(operand.get(), tensorType.getEncoding()); conditionOp->setOperand(operand.getOperandNumber(), newOperand); } } void LayoutPropagation::rewriteReduceToScalar(Operation *reduceOp) { OpBuilder rewriter(reduceOp); Attribute srcEncoding; // Since all the operands need to have the same encoding pick the first one // and use it for all the operands. for (Value operand : reduceOp->getOperands()) { auto it = layouts.find(operand); if (it != layouts.end()) { srcEncoding = it->second.encodings[0]; break; } } if (!srcEncoding) return; for (OpOperand &operand : reduceOp->getOpOperands()) { Value newOperand = getValueAs(operand.get(), srcEncoding); reduceOp->setOperand(operand.getOperandNumber(), newOperand); } } Operation *LayoutPropagation::rewriteOp(Operation *op) { opToDelete.push_back(op); if (auto forOp = dyn_cast(op)) return rewriteForOp(forOp); if (auto whileOp = dyn_cast(op)) return rewriteWhileOp(whileOp); if (auto ifOp = dyn_cast(op)) return rewriteIfOp(ifOp); OpBuilder rewriter(op); Attribute encoding = *layouts[op->getResult(0)].encodings.begin(); if (auto convertOp = dyn_cast(op)) { Attribute srcEncoding = convertOp.getOperand().getType().cast().getEncoding(); auto it = layouts.find(convertOp.getOperand()); if (it != layouts.end()) srcEncoding = *(it->second.encodings.begin()); Value src = getValueAs(convertOp.getOperand(), srcEncoding); auto tensorType = op->getResult(0).getType().cast(); auto newType = RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(), encoding); auto cvt = rewriter.create(op->getLoc(), newType, src); map(op->getResult(0), cvt.getResult()); return cvt.getOperation(); } if (canFoldIntoConversion(op, encoding)) { Operation *newOp = rewriter.clone(*op); auto tensorType = op->getResult(0).getType().cast(); auto newType = RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(), encoding); auto cvt = rewriter.create( op->getLoc(), newType, newOp->getResult(0)); map(op->getResult(0), cvt.getResult()); return cvt.getOperation(); } if (op->hasTrait() || op->hasTrait() || isa( op)) { Operation *newOp = cloneElementwise(rewriter, op, encoding); for (auto [oldResult, newResult] : llvm::zip(op->getResults(), newOp->getResults())) map(oldResult, newResult); return newOp; } assert(0 && "unexpected op in rewrite"); return nullptr; } static bool canBeRemat(Operation *op) { if (isa(op)) return !isExpensiveLoadOrStore(op); if (isa(op)) return false; if (isa(op)) return false; return true; } // Replace ForOp with a new ForOp with extra operands. The YieldOp is not // updated and needs to be updated separatly for the loop to be correct. static scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, scf::ForOp loop, ValueRange newIterOperands) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(loop); // Create a new loop before the existing one, with the extra operands. rewriter.setInsertionPoint(loop); auto operands = llvm::to_vector<4>(loop.getInitArgs()); operands.append(newIterOperands.begin(), newIterOperands.end()); scf::ForOp newLoop = rewriter.create( loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), operands); newLoop.getBody()->erase(); newLoop.getRegion().getBlocks().splice( newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks()); for (Value operand : newIterOperands) newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( loop.getNumResults()))) std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); return newLoop; } static void rewriteSlice(SetVector &slice, DenseMap &layout, ConvertLayoutOp convertOp, IRMapping &mapping) { SetVector opsToRewrite; for (Value v : slice) { if (v.getDefiningOp()) { opsToRewrite.insert(v.getDefiningOp()); } else { opsToRewrite.insert(v.cast().getOwner()->getParentOp()); // We also need to rewrite the yield op. opsToRewrite.insert(v.cast().getOwner()->getTerminator()); } } opsToRewrite = multiRootTopologicalSort(opsToRewrite); SmallVector deadLoops; OpBuilder builder(slice.begin()->getContext()); for (Operation *op : opsToRewrite) { if (auto forOp = dyn_cast(op)) { // Keep a mapping of the operands index to the new operands index. SmallVector> argMapping; SmallVector newOperands; for (auto arg : forOp.getRegionIterArgs()) { if (slice.count(arg)) { OpOperand &initVal = forOp.getOpOperandForRegionIterArg(arg); argMapping.push_back(std::make_pair( forOp.getResultForOpOperand(initVal).getResultNumber(), forOp.getInitArgs().size() + newOperands.size())); newOperands.push_back(mapping.lookup(initVal.get())); } } // Create a new for loop with the new operands. scf::ForOp newForOp = replaceForOpWithNewSignature(builder, forOp, newOperands); deadLoops.push_back(forOp.getOperation()); Block &loopBody = *newForOp.getBody(); for (auto m : argMapping) { mapping.map(newForOp.getResult(m.first), newForOp.getResult(m.second)); int numIndVars = newForOp.getNumInductionVars(); mapping.map(loopBody.getArgument(m.first + numIndVars), loopBody.getArgument(m.second + numIndVars)); } continue; } builder.setInsertionPoint(op); if (auto yieldOp = dyn_cast(op)) { auto yieldOperands = llvm::to_vector(yieldOp.getOperands()); for (Value operand : yieldOp.getOperands()) { if (slice.count(operand) == 0) continue; yieldOperands.push_back(mapping.lookup(operand)); } builder.create(op->getLoc(), yieldOperands); op->erase(); continue; } if (isa(op)) { Operation *newOp = builder.clone(*op); auto tensorType = op->getResult(0).getType().cast(); auto newType = RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(), layout[op->getResult(0)]); auto cvt = builder.create( op->getLoc(), newType, newOp->getResult(0)); mapping.map(op->getResult(0), cvt.getResult()); continue; } Operation *newOp = builder.clone(*op, mapping); for (auto [old, newV] : llvm::zip(op->getResults(), newOp->getResults())) { auto it = layout.find(old); if (it == layout.end()) continue; auto newType = RankedTensorType::get( old.getType().cast().getShape(), old.getType().cast().getElementType(), it->second); newV.setType(newType); } } convertOp.replaceAllUsesWith(mapping.lookup(convertOp.getOperand())); convertOp.erase(); for (Operation *op : deadLoops) op->erase(); } static void rewriteSlice(SetVector &slice, DenseMap &layout, ConvertLayoutOp convertOp) { IRMapping mapping; rewriteSlice(slice, layout, convertOp, mapping); } static LogicalResult getRematerializableSlice( Value root, Attribute rootEncoding, SetVector &slice, DenseMap &layout, std::function stopPropagation = nullptr) { LogicalResult result = getConvertBackwardSlice(root, slice, rootEncoding, layout, stopPropagation); if (result.failed() || slice.empty()) return failure(); // Check if all the operations in the slice can be rematerialized. for (Value v : slice) { if (Operation *op = v.getDefiningOp()) { if (!canBeRemat(op)) return failure(); } } return success(); } static void backwardRematerialization(ConvertLayoutOp convertOp) { // we don't want to rematerialize any conversion to/from shared if (triton::gpu::isSharedEncoding(convertOp.getResult()) || triton::gpu::isSharedEncoding(convertOp.getOperand())) return; // we don't handle conversions to DotOperandEncodingAttr // this is a heuristics to accommodate fused attention auto targetType = convertOp->getResultTypes()[0].cast(); if (targetType.getEncoding().isa()) return; // 1. Take a backward slice of all the tensor dependencies that can be // rematerialized. SetVector slice; DenseMap layout; LogicalResult result = getRematerializableSlice( convertOp.getOperand(), targetType.getEncoding(), slice, layout); if (result.failed()) return; // 2. Rewrite the slice. rewriteSlice(slice, layout, convertOp); } // For convert left we try to hoist them above type extension to reduce the cost // of the convert. static void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp) { // we don't want to rematerialize any conversion to/from shared if (triton::gpu::isSharedEncoding(convertOp.getResult()) || triton::gpu::isSharedEncoding(convertOp.getOperand())) return; // we don't handle conversions to DotOperandEncodingAttr // this is a heuristics to accommodate fused attention auto targetType = convertOp->getResultTypes()[0].cast(); if (targetType.getEncoding().isa()) return; auto isExtOrBroadcastOp = [](Operation *op) { return isa(op); }; // 1. Take a backward slice of all the tensor dependencies. SetVector slice; DenseMap layout; LogicalResult result = getRematerializableSlice(convertOp.getOperand(), targetType.getEncoding(), slice, layout, isExtOrBroadcastOp); if (result.failed()) return; Operation *extOrBroadcatOp = nullptr; unsigned sliceSize = slice.size(); for (unsigned i = 0; i < sliceSize; i++) { Value v = slice[i]; Operation *op = v.getDefiningOp(); if (!op) continue; if (isExtOrBroadcastOp(op)) { SetVector tempSlice; DenseMap tempLayout; std::optional srcEncoding = inferSrcEncoding(op, layout[v]); if (!srcEncoding) return; LogicalResult result = getRematerializableSlice( op->getOperand(0), *srcEncoding, tempSlice, tempLayout); // If we can rematerialize the rest of the ext slice we can ignore this // ext as it won't need a convert. if (result.succeeded()) { slice.insert(tempSlice.begin(), tempSlice.end()); layout.insert(tempLayout.begin(), tempLayout.end()); continue; } // Only apply it if there is a single ext op otherwise we would have to // duplicate the convert. if (extOrBroadcatOp != nullptr) return; extOrBroadcatOp = op; } } if (extOrBroadcatOp == nullptr) return; std::optional srcEncoding = inferSrcEncoding(extOrBroadcatOp, layout[extOrBroadcatOp->getResult(0)]); if (!srcEncoding) return; // Move the convert before the ext op and rewrite the slice. OpBuilder builder(extOrBroadcatOp); auto tensorType = extOrBroadcatOp->getOperand(0).getType().cast(); auto newType = RankedTensorType::get( tensorType.getShape(), tensorType.getElementType(), *srcEncoding); auto newConvertOp = builder.create( convertOp.getLoc(), newType, extOrBroadcatOp->getOperand(0)); IRMapping mapping; mapping.map(extOrBroadcatOp->getOperand(0), newConvertOp.getResult()); // 3. Rewrite the slice. rewriteSlice(slice, layout, convertOp, mapping); } static void backwardRematerialization(ModuleOp module) { SmallVector convertOps; module.walk( [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); for (ConvertLayoutOp convertOp : convertOps) { backwardRematerialization(convertOp); } } static void hoistConvert(ModuleOp module) { SmallVector convertOps; module.walk( [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); for (ConvertLayoutOp convertOp : convertOps) { hoistConvertOnTopOfExtOrBroadcast(convertOp); } } #define GEN_PASS_CLASSES #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" class TritonGPURemoveLayoutConversionsPass : public TritonGPURemoveLayoutConversionsBase< TritonGPURemoveLayoutConversionsPass> { public: TritonGPURemoveLayoutConversionsPass() = default; void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp m = getOperation(); // 1. Propagate layout forward starting from "anchor" ops. m.walk([](triton::FuncOp funcOp) { LayoutPropagation layoutPropagation(funcOp); layoutPropagation.initAnchorLayout(); layoutPropagation.propagateLayout(); layoutPropagation.resolveConflicts(); layoutPropagation.rewrite(); }); mlir::RewritePatternSet cleanUpPatterns(context); ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns, context); if (mlir::applyPatternsAndFoldGreedily(m, std::move(cleanUpPatterns)) .failed()) { signalPassFailure(); } // 2. For convert ops left try to rematerialize the slice of producer // operation to avoid having to convert. backwardRematerialization(m); // 3. For converts left try to hoist them above cast generating larger size // types in order to reduce the cost of the convert op. hoistConvert(m); mlir::RewritePatternSet decomposePatterns(context); decomposePatterns.add(context); if (mlir::applyPatternsAndFoldGreedily(m, std::move(decomposePatterns)) .failed()) { signalPassFailure(); } // 4. Apply clean up patterns to remove remove dead convert and dead code // generated by the previous transformations. mlir::RewritePatternSet cleanUpPatterns2(context); populateForOpDeadArgumentElimination(cleanUpPatterns2); scf::ForOp::getCanonicalizationPatterns(cleanUpPatterns2, context); ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns2, context); if (mlir::applyPatternsAndFoldGreedily(m, std::move(cleanUpPatterns2)) .failed()) { signalPassFailure(); } } }; std::unique_ptr mlir::createTritonGPURemoveLayoutConversionsPass() { return std::make_unique(); }