mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-18 00:21:36 -05:00
feat(compiler): Make trip counts and memory usage optional in statistics passes
This makes the trip counts of operations in the TFHE statistics pass as well as the per-location memory usage statistics in the memory usage statistics pass optional. These values are unset if the trip count could not be determined statically.
This commit is contained in:
@@ -8,11 +8,49 @@
|
||||
|
||||
#include <boost/outcome.h>
|
||||
#include <concretelang/Common/Error.h>
|
||||
#include <limits>
|
||||
#include <mlir/Dialect/SCF/IR/SCF.h>
|
||||
#include <mlir/IR/Location.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
class TripCountTracker {
|
||||
public:
|
||||
void pushTripCount(mlir::Operation *op, std::optional<int64_t> n) {
|
||||
if (tripCount.has_value()) {
|
||||
if (n.has_value()) {
|
||||
assert(std::numeric_limits<int64_t>::max() / n.value() >
|
||||
tripCount.value());
|
||||
|
||||
tripCount = tripCount.value() * n.value();
|
||||
} else {
|
||||
savedTripCount = *tripCount;
|
||||
tripCount = std::nullopt;
|
||||
firstDynamicTripCountOp = op;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void popTripCount(mlir::Operation *op, std::optional<int64_t> n) {
|
||||
if (n.has_value()) {
|
||||
if (tripCount.has_value()) {
|
||||
tripCount = tripCount.value() / n.value();
|
||||
}
|
||||
} else {
|
||||
if (firstDynamicTripCountOp == op) {
|
||||
tripCount = savedTripCount;
|
||||
}
|
||||
firstDynamicTripCountOp = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<int64_t> getTripCount() { return tripCount; }
|
||||
|
||||
protected:
|
||||
std::optional<int64_t> tripCount = 1;
|
||||
size_t savedTripCount;
|
||||
mlir::Operation *firstDynamicTripCountOp = nullptr;
|
||||
};
|
||||
|
||||
/// Get the string representation of a location
|
||||
std::string locationString(mlir::Location loc);
|
||||
|
||||
@@ -45,7 +45,7 @@ struct Statistic {
|
||||
std::string location;
|
||||
PrimitiveOperation operation;
|
||||
std::vector<std::pair<KeyType, int64_t>> keys;
|
||||
int64_t count;
|
||||
std::optional<int64_t> count;
|
||||
};
|
||||
|
||||
struct CircuitCompilationFeedback {
|
||||
@@ -65,7 +65,7 @@ struct CircuitCompilationFeedback {
|
||||
std::vector<Statistic> statistics;
|
||||
|
||||
/// @brief memory usage per location
|
||||
std::map<std::string, int64_t> memoryUsagePerLoc;
|
||||
std::map<std::string, std::optional<int64_t>> memoryUsagePerLoc;
|
||||
|
||||
/// Fill the sizes from the program info.
|
||||
void fillFromCircuitInfo(concreteprotocol::CircuitInfo::Reader params);
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <mlir/IR/Operation.h>
|
||||
#include <mlir/Interfaces/ViewLikeInterface.h>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
|
||||
using namespace mlir::concretelang;
|
||||
using namespace mlir;
|
||||
@@ -65,8 +66,21 @@ namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace Concrete {
|
||||
|
||||
namespace {
|
||||
// Adds the values of `a` and `b` if they are both defined,
|
||||
// otherwise returns `std::nullopt`
|
||||
std::optional<int64_t> addOrNullopt(std::optional<int64_t> a,
|
||||
std::optional<int64_t> b) {
|
||||
if (a.has_value() && b.has_value())
|
||||
return a.value() + b.value();
|
||||
else
|
||||
return std::nullopt;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
struct MemoryUsagePass
|
||||
: public PassWrapper<MemoryUsagePass, OperationPass<ModuleOp>> {
|
||||
: public PassWrapper<MemoryUsagePass, OperationPass<ModuleOp>>,
|
||||
public TripCountTracker {
|
||||
|
||||
ProgramCompilationFeedback &feedback;
|
||||
CircuitCompilationFeedback *circuitFeedback;
|
||||
@@ -148,27 +162,22 @@ struct MemoryUsagePass
|
||||
|
||||
static std::optional<StringError> on_enter(scf::ForOp &op,
|
||||
MemoryUsagePass &pass) {
|
||||
std::optional<int64_t> numberOfIterations = tryGetStaticTripCount(op);
|
||||
std::optional<int64_t> tripCount = tryGetStaticTripCount(op);
|
||||
|
||||
if (!numberOfIterations.has_value()) {
|
||||
return StringError("only static loops can be analyzed");
|
||||
if (!tripCount.has_value()) {
|
||||
emitWarning(op.getLoc(), "Cannot determine static trip count");
|
||||
}
|
||||
|
||||
assert(numberOfIterations.value() > 0);
|
||||
pass.iterations *= (uint64_t)numberOfIterations.value();
|
||||
pass.pushTripCount(op, tripCount);
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static std::optional<StringError> on_exit(scf::ForOp &op,
|
||||
MemoryUsagePass &pass) {
|
||||
std::optional<int64_t> numberOfIterations = tryGetStaticTripCount(op);
|
||||
std::optional<int64_t> tripCount = tryGetStaticTripCount(op);
|
||||
pass.popTripCount(op, tripCount);
|
||||
|
||||
if (!numberOfIterations.has_value()) {
|
||||
return StringError("only static loops can be analyzed");
|
||||
}
|
||||
|
||||
assert(numberOfIterations.value() > 0);
|
||||
pass.iterations /= (uint64_t)numberOfIterations.value();
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
@@ -179,18 +188,28 @@ struct MemoryUsagePass
|
||||
if (!maybeBufferSize) {
|
||||
return maybeBufferSize.error();
|
||||
}
|
||||
|
||||
std::optional<int64_t> memoryUsage = maybeBufferSize.value();
|
||||
|
||||
// if the allocated buffer is being deallocated then count it as one.
|
||||
// Otherwise (and there must be a problem) multiply it by the number of
|
||||
// iterations
|
||||
int64_t numberOfAlloc =
|
||||
isBufferDeallocated(op.getResult()) ? 1 : pass.iterations;
|
||||
if (!isBufferDeallocated(op.getResult())) {
|
||||
if (pass.getTripCount().has_value())
|
||||
memoryUsage = memoryUsage.value() * pass.getTripCount().value();
|
||||
else
|
||||
memoryUsage = std::nullopt;
|
||||
}
|
||||
|
||||
auto location = locationString(op.getLoc());
|
||||
// pass.iterations number of allocation of size: shape_1 * ... * shape_n *
|
||||
// element_size
|
||||
auto memoryUsage = numberOfAlloc * maybeBufferSize.value();
|
||||
|
||||
pass.circuitFeedback->memoryUsagePerLoc[location] += memoryUsage;
|
||||
if (pass.circuitFeedback->memoryUsagePerLoc.find(location) !=
|
||||
pass.circuitFeedback->memoryUsagePerLoc.end()) {
|
||||
pass.circuitFeedback->memoryUsagePerLoc[location] = addOrNullopt(
|
||||
pass.circuitFeedback->memoryUsagePerLoc[location], memoryUsage);
|
||||
} else {
|
||||
pass.circuitFeedback->memoryUsagePerLoc[location] = memoryUsage;
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
@@ -236,7 +255,13 @@ struct MemoryUsagePass
|
||||
}
|
||||
auto bufferSize = maybeBufferSize.value();
|
||||
|
||||
pass.circuitFeedback->memoryUsagePerLoc[location] += bufferSize;
|
||||
if (pass.circuitFeedback->memoryUsagePerLoc.find(location) !=
|
||||
pass.circuitFeedback->memoryUsagePerLoc.end()) {
|
||||
pass.circuitFeedback->memoryUsagePerLoc[location] = addOrNullopt(
|
||||
pass.circuitFeedback->memoryUsagePerLoc[location], bufferSize);
|
||||
} else {
|
||||
pass.circuitFeedback->memoryUsagePerLoc[location] = bufferSize;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -244,8 +269,6 @@ struct MemoryUsagePass
|
||||
}
|
||||
|
||||
std::map<std::string, std::vector<mlir::Value>> visitedValuesPerLoc;
|
||||
|
||||
size_t iterations = 1;
|
||||
};
|
||||
|
||||
} // namespace Concrete
|
||||
|
||||
@@ -34,7 +34,8 @@ namespace TFHE {
|
||||
}
|
||||
|
||||
struct ExtractTFHEStatisticsPass
|
||||
: public PassWrapper<ExtractTFHEStatisticsPass, OperationPass<ModuleOp>> {
|
||||
: public PassWrapper<ExtractTFHEStatisticsPass, OperationPass<ModuleOp>>,
|
||||
public TripCountTracker {
|
||||
|
||||
ProgramCompilationFeedback &feedback;
|
||||
CircuitCompilationFeedback *circuitFeedback;
|
||||
@@ -101,27 +102,22 @@ struct ExtractTFHEStatisticsPass
|
||||
|
||||
static std::optional<StringError> on_enter(scf::ForOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
std::optional<int64_t> numberOfIterations = tryGetStaticTripCount(op);
|
||||
std::optional<int64_t> tripCount = tryGetStaticTripCount(op);
|
||||
|
||||
if (!numberOfIterations.has_value()) {
|
||||
return StringError("only static loops can be analyzed");
|
||||
if (!tripCount.has_value()) {
|
||||
emitWarning(op.getLoc(), "Cannot determine static trip count");
|
||||
}
|
||||
|
||||
assert(numberOfIterations.value() > 0);
|
||||
pass.iterations *= (uint64_t)numberOfIterations.value();
|
||||
pass.pushTripCount(op, tripCount);
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static std::optional<StringError> on_exit(scf::ForOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
std::optional<int64_t> numberOfIterations = tryGetStaticTripCount(op);
|
||||
std::optional<int64_t> tripCount = tryGetStaticTripCount(op);
|
||||
pass.popTripCount(op, tripCount);
|
||||
|
||||
if (!numberOfIterations.has_value()) {
|
||||
return StringError("only static loops can be analyzed");
|
||||
}
|
||||
|
||||
assert(numberOfIterations.value() > 0);
|
||||
pass.iterations /= (uint64_t)numberOfIterations.value();
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
@@ -136,7 +132,7 @@ struct ExtractTFHEStatisticsPass
|
||||
auto location = locationString(op.getLoc());
|
||||
auto operation = PrimitiveOperation::ENCRYPTED_ADDITION;
|
||||
auto keys = std::vector<std::pair<KeyType, int64_t>>();
|
||||
auto count = pass.iterations;
|
||||
auto count = pass.getTripCount();
|
||||
|
||||
std::pair<KeyType, int64_t> key =
|
||||
std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index);
|
||||
@@ -163,7 +159,7 @@ struct ExtractTFHEStatisticsPass
|
||||
auto location = locationString(op.getLoc());
|
||||
auto operation = PrimitiveOperation::CLEAR_ADDITION;
|
||||
auto keys = std::vector<std::pair<KeyType, int64_t>>();
|
||||
auto count = pass.iterations;
|
||||
auto count = pass.getTripCount();
|
||||
|
||||
std::pair<KeyType, int64_t> key =
|
||||
std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index);
|
||||
@@ -190,7 +186,7 @@ struct ExtractTFHEStatisticsPass
|
||||
auto location = locationString(op.getLoc());
|
||||
auto operation = PrimitiveOperation::PBS;
|
||||
auto keys = std::vector<std::pair<KeyType, int64_t>>();
|
||||
auto count = pass.iterations;
|
||||
auto count = pass.getTripCount();
|
||||
|
||||
std::pair<KeyType, int64_t> key =
|
||||
std::make_pair(KeyType::BOOTSTRAP, (int64_t)bsk.getIndex());
|
||||
@@ -217,7 +213,7 @@ struct ExtractTFHEStatisticsPass
|
||||
auto location = locationString(op.getLoc());
|
||||
auto operation = PrimitiveOperation::KEY_SWITCH;
|
||||
auto keys = std::vector<std::pair<KeyType, int64_t>>();
|
||||
auto count = pass.iterations;
|
||||
auto count = pass.getTripCount();
|
||||
|
||||
std::pair<KeyType, int64_t> key =
|
||||
std::make_pair(KeyType::KEY_SWITCH, (int64_t)ksk.getIndex());
|
||||
@@ -244,7 +240,7 @@ struct ExtractTFHEStatisticsPass
|
||||
auto location = locationString(op.getLoc());
|
||||
auto operation = PrimitiveOperation::CLEAR_MULTIPLICATION;
|
||||
auto keys = std::vector<std::pair<KeyType, int64_t>>();
|
||||
auto count = pass.iterations;
|
||||
auto count = pass.getTripCount();
|
||||
|
||||
std::pair<KeyType, int64_t> key =
|
||||
std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index);
|
||||
@@ -271,7 +267,7 @@ struct ExtractTFHEStatisticsPass
|
||||
auto location = locationString(op.getLoc());
|
||||
auto operation = PrimitiveOperation::ENCRYPTED_NEGATION;
|
||||
auto keys = std::vector<std::pair<KeyType, int64_t>>();
|
||||
auto count = pass.iterations;
|
||||
auto count = pass.getTripCount();
|
||||
|
||||
std::pair<KeyType, int64_t> key =
|
||||
std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index);
|
||||
@@ -297,7 +293,7 @@ struct ExtractTFHEStatisticsPass
|
||||
|
||||
auto location = locationString(op.getLoc());
|
||||
auto keys = std::vector<std::pair<KeyType, int64_t>>();
|
||||
auto count = pass.iterations;
|
||||
auto count = pass.getTripCount();
|
||||
|
||||
std::pair<KeyType, int64_t> key =
|
||||
std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index);
|
||||
@@ -337,7 +333,7 @@ struct ExtractTFHEStatisticsPass
|
||||
auto location = locationString(op.getLoc());
|
||||
auto operation = PrimitiveOperation::WOP_PBS;
|
||||
auto keys = std::vector<std::pair<KeyType, int64_t>>();
|
||||
auto count = pass.iterations;
|
||||
auto count = pass.getTripCount();
|
||||
|
||||
std::pair<KeyType, int64_t> key =
|
||||
std::make_pair(KeyType::BOOTSTRAP, (int64_t)bsk.getIndex());
|
||||
@@ -358,8 +354,6 @@ struct ExtractTFHEStatisticsPass
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
int64_t iterations = 1;
|
||||
};
|
||||
|
||||
} // namespace TFHE
|
||||
|
||||
@@ -136,12 +136,13 @@ ProgramCompilationFeedback::load(std::string jsonPath) {
|
||||
return expectedCompFeedback.get();
|
||||
}
|
||||
|
||||
llvm::json::Object
|
||||
memoryUsageToJson(const std::map<std::string, int64_t> &memoryUsagePerLoc) {
|
||||
llvm::json::Object memoryUsageToJson(
|
||||
const std::map<std::string, std::optional<int64_t>> &memoryUsagePerLoc) {
|
||||
auto object = llvm::json::Object();
|
||||
for (auto key : memoryUsagePerLoc) {
|
||||
object.insert({key.first, key.second});
|
||||
}
|
||||
|
||||
return object;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user