refactor(compiler): clean statistic passes

This commit is contained in:
youben11
2023-08-29 15:49:09 +01:00
committed by Ayoub Benaissa
parent 4e8b9a199c
commit 530bacb2e3
18 changed files with 655 additions and 634 deletions

View File

@@ -0,0 +1,30 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_ANALYSIS_UTILS_H
#define CONCRETELANG_ANALYSIS_UTILS_H
#include <boost/outcome.h>
#include <concretelang/Common/Error.h>
#include <mlir/Dialect/SCF/IR/SCF.h>
#include <mlir/IR/Location.h>
namespace mlir {
namespace concretelang {
/// Get the string representation of a location
std::string locationString(mlir::Location loc);
/// Compute the number of iterations based on loop info
int64_t calculateNumberOfIterations(int64_t start, int64_t stop, int64_t step);
/// Compute the number of iterations of an scf for loop
outcome::checked<int64_t, ::concretelang::error::StringError>
calculateNumberOfIterations(scf::ForOp &op);
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -0,0 +1,13 @@
#ifndef CONCRETELANG_DIALECT_CONCRETE_ANALYSIS
#define CONCRETELANG_DIALECT_CONCRETE_ANALYSIS
include "mlir/Pass/PassBase.td"
def MemoryUsage : Pass<"MemoryUsage", "::mlir::ModuleOp"> {
let summary = "Compute memory usage";
let description = [{
Computes memory usage per location, and provides those numbers throught the CompilationFeedback.
}];
}
#endif

View File

@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS Analysis.td)
mlir_tablegen(Analysis.h.inc -gen-pass-decls -name Analysis)
add_public_tablegen_target(ConcretelangConcreteAnalysisPassIncGen)
add_dependencies(mlir-headers ConcretelangConcreteAnalysisPassIncGen)

View File

@@ -7,59 +7,16 @@
#define CONCRETELANG_DIALECT_CONCRETE_MEMORY_USAGE_H
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/Operation.h>
#include <mlir/Pass/Pass.h>
#include <concretelang/Support/CompilationFeedback.h>
namespace mlir {
namespace concretelang {
namespace Concrete {
struct MemoryUsagePass
: public PassWrapper<MemoryUsagePass, OperationPass<ModuleOp>> {
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createMemoryUsagePass(CompilationFeedback &feedback);
CompilationFeedback &feedback;
MemoryUsagePass(CompilationFeedback &feedback) : feedback{feedback} {};
void runOnOperation() override {
WalkResult walk =
getOperation()->walk([&](Operation *op, const WalkStage &stage) {
if (stage.isBeforeAllRegions()) {
std::optional<StringError> error = this->enter(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
}
}
if (stage.isAfterAllRegions()) {
std::optional<StringError> error = this->exit(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});
if (walk.wasInterrupted()) {
signalPassFailure();
}
}
std::optional<StringError> enter(Operation *op);
std::optional<StringError> exit(Operation *op);
std::map<std::string, std::vector<mlir::Value>> visitedValuesPerLoc;
size_t iterations = 1;
};
} // namespace Concrete
} // namespace concretelang
} // namespace mlir

View File

@@ -1,2 +1,3 @@
add_subdirectory(Analysis)
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -0,0 +1,14 @@
#ifndef CONCRETELANG_DIALECT_TFHE_ANALYSIS
#define CONCRETELANG_DIALECT_TFHE_ANALYSIS
include "mlir/Pass/PassBase.td"
def ExtractStatistics : Pass<"ExtractStatistics", "::mlir::ModuleOp"> {
let summary = "Extracts statistics";
let description = [{
Extracts different statistics (e.g. number of certain crypto operations),
and provides those numbers throught the CompilationFeedback.
}];
}
#endif

View File

@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS Analysis.td)
mlir_tablegen(Analysis.h.inc -gen-pass-decls -name Analysis)
add_public_tablegen_target(ConcretelangTFHEAnalysisPassIncGen)
add_dependencies(mlir-headers ConcretelangTFHEAnalysisPassIncGen)

View File

@@ -7,58 +7,15 @@
#define CONCRETELANG_DIALECT_TFHE_ANALYSIS_EXTRACT_STATISTICS_H
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/Operation.h>
#include <mlir/Pass/Pass.h>
#include <concretelang/Support/CompilationFeedback.h>
namespace mlir {
namespace concretelang {
namespace TFHE {
struct ExtractTFHEStatisticsPass
: public PassWrapper<ExtractTFHEStatisticsPass, OperationPass<ModuleOp>> {
CompilationFeedback &feedback;
ExtractTFHEStatisticsPass(CompilationFeedback &feedback)
: feedback{feedback} {};
void runOnOperation() override {
WalkResult walk =
getOperation()->walk([&](Operation *op, const WalkStage &stage) {
if (stage.isBeforeAllRegions()) {
std::optional<StringError> error = this->enter(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
}
}
if (stage.isAfterAllRegions()) {
std::optional<StringError> error = this->exit(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});
if (walk.wasInterrupted()) {
signalPassFailure();
}
}
std::optional<StringError> enter(Operation *op);
std::optional<StringError> exit(Operation *op);
size_t iterations = 1;
};
} // namespace TFHE
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createStatisticExtractionPass(CompilationFeedback &feedback);
} // namespace concretelang
} // namespace mlir

View File

@@ -1,2 +1,3 @@
add_subdirectory(Analysis)
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -0,0 +1,8 @@
add_mlir_library(
AnalysisUtils
Utils.cpp
DEPENDS
mlir-headers
LINK_LIBS
PUBLIC
MLIRIR)

View File

@@ -0,0 +1,67 @@
#include <concretelang/Analysis/Utils.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
using ::concretelang::error::StringError;
namespace mlir {
namespace concretelang {
std::string locationString(mlir::Location loc) {
auto location = std::string();
auto locationStream = llvm::raw_string_ostream(location);
loc->print(locationStream);
return location;
}
int64_t calculateNumberOfIterations(int64_t start, int64_t stop, int64_t step) {
int64_t high;
int64_t low;
if (step > 0) {
low = start;
high = stop;
} else {
low = stop;
high = start;
step = -step;
}
if (low >= high) {
return 0;
}
return ((high - low - 1) / step) + 1;
}
outcome::checked<int64_t, StringError>
calculateNumberOfIterations(scf::ForOp &op) {
mlir::Value startValue = op.getLowerBound();
mlir::Value stopValue = op.getUpperBound();
mlir::Value stepValue = op.getStep();
auto startOp =
llvm::dyn_cast_or_null<arith::ConstantOp>(startValue.getDefiningOp());
auto stopOp =
llvm::dyn_cast_or_null<arith::ConstantOp>(stopValue.getDefiningOp());
auto stepOp =
llvm::dyn_cast_or_null<arith::ConstantOp>(stepValue.getDefiningOp());
if (!startOp || !stopOp || !stepOp) {
return StringError("only static loops can be analyzed");
}
auto startAttr = startOp.getValue().cast<mlir::IntegerAttr>();
auto stopAttr = stopOp.getValue().cast<mlir::IntegerAttr>();
auto stepAttr = stepOp.getValue().cast<mlir::IntegerAttr>();
if (!startOp || !stopOp || !stepOp) {
return StringError("only integer loops can be analyzed");
}
int64_t start = startAttr.getInt();
int64_t stop = stopAttr.getInt();
int64_t step = stepAttr.getInt();
return calculateNumberOfIterations(start, stop, step);
}
} // namespace concretelang
} // namespace mlir

View File

@@ -262,6 +262,7 @@ const LLVM_TARGET_SPECIFIC_STATIC_LIBS: &[&str] = &[
const LLVM_TARGET_SPECIFIC_STATIC_LIBS: &[&str] = &["LLVMX86CodeGen", "LLVMX86Desc", "LLVMX86Info"];
const CONCRETE_COMPILER_STATIC_LIBS: &[&str] = &[
"AnalysisUtils",
"RTDialect",
"RTDialectTransforms",
"ConcretelangSupport",

View File

@@ -1,3 +1,4 @@
add_subdirectory(Analysis)
add_subdirectory(Dialect)
add_subdirectory(Conversion)
add_subdirectory(Transforms)

View File

@@ -9,4 +9,5 @@ add_mlir_library(
LINK_LIBS
PUBLIC
MLIRIR
ConcreteDialect)
ConcreteDialect
AnalysisUtils)

View File

@@ -1,109 +1,20 @@
#include <concretelang/Analysis/Utils.h>
#include <concretelang/Dialect/Concrete/Analysis/MemoryUsage.h>
#include <concretelang/Dialect/Concrete/IR/ConcreteOps.h>
#include <concretelang/Support/logging.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Dialect/SCF/IR/SCF.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/Operation.h>
#include <mlir/Interfaces/ViewLikeInterface.h>
#include <numeric>
using namespace mlir::concretelang;
using namespace mlir;
using Concrete::MemoryUsagePass;
namespace {
std::string locationString(mlir::Location loc) {
auto location = std::string();
auto locationStream = llvm::raw_string_ostream(location);
loc->print(locationStream);
return location;
}
int64_t calculateNumberOfIterations(int64_t start, int64_t stop, int64_t step) {
int64_t high;
int64_t low;
if (step > 0) {
low = start;
high = stop;
} else {
low = stop;
high = start;
step = -step;
}
if (low >= high) {
return 0;
}
return ((high - low - 1) / step) + 1;
}
std::optional<StringError> calculateNumberOfIterations(scf::ForOp &op,
int64_t &result) {
mlir::Value startValue = op.getLowerBound();
mlir::Value stopValue = op.getUpperBound();
mlir::Value stepValue = op.getStep();
auto startOp =
llvm::dyn_cast_or_null<arith::ConstantOp>(startValue.getDefiningOp());
auto stopOp =
llvm::dyn_cast_or_null<arith::ConstantOp>(stopValue.getDefiningOp());
auto stepOp =
llvm::dyn_cast_or_null<arith::ConstantOp>(stepValue.getDefiningOp());
if (!startOp || !stopOp || !stepOp) {
return StringError("only static loops can be analyzed");
}
auto startAttr = startOp.getValue().cast<mlir::IntegerAttr>();
auto stopAttr = stopOp.getValue().cast<mlir::IntegerAttr>();
auto stepAttr = stepOp.getValue().cast<mlir::IntegerAttr>();
if (!startOp || !stopOp || !stepOp) {
return StringError("only integer loops can be analyzed");
}
int64_t start = startAttr.getInt();
int64_t stop = stopAttr.getInt();
int64_t step = stepAttr.getInt();
result = calculateNumberOfIterations(start, stop, step);
return std::nullopt;
}
static std::optional<StringError> on_enter(scf::ForOp &op,
MemoryUsagePass &pass) {
int64_t numberOfIterations;
std::optional<StringError> error =
calculateNumberOfIterations(op, numberOfIterations);
if (error.has_value()) {
return error;
}
assert(numberOfIterations > 0);
pass.iterations *= (uint64_t)numberOfIterations;
return std::nullopt;
}
static std::optional<StringError> on_exit(scf::ForOp &op,
MemoryUsagePass &pass) {
int64_t numberOfIterations;
std::optional<StringError> error =
calculateNumberOfIterations(op, numberOfIterations);
if (error.has_value()) {
return error;
}
assert(numberOfIterations > 0);
pass.iterations /= (uint64_t)numberOfIterations;
return std::nullopt;
}
int64_t getElementTypeSize(mlir::Type elementType) {
if (auto integerType = mlir::dyn_cast<mlir::IntegerType>(elementType)) {
auto width = integerType.getWidth();
@@ -146,108 +57,185 @@ bool isBufferDeallocated(mlir::Value buffer) {
return false;
}
static std::optional<StringError> on_enter(memref::AllocOp &op,
MemoryUsagePass &pass) {
auto maybeBufferSize = getBufferSize(op.getResult().getType());
if (!maybeBufferSize) {
return maybeBufferSize.error();
}
// if the allocated buffer is being deallocated then count it as one.
// Otherwise (and there must be a problem) multiply it by the number of
// iterations
int64_t numberOfAlloc =
isBufferDeallocated(op.getResult()) ? 1 : pass.iterations;
auto location = locationString(op.getLoc());
// pass.iterations number of allocation of size: shape_1 * ... * shape_n *
// element_size
auto memoryUsage = numberOfAlloc * maybeBufferSize.value();
pass.feedback.memoryUsagePerLoc[location] += memoryUsage;
return std::nullopt;
}
static std::optional<StringError> on_enter(mlir::Operation *op,
MemoryUsagePass &pass) {
for (auto operand : op->getOperands()) {
// we only consider buffers
if (!mlir::isa<mlir::MemRefType>(operand.getType()))
continue;
// find the origin of the buffer
auto definingOp = operand.getDefiningOp();
mlir::Value lastVisitedBuffer = operand;
while (definingOp) {
mlir::ViewLikeOpInterface viewLikeOp =
mlir::dyn_cast<mlir::ViewLikeOpInterface>(definingOp);
if (viewLikeOp) {
lastVisitedBuffer = viewLikeOp.getViewSource();
definingOp = lastVisitedBuffer.getDefiningOp();
} else {
break;
}
}
// we already count allocations separately
if (definingOp && mlir::isa<memref::AllocOp>(definingOp) &&
definingOp->getLoc() == op->getLoc())
continue;
auto location = locationString(op->getLoc());
std::vector<mlir::Value> &visited = pass.visitedValuesPerLoc[location];
// the search would be faster if we use an unsorted_set, but we need a hash
// function for mlir::Value
if (std::find(visited.begin(), visited.end(), lastVisitedBuffer) ==
visited.end()) {
visited.push_back(lastVisitedBuffer);
auto maybeBufferSize =
getBufferSize(lastVisitedBuffer.getType().cast<mlir::MemRefType>());
if (!maybeBufferSize) {
return maybeBufferSize.error();
}
auto bufferSize = maybeBufferSize.value();
pass.feedback.memoryUsagePerLoc[location] += bufferSize;
}
}
return std::nullopt;
}
} // namespace
std::optional<StringError> MemoryUsagePass::enter(mlir::Operation *op) {
// specialized calls
if (auto typedOp = llvm::dyn_cast<scf::ForOp>(op)) {
std::optional<StringError> error = on_enter(typedOp, *this);
if (error.has_value()) {
return error;
}
}
if (auto typedOp = llvm::dyn_cast<memref::AllocOp>(op)) {
std::optional<StringError> error = on_enter(typedOp, *this);
if (error.has_value()) {
return error;
namespace mlir {
namespace concretelang {
namespace Concrete {
struct MemoryUsagePass
: public PassWrapper<MemoryUsagePass, OperationPass<ModuleOp>> {
CompilationFeedback &feedback;
MemoryUsagePass(CompilationFeedback &feedback) : feedback{feedback} {};
void runOnOperation() override {
WalkResult walk =
getOperation()->walk([&](Operation *op, const WalkStage &stage) {
if (stage.isBeforeAllRegions()) {
std::optional<StringError> error = this->enter(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
}
}
if (stage.isAfterAllRegions()) {
std::optional<StringError> error = this->exit(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});
if (walk.wasInterrupted()) {
signalPassFailure();
}
}
// call generic enter
std::optional<StringError> error = on_enter(op, *this);
if (error.has_value()) {
return error;
}
return std::nullopt;
}
std::optional<StringError> enter(mlir::Operation *op) {
// specialized calls
if (auto typedOp = llvm::dyn_cast<scf::ForOp>(op)) {
std::optional<StringError> error = on_enter(typedOp, *this);
if (error.has_value()) {
return error;
}
}
if (auto typedOp = llvm::dyn_cast<memref::AllocOp>(op)) {
std::optional<StringError> error = on_enter(typedOp, *this);
if (error.has_value()) {
return error;
}
}
std::optional<StringError> MemoryUsagePass::exit(mlir::Operation *op) {
if (auto typedOp = llvm::dyn_cast<scf::ForOp>(op)) {
std::optional<StringError> error = on_exit(typedOp, *this);
// call generic enter
std::optional<StringError> error = on_enter(op, *this);
if (error.has_value()) {
return error;
}
return std::nullopt;
}
return std::nullopt;
std::optional<StringError> exit(mlir::Operation *op) {
if (auto typedOp = llvm::dyn_cast<scf::ForOp>(op)) {
std::optional<StringError> error = on_exit(typedOp, *this);
if (error.has_value()) {
return error;
}
}
return std::nullopt;
}
static std::optional<StringError> on_enter(scf::ForOp &op,
MemoryUsagePass &pass) {
auto numberOfIterations = calculateNumberOfIterations(op);
if (!numberOfIterations) {
return numberOfIterations.error();
}
assert(numberOfIterations.value() > 0);
pass.iterations *= (uint64_t)numberOfIterations.value();
return std::nullopt;
}
static std::optional<StringError> on_exit(scf::ForOp &op,
MemoryUsagePass &pass) {
auto numberOfIterations = calculateNumberOfIterations(op);
if (!numberOfIterations) {
return numberOfIterations.error();
}
assert(numberOfIterations.value() > 0);
pass.iterations /= (uint64_t)numberOfIterations.value();
return std::nullopt;
}
static std::optional<StringError> on_enter(memref::AllocOp &op,
MemoryUsagePass &pass) {
auto maybeBufferSize = getBufferSize(op.getResult().getType());
if (!maybeBufferSize) {
return maybeBufferSize.error();
}
// if the allocated buffer is being deallocated then count it as one.
// Otherwise (and there must be a problem) multiply it by the number of
// iterations
int64_t numberOfAlloc =
isBufferDeallocated(op.getResult()) ? 1 : pass.iterations;
auto location = locationString(op.getLoc());
// pass.iterations number of allocation of size: shape_1 * ... * shape_n *
// element_size
auto memoryUsage = numberOfAlloc * maybeBufferSize.value();
pass.feedback.memoryUsagePerLoc[location] += memoryUsage;
return std::nullopt;
}
static std::optional<StringError> on_enter(mlir::Operation *op,
MemoryUsagePass &pass) {
for (auto operand : op->getOperands()) {
// we only consider buffers
if (!mlir::isa<mlir::MemRefType>(operand.getType()))
continue;
// find the origin of the buffer
auto definingOp = operand.getDefiningOp();
mlir::Value lastVisitedBuffer = operand;
while (definingOp) {
mlir::ViewLikeOpInterface viewLikeOp =
mlir::dyn_cast<mlir::ViewLikeOpInterface>(definingOp);
if (viewLikeOp) {
lastVisitedBuffer = viewLikeOp.getViewSource();
definingOp = lastVisitedBuffer.getDefiningOp();
} else {
break;
}
}
// we already count allocations separately
if (definingOp && mlir::isa<memref::AllocOp>(definingOp) &&
definingOp->getLoc() == op->getLoc())
continue;
auto location = locationString(op->getLoc());
std::vector<mlir::Value> &visited = pass.visitedValuesPerLoc[location];
// the search would be faster if we use an unsorted_set, but we need a
// hash function for mlir::Value
if (std::find(visited.begin(), visited.end(), lastVisitedBuffer) ==
visited.end()) {
visited.push_back(lastVisitedBuffer);
auto maybeBufferSize =
getBufferSize(lastVisitedBuffer.getType().cast<mlir::MemRefType>());
if (!maybeBufferSize) {
return maybeBufferSize.error();
}
auto bufferSize = maybeBufferSize.value();
pass.feedback.memoryUsagePerLoc[location] += bufferSize;
}
}
return std::nullopt;
}
std::map<std::string, std::vector<mlir::Value>> visitedValuesPerLoc;
size_t iterations = 1;
};
} // namespace Concrete
std::unique_ptr<OperationPass<ModuleOp>>
createMemoryUsagePass(CompilationFeedback &feedback) {
return std::make_unique<Concrete::MemoryUsagePass>(feedback);
}
} // namespace concretelang
} // namespace mlir

View File

@@ -9,4 +9,5 @@ add_mlir_library(
LINK_LIBS
PUBLIC
MLIRIR
TFHEDialect)
TFHEDialect
AnalysisUtils)

View File

@@ -1,350 +1,19 @@
#include <concretelang/Analysis/Utils.h>
#include <concretelang/Dialect/TFHE/Analysis/ExtractStatistics.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/SCF/IR/SCF.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/Operation.h>
#include <concretelang/Dialect/TFHE/IR/TFHEOps.h>
using namespace mlir::concretelang;
using namespace mlir;
using TFHE::ExtractTFHEStatisticsPass;
// #########
// Utilities
// #########
template <typename Op> std::string locationOf(Op op) {
auto location = std::string();
auto locationStream = llvm::raw_string_ostream(location);
op.getLoc()->print(locationStream);
return location.substr(5, location.size() - 2 - 5); // remove loc(" and ")
}
// #######
// scf.for
// #######
int64_t calculateNumberOfIterations(int64_t start, int64_t stop, int64_t step) {
int64_t high;
int64_t low;
if (step > 0) {
low = start;
high = stop;
} else {
low = stop;
high = start;
step = -step;
}
if (low >= high) {
return 0;
}
return ((high - low - 1) / step) + 1;
}
std::optional<StringError> calculateNumberOfIterations(scf::ForOp &op,
int64_t &result) {
mlir::Value startValue = op.getLowerBound();
mlir::Value stopValue = op.getUpperBound();
mlir::Value stepValue = op.getStep();
auto startOp =
llvm::dyn_cast_or_null<arith::ConstantOp>(startValue.getDefiningOp());
auto stopOp =
llvm::dyn_cast_or_null<arith::ConstantOp>(stopValue.getDefiningOp());
auto stepOp =
llvm::dyn_cast_or_null<arith::ConstantOp>(stepValue.getDefiningOp());
if (!startOp || !stopOp || !stepOp) {
return StringError("only static loops can be analyzed");
}
auto startAttr = startOp.getValue().cast<mlir::IntegerAttr>();
auto stopAttr = stopOp.getValue().cast<mlir::IntegerAttr>();
auto stepAttr = stepOp.getValue().cast<mlir::IntegerAttr>();
if (!startOp || !stopOp || !stepOp) {
return StringError("only integer loops can be analyzed");
}
int64_t start = startAttr.getInt();
int64_t stop = stopAttr.getInt();
int64_t step = stepAttr.getInt();
result = calculateNumberOfIterations(start, stop, step);
return std::nullopt;
}
static std::optional<StringError> on_enter(scf::ForOp &op,
ExtractTFHEStatisticsPass &pass) {
int64_t numberOfIterations;
std::optional<StringError> error =
calculateNumberOfIterations(op, numberOfIterations);
if (error.has_value()) {
return error;
}
assert(numberOfIterations > 0);
pass.iterations *= (uint64_t)numberOfIterations;
return std::nullopt;
}
static std::optional<StringError> on_exit(scf::ForOp &op,
ExtractTFHEStatisticsPass &pass) {
int64_t numberOfIterations;
std::optional<StringError> error =
calculateNumberOfIterations(op, numberOfIterations);
if (error.has_value()) {
return error;
}
assert(numberOfIterations > 0);
pass.iterations /= (uint64_t)numberOfIterations;
return std::nullopt;
}
// #############
// TFHE.add_glwe
// #############
static std::optional<StringError> on_enter(TFHE::AddGLWEOp &op,
ExtractTFHEStatisticsPass &pass) {
auto resultingKey = op.getType().getKey().getNormalized();
auto location = locationOf(op);
auto operation = PrimitiveOperation::ENCRYPTED_ADDITION;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
keys.push_back(key);
pass.feedback.statistics.push_back(Statistic{
location,
operation,
keys,
count,
});
return std::nullopt;
}
// #################
// TFHE.add_glwe_int
// #################
static std::optional<StringError> on_enter(TFHE::AddGLWEIntOp &op,
ExtractTFHEStatisticsPass &pass) {
auto resultingKey = op.getType().getKey().getNormalized();
auto location = locationOf(op);
auto operation = PrimitiveOperation::CLEAR_ADDITION;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
keys.push_back(key);
pass.feedback.statistics.push_back(Statistic{
location,
operation,
keys,
count,
});
return std::nullopt;
}
// ###################
// TFHE.bootstrap_glwe
// ###################
static std::optional<StringError> on_enter(TFHE::BootstrapGLWEOp &op,
ExtractTFHEStatisticsPass &pass) {
auto bsk = op.getKey();
auto location = locationOf(op);
auto operation = PrimitiveOperation::PBS;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::BOOTSTRAP, (size_t)bsk.getIndex());
keys.push_back(key);
pass.feedback.statistics.push_back(Statistic{
location,
operation,
keys,
count,
});
return std::nullopt;
}
// ###################
// TFHE.keyswitch_glwe
// ###################
static std::optional<StringError> on_enter(TFHE::KeySwitchGLWEOp &op,
ExtractTFHEStatisticsPass &pass) {
auto ksk = op.getKey();
auto location = locationOf(op);
auto operation = PrimitiveOperation::KEY_SWITCH;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::KEY_SWITCH, (size_t)ksk.getIndex());
keys.push_back(key);
pass.feedback.statistics.push_back(Statistic{
location,
operation,
keys,
count,
});
return std::nullopt;
}
// #################
// TFHE.mul_glwe_int
// #################
static std::optional<StringError> on_enter(TFHE::MulGLWEIntOp &op,
ExtractTFHEStatisticsPass &pass) {
auto resultingKey = op.getType().getKey().getNormalized();
auto location = locationOf(op);
auto operation = PrimitiveOperation::CLEAR_MULTIPLICATION;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
keys.push_back(key);
pass.feedback.statistics.push_back(Statistic{
location,
operation,
keys,
count,
});
return std::nullopt;
}
// #############
// TFHE.neg_glwe
// #############
static std::optional<StringError> on_enter(TFHE::NegGLWEOp &op,
ExtractTFHEStatisticsPass &pass) {
auto resultingKey = op.getType().getKey().getNormalized();
auto location = locationOf(op);
auto operation = PrimitiveOperation::ENCRYPTED_NEGATION;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
keys.push_back(key);
pass.feedback.statistics.push_back(Statistic{
location,
operation,
keys,
count,
});
return std::nullopt;
}
// #################
// TFHE.sub_int_glwe
// #################
static std::optional<StringError> on_enter(TFHE::SubGLWEIntOp &op,
ExtractTFHEStatisticsPass &pass) {
auto resultingKey = op.getType().getKey().getNormalized();
auto location = locationOf(op);
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
keys.push_back(key);
// clear - encrypted = clear + neg(encrypted)
auto operation = PrimitiveOperation::ENCRYPTED_NEGATION;
pass.feedback.statistics.push_back(Statistic{
location,
operation,
keys,
count,
});
operation = PrimitiveOperation::CLEAR_ADDITION;
pass.feedback.statistics.push_back(Statistic{
location,
operation,
keys,
count,
});
return std::nullopt;
}
// #################
// TFHE.wop_pbs_glwe
// #################
static std::optional<StringError> on_enter(TFHE::WopPBSGLWEOp &op,
ExtractTFHEStatisticsPass &pass) {
auto bsk = op.getBsk();
auto ksk = op.getKsk();
auto pksk = op.getPksk();
auto location = locationOf(op);
auto operation = PrimitiveOperation::WOP_PBS;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::BOOTSTRAP, (size_t)bsk.getIndex());
keys.push_back(key);
key = std::make_pair(KeyType::KEY_SWITCH, (size_t)ksk.getIndex());
keys.push_back(key);
key = std::make_pair(KeyType::PACKING_KEY_SWITCH, (size_t)pksk.getIndex());
keys.push_back(key);
pass.feedback.statistics.push_back(Statistic{
location,
operation,
keys,
count,
});
return std::nullopt;
}
// ########
// Dispatch
// ########
namespace mlir {
namespace concretelang {
namespace TFHE {
#define DISPATCH_ENTER(type) \
if (auto typedOp = llvm::dyn_cast<type>(op)) { \
@@ -362,22 +31,326 @@ static std::optional<StringError> on_enter(TFHE::WopPBSGLWEOp &op,
} \
}
std::optional<StringError>
ExtractTFHEStatisticsPass::enter(mlir::Operation *op) {
DISPATCH_ENTER(scf::ForOp)
DISPATCH_ENTER(TFHE::AddGLWEOp)
DISPATCH_ENTER(TFHE::AddGLWEIntOp)
DISPATCH_ENTER(TFHE::BootstrapGLWEOp)
DISPATCH_ENTER(TFHE::KeySwitchGLWEOp)
DISPATCH_ENTER(TFHE::MulGLWEIntOp)
DISPATCH_ENTER(TFHE::NegGLWEOp)
DISPATCH_ENTER(TFHE::SubGLWEIntOp)
DISPATCH_ENTER(TFHE::WopPBSGLWEOp)
return std::nullopt;
struct ExtractTFHEStatisticsPass
: public PassWrapper<ExtractTFHEStatisticsPass, OperationPass<ModuleOp>> {
CompilationFeedback &feedback;
ExtractTFHEStatisticsPass(CompilationFeedback &feedback)
: feedback{feedback} {};
void runOnOperation() override {
WalkResult walk =
getOperation()->walk([&](Operation *op, const WalkStage &stage) {
if (stage.isBeforeAllRegions()) {
std::optional<StringError> error = this->enter(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
}
}
if (stage.isAfterAllRegions()) {
std::optional<StringError> error = this->exit(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});
if (walk.wasInterrupted()) {
signalPassFailure();
}
}
std::optional<StringError> enter(mlir::Operation *op) {
DISPATCH_ENTER(scf::ForOp)
DISPATCH_ENTER(TFHE::AddGLWEOp)
DISPATCH_ENTER(TFHE::AddGLWEIntOp)
DISPATCH_ENTER(TFHE::BootstrapGLWEOp)
DISPATCH_ENTER(TFHE::KeySwitchGLWEOp)
DISPATCH_ENTER(TFHE::MulGLWEIntOp)
DISPATCH_ENTER(TFHE::NegGLWEOp)
DISPATCH_ENTER(TFHE::SubGLWEIntOp)
DISPATCH_ENTER(TFHE::WopPBSGLWEOp)
return std::nullopt;
}
std::optional<StringError> exit(mlir::Operation *op) {
DISPATCH_EXIT(scf::ForOp)
return std::nullopt;
}
static std::optional<StringError> on_enter(scf::ForOp &op,
ExtractTFHEStatisticsPass &pass) {
auto numberOfIterations = calculateNumberOfIterations(op);
if (!numberOfIterations) {
return numberOfIterations.error();
}
assert(numberOfIterations.value() > 0);
pass.iterations *= (uint64_t)numberOfIterations.value();
return std::nullopt;
}
static std::optional<StringError> on_exit(scf::ForOp &op,
ExtractTFHEStatisticsPass &pass) {
auto numberOfIterations = calculateNumberOfIterations(op);
if (!numberOfIterations) {
return numberOfIterations.error();
}
assert(numberOfIterations.value() > 0);
pass.iterations /= (uint64_t)numberOfIterations.value();
return std::nullopt;
}
// #############
// TFHE.add_glwe
// #############
static std::optional<StringError> on_enter(TFHE::AddGLWEOp &op,
ExtractTFHEStatisticsPass &pass) {
auto resultingKey = op.getType().getKey().getNormalized();
auto location = locationString(op.getLoc());
auto operation = PrimitiveOperation::ENCRYPTED_ADDITION;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
keys.push_back(key);
pass.feedback.statistics.push_back(concretelang::Statistic{
location,
operation,
keys,
count,
});
return std::nullopt;
}
// #################
// TFHE.add_glwe_int
// #################
static std::optional<StringError> on_enter(TFHE::AddGLWEIntOp &op,
ExtractTFHEStatisticsPass &pass) {
auto resultingKey = op.getType().getKey().getNormalized();
auto location = locationString(op.getLoc());
auto operation = PrimitiveOperation::CLEAR_ADDITION;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
keys.push_back(key);
pass.feedback.statistics.push_back(concretelang::Statistic{
location,
operation,
keys,
count,
});
return std::nullopt;
}
// ###################
// TFHE.bootstrap_glwe
// ###################
static std::optional<StringError> on_enter(TFHE::BootstrapGLWEOp &op,
ExtractTFHEStatisticsPass &pass) {
auto bsk = op.getKey();
auto location = locationString(op.getLoc());
auto operation = PrimitiveOperation::PBS;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::BOOTSTRAP, (size_t)bsk.getIndex());
keys.push_back(key);
pass.feedback.statistics.push_back(concretelang::Statistic{
location,
operation,
keys,
count,
});
return std::nullopt;
}
// ###################
// TFHE.keyswitch_glwe
// ###################
static std::optional<StringError> on_enter(TFHE::KeySwitchGLWEOp &op,
ExtractTFHEStatisticsPass &pass) {
auto ksk = op.getKey();
auto location = locationString(op.getLoc());
auto operation = PrimitiveOperation::KEY_SWITCH;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::KEY_SWITCH, (size_t)ksk.getIndex());
keys.push_back(key);
pass.feedback.statistics.push_back(concretelang::Statistic{
location,
operation,
keys,
count,
});
return std::nullopt;
}
// #################
// TFHE.mul_glwe_int
// #################
static std::optional<StringError> on_enter(TFHE::MulGLWEIntOp &op,
ExtractTFHEStatisticsPass &pass) {
auto resultingKey = op.getType().getKey().getNormalized();
auto location = locationString(op.getLoc());
auto operation = PrimitiveOperation::CLEAR_MULTIPLICATION;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
keys.push_back(key);
pass.feedback.statistics.push_back(concretelang::Statistic{
location,
operation,
keys,
count,
});
return std::nullopt;
}
// #############
// TFHE.neg_glwe
// #############
static std::optional<StringError> on_enter(TFHE::NegGLWEOp &op,
ExtractTFHEStatisticsPass &pass) {
auto resultingKey = op.getType().getKey().getNormalized();
auto location = locationString(op.getLoc());
auto operation = PrimitiveOperation::ENCRYPTED_NEGATION;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
keys.push_back(key);
pass.feedback.statistics.push_back(concretelang::Statistic{
location,
operation,
keys,
count,
});
return std::nullopt;
}
// #################
// TFHE.sub_int_glwe
// #################
static std::optional<StringError> on_enter(TFHE::SubGLWEIntOp &op,
ExtractTFHEStatisticsPass &pass) {
auto resultingKey = op.getType().getKey().getNormalized();
auto location = locationString(op.getLoc());
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
keys.push_back(key);
// clear - encrypted = clear + neg(encrypted)
auto operation = PrimitiveOperation::ENCRYPTED_NEGATION;
pass.feedback.statistics.push_back(concretelang::Statistic{
location,
operation,
keys,
count,
});
operation = PrimitiveOperation::CLEAR_ADDITION;
pass.feedback.statistics.push_back(concretelang::Statistic{
location,
operation,
keys,
count,
});
return std::nullopt;
}
// #################
// TFHE.wop_pbs_glwe
// #################
static std::optional<StringError> on_enter(TFHE::WopPBSGLWEOp &op,
ExtractTFHEStatisticsPass &pass) {
auto bsk = op.getBsk();
auto ksk = op.getKsk();
auto pksk = op.getPksk();
auto location = locationString(op.getLoc());
auto operation = PrimitiveOperation::WOP_PBS;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::BOOTSTRAP, (size_t)bsk.getIndex());
keys.push_back(key);
key = std::make_pair(KeyType::KEY_SWITCH, (size_t)ksk.getIndex());
keys.push_back(key);
key = std::make_pair(KeyType::PACKING_KEY_SWITCH, (size_t)pksk.getIndex());
keys.push_back(key);
pass.feedback.statistics.push_back(concretelang::Statistic{
location,
operation,
keys,
count,
});
return std::nullopt;
}
size_t iterations = 1;
};
} // namespace TFHE
std::unique_ptr<OperationPass<ModuleOp>>
createStatisticExtractionPass(CompilationFeedback &feedback) {
return std::make_unique<TFHE::ExtractTFHEStatisticsPass>(feedback);
}
std::optional<StringError>
ExtractTFHEStatisticsPass::exit(mlir::Operation *op) {
DISPATCH_EXIT(scf::ForOp)
return std::nullopt;
}
} // namespace concretelang
} // namespace mlir

View File

@@ -332,7 +332,7 @@ extractTFHEStatistics(mlir::MLIRContext &context, mlir::ModuleOp &module,
pipelinePrinting("TFHEStatistics", pm, context);
addPotentiallyNestedPass(
pm, std::make_unique<TFHE::ExtractTFHEStatisticsPass>(feedback),
pm, mlir::concretelang::createStatisticExtractionPass(feedback),
enablePass);
return pm.run(module.getOperation());
@@ -358,7 +358,7 @@ computeMemoryUsage(mlir::MLIRContext &context, mlir::ModuleOp &module,
pipelinePrinting("Computing Memory Usage", pm, context);
addPotentiallyNestedPass(
pm, std::make_unique<Concrete::MemoryUsagePass>(feedback), enablePass);
pm, mlir::concretelang::createMemoryUsagePass(feedback), enablePass);
return pm.run(module.getOperation());
}