feat(compiler): First draft of client parameters generation, runtime support for encrypting and decrypting circuit gates, integration of fhe parameters for the v0 (#65, #66, #56)

This commit is contained in:
Quentin Bourgerie
2021-08-04 15:12:48 +02:00
parent e290447389
commit d0877536ed
14 changed files with 984 additions and 44 deletions

View File

@@ -32,6 +32,9 @@ llvm::cl::opt<std::string> output("o",
llvm::cl::value_desc("filename"),
llvm::cl::init("-"));
llvm::cl::opt<bool> verbose("verbose", llvm::cl::desc("verbose logs"),
llvm::cl::init<bool>(false));
llvm::cl::list<std::string> passes(
"passes",
llvm::cl::desc("Specify the passes to run (use only for compiler tests)"),
@@ -53,8 +56,19 @@ llvm::cl::opt<bool> splitInputFile(
"chunk independently"),
llvm::cl::init(false));
llvm::cl::opt<bool> generateKeySet(
"generate-keyset",
llvm::cl::desc("[tmp] Generate a key set for the compiled fhe circuit"),
llvm::cl::init<bool>(false));
llvm::cl::opt<bool> runJit("run-jit", llvm::cl::desc("JIT the code and run it"),
llvm::cl::init<bool>(false));
llvm::cl::opt<std::string> jitFuncname(
"jit-funcname",
llvm::cl::desc("Name of the function to execute, default 'main'"),
llvm::cl::init<std::string>("main"));
llvm::cl::list<int>
jitArgs("jit-args",
llvm::cl::desc("Value of arguments to pass to the main func"),
@@ -64,6 +78,12 @@ llvm::cl::opt<bool> toLLVM("to-llvm", llvm::cl::desc("Compile to llvm and "),
llvm::cl::init<bool>(false));
}; // namespace cmdline
#define LOG_VERBOSE(expr) \
if (cmdline::verbose) \
llvm::errs() << expr;
#define LOG_ERROR(expr) llvm::errs() << expr;
auto defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr);
mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) {
@@ -77,32 +97,42 @@ mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) {
return mlir::success();
}
mlir::LogicalResult runJit(mlir::ModuleOp module, llvm::raw_ostream &os) {
mlir::LogicalResult runJit(mlir::ModuleOp module,
mlir::zamalang::KeySet &keySet,
llvm::raw_ostream &os) {
// Create the JIT lambda
auto maybeLambda =
mlir::zamalang::JITLambda::create("main", module, defaultOptPipeline);
auto maybeLambda = mlir::zamalang::JITLambda::create(
cmdline::jitFuncname, module, defaultOptPipeline);
if (!maybeLambda) {
return mlir::failure();
}
auto lambda = maybeLambda.get().get();
auto lambda = std::move(maybeLambda.get());
// Create buffer to copy argument
std::vector<int64_t> dummy(cmdline::jitArgs.size());
llvm::SmallVector<void *> llvmArgs;
for (auto i = 0; i < cmdline::jitArgs.size(); i++) {
dummy[i] = cmdline::jitArgs[i];
llvmArgs.push_back(&dummy[i]);
}
// Add the result pointer
uint64_t res = 0;
llvmArgs.push_back(&res);
// Create the arguments of the JIT lambda
auto maybeArguments = mlir::zamalang::JITLambda::Argument::create(keySet);
if (auto err = maybeArguments.takeError()) {
// Invoke the lambda
if (lambda->invokeRaw(llvmArgs)) {
LOG_ERROR("Cannot create lambda arguments: " << err << "\n");
return mlir::failure();
}
std::cerr << res << "\n";
// Set the arguments
auto arguments = std::move(maybeArguments.get());
for (auto i = 0; i < cmdline::jitArgs.size(); i++) {
if (auto err = arguments->setArg(i, cmdline::jitArgs[i])) {
LOG_ERROR("Cannot push argument " << i << ": " << err << "\n");
return mlir::failure();
}
}
// Invoke the lambda
if (lambda->invoke(*arguments)) {
return mlir::failure();
}
uint64_t res = 0;
if (auto err = arguments->getResult(0, res)) {
LOG_ERROR("Cannot get result : " << err << "\n");
return mlir::failure();
}
llvm::errs() << res << "\n";
return mlir::success();
}
@@ -137,20 +167,67 @@ processInputBuffer(mlir::MLIRContext &context,
return mlir::success();
}
if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirLLVMDialect(
context, *module,
[](std::string passName) {
return cmdline::passes.size() == 0 ||
std::any_of(
cmdline::passes.begin(), cmdline::passes.end(),
auto enablePass = [](std::string passName) {
return cmdline::passes.size() == 0 ||
std::any_of(cmdline::passes.begin(), cmdline::passes.end(),
[&](const std::string &p) { return passName == p; });
})
};
// Lower to MLIR Stds Dialects and compute the constraint on the FHE Circuit.
mlir::zamalang::FHECircuitConstraint constraint;
LOG_VERBOSE("### Lower from HLFHE to MLIR standards \n");
if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect(
context, *module, constraint, enablePass)
.failed()) {
return mlir::failure();
}
LOG_VERBOSE("### Global FHE constraint: {norm2:" << constraint.norm2 << ", p:"
<< constraint.p << "}\n");
// Retreive the parameters for the v0 approach
mlir::zamalang::V0Parameter *fheParameter =
mlir::zamalang::getV0Parameter(constraint.norm2, constraint.p);
LOG_VERBOSE("### FHE parameters for the atomic pattern: {k: "
<< fheParameter->k
<< ", polynomialSize: " << fheParameter->polynomialSize
<< ", nSmall: " << fheParameter->nSmall
<< ", brLevel: " << fheParameter->brLevel
<< ", brLogBase: " << fheParameter->brLogBase
<< ", ksLevel: " << fheParameter->ksLevel
<< ", polynomialSize: " << fheParameter->ksLogBase << "}\n");
// Generate the keySet
std::unique_ptr<mlir::zamalang::KeySet> keySet;
if (cmdline::generateKeySet || cmdline::runJit) {
// Create the client parameters
auto clientParameter = mlir::zamalang::createClientParametersForV0(
fheParameter, constraint.p, cmdline::jitFuncname, *module);
if (auto err = clientParameter.takeError()) {
LOG_ERROR("cannot generate client parameters: " << err << "\n");
return mlir::failure();
}
LOG_VERBOSE("### Generate the key set\n");
auto maybeKeySet =
mlir::zamalang::KeySet::generate(clientParameter.get(), 0,
0); // TODO: seed
if (auto err = maybeKeySet.takeError()) {
llvm::errs() << err;
return mlir::failure();
}
keySet = std::move(maybeKeySet.get());
}
// Lower to MLIR LLVM Dialect
LOG_VERBOSE("### Lower from MLIR standards to LLVM\n");
if (mlir::zamalang::CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect(
context, *module, enablePass)
.failed()) {
return mlir::failure();
}
if (cmdline::runJit) {
return runJit(module.get(), os);
LOG_VERBOSE("### JIT compile & running\n");
return runJit(module.get(), *keySet, os);
}
if (cmdline::toLLVM) {
return dumpLLVMIR(module.get(), os);