mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
feat(compiler): First draft of client parameters generation, runtime support for encrypting and decrypting circuit gates, integration of fhe parameters for the v0 (#65, #66, #56)
This commit is contained in:
@@ -32,6 +32,9 @@ llvm::cl::opt<std::string> output("o",
|
||||
llvm::cl::value_desc("filename"),
|
||||
llvm::cl::init("-"));
|
||||
|
||||
llvm::cl::opt<bool> verbose("verbose", llvm::cl::desc("verbose logs"),
|
||||
llvm::cl::init<bool>(false));
|
||||
|
||||
llvm::cl::list<std::string> passes(
|
||||
"passes",
|
||||
llvm::cl::desc("Specify the passes to run (use only for compiler tests)"),
|
||||
@@ -53,8 +56,19 @@ llvm::cl::opt<bool> splitInputFile(
|
||||
"chunk independently"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
llvm::cl::opt<bool> generateKeySet(
|
||||
"generate-keyset",
|
||||
llvm::cl::desc("[tmp] Generate a key set for the compiled fhe circuit"),
|
||||
llvm::cl::init<bool>(false));
|
||||
|
||||
llvm::cl::opt<bool> runJit("run-jit", llvm::cl::desc("JIT the code and run it"),
|
||||
llvm::cl::init<bool>(false));
|
||||
|
||||
llvm::cl::opt<std::string> jitFuncname(
|
||||
"jit-funcname",
|
||||
llvm::cl::desc("Name of the function to execute, default 'main'"),
|
||||
llvm::cl::init<std::string>("main"));
|
||||
|
||||
llvm::cl::list<int>
|
||||
jitArgs("jit-args",
|
||||
llvm::cl::desc("Value of arguments to pass to the main func"),
|
||||
@@ -64,6 +78,12 @@ llvm::cl::opt<bool> toLLVM("to-llvm", llvm::cl::desc("Compile to llvm and "),
|
||||
llvm::cl::init<bool>(false));
|
||||
}; // namespace cmdline
|
||||
|
||||
#define LOG_VERBOSE(expr) \
|
||||
if (cmdline::verbose) \
|
||||
llvm::errs() << expr;
|
||||
|
||||
#define LOG_ERROR(expr) llvm::errs() << expr;
|
||||
|
||||
auto defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr);
|
||||
|
||||
mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) {
|
||||
@@ -77,32 +97,42 @@ mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) {
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult runJit(mlir::ModuleOp module, llvm::raw_ostream &os) {
|
||||
mlir::LogicalResult runJit(mlir::ModuleOp module,
|
||||
mlir::zamalang::KeySet &keySet,
|
||||
llvm::raw_ostream &os) {
|
||||
// Create the JIT lambda
|
||||
auto maybeLambda =
|
||||
mlir::zamalang::JITLambda::create("main", module, defaultOptPipeline);
|
||||
auto maybeLambda = mlir::zamalang::JITLambda::create(
|
||||
cmdline::jitFuncname, module, defaultOptPipeline);
|
||||
if (!maybeLambda) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto lambda = maybeLambda.get().get();
|
||||
auto lambda = std::move(maybeLambda.get());
|
||||
|
||||
// Create buffer to copy argument
|
||||
std::vector<int64_t> dummy(cmdline::jitArgs.size());
|
||||
llvm::SmallVector<void *> llvmArgs;
|
||||
for (auto i = 0; i < cmdline::jitArgs.size(); i++) {
|
||||
dummy[i] = cmdline::jitArgs[i];
|
||||
llvmArgs.push_back(&dummy[i]);
|
||||
}
|
||||
// Add the result pointer
|
||||
uint64_t res = 0;
|
||||
llvmArgs.push_back(&res);
|
||||
// Create the arguments of the JIT lambda
|
||||
auto maybeArguments = mlir::zamalang::JITLambda::Argument::create(keySet);
|
||||
if (auto err = maybeArguments.takeError()) {
|
||||
|
||||
// Invoke the lambda
|
||||
if (lambda->invokeRaw(llvmArgs)) {
|
||||
LOG_ERROR("Cannot create lambda arguments: " << err << "\n");
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
std::cerr << res << "\n";
|
||||
// Set the arguments
|
||||
auto arguments = std::move(maybeArguments.get());
|
||||
for (auto i = 0; i < cmdline::jitArgs.size(); i++) {
|
||||
if (auto err = arguments->setArg(i, cmdline::jitArgs[i])) {
|
||||
LOG_ERROR("Cannot push argument " << i << ": " << err << "\n");
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Invoke the lambda
|
||||
if (lambda->invoke(*arguments)) {
|
||||
return mlir::failure();
|
||||
}
|
||||
uint64_t res = 0;
|
||||
if (auto err = arguments->getResult(0, res)) {
|
||||
LOG_ERROR("Cannot get result : " << err << "\n");
|
||||
return mlir::failure();
|
||||
}
|
||||
llvm::errs() << res << "\n";
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
@@ -137,20 +167,67 @@ processInputBuffer(mlir::MLIRContext &context,
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirLLVMDialect(
|
||||
context, *module,
|
||||
[](std::string passName) {
|
||||
return cmdline::passes.size() == 0 ||
|
||||
std::any_of(
|
||||
cmdline::passes.begin(), cmdline::passes.end(),
|
||||
auto enablePass = [](std::string passName) {
|
||||
return cmdline::passes.size() == 0 ||
|
||||
std::any_of(cmdline::passes.begin(), cmdline::passes.end(),
|
||||
[&](const std::string &p) { return passName == p; });
|
||||
})
|
||||
};
|
||||
|
||||
// Lower to MLIR Stds Dialects and compute the constraint on the FHE Circuit.
|
||||
mlir::zamalang::FHECircuitConstraint constraint;
|
||||
LOG_VERBOSE("### Lower from HLFHE to MLIR standards \n");
|
||||
if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect(
|
||||
context, *module, constraint, enablePass)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
LOG_VERBOSE("### Global FHE constraint: {norm2:" << constraint.norm2 << ", p:"
|
||||
<< constraint.p << "}\n");
|
||||
|
||||
// Retreive the parameters for the v0 approach
|
||||
mlir::zamalang::V0Parameter *fheParameter =
|
||||
mlir::zamalang::getV0Parameter(constraint.norm2, constraint.p);
|
||||
LOG_VERBOSE("### FHE parameters for the atomic pattern: {k: "
|
||||
<< fheParameter->k
|
||||
<< ", polynomialSize: " << fheParameter->polynomialSize
|
||||
<< ", nSmall: " << fheParameter->nSmall
|
||||
<< ", brLevel: " << fheParameter->brLevel
|
||||
<< ", brLogBase: " << fheParameter->brLogBase
|
||||
<< ", ksLevel: " << fheParameter->ksLevel
|
||||
<< ", polynomialSize: " << fheParameter->ksLogBase << "}\n");
|
||||
|
||||
// Generate the keySet
|
||||
std::unique_ptr<mlir::zamalang::KeySet> keySet;
|
||||
if (cmdline::generateKeySet || cmdline::runJit) {
|
||||
// Create the client parameters
|
||||
auto clientParameter = mlir::zamalang::createClientParametersForV0(
|
||||
fheParameter, constraint.p, cmdline::jitFuncname, *module);
|
||||
if (auto err = clientParameter.takeError()) {
|
||||
LOG_ERROR("cannot generate client parameters: " << err << "\n");
|
||||
return mlir::failure();
|
||||
}
|
||||
LOG_VERBOSE("### Generate the key set\n");
|
||||
auto maybeKeySet =
|
||||
mlir::zamalang::KeySet::generate(clientParameter.get(), 0,
|
||||
0); // TODO: seed
|
||||
if (auto err = maybeKeySet.takeError()) {
|
||||
llvm::errs() << err;
|
||||
return mlir::failure();
|
||||
}
|
||||
keySet = std::move(maybeKeySet.get());
|
||||
}
|
||||
|
||||
// Lower to MLIR LLVM Dialect
|
||||
LOG_VERBOSE("### Lower from MLIR standards to LLVM\n");
|
||||
if (mlir::zamalang::CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect(
|
||||
context, *module, enablePass)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (cmdline::runJit) {
|
||||
return runJit(module.get(), os);
|
||||
LOG_VERBOSE("### JIT compile & running\n");
|
||||
return runJit(module.get(), *keySet, os);
|
||||
}
|
||||
if (cmdline::toLLVM) {
|
||||
return dumpLLVMIR(module.get(), os);
|
||||
|
||||
Reference in New Issue
Block a user