diff --git a/compiler/include/concretelang/Transforms/Passes.h b/compiler/include/concretelang/Transforms/Passes.h index 987f16380..72f94f2e7 100644 --- a/compiler/include/concretelang/Transforms/Passes.h +++ b/compiler/include/concretelang/Transforms/Passes.h @@ -18,6 +18,7 @@ namespace mlir { namespace concretelang { std::unique_ptr> createForLoopToParallel(); +std::unique_ptr> createBatchingPass(); } // namespace concretelang } // namespace mlir diff --git a/compiler/include/concretelang/Transforms/Passes.td b/compiler/include/concretelang/Transforms/Passes.td index bff23255c..b0d99a513 100644 --- a/compiler/include/concretelang/Transforms/Passes.td +++ b/compiler/include/concretelang/Transforms/Passes.td @@ -11,4 +11,11 @@ def ForLoopToParallel : Pass<"for-loop-to-parallel", "mlir::ModuleOp"> { let dependentDialects = ["mlir::scf::SCFDialect"]; } +def Batching : Pass<"concrete", "mlir::ModuleOp"> { + let summary = + "Hoists operation for which a batched version exists out of loops applying " + "the operation to values stored in a tensor."; + let constructor = "mlir::concretelang::createBatchingPass()"; +} + #endif diff --git a/compiler/lib/Transforms/Batching.cpp b/compiler/lib/Transforms/Batching.cpp new file mode 100644 index 000000000..f8f104e8b --- /dev/null +++ b/compiler/lib/Transforms/Batching.cpp @@ -0,0 +1,920 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace mlir { +namespace concretelang { +/// Checks if `forOp` has constant bounds and a constant step. +static bool isStaticLoop(mlir::scf::ForOp forOp, int64_t *ilb = nullptr, + int64_t *iub = nullptr, int64_t *istep = nullptr) { + + mlir::Operation *lbOp = forOp.getLowerBound().getDefiningOp(); + mlir::Operation *ubOp = forOp.getUpperBound().getDefiningOp(); + mlir::Operation *stepOp = forOp.getStep().getDefiningOp(); + + if (!lbOp || !ubOp || !stepOp) + return false; + + mlir::arith::ConstantIndexOp lb = + llvm::dyn_cast(lbOp); + mlir::arith::ConstantIndexOp ub = + llvm::dyn_cast(ubOp); + mlir::arith::ConstantIndexOp step = + llvm::dyn_cast(stepOp); + + if (lb && ub && step) { + if (ilb) + *ilb = lb.value(); + + if (iub) + *iub = ub.value(); + + if (istep) + *istep = step.value(); + + return true; + } + + return false; +} + +/// Checks if `forOp` is a loop with a lower bound of 0, a constant +/// upper bound and a constant step of 1 +static bool isStaticNormalizedLoop(mlir::scf::ForOp forOp) { + int64_t lb; + int64_t step; + + if (isStaticLoop(forOp, &lb, nullptr, &step)) + return (lb == 0 && step == 1); + + return false; +} + +/// Returns an `OpFoldResult` with an `IntegerAttr` value if `v` is +/// produced by a constant, otherwise an `OpFoldResult` containing `v` +/// itself. +static mlir::OpFoldResult getValueAsOpFoldResult(mlir::Value v) { + if (mlir::arith::ConstantOp cstOp = + dyn_cast_or_null(v.getDefiningOp())) { + return cstOp.getValue(); + } + + return v; +} + +/// Assumes that `v` is a constant index operation and returns the +/// constant value as an `int64_t`. +static int64_t getConstantIndexValue(mlir::Value v) { + assert(v.getDefiningOp() && + llvm::isa(*v.getDefiningOp())); + + return llvm::dyn_cast(*v.getDefiningOp()) + .value(); +} + +/// Returns a `Value` from an `OpFoldResult`. If the `OpFoldResult` is +/// a already a value, the value is returned as is. Otherwise a +/// constant op is created using `builder`. +static mlir::Value getOpFoldResultAsValue(mlir::ImplicitLocOpBuilder &builder, + mlir::OpFoldResult v) { + if (v.is()) { + return v.dyn_cast(); + } else { + return builder.create( + v.get().cast().getInt()); + } +} + +/// Performs an arithmetic operation on `a` and `b`, where both values +/// can be any combination of `IntegerAttr` and `Value`. +template +mlir::OpFoldResult opFoldExpr(mlir::ImplicitLocOpBuilder &builder, + mlir::OpFoldResult a, mlir::OpFoldResult b) { + static IsNeutralElementFunctor isNeutralElement; + + auto exprValVal = [&](mlir::Value a, mlir::Value b) -> mlir::Value { + return builder.create(a, b); + }; + + auto exprAttrVal = [&](mlir::IntegerAttr attr, mlir::Value v) -> mlir::Value { + mlir::Value cst = + builder.create(attr.getInt()); + + return exprValVal(cst, v); + }; + + auto exprValAttr = [&](mlir::Value v, mlir::IntegerAttr attr) -> mlir::Value { + mlir::Value cst = + builder.create(attr.getInt()); + + return exprValVal(v, cst); + }; + + auto exprAttrAttr = [&](mlir::IntegerAttr a, + mlir::IntegerAttr b) -> mlir::IntegerAttr { + static ArithFunctor f; + return builder.getIndexAttr(f(a.getInt(), b.getInt())); + }; + + if (a.is()) { + if (b.is()) { + return exprValVal(a.get(), b.get()); + } else { + mlir::IntegerAttr bAttr = + b.get().cast(); + + if (isNeutralElement(bAttr.getValue().getSExtValue())) { + return a; + } else { + return exprValAttr(a.get(), bAttr); + } + } + } else { + mlir::IntegerAttr aAttr = + a.get().cast(); + + if (b.is()) { + return exprAttrVal(aAttr, b.get()); + } else { + mlir::IntegerAttr bAttr = + b.get().cast(); + + if (isNeutralElement(bAttr.getValue().getSExtValue())) + return a; + else + return exprAttrAttr(aAttr, bAttr); + } + } +} + +/// Helper class whose call operator compares its argument to the +/// constant value `cst`. +template struct comparator { + bool operator()(const T &val) { return cst == val; } +}; + +/// Divides `a` by `b`, where both values can be any combination of +/// `IntegerAttr` and `Value`. +static mlir::OpFoldResult opFoldDiv(mlir::ImplicitLocOpBuilder &builder, + mlir::OpFoldResult a, + mlir::OpFoldResult b) { + return opFoldExpr, + comparator>(builder, a, b); +} + +/// Subtracts `b` from `a`, where both values can be any combination +/// of `IntegerAttr` and `Value`. +static mlir::OpFoldResult opFoldSub(mlir::ImplicitLocOpBuilder &builder, + mlir::OpFoldResult a, + mlir::OpFoldResult b) { + return opFoldExpr, + comparator>(builder, a, b); +} + +/// Convenience class that holds all parameters of a loop +struct BoundsAndStep { + int64_t lb; + int64_t ub; + int64_t step; + + BoundsAndStep operator+(const BoundsAndStep &other) { + return BoundsAndStep{lb + other.lb, ub + other.ub, step + other.step}; + } + BoundsAndStep operator-(const BoundsAndStep &other) { + return BoundsAndStep{lb - other.lb, ub - other.ub, step - other.step}; + } + BoundsAndStep operator*(const BoundsAndStep &other) { + return BoundsAndStep{lb * other.lb, ub * other.ub, step * other.step}; + } + BoundsAndStep operator/(int64_t d) { + return BoundsAndStep{lb / d, ub / d, step / d}; + } +}; + +/// Returns the lower bound, upper bound and step of the quasi-affine +/// expression `expr` on the the induction variable from a for +/// operation. +static llvm::Optional +getBoundsOfQuasiAffineIVExpression(mlir::Value expr, mlir::scf::ForOp forOp) { + // Base case: expression is the induction variable itself -> return + // loop bounds + if (expr == forOp.getInductionVar()) { + return BoundsAndStep{getConstantIndexValue(forOp.getLowerBound()), + getConstantIndexValue(forOp.getUpperBound()), + getConstantIndexValue(forOp.getStep())}; + } + // Arithmetic expression + else if (mlir::Operation *op = expr.getDefiningOp()) { + if (llvm::isa(op)) { + + llvm::Optional lhs = + getBoundsOfQuasiAffineIVExpression(op->getOperand(0), forOp); + llvm::Optional rhs = + getBoundsOfQuasiAffineIVExpression(op->getOperand(1), forOp); + + if (!lhs.hasValue() || !rhs.hasValue()) + return llvm::None; + + if (llvm::isa(op)) + return *lhs + *rhs; + else if (llvm::isa(op)) + return *lhs - *rhs; + else if (llvm::isa(op)) + return (*lhs) * (*rhs); + else if (llvm::isa(op)) { + assert(rhs->ub == rhs->lb && rhs->step == 0 && + "Expression for divisor references IV"); + int64_t rhsVal = rhs->ub; + + assert(rhsVal != 0 && "Division by zero"); + + // If the step value of the subexpression is not a multiple of + // the divisor, there may be two iterations with the same + // value. Conservatively bail out. + if (lhs->step % rhsVal != 0) + return llvm::None; + + return *lhs / rhsVal; + } + } + // Base case: constant -> return constant value + else if (llvm::isa(expr.getDefiningOp())) { + mlir::arith::ConstantIndexOp cst = + llvm::dyn_cast(expr.getDefiningOp()); + return BoundsAndStep{cst.value(), cst.value(), 0}; + } + } + + llvm_unreachable("Expression could not be evaluated statically"); +} + +/// Checks whether the expression `expr` is a quasi-affine expression +/// on a single induction variable. If an induction variable is +/// referenced, the owning for loop is returned in `*owningForOp`. +static bool isQuasiAffineIVExpression(mlir::Value expr, + mlir::scf::ForOp *owningForOp = nullptr) { + if (mlir::Operation *op = expr.getDefiningOp()) { + if (llvm::isa(op)) { + return true; + } else if (llvm::isa(op)) { + mlir::scf::ForOp forLHS; + mlir::scf::ForOp forRHS; + + if (!isQuasiAffineIVExpression(op->getOperand(0), &forLHS) || + !isQuasiAffineIVExpression(op->getOperand(1), &forRHS)) { + return false; + } else { + // Check that appearances of IVs refer to the same IV + if (forLHS && forRHS && forLHS != forRHS) + return false; + } + + // Assume that the expression is already canonicalized, so that + // IVs appear only in numerators and on one side of a + // multiplication subexpression + if ((llvm::isa(op) && forLHS && forRHS) || + (llvm::isa(op) && forRHS)) + return false; + + if (owningForOp != nullptr) { + if (forLHS) + *owningForOp = forLHS; + else if (forRHS) + *owningForOp = forRHS; + } + + return true; + } + + return false; + } + // Base case: Expression is an induction variable + else if (mlir::scf::ForOp forOp = scf::getForInductionVarOwner(expr)) { + if (owningForOp != nullptr) + *owningForOp = forOp; + + return true; + } + + return false; +} + +/// Invokes `callback` for every subexpression of `expr` that is an +/// induction variable with the corresponding for operation as the +/// argument. Stops if the callback function returns `true`. +static void forEveryReferencedInductionVarBreakable( + mlir::Value expr, llvm::function_ref callback) { + if (mlir::scf::ForOp forOp = scf::getForInductionVarOwner(expr)) { + callback(forOp); + } else { + if (expr.getDefiningOp()) { + for (mlir::Value operand : expr.getDefiningOp()->getOperands()) { + forEveryReferencedInductionVarBreakable(operand, callback); + } + } + } +} + +/// Invokes `callback` for every subexpression of `expr` that is an +/// induction variable with the corresponding for operation as the +/// argument. +static void forEveryReferencedInductionVar( + mlir::Value expr, llvm::function_ref callback) { + forEveryReferencedInductionVarBreakable(expr, + [&](mlir::scf::ForOp forOp) -> bool { + callback(forOp); + return false; + }); +} + +/// Returns the loop associated to the first induction variable +/// encountered in a subexpression of `expr`. +static mlir::scf::ForOp findFirstReferencedInductionVar(mlir::Value expr) { + mlir::scf::ForOp ret; + + forEveryReferencedInductionVarBreakable(expr, + [&](mlir::scf::ForOp forOp) -> bool { + ret = forOp; + return true; + }); + + return ret; +} + +/// Checks if `expr` is a quasi affine expression on a single +/// induction variable, for which the increment of the induction +/// variable with the step of the associated for loop results in a +/// constant incrementation of when evaluating the expression. +/// +/// E.g., this is true for the expression `i+1` for any constant step +/// size, since `((i+step)+1) - (i+1)` is constant. This is also true +/// for `(i+5)/7` for a step size that is a multiple of `7`, but false +/// for any other step size. +static bool +isQuasiAffineIVExpressionWithConstantStep(mlir::Value expr, + mlir::scf::ForOp *forOp = nullptr) { + mlir::scf::ForOp tmpForOp; + + if (isQuasiAffineIVExpression(expr, &tmpForOp)) { + llvm::Optional bas = + getBoundsOfQuasiAffineIVExpression(expr, tmpForOp); + + if (bas.hasValue()) { + if (forOp != nullptr) + *forOp = tmpForOp; + return true; + } + } + + return false; +} + +/// Hoists a an operation embedded into a loop nest that and that is +/// indexed using quasi-affine expressions of the loops' IVs as a +/// `tensor.extract_slice` outside of the outermost loop +template +mlir::Value hoistIndexedOp( + mlir::PatternRewriter &rewriter, mlir::scf::ForOp outermostFor, + mlir::Value tensorizedOperands, EltWiseOp eltwiseOp, + llvm::function_ref, llvm::ArrayRef, + llvm::ArrayRef, llvm::ArrayRef)> + tensorOpBuilder) { + llvm::SmallVector offsets; + llvm::SmallVector sizes; + llvm::SmallVector strides; + llvm::SmallVector ivIndexedDims; + + rewriter.setInsertionPoint(outermostFor); + mlir::ImplicitLocOpBuilder ilob(eltwiseOp.getLoc(), rewriter); + + for (mlir::Value idx : eltwiseOp.getIndices()) { + mlir::scf::ForOp forOp; + bool isAffine = isQuasiAffineIVExpression(idx, &forOp); + + if (isAffine && forOp) { + + llvm::Optional bas = + getBoundsOfQuasiAffineIVExpression(idx, forOp); + + assert(bas.hasValue()); + assert(bas->step != 0); + + offsets.push_back(rewriter.getIndexAttr(bas->lb)); + sizes.push_back(rewriter.getIndexAttr((bas->ub - bas->lb) / bas->step)); + strides.push_back(rewriter.getIndexAttr(bas->step)); + + ivIndexedDims.push_back(true); + } else if (isAffine || outermostFor.isDefinedOutsideOfLoop(idx)) { + offsets.push_back(getValueAsOpFoldResult(idx)); + sizes.push_back(rewriter.getIndexAttr(1)); + strides.push_back(rewriter.getIndexAttr(1)); + ivIndexedDims.push_back(false); + } + } + + return tensorOpBuilder(ilob, tensorizedOperands, offsets, sizes, strides, + ivIndexedDims); +} + +/// Hoists a tensor.extract operation embedded into a loop nest as a +/// `tensor.extract_slice` outside of the outermost loop of the nest +static mlir::Value hoistExtractOp(mlir::PatternRewriter &rewriter, + mlir::scf::ForOp outermostFor, + mlir::tensor::ExtractOp extractOp) { + return hoistIndexedOp( + rewriter, outermostFor, extractOp.getTensor(), extractOp, + [](mlir::ImplicitLocOpBuilder &builder, mlir::Value tensorizedOperands, + llvm::ArrayRef offsets, + llvm::ArrayRef sizes, + llvm::ArrayRef strides, + llvm::ArrayRef ivIndexedDims) -> mlir::Value { + mlir::tensor::ExtractSliceOp slice = + builder.create( + tensorizedOperands, offsets, sizes, strides); + + // The extract slice operation above preserves non-IV-indexed + // dimensions of the original extract operation as 1-sized + // dimensions, e.g., a `tensor.extract[cst, i, j, cst, k]` + // results in a slice with the shape `1xMxNx1xK` (where M, N + // and K are the maximum values for i, j and k, assuming a + // loop step of 1). + // + // If there is any non-IV-indexed dimension, add a collapse + // shape operation that collapses the 1-sized dimensions + // originating from non-IV-indexed dimensions of the extract + // operation into the IV-indexed dimensions. I.e., in the + // above example, produce a slice with the shape `MxNxK` + // rather than `1xMxNx1xK`. + if (llvm::all_of(ivIndexedDims, [](bool v) { return v; })) { + return slice; + } else { + llvm::SmallVector collapseGroups; + mlir::ReassociationIndices currCollapseGroup; + + bool prefixDone = false; + for (auto i : llvm::enumerate(ivIndexedDims)) { + // If this is a non-IV-indexed dimension, accumulate + // dimension in the current group of collapsed dimensions + if (!i.value()) { + currCollapseGroup.push_back(i.index()); + } else { + // If there were only non-IV-indexed dimensions before, + // add this first IV-indexed dimension to the current + // group of collapsed dimensions and try to accumulate + // with following, non-IV-indexed dimensions. + if (!prefixDone) { + currCollapseGroup.push_back(i.index()); + prefixDone = true; + } else { + collapseGroups.push_back(currCollapseGroup); + currCollapseGroup = mlir::ReassociationIndices(); + currCollapseGroup.push_back(i.index()); + } + } + } + + // Add last collapse group for trailing series of + // non-IV-indexed dimensions + if (!currCollapseGroup.empty()) + collapseGroups.push_back(currCollapseGroup); + + mlir::tensor::CollapseShapeOp cso = + builder.create(slice, + collapseGroups); + + return cso; + } + }); +} + +/// Hoists a tensor.insert operation embedded into a loop nest as a +/// tensor.insert_slice outside of the outermost loop of the nest +static mlir::Value hoistInsertOp(mlir::PatternRewriter &rewriter, + mlir::Value tensorizedOperands, + mlir::Value targetTensor, + mlir::scf::ForOp outermostFor, + mlir::tensor::InsertOp insertOp) { + return hoistIndexedOp( + rewriter, outermostFor, targetTensor, insertOp, + [&](mlir::ImplicitLocOpBuilder &builder, mlir::Value targetTesor, + llvm::ArrayRef offsets, + llvm::ArrayRef sizes, + llvm::ArrayRef strides, + llvm::ArrayRef ivIndexedDims) -> mlir::Value { + return builder.create( + tensorizedOperands, targetTesor, offsets, sizes, strides); + }); +} + +/// Pattern that replaces a batchable operation embedded into a loop +/// nest with the batched version of the operation, e.g., +/// +/// scf.for %i = c0 to %cN step %c1 { +/// scf.for %j = c0 to %cM step %c1 { +/// scf.for %k = c0 to %cK step %c1 { +/// %s = tensor.extract %T[%i, %j, %k] +/// %res = batchable_op %s +/// ... +/// } +/// } +/// } +/// +/// is replaced with: +/// +/// %batchedSlice = tensor.extract_slice +/// %T[%c0, %c0, %c0] [%cN, %cM, %cK] [%c1, %c1, %c1] +/// %flatSlice = tensor.collapse_shape %batchedSlice +/// %resTFlat = batchedOp %flatSlice +/// %resT = tensor.expand_shape %resTFlat +/// +/// scf.for %i = c0 to %cN step %c1 { +/// scf.for %j = c0 to %cM step %c1 { +/// scf.for %k = c0 to %cK step %c1 { +/// %res = tensor.extract %resT[%i, %j, %k] +/// ... +/// } +/// } +/// } +/// +/// Any index may be a quasi-affine expression on a single loop +/// induction variable, but the distance between the result for any +/// two successive values of the IV must be constant. +class BatchingPattern : public mlir::OpRewritePattern { +public: + BatchingPattern(mlir::MLIRContext *context) + : mlir::OpRewritePattern(context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::func::FuncOp func, + mlir::PatternRewriter &rewriter) const override { + // Operation that will be hoisted out of the loop nest and + // replaced by the batched version of the operation + BatchableOpInterface targetOp; + + // Extract operation producing the scalar operand of the batchable + // operation + mlir::tensor::ExtractOp targetExtractOp; + + // Outermost for loop of the loop nest in which the batchable op + // is located + mlir::scf::ForOp outermostFor; + + // Find a batchable op which is embedded into a loop nest + func.walk([&](BatchableOpInterface scalarOp) { + // Is producer an extract op? + auto extractOp = llvm::dyn_cast( + scalarOp.getBatchableOperand().get().getDefiningOp()); + + if (!extractOp) + return mlir::WalkResult::skip(); + + // Is extract op embedded into a loop? + if (!isa(extractOp->getParentOp())) + return mlir::WalkResult::skip(); + + // Find outermost for loop whose IV is used as an index + mlir::scf::ForOp currOutermostFor; + + for (mlir::Value idx : extractOp.getIndices()) { + forEveryReferencedInductionVar(idx, [&](mlir::scf::ForOp forOp) { + if (!currOutermostFor || + forOp.getOperation()->isAncestor(currOutermostFor)) { + currOutermostFor = forOp; + } + }); + } + + if (!currOutermostFor) + return mlir::WalkResult::skip(); + + if (!currOutermostFor.isDefinedOutsideOfLoop(extractOp.getTensor())) + return mlir::WalkResult::skip(); + + // Now make sure that each index is either a quasi-affine + // expression on a single Loop IV, with a constant offset for + // all steps, a constant or defined above the outermost loop. + for (mlir::Value idx : extractOp.getIndices()) { + if (!currOutermostFor.isDefinedOutsideOfLoop(idx) && + !(idx.getDefiningOp() && + isa(idx.getDefiningOp())) && + !isQuasiAffineIVExpressionWithConstantStep(idx)) { + return mlir::WalkResult::skip(); + } + } + + // Verify that other args are defined outside the loop nest + if (!llvm::all_of(scalarOp.getNonBatchableOperands(), [&](mlir::Value v) { + return currOutermostFor.isDefinedOutsideOfLoop(v); + })) { + return mlir::WalkResult::skip(); + } + + // Make sure that there are only loops on the way from the + // outermost loop to the extract operation (i.e., loops are not + // embedded in other regions) + for (Operation *op = extractOp->getParentOp(); + op != currOutermostFor->getParentOp(); op = op->getParentOp()) { + if (!isa(op) || + !isStaticLoop(llvm::dyn_cast(op))) + return mlir::WalkResult::skip(); + } + + targetOp = scalarOp; + outermostFor = currOutermostFor; + targetExtractOp = extractOp; + + return mlir::WalkResult::interrupt(); + }); + + if (!targetOp) + return mlir::failure(); + + mlir::Value slice = hoistExtractOp(rewriter, outermostFor, targetExtractOp); + mlir::RankedTensorType sliceType = + slice.getType().cast(); + + // Flatten the tensor with the batched operands, so that they can + // be passed as a one-dimensional tensor to the batched operation + mlir::ReassociationIndices indices; + for (int64_t i = 0; i < sliceType.getRank(); i++) + indices.push_back(i); + + mlir::tensor::CollapseShapeOp flattenedSlice = + rewriter.create( + targetExtractOp.getLoc(), slice, + llvm::SmallVector{indices}); + + // Create the batched operation and pass flattened, batched + // operands + mlir::ImplicitLocOpBuilder ilob(targetExtractOp.getLoc(), rewriter); + mlir::Value batchedOpResult = + targetOp.createBatchedOperation(ilob, flattenedSlice); + + // Restore original shape of the batched operands for the result + // of the batched operation. Dimensions, result from indexing with + // non-loop-IVs are collapsed. + mlir::Type expandedBatchResultType = mlir::RankedTensorType::get( + sliceType.getShape(), batchedOpResult.getType() + .dyn_cast() + .getElementType()); + + mlir::Value expandedBatchResultTensor = + rewriter.create( + targetExtractOp.getLoc(), expandedBatchResultType, batchedOpResult, + llvm::SmallVector{indices}); + + // Collect all loop IVs from the extract op. These will be used to + // index the batched result tensor within the loop for consumers + // of the batchable op + llvm::SmallVector shiftedLoopIVs; + ilob.setInsertionPoint(targetOp); + + for (mlir::Value idx : targetExtractOp.getIndices()) { + mlir::scf::ForOp forOp = findFirstReferencedInductionVar(idx); + + if (forOp) { + // Loop has either a lower bound that is not 0 or a non-unit + // step. Shift the index to match the shape of the batched + // results. + if (!isStaticNormalizedLoop(forOp)) { + idx = getOpFoldResultAsValue( + ilob, + opFoldDiv( + ilob, + opFoldSub(ilob, + getValueAsOpFoldResult(forOp.getInductionVar()), + getValueAsOpFoldResult(forOp.getLowerBound())), + getValueAsOpFoldResult(forOp.getStep()))); + } + + shiftedLoopIVs.push_back(idx); + } + } + + rewriter.setInsertionPoint(targetOp); + rewriter.replaceOpWithNewOp( + targetOp, expandedBatchResultTensor, shiftedLoopIVs); + + return mlir::success(); + } +}; + +/// Cleanup pattern that replaces a perfect loop nest resulting from +/// repeated application of `BatchingPattern` that only contains a +/// `tensor.extract`, a `tensor.insert` and a `scf.yield` op in the +/// innermost loop, interleaved with side-effect-free operations, with +/// `tensor.extract_slice` and `tensor.insert_slice` ops. E.g., +/// +/// %res0 = scf.for %i = c0 to %cN step %c1 iter_args(%arg0 = %T1) { +/// %res1 = scf.for %j = c0 to %cM step %c1 iter_args(%arg1 = %arg0) { +/// %res2 = scf.for %k = c0 to %cK step %c1 iter_args(%arg2 = %arg1) { +/// %s = tensor.extract %T2[%i, %j, %k] +/// %TRes = tensor.insert %s into %arg2 +/// scf.yield %arg2 +/// } +/// scf.yield %res2 +/// } +/// scf.yield %res1 +/// } +/// +/// is replaced with: +/// +/// %tmp = tensor.extract_slice %T2 +/// %res0 = tensor.insert_slice %tmp into %T1 +/// +/// Any index may be a quasi-affine expression on a single loop +/// induction variable, but the distance between the result for any +/// two successive values of the IV must be constant. +class CleanupPattern : public mlir::OpRewritePattern { +public: + CleanupPattern(mlir::MLIRContext *context) + : mlir::OpRewritePattern(context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::func::FuncOp func, + mlir::PatternRewriter &rewriter) const override { + + mlir::scf::ForOp outermostFor; + mlir::tensor::ExtractOp extractOp; + mlir::tensor::InsertOp insertOp; + + func.walk([&](mlir::tensor::ExtractOp currExtractOp) { + // First check that the extract op is embedded in a for loop + mlir::scf::ForOp innermostFor = + llvm::dyn_cast(currExtractOp->getParentOp()); + + if (!innermostFor) + return mlir::WalkResult::skip(); + + // Next, check find a chain of the 3 relevant operations: + // + // %s = tensor.extract %T[...] + // %U' = tensor.insert %s into %U[...] + // scf.yield %U' + // + // All other operations must be side-effect-free. + mlir::Block &body = innermostFor.getRegion().front(); + + mlir::scf::YieldOp yield = + llvm::dyn_cast(body.getTerminator()); + + if (yield.getOperands().size() != 1) + return mlir::WalkResult::skip(); + + mlir::Operation *yieldOperandProducer = + yield.getOperand(0).getDefiningOp(); + + if (!yieldOperandProducer) + return mlir::WalkResult::skip(); + + mlir::tensor::InsertOp currInsertOp = + llvm::dyn_cast(yieldOperandProducer); + + if (!currInsertOp || + currInsertOp.getScalar() != currExtractOp.getResult()) + return mlir::WalkResult::skip(); + + if (!llvm::all_of(body, [&](mlir::Operation &op) { + return MemoryEffectOpInterface::hasNoEffect(&op); + })) { + return mlir::WalkResult::skip(); + } + + // Now check that all IVs used for indexation are from perfectly + // nested loops down to the parent loop of the extract op and + // identify the outermost loop of the nest + mlir::scf::ForOp currOutermostFor; + + if (currExtractOp.getIndices().size() != + currInsertOp.getIndices().size()) { + return mlir::WalkResult::skip(); + } + + // Find outermost for loop whose IV is used as an index and make + // sure that IVs are used for the same indexes of the extract + // and insert operations + for (auto it : + llvm::zip(currExtractOp.getIndices(), currInsertOp.getIndices())) { + mlir::Value extractIdx = std::get<0>(it); + mlir::Value insertIdx = std::get<1>(it); + + mlir::scf::ForOp extractForOp; + mlir::scf::ForOp insertForOp; + + if (!isQuasiAffineIVExpressionWithConstantStep(extractIdx, + &extractForOp) || + !isQuasiAffineIVExpressionWithConstantStep(insertIdx, + &insertForOp)) { + return mlir::WalkResult::skip(); + } + + if (insertForOp && extractForOp && + insertForOp.getOperation() == extractForOp.getOperation()) { + if (!currOutermostFor || + extractForOp.getOperation()->isAncestor(currOutermostFor)) { + currOutermostFor = extractForOp; + } + } + } + + if (!currOutermostFor) + return mlir::WalkResult::skip(); + + // Check that the nest from the outermost to the innermost loop + // is perfect and forwards the result of the innermost loop to + // the outermost one + for (mlir::Operation *forOp = innermostFor.getOperation()->getParentOp(); + forOp != currOutermostFor.getOperation()->getParentOp(); + forOp = forOp->getParentOp()) { + mlir::scf::ForOp currentFor = llvm::dyn_cast(forOp); + + // Parent is not a for loop + if (!currentFor) + return mlir::WalkResult::skip(); + + // Body must have exactly two ops: a for loop and a yield + mlir::Block &body = currentFor.getRegion().front(); + + if (body.begin() != std::prev(body.end(), 2)) + return mlir::WalkResult::skip(); + + mlir::scf::ForOp childFor = + llvm::dyn_cast(*body.begin()); + mlir::scf::YieldOp yield = + llvm::dyn_cast(*std::next(body.begin())); + + // Check that result of the for loop is forwarded + if (!childFor || !yield || yield.getOperands().size() != 1 || + childFor.getResults().size() != 1 || + yield.getOperand(0) != childFor.getResult(0)) + return mlir::WalkResult::skip(); + } + + outermostFor = currOutermostFor; + extractOp = currExtractOp; + insertOp = currInsertOp; + + return mlir::WalkResult::interrupt(); + }); + + if (!outermostFor) + return mlir::failure(); + + // Outermost for loop must produce exactly one result + if (outermostFor.getInitArgs().size() != 1) + return mlir::failure(); + + // Original tensor that is carried through the loops + mlir::Value initialTensor = outermostFor.getInitArgs().front(); + + mlir::Value slice = hoistExtractOp(rewriter, outermostFor, extractOp); + + mlir::Value insertedSlice = + hoistInsertOp(rewriter, slice, initialTensor, outermostFor, insertOp); + + // Replace the entire loop nest with the result of the insert + // slice op. Since this is a perfect loop nest with the innermost + // body only producing the tensor elements, there cannot be any + // other operations that produces results or that has side + // effects. + rewriter.replaceOp(outermostFor, {insertedSlice}); + + return mlir::success(); + } +}; + +class BatchingPass : public BatchingBase { +public: + void runOnOperation() override { + mlir::Operation *op = getOperation(); + + mlir::RewritePatternSet patterns(op->getContext()); + patterns.add(op->getContext()); + + if (mlir::applyPatternsAndFoldGreedily(op, std::move(patterns)).failed()) + this->signalPassFailure(); + } +}; + +std::unique_ptr> createBatchingPass() { + return std::make_unique(); +} + +} // namespace concretelang +} // namespace mlir diff --git a/compiler/lib/Transforms/CMakeLists.txt b/compiler/lib/Transforms/CMakeLists.txt index ac35ca335..7aabc1b90 100644 --- a/compiler/lib/Transforms/CMakeLists.txt +++ b/compiler/lib/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_library(ConcretelangTransforms + Batching.cpp ForLoopToParallel.cpp ADDITIONAL_HEADER_DIRS @@ -7,10 +8,12 @@ add_mlir_library(ConcretelangTransforms DEPENDS MLIRTransforms ConcretelangTransformsBufferizePassIncGen + ConcretelangInterfaces mlir-headers LINK_LIBS PUBLIC MLIRIR MLIRMemRefDialect MLIRTransforms + ConcretelangInterfaces )