From 7a295f89bdc3d5354f5ced0592d1ff22293697bd Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Fri, 22 Sep 2023 06:38:49 +0200 Subject: [PATCH] feat(compiler): Make trip counts and memory usage optional in statistics passes This makes the trip counts of operations in the TFHE statistics pass as well as the per-location memory usage statistics in the memory usage statistics pass optional. These values are unset if the trip count could not be determined statically. --- .../include/concretelang/Analysis/Utils.h | 38 +++++++++++ .../Support/CompilationFeedback.h | 4 +- .../Dialect/Concrete/Analysis/MemoryUsage.cpp | 67 +++++++++++++------ .../TFHE/Analysis/ExtractStatistics.cpp | 40 +++++------ .../lib/Support/CompilationFeedback.cpp | 5 +- 5 files changed, 105 insertions(+), 49 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Analysis/Utils.h b/compilers/concrete-compiler/compiler/include/concretelang/Analysis/Utils.h index 252e61ab8..378270923 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Analysis/Utils.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Analysis/Utils.h @@ -8,11 +8,49 @@ #include #include +#include #include #include namespace mlir { namespace concretelang { +class TripCountTracker { +public: + void pushTripCount(mlir::Operation *op, std::optional n) { + if (tripCount.has_value()) { + if (n.has_value()) { + assert(std::numeric_limits::max() / n.value() > + tripCount.value()); + + tripCount = tripCount.value() * n.value(); + } else { + savedTripCount = *tripCount; + tripCount = std::nullopt; + firstDynamicTripCountOp = op; + } + } + } + + void popTripCount(mlir::Operation *op, std::optional n) { + if (n.has_value()) { + if (tripCount.has_value()) { + tripCount = tripCount.value() / n.value(); + } + } else { + if (firstDynamicTripCountOp == op) { + tripCount = savedTripCount; + } + firstDynamicTripCountOp = nullptr; + } + } + + std::optional getTripCount() { return tripCount; } + +protected: + std::optional tripCount = 1; + size_t savedTripCount; + mlir::Operation *firstDynamicTripCountOp = nullptr; +}; /// Get the string representation of a location std::string locationString(mlir::Location loc); diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilationFeedback.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilationFeedback.h index a4084926b..13a5b5760 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilationFeedback.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilationFeedback.h @@ -45,7 +45,7 @@ struct Statistic { std::string location; PrimitiveOperation operation; std::vector> keys; - int64_t count; + std::optional count; }; struct CircuitCompilationFeedback { @@ -65,7 +65,7 @@ struct CircuitCompilationFeedback { std::vector statistics; /// @brief memory usage per location - std::map memoryUsagePerLoc; + std::map> memoryUsagePerLoc; /// Fill the sizes from the program info. void fillFromCircuitInfo(concreteprotocol::CircuitInfo::Reader params); diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp index 15087690b..4e5d0bf13 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp @@ -11,6 +11,7 @@ #include #include #include +#include using namespace mlir::concretelang; using namespace mlir; @@ -65,8 +66,21 @@ namespace mlir { namespace concretelang { namespace Concrete { +namespace { +// Adds the values of `a` and `b` if they are both defined, +// otherwise returns `std::nullopt` +std::optional addOrNullopt(std::optional a, + std::optional b) { + if (a.has_value() && b.has_value()) + return a.value() + b.value(); + else + return std::nullopt; +} +} // namespace + struct MemoryUsagePass - : public PassWrapper> { + : public PassWrapper>, + public TripCountTracker { ProgramCompilationFeedback &feedback; CircuitCompilationFeedback *circuitFeedback; @@ -148,27 +162,22 @@ struct MemoryUsagePass static std::optional on_enter(scf::ForOp &op, MemoryUsagePass &pass) { - std::optional numberOfIterations = tryGetStaticTripCount(op); + std::optional tripCount = tryGetStaticTripCount(op); - if (!numberOfIterations.has_value()) { - return StringError("only static loops can be analyzed"); + if (!tripCount.has_value()) { + emitWarning(op.getLoc(), "Cannot determine static trip count"); } - assert(numberOfIterations.value() > 0); - pass.iterations *= (uint64_t)numberOfIterations.value(); + pass.pushTripCount(op, tripCount); + return std::nullopt; } static std::optional on_exit(scf::ForOp &op, MemoryUsagePass &pass) { - std::optional numberOfIterations = tryGetStaticTripCount(op); + std::optional tripCount = tryGetStaticTripCount(op); + pass.popTripCount(op, tripCount); - if (!numberOfIterations.has_value()) { - return StringError("only static loops can be analyzed"); - } - - assert(numberOfIterations.value() > 0); - pass.iterations /= (uint64_t)numberOfIterations.value(); return std::nullopt; } @@ -179,18 +188,28 @@ struct MemoryUsagePass if (!maybeBufferSize) { return maybeBufferSize.error(); } + + std::optional memoryUsage = maybeBufferSize.value(); + // if the allocated buffer is being deallocated then count it as one. // Otherwise (and there must be a problem) multiply it by the number of // iterations - int64_t numberOfAlloc = - isBufferDeallocated(op.getResult()) ? 1 : pass.iterations; + if (!isBufferDeallocated(op.getResult())) { + if (pass.getTripCount().has_value()) + memoryUsage = memoryUsage.value() * pass.getTripCount().value(); + else + memoryUsage = std::nullopt; + } auto location = locationString(op.getLoc()); - // pass.iterations number of allocation of size: shape_1 * ... * shape_n * - // element_size - auto memoryUsage = numberOfAlloc * maybeBufferSize.value(); - pass.circuitFeedback->memoryUsagePerLoc[location] += memoryUsage; + if (pass.circuitFeedback->memoryUsagePerLoc.find(location) != + pass.circuitFeedback->memoryUsagePerLoc.end()) { + pass.circuitFeedback->memoryUsagePerLoc[location] = addOrNullopt( + pass.circuitFeedback->memoryUsagePerLoc[location], memoryUsage); + } else { + pass.circuitFeedback->memoryUsagePerLoc[location] = memoryUsage; + } return std::nullopt; } @@ -236,7 +255,13 @@ struct MemoryUsagePass } auto bufferSize = maybeBufferSize.value(); - pass.circuitFeedback->memoryUsagePerLoc[location] += bufferSize; + if (pass.circuitFeedback->memoryUsagePerLoc.find(location) != + pass.circuitFeedback->memoryUsagePerLoc.end()) { + pass.circuitFeedback->memoryUsagePerLoc[location] = addOrNullopt( + pass.circuitFeedback->memoryUsagePerLoc[location], bufferSize); + } else { + pass.circuitFeedback->memoryUsagePerLoc[location] = bufferSize; + } } } @@ -244,8 +269,6 @@ struct MemoryUsagePass } std::map> visitedValuesPerLoc; - - size_t iterations = 1; }; } // namespace Concrete diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/ExtractStatistics.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/ExtractStatistics.cpp index f417c5003..a68b3d015 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/ExtractStatistics.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/ExtractStatistics.cpp @@ -34,7 +34,8 @@ namespace TFHE { } struct ExtractTFHEStatisticsPass - : public PassWrapper> { + : public PassWrapper>, + public TripCountTracker { ProgramCompilationFeedback &feedback; CircuitCompilationFeedback *circuitFeedback; @@ -101,27 +102,22 @@ struct ExtractTFHEStatisticsPass static std::optional on_enter(scf::ForOp &op, ExtractTFHEStatisticsPass &pass) { - std::optional numberOfIterations = tryGetStaticTripCount(op); + std::optional tripCount = tryGetStaticTripCount(op); - if (!numberOfIterations.has_value()) { - return StringError("only static loops can be analyzed"); + if (!tripCount.has_value()) { + emitWarning(op.getLoc(), "Cannot determine static trip count"); } - assert(numberOfIterations.value() > 0); - pass.iterations *= (uint64_t)numberOfIterations.value(); + pass.pushTripCount(op, tripCount); + return std::nullopt; } static std::optional on_exit(scf::ForOp &op, ExtractTFHEStatisticsPass &pass) { - std::optional numberOfIterations = tryGetStaticTripCount(op); + std::optional tripCount = tryGetStaticTripCount(op); + pass.popTripCount(op, tripCount); - if (!numberOfIterations.has_value()) { - return StringError("only static loops can be analyzed"); - } - - assert(numberOfIterations.value() > 0); - pass.iterations /= (uint64_t)numberOfIterations.value(); return std::nullopt; } @@ -136,7 +132,7 @@ struct ExtractTFHEStatisticsPass auto location = locationString(op.getLoc()); auto operation = PrimitiveOperation::ENCRYPTED_ADDITION; auto keys = std::vector>(); - auto count = pass.iterations; + auto count = pass.getTripCount(); std::pair key = std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index); @@ -163,7 +159,7 @@ struct ExtractTFHEStatisticsPass auto location = locationString(op.getLoc()); auto operation = PrimitiveOperation::CLEAR_ADDITION; auto keys = std::vector>(); - auto count = pass.iterations; + auto count = pass.getTripCount(); std::pair key = std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index); @@ -190,7 +186,7 @@ struct ExtractTFHEStatisticsPass auto location = locationString(op.getLoc()); auto operation = PrimitiveOperation::PBS; auto keys = std::vector>(); - auto count = pass.iterations; + auto count = pass.getTripCount(); std::pair key = std::make_pair(KeyType::BOOTSTRAP, (int64_t)bsk.getIndex()); @@ -217,7 +213,7 @@ struct ExtractTFHEStatisticsPass auto location = locationString(op.getLoc()); auto operation = PrimitiveOperation::KEY_SWITCH; auto keys = std::vector>(); - auto count = pass.iterations; + auto count = pass.getTripCount(); std::pair key = std::make_pair(KeyType::KEY_SWITCH, (int64_t)ksk.getIndex()); @@ -244,7 +240,7 @@ struct ExtractTFHEStatisticsPass auto location = locationString(op.getLoc()); auto operation = PrimitiveOperation::CLEAR_MULTIPLICATION; auto keys = std::vector>(); - auto count = pass.iterations; + auto count = pass.getTripCount(); std::pair key = std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index); @@ -271,7 +267,7 @@ struct ExtractTFHEStatisticsPass auto location = locationString(op.getLoc()); auto operation = PrimitiveOperation::ENCRYPTED_NEGATION; auto keys = std::vector>(); - auto count = pass.iterations; + auto count = pass.getTripCount(); std::pair key = std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index); @@ -297,7 +293,7 @@ struct ExtractTFHEStatisticsPass auto location = locationString(op.getLoc()); auto keys = std::vector>(); - auto count = pass.iterations; + auto count = pass.getTripCount(); std::pair key = std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index); @@ -337,7 +333,7 @@ struct ExtractTFHEStatisticsPass auto location = locationString(op.getLoc()); auto operation = PrimitiveOperation::WOP_PBS; auto keys = std::vector>(); - auto count = pass.iterations; + auto count = pass.getTripCount(); std::pair key = std::make_pair(KeyType::BOOTSTRAP, (int64_t)bsk.getIndex()); @@ -358,8 +354,6 @@ struct ExtractTFHEStatisticsPass return std::nullopt; } - - int64_t iterations = 1; }; } // namespace TFHE diff --git a/compilers/concrete-compiler/compiler/lib/Support/CompilationFeedback.cpp b/compilers/concrete-compiler/compiler/lib/Support/CompilationFeedback.cpp index 48d685e83..f5395fa8f 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/CompilationFeedback.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/CompilationFeedback.cpp @@ -136,12 +136,13 @@ ProgramCompilationFeedback::load(std::string jsonPath) { return expectedCompFeedback.get(); } -llvm::json::Object -memoryUsageToJson(const std::map &memoryUsagePerLoc) { +llvm::json::Object memoryUsageToJson( + const std::map> &memoryUsagePerLoc) { auto object = llvm::json::Object(); for (auto key : memoryUsagePerLoc) { object.insert({key.first, key.second}); } + return object; }