mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-17 08:01:20 -05:00
refactor(compiler): clean statistic passes
This commit is contained in:
@@ -0,0 +1,30 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_ANALYSIS_UTILS_H
|
||||
#define CONCRETELANG_ANALYSIS_UTILS_H
|
||||
|
||||
#include <boost/outcome.h>
|
||||
#include <concretelang/Common/Error.h>
|
||||
#include <mlir/Dialect/SCF/IR/SCF.h>
|
||||
#include <mlir/IR/Location.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
/// Get the string representation of a location
|
||||
std::string locationString(mlir::Location loc);
|
||||
|
||||
/// Compute the number of iterations based on loop info
|
||||
int64_t calculateNumberOfIterations(int64_t start, int64_t stop, int64_t step);
|
||||
|
||||
/// Compute the number of iterations of an scf for loop
|
||||
outcome::checked<int64_t, ::concretelang::error::StringError>
|
||||
calculateNumberOfIterations(scf::ForOp &op);
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,13 @@
|
||||
#ifndef CONCRETELANG_DIALECT_CONCRETE_ANALYSIS
|
||||
#define CONCRETELANG_DIALECT_CONCRETE_ANALYSIS
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def MemoryUsage : Pass<"MemoryUsage", "::mlir::ModuleOp"> {
|
||||
let summary = "Compute memory usage";
|
||||
let description = [{
|
||||
Computes memory usage per location, and provides those numbers throught the CompilationFeedback.
|
||||
}];
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,4 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Analysis.td)
|
||||
mlir_tablegen(Analysis.h.inc -gen-pass-decls -name Analysis)
|
||||
add_public_tablegen_target(ConcretelangConcreteAnalysisPassIncGen)
|
||||
add_dependencies(mlir-headers ConcretelangConcreteAnalysisPassIncGen)
|
||||
@@ -7,59 +7,16 @@
|
||||
#define CONCRETELANG_DIALECT_CONCRETE_MEMORY_USAGE_H
|
||||
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/Operation.h>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
|
||||
#include <concretelang/Support/CompilationFeedback.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace Concrete {
|
||||
|
||||
struct MemoryUsagePass
|
||||
: public PassWrapper<MemoryUsagePass, OperationPass<ModuleOp>> {
|
||||
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
|
||||
createMemoryUsagePass(CompilationFeedback &feedback);
|
||||
|
||||
CompilationFeedback &feedback;
|
||||
|
||||
MemoryUsagePass(CompilationFeedback &feedback) : feedback{feedback} {};
|
||||
|
||||
void runOnOperation() override {
|
||||
WalkResult walk =
|
||||
getOperation()->walk([&](Operation *op, const WalkStage &stage) {
|
||||
if (stage.isBeforeAllRegions()) {
|
||||
std::optional<StringError> error = this->enter(op);
|
||||
if (error.has_value()) {
|
||||
op->emitError() << error->mesg;
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
if (stage.isAfterAllRegions()) {
|
||||
std::optional<StringError> error = this->exit(op);
|
||||
if (error.has_value()) {
|
||||
op->emitError() << error->mesg;
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (walk.wasInterrupted()) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<StringError> enter(Operation *op);
|
||||
|
||||
std::optional<StringError> exit(Operation *op);
|
||||
|
||||
std::map<std::string, std::vector<mlir::Value>> visitedValuesPerLoc;
|
||||
|
||||
size_t iterations = 1;
|
||||
};
|
||||
|
||||
} // namespace Concrete
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
add_subdirectory(Analysis)
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
#ifndef CONCRETELANG_DIALECT_TFHE_ANALYSIS
|
||||
#define CONCRETELANG_DIALECT_TFHE_ANALYSIS
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def ExtractStatistics : Pass<"ExtractStatistics", "::mlir::ModuleOp"> {
|
||||
let summary = "Extracts statistics";
|
||||
let description = [{
|
||||
Extracts different statistics (e.g. number of certain crypto operations),
|
||||
and provides those numbers throught the CompilationFeedback.
|
||||
}];
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,4 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Analysis.td)
|
||||
mlir_tablegen(Analysis.h.inc -gen-pass-decls -name Analysis)
|
||||
add_public_tablegen_target(ConcretelangTFHEAnalysisPassIncGen)
|
||||
add_dependencies(mlir-headers ConcretelangTFHEAnalysisPassIncGen)
|
||||
@@ -7,58 +7,15 @@
|
||||
#define CONCRETELANG_DIALECT_TFHE_ANALYSIS_EXTRACT_STATISTICS_H
|
||||
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/Operation.h>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
|
||||
#include <concretelang/Support/CompilationFeedback.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace TFHE {
|
||||
|
||||
struct ExtractTFHEStatisticsPass
|
||||
: public PassWrapper<ExtractTFHEStatisticsPass, OperationPass<ModuleOp>> {
|
||||
|
||||
CompilationFeedback &feedback;
|
||||
|
||||
ExtractTFHEStatisticsPass(CompilationFeedback &feedback)
|
||||
: feedback{feedback} {};
|
||||
|
||||
void runOnOperation() override {
|
||||
WalkResult walk =
|
||||
getOperation()->walk([&](Operation *op, const WalkStage &stage) {
|
||||
if (stage.isBeforeAllRegions()) {
|
||||
std::optional<StringError> error = this->enter(op);
|
||||
if (error.has_value()) {
|
||||
op->emitError() << error->mesg;
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
if (stage.isAfterAllRegions()) {
|
||||
std::optional<StringError> error = this->exit(op);
|
||||
if (error.has_value()) {
|
||||
op->emitError() << error->mesg;
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (walk.wasInterrupted()) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<StringError> enter(Operation *op);
|
||||
|
||||
std::optional<StringError> exit(Operation *op);
|
||||
|
||||
size_t iterations = 1;
|
||||
};
|
||||
|
||||
} // namespace TFHE
|
||||
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
|
||||
createStatisticExtractionPass(CompilationFeedback &feedback);
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
add_subdirectory(Analysis)
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
add_mlir_library(
|
||||
AnalysisUtils
|
||||
Utils.cpp
|
||||
DEPENDS
|
||||
mlir-headers
|
||||
LINK_LIBS
|
||||
PUBLIC
|
||||
MLIRIR)
|
||||
67
compilers/concrete-compiler/compiler/lib/Analysis/Utils.cpp
Normal file
67
compilers/concrete-compiler/compiler/lib/Analysis/Utils.cpp
Normal file
@@ -0,0 +1,67 @@
|
||||
#include <concretelang/Analysis/Utils.h>
|
||||
#include <mlir/Dialect/Arith/IR/Arith.h>
|
||||
|
||||
using ::concretelang::error::StringError;
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::string locationString(mlir::Location loc) {
|
||||
auto location = std::string();
|
||||
auto locationStream = llvm::raw_string_ostream(location);
|
||||
loc->print(locationStream);
|
||||
return location;
|
||||
}
|
||||
|
||||
int64_t calculateNumberOfIterations(int64_t start, int64_t stop, int64_t step) {
|
||||
int64_t high;
|
||||
int64_t low;
|
||||
|
||||
if (step > 0) {
|
||||
low = start;
|
||||
high = stop;
|
||||
} else {
|
||||
low = stop;
|
||||
high = start;
|
||||
step = -step;
|
||||
}
|
||||
|
||||
if (low >= high) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return ((high - low - 1) / step) + 1;
|
||||
}
|
||||
|
||||
outcome::checked<int64_t, StringError>
|
||||
calculateNumberOfIterations(scf::ForOp &op) {
|
||||
mlir::Value startValue = op.getLowerBound();
|
||||
mlir::Value stopValue = op.getUpperBound();
|
||||
mlir::Value stepValue = op.getStep();
|
||||
|
||||
auto startOp =
|
||||
llvm::dyn_cast_or_null<arith::ConstantOp>(startValue.getDefiningOp());
|
||||
auto stopOp =
|
||||
llvm::dyn_cast_or_null<arith::ConstantOp>(stopValue.getDefiningOp());
|
||||
auto stepOp =
|
||||
llvm::dyn_cast_or_null<arith::ConstantOp>(stepValue.getDefiningOp());
|
||||
|
||||
if (!startOp || !stopOp || !stepOp) {
|
||||
return StringError("only static loops can be analyzed");
|
||||
}
|
||||
|
||||
auto startAttr = startOp.getValue().cast<mlir::IntegerAttr>();
|
||||
auto stopAttr = stopOp.getValue().cast<mlir::IntegerAttr>();
|
||||
auto stepAttr = stepOp.getValue().cast<mlir::IntegerAttr>();
|
||||
|
||||
if (!startOp || !stopOp || !stepOp) {
|
||||
return StringError("only integer loops can be analyzed");
|
||||
}
|
||||
|
||||
int64_t start = startAttr.getInt();
|
||||
int64_t stop = stopAttr.getInt();
|
||||
int64_t step = stepAttr.getInt();
|
||||
|
||||
return calculateNumberOfIterations(start, stop, step);
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -262,6 +262,7 @@ const LLVM_TARGET_SPECIFIC_STATIC_LIBS: &[&str] = &[
|
||||
const LLVM_TARGET_SPECIFIC_STATIC_LIBS: &[&str] = &["LLVMX86CodeGen", "LLVMX86Desc", "LLVMX86Info"];
|
||||
|
||||
const CONCRETE_COMPILER_STATIC_LIBS: &[&str] = &[
|
||||
"AnalysisUtils",
|
||||
"RTDialect",
|
||||
"RTDialectTransforms",
|
||||
"ConcretelangSupport",
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
add_subdirectory(Analysis)
|
||||
add_subdirectory(Dialect)
|
||||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Transforms)
|
||||
|
||||
@@ -9,4 +9,5 @@ add_mlir_library(
|
||||
LINK_LIBS
|
||||
PUBLIC
|
||||
MLIRIR
|
||||
ConcreteDialect)
|
||||
ConcreteDialect
|
||||
AnalysisUtils)
|
||||
|
||||
@@ -1,109 +1,20 @@
|
||||
#include <concretelang/Analysis/Utils.h>
|
||||
#include <concretelang/Dialect/Concrete/Analysis/MemoryUsage.h>
|
||||
#include <concretelang/Dialect/Concrete/IR/ConcreteOps.h>
|
||||
#include <concretelang/Support/logging.h>
|
||||
#include <mlir/Dialect/Arith/IR/Arith.h>
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/Dialect/SCF/IR/SCF.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/Operation.h>
|
||||
#include <mlir/Interfaces/ViewLikeInterface.h>
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir::concretelang;
|
||||
using namespace mlir;
|
||||
|
||||
using Concrete::MemoryUsagePass;
|
||||
|
||||
namespace {
|
||||
|
||||
std::string locationString(mlir::Location loc) {
|
||||
auto location = std::string();
|
||||
auto locationStream = llvm::raw_string_ostream(location);
|
||||
loc->print(locationStream);
|
||||
return location;
|
||||
}
|
||||
|
||||
int64_t calculateNumberOfIterations(int64_t start, int64_t stop, int64_t step) {
|
||||
int64_t high;
|
||||
int64_t low;
|
||||
|
||||
if (step > 0) {
|
||||
low = start;
|
||||
high = stop;
|
||||
} else {
|
||||
low = stop;
|
||||
high = start;
|
||||
step = -step;
|
||||
}
|
||||
|
||||
if (low >= high) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return ((high - low - 1) / step) + 1;
|
||||
}
|
||||
|
||||
std::optional<StringError> calculateNumberOfIterations(scf::ForOp &op,
|
||||
int64_t &result) {
|
||||
mlir::Value startValue = op.getLowerBound();
|
||||
mlir::Value stopValue = op.getUpperBound();
|
||||
mlir::Value stepValue = op.getStep();
|
||||
|
||||
auto startOp =
|
||||
llvm::dyn_cast_or_null<arith::ConstantOp>(startValue.getDefiningOp());
|
||||
auto stopOp =
|
||||
llvm::dyn_cast_or_null<arith::ConstantOp>(stopValue.getDefiningOp());
|
||||
auto stepOp =
|
||||
llvm::dyn_cast_or_null<arith::ConstantOp>(stepValue.getDefiningOp());
|
||||
|
||||
if (!startOp || !stopOp || !stepOp) {
|
||||
return StringError("only static loops can be analyzed");
|
||||
}
|
||||
|
||||
auto startAttr = startOp.getValue().cast<mlir::IntegerAttr>();
|
||||
auto stopAttr = stopOp.getValue().cast<mlir::IntegerAttr>();
|
||||
auto stepAttr = stepOp.getValue().cast<mlir::IntegerAttr>();
|
||||
|
||||
if (!startOp || !stopOp || !stepOp) {
|
||||
return StringError("only integer loops can be analyzed");
|
||||
}
|
||||
|
||||
int64_t start = startAttr.getInt();
|
||||
int64_t stop = stopAttr.getInt();
|
||||
int64_t step = stepAttr.getInt();
|
||||
|
||||
result = calculateNumberOfIterations(start, stop, step);
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static std::optional<StringError> on_enter(scf::ForOp &op,
|
||||
MemoryUsagePass &pass) {
|
||||
int64_t numberOfIterations;
|
||||
|
||||
std::optional<StringError> error =
|
||||
calculateNumberOfIterations(op, numberOfIterations);
|
||||
if (error.has_value()) {
|
||||
return error;
|
||||
}
|
||||
|
||||
assert(numberOfIterations > 0);
|
||||
pass.iterations *= (uint64_t)numberOfIterations;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static std::optional<StringError> on_exit(scf::ForOp &op,
|
||||
MemoryUsagePass &pass) {
|
||||
int64_t numberOfIterations;
|
||||
|
||||
std::optional<StringError> error =
|
||||
calculateNumberOfIterations(op, numberOfIterations);
|
||||
if (error.has_value()) {
|
||||
return error;
|
||||
}
|
||||
|
||||
assert(numberOfIterations > 0);
|
||||
pass.iterations /= (uint64_t)numberOfIterations;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
int64_t getElementTypeSize(mlir::Type elementType) {
|
||||
if (auto integerType = mlir::dyn_cast<mlir::IntegerType>(elementType)) {
|
||||
auto width = integerType.getWidth();
|
||||
@@ -146,108 +57,185 @@ bool isBufferDeallocated(mlir::Value buffer) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static std::optional<StringError> on_enter(memref::AllocOp &op,
|
||||
MemoryUsagePass &pass) {
|
||||
|
||||
auto maybeBufferSize = getBufferSize(op.getResult().getType());
|
||||
if (!maybeBufferSize) {
|
||||
return maybeBufferSize.error();
|
||||
}
|
||||
// if the allocated buffer is being deallocated then count it as one.
|
||||
// Otherwise (and there must be a problem) multiply it by the number of
|
||||
// iterations
|
||||
int64_t numberOfAlloc =
|
||||
isBufferDeallocated(op.getResult()) ? 1 : pass.iterations;
|
||||
|
||||
auto location = locationString(op.getLoc());
|
||||
// pass.iterations number of allocation of size: shape_1 * ... * shape_n *
|
||||
// element_size
|
||||
auto memoryUsage = numberOfAlloc * maybeBufferSize.value();
|
||||
|
||||
pass.feedback.memoryUsagePerLoc[location] += memoryUsage;
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static std::optional<StringError> on_enter(mlir::Operation *op,
|
||||
MemoryUsagePass &pass) {
|
||||
for (auto operand : op->getOperands()) {
|
||||
// we only consider buffers
|
||||
if (!mlir::isa<mlir::MemRefType>(operand.getType()))
|
||||
continue;
|
||||
// find the origin of the buffer
|
||||
auto definingOp = operand.getDefiningOp();
|
||||
mlir::Value lastVisitedBuffer = operand;
|
||||
while (definingOp) {
|
||||
mlir::ViewLikeOpInterface viewLikeOp =
|
||||
mlir::dyn_cast<mlir::ViewLikeOpInterface>(definingOp);
|
||||
if (viewLikeOp) {
|
||||
lastVisitedBuffer = viewLikeOp.getViewSource();
|
||||
definingOp = lastVisitedBuffer.getDefiningOp();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// we already count allocations separately
|
||||
if (definingOp && mlir::isa<memref::AllocOp>(definingOp) &&
|
||||
definingOp->getLoc() == op->getLoc())
|
||||
continue;
|
||||
|
||||
auto location = locationString(op->getLoc());
|
||||
|
||||
std::vector<mlir::Value> &visited = pass.visitedValuesPerLoc[location];
|
||||
|
||||
// the search would be faster if we use an unsorted_set, but we need a hash
|
||||
// function for mlir::Value
|
||||
if (std::find(visited.begin(), visited.end(), lastVisitedBuffer) ==
|
||||
visited.end()) {
|
||||
visited.push_back(lastVisitedBuffer);
|
||||
|
||||
auto maybeBufferSize =
|
||||
getBufferSize(lastVisitedBuffer.getType().cast<mlir::MemRefType>());
|
||||
if (!maybeBufferSize) {
|
||||
return maybeBufferSize.error();
|
||||
}
|
||||
auto bufferSize = maybeBufferSize.value();
|
||||
|
||||
pass.feedback.memoryUsagePerLoc[location] += bufferSize;
|
||||
}
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::optional<StringError> MemoryUsagePass::enter(mlir::Operation *op) {
|
||||
// specialized calls
|
||||
if (auto typedOp = llvm::dyn_cast<scf::ForOp>(op)) {
|
||||
std::optional<StringError> error = on_enter(typedOp, *this);
|
||||
if (error.has_value()) {
|
||||
return error;
|
||||
}
|
||||
}
|
||||
if (auto typedOp = llvm::dyn_cast<memref::AllocOp>(op)) {
|
||||
std::optional<StringError> error = on_enter(typedOp, *this);
|
||||
if (error.has_value()) {
|
||||
return error;
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace Concrete {
|
||||
|
||||
struct MemoryUsagePass
|
||||
: public PassWrapper<MemoryUsagePass, OperationPass<ModuleOp>> {
|
||||
|
||||
CompilationFeedback &feedback;
|
||||
|
||||
MemoryUsagePass(CompilationFeedback &feedback) : feedback{feedback} {};
|
||||
|
||||
void runOnOperation() override {
|
||||
WalkResult walk =
|
||||
getOperation()->walk([&](Operation *op, const WalkStage &stage) {
|
||||
if (stage.isBeforeAllRegions()) {
|
||||
std::optional<StringError> error = this->enter(op);
|
||||
if (error.has_value()) {
|
||||
op->emitError() << error->mesg;
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
if (stage.isAfterAllRegions()) {
|
||||
std::optional<StringError> error = this->exit(op);
|
||||
if (error.has_value()) {
|
||||
op->emitError() << error->mesg;
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (walk.wasInterrupted()) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
// call generic enter
|
||||
std::optional<StringError> error = on_enter(op, *this);
|
||||
if (error.has_value()) {
|
||||
return error;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
std::optional<StringError> enter(mlir::Operation *op) {
|
||||
// specialized calls
|
||||
if (auto typedOp = llvm::dyn_cast<scf::ForOp>(op)) {
|
||||
std::optional<StringError> error = on_enter(typedOp, *this);
|
||||
if (error.has_value()) {
|
||||
return error;
|
||||
}
|
||||
}
|
||||
if (auto typedOp = llvm::dyn_cast<memref::AllocOp>(op)) {
|
||||
std::optional<StringError> error = on_enter(typedOp, *this);
|
||||
if (error.has_value()) {
|
||||
return error;
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<StringError> MemoryUsagePass::exit(mlir::Operation *op) {
|
||||
if (auto typedOp = llvm::dyn_cast<scf::ForOp>(op)) {
|
||||
std::optional<StringError> error = on_exit(typedOp, *this);
|
||||
// call generic enter
|
||||
std::optional<StringError> error = on_enter(op, *this);
|
||||
if (error.has_value()) {
|
||||
return error;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
return std::nullopt;
|
||||
|
||||
std::optional<StringError> exit(mlir::Operation *op) {
|
||||
if (auto typedOp = llvm::dyn_cast<scf::ForOp>(op)) {
|
||||
std::optional<StringError> error = on_exit(typedOp, *this);
|
||||
if (error.has_value()) {
|
||||
return error;
|
||||
}
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static std::optional<StringError> on_enter(scf::ForOp &op,
|
||||
MemoryUsagePass &pass) {
|
||||
auto numberOfIterations = calculateNumberOfIterations(op);
|
||||
if (!numberOfIterations) {
|
||||
return numberOfIterations.error();
|
||||
}
|
||||
|
||||
assert(numberOfIterations.value() > 0);
|
||||
pass.iterations *= (uint64_t)numberOfIterations.value();
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static std::optional<StringError> on_exit(scf::ForOp &op,
|
||||
MemoryUsagePass &pass) {
|
||||
auto numberOfIterations = calculateNumberOfIterations(op);
|
||||
if (!numberOfIterations) {
|
||||
return numberOfIterations.error();
|
||||
}
|
||||
|
||||
assert(numberOfIterations.value() > 0);
|
||||
pass.iterations /= (uint64_t)numberOfIterations.value();
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static std::optional<StringError> on_enter(memref::AllocOp &op,
|
||||
MemoryUsagePass &pass) {
|
||||
|
||||
auto maybeBufferSize = getBufferSize(op.getResult().getType());
|
||||
if (!maybeBufferSize) {
|
||||
return maybeBufferSize.error();
|
||||
}
|
||||
// if the allocated buffer is being deallocated then count it as one.
|
||||
// Otherwise (and there must be a problem) multiply it by the number of
|
||||
// iterations
|
||||
int64_t numberOfAlloc =
|
||||
isBufferDeallocated(op.getResult()) ? 1 : pass.iterations;
|
||||
|
||||
auto location = locationString(op.getLoc());
|
||||
// pass.iterations number of allocation of size: shape_1 * ... * shape_n *
|
||||
// element_size
|
||||
auto memoryUsage = numberOfAlloc * maybeBufferSize.value();
|
||||
|
||||
pass.feedback.memoryUsagePerLoc[location] += memoryUsage;
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static std::optional<StringError> on_enter(mlir::Operation *op,
|
||||
MemoryUsagePass &pass) {
|
||||
for (auto operand : op->getOperands()) {
|
||||
// we only consider buffers
|
||||
if (!mlir::isa<mlir::MemRefType>(operand.getType()))
|
||||
continue;
|
||||
// find the origin of the buffer
|
||||
auto definingOp = operand.getDefiningOp();
|
||||
mlir::Value lastVisitedBuffer = operand;
|
||||
while (definingOp) {
|
||||
mlir::ViewLikeOpInterface viewLikeOp =
|
||||
mlir::dyn_cast<mlir::ViewLikeOpInterface>(definingOp);
|
||||
if (viewLikeOp) {
|
||||
lastVisitedBuffer = viewLikeOp.getViewSource();
|
||||
definingOp = lastVisitedBuffer.getDefiningOp();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// we already count allocations separately
|
||||
if (definingOp && mlir::isa<memref::AllocOp>(definingOp) &&
|
||||
definingOp->getLoc() == op->getLoc())
|
||||
continue;
|
||||
|
||||
auto location = locationString(op->getLoc());
|
||||
|
||||
std::vector<mlir::Value> &visited = pass.visitedValuesPerLoc[location];
|
||||
|
||||
// the search would be faster if we use an unsorted_set, but we need a
|
||||
// hash function for mlir::Value
|
||||
if (std::find(visited.begin(), visited.end(), lastVisitedBuffer) ==
|
||||
visited.end()) {
|
||||
visited.push_back(lastVisitedBuffer);
|
||||
|
||||
auto maybeBufferSize =
|
||||
getBufferSize(lastVisitedBuffer.getType().cast<mlir::MemRefType>());
|
||||
if (!maybeBufferSize) {
|
||||
return maybeBufferSize.error();
|
||||
}
|
||||
auto bufferSize = maybeBufferSize.value();
|
||||
|
||||
pass.feedback.memoryUsagePerLoc[location] += bufferSize;
|
||||
}
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::map<std::string, std::vector<mlir::Value>> visitedValuesPerLoc;
|
||||
|
||||
size_t iterations = 1;
|
||||
};
|
||||
|
||||
} // namespace Concrete
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createMemoryUsagePass(CompilationFeedback &feedback) {
|
||||
return std::make_unique<Concrete::MemoryUsagePass>(feedback);
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -9,4 +9,5 @@ add_mlir_library(
|
||||
LINK_LIBS
|
||||
PUBLIC
|
||||
MLIRIR
|
||||
TFHEDialect)
|
||||
TFHEDialect
|
||||
AnalysisUtils)
|
||||
|
||||
@@ -1,350 +1,19 @@
|
||||
#include <concretelang/Analysis/Utils.h>
|
||||
#include <concretelang/Dialect/TFHE/Analysis/ExtractStatistics.h>
|
||||
|
||||
#include <mlir/Dialect/Arith/IR/Arith.h>
|
||||
#include <mlir/Dialect/SCF/IR/SCF.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/Operation.h>
|
||||
|
||||
#include <concretelang/Dialect/TFHE/IR/TFHEOps.h>
|
||||
|
||||
using namespace mlir::concretelang;
|
||||
using namespace mlir;
|
||||
|
||||
using TFHE::ExtractTFHEStatisticsPass;
|
||||
|
||||
// #########
|
||||
// Utilities
|
||||
// #########
|
||||
|
||||
template <typename Op> std::string locationOf(Op op) {
|
||||
auto location = std::string();
|
||||
auto locationStream = llvm::raw_string_ostream(location);
|
||||
op.getLoc()->print(locationStream);
|
||||
return location.substr(5, location.size() - 2 - 5); // remove loc(" and ")
|
||||
}
|
||||
|
||||
// #######
|
||||
// scf.for
|
||||
// #######
|
||||
|
||||
int64_t calculateNumberOfIterations(int64_t start, int64_t stop, int64_t step) {
|
||||
int64_t high;
|
||||
int64_t low;
|
||||
|
||||
if (step > 0) {
|
||||
low = start;
|
||||
high = stop;
|
||||
} else {
|
||||
low = stop;
|
||||
high = start;
|
||||
step = -step;
|
||||
}
|
||||
|
||||
if (low >= high) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return ((high - low - 1) / step) + 1;
|
||||
}
|
||||
|
||||
std::optional<StringError> calculateNumberOfIterations(scf::ForOp &op,
|
||||
int64_t &result) {
|
||||
mlir::Value startValue = op.getLowerBound();
|
||||
mlir::Value stopValue = op.getUpperBound();
|
||||
mlir::Value stepValue = op.getStep();
|
||||
|
||||
auto startOp =
|
||||
llvm::dyn_cast_or_null<arith::ConstantOp>(startValue.getDefiningOp());
|
||||
auto stopOp =
|
||||
llvm::dyn_cast_or_null<arith::ConstantOp>(stopValue.getDefiningOp());
|
||||
auto stepOp =
|
||||
llvm::dyn_cast_or_null<arith::ConstantOp>(stepValue.getDefiningOp());
|
||||
|
||||
if (!startOp || !stopOp || !stepOp) {
|
||||
return StringError("only static loops can be analyzed");
|
||||
}
|
||||
|
||||
auto startAttr = startOp.getValue().cast<mlir::IntegerAttr>();
|
||||
auto stopAttr = stopOp.getValue().cast<mlir::IntegerAttr>();
|
||||
auto stepAttr = stepOp.getValue().cast<mlir::IntegerAttr>();
|
||||
|
||||
if (!startOp || !stopOp || !stepOp) {
|
||||
return StringError("only integer loops can be analyzed");
|
||||
}
|
||||
|
||||
int64_t start = startAttr.getInt();
|
||||
int64_t stop = stopAttr.getInt();
|
||||
int64_t step = stepAttr.getInt();
|
||||
|
||||
result = calculateNumberOfIterations(start, stop, step);
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static std::optional<StringError> on_enter(scf::ForOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
int64_t numberOfIterations;
|
||||
|
||||
std::optional<StringError> error =
|
||||
calculateNumberOfIterations(op, numberOfIterations);
|
||||
if (error.has_value()) {
|
||||
return error;
|
||||
}
|
||||
|
||||
assert(numberOfIterations > 0);
|
||||
pass.iterations *= (uint64_t)numberOfIterations;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static std::optional<StringError> on_exit(scf::ForOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
int64_t numberOfIterations;
|
||||
|
||||
std::optional<StringError> error =
|
||||
calculateNumberOfIterations(op, numberOfIterations);
|
||||
if (error.has_value()) {
|
||||
return error;
|
||||
}
|
||||
|
||||
assert(numberOfIterations > 0);
|
||||
pass.iterations /= (uint64_t)numberOfIterations;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// #############
|
||||
// TFHE.add_glwe
|
||||
// #############
|
||||
|
||||
static std::optional<StringError> on_enter(TFHE::AddGLWEOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
auto resultingKey = op.getType().getKey().getNormalized();
|
||||
|
||||
auto location = locationOf(op);
|
||||
auto operation = PrimitiveOperation::ENCRYPTED_ADDITION;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// #################
|
||||
// TFHE.add_glwe_int
|
||||
// #################
|
||||
|
||||
static std::optional<StringError> on_enter(TFHE::AddGLWEIntOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
auto resultingKey = op.getType().getKey().getNormalized();
|
||||
|
||||
auto location = locationOf(op);
|
||||
auto operation = PrimitiveOperation::CLEAR_ADDITION;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// ###################
|
||||
// TFHE.bootstrap_glwe
|
||||
// ###################
|
||||
|
||||
static std::optional<StringError> on_enter(TFHE::BootstrapGLWEOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
auto bsk = op.getKey();
|
||||
|
||||
auto location = locationOf(op);
|
||||
auto operation = PrimitiveOperation::PBS;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::BOOTSTRAP, (size_t)bsk.getIndex());
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// ###################
|
||||
// TFHE.keyswitch_glwe
|
||||
// ###################
|
||||
|
||||
static std::optional<StringError> on_enter(TFHE::KeySwitchGLWEOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
auto ksk = op.getKey();
|
||||
|
||||
auto location = locationOf(op);
|
||||
auto operation = PrimitiveOperation::KEY_SWITCH;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::KEY_SWITCH, (size_t)ksk.getIndex());
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// #################
|
||||
// TFHE.mul_glwe_int
|
||||
// #################
|
||||
|
||||
static std::optional<StringError> on_enter(TFHE::MulGLWEIntOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
auto resultingKey = op.getType().getKey().getNormalized();
|
||||
|
||||
auto location = locationOf(op);
|
||||
auto operation = PrimitiveOperation::CLEAR_MULTIPLICATION;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// #############
|
||||
// TFHE.neg_glwe
|
||||
// #############
|
||||
|
||||
static std::optional<StringError> on_enter(TFHE::NegGLWEOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
auto resultingKey = op.getType().getKey().getNormalized();
|
||||
|
||||
auto location = locationOf(op);
|
||||
auto operation = PrimitiveOperation::ENCRYPTED_NEGATION;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// #################
|
||||
// TFHE.sub_int_glwe
|
||||
// #################
|
||||
|
||||
static std::optional<StringError> on_enter(TFHE::SubGLWEIntOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
auto resultingKey = op.getType().getKey().getNormalized();
|
||||
|
||||
auto location = locationOf(op);
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
|
||||
keys.push_back(key);
|
||||
|
||||
// clear - encrypted = clear + neg(encrypted)
|
||||
|
||||
auto operation = PrimitiveOperation::ENCRYPTED_NEGATION;
|
||||
pass.feedback.statistics.push_back(Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
operation = PrimitiveOperation::CLEAR_ADDITION;
|
||||
pass.feedback.statistics.push_back(Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// #################
|
||||
// TFHE.wop_pbs_glwe
|
||||
// #################
|
||||
|
||||
static std::optional<StringError> on_enter(TFHE::WopPBSGLWEOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
auto bsk = op.getBsk();
|
||||
auto ksk = op.getKsk();
|
||||
auto pksk = op.getPksk();
|
||||
|
||||
auto location = locationOf(op);
|
||||
auto operation = PrimitiveOperation::WOP_PBS;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::BOOTSTRAP, (size_t)bsk.getIndex());
|
||||
keys.push_back(key);
|
||||
|
||||
key = std::make_pair(KeyType::KEY_SWITCH, (size_t)ksk.getIndex());
|
||||
keys.push_back(key);
|
||||
|
||||
key = std::make_pair(KeyType::PACKING_KEY_SWITCH, (size_t)pksk.getIndex());
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// ########
|
||||
// Dispatch
|
||||
// ########
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace TFHE {
|
||||
|
||||
#define DISPATCH_ENTER(type) \
|
||||
if (auto typedOp = llvm::dyn_cast<type>(op)) { \
|
||||
@@ -362,22 +31,326 @@ static std::optional<StringError> on_enter(TFHE::WopPBSGLWEOp &op,
|
||||
} \
|
||||
}
|
||||
|
||||
std::optional<StringError>
|
||||
ExtractTFHEStatisticsPass::enter(mlir::Operation *op) {
|
||||
DISPATCH_ENTER(scf::ForOp)
|
||||
DISPATCH_ENTER(TFHE::AddGLWEOp)
|
||||
DISPATCH_ENTER(TFHE::AddGLWEIntOp)
|
||||
DISPATCH_ENTER(TFHE::BootstrapGLWEOp)
|
||||
DISPATCH_ENTER(TFHE::KeySwitchGLWEOp)
|
||||
DISPATCH_ENTER(TFHE::MulGLWEIntOp)
|
||||
DISPATCH_ENTER(TFHE::NegGLWEOp)
|
||||
DISPATCH_ENTER(TFHE::SubGLWEIntOp)
|
||||
DISPATCH_ENTER(TFHE::WopPBSGLWEOp)
|
||||
return std::nullopt;
|
||||
struct ExtractTFHEStatisticsPass
|
||||
: public PassWrapper<ExtractTFHEStatisticsPass, OperationPass<ModuleOp>> {
|
||||
|
||||
CompilationFeedback &feedback;
|
||||
|
||||
ExtractTFHEStatisticsPass(CompilationFeedback &feedback)
|
||||
: feedback{feedback} {};
|
||||
|
||||
void runOnOperation() override {
|
||||
WalkResult walk =
|
||||
getOperation()->walk([&](Operation *op, const WalkStage &stage) {
|
||||
if (stage.isBeforeAllRegions()) {
|
||||
std::optional<StringError> error = this->enter(op);
|
||||
if (error.has_value()) {
|
||||
op->emitError() << error->mesg;
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
if (stage.isAfterAllRegions()) {
|
||||
std::optional<StringError> error = this->exit(op);
|
||||
if (error.has_value()) {
|
||||
op->emitError() << error->mesg;
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (walk.wasInterrupted()) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<StringError> enter(mlir::Operation *op) {
|
||||
DISPATCH_ENTER(scf::ForOp)
|
||||
DISPATCH_ENTER(TFHE::AddGLWEOp)
|
||||
DISPATCH_ENTER(TFHE::AddGLWEIntOp)
|
||||
DISPATCH_ENTER(TFHE::BootstrapGLWEOp)
|
||||
DISPATCH_ENTER(TFHE::KeySwitchGLWEOp)
|
||||
DISPATCH_ENTER(TFHE::MulGLWEIntOp)
|
||||
DISPATCH_ENTER(TFHE::NegGLWEOp)
|
||||
DISPATCH_ENTER(TFHE::SubGLWEIntOp)
|
||||
DISPATCH_ENTER(TFHE::WopPBSGLWEOp)
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<StringError> exit(mlir::Operation *op) {
|
||||
DISPATCH_EXIT(scf::ForOp)
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static std::optional<StringError> on_enter(scf::ForOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
auto numberOfIterations = calculateNumberOfIterations(op);
|
||||
if (!numberOfIterations) {
|
||||
return numberOfIterations.error();
|
||||
}
|
||||
|
||||
assert(numberOfIterations.value() > 0);
|
||||
pass.iterations *= (uint64_t)numberOfIterations.value();
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static std::optional<StringError> on_exit(scf::ForOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
auto numberOfIterations = calculateNumberOfIterations(op);
|
||||
if (!numberOfIterations) {
|
||||
return numberOfIterations.error();
|
||||
}
|
||||
|
||||
assert(numberOfIterations.value() > 0);
|
||||
pass.iterations /= (uint64_t)numberOfIterations.value();
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// #############
|
||||
// TFHE.add_glwe
|
||||
// #############
|
||||
|
||||
static std::optional<StringError> on_enter(TFHE::AddGLWEOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
auto resultingKey = op.getType().getKey().getNormalized();
|
||||
|
||||
auto location = locationString(op.getLoc());
|
||||
auto operation = PrimitiveOperation::ENCRYPTED_ADDITION;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(concretelang::Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// #################
|
||||
// TFHE.add_glwe_int
|
||||
// #################
|
||||
|
||||
static std::optional<StringError> on_enter(TFHE::AddGLWEIntOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
auto resultingKey = op.getType().getKey().getNormalized();
|
||||
|
||||
auto location = locationString(op.getLoc());
|
||||
auto operation = PrimitiveOperation::CLEAR_ADDITION;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(concretelang::Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// ###################
|
||||
// TFHE.bootstrap_glwe
|
||||
// ###################
|
||||
|
||||
static std::optional<StringError> on_enter(TFHE::BootstrapGLWEOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
auto bsk = op.getKey();
|
||||
|
||||
auto location = locationString(op.getLoc());
|
||||
auto operation = PrimitiveOperation::PBS;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::BOOTSTRAP, (size_t)bsk.getIndex());
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(concretelang::Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// ###################
|
||||
// TFHE.keyswitch_glwe
|
||||
// ###################
|
||||
|
||||
static std::optional<StringError> on_enter(TFHE::KeySwitchGLWEOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
auto ksk = op.getKey();
|
||||
|
||||
auto location = locationString(op.getLoc());
|
||||
auto operation = PrimitiveOperation::KEY_SWITCH;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::KEY_SWITCH, (size_t)ksk.getIndex());
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(concretelang::Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// #################
|
||||
// TFHE.mul_glwe_int
|
||||
// #################
|
||||
|
||||
static std::optional<StringError> on_enter(TFHE::MulGLWEIntOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
auto resultingKey = op.getType().getKey().getNormalized();
|
||||
|
||||
auto location = locationString(op.getLoc());
|
||||
auto operation = PrimitiveOperation::CLEAR_MULTIPLICATION;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(concretelang::Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// #############
|
||||
// TFHE.neg_glwe
|
||||
// #############
|
||||
|
||||
static std::optional<StringError> on_enter(TFHE::NegGLWEOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
auto resultingKey = op.getType().getKey().getNormalized();
|
||||
|
||||
auto location = locationString(op.getLoc());
|
||||
auto operation = PrimitiveOperation::ENCRYPTED_NEGATION;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(concretelang::Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// #################
|
||||
// TFHE.sub_int_glwe
|
||||
// #################
|
||||
|
||||
static std::optional<StringError> on_enter(TFHE::SubGLWEIntOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
auto resultingKey = op.getType().getKey().getNormalized();
|
||||
|
||||
auto location = locationString(op.getLoc());
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
|
||||
keys.push_back(key);
|
||||
|
||||
// clear - encrypted = clear + neg(encrypted)
|
||||
|
||||
auto operation = PrimitiveOperation::ENCRYPTED_NEGATION;
|
||||
pass.feedback.statistics.push_back(concretelang::Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
operation = PrimitiveOperation::CLEAR_ADDITION;
|
||||
pass.feedback.statistics.push_back(concretelang::Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// #################
|
||||
// TFHE.wop_pbs_glwe
|
||||
// #################
|
||||
|
||||
static std::optional<StringError> on_enter(TFHE::WopPBSGLWEOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
auto bsk = op.getBsk();
|
||||
auto ksk = op.getKsk();
|
||||
auto pksk = op.getPksk();
|
||||
|
||||
auto location = locationString(op.getLoc());
|
||||
auto operation = PrimitiveOperation::WOP_PBS;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::BOOTSTRAP, (size_t)bsk.getIndex());
|
||||
keys.push_back(key);
|
||||
|
||||
key = std::make_pair(KeyType::KEY_SWITCH, (size_t)ksk.getIndex());
|
||||
keys.push_back(key);
|
||||
|
||||
key = std::make_pair(KeyType::PACKING_KEY_SWITCH, (size_t)pksk.getIndex());
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(concretelang::Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
size_t iterations = 1;
|
||||
};
|
||||
|
||||
} // namespace TFHE
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createStatisticExtractionPass(CompilationFeedback &feedback) {
|
||||
return std::make_unique<TFHE::ExtractTFHEStatisticsPass>(feedback);
|
||||
}
|
||||
|
||||
std::optional<StringError>
|
||||
ExtractTFHEStatisticsPass::exit(mlir::Operation *op) {
|
||||
DISPATCH_EXIT(scf::ForOp)
|
||||
return std::nullopt;
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -332,7 +332,7 @@ extractTFHEStatistics(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
pipelinePrinting("TFHEStatistics", pm, context);
|
||||
|
||||
addPotentiallyNestedPass(
|
||||
pm, std::make_unique<TFHE::ExtractTFHEStatisticsPass>(feedback),
|
||||
pm, mlir::concretelang::createStatisticExtractionPass(feedback),
|
||||
enablePass);
|
||||
|
||||
return pm.run(module.getOperation());
|
||||
@@ -358,7 +358,7 @@ computeMemoryUsage(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
pipelinePrinting("Computing Memory Usage", pm, context);
|
||||
|
||||
addPotentiallyNestedPass(
|
||||
pm, std::make_unique<Concrete::MemoryUsagePass>(feedback), enablePass);
|
||||
pm, mlir::concretelang::createMemoryUsagePass(feedback), enablePass);
|
||||
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user