feat(compiler): Make trip counts and memory usage optional in statistics passes

This makes the trip counts of operations in the TFHE statistics pass
as well as the per-location memory usage statistics in the memory
usage statistics pass optional. These values are unset if the trip
count could not be determined statically.
This commit is contained in:
Andi Drebes
2023-09-22 06:38:49 +02:00
parent e6e5db6f51
commit 7a295f89bd
5 changed files with 105 additions and 49 deletions

View File

@@ -8,11 +8,49 @@
#include <boost/outcome.h>
#include <concretelang/Common/Error.h>
#include <limits>
#include <mlir/Dialect/SCF/IR/SCF.h>
#include <mlir/IR/Location.h>
namespace mlir {
namespace concretelang {
class TripCountTracker {
public:
void pushTripCount(mlir::Operation *op, std::optional<int64_t> n) {
if (tripCount.has_value()) {
if (n.has_value()) {
assert(std::numeric_limits<int64_t>::max() / n.value() >
tripCount.value());
tripCount = tripCount.value() * n.value();
} else {
savedTripCount = *tripCount;
tripCount = std::nullopt;
firstDynamicTripCountOp = op;
}
}
}
void popTripCount(mlir::Operation *op, std::optional<int64_t> n) {
if (n.has_value()) {
if (tripCount.has_value()) {
tripCount = tripCount.value() / n.value();
}
} else {
if (firstDynamicTripCountOp == op) {
tripCount = savedTripCount;
}
firstDynamicTripCountOp = nullptr;
}
}
std::optional<int64_t> getTripCount() { return tripCount; }
protected:
std::optional<int64_t> tripCount = 1;
size_t savedTripCount;
mlir::Operation *firstDynamicTripCountOp = nullptr;
};
/// Get the string representation of a location
std::string locationString(mlir::Location loc);

View File

@@ -45,7 +45,7 @@ struct Statistic {
std::string location;
PrimitiveOperation operation;
std::vector<std::pair<KeyType, int64_t>> keys;
int64_t count;
std::optional<int64_t> count;
};
struct CircuitCompilationFeedback {
@@ -65,7 +65,7 @@ struct CircuitCompilationFeedback {
std::vector<Statistic> statistics;
/// @brief memory usage per location
std::map<std::string, int64_t> memoryUsagePerLoc;
std::map<std::string, std::optional<int64_t>> memoryUsagePerLoc;
/// Fill the sizes from the program info.
void fillFromCircuitInfo(concreteprotocol::CircuitInfo::Reader params);

View File

@@ -11,6 +11,7 @@
#include <mlir/IR/Operation.h>
#include <mlir/Interfaces/ViewLikeInterface.h>
#include <numeric>
#include <optional>
using namespace mlir::concretelang;
using namespace mlir;
@@ -65,8 +66,21 @@ namespace mlir {
namespace concretelang {
namespace Concrete {
namespace {
// Adds the values of `a` and `b` if they are both defined,
// otherwise returns `std::nullopt`
std::optional<int64_t> addOrNullopt(std::optional<int64_t> a,
std::optional<int64_t> b) {
if (a.has_value() && b.has_value())
return a.value() + b.value();
else
return std::nullopt;
}
} // namespace
struct MemoryUsagePass
: public PassWrapper<MemoryUsagePass, OperationPass<ModuleOp>> {
: public PassWrapper<MemoryUsagePass, OperationPass<ModuleOp>>,
public TripCountTracker {
ProgramCompilationFeedback &feedback;
CircuitCompilationFeedback *circuitFeedback;
@@ -148,27 +162,22 @@ struct MemoryUsagePass
static std::optional<StringError> on_enter(scf::ForOp &op,
MemoryUsagePass &pass) {
std::optional<int64_t> numberOfIterations = tryGetStaticTripCount(op);
std::optional<int64_t> tripCount = tryGetStaticTripCount(op);
if (!numberOfIterations.has_value()) {
return StringError("only static loops can be analyzed");
if (!tripCount.has_value()) {
emitWarning(op.getLoc(), "Cannot determine static trip count");
}
assert(numberOfIterations.value() > 0);
pass.iterations *= (uint64_t)numberOfIterations.value();
pass.pushTripCount(op, tripCount);
return std::nullopt;
}
static std::optional<StringError> on_exit(scf::ForOp &op,
MemoryUsagePass &pass) {
std::optional<int64_t> numberOfIterations = tryGetStaticTripCount(op);
std::optional<int64_t> tripCount = tryGetStaticTripCount(op);
pass.popTripCount(op, tripCount);
if (!numberOfIterations.has_value()) {
return StringError("only static loops can be analyzed");
}
assert(numberOfIterations.value() > 0);
pass.iterations /= (uint64_t)numberOfIterations.value();
return std::nullopt;
}
@@ -179,18 +188,28 @@ struct MemoryUsagePass
if (!maybeBufferSize) {
return maybeBufferSize.error();
}
std::optional<int64_t> memoryUsage = maybeBufferSize.value();
// if the allocated buffer is being deallocated then count it as one.
// Otherwise (and there must be a problem) multiply it by the number of
// iterations
int64_t numberOfAlloc =
isBufferDeallocated(op.getResult()) ? 1 : pass.iterations;
if (!isBufferDeallocated(op.getResult())) {
if (pass.getTripCount().has_value())
memoryUsage = memoryUsage.value() * pass.getTripCount().value();
else
memoryUsage = std::nullopt;
}
auto location = locationString(op.getLoc());
// pass.iterations number of allocation of size: shape_1 * ... * shape_n *
// element_size
auto memoryUsage = numberOfAlloc * maybeBufferSize.value();
pass.circuitFeedback->memoryUsagePerLoc[location] += memoryUsage;
if (pass.circuitFeedback->memoryUsagePerLoc.find(location) !=
pass.circuitFeedback->memoryUsagePerLoc.end()) {
pass.circuitFeedback->memoryUsagePerLoc[location] = addOrNullopt(
pass.circuitFeedback->memoryUsagePerLoc[location], memoryUsage);
} else {
pass.circuitFeedback->memoryUsagePerLoc[location] = memoryUsage;
}
return std::nullopt;
}
@@ -236,7 +255,13 @@ struct MemoryUsagePass
}
auto bufferSize = maybeBufferSize.value();
pass.circuitFeedback->memoryUsagePerLoc[location] += bufferSize;
if (pass.circuitFeedback->memoryUsagePerLoc.find(location) !=
pass.circuitFeedback->memoryUsagePerLoc.end()) {
pass.circuitFeedback->memoryUsagePerLoc[location] = addOrNullopt(
pass.circuitFeedback->memoryUsagePerLoc[location], bufferSize);
} else {
pass.circuitFeedback->memoryUsagePerLoc[location] = bufferSize;
}
}
}
@@ -244,8 +269,6 @@ struct MemoryUsagePass
}
std::map<std::string, std::vector<mlir::Value>> visitedValuesPerLoc;
size_t iterations = 1;
};
} // namespace Concrete

View File

@@ -34,7 +34,8 @@ namespace TFHE {
}
struct ExtractTFHEStatisticsPass
: public PassWrapper<ExtractTFHEStatisticsPass, OperationPass<ModuleOp>> {
: public PassWrapper<ExtractTFHEStatisticsPass, OperationPass<ModuleOp>>,
public TripCountTracker {
ProgramCompilationFeedback &feedback;
CircuitCompilationFeedback *circuitFeedback;
@@ -101,27 +102,22 @@ struct ExtractTFHEStatisticsPass
static std::optional<StringError> on_enter(scf::ForOp &op,
ExtractTFHEStatisticsPass &pass) {
std::optional<int64_t> numberOfIterations = tryGetStaticTripCount(op);
std::optional<int64_t> tripCount = tryGetStaticTripCount(op);
if (!numberOfIterations.has_value()) {
return StringError("only static loops can be analyzed");
if (!tripCount.has_value()) {
emitWarning(op.getLoc(), "Cannot determine static trip count");
}
assert(numberOfIterations.value() > 0);
pass.iterations *= (uint64_t)numberOfIterations.value();
pass.pushTripCount(op, tripCount);
return std::nullopt;
}
static std::optional<StringError> on_exit(scf::ForOp &op,
ExtractTFHEStatisticsPass &pass) {
std::optional<int64_t> numberOfIterations = tryGetStaticTripCount(op);
std::optional<int64_t> tripCount = tryGetStaticTripCount(op);
pass.popTripCount(op, tripCount);
if (!numberOfIterations.has_value()) {
return StringError("only static loops can be analyzed");
}
assert(numberOfIterations.value() > 0);
pass.iterations /= (uint64_t)numberOfIterations.value();
return std::nullopt;
}
@@ -136,7 +132,7 @@ struct ExtractTFHEStatisticsPass
auto location = locationString(op.getLoc());
auto operation = PrimitiveOperation::ENCRYPTED_ADDITION;
auto keys = std::vector<std::pair<KeyType, int64_t>>();
auto count = pass.iterations;
auto count = pass.getTripCount();
std::pair<KeyType, int64_t> key =
std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index);
@@ -163,7 +159,7 @@ struct ExtractTFHEStatisticsPass
auto location = locationString(op.getLoc());
auto operation = PrimitiveOperation::CLEAR_ADDITION;
auto keys = std::vector<std::pair<KeyType, int64_t>>();
auto count = pass.iterations;
auto count = pass.getTripCount();
std::pair<KeyType, int64_t> key =
std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index);
@@ -190,7 +186,7 @@ struct ExtractTFHEStatisticsPass
auto location = locationString(op.getLoc());
auto operation = PrimitiveOperation::PBS;
auto keys = std::vector<std::pair<KeyType, int64_t>>();
auto count = pass.iterations;
auto count = pass.getTripCount();
std::pair<KeyType, int64_t> key =
std::make_pair(KeyType::BOOTSTRAP, (int64_t)bsk.getIndex());
@@ -217,7 +213,7 @@ struct ExtractTFHEStatisticsPass
auto location = locationString(op.getLoc());
auto operation = PrimitiveOperation::KEY_SWITCH;
auto keys = std::vector<std::pair<KeyType, int64_t>>();
auto count = pass.iterations;
auto count = pass.getTripCount();
std::pair<KeyType, int64_t> key =
std::make_pair(KeyType::KEY_SWITCH, (int64_t)ksk.getIndex());
@@ -244,7 +240,7 @@ struct ExtractTFHEStatisticsPass
auto location = locationString(op.getLoc());
auto operation = PrimitiveOperation::CLEAR_MULTIPLICATION;
auto keys = std::vector<std::pair<KeyType, int64_t>>();
auto count = pass.iterations;
auto count = pass.getTripCount();
std::pair<KeyType, int64_t> key =
std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index);
@@ -271,7 +267,7 @@ struct ExtractTFHEStatisticsPass
auto location = locationString(op.getLoc());
auto operation = PrimitiveOperation::ENCRYPTED_NEGATION;
auto keys = std::vector<std::pair<KeyType, int64_t>>();
auto count = pass.iterations;
auto count = pass.getTripCount();
std::pair<KeyType, int64_t> key =
std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index);
@@ -297,7 +293,7 @@ struct ExtractTFHEStatisticsPass
auto location = locationString(op.getLoc());
auto keys = std::vector<std::pair<KeyType, int64_t>>();
auto count = pass.iterations;
auto count = pass.getTripCount();
std::pair<KeyType, int64_t> key =
std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index);
@@ -337,7 +333,7 @@ struct ExtractTFHEStatisticsPass
auto location = locationString(op.getLoc());
auto operation = PrimitiveOperation::WOP_PBS;
auto keys = std::vector<std::pair<KeyType, int64_t>>();
auto count = pass.iterations;
auto count = pass.getTripCount();
std::pair<KeyType, int64_t> key =
std::make_pair(KeyType::BOOTSTRAP, (int64_t)bsk.getIndex());
@@ -358,8 +354,6 @@ struct ExtractTFHEStatisticsPass
return std::nullopt;
}
int64_t iterations = 1;
};
} // namespace TFHE

View File

@@ -136,12 +136,13 @@ ProgramCompilationFeedback::load(std::string jsonPath) {
return expectedCompFeedback.get();
}
llvm::json::Object
memoryUsageToJson(const std::map<std::string, int64_t> &memoryUsagePerLoc) {
llvm::json::Object memoryUsageToJson(
const std::map<std::string, std::optional<int64_t>> &memoryUsagePerLoc) {
auto object = llvm::json::Object();
for (auto key : memoryUsagePerLoc) {
object.insert({key.first, key.second});
}
return object;
}