// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include #include #include #include #include namespace mlir { namespace concretelang { /// Checks whether the expression `expr` is a quasi-affine expression /// on a single induction variable. If an induction variable is /// referenced, the owning for loop is returned in `*owningForOp`. bool isQuasiAffineIVExpression(mlir::Value expr, mlir::scf::ForOp *owningForOp) { if (mlir::Operation *op = expr.getDefiningOp()) { if (llvm::isa(op)) { return true; } else if (llvm::isa(op)) { mlir::scf::ForOp forLHS; mlir::scf::ForOp forRHS; if (!isQuasiAffineIVExpression(op->getOperand(0), &forLHS) || !isQuasiAffineIVExpression(op->getOperand(1), &forRHS)) { return false; } else { // Check that appearances of IVs refer to the same IV if (forLHS && forRHS && forLHS != forRHS) return false; } // Assume that the expression is already canonicalized, so that // IVs appear only in numerators and on one side of a // multiplication subexpression if ((llvm::isa(op) && forLHS && forRHS) || (llvm::isa(op) && forRHS)) return false; if (owningForOp != nullptr) { if (forLHS) *owningForOp = forLHS; else if (forRHS) *owningForOp = forRHS; } return true; } else if (mlir::AffineApplyOp applyOp = llvm::dyn_cast(op)) { // Affine.apply: make sure that all operands are either constant // expressions or using IVs of the same loop mlir::scf::ForOp ivOwner; for (mlir::Value operand : applyOp->getOperands()) { mlir::scf::ForOp thisOwner; if (!isQuasiAffineIVExpression(operand, &thisOwner)) return false; if (thisOwner) { if (!ivOwner) { ivOwner = thisOwner; } else { if (thisOwner != ivOwner) return false; } } } if (owningForOp != nullptr) *owningForOp = ivOwner; } return false; } // Base case: Expression is an induction variable else if (mlir::scf::ForOp forOp = scf::getForInductionVarOwner(expr)) { if (owningForOp != nullptr) *owningForOp = forOp; return true; } return false; } bool isQuasiAffineIVExpression(mlir::OpFoldResult expr, mlir::scf::ForOp *owningForOp) { if (mlir::Value dynExpr = expr.dyn_cast()) return isQuasiAffineIVExpression(dynExpr, owningForOp); return true; } /// Checks if `expr` is a quasi affine expression on a single /// induction variable, for which the increment of the induction /// variable with the step of the associated for loop results in a /// constant incrementation when evaluating the expression. /// /// E.g., this is true for the expression `i+1` for any constant step /// size, since `((i+step)+1) - (i+1)` is constant. This is also true /// for `(i+5)/7` for a step size that is a multiple of `7`, but false /// for any other step size. bool isQuasiAffineIVExpressionWithConstantStep(mlir::OpFoldResult expr, mlir::scf::ForOp *forOp, LoopsBoundsAndStep *basOut) { mlir::scf::ForOp tmpForOp; if (isQuasiAffineIVExpression(expr, &tmpForOp)) { std::optional bas = getBoundsOfQuasiAffineIVExpression(expr, tmpForOp); if (bas.has_value()) { if (forOp != nullptr) *forOp = tmpForOp; if (basOut != nullptr) *basOut = *bas; return true; } } return false; } static std::optional getBoundsOfAffineExpression(mlir::AffineExpr expr, llvm::ArrayRef dimBounds); // Returns the static bounds of an affine binary expression `expr` // given the bounds `dimBounds` for any dimension expression appearing // in the affine expression by determining the bounds for the left // hand side and right hand side separately and applying `combinator` // on them. static std::optional getBoundsOfAffineBinaryExpression( mlir::AffineExpr expr, llvm::ArrayRef dimBounds, llvm::function_ref combinator) { mlir::AffineBinaryOpExpr binExpr = expr.cast(); std::optional lhs = getBoundsOfAffineExpression(binExpr.getLHS(), dimBounds); if (!lhs.has_value()) return std::nullopt; std::optional rhs = getBoundsOfAffineExpression(binExpr.getRHS(), dimBounds); if (!rhs.has_value()) return std::nullopt; return combinator(lhs.value(), rhs.value()); } // Returns the static bounds of an affine expression given the bounds // `dimBounds` for any dimension expression appearing in the affine // expression. static std::optional getBoundsOfAffineExpression(mlir::AffineExpr expr, llvm::ArrayRef dimBounds) { // Cannot just use AffineExpr::compose() due to the check on // division switch (expr.getKind()) { case mlir::AffineExprKind::SymbolId: assert(false && "Symbol found in affine expression that should not contain sumbols"); break; case mlir::AffineExprKind::Constant: { int64_t cstVal = expr.cast().getValue(); return LoopsBoundsAndStep{cstVal, cstVal, 0}; } case mlir::AffineExprKind::DimId: { unsigned dimId = expr.cast().getPosition(); assert(dimId < dimBounds.size()); return dimBounds[dimId]; } case AffineExprKind::Add: return getBoundsOfAffineBinaryExpression( expr, dimBounds, [](LoopsBoundsAndStep lhs, LoopsBoundsAndStep rhs) { return lhs + rhs; }); case AffineExprKind::Mul: return getBoundsOfAffineBinaryExpression( expr, dimBounds, [](LoopsBoundsAndStep lhs, LoopsBoundsAndStep rhs) { return lhs * rhs; }); case AffineExprKind::Mod: case AffineExprKind::CeilDiv: case AffineExprKind::FloorDiv: { mlir::AffineBinaryOpExpr binExpr = expr.cast(); std::optional lhs = getBoundsOfAffineExpression(binExpr.getLHS(), dimBounds); std::optional rhs = getBoundsOfAffineExpression(binExpr.getRHS(), dimBounds); assert(rhs->ub == rhs->lb && rhs->step == 0 && "Expression for divisor references IV"); int64_t rhsVal = rhs->ub; assert(rhsVal != 0 && "Division by zero"); // If the step value of the subexpression is not a multiple of // the divisor, there may be two iterations with the same // value. Conservatively bail out. if (lhs->step % rhsVal != 0) return std::nullopt; return *lhs / rhsVal; } } llvm_unreachable("Unknown affine expression kind"); } // Returns the static bounds for the affine map `map` given the static // bounds for all operands on which the map is applied. The map must // not contain any symbols, all of its expressions must be pure affine // expressions and the number of results must be one. static std::optional getBoundsOfAffineMap(mlir::AffineMap map, llvm::ArrayRef mapOperandBounds) { assert(map.getNumResults() == 1 && "Attempting to get bounds for map with multiple result dimensions"); assert(map.getNumSymbols() == 0 && "Attempting to get bounds for map with symbols"); assert(map.getResult(0).isPureAffine() && "Attempting to get bounds for non-pure affine expression"); return getBoundsOfAffineExpression(map.getResult(0), mapOperandBounds); } /// Returns the lower bound, upper bound and step of the quasi-affine /// expression `expr` on the the induction variable from a for /// operation. std::optional getBoundsOfQuasiAffineIVExpression(mlir::Value expr, mlir::scf::ForOp forOp) { // Base case: expression is the induction variable itself -> check // if the bounds are static and return them if (forOp && expr == forOp.getInductionVar() && isConstantIndexValue(forOp.getLowerBound()) && isConstantIndexValue(forOp.getUpperBound()) && isConstantIndexValue(forOp.getStep())) { return LoopsBoundsAndStep{getConstantIndexValue(forOp.getLowerBound()), getConstantIndexValue(forOp.getUpperBound()), getConstantIndexValue(forOp.getStep())}; } // Arithmetic expression else if (mlir::Operation *op = expr.getDefiningOp()) { if (llvm::isa(op)) { std::optional lhs = getBoundsOfQuasiAffineIVExpression(op->getOperand(0), forOp); std::optional rhs = getBoundsOfQuasiAffineIVExpression(op->getOperand(1), forOp); if (!lhs.has_value() || !rhs.has_value()) return std::nullopt; if (llvm::isa(op)) return *lhs + *rhs; else if (llvm::isa(op)) return *lhs - *rhs; else if (llvm::isa(op)) return (*lhs) * (*rhs); else if (llvm::isa(op)) { assert(rhs->ub == rhs->lb && rhs->step == 0 && "Expression for divisor references IV"); int64_t rhsVal = rhs->ub; assert(rhsVal != 0 && "Division by zero"); // If the step value of the subexpression is not a multiple of // the divisor, there may be two iterations with the same // value. Conservatively bail out. if (lhs->step % rhsVal != 0) return std::nullopt; return *lhs / rhsVal; } } // Affine.apply if (mlir::AffineApplyOp applyOp = llvm::dyn_cast(op)) { if (applyOp.getMap().getNumResults() != 1 || applyOp.getMap().getNumSymbols() != 0 || !applyOp.getMap().getResult(0).isPureAffine()) return std::nullopt; llvm::SmallVector bounds; for (mlir::Value operand : applyOp.getMapOperands()) { std::optional operatorBounds = getBoundsOfQuasiAffineIVExpression(operand, forOp); if (!operatorBounds.has_value()) return std::nullopt; bounds.push_back(operatorBounds.value()); } return getBoundsOfAffineMap(applyOp.getMap(), bounds); } // Base case: constant -> return constant value else if (llvm::isa(expr.getDefiningOp())) { mlir::arith::ConstantIndexOp cst = llvm::dyn_cast(expr.getDefiningOp()); return LoopsBoundsAndStep{cst.value(), cst.value(), 0}; } } return std::nullopt; } std::optional getBoundsOfQuasiAffineIVExpression(mlir::OpFoldResult expr, mlir::scf::ForOp forOp) { if (mlir::Value dynExpr = expr.dyn_cast()) return getBoundsOfQuasiAffineIVExpression(dynExpr, forOp); mlir::IntegerAttr exprAttr = expr.dyn_cast().dyn_cast_or_null(); assert(exprAttr && "Expected OpFoldResult to contain either a Value or an " "integer attribute"); return LoopsBoundsAndStep{exprAttr.getInt(), exprAttr.getInt(), 0}; } /// Checks if `forOp` has constant bounds and a constant step /// resulting from quasi affine expressions. bool isStaticLoop(mlir::scf::ForOp forOp, int64_t *ilb, int64_t *iub, int64_t *istep) { std::optional basLB = getBoundsOfQuasiAffineIVExpression(forOp.getLowerBound(), nullptr); std::optional basUB = getBoundsOfQuasiAffineIVExpression(forOp.getUpperBound(), nullptr); std::optional basStep = getBoundsOfQuasiAffineIVExpression(forOp.getStep(), nullptr); if (!basLB.has_value() || !basUB.has_value() || !basStep.has_value()) return false; if ((basLB->lb != basLB->ub || basLB->step != 0) || (basUB->lb != basUB->ub || basUB->step != 0) || (basStep->lb != basStep->ub || basStep->step != 0)) return false; if (ilb) *ilb = basLB->lb; if (iub) *iub = basUB->lb; if (istep) *istep = basStep->lb; return true; } int64_t getStaticTripCount(int64_t lb, int64_t ub, int64_t step) { assert((step == 0 && lb == ub) || (step >= 0 && lb <= ub) || (step < 0 && lb > ub)); if (lb == ub) return 0; if (lb > ub) return getStaticTripCount(ub, lb, -step); assert(ub - lb < std::numeric_limits::max() - step); return (ub - lb + step - 1) / step; } int64_t getStaticTripCount(const LoopsBoundsAndStep &bas) { return getStaticTripCount(bas.lb, bas.ub, bas.step); } // Returns the number of iterations of a static loop int64_t getStaticTripCount(mlir::scf::ForOp forOp) { int64_t lb; int64_t ub; int64_t step; bool isStatic = isStaticLoop(forOp, &lb, &ub, &step); assert(isStatic && "Loop must be static"); return getStaticTripCount(lb, ub, step); } // Returns the trip count of `forOp` if it is a static loop std::optional tryGetStaticTripCount(mlir::scf::ForOp forOp) { int64_t lb; int64_t ub; int64_t step; if (!isStaticLoop(forOp, &lb, &ub, &step)) return std::nullopt; return getStaticTripCount(lb, ub, step); } // Returns the total number of executions of the body of the innermost // loop of a nest of static loops int64_t getNestedStaticTripCount(llvm::ArrayRef nest) { int64_t tripCount = 1; for (mlir::scf::ForOp forOp : nest) { int64_t thisCount = getStaticTripCount(forOp); if (thisCount == 0) return 0; assert(std::numeric_limits::max() / thisCount >= tripCount); tripCount *= thisCount; } return tripCount; } // Checks whether `v` is a constant value of type index bool isConstantIndexValue(mlir::Value v) { return v.getDefiningOp() && llvm::isa(*v.getDefiningOp()); } /// Assumes that `v` is a constant index operation and returns the /// constant value as an `int64_t`. int64_t getConstantIndexValue(mlir::Value v) { assert(isConstantIndexValue(v)); return llvm::dyn_cast(*v.getDefiningOp()) .value(); } // Checks whether `v` is a constant value of type index and its values is `i` bool isConstantIndexValue(mlir::Value v, int64_t i) { return isConstantIndexValue(v) && getConstantIndexValue(v) == i; } // Returns a `Value` corresponding to `iv`, normalized to the lower // bound `lb` and step `step` of a loop (i.e., (iv - lb) / step). mlir::Value normalizeInductionVar(mlir::ImplicitLocOpBuilder &builder, mlir::Value iv, mlir::OpFoldResult lb, mlir::OpFoldResult step) { std::optional lbInt = mlir::getConstantIntValue(lb); std::optional stepInt = mlir::getConstantIntValue(step); mlir::Value idxShifted = lbInt.has_value() && *lbInt == 0 ? iv : builder.create( iv, mlir::getValueOrCreateConstantIndexOp( builder, builder.getLoc(), lb)); mlir::Value normalizedIV = stepInt.has_value() && *stepInt == 1 ? idxShifted : builder.create( idxShifted, mlir::getValueOrCreateConstantIndexOp( builder, builder.getLoc(), step)); return normalizedIV; } llvm::SmallVector normalizeInductionVars(mlir::ImplicitLocOpBuilder &builder, mlir::ValueRange ivs, llvm::ArrayRef lbs, llvm::ArrayRef steps) { llvm::SmallVector normalizedIVs; for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) { normalizedIVs.push_back(normalizeInductionVar(builder, iv, lb, step)); } return normalizedIVs; } } // namespace concretelang } // namespace mlir