diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Analysis/StaticLoops.h b/compilers/concrete-compiler/compiler/include/concretelang/Analysis/StaticLoops.h index ad45519e1..559e7136b 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Analysis/StaticLoops.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Analysis/StaticLoops.h @@ -51,6 +51,7 @@ getBoundsOfQuasiAffineIVExpression(mlir::OpFoldResult expr, int64_t getStaticTripCount(int64_t lb, int64_t ub, int64_t step); int64_t getStaticTripCount(const LoopsBoundsAndStep &bas); int64_t getStaticTripCount(mlir::scf::ForOp forOp); +std::optional tryGetStaticTripCount(mlir::scf::ForOp forOp); int64_t getNestedStaticTripCount(llvm::ArrayRef nest); bool isStaticLoop(mlir::scf::ForOp forOp, int64_t *ilb = nullptr, int64_t *iub = nullptr, int64_t *istep = nullptr); diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Analysis/Utils.h b/compilers/concrete-compiler/compiler/include/concretelang/Analysis/Utils.h index a332837c2..252e61ab8 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Analysis/Utils.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Analysis/Utils.h @@ -17,13 +17,6 @@ namespace concretelang { /// Get the string representation of a location std::string locationString(mlir::Location loc); -/// Compute the number of iterations based on loop info -int64_t calculateNumberOfIterations(int64_t start, int64_t stop, int64_t step); - -/// Compute the number of iterations of an scf for loop -outcome::checked -calculateNumberOfIterations(scf::ForOp &op); - } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Analysis/StaticLoops.cpp b/compilers/concrete-compiler/compiler/lib/Analysis/StaticLoops.cpp index 953427f89..4a22a588c 100644 --- a/compilers/concrete-compiler/compiler/lib/Analysis/StaticLoops.cpp +++ b/compilers/concrete-compiler/compiler/lib/Analysis/StaticLoops.cpp @@ -7,6 +7,7 @@ #include #include +#include namespace mlir { namespace concretelang { @@ -395,6 +396,18 @@ int64_t getStaticTripCount(mlir::scf::ForOp forOp) { return getStaticTripCount(lb, ub, step); } +// Returns the trip count of `forOp` if it is a static loop +std::optional tryGetStaticTripCount(mlir::scf::ForOp forOp) { + int64_t lb; + int64_t ub; + int64_t step; + + if (!isStaticLoop(forOp, &lb, &ub, &step)) + return std::nullopt; + + return getStaticTripCount(lb, ub, step); +} + // Returns the total number of executions of the body of the innermost // loop of a nest of static loops int64_t getNestedStaticTripCount(llvm::ArrayRef nest) { diff --git a/compilers/concrete-compiler/compiler/lib/Analysis/Utils.cpp b/compilers/concrete-compiler/compiler/lib/Analysis/Utils.cpp index fd795cc83..48852b9ac 100644 --- a/compilers/concrete-compiler/compiler/lib/Analysis/Utils.cpp +++ b/compilers/concrete-compiler/compiler/lib/Analysis/Utils.cpp @@ -11,57 +11,5 @@ std::string locationString(mlir::Location loc) { loc->print(locationStream); return location; } - -int64_t calculateNumberOfIterations(int64_t start, int64_t stop, int64_t step) { - int64_t high; - int64_t low; - - if (step > 0) { - low = start; - high = stop; - } else { - low = stop; - high = start; - step = -step; - } - - if (low >= high) { - return 0; - } - - return ((high - low - 1) / step) + 1; -} - -outcome::checked -calculateNumberOfIterations(scf::ForOp &op) { - mlir::Value startValue = op.getLowerBound(); - mlir::Value stopValue = op.getUpperBound(); - mlir::Value stepValue = op.getStep(); - - auto startOp = - llvm::dyn_cast_or_null(startValue.getDefiningOp()); - auto stopOp = - llvm::dyn_cast_or_null(stopValue.getDefiningOp()); - auto stepOp = - llvm::dyn_cast_or_null(stepValue.getDefiningOp()); - - if (!startOp || !stopOp || !stepOp) { - return StringError("only static loops can be analyzed"); - } - - auto startAttr = startOp.getValue().cast(); - auto stopAttr = stopOp.getValue().cast(); - auto stepAttr = stepOp.getValue().cast(); - - if (!startOp || !stopOp || !stepOp) { - return StringError("only integer loops can be analyzed"); - } - - int64_t start = startAttr.getInt(); - int64_t stop = stopAttr.getInt(); - int64_t step = stepAttr.getInt(); - - return calculateNumberOfIterations(start, stop, step); -} } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp index b4958735e..15087690b 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -147,9 +148,10 @@ struct MemoryUsagePass static std::optional on_enter(scf::ForOp &op, MemoryUsagePass &pass) { - auto numberOfIterations = calculateNumberOfIterations(op); - if (!numberOfIterations) { - return numberOfIterations.error(); + std::optional numberOfIterations = tryGetStaticTripCount(op); + + if (!numberOfIterations.has_value()) { + return StringError("only static loops can be analyzed"); } assert(numberOfIterations.value() > 0); @@ -159,9 +161,10 @@ struct MemoryUsagePass static std::optional on_exit(scf::ForOp &op, MemoryUsagePass &pass) { - auto numberOfIterations = calculateNumberOfIterations(op); - if (!numberOfIterations) { - return numberOfIterations.error(); + std::optional numberOfIterations = tryGetStaticTripCount(op); + + if (!numberOfIterations.has_value()) { + return StringError("only static loops can be analyzed"); } assert(numberOfIterations.value() > 0); diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/ExtractStatistics.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/ExtractStatistics.cpp index 4b7aeaad3..f417c5003 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/ExtractStatistics.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/ExtractStatistics.cpp @@ -1,4 +1,5 @@ #include "concretelang/Support/CompilationFeedback.h" +#include #include #include @@ -100,9 +101,10 @@ struct ExtractTFHEStatisticsPass static std::optional on_enter(scf::ForOp &op, ExtractTFHEStatisticsPass &pass) { - auto numberOfIterations = calculateNumberOfIterations(op); - if (!numberOfIterations) { - return numberOfIterations.error(); + std::optional numberOfIterations = tryGetStaticTripCount(op); + + if (!numberOfIterations.has_value()) { + return StringError("only static loops can be analyzed"); } assert(numberOfIterations.value() > 0); @@ -112,9 +114,10 @@ struct ExtractTFHEStatisticsPass static std::optional on_exit(scf::ForOp &op, ExtractTFHEStatisticsPass &pass) { - auto numberOfIterations = calculateNumberOfIterations(op); - if (!numberOfIterations) { - return numberOfIterations.error(); + std::optional numberOfIterations = tryGetStaticTripCount(op); + + if (!numberOfIterations.has_value()) { + return StringError("only static loops can be analyzed"); } assert(numberOfIterations.value() > 0);