fix(compiler): Fix tfhe global parametrization to handle k>1

Co-authored-by: Mayeul@Zama <mayeul.debellabre@zama.ai>
This commit is contained in:
Quentin Bourgerie
2022-06-23 15:21:10 +02:00
parent ef9d11c16f
commit 2de76e9c4e
22 changed files with 228 additions and 154 deletions

View File

@@ -21,6 +21,7 @@
#include "concretelang/ClientLib/KeySet.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
@@ -175,16 +176,6 @@ llvm::cl::opt<std::string> jitKeySetCachePath(
"jit-keyset-cache-path",
llvm::cl::desc("Path to cache KeySet content (unsecure)"));
llvm::cl::opt<llvm::Optional<size_t>, false, OptionalSizeTParser>
assumeMaxEintPrecision(
"assume-max-eint-precision",
llvm::cl::desc("Assume a maximum precision for encrypted integers"));
llvm::cl::opt<llvm::Optional<size_t>, false, OptionalSizeTParser> assumeMaxMANP(
"assume-max-manp",
llvm::cl::desc(
"Assume a maximum for the Minimum Arithmetic Noise Padding"));
llvm::cl::opt<double> pbsErrorProbability(
"pbs-error-probability",
llvm::cl::desc("Change the default probability of error for all pbs"),
@@ -200,6 +191,20 @@ llvm::cl::list<int64_t> fhelinalgTileSizes(
llvm::cl::desc(
"Force tiling of FHELinalg operation with the given tile sizes"),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
llvm::cl::list<size_t> v0Constraint(
"v0-constraint",
llvm::cl::desc(
"Force the compiler to use the given v0 constraint [p, norm2]"),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
llvm::cl::list<int64_t> v0Parameter(
"v0-parameter",
llvm::cl::desc(
"Force to apply the given v0 parameters [glweDimension, "
"logPolynomialSize, nSmall, brLevel, brLobBase, ksLevel, ksLogBase]"),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
} // namespace cmdline
namespace llvm {
@@ -219,7 +224,8 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
}
} // namespace llvm
mlir::concretelang::CompilationOptions cmdlineCompilationOptions() {
llvm::Expected<mlir::concretelang::CompilationOptions>
cmdlineCompilationOptions() {
mlir::concretelang::CompilationOptions options;
options.verifyDiagnostics = cmdline::verifyDiagnostics;
@@ -228,12 +234,14 @@ mlir::concretelang::CompilationOptions cmdlineCompilationOptions() {
options.dataflowParallelize = cmdline::dataflowParallelize;
options.optimizeConcrete = cmdline::optimizeConcrete;
if (cmdline::assumeMaxEintPrecision.hasValue() &&
cmdline::assumeMaxMANP.hasValue()) {
if (!cmdline::v0Constraint.empty()) {
if (cmdline::v0Constraint.size() != 2) {
return llvm::make_error<llvm::StringError>(
"The v0-constraint option expect a list of size 2",
llvm::inconvertibleErrorCode());
}
options.v0FHEConstraints = mlir::concretelang::V0FHEConstraint{
cmdline::assumeMaxMANP.getValue().getValue(),
cmdline::assumeMaxEintPrecision.getValue().getValue(),
};
cmdline::v0Constraint[1], cmdline::v0Constraint[0]};
}
if (!cmdline::funcName.empty()) {
@@ -244,6 +252,19 @@ mlir::concretelang::CompilationOptions cmdlineCompilationOptions() {
if (!cmdline::fhelinalgTileSizes.empty())
options.fhelinalgTileSizes.emplace(cmdline::fhelinalgTileSizes);
if (!cmdline::v0Parameter.empty()) {
if (cmdline::v0Parameter.size() != 7) {
return llvm::make_error<llvm::StringError>(
"The v0-parameter option expect a list of size 7",
llvm::inconvertibleErrorCode());
}
options.v0Parameter = mlir::concretelang::V0Parameter(
cmdline::v0Parameter[0], cmdline::v0Parameter[1],
cmdline::v0Parameter[2], cmdline::v0Parameter[3],
cmdline::v0Parameter[4], cmdline::v0Parameter[5],
cmdline::v0Parameter[6]);
}
options.optimizerConfig.p_error = cmdline::pbsErrorProbability;
options.optimizerConfig.display = cmdline::displayOptimizerChoice;
@@ -410,6 +431,10 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
}
auto compilerOptions = cmdlineCompilationOptions();
if (auto err = compilerOptions.takeError()) {
llvm::errs() << err << "\n";
return mlir::failure();
}
llvm::Optional<clientlib::KeySetCache> jitKeySetCache;
if (!cmdline::jitKeySetCachePath.empty()) {
@@ -447,7 +472,7 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
auto process = [&](std::unique_ptr<llvm::MemoryBuffer> inputBuffer,
llvm::raw_ostream &os) {
return processInputBuffer(
std::move(inputBuffer), fileName, compilerOptions, cmdline::action,
std::move(inputBuffer), fileName, *compilerOptions, cmdline::action,
cmdline::jitArgs, jitKeySetCache, os, outputLib);
};
auto &os = output->os();