// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete/blob/main/LICENSE.txt // for license information. #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace mlir { namespace concretelang { template struct IndexedOpInfo { static mlir::SmallVector getOffsets(IndexedOpTy op); static mlir::Value getTensor(IndexedOpTy op); static int64_t getSize(IndexedOpTy op, int64_t dim); static int64_t getStride(IndexedOpTy op, int64_t dim); static bool hasAllStaticSizesAndStrides(IndexedOpTy op); }; template <> struct IndexedOpInfo { static mlir::SmallVector getOffsets(mlir::tensor::ExtractOp op) { return op.getIndices(); } static mlir::Value getTensor(mlir::tensor::ExtractOp op) { return op.getTensor(); } static int64_t getSize(mlir::tensor::ExtractOp op, int64_t dim) { return 1; } static int64_t getStride(mlir::tensor::ExtractOp, int64_t dim) { return 1; } static bool hasAllStaticSizesAndStrides(mlir::tensor::ExtractOp op) { return true; } }; template <> struct IndexedOpInfo { static mlir::SmallVector getOffsets(mlir::tensor::InsertOp op) { return op.getIndices(); } static mlir::Value getTensor(mlir::tensor::InsertOp op) { return op.getDest(); } static int64_t getSize(mlir::tensor::InsertOp op, int64_t dim) { return 1; } static int64_t getStride(mlir::tensor::InsertOp, int64_t dim) { return 1; } static bool hasAllStaticSizesAndStrides(mlir::tensor::InsertOp op) { return true; } }; template <> struct IndexedOpInfo { static mlir::SmallVector getOffsets(mlir::tensor::ExtractSliceOp op) { return op.getMixedOffsets(); } static mlir::Value getTensor(mlir::tensor::ExtractSliceOp op) { return op.getSource(); } static int64_t getSize(mlir::tensor::ExtractSliceOp op, int64_t dim) { return op.getStaticSizes()[dim]; } static int64_t getStride(mlir::tensor::ExtractSliceOp op, int64_t dim) { return op.getStaticStrides()[dim]; } static bool hasAllStaticSizesAndStrides(mlir::tensor::ExtractSliceOp op) { for (size_t i = 0; i < op.getSizes().size(); i++) { if (op.isDynamicSize(i) || op.isDynamicStride(i)) return false; } return true; } }; template <> struct IndexedOpInfo { static mlir::SmallVector getOffsets(mlir::tensor::InsertSliceOp op) { return op.getMixedOffsets(); } static mlir::Value getTensor(mlir::tensor::InsertSliceOp op) { return op.getDest(); } static int64_t getSize(mlir::tensor::InsertSliceOp op, int64_t dim) { return op.getStaticSizes()[dim]; } static int64_t getStride(mlir::tensor::InsertSliceOp op, int64_t dim) { return op.getStaticStrides()[dim]; } static bool hasAllStaticSizesAndStrides(mlir::tensor::InsertSliceOp op) { for (size_t i = 0; i < op.getSizes().size(); i++) { if (op.isDynamicSize(i) || op.isDynamicStride(i)) return false; } return true; } }; // Returns the intersection of two dense sets template llvm::DenseSet intersectSets(llvm::DenseSet &a, llvm::DenseSet &b) { llvm::DenseSet res; for (T element : a) { if (b.contains(element)) res.insert(element); } return res; } // Returns a set with all elements of `a`, which do not appear in `b` template llvm::DenseSet setMinus(llvm::DenseSet &a, llvm::DenseSet &b) { llvm::DenseSet res; for (T element : a) { if (!b.contains(element)) { res.insert(element); } } return res; } // Returns a vector, which contains all elements of `a` that appear // in `f`. Order is preserved. template llvm::SmallVector filterVector(llvm::SmallVector &a, llvm::DenseSet &f) { llvm::SmallVector res; for (T &elt : a) { if (f.contains(elt)) res.push_back(elt); } return res; } // Returns the index of the first operand of `op` that uses `v`. The // value `v` must be referenced by at least one operand, otherwise an // assertion is triggered. unsigned getOperandIndexForValue(mlir::Operation *op, mlir::Value v) { for (auto it : llvm::enumerate(op->getOperands())) { if (it.value() == v) return it.index(); } llvm_unreachable("Attempted to get operand index of value that is not an " "operand of the operation"); } // Walks up the use-def-chain of of the value `v`, executing `cb` // for any value not previously encountered un `visited`. static void walkUseDefChainRec(mlir::DenseSet &visited, mlir::Value v, llvm::function_ref cb) { if (visited.contains(v)) return; cb(v); if (mlir::Operation *op = v.getDefiningOp()) { for (mlir::Value operand : op->getOperands()) { walkUseDefChainRec(visited, operand, cb); } } } // Walks up the use-def-chain of of the value `v`, executing `cb` once // for every value encountered. static void walkUseDefChain(mlir::Value v, llvm::function_ref cb) { mlir::DenseSet visited; walkUseDefChainRec(visited, v, cb); } // Helper function that applies a function `func` to each element of // a container `ins` and returns the result as a `llvm::SmallVector`. template auto map(ContainerTy &&ins, FuncTy func) { return llvm::to_vector(llvm::map_range(ins, func)); } // Returns true if the operation `op` is the only direct user of `v`. static bool isSoleUser(mlir::Value v, mlir::Operation *op) { return !v.getUsers().empty() && std::next(v.getUsers().begin()) == v.getUsers().end() && (*v.getUsers().begin()) == op; } // Sorts the `scf.for` loops from `forOps` from the innermost to the // outermost loop. The loops must be embedded one into another. template SmallVector sortLoopsInnermostToOutermost(ContainerTy &forOps) { SmallVector res; for (mlir::scf::ForOp forOp : forOps) { size_t i = 0; for (; i < res.size(); i++) { if (!forOp->isAncestor(res[i])) break; } res.insert(res.begin() + i, forOp); } return res; } // Sorts the `scf.for` loops from `forOps` from the outermost to the // innermost loop. The loops must be embedded one into another. template SmallVector sortLoopsOutermostToInnermost(ContainerTy &forOps) { SmallVector nest = sortLoopsInnermostToOutermost(forOps); std::reverse(nest.begin(), nest.end()); return nest; } // Takes a set of loops `forOps` and finds the longest sequence of // perfectly nested loops starting with innermost loops of `forOps`, // in which the predicate `parentChildPredicate` holds for all loops // and their direct child loops if specified. template void getLongestPerfectLoopnest( ContainerTy forOps, mlir::scf::ForOp &innermost, mlir::scf::ForOp &outermost, llvm::function_ref parentChildPredicate = nullptr) { assert(forOps.size() > 0); SmallVector innermostToOutermost = sortLoopsInnermostToOutermost(forOps); innermost = innermostToOutermost[0]; outermost = innermost; for (size_t i = 1; i < innermostToOutermost.size(); i++) { if (outermost->getParentOp() != innermostToOutermost[i].getOperation()) break; if (parentChildPredicate && !parentChildPredicate(innermostToOutermost[i], outermost)) { break; } outermost = innermostToOutermost[i]; } } // Returns a value that corresponds to the tensor `v` with all // dimensions, but the last `trailingDimensions` dimensions collapsed // into a single dimension. static mlir::Value flattenTensor(mlir::ImplicitLocOpBuilder &builder, mlir::Value v, unsigned trailingDimensions = 0) { mlir::RankedTensorType type = v.getType().dyn_cast(); assert(type && "Value type is not a ranked tensor"); if (type.getShape().size() - trailingDimensions == 1) { return v; } else { mlir::ReassociationIndices prefixCollapseGroup; llvm::SmallVector collapseGroups; for (unsigned i = 0; i < type.getShape().size() - trailingDimensions; i++) prefixCollapseGroup.push_back(i); collapseGroups.push_back(prefixCollapseGroup); for (unsigned i = type.getShape().size() - trailingDimensions; i < type.getShape().size(); i++) { mlir::ReassociationIndices suffixGroup; suffixGroup.push_back(i); collapseGroups.push_back(suffixGroup); } return builder.create(v, collapseGroups); } } // Returns a tensor with all the elements of the flat tensor `v`, but // shaped as a tensor with the type `targetType`. static mlir::Value unflattenTensor(mlir::ImplicitLocOpBuilder &builder, mlir::Value v, mlir::RankedTensorType targetType) { mlir::RankedTensorType type = v.getType().dyn_cast(); assert(type && type.getShape().size() == 1 && "Value is not a tensor of rank 1"); if (targetType.getShape().size() == 1) { return v; } else { mlir::ReassociationIndices expandGroup; for (unsigned i = 0; i < targetType.getShape().size(); i++) expandGroup.push_back(i); return builder.create( targetType, v, llvm::ArrayRef{expandGroup}); } } // Rewrites a perfect loop nest yielding M results from the innermost // loop, such that the n-th of these results is omitted, causing the // nest to return M-1 results. // // Example: // // scf.for ... { // %v0:5 = scf.for ... { // %v1:5 = scf.for ... { // %v2:5 = scf.for ... { // %inner0 = ... // %inner1 = ... // %inner2 = ... // %inner3 = ... // %inner4 = ... // ... // scf.yield %inner0, %inner1, %inner2, %inner3, %inner4 // } // scf.yield %v2#0, %v2#1, %v2#2, %v2#3, %v2#4 // } // scf.yield %v1#0, %v1#1, %v1#2, %v1#3, %v1#4 // } // scf.yield %v0#0, %v0#1, %v0#2, %v0#3, %v0#4 // } // // with n=2 becomes: // // scf.for ... { // %v0:4 = scf.for ... { // %v1:4 = scf.for ... { // %v2:4 = scf.for ... { // %inner0 = ... // %inner1 = ... // %inner2 = ... // %inner3 = ... // %inner4 = ... // ... // scf.yield %inner0, %inner1, %%inner3, %inner4 // } // scf.yield %v2#0, %v2#1, %v2#2, %v2#3 // } // scf.yield %v1#0, %v1#1, %v1#2, %v1#3 // } // scf.yield %v0#0, %v0#1, %v0#2, %v0#3 // } static void rewritePerfectLoopNestWithReplacedNthResult( mlir::PatternRewriter &rewriter, mlir::scf::ForOp innermostFor, mlir::scf::ForOp outermostFor, unsigned n, mlir::Value replacement) { // Assemble loop nest from innermost and outermost loop llvm::SmallVector nest; for (mlir::Operation *forOp = innermostFor.getOperation(); forOp != outermostFor.getOperation()->getParentOp(); forOp = forOp->getParentOp()) { nest.push_back(llvm::dyn_cast(forOp)); } // Dismiss n-th operand from all yields from outermost to innermost loop for (mlir::scf::ForOp currFor : llvm::reverse(nest)) { // Build new, empty loop nest rewriter.setInsertionPoint(currFor); SmallVector newInitArgs; for (auto i : llvm::enumerate(currFor.getInitArgs())) { if (i.index() != n) { newInitArgs.push_back(i.value()); } } scf::ForOp newFor = rewriter.create( currFor.getLoc(), currFor.getLowerBound(), currFor.getUpperBound(), currFor.getStep(), newInitArgs, [&](OpBuilder &builder, Location loc, Value iv, ValueRange args) {}); // Copy all attributes from old for loop newFor->setAttrs(currFor->getAttrs()); // Move operations from old for op to new one without yield mlir::Block *newBody = newFor.getBody(); auto &newOperations = newBody->getOperations(); mlir::Block *oldBody = currFor.getBody(); auto end = oldBody->end(); end--; newOperations.splice(newOperations.begin(), oldBody->getOperations(), oldBody->begin(), end); mlir::scf::YieldOp oldYield = llvm::dyn_cast(currFor.getBody()->getTerminator()); // Assemble list of new yielded values, i.e., all old operands // without the n-th operand llvm::SmallVector newYieldOperands; for (size_t i = 0; i < oldYield->getNumOperands(); i++) { if (i != n) { newYieldOperands.push_back(oldYield->getOperand(i)); } } rewriter.setInsertionPointToEnd(newBody); rewriter.create(oldYield.getLoc(), newYieldOperands); // Remap iter args for (size_t i = 0; i < currFor.getNumRegionIterArgs(); i++) { if (i != n) { size_t idx = (i < n) ? i : i - 1; replaceAllUsesInRegionWith(currFor.getRegionIterArg(i), newFor.getRegionIterArg(idx), newFor.getRegion()); } else { replaceAllUsesInRegionWith(currFor.getRegionIterArg(i), replacement, newFor.getRegion()); } } // Remap IV replaceAllUsesInRegionWith(currFor.getInductionVar(), newFor.getInductionVar(), newFor.getRegion()); // Remap results in old yield for (auto i : llvm::enumerate(currFor.getInitArgs())) { replaceAllUsesInRegionWith(oldYield.getOperand(i.index()), currFor.getRegionIterArg(i.index()), currFor.getRegion()); } // Assemble list of values that the old outermost for loop is // replaced with (i.e., all retained yielded values and the // replacement value for the n-th operand) if (currFor == outermostFor) { llvm::SmallVector newResults; for (size_t i = 0; i < currFor->getNumResults(); i++) { if (i < n) { newResults.push_back(newFor->getResult(i)); } else if (i == n) { newResults.push_back(replacement); } else { newResults.push_back(newFor->getResult(i - 1)); } } rewriter.replaceOp(currFor, newResults); } else { // An inner loop has been rewritten -> remap uses of results of // the old loop to the new loop mlir::scf::ForOp parentFor = llvm::dyn_cast(currFor->getParentOp()); for (auto it : llvm::enumerate(currFor.getResults())) { if (it.index() != n) { mlir::Value newResult = (it.index() < n) ? newFor.getResult(it.index()) : newFor.getResult(it.index() - 1); replaceAllUsesInRegionWith(it.value(), newResult, parentFor.getRegion()); } } } } } /// Checks if the value `v` is defined outside of the `loop` or a pure /// operation that can be safely replicated outside the loop (i.e., all /// of its operands are also recursively either defined outside of the /// loop or pure). static bool isHoistable(mlir::Value v, mlir::scf::ForOp loop) { mlir::Operation *op = v.getDefiningOp(); return loop.isDefinedOutsideOfLoop(v) || (op && mlir::isPure(op) && op->getNumResults() == 1 && llvm::all_of(op->getOperands(), [&](mlir::Value operand) { return isHoistable(operand, loop); })); } llvm::SmallVector buildNormalizedIndexes(mlir::PatternRewriter &rewriter, llvm::ArrayRef nest) { assert(nest.size() > 0); mlir::scf::ForOp innermost = nest[nest.size() - 1]; llvm::SmallVector res; rewriter.setInsertionPointToStart(innermost.getBody()); for (mlir::scf::ForOp forOp : nest) { mlir::ImplicitLocOpBuilder ilob(forOp.getLoc(), rewriter); mlir::Value idx = normalizeInductionVar( ilob, forOp.getInductionVar(), forOp.getLowerBound(), forOp.getStep()); res.push_back(idx); } return res; } /// Returns an `OpFoldResult` with an `IntegerAttr` value if `v` is /// produced by a constant, otherwise an `OpFoldResult` containing `v` /// itself. static mlir::OpFoldResult getValueAsOpFoldResult(mlir::Value v) { if (mlir::arith::ConstantOp cstOp = dyn_cast_or_null(v.getDefiningOp())) { return cstOp.getValue(); } return v; } // Checks whether the `OpFoldResult` `v` is a `mlir::Value` generated // by a `ConstantOp`. If so, an `OpFoldResult` with an attribute corresponding // to the value of the constant is returned. Otherwise, `v` is returned // unchanged. static mlir::OpFoldResult opFoldConstantValueToAttribute(mlir::OpFoldResult v) { if (mlir::Value dynV = v.dyn_cast()) { if (isConstantIndexValue(dynV)) { return mlir::IntegerAttr::get( IndexType::get(dynV.getContext()), llvm::APInt(64, getConstantIndexValue(dynV))); } } return v; } /// Performs an arithmetic operation on `a` and `b`, where both values /// can be any combination of `IntegerAttr` and `Value`. template mlir::OpFoldResult opFoldExpr(mlir::ImplicitLocOpBuilder &builder, mlir::OpFoldResult a, mlir::OpFoldResult b) { static IsNeutralElementFunctor isNeutralElement; auto exprValVal = [&](mlir::Value a, mlir::Value b) -> mlir::Value { return builder.create(a, b); }; auto exprAttrVal = [&](mlir::IntegerAttr attr, mlir::Value v) -> mlir::Value { mlir::Value cst = builder.create(attr.getInt()); return exprValVal(cst, v); }; auto exprValAttr = [&](mlir::Value v, mlir::IntegerAttr attr) -> mlir::Value { mlir::Value cst = builder.create(attr.getInt()); return exprValVal(v, cst); }; auto exprAttrAttr = [&](mlir::IntegerAttr a, mlir::IntegerAttr b) -> mlir::IntegerAttr { static ArithFunctor f; return builder.getIndexAttr(f(a.getInt(), b.getInt())); }; if (a.is()) { if (b.is()) { return exprValVal(a.get(), b.get()); } else { mlir::IntegerAttr bAttr = b.get().cast(); if (isNeutralElement(bAttr.getValue().getSExtValue())) { return a; } else { return exprValAttr(a.get(), bAttr); } } } else { mlir::IntegerAttr aAttr = a.get().cast(); if (b.is()) { return exprAttrVal(aAttr, b.get()); } else { mlir::IntegerAttr bAttr = b.get().cast(); if (isNeutralElement(bAttr.getValue().getSExtValue())) return a; else return exprAttrAttr(aAttr, bAttr); } } } /// Hoists the pure operation producing the value `v` out of /// `outermostFor` recursively. All newly created mappings are /// collected in `mapping`. static mlir::Value hoistPure(mlir::PatternRewriter &rewriter, mlir::scf::ForOp outermostFor, mlir::IRMapping &mapping, mlir::Value v) { if (outermostFor.isDefinedOutsideOfLoop(v)) return v; mlir::Operation *op = v.getDefiningOp(); assert(op && mlir::isPure(op) && op->getNumResults() == 1); for (mlir::Value operand : op->getOperands()) { if (!mapping.contains(operand)) mapping.map(operand, hoistPure(rewriter, outermostFor, mapping, operand)); } rewriter.setInsertionPoint(outermostFor); mlir::Operation *clonedOp = rewriter.clone(*op, mapping); return clonedOp->getResult(0); } /// Hoists the pure operation producing the value `v` out of /// `outermostFor` recursively. static mlir::Value hoistPure(mlir::PatternRewriter &rewriter, mlir::scf::ForOp outermostFor, mlir::Value v) { mlir::IRMapping mapping; return hoistPure(rewriter, outermostFor, mapping, v); } template int64_t getSliceExtents(IndexedOpTy op, size_t dimIdx) { int64_t stride = IndexedOpInfo::getStride(op, dimIdx); int64_t size = IndexedOpInfo::getSize(op, dimIdx); assert(stride > 0 && size > 0 && std::numeric_limits::max() / stride >= size); return stride * size; } /// Hoists a an operation embedded into a loop nest that and that is /// indexed using quasi-affine expressions of the loops' IVs as a /// tensor operation outside of the outermost loop (e.g., /// `tensor.extract` becomes `tensor.extract_slice`). This function /// only takescare about calculating offsets, sizes and strides. The actual /// tensor operation must be built by the callback function `tensorOpBuilder`. template mlir::Value hoistIndexedOp( mlir::PatternRewriter &rewriter, mlir::scf::ForOp outermostFor, mlir::Value tensorizedOperands, IndexedOpTy indexedOp, llvm::function_ref, llvm::ArrayRef, llvm::ArrayRef, llvm::ArrayRef)> tensorOpBuilder) { llvm::SmallVector offsets; llvm::SmallVector sizes; llvm::SmallVector strides; llvm::SmallVector ivIndexedDims; rewriter.setInsertionPoint(outermostFor); mlir::ImplicitLocOpBuilder ilob(indexedOp.getLoc(), rewriter); for (auto it : llvm::enumerate(IndexedOpInfo::getOffsets(indexedOp))) { mlir::OpFoldResult idxExpr = it.value(); size_t dimIdx = it.index(); mlir::scf::ForOp forOp; bool isAffine = isQuasiAffineIVExpression(idxExpr, &forOp); int64_t stride = IndexedOpInfo::getStride(indexedOp, dimIdx); int64_t size = IndexedOpInfo::getSize(indexedOp, dimIdx); if (isAffine && forOp && (forOp == outermostFor || outermostFor->isAncestor(forOp))) { std::optional bas = getBoundsOfQuasiAffineIVExpression(idxExpr, forOp); assert(bas.has_value()); assert(bas->step != 0); int64_t sliceExtents = getSliceExtents(indexedOp, dimIdx); assert(sliceExtents == 1 || sliceExtents == bas->step); int64_t hoistedSliceSize = size * getStaticTripCount(*bas); int64_t hoistedStride = (sliceExtents == 1) ? bas->step : stride; offsets.push_back(rewriter.getIndexAttr(bas->lb)); sizes.push_back(rewriter.getIndexAttr(hoistedSliceSize)); strides.push_back(rewriter.getIndexAttr(hoistedStride)); ivIndexedDims.push_back(true); } else if (isAffine || idxExpr.is() || outermostFor.isDefinedOutsideOfLoop( idxExpr.dyn_cast())) { offsets.push_back(opFoldConstantValueToAttribute(idxExpr)); sizes.push_back(rewriter.getIndexAttr(size)); strides.push_back(rewriter.getIndexAttr(stride)); ivIndexedDims.push_back(false); } else { llvm_unreachable("Unknown type of index found"); } } return tensorOpBuilder(ilob, tensorizedOperands, offsets, sizes, strides, ivIndexedDims); } /// Hoists a tensor.extract operation embedded into a loop nest as a /// `tensor.extract_slice` outside of the outermost loop of the nest static mlir::Value hoistExtractOp(mlir::PatternRewriter &rewriter, mlir::scf::ForOp outermostFor, mlir::tensor::ExtractOp extractOp) { return hoistIndexedOp( rewriter, outermostFor, extractOp.getTensor(), extractOp, [](mlir::ImplicitLocOpBuilder &builder, mlir::Value tensorizedOperands, llvm::ArrayRef offsets, llvm::ArrayRef sizes, llvm::ArrayRef strides, llvm::ArrayRef ivIndexedDims) -> mlir::Value { mlir::tensor::ExtractSliceOp slice = builder.create( tensorizedOperands, offsets, sizes, strides); // The extract slice operation above preserves non-IV-indexed // dimensions of the original extract operation as 1-sized // dimensions, e.g., a `tensor.extract[cst, i, j, cst, k]` // results in a slice with the shape `1xMxNx1xK` (where M, N // and K are the maximum values for i, j and k, assuming a // loop step of 1). // // If there is any non-IV-indexed dimension, add a collapse // shape operation that collapses the 1-sized dimensions // originating from non-IV-indexed dimensions of the extract // operation into the IV-indexed dimensions. I.e., in the // above example, produce a slice with the shape `MxNxK` // rather than `1xMxNx1xK`. if (llvm::all_of(ivIndexedDims, [](bool v) { return v; })) { return slice; } else { llvm::SmallVector collapseGroups; mlir::ReassociationIndices currCollapseGroup; bool prefixDone = false; for (auto i : llvm::enumerate(ivIndexedDims)) { // If this is a non-IV-indexed dimension, accumulate // dimension in the current group of collapsed dimensions if (!i.value()) { currCollapseGroup.push_back(i.index()); } else { // If there were only non-IV-indexed dimensions before, // add this first IV-indexed dimension to the current // group of collapsed dimensions and try to accumulate // with following, non-IV-indexed dimensions. if (!prefixDone) { currCollapseGroup.push_back(i.index()); prefixDone = true; } else { collapseGroups.push_back(currCollapseGroup); currCollapseGroup = mlir::ReassociationIndices(); currCollapseGroup.push_back(i.index()); } } } // Add last collapse group for trailing series of // non-IV-indexed dimensions if (!currCollapseGroup.empty()) collapseGroups.push_back(currCollapseGroup); mlir::tensor::CollapseShapeOp cso = builder.create(slice, collapseGroups); return cso; } }); } /// Hoists a tensor.extract_slice operation embedded into a loop nest outside of /// the outermost loop of the nest static mlir::Value hoistExtractOp(mlir::PatternRewriter &rewriter, mlir::scf::ForOp outermostFor, mlir::tensor::ExtractSliceOp extractSliceOp) { return hoistIndexedOp( rewriter, outermostFor, extractSliceOp.getSource(), extractSliceOp, [](mlir::ImplicitLocOpBuilder &builder, mlir::Value tensorizedOperands, llvm::ArrayRef offsets, llvm::ArrayRef sizes, llvm::ArrayRef strides, llvm::ArrayRef ivIndexedDims) -> mlir::Value { mlir::tensor::ExtractSliceOp slice = builder.create( tensorizedOperands, offsets, sizes, strides); return slice; }); } /// Hoists a tensor.insert operation embedded into a loop nest as a /// tensor.insert_slice outside of the outermost loop of the nest static mlir::Value hoistInsertOp(mlir::PatternRewriter &rewriter, mlir::Value tensorizedOperands, mlir::Value targetTensor, mlir::scf::ForOp outermostFor, mlir::tensor::InsertOp insertOp) { return hoistIndexedOp( rewriter, outermostFor, targetTensor, insertOp, [&](mlir::ImplicitLocOpBuilder &builder, mlir::Value targetTesor, llvm::ArrayRef offsets, llvm::ArrayRef sizes, llvm::ArrayRef strides, llvm::ArrayRef ivIndexedDims) -> mlir::Value { return builder.create( tensorizedOperands, targetTesor, offsets, sizes, strides); }); } /// Hoists a tensor.insert_slice operation embedded into a loop nest /// outside of the outermost loop of the nest static mlir::Value hoistInsertOp(mlir::PatternRewriter &rewriter, mlir::Value tensorizedOperands, mlir::Value targetTensor, mlir::scf::ForOp outermostFor, mlir::tensor::InsertSliceOp insertSliceOp) { return hoistIndexedOp( rewriter, outermostFor, targetTensor, insertSliceOp, [&](mlir::ImplicitLocOpBuilder &builder, mlir::Value targetTesor, llvm::ArrayRef offsets, llvm::ArrayRef sizes, llvm::ArrayRef strides, llvm::ArrayRef ivIndexedDims) -> mlir::Value { return builder.create( tensorizedOperands, targetTesor, offsets, sizes, strides); }); } // Recursively a set of values `next` with all producers within // `forOp` and those that are just outside of it. All intermediate // values and those collected in `next` are added to `visited`. The // function `valPredicate` is evaluated for each encountered value. If // the predicate does not hold for any value, `extendFrontierRec` // halts and returns false. bool extendFrontierRec(mlir::Value v, llvm::DenseSet &next, llvm::DenseSet &visited, mlir::scf::ForOp forOp, llvm::function_ref valPredicate) { if (visited.contains(v)) return true; if (forOp.isDefinedOutsideOfLoop(v)) { next.insert(v); } else { if (!valPredicate(v)) return false; visited.insert(v); if (mlir::BlockArgument ba = llvm::dyn_cast(v)) { return true; } else { mlir::Operation *definingOp = v.getDefiningOp(); assert(definingOp); for (mlir::Value operand : definingOp->getOperands()) { if (!extendFrontierRec(operand, next, visited, forOp, valPredicate)) return false; } } } return true; } // See `extendFrontierRec` std::optional> extendFrontier(const llvm::DenseSet &prev, llvm::DenseSet &visited, mlir::scf::ForOp forOp, llvm::function_ref valPredicate) { llvm::DenseSet next; for (mlir::Value v : prev) { if (!extendFrontierRec(v, next, visited, forOp, valPredicate)) return std::nullopt; } return next; } // Returns true if `v` is a region iteration argument of `forOp`. static bool valueIsRegionIterArgOf(mlir::Value v, mlir::scf::ForOp forOp) { return llvm::any_of(forOp.getRegionIterArgs(), [=](mlir::Value iterArg) { return iterArg == v; }); } // Checks if `v` refers to the iteration argument of an `scf.for` loop static std::optional valueIsRegionIterArg(mlir::Value v) { if (mlir::BlockArgument ba = llvm::dyn_cast(v)) { if (mlir::scf::ForOp forOp = llvm::dyn_cast(ba.getOwner()->getParentOp())) { if (llvm::any_of( forOp.getRegionIterArgs(), [=](mlir::BlockArgument otherArg) { return otherArg == ba; })) return forOp; } } return std::nullopt; } // Separates the operands of a batchable operation into a vector of // batchable and a vector of non-batchable operands according to the // specification of the `variant`-th batching variant of the // operation. static void splitOperands(BatchableOpInterface batchableOp, unsigned variant, llvm::SmallVector &batchableOperands, llvm::SmallVector &nonBatchableOperands) { for (mlir::OpOperand &batchableOperand : batchableOp.getBatchableOperands(variant)) { batchableOperands.push_back(&batchableOperand); } for (mlir::OpOperand &operand : batchableOp->getOpOperands()) { if (llvm::none_of(batchableOp.getBatchableOperands(variant), [&](mlir::OpOperand &batchableOperand) { return operand.getOperandNumber() == batchableOperand.getOperandNumber(); })) { nonBatchableOperands.push_back(&operand); } } } // Converts the ordered sequence `operands` into an unordered set of // operands llvm::DenseSet operandsToValueSet(llvm::ArrayRef operands) { llvm::DenseSet set; for (mlir::OpOperand *operand : operands) set.insert(operand->get()); return set; } /// Pattern that replaces a batchable operation embedded into a static /// loop nest with the batched version of the operation, e.g., /// /// scf.for %i = c0 to %cN step %c1 { /// scf.for %j = c0 to %cM step %c1 { /// scf.for %k = c0 to %cK step %c1 { /// %s0 = tensor.extract %T[%i, %j, %k] /// %s1 = tensor.extract %U[%k, %j, %i] /// %res = batchable_op %s0, %s1, ... /// ... /// } /// } /// } /// /// is replaced with: /// /// %batchedSlice0 = tensor.extract_slice /// %T[%c0, %c0, %c0] [%cN, %cM, %cK] [%c1, %c1, %c1] /// %batchedSlice1 = tensor.extract_slice /// %U[%c0, %c0, %c0] [%cK, %cM, %cN] [%c1, %c1, %c1] /// %flatSlice0 = tensor.collapse_shape %batchedSlice0 /// %flatSlice1 = tensor.collapse_shape %batchedSlice1 /// %resTFlat = batchedOp %flatSlice0, %flatSlice1, ... /// %resT = tensor.expand_shape %resTFlat /// /// scf.for %i = c0 to %cN step %c1 { /// scf.for %j = c0 to %cM step %c1 { /// scf.for %k = c0 to %cK step %c1 { /// %res = tensor.extract %resT[%i, %j, %k] /// ... /// } /// } /// } /// /// Any index may be a quasi-affine expression on a single loop /// induction variable, but the distance between the result for any /// two successive values of the IV must be constant. /// /// The element type of batched operands can either be scalar or /// tensors themselves. class BatchingPattern : public mlir::OpRewritePattern { public: BatchingPattern(mlir::MLIRContext *context, int64_t maxBatchSize = std::numeric_limits::max()) : mlir::OpRewritePattern(context), maxBatchSize(maxBatchSize) {} mlir::LogicalResult matchAndRewrite(mlir::func::FuncOp func, mlir::PatternRewriter &rewriter) const override { // Operation that will be hoisted out of the loop nest and // replaced by the batched version of the operation BatchableOpInterface targetOp; // The perfect loop nest containing the target operation from // the outermost to the innermost loop llvm::SmallVector nest; // Total number of elements that within a batch int64_t batchSize; // Selected batching variant of the batchable operation unsigned variant; // Sequences of batchable and non-batchable operands of the // batchable operation llvm::SmallVector batchableOperands; llvm::SmallVector nonBatchableOperands; // Sets of values that recursively produce the batched and // non-batched operands and that are contained in the loop nest llvm::DenseSet visitedBatched; llvm::DenseSet visitedNonBatched; // Find a batchable op which is embedded into a loop nest func.walk([&](BatchableOpInterface scalarOp) { // Predicate checking whether an scf.for op is a valid candidate // to expand the loop nest upwards towards the outermost loop auto isCandidateLoop = [](mlir::scf::ForOp forOp) -> bool { return isStaticLoop(forOp); }; // Only batchable operations within at least one loop are of // interest mlir::scf::ForOp innermostFor = llvm::dyn_cast_or_null(scalarOp->getParentOp()); if (!innermostFor || !isCandidateLoop(innermostFor)) return mlir::WalkResult::skip(); unsigned candidateVariant = 0; // Try all batching variants of the batchable operation in // sequence until all conditions for batching are met. for (candidateVariant = 0; candidateVariant < scalarOp.getNumBatchingVariants(); candidateVariant++) { llvm::SmallVector candidateBatchableOperands; llvm::SmallVector candidateNonBatchableOperands; splitOperands(scalarOp, candidateVariant, candidateBatchableOperands, candidateNonBatchableOperands); // Construct initial frontiers from the values used directly // by the batchable operation llvm::DenseSet frontierBatched = operandsToValueSet(candidateBatchableOperands); llvm::DenseSet frontierNonBatched = operandsToValueSet(candidateNonBatchableOperands); // Predicate for all values that a batchable operand depends // on auto batchableOperandProducerPredicate = [](mlir::Value v) { mlir::Operation *definingOp = v.getDefiningOp(); // Skip operations with regions so that // tests always only need to be // performed upwards and never have to // descend return (!definingOp || (mlir::isPure(definingOp) && definingOp->getNumRegions() == 0)) && !valueIsRegionIterArg(v); }; // Predicate for all values that a non-batchable operand depends // on auto nonBatchableOperandProducerPredicate = [](mlir::Value v, mlir::scf::ForOp forOp) { mlir::Operation *definingOp = v.getDefiningOp(); return isHoistable(v, forOp) && (!definingOp || definingOp->getNumRegions() == 0); }; // Check that predicates hold for the initial frontiers if (!llvm::all_of(frontierBatched, batchableOperandProducerPredicate) || !llvm::all_of(frontierNonBatched, [&](mlir::Value v) { return nonBatchableOperandProducerPredicate(v, innermostFor); })) { continue; } // Walk up the loop nest to find the outermost loop that // satisfies the conditions for the batchable and non-batchable // operands. llvm::DenseSet candidateVisitedBatched; llvm::DenseSet candidateVisitedNonBatched; llvm::SmallVector revNest; int64_t candidateBatchSize = 1; for (mlir::scf::ForOp forOp = innermostFor; forOp && isCandidateLoop(forOp); forOp = llvm::dyn_cast(forOp->getParentOp())) { int64_t thisTripCount = getStaticTripCount(forOp); if (maxBatchSize / candidateBatchSize < thisTripCount) break; else candidateBatchSize *= thisTripCount; std::optional> nextFrontierBatched = extendFrontier(frontierBatched, candidateVisitedBatched, forOp, batchableOperandProducerPredicate); if (!nextFrontierBatched.has_value()) break; // non-batchable operands must be defined outside or hoistable std::optional> nextFrontierNonBatched = extendFrontier(frontierNonBatched, candidateVisitedNonBatched, forOp, [&](mlir::Value v) { return nonBatchableOperandProducerPredicate( v, forOp); }); if (!nextFrontierNonBatched.has_value()) break; frontierBatched = nextFrontierBatched.value(); frontierNonBatched = nextFrontierNonBatched.value(); revNest.push_back(forOp); } // Skip if no loop nest satisfying constraints had been found if (revNest.size() == 0) continue; int64_t revBatchSize = getNestedStaticTripCount(revNest); // Skip empty loop nests if (revBatchSize == 0) continue; nest = llvm::to_vector(llvm::reverse(revNest)); batchSize = revBatchSize; targetOp = scalarOp; variant = candidateVariant; batchableOperands = std::move(candidateBatchableOperands); nonBatchableOperands = std::move(candidateNonBatchableOperands); visitedBatched = std::move(candidateVisitedBatched); visitedNonBatched = std::move(candidateVisitedNonBatched); return mlir::WalkResult::interrupt(); } return mlir::WalkResult::skip(); }); // if no suitable batchable operation was found, bail out if (!targetOp) return mlir::failure(); mlir::scf::ForOp outermostFor = nest[0]; rewriter.setInsertionPoint(outermostFor); mlir::SmallVector iterArgs; // Create a tensor for each batchable operand with the right size // to be used as loop-carried dependencies in a loop nest // collecting the input values of the batched operation for (mlir::OpOperand *batchableOperand : batchableOperands) { mlir::Type nonBatchedType = batchableOperand->get().getType(); llvm::SmallVector batchedShape = map(nest, static_cast(&getStaticTripCount)); mlir::Type batchedElementType; if (mlir::RankedTensorType nonBatchedTensorType = nonBatchedType.dyn_cast()) { batchedShape.append(nonBatchedTensorType.getShape().begin(), nonBatchedTensorType.getShape().end()); batchedElementType = nonBatchedTensorType.getElementType(); } else { batchedElementType = nonBatchedType; } mlir::RankedTensorType batchedTensorType = RankedTensorType::get(batchedShape, batchedElementType); // TODO check that there are no dynamic sizes mlir::Value iterArg = rewriter.create( outermostFor.getLoc(), batchedTensorType, mlir::ValueRange{}); iterArgs.push_back(iterArg); } // Now reconstruct the loop nest with normalized IV ranges and // clone all necessary operations mlir::IRMapping mapping; llvm::SmallVector clonedNest; // Create a loop nest producing the tensors for the batched input // operands. The operations producing the individual scalar values // are cloned from the original loop nest containing the batchable // operation. for (mlir::scf::ForOp forOp : nest) { mlir::scf::ForOp clonedForOp = rewriter.create( forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), iterArgs); clonedNest.push_back(clonedForOp); iterArgs = llvm::to_vector_of(clonedForOp.getRegionIterArgs()); rewriter.setInsertionPoint(clonedForOp.getBody(), clonedForOp.getBody()->begin()); mapping.map(forOp.getInductionVar(), clonedForOp.getInductionVar()); for (mlir::Operation &op : forOp.getBody()->getOperations()) { if (llvm::any_of(op.getResults(), [&](mlir::Value res) { return visitedBatched.find(res) != visitedBatched.end() || visitedNonBatched.find(res) != visitedNonBatched.end(); })) { rewriter.clone(op, mapping); } } } // Build the index for the batchable operands in their tensors mlir::ImplicitLocOpBuilder ilob(outermostFor.getLoc(), rewriter); llvm::SmallVector idx = buildNormalizedIndexes(rewriter, clonedNest); // Build the operations that insert the scalar values of the // batches into the respective tensors llvm::SmallVector yields; for (auto batchableOperandAndIterArg : llvm::zip(batchableOperands, iterArgs)) { mlir::Value batchableOperand = std::get<0>(batchableOperandAndIterArg)->get(); mlir::Value iterArg = std::get<1>(batchableOperandAndIterArg); if (mlir::RankedTensorType batchElementType = llvm::dyn_cast( batchableOperand.getType())) { llvm::SmallVector offsets = map(idx, getValueAsOpFoldResult); llvm::SmallVector strides(idx.size(), ilob.getI64IntegerAttr(1)); llvm::SmallVector sizes(idx.size(), ilob.getI64IntegerAttr(1)); offsets.append(batchElementType.getShape().size(), ilob.getI64IntegerAttr(0)); strides.append(batchElementType.getShape().size(), ilob.getI64IntegerAttr(1)); for (int64_t dim : batchElementType.getShape()) { sizes.push_back(ilob.getI64IntegerAttr(dim)); } mlir::Value updatedBatchedVector = ilob.create( mapping.lookupOrDefault(batchableOperand), iterArg, offsets, sizes, strides); yields.push_back(updatedBatchedVector); } else { mlir::Value updatedBatchedVector = ilob.create( mapping.lookupOrDefault(batchableOperand), iterArg, idx); yields.push_back(updatedBatchedVector); } } for (mlir::scf::ForOp forOp : llvm::reverse(clonedNest)) { ilob.setInsertionPointToEnd(forOp.getBody()); ilob.create(yields); yields = llvm::to_vector_of(forOp.getResults()); } // Hoist all non-batchable operands out of the loop nest llvm::SmallVector hoistedNonBatchableValues; for (mlir::OpOperand *nonBatchableOperand : nonBatchableOperands) { hoistedNonBatchableValues.push_back( hoistPure(rewriter, outermostFor, nonBatchableOperand->get())); } ilob.setInsertionPoint(outermostFor); // Flatten all batched tensors before passing them to the batched // operation llvm::SmallVector batchedOperands; for (auto batchableOperandAndBatchedOperand : llvm::zip(batchableOperands, yields)) { mlir::OpOperand *batchableOperand = std::get<0>(batchableOperandAndBatchedOperand); mlir::Value structuredBatchedOperand = std::get<1>(batchableOperandAndBatchedOperand); unsigned trailingDimensions = 0; if (mlir::RankedTensorType batchedTensorType = batchableOperand->get() .getType() .dyn_cast()) { trailingDimensions = batchedTensorType.getShape().size(); } batchedOperands.push_back( flattenTensor(ilob, structuredBatchedOperand, trailingDimensions)); } // Created the actual batched operation through the op interface mlir::Value batchedResult = targetOp.createBatchedOperation( variant, ilob, batchedOperands, hoistedNonBatchableValues); mlir::RankedTensorType batchedResultType = llvm::dyn_cast(batchedResult.getType()); assert(batchedResultType); // Recreate the original shape of the batched results with the // normalized dimensions of the original loop nest llvm::SmallVector structuredBatchedShape = map( nest, static_cast(&getStaticTripCount)); if (batchedResultType.getShape().size() > 1) { structuredBatchedShape.push_back( batchedResultType .getShape()[batchedResultType.getShape().size() - 1]); } mlir::RankedTensorType structuredBatchedResultType = mlir::RankedTensorType::get(structuredBatchedShape, batchedResultType.getElementType()); mlir::Value structuredBatchedResult = unflattenTensor(ilob, batchedResult, structuredBatchedResultType); // Replace the original batchable operation with an operation that // extracts the respective scalar result from the batch of results // produced by the batched operation mlir::ImplicitLocOpBuilder ilob2(targetOp.getLoc(), rewriter); llvm::SmallVector idxUse = buildNormalizedIndexes(rewriter, nest); rewriter.setInsertionPoint(targetOp); if (batchedResultType.getShape().size() == 1) { rewriter.replaceOpWithNewOp( targetOp, structuredBatchedResult, idxUse); } else { llvm::SmallVector offsets = map(idxUse, getValueAsOpFoldResult); llvm::SmallVector strides(1, ilob2.getI64IntegerAttr(1)); llvm::SmallVector sizes(1, ilob2.getI64IntegerAttr(1)); offsets.append(batchedResultType.getShape().size() - 1, ilob2.getI64IntegerAttr(0)); strides.append(batchedResultType.getShape().size() - 1, ilob2.getI64IntegerAttr(1)); for (int64_t dim : batchedResultType.getShape().take_front()) { strides.push_back(ilob2.getI64IntegerAttr(dim)); sizes.push_back(ilob2.getI64IntegerAttr(dim)); } rewriter.replaceOpWithNewOp( targetOp, batchedResult, offsets, sizes, strides); } return mlir::success(); } private: int64_t maxBatchSize; }; // Returns a pair containing: // // - the set of loops whose IVs are referenced in the indexing // dimensions of `op` and which only appear in pure quasi-affine // expressions with a constant step wrt. to the iteration space and // where the step is equal to the size times the offset of the // dimension indexed by the expression. // // - an array defining the order in which these loop IVs are // referenced in the indexes // template std::pair, llvm::SmallVector> getLoopsForCandidateIndexes(IndexedOpTy op) { llvm::DenseSet allIVs; llvm::DenseSet qaIVs; llvm::SmallVector orderedQAIVs; for (auto it : llvm::enumerate(IndexedOpInfo::getOffsets(op))) { mlir::OpFoldResult expr = it.value(); size_t dimIdx = it.index(); if (mlir::Value dynExpr = expr.dyn_cast()) { walkUseDefChain(dynExpr, [&](mlir::Value v) { if (auto loop = valueIsRegionIterArg(v)) allIVs.insert(*loop); }); } mlir::scf::ForOp qaLoop; LoopsBoundsAndStep bas; if (isQuasiAffineIVExpressionWithConstantStep(expr, &qaLoop, &bas)) { if (qaLoop) { int64_t sliceExtents = getSliceExtents(op, dimIdx); if (sliceExtents == 1 || sliceExtents == bas.step) { qaIVs.insert(qaLoop); orderedQAIVs.push_back(qaLoop); } } } } llvm::DenseSet res = setMinus(qaIVs, allIVs); llvm::SmallVector orderedRes = filterVector(orderedQAIVs, res); return std::make_pair(std::move(res), std::move(orderedRes)); } /// Cleanup pattern that replaces a chain of a `tensor.extract` / /// `tensor.extract_slice`, a `tensor.insert` / `tensor.insert_slice` and an /// `scf.yield` op with `tensor.extract_slice` and `tensor.insert_slice` ops. /// E.g., /// /// %res0 = scf.for %i = c0 to %cN step %c1 iter_args(%arg0 = %T1) { /// %res1 = scf.for %j = c0 to %cM step %c1 iter_args(%arg1 = %arg0) { /// %res2 = scf.for %k = c0 to %cK step %c1 iter_args(%arg2 = %arg1) { /// %s = tensor.extract %T2[%i, %j, %k] /// %TRes = tensor.insert %s into %arg2 /// scf.yield %arg2 /// } /// scf.yield %res2 /// } /// scf.yield %res1 /// } /// /// is replaced with: /// /// %tmp = tensor.extract_slice %T2 /// %res0 = tensor.insert_slice %tmp into %T1 /// /// Any index may be a quasi-affine expression on a single loop /// induction variable, but the distance between the result for any /// two successive values of the IV must be constant. template class CleanupPattern : public mlir::OpRewritePattern { public: CleanupPattern(mlir::MLIRContext *context) : mlir::OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite(mlir::func::FuncOp func, mlir::PatternRewriter &rewriter) const override { mlir::scf::ForOp outermostFor; mlir::scf::ForOp innermostFor; ExtractOpTy extractOp; InsertOpTy insertOp; unsigned yieldOperandNumber; mlir::Value dstTensor; func.walk([&](ExtractOpTy currExtractOp) { // Check that we deal with only static sizes and strides; // otherwise the conservative hoisting implementation would fail if (!IndexedOpInfo::hasAllStaticSizesAndStrides( currExtractOp)) { return mlir::WalkResult::skip(); } // First check that the extract op is embedded in a for loop mlir::scf::ForOp currInnermostFor = llvm::dyn_cast(currExtractOp->getParentOp()); if (!currInnermostFor) return mlir::WalkResult::skip(); if (!currInnermostFor.isDefinedOutsideOfLoop( IndexedOpInfo::getTensor(currExtractOp))) { return mlir::WalkResult::skip(); } mlir::scf::YieldOp yield = llvm::dyn_cast( currInnermostFor.getBody()->getTerminator()); // Next, find a chain of the 3 relevant operations: // // %s = tensor.extract %T[...] (or tensor.extract_slice) // %U' = tensor.insert %s into %U[...] (or tensor.insert_slice) // scf.yield %U' // // with all affine indexes. The indexes used by the extract must // be a suffix of the indexes of the insert op and the tensor // that is updated must not be derived from the tensor from // which the element is extracted. InsertOpTy currInsertOp; std::pair, llvm::SmallVector> qaLoopsExtract = getLoopsForCandidateIndexes(currExtractOp); for (mlir::OpOperand &extractUser : currExtractOp->getUses()) { if (InsertOpTy currInsertOp = llvm::dyn_cast(extractUser.getOwner())) { // Insert and extract must be in the same loop if (currInsertOp->getParentOp() != currInnermostFor.getOperation()) continue; // Insert op must also have static sizes and strides in // order to be hoistable if (!IndexedOpInfo::hasAllStaticSizesAndStrides( currInsertOp)) { continue; } if (!isSoleUser(currInsertOp.getResult(), yield)) continue; // Insertion must be into a tensor that is a region // iteration argument and must be the only user of that // value (since it will be erased upon hoisting) if (!valueIsRegionIterArgOf(currInsertOp.getDest(), currInnermostFor) || !isSoleUser(currInsertOp.getDest(), currInsertOp)) continue; std::pair, llvm::SmallVector> qaLoopsInsert = getLoopsForCandidateIndexes(currInsertOp); llvm::DenseSet qaLoopsBoth = intersectSets(qaLoopsExtract.first, qaLoopsInsert.first); if (qaLoopsBoth.size() == 0) continue; if (!qaLoopsBoth.contains(currInnermostFor)) continue; // Indexes must appear in the same order and the same number // of times, such that the extracted and inserted slices // after the cleanup have the same shape and order llvm::SmallVector orderedLoopsExtract = filterVector(qaLoopsExtract.second, qaLoopsBoth); llvm::SmallVector orderedLoopsInsert = filterVector(qaLoopsInsert.second, qaLoopsBoth); if (orderedLoopsExtract.size() != orderedLoopsInsert.size() || orderedLoopsExtract != orderedLoopsInsert) { continue; } mlir::scf::ForOp candidateOutermostFor; mlir::scf::ForOp candidateInnermostFor; unsigned currYieldOperandNumber = getOperandIndexForValue(yield, currInsertOp.getResult()); getLongestPerfectLoopnest( qaLoopsBoth, candidateInnermostFor, candidateOutermostFor, [&](mlir::scf::ForOp parent, mlir::scf::ForOp child) { // Check that the nest from the outermost to // the innermost loop is perfect and forwards // the result of the innermost loop to the // outermost one mlir::scf::YieldOp parentYield = llvm::dyn_cast( parent.getBody()->getTerminator()); if (!parent.isDefinedOutsideOfLoop( IndexedOpInfo::getTensor(currExtractOp))) return false; return parentYield.getOperand(currYieldOperandNumber) == child.getResult(currYieldOperandNumber); }); insertOp = currInsertOp; yieldOperandNumber = currYieldOperandNumber; outermostFor = candidateOutermostFor; innermostFor = currInnermostFor; extractOp = currExtractOp; dstTensor = candidateOutermostFor.getInitArgs()[currYieldOperandNumber]; return mlir::WalkResult::interrupt(); } } return mlir::WalkResult::skip(); }); if (!outermostFor) return mlir::failure(); mlir::Value slice = hoistExtractOp(rewriter, outermostFor, extractOp); mlir::Value insertedSlice = hoistInsertOp(rewriter, slice, dstTensor, outermostFor, insertOp); rewritePerfectLoopNestWithReplacedNthResult( rewriter, innermostFor, outermostFor, yieldOperandNumber, insertedSlice); return mlir::success(); } }; // Folds the operation `op` by recursively folding all // producers. Occurrences of `arg` are replaced with `argVal`. // All encountered operations must produce a single result. static mlir::Attribute fold(mlir::Operation *op, mlir::Value arg, mlir::Attribute argVal) { assert(op->getNumResults() == 1); // Check if `arg` needs to be replaced with `argVal` if (op->getResult(0) == arg) return argVal; // Constants are just folded to their value attributes if (mlir::arith::ConstantOp cstOp = llvm::dyn_cast(op)) { return cstOp.getValue(); } // Recursively fold all producers and collect the folding results // for each operand llvm::SmallVector foldedOperands; for (mlir::OpOperand &operand : op->getOpOperands()) { mlir::Operation *producer = operand.get().getDefiningOp(); assert(producer); foldedOperands.push_back(fold(producer, arg, argVal)); } // Invoke the folder for this operation llvm::SmallVector res; mlir::LogicalResult foldRes = op->fold(foldedOperands, res); assert(foldRes.succeeded()); assert(res.size() == 1); mlir::Attribute resAttr = res[0].dyn_cast(); assert(resAttr); return resAttr; } // Folding pattern that collapses operations on constant dense tensors // into a new constant. E.g., // // %cst = arith.constant dense<...> : tensor // %res = scf.for %i = %c0 to %cN { // %cst_i9 = tensor.extract %cst[%i] // %cst_i64 = arith.extui %cst_i9 : i64 // ... // } // // becomes: // // %cst = arith.constant dense<...> : tensor // %res = scf.for %i = %c0 to %cN { // %cst_i64 = tensor.extract %cst[%i] // ... // } class ConstantDenseFoldingPattern : public mlir::OpRewritePattern { protected: // Checks if an operation is foldable static bool isFoldableOp(mlir::Operation *op) { return op->getNumResults() == 1 && mlir::isPure(op) && llvm::TypeSwitch(op) .Case([](auto op) { return true; }) .Default([](auto op) { return false; }); } // Checks if `v` can be calculated statically given that the values // in `foldables` are static. The function recursively collects all // intermediate values which have been found to be static in // `foldables`. static bool isFoldableValue(mlir::Value v, llvm::DenseSet &foldables) { if (foldables.contains(v)) return true; mlir::Operation *op = v.getDefiningOp(); if (!op || !isFoldableOp(op)) return false; if (llvm::all_of(op->getOperands(), [&](mlir::Value v) { return isFoldableValue(v, foldables); })) { for (mlir::Value v : op->getOperands()) foldables.insert(v); return true; } return false; } // Generates a flat index from a tensor with the shape `shape` // indexed by `idx` static int64_t linearizeIndex(llvm::ArrayRef shape, llvm::ArrayRef idx) { int64_t flatIdx = 0; int64_t mul = 1; int64_t n = shape.size(); for (int64_t i = 0; i < n; i++) { flatIdx += mul * idx[n - i - 1]; mul *= shape[n - i - 1]; } return flatIdx; }; public: ConstantDenseFoldingPattern(mlir::MLIRContext *context) : mlir::OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite(mlir::func::FuncOp func, mlir::PatternRewriter &rewriter) const override { mlir::tensor::ExtractOp extractOp; mlir::Operation *targetOp = nullptr; llvm::SmallVector nest; llvm::SmallVector idxMap; llvm::SmallVector idxBounds; mlir::arith::ConstantOp cdo; mlir::RankedTensorType constantType; mlir::Type origElementType; mlir::Type foldedElementType; func.walk([&](mlir::tensor::ExtractOp currExtractOp) { // Check that the extraction in on a value produced by an // `arith.constant_dense` operation. mlir::arith::ConstantOp currCdo = llvm::dyn_cast_or_null( currExtractOp.getTensor().getDefiningOp()); if (!currCdo) return mlir::WalkResult::skip(); if (!isSoleUser(currCdo.getResult(), currExtractOp)) return mlir::WalkResult::skip(); mlir::RankedTensorType currConstantType = currExtractOp.getTensor() .getType() .dyn_cast(); if (!currConstantType) return mlir::WalkResult::skip(); // First check that the extract op is embedded in a for loop mlir::scf::ForOp currInnermostFor = llvm::dyn_cast(currExtractOp->getParentOp()); if (!currInnermostFor) return mlir::WalkResult::skip(); llvm::DenseSet nestUnsorted; llvm::SmallVector currIdxMap; llvm::SmallVector currIdxBounds; // Make sure that the extract operation uses only quasi affine // expressions on IVs, where each index uses at most a single // IV. for (mlir::Value idx : currExtractOp.getIndices()) { mlir::scf::ForOp forOp; LoopsBoundsAndStep bas; if (!isQuasiAffineIVExpressionWithConstantStep(idx, &forOp, &bas)) return mlir::WalkResult::skip(); if (forOp) nestUnsorted.insert(forOp); currIdxBounds.push_back(bas); currIdxMap.push_back(forOp); } llvm::DenseSet foldables; foldables.insert(currExtractOp.getResult()); if (!currExtractOp.getResult().hasOneUse()) return mlir::WalkResult::skip(); mlir::Operation *firstUser = currExtractOp.getResult().getUses().begin()->getOwner(); mlir::Operation *currOp = firstUser; mlir::Operation *currTargetOp = nullptr; mlir::Type currOrigElementType = currConstantType.getElementType(); mlir::Type currFoldedElementType = currOrigElementType; // Walk down the def-use chain from the extract operation until // an operation is found that is not foldable while (true) { if (!isFoldableOp(currOp)) break; if (currOp->getNumResults() != 1 || !currOp->getResult(0).hasOneUse() || currOp->getParentOp() != currInnermostFor.getOperation()) break; if (!llvm::all_of(currOp->getOperands(), [&](mlir::Value v) { return isFoldableValue(v, foldables); })) break; currFoldedElementType = currOp->getResult(0).getType(); currTargetOp = currOp; currOp = currOp->getUses().begin()->getOwner(); } if (!currTargetOp) return mlir::WalkResult::skip(); // Check constraints on the index space of the extract // operation. I.e., if the type changes during the folding, // ensure that the index space covers the entire tensor and that // there are no out-of-bounds accesses. for (auto it : llvm::enumerate(currExtractOp.getIndices())) { mlir::scf::ForOp forOp; LoopsBoundsAndStep bas; mlir::Value idx = it.value(); size_t i = it.index(); if (!isQuasiAffineIVExpressionWithConstantStep(idx, &forOp, &bas)) return mlir::WalkResult::skip(); // If the type changes by the folding, the entire tensor needs // to be rewritten if (currFoldedElementType != currOrigElementType) { if (bas.lb != 0 || bas.ub != currConstantType.getDimSize(i) || bas.step != 1) return mlir::WalkResult::skip(); } // Otherwise, just make sure that there are no out-of-bounds // accesses else { if (bas.ub - bas.step >= currConstantType.getDimSize(i)) return mlir::WalkResult::skip(); } } extractOp = currExtractOp; targetOp = currTargetOp; nest = sortLoopsOutermostToInnermost(nestUnsorted); idxMap = std::move(currIdxMap); idxBounds = std::move(currIdxBounds); cdo = currCdo; constantType = currConstantType; origElementType = currOrigElementType; foldedElementType = currFoldedElementType; return mlir::WalkResult::interrupt(); }); if (!targetOp) return mlir::failure(); // Original tensor of constants auto denseVals = cdo.getValueAttr() .cast() .getValues(); // Updated tensor of constants initialized with original values SmallVector newDenseVals(denseVals.begin(), denseVals.end()); mlir::SmallVector tripCounts = map( nest, [](mlir::scf::ForOp forOp) { return getStaticTripCount(forOp); }); // Number of iterations already performed for each loop mlir::SmallVector trips(nest.size(), 0); // current index mlir::SmallVector idx = map(idxBounds, [](LoopsBoundsAndStep &bas) { return bas.lb; }); // Maps the index of each IV in the loop nest to the indexes of // the extract operation mlir::SmallVector> revIdxMap(nest.size()); for (size_t i = 0; i < idxMap.size(); i++) { for (size_t j = 0; j < nest.size(); j++) { if (nest[j] == idxMap[i]) { revIdxMap[j].push_back(i); break; } } } size_t i = nest.size() - 1; // Reset the trip count for a loop back to zero and reinitializes // all indexes using the associated IV auto resetTrips = [&](size_t loopIdx) { trips[loopIdx] = 0; for (size_t i : revIdxMap[loopIdx]) { idx[i] = idxBounds[i].lb; } }; // Increases the trip count of a loop by one and calculates the // next value of all indexes using the associated IV auto incTrips = [&](size_t loopIdx) { trips[loopIdx] += 1; for (size_t i : revIdxMap[loopIdx]) { idx[i] += idxBounds[i].step; } }; // Iterate over the entire iteration space of the loop nest. The // variable i represents the index of the loop that is currently // stepped in the nest while (true) { // Loop has reached its maximum trip count. If the loop ist the // first in the nest, the entire space has been // covered. Otherwise, reset the trip count of the current loop // and step the loop above. if (trips[i] == tripCounts[i]) { if (i == 0) break; resetTrips(i); i--; incTrips(i); } else { // Trip count of the current loop hasn't been reached. If this // is the innermost loop, calculate a new index, fold all // values and write the result to the new tensor of // constants. Otherwise, switch to the next loop in the nest. if (i == nest.size() - 1) { size_t flatIdx = linearizeIndex(constantType.getShape(), idx); mlir::Attribute newVal = fold(targetOp, extractOp.getResult(), denseVals[flatIdx]); newDenseVals[flatIdx] = newVal; incTrips(i); } else { i++; } } } // Create a new `arith.constant` operation with the updated tensor // of constants mlir::DenseElementsAttr newDenseElementsAttr = mlir::DenseElementsAttr::get( mlir::RankedTensorType::get(constantType.getShape(), foldedElementType), newDenseVals); rewriter.setInsertionPoint(cdo); mlir::arith::ConstantOp newCdo = rewriter.create( cdo.getLoc(), newDenseElementsAttr); rewriter.setInsertionPoint(targetOp); // Replace the last op in the chain of foldable operations with a // `tensor.extract` op on the new tensor of constants. rewriter.replaceOpWithNewOp( targetOp, targetOp->getResult(0).getType(), newCdo, extractOp.getIndices()); return mlir::success(); } }; class TensorAllocationCleanupPattern : public mlir::OpRewritePattern { public: TensorAllocationCleanupPattern(mlir::MLIRContext *context) : mlir::OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite(mlir::func::FuncOp func, mlir::PatternRewriter &rewriter) const override { mlir::LogicalResult res = mlir::failure(); func.walk([&](mlir::bufferization::AllocTensorOp allocOp) { if (allocOp->getUses().empty()) { rewriter.eraseOp(allocOp); res = mlir::success(); } }); return res; } }; class BatchingPass : public BatchingBase { public: BatchingPass(int64_t maxBatchSize) : maxBatchSize(maxBatchSize) {} void runOnOperation() override { mlir::Operation *op = getOperation(); mlir::RewritePatternSet patterns(op->getContext()); patterns.add(op->getContext(), maxBatchSize); patterns .add, CleanupPattern, ConstantDenseFoldingPattern, TensorAllocationCleanupPattern>( op->getContext()); if (mlir::applyPatternsAndFoldGreedily(op, std::move(patterns)).failed()) this->signalPassFailure(); } private: int64_t maxBatchSize; }; std::unique_ptr> createBatchingPass(int64_t maxBatchSize) { return std::make_unique(maxBatchSize); } } // namespace concretelang } // namespace mlir