mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-18 00:21:36 -05:00
refactor(compiler): Use common static loop analysis in TFHE / memory statistics
This commit is contained in:
@@ -51,6 +51,7 @@ getBoundsOfQuasiAffineIVExpression(mlir::OpFoldResult expr,
|
||||
int64_t getStaticTripCount(int64_t lb, int64_t ub, int64_t step);
|
||||
int64_t getStaticTripCount(const LoopsBoundsAndStep &bas);
|
||||
int64_t getStaticTripCount(mlir::scf::ForOp forOp);
|
||||
std::optional<int64_t> tryGetStaticTripCount(mlir::scf::ForOp forOp);
|
||||
int64_t getNestedStaticTripCount(llvm::ArrayRef<mlir::scf::ForOp> nest);
|
||||
bool isStaticLoop(mlir::scf::ForOp forOp, int64_t *ilb = nullptr,
|
||||
int64_t *iub = nullptr, int64_t *istep = nullptr);
|
||||
|
||||
@@ -17,13 +17,6 @@ namespace concretelang {
|
||||
/// Get the string representation of a location
|
||||
std::string locationString(mlir::Location loc);
|
||||
|
||||
/// Compute the number of iterations based on loop info
|
||||
int64_t calculateNumberOfIterations(int64_t start, int64_t stop, int64_t step);
|
||||
|
||||
/// Compute the number of iterations of an scf for loop
|
||||
outcome::checked<int64_t, ::concretelang::error::StringError>
|
||||
calculateNumberOfIterations(scf::ForOp &op);
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <mlir/Dialect/Arith/IR/Arith.h>
|
||||
|
||||
#include <concretelang/Analysis/StaticLoops.h>
|
||||
#include <optional>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
@@ -395,6 +396,18 @@ int64_t getStaticTripCount(mlir::scf::ForOp forOp) {
|
||||
return getStaticTripCount(lb, ub, step);
|
||||
}
|
||||
|
||||
// Returns the trip count of `forOp` if it is a static loop
|
||||
std::optional<int64_t> tryGetStaticTripCount(mlir::scf::ForOp forOp) {
|
||||
int64_t lb;
|
||||
int64_t ub;
|
||||
int64_t step;
|
||||
|
||||
if (!isStaticLoop(forOp, &lb, &ub, &step))
|
||||
return std::nullopt;
|
||||
|
||||
return getStaticTripCount(lb, ub, step);
|
||||
}
|
||||
|
||||
// Returns the total number of executions of the body of the innermost
|
||||
// loop of a nest of static loops
|
||||
int64_t getNestedStaticTripCount(llvm::ArrayRef<mlir::scf::ForOp> nest) {
|
||||
|
||||
@@ -11,57 +11,5 @@ std::string locationString(mlir::Location loc) {
|
||||
loc->print(locationStream);
|
||||
return location;
|
||||
}
|
||||
|
||||
int64_t calculateNumberOfIterations(int64_t start, int64_t stop, int64_t step) {
|
||||
int64_t high;
|
||||
int64_t low;
|
||||
|
||||
if (step > 0) {
|
||||
low = start;
|
||||
high = stop;
|
||||
} else {
|
||||
low = stop;
|
||||
high = start;
|
||||
step = -step;
|
||||
}
|
||||
|
||||
if (low >= high) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return ((high - low - 1) / step) + 1;
|
||||
}
|
||||
|
||||
outcome::checked<int64_t, StringError>
|
||||
calculateNumberOfIterations(scf::ForOp &op) {
|
||||
mlir::Value startValue = op.getLowerBound();
|
||||
mlir::Value stopValue = op.getUpperBound();
|
||||
mlir::Value stepValue = op.getStep();
|
||||
|
||||
auto startOp =
|
||||
llvm::dyn_cast_or_null<arith::ConstantOp>(startValue.getDefiningOp());
|
||||
auto stopOp =
|
||||
llvm::dyn_cast_or_null<arith::ConstantOp>(stopValue.getDefiningOp());
|
||||
auto stepOp =
|
||||
llvm::dyn_cast_or_null<arith::ConstantOp>(stepValue.getDefiningOp());
|
||||
|
||||
if (!startOp || !stopOp || !stepOp) {
|
||||
return StringError("only static loops can be analyzed");
|
||||
}
|
||||
|
||||
auto startAttr = startOp.getValue().cast<mlir::IntegerAttr>();
|
||||
auto stopAttr = stopOp.getValue().cast<mlir::IntegerAttr>();
|
||||
auto stepAttr = stepOp.getValue().cast<mlir::IntegerAttr>();
|
||||
|
||||
if (!startOp || !stopOp || !stepOp) {
|
||||
return StringError("only integer loops can be analyzed");
|
||||
}
|
||||
|
||||
int64_t start = startAttr.getInt();
|
||||
int64_t stop = stopAttr.getInt();
|
||||
int64_t step = stepAttr.getInt();
|
||||
|
||||
return calculateNumberOfIterations(start, stop, step);
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#include <concretelang/Analysis/StaticLoops.h>
|
||||
#include <concretelang/Analysis/Utils.h>
|
||||
#include <concretelang/Dialect/Concrete/Analysis/MemoryUsage.h>
|
||||
#include <concretelang/Dialect/Concrete/IR/ConcreteOps.h>
|
||||
@@ -147,9 +148,10 @@ struct MemoryUsagePass
|
||||
|
||||
static std::optional<StringError> on_enter(scf::ForOp &op,
|
||||
MemoryUsagePass &pass) {
|
||||
auto numberOfIterations = calculateNumberOfIterations(op);
|
||||
if (!numberOfIterations) {
|
||||
return numberOfIterations.error();
|
||||
std::optional<int64_t> numberOfIterations = tryGetStaticTripCount(op);
|
||||
|
||||
if (!numberOfIterations.has_value()) {
|
||||
return StringError("only static loops can be analyzed");
|
||||
}
|
||||
|
||||
assert(numberOfIterations.value() > 0);
|
||||
@@ -159,9 +161,10 @@ struct MemoryUsagePass
|
||||
|
||||
static std::optional<StringError> on_exit(scf::ForOp &op,
|
||||
MemoryUsagePass &pass) {
|
||||
auto numberOfIterations = calculateNumberOfIterations(op);
|
||||
if (!numberOfIterations) {
|
||||
return numberOfIterations.error();
|
||||
std::optional<int64_t> numberOfIterations = tryGetStaticTripCount(op);
|
||||
|
||||
if (!numberOfIterations.has_value()) {
|
||||
return StringError("only static loops can be analyzed");
|
||||
}
|
||||
|
||||
assert(numberOfIterations.value() > 0);
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "concretelang/Support/CompilationFeedback.h"
|
||||
#include <concretelang/Analysis/StaticLoops.h>
|
||||
#include <concretelang/Analysis/Utils.h>
|
||||
#include <concretelang/Dialect/TFHE/Analysis/ExtractStatistics.h>
|
||||
|
||||
@@ -100,9 +101,10 @@ struct ExtractTFHEStatisticsPass
|
||||
|
||||
static std::optional<StringError> on_enter(scf::ForOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
auto numberOfIterations = calculateNumberOfIterations(op);
|
||||
if (!numberOfIterations) {
|
||||
return numberOfIterations.error();
|
||||
std::optional<int64_t> numberOfIterations = tryGetStaticTripCount(op);
|
||||
|
||||
if (!numberOfIterations.has_value()) {
|
||||
return StringError("only static loops can be analyzed");
|
||||
}
|
||||
|
||||
assert(numberOfIterations.value() > 0);
|
||||
@@ -112,9 +114,10 @@ struct ExtractTFHEStatisticsPass
|
||||
|
||||
static std::optional<StringError> on_exit(scf::ForOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
auto numberOfIterations = calculateNumberOfIterations(op);
|
||||
if (!numberOfIterations) {
|
||||
return numberOfIterations.error();
|
||||
std::optional<int64_t> numberOfIterations = tryGetStaticTripCount(op);
|
||||
|
||||
if (!numberOfIterations.has_value()) {
|
||||
return StringError("only static loops can be analyzed");
|
||||
}
|
||||
|
||||
assert(numberOfIterations.value() > 0);
|
||||
|
||||
Reference in New Issue
Block a user