feat(compiler): First draft of client parameters generation, runtime support for encrypting and decrypting circuit gates, integration of fhe parameters for the v0 (#65, #66, #56)

This commit is contained in:
Quentin Bourgerie
2021-08-04 15:12:48 +02:00
parent e290447389
commit d0877536ed
14 changed files with 984 additions and 44 deletions

View File

@@ -6,7 +6,6 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-rtti")
find_package(MLIR REQUIRED CONFIG)
message(STATUS "Using MLIR cmake file from: ${MLIR_DIR}")
@@ -27,6 +26,12 @@ include_directories(${PROJECT_BINARY_DIR}/include)
link_directories(${LLVM_BUILD_LIBRARY_DIR})
add_definitions(${LLVM_DEFINITIONS})
#-------------------------------------------------------------------------------
# Concrete FFI Configuration
#-------------------------------------------------------------------------------
include_directories(${CONCRETE_FFI_RELEASE})
add_library(Concrete SHARED IMPORTED)
set_target_properties(Concrete PROPERTIES IMPORTED_LOCATION ${CONCRETE_FFI_RELEASE}/libconcrete_ffi.so )
#-------------------------------------------------------------------------------
# Python Configuration

View File

@@ -1,5 +1,8 @@
build:
cmake -B build . -DLLVM_DIR=${LLVM_PROJECT}/build/lib/cmake/llvm -DMLIR_DIR=${LLVM_PROJECT}/build/lib/cmake/mlir
cmake -B build . \
-DLLVM_DIR=${LLVM_PROJECT}/build/lib/cmake/llvm \
-DMLIR_DIR=${LLVM_PROJECT}/build/lib/cmake/mlir \
-DCONCRETE_FFI_RELEASE=${CONCRETE_PROJECT}/target/release
zamacompiler:
make -C build/ zamacompiler

View File

@@ -0,0 +1,81 @@
#ifndef ZAMALANG_SUPPORT_CLIENTPARAMETERS_H_
#define ZAMALANG_SUPPORT_CLIENTPARAMETERS_H_
#include <map>
#include <string>
#include <vector>
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
#include <mlir/IR/BuiltinOps.h>
#include "zamalang/Support/V0Parameters.h"
namespace mlir {
namespace zamalang {
typedef size_t DecompositionLevelCount;
typedef size_t DecompositionBaseLog;
typedef size_t PolynomialSize;
typedef size_t Precision;
typedef double Variance;
typedef uint64_t LweSize;
typedef uint64_t GLWESize;
typedef std::string LweSecretKeyID;
struct LweSecretKeyParam {
LweSize size;
};
typedef std::string BootstrapKeyID;
struct BootstrapKeyParam {
LweSecretKeyID inputSecretKeyID;
LweSecretKeyID outputSecretKeyID;
DecompositionLevelCount level;
DecompositionBaseLog baseLog;
GLWESize k;
Variance variance;
};
typedef std::string KeyswitchKeyID;
struct KeyswitchKeyParam {
LweSecretKeyID inputSecretKeyID;
LweSecretKeyID outputSecretKeyID;
DecompositionLevelCount level;
DecompositionBaseLog baseLog;
Variance variance;
};
struct Encoding {
Precision precision;
};
struct EncryptionGate {
LweSecretKeyID secretKeyID;
Variance variance;
Encoding encoding;
};
struct CircuitGateShape {
uint64_t size;
};
struct CircuitGate {
llvm::Optional<EncryptionGate> encryption;
CircuitGateShape shape;
};
struct ClientParameters {
std::map<LweSecretKeyID, LweSecretKeyParam> secretKeys;
std::map<BootstrapKeyID, BootstrapKeyParam> bootstrapKeys;
std::map<KeyswitchKeyID, KeyswitchKeyParam> keyswitchKeys;
std::vector<CircuitGate> inputs;
std::vector<CircuitGate> outputs;
};
llvm::Expected<ClientParameters>
createClientParametersForV0(V0Parameter *v0Param, Precision precision,
llvm::StringRef name, mlir::ModuleOp module);
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -5,21 +5,43 @@
#include <mlir/ExecutionEngine/ExecutionEngine.h>
#include <mlir/Pass/PassManager.h>
#include "zamalang/Support/ClientParameters.h"
#include "zamalang/Support/KeySet.h"
#include "zamalang/Support/V0Parameters.h"
namespace mlir {
namespace zamalang {
/// For the v0 we compute a global constraint, this is defined here as the
/// high-level verification pass is not yet implemented.
struct FHECircuitConstraint {
size_t norm2;
size_t p;
};
class CompilerTools {
public:
/// lowerHLFHEToMlirLLVMDialect run all passes to lower FHE dialects to mlir
/// LLVM dialect.
static mlir::LogicalResult lowerHLFHEToMlirLLVMDialect(
/// lowerable to llvm dialect.
/// The given module MLIR operation would be modified and the constraints set.
static mlir::LogicalResult lowerHLFHEToMlirStdsDialect(
mlir::MLIRContext &context, mlir::Operation *module,
FHECircuitConstraint &constraint,
llvm::function_ref<bool(std::string)> enablePass = [](std::string pass) {
return true;
});
/// lowerMlirStdsDialectToMlirLLVMDialect run all passes to lower MLIR
/// dialects to MLIR LLVM dialect. The given module MLIR operation would be
/// modified.
static mlir::LogicalResult lowerMlirStdsDialectToMlirLLVMDialect(
mlir::MLIRContext &context, mlir::Operation *module,
llvm::function_ref<bool(std::string)> enablePass = [](std::string pass) {
return true;
});
static llvm::Expected<std::unique_ptr<llvm::Module>>
toLLVMModule(llvm::LLVMContext &context, mlir::ModuleOp &module,
toLLVMModule(llvm::LLVMContext &llvmContext, mlir::ModuleOp &module,
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline);
};
@@ -27,6 +49,28 @@ public:
/// of the module.
class JITLambda {
public:
class Argument {
public:
Argument(KeySet &keySet);
~Argument();
// Create lambda Argument that use the given KeySet to perform encryption
// and decryption operations.
static llvm::Expected<std::unique_ptr<Argument>> create(KeySet &keySet);
// Set the argument at the given pos as a uint64_t.
llvm::Error setArg(size_t pos, uint64_t arg);
// Get the result at the given pos as an uint64_t.
llvm::Error getResult(size_t pos, uint64_t &res);
private:
friend JITLambda;
std::vector<void *> rawArg;
std::vector<void *> inputs;
std::vector<void *> results;
KeySet &keySet;
};
JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name)
: type(type), name(name){};
@@ -35,7 +79,7 @@ public:
create(llvm::StringRef name, mlir::ModuleOp &module,
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline);
/// invokeRaw execute the jit lambda with a lits of arguments, the last one is
/// invokeRaw execute the jit lambda with a list of Argument, the last one is
/// used to store the result of the computation.
/// Example:
/// uin64_t arg0 = 1;
@@ -44,6 +88,9 @@ public:
/// lambda.invokeRaw(args);
llvm::Error invokeRaw(llvm::MutableArrayRef<void *> args);
/// invoke the jit lambda with the Argument.
llvm::Error invoke(Argument &args);
private:
mlir::LLVM::LLVMFunctionType type;
llvm::StringRef name;

View File

@@ -0,0 +1,65 @@
#ifndef ZAMALANG_SUPPORT_KEYSET_H_
#define ZAMALANG_SUPPORT_KEYSET_H_
#include "llvm/Support/Error.h"
#include <memory>
extern "C" {
#include "concrete-ffi.h"
}
#include "zamalang/Support/ClientParameters.h"
namespace mlir {
namespace zamalang {
class KeySet {
public:
~KeySet();
// allocate a KeySet according the ClientParameters.
static llvm::Expected<std::unique_ptr<KeySet>>
generate(ClientParameters &params, uint64_t seed_msb, uint64_t seed_lsb);
// isInputEncrypted return true if the input at the given pos is encrypted.
bool isInputEncrypted(size_t pos);
// allocate a lwe ciphertext for the argument at argPos.
llvm::Error allocate_lwe(size_t argPos, LweCiphertext_u64 **ciphertext);
// encrypt the input to the ciphertext for the argument at argPos.
llvm::Error encrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
uint64_t input);
// isOuputEncrypted return true if the output at the given pos is encrypted.
bool isOutputEncrypted(size_t pos);
// decrypt the ciphertext to the output for the argument at argPos.
llvm::Error decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
uint64_t &output);
size_t numInputs() { return inputs.size(); }
size_t numOutputs() { return outputs.size(); }
protected:
llvm::Error generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param,
SecretRandomGenerator *generator);
llvm::Error generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param,
EncryptionRandomGenerator *generator);
llvm::Error generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param,
EncryptionRandomGenerator *generator);
private:
EncryptionRandomGenerator *encryptionRandomGenerator;
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
secretKeys;
std::map<LweSecretKeyID, std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>>
bootstrapKeys;
std::map<LweSecretKeyID, std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>>
keyswitchKeys;
std::vector<std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *>>
inputs;
std::vector<std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *>>
outputs;
};
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -0,0 +1,29 @@
#ifndef ZAMALANG_SUPPORT_V0Parameter_H_
#define ZAMALANG_SUPPORT_V0Parameter_H_
#include <cstddef>
namespace mlir {
namespace zamalang {
typedef struct V0Parameter {
size_t k;
size_t polynomialSize;
size_t nSmall;
size_t brLevel;
size_t brLogBase;
size_t ksLevel;
size_t ksLogBase;
V0Parameter(size_t k, size_t polynomialSize, size_t nSmall, size_t brLevel,
size_t brLogBase, size_t ksLevel, size_t ksLogBase)
: k(k), polynomialSize(polynomialSize), nSmall(nSmall), brLevel(brLevel),
brLogBase(brLogBase), ksLevel(ksLevel), ksLogBase(ksLogBase) {}
} V0Parameter;
V0Parameter *getV0Parameter(size_t norm, size_t p);
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -3,9 +3,6 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
@@ -19,10 +16,16 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/Sequence.h"
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
namespace {
struct MLIRLowerableDialectsToLLVMPass
: public MLIRLowerableDialectsToLLVMBase<MLIRLowerableDialectsToLLVMPass> {
void runOnOperation() final;
/// Convert types to the LLVM dialect-compatible type
static llvm::Optional<mlir::Type> convertTypes(mlir::Type type);
};
} // namespace
@@ -35,6 +38,7 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
// Setup the LLVMTypeConverter (that converts `std` types to `llvm` types) and
// add our types conversion to `llvm` compatible type.
mlir::LLVMTypeConverter typeConverter(&getContext());
typeConverter.addConversion(convertTypes);
// Setup the set of the patterns rewriter. At this point we want to
// convert the `scf` operations to `std` and `std` operations to `llvm`.
@@ -49,6 +53,15 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
}
}
llvm::Optional<mlir::Type>
MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) {
if (type.isa<mlir::zamalang::LowLFHE::LweCiphertextType>()) {
return mlir::LLVM::LLVMPointerType::get(
mlir::IntegerType::get(type.getContext(), 8));
}
return llvm::None;
}
namespace mlir {
namespace zamalang {
/// Create a pass for lowering operations the remaining mlir dialects

View File

@@ -1,5 +1,8 @@
add_mlir_library(ZamalangSupport
CompilerTools.cpp
V0Parameters.cpp
ClientParameters.cpp
KeySet.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/zamalang/Support
@@ -13,4 +16,7 @@ add_mlir_library(ZamalangSupport
MLIRLowerableDialectsToLLVM
MLIRExecutionEngine
${LLVM_PTHREAD_LIB})
${LLVM_PTHREAD_LIB}
Concrete
)

View File

@@ -0,0 +1,117 @@
#include <llvm/ADT/STLExtras.h>
#include <llvm/Support/Error.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
#include "zamalang/Support/ClientParameters.h"
namespace mlir {
namespace zamalang {
// For the v0 the secretKeyID and precision are the same for all gates.
llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
Precision precision,
mlir::Type type) {
if (type.isInteger(64)) {
return CircuitGate{
.encryption = llvm::None,
.shape = {.size = 0},
};
}
if (type.isa<mlir::zamalang::LowLFHE::LweCiphertextType>()) {
return CircuitGate{
.encryption = llvm::Optional<EncryptionGate>({
.secretKeyID = secretKeyID,
// TODO - Compute variance, wait for security estimator
.variance = 0.,
.encoding = {.precision = precision},
}),
.shape = {.size = 0},
};
}
auto memref = type.dyn_cast_or_null<mlir::MemRefType>();
if (memref != nullptr) {
auto gate =
gateFromMLIRType(secretKeyID, precision, memref.getElementType());
if (auto err = gate.takeError()) {
return std::move(err);
}
gate->shape.size = memref.getDimSize(0);
return gate;
}
return llvm::make_error<llvm::StringError>(
"cannot convert MLIR type to shape", llvm::inconvertibleErrorCode());
}
llvm::Expected<ClientParameters>
createClientParametersForV0(V0Parameter *v0Param, Precision precision,
llvm::StringRef name, mlir::ModuleOp module) {
// Static client parameters from global parameters for v0
ClientParameters c{
.secretKeys{
{"small", {.size = v0Param->nSmall}},
{"big", {.size = v0Param->k * (1 << v0Param->polynomialSize)}},
},
.bootstrapKeys{
{
"bsk_v0",
{
.inputSecretKeyID = "small",
.outputSecretKeyID = "big",
.level = v0Param->brLevel,
.baseLog = v0Param->brLogBase,
.k = v0Param->k,
// TODO - Compute variance, wait for security estimator
.variance = 0,
},
},
},
.keyswitchKeys{
{
"ksk_v0",
{
.inputSecretKeyID = "big",
.outputSecretKeyID = "small",
.level = v0Param->ksLevel,
.baseLog = v0Param->ksLogBase,
// TODO - Compute variance, wait for security estimator
.variance = 0,
},
},
},
};
// Find the input function
auto rangeOps = module.getOps<mlir::FuncOp>();
auto funcOp = llvm::find_if(
rangeOps, [&](mlir::FuncOp op) { return op.getName() == name; });
if (funcOp == rangeOps.end()) {
return llvm::make_error<llvm::StringError>(
"cannot find the function for generate client parameters",
llvm::inconvertibleErrorCode());
}
// For the v0 the precision is global
Encoding encoding = {.precision = precision};
// Create input and output circuit gate parameters
auto funcType = (*funcOp).getType();
for (auto inType : funcType.getInputs()) {
auto gate = gateFromMLIRType("big", precision, inType);
if (auto err = gate.takeError()) {
return std::move(err);
}
c.inputs.push_back(gate.get());
}
for (auto outType : funcType.getResults()) {
auto gate = gateFromMLIRType("big", precision, outType);
if (auto err = gate.takeError()) {
return std::move(err);
}
c.outputs.push_back(gate.get());
}
return c;
}
} // namespace zamalang
} // namespace mlir

View File

@@ -8,6 +8,9 @@
namespace mlir {
namespace zamalang {
// This is temporary while we doesn't yet have the high-level verification pass
FHECircuitConstraint defaultGlobalFHECircuitConstraint{.norm2 = 20, .p = 6};
void initLLVMNativeTarget() {
// Initialize LLVM targets.
llvm::InitializeNativeTarget();
@@ -27,8 +30,9 @@ void addFilteredPassToPassManager(
}
};
mlir::LogicalResult CompilerTools::lowerHLFHEToMlirLLVMDialect(
mlir::LogicalResult CompilerTools::lowerHLFHEToMlirStdsDialect(
mlir::MLIRContext &context, mlir::Operation *module,
FHECircuitConstraint &constraint,
llvm::function_ref<bool(std::string)> enablePass) {
mlir::PassManager pm(&context);
@@ -37,11 +41,7 @@ mlir::LogicalResult CompilerTools::lowerHLFHEToMlirLLVMDialect(
pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg(), enablePass);
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass);
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertMidLFHEToLowLFHEPass(), enablePass);
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass(),
enablePass);
constraint = defaultGlobalFHECircuitConstraint;
// Run the passes
if (pm.run(module).failed()) {
@@ -51,14 +51,31 @@ mlir::LogicalResult CompilerTools::lowerHLFHEToMlirLLVMDialect(
return mlir::success();
}
mlir::LogicalResult CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect(
mlir::MLIRContext &context, mlir::Operation *module,
llvm::function_ref<bool(std::string)> enablePass) {
mlir::PassManager pm(&context);
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertMidLFHEToLowLFHEPass(), enablePass);
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass(),
enablePass);
if (pm.run(module).failed()) {
return mlir::failure();
}
return mlir::success();
}
llvm::Expected<std::unique_ptr<llvm::Module>> CompilerTools::toLLVMModule(
llvm::LLVMContext &context, mlir::ModuleOp &module,
llvm::LLVMContext &llvmContext, mlir::ModuleOp &module,
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline) {
initLLVMNativeTarget();
mlir::registerLLVMDialectTranslation(*module->getContext());
auto llvmModule = mlir::translateModuleToLLVMIR(module, context);
auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext);
if (!llvmModule) {
return llvm::make_error<llvm::StringError>(
"failed to translate MLIR to LLVM IR", llvm::inconvertibleErrorCode());
@@ -113,5 +130,69 @@ llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef<void *> args) {
llvm::inconvertibleErrorCode());
}
llvm::Error JITLambda::invoke(Argument &args) { return invokeRaw(args.rawArg); }
JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
inputs = std::vector<void *>(keySet.numInputs());
results = std::vector<void *>(keySet.numOutputs());
// The raw argument contains pointers to inputs and pointers to store the
// results
rawArg =
std::vector<void *>(keySet.numInputs() + keySet.numOutputs(), nullptr);
// Set the results pointer on the rawArg
for (auto i = keySet.numInputs(); i < rawArg.size(); i++) {
rawArg[i] = &results[i - keySet.numInputs()];
}
}
JITLambda::Argument::~Argument() {
int err;
for (auto i = 0; i < keySet.numInputs(); i++) {
if (keySet.isInputEncrypted(i)) {
free_lwe_ciphertext_u64(&err, (LweCiphertext_u64 *)(inputs[i]));
}
}
}
llvm::Expected<std::unique_ptr<JITLambda::Argument>>
JITLambda::Argument::create(KeySet &keySet) {
auto args = std::make_unique<JITLambda::Argument>(keySet);
return std::move(args);
}
llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) {
// If argument is not encrypted, just save.
if (!keySet.isInputEncrypted(pos)) {
inputs[pos] = (void *)arg;
rawArg[pos] = &inputs[pos];
return llvm::Error::success();
}
// Else if is encryted, allocate ciphertext.
LweCiphertext_u64 *ctArg;
if (auto err = this->keySet.allocate_lwe(pos, &ctArg)) {
return std::move(err);
}
if (auto err = this->keySet.encrypt_lwe(pos, ctArg, arg)) {
return std::move(err);
}
inputs[pos] = ctArg;
rawArg[pos] = &inputs[pos];
return llvm::Error::success();
}
llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t &res) {
// If result is not encrypted, just set the result
if (!keySet.isOutputEncrypted(pos)) {
res = (uint64_t)(results[pos]);
return llvm::Error::success();
}
// Else if is encryted, decrypt
LweCiphertext_u64 *ct = (LweCiphertext_u64 *)(results[pos]);
if (auto err = this->keySet.decrypt_lwe(pos, ct, res)) {
return std::move(err);
}
return llvm::Error::success();
}
} // namespace zamalang
} // namespace mlir

View File

@@ -0,0 +1,276 @@
#include "zamalang/Support/KeySet.h"
#define CAPI_ERR_TO_LLVM_ERROR(s, msg) \
{ \
int err; \
s; \
if (err != 0) { \
return llvm::make_error<llvm::StringError>( \
msg, llvm::inconvertibleErrorCode()); \
} \
}
namespace mlir {
namespace zamalang {
KeySet::~KeySet() {
int err;
for (auto it : secretKeys) {
free_lwe_secret_key_u64(&err, it.second.second);
}
for (auto it : bootstrapKeys) {
free_lwe_bootstrap_key_u64(&err, it.second.second);
}
for (auto it : keyswitchKeys) {
free_lwe_keyswitch_key_u64(&err, it.second.second);
}
free_encryption_generator(&err, encryptionRandomGenerator);
}
llvm::Expected<std::unique_ptr<KeySet>>
KeySet::generate(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb) {
auto keySet = std::make_unique<KeySet>();
{
// Generate LWE secret keys
SecretRandomGenerator *generator;
CAPI_ERR_TO_LLVM_ERROR(
generator = allocate_secret_generator(&err, seed_msb, seed_lsb),
"cannot allocate random generator");
for (auto secretKeyParam : params.secretKeys) {
auto e = keySet->generateSecretKey(secretKeyParam.first,
secretKeyParam.second, generator);
if (e) {
return e;
}
}
CAPI_ERR_TO_LLVM_ERROR(free_secret_generator(&err, generator),
"cannot free random generator");
}
// Allocate the encryption random generator
CAPI_ERR_TO_LLVM_ERROR(
keySet->encryptionRandomGenerator =
allocate_encryption_generator(&err, seed_msb, seed_lsb),
"cannot allocate encryption generator");
// Generate bootstrap and keyswitch keys
{
for (auto bootstrapKeyParam : params.bootstrapKeys) {
auto e = keySet->generateBootstrapKey(bootstrapKeyParam.first,
bootstrapKeyParam.second,
keySet->encryptionRandomGenerator);
if (e) {
return e;
}
}
for (auto keyswitchParam : params.keyswitchKeys) {
auto e = keySet->generateKeyswitchKey(keyswitchParam.first,
keyswitchParam.second,
keySet->encryptionRandomGenerator);
if (e) {
return e;
}
}
}
// Set inputs and outputs LWE secret keys
{
for (auto param : params.inputs) {
std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *> input = {
param, nullptr, nullptr};
if (param.encryption.hasValue()) {
auto inputSk = keySet->secretKeys.find(param.encryption->secretKeyID);
if (inputSk == keySet->secretKeys.end()) {
return llvm::make_error<llvm::StringError>(
"cannot find input key to generate bootstrap key",
llvm::inconvertibleErrorCode());
}
std::get<1>(input) = &inputSk->second.first;
std::get<2>(input) = inputSk->second.second;
}
keySet->inputs.push_back(input);
}
for (auto param : params.outputs) {
std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *> output =
{param, nullptr, nullptr};
if (param.encryption.hasValue()) {
auto outputSk = keySet->secretKeys.find(param.encryption->secretKeyID);
if (outputSk == keySet->secretKeys.end()) {
return llvm::make_error<llvm::StringError>(
"cannot find output key to generate bootstrap key",
llvm::inconvertibleErrorCode());
}
std::get<1>(output) = &outputSk->second.first;
std::get<2>(output) = outputSk->second.second;
}
keySet->outputs.push_back(output);
}
}
return std::move(keySet);
}
llvm::Error KeySet::generateSecretKey(LweSecretKeyID id,
LweSecretKeyParam param,
SecretRandomGenerator *generator) {
LweSecretKey_u64 *sk;
CAPI_ERR_TO_LLVM_ERROR(
sk = allocate_lwe_secret_key_u64(&err, {_0 : param.size}),
"cannot allocate secret key");
CAPI_ERR_TO_LLVM_ERROR(fill_lwe_secret_key_u64(&err, sk, generator),
"cannot fill secret key with random generator")
secretKeys[id] = {param, sk};
return llvm::Error::success();
}
llvm::Error KeySet::generateBootstrapKey(BootstrapKeyID id,
BootstrapKeyParam param,
EncryptionRandomGenerator *generator) {
// Finding input and output secretKeys
auto inputSk = secretKeys.find(param.inputSecretKeyID);
if (inputSk == secretKeys.end()) {
return llvm::make_error<llvm::StringError>(
"cannot find input key to generate bootstrap key",
llvm::inconvertibleErrorCode());
}
auto outputSk = secretKeys.find(param.outputSecretKeyID);
if (outputSk == secretKeys.end()) {
return llvm::make_error<llvm::StringError>(
"cannot find input key to generate bootstrap key",
llvm::inconvertibleErrorCode());
}
// Allocate the bootstrap key
LweBootstrapKey_u64 *bsk;
CAPI_ERR_TO_LLVM_ERROR(
bsk = allocate_lwe_bootstrap_key_u64(
&err, {param.level}, {param.baseLog}, {param.k},
{inputSk->second.first.size},
{outputSk->second.first.size /*TODO: size / k ?*/}),
"cannot allocate bootstrap key");
// Store the bootstrap key
bootstrapKeys[id] = {param, bsk};
// Convert the output lwe key to glwe key
GlweSecretKey_u64 *glwe_sk;
CAPI_ERR_TO_LLVM_ERROR(
glwe_sk = allocate_glwe_secret_key_u64(&err, {param.k},
{outputSk->second.first.size}),
"cannot allocate glwe key for initiliazation of bootstrap key");
// Initialize the bootstrap key
CAPI_ERR_TO_LLVM_ERROR(
fill_lwe_bootstrap_key_u64(&err, bsk, inputSk->second.second, glwe_sk,
generator, {param.variance}),
"cannot fill bootstrap key");
CAPI_ERR_TO_LLVM_ERROR(
free_glwe_secret_key_u64(&err, glwe_sk),
"cannot free glwe key for initiliazation of bootstrap key")
return llvm::Error::success();
}
llvm::Error KeySet::generateKeyswitchKey(KeyswitchKeyID id,
KeyswitchKeyParam param,
EncryptionRandomGenerator *generator) {
// Finding input and output secretKeys
auto inputSk = secretKeys.find(param.inputSecretKeyID);
if (inputSk == secretKeys.end()) {
return llvm::make_error<llvm::StringError>(
"cannot find input key to generate keyswitch key",
llvm::inconvertibleErrorCode());
}
auto outputSk = secretKeys.find(param.outputSecretKeyID);
if (outputSk == secretKeys.end()) {
return llvm::make_error<llvm::StringError>(
"cannot find input key to generate keyswitch key",
llvm::inconvertibleErrorCode());
}
// Allocate the keyswitch key
LweKeyswitchKey_u64 *ksk;
CAPI_ERR_TO_LLVM_ERROR(
ksk = allocate_lwe_keyswitch_key_u64(&err, {param.level}, {param.baseLog},
{inputSk->second.first.size},
{outputSk->second.first.size}),
"cannot allocate keyswitch key");
// Store the keyswitch key
keyswitchKeys[id] = {param, ksk};
// Initialize the keyswitch key
CAPI_ERR_TO_LLVM_ERROR(
fill_lwe_keyswitch_key_u64(&err, ksk, inputSk->second.second,
outputSk->second.second, generator,
{param.variance}),
"cannot fill bootsrap key");
return llvm::Error::success();
}
llvm::Error KeySet::allocate_lwe(size_t argPos,
LweCiphertext_u64 **ciphertext) {
if (argPos >= inputs.size()) {
return llvm::make_error<llvm::StringError>(
"allocate_lwe position of argument is too high",
llvm::inconvertibleErrorCode());
}
auto inputSk = inputs[argPos];
CAPI_ERR_TO_LLVM_ERROR(*ciphertext = allocate_lwe_ciphertext_u64(
&err, {std::get<1>(inputSk)->size}),
"cannot allocate ciphertext");
return llvm::Error::success();
}
bool KeySet::isInputEncrypted(size_t argPos) {
return argPos < inputs.size() &&
std::get<0>(inputs[argPos]).encryption.hasValue();
}
bool KeySet::isOutputEncrypted(size_t argPos) {
return argPos < outputs.size() &&
std::get<0>(outputs[argPos]).encryption.hasValue();
}
llvm::Error KeySet::encrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
uint64_t input) {
if (argPos >= inputs.size()) {
return llvm::make_error<llvm::StringError>(
"encrypt_lwe position of argument is too high",
llvm::inconvertibleErrorCode());
}
auto inputSk = inputs[argPos];
if (!std::get<0>(inputSk).encryption.hasValue()) {
return llvm::make_error<llvm::StringError>(
"encrypt_lwe the positional argument is not encrypted",
llvm::inconvertibleErrorCode());
}
// Encode - TODO we could check if the input value is in the right range
Plaintext_u64 plaintext = {
input << (64 -
(std::get<0>(inputSk).encryption->encoding.precision + 1))};
// Encrypt
CAPI_ERR_TO_LLVM_ERROR(
encrypt_lwe_u64(&err, std::get<2>(inputSk), ciphertext, plaintext,
encryptionRandomGenerator,
{std::get<0>(inputSk).encryption->variance}),
"cannot encrypt");
return llvm::Error::success();
}
llvm::Error KeySet::decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
uint64_t &output) {
if (argPos >= outputs.size()) {
return llvm::make_error<llvm::StringError>(
"decrypt_lwe: position of argument is too high",
llvm::inconvertibleErrorCode());
}
auto outputSk = outputs[argPos];
if (!std::get<0>(outputSk).encryption.hasValue()) {
return llvm::make_error<llvm::StringError>(
"decrypt_lwe: the positional argument is not encrypted",
llvm::inconvertibleErrorCode());
}
// Decrypt
Plaintext_u64 plaintext;
CAPI_ERR_TO_LLVM_ERROR(
decrypt_lwe_u64(&err, std::get<2>(outputSk), ciphertext, &plaintext),
"cannot decrypt");
// Decode
output = plaintext._0 >>
(64 - (std::get<0>(outputSk).encryption->encoding.precision + 1));
return llvm::Error::success();
}
} // namespace zamalang
} // namespace mlir

View File

@@ -0,0 +1,134 @@
/// DO NOT MANUALLY EDIT THIS FILE.
/// This file was generated thanks the "parameters optimizer".
/// We should include this in our build system, but for moment it is just a cc
/// from the optimizer output.
#include "zamalang/Support/V0Parameters.h"
namespace mlir {
namespace zamalang {
using namespace std;
const int NORM2_MAX = 25;
const int P_MAX = 7;
V0Parameter parameters[NORM2_MAX][P_MAX] = {
{V0Parameter(1, 10, 481, 2, 8, 4, 2), V0Parameter(1, 10, 483, 2, 8, 4, 2),
V0Parameter(1, 10, 491, 2, 8, 4, 2), V0Parameter(1, 10, 484, 3, 6, 4, 2),
V0Parameter(1, 10, 497, 3, 6, 4, 2), V0Parameter(1, 11, 506, 1, 24, 4, 2),
V0Parameter(1, 11, 506, 1, 24, 4, 2)},
{V0Parameter(1, 11, 506, 1, 23, 4, 2), V0Parameter(1, 11, 508, 1, 23, 4, 2),
V0Parameter(1, 11, 533, 1, 23, 3, 3), V0Parameter(1, 11, 506, 2, 13, 4, 2),
V0Parameter(1, 11, 506, 2, 15, 4, 2), V0Parameter(1, 11, 506, 2, 17, 4, 2),
V0Parameter(1, 11, 506, 2, 15, 4, 2)},
{V0Parameter(1, 11, 506, 2, 15, 4, 2), V0Parameter(1, 11, 508, 2, 15, 4, 2),
V0Parameter(1, 11, 513, 2, 15, 4, 2), V0Parameter(1, 11, 530, 2, 15, 5, 2),
V0Parameter(1, 11, 507, 3, 12, 4, 2), V0Parameter(1, 11, 511, 3, 11, 4, 2),
V0Parameter(1, 11, 517, 3, 12, 5, 2)},
{V0Parameter(1, 11, 508, 4, 9, 4, 2), V0Parameter(1, 11, 509, 4, 9, 5, 2),
V0Parameter(1, 11, 507, 5, 8, 5, 2), V0Parameter(1, 11, 541, 5, 8, 5, 2),
V0Parameter(1, 10, 549, 2, 8, 3, 3), V0Parameter(1, 10, 530, 2, 8, 5, 2),
V0Parameter(1, 10, 524, 3, 6, 5, 2)},
{V0Parameter(1, 10, 536, 3, 6, 5, 2), V0Parameter(1, 11, 571, 1, 22, 3, 3),
V0Parameter(1, 11, 572, 1, 22, 3, 3), V0Parameter(1, 11, 572, 1, 22, 3, 3),
V0Parameter(1, 11, 574, 1, 23, 3, 3), V0Parameter(1, 11, 585, 1, 23, 3, 3),
V0Parameter(1, 11, 541, 2, 14, 5, 2)},
{V0Parameter(1, 11, 541, 2, 14, 5, 2), V0Parameter(1, 11, 541, 2, 17, 5, 2),
V0Parameter(1, 11, 541, 2, 15, 5, 2), V0Parameter(1, 11, 541, 2, 15, 5, 2),
V0Parameter(1, 11, 542, 2, 15, 5, 2), V0Parameter(1, 11, 547, 2, 15, 5, 2),
V0Parameter(1, 11, 580, 2, 15, 5, 2)},
{V0Parameter(1, 11, 542, 3, 11, 5, 2), V0Parameter(1, 11, 545, 3, 12, 5, 2),
V0Parameter(1, 11, 561, 3, 12, 5, 2), V0Parameter(1, 11, 543, 4, 9, 5, 2),
V0Parameter(1, 11, 549, 4, 9, 5, 2), V0Parameter(1, 11, 547, 5, 8, 5, 2),
V0Parameter(1, 11, 591, 5, 8, 6, 2)},
{V0Parameter(1, 11, 550, 7, 6, 5, 2), V0Parameter(1, 10, 576, 2, 8, 5, 2),
V0Parameter(1, 10, 567, 3, 6, 5, 2), V0Parameter(1, 10, 584, 3, 6, 5, 2),
V0Parameter(1, 11, 607, 1, 20, 4, 3), V0Parameter(1, 11, 607, 1, 22, 4, 3),
V0Parameter(1, 11, 607, 1, 22, 4, 3)},
{V0Parameter(1, 11, 609, 1, 23, 4, 3), V0Parameter(1, 11, 616, 1, 23, 4, 3),
V0Parameter(1, 11, 588, 2, 13, 5, 2), V0Parameter(1, 11, 588, 2, 17, 5, 2),
V0Parameter(1, 11, 588, 2, 14, 5, 2), V0Parameter(1, 11, 588, 2, 16, 5, 2),
V0Parameter(1, 11, 588, 2, 15, 5, 2)},
{V0Parameter(1, 11, 590, 2, 15, 5, 2), V0Parameter(1, 11, 596, 2, 15, 5, 2),
V0Parameter(1, 11, 622, 2, 15, 6, 2), V0Parameter(1, 11, 589, 3, 11, 5, 2),
V0Parameter(1, 11, 594, 3, 12, 5, 2), V0Parameter(1, 11, 602, 3, 12, 6, 2),
V0Parameter(1, 11, 591, 4, 9, 5, 2)},
{V0Parameter(1, 11, 591, 4, 9, 6, 2), V0Parameter(1, 11, 589, 5, 8, 6, 2),
V0Parameter(1, 11, 655, 5, 8, 6, 2), V0Parameter(1, 11, 592, 7, 6, 6, 2),
V0Parameter(1, 11, 598, 8, 5, 6, 2), V0Parameter(1, 10, 608, 3, 6, 6, 2),
V0Parameter(1, 10, 625, 3, 6, 6, 2)},
{V0Parameter(1, 11, 647, 1, 24, 4, 3), V0Parameter(1, 11, 647, 1, 22, 4, 3),
V0Parameter(1, 11, 647, 1, 23, 4, 3), V0Parameter(1, 11, 649, 1, 23, 4, 3),
V0Parameter(1, 11, 658, 1, 23, 4, 3), V0Parameter(1, 11, 647, 2, 19, 4, 3),
V0Parameter(1, 11, 647, 2, 14, 4, 3)},
{V0Parameter(1, 11, 647, 2, 17, 4, 3), V0Parameter(1, 11, 647, 2, 15, 4, 3),
V0Parameter(1, 11, 647, 2, 15, 4, 3), V0Parameter(1, 11, 648, 2, 15, 4, 3),
V0Parameter(1, 11, 654, 2, 15, 4, 3), V0Parameter(1, 11, 674, 2, 15, 7, 2),
V0Parameter(1, 11, 624, 3, 12, 6, 2)},
{V0Parameter(1, 11, 627, 3, 11, 6, 2), V0Parameter(1, 11, 648, 3, 12, 6, 2),
V0Parameter(1, 11, 625, 4, 9, 6, 2), V0Parameter(1, 11, 633, 4, 9, 6, 2),
V0Parameter(1, 11, 630, 5, 8, 6, 2), V0Parameter(1, 11, 632, 6, 7, 6, 2),
V0Parameter(1, 11, 634, 7, 6, 6, 2)},
{V0Parameter(1, 11, 642, 8, 5, 6, 2), V0Parameter(1, 11, 644, 11, 4, 6, 2),
V0Parameter(1, 11, 698, 1, 20, 4, 3), V0Parameter(1, 11, 698, 1, 21, 4, 3),
V0Parameter(1, 11, 698, 1, 22, 4, 3), V0Parameter(1, 11, 699, 1, 22, 4, 3),
V0Parameter(1, 11, 702, 1, 23, 4, 3)},
{V0Parameter(1, 11, 721, 1, 23, 4, 3), V0Parameter(1, 11, 698, 2, 13, 4, 3),
V0Parameter(1, 11, 698, 2, 17, 4, 3), V0Parameter(1, 11, 698, 2, 14, 4, 3),
V0Parameter(1, 11, 698, 2, 16, 4, 3), V0Parameter(1, 11, 698, 2, 15, 4, 3),
V0Parameter(1, 11, 700, 2, 15, 4, 3)},
{V0Parameter(1, 11, 711, 2, 15, 4, 3), V0Parameter(1, 11, 674, 3, 12, 6, 2),
V0Parameter(1, 11, 675, 3, 11, 6, 2), V0Parameter(1, 11, 681, 3, 12, 6, 2),
V0Parameter(1, 11, 695, 3, 12, 7, 2), V0Parameter(1, 11, 677, 4, 9, 6, 2),
V0Parameter(1, 11, 677, 4, 9, 7, 2)},
{V0Parameter(1, 11, 674, 5, 8, 7, 2), V0Parameter(1, 11, 676, 6, 7, 7, 2),
V0Parameter(1, 11, 678, 7, 6, 7, 2), V0Parameter(1, 11, 688, 8, 5, 7, 2),
V0Parameter(1, 11, 690, 11, 4, 7, 2),
V0Parameter(1, 11, 731, 14, 3, 15, 1),
V0Parameter(1, 11, 745, 1, 21, 5, 3)},
{V0Parameter(1, 11, 745, 1, 22, 5, 3), V0Parameter(1, 11, 746, 1, 23, 5, 3),
V0Parameter(1, 11, 750, 1, 23, 5, 3), V0Parameter(1, 11, 781, 1, 23, 5, 3),
V0Parameter(1, 11, 745, 2, 19, 5, 3), V0Parameter(1, 11, 745, 2, 14, 5, 3),
V0Parameter(1, 11, 745, 2, 17, 5, 3)},
{V0Parameter(1, 11, 745, 2, 15, 5, 3), V0Parameter(1, 11, 746, 2, 15, 5, 3),
V0Parameter(1, 11, 748, 2, 15, 5, 3), V0Parameter(1, 11, 763, 2, 15, 5, 3),
V0Parameter(1, 11, 745, 3, 12, 5, 3), V0Parameter(1, 11, 747, 3, 11, 5, 3),
V0Parameter(1, 11, 756, 3, 11, 5, 3)},
{V0Parameter(1, 11, 722, 4, 9, 7, 2), V0Parameter(1, 11, 727, 4, 9, 7, 2),
V0Parameter(1, 11, 750, 4, 9, 8, 2), V0Parameter(1, 11, 747, 5, 8, 7, 2),
V0Parameter(1, 11, 747, 6, 7, 8, 2), V0Parameter(1, 11, 757, 7, 6, 8, 2),
V0Parameter(1, 11, 739, 10, 4, 8, 2)},
{V0Parameter(1, 11, 735, 14, 3, 8, 2), V0Parameter(1, 11, 767, 21, 2, 8, 2),
V0Parameter(0, 0, 0, 0, 0, 0, 0), V0Parameter(1, 12, 839, 1, 22, 4, 4),
V0Parameter(1, 12, 844, 1, 23, 4, 4), V0Parameter(1, 12, 847, 1, 23, 6, 3),
V0Parameter(1, 12, 815, 2, 13, 5, 3)},
{V0Parameter(1, 12, 815, 2, 17, 5, 3), V0Parameter(1, 12, 815, 2, 14, 5, 3),
V0Parameter(1, 12, 815, 2, 15, 5, 3), V0Parameter(1, 12, 816, 2, 15, 5, 3),
V0Parameter(1, 12, 820, 2, 15, 5, 3), V0Parameter(1, 12, 858, 2, 15, 4, 4),
V0Parameter(1, 12, 815, 3, 11, 5, 3)},
{V0Parameter(1, 12, 818, 3, 11, 5, 3), V0Parameter(1, 12, 790, 3, 11, 8, 2),
V0Parameter(1, 12, 783, 4, 9, 8, 2), V0Parameter(1, 12, 787, 4, 9, 8, 2),
V0Parameter(1, 12, 823, 4, 9, 8, 2), V0Parameter(1, 12, 822, 5, 8, 8, 2),
V0Parameter(1, 12, 843, 6, 7, 9, 2)},
{V0Parameter(1, 12, 792, 8, 5, 8, 2), V0Parameter(1, 12, 802, 10, 4, 8, 2),
V0Parameter(1, 12, 804, 14, 3, 8, 2), V0Parameter(1, 12, 825, 22, 2, 9, 2),
V0Parameter(0, 0, 0, 0, 0, 0, 0), V0Parameter(0, 0, 0, 0, 0, 0, 0),
V0Parameter(0, 0, 0, 0, 0, 0, 0)}};
V0Parameter *getV0Parameter(size_t norm, size_t p) {
if (norm > NORM2_MAX) {
return nullptr;
}
if (p >= P_MAX) {
return nullptr;
}
// - 1 is an offset as norm and p are in [1, ...] and not [0, ...]
auto param = &parameters[norm - 1][p - 1];
if (param->k == 0) {
return nullptr;
}
return param;
}
} // namespace zamalang
} // namespace mlir

View File

@@ -32,6 +32,9 @@ llvm::cl::opt<std::string> output("o",
llvm::cl::value_desc("filename"),
llvm::cl::init("-"));
llvm::cl::opt<bool> verbose("verbose", llvm::cl::desc("verbose logs"),
llvm::cl::init<bool>(false));
llvm::cl::list<std::string> passes(
"passes",
llvm::cl::desc("Specify the passes to run (use only for compiler tests)"),
@@ -53,8 +56,19 @@ llvm::cl::opt<bool> splitInputFile(
"chunk independently"),
llvm::cl::init(false));
llvm::cl::opt<bool> generateKeySet(
"generate-keyset",
llvm::cl::desc("[tmp] Generate a key set for the compiled fhe circuit"),
llvm::cl::init<bool>(false));
llvm::cl::opt<bool> runJit("run-jit", llvm::cl::desc("JIT the code and run it"),
llvm::cl::init<bool>(false));
llvm::cl::opt<std::string> jitFuncname(
"jit-funcname",
llvm::cl::desc("Name of the function to execute, default 'main'"),
llvm::cl::init<std::string>("main"));
llvm::cl::list<int>
jitArgs("jit-args",
llvm::cl::desc("Value of arguments to pass to the main func"),
@@ -64,6 +78,12 @@ llvm::cl::opt<bool> toLLVM("to-llvm", llvm::cl::desc("Compile to llvm and "),
llvm::cl::init<bool>(false));
}; // namespace cmdline
#define LOG_VERBOSE(expr) \
if (cmdline::verbose) \
llvm::errs() << expr;
#define LOG_ERROR(expr) llvm::errs() << expr;
auto defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr);
mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) {
@@ -77,32 +97,42 @@ mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) {
return mlir::success();
}
mlir::LogicalResult runJit(mlir::ModuleOp module, llvm::raw_ostream &os) {
mlir::LogicalResult runJit(mlir::ModuleOp module,
mlir::zamalang::KeySet &keySet,
llvm::raw_ostream &os) {
// Create the JIT lambda
auto maybeLambda =
mlir::zamalang::JITLambda::create("main", module, defaultOptPipeline);
auto maybeLambda = mlir::zamalang::JITLambda::create(
cmdline::jitFuncname, module, defaultOptPipeline);
if (!maybeLambda) {
return mlir::failure();
}
auto lambda = maybeLambda.get().get();
auto lambda = std::move(maybeLambda.get());
// Create buffer to copy argument
std::vector<int64_t> dummy(cmdline::jitArgs.size());
llvm::SmallVector<void *> llvmArgs;
for (auto i = 0; i < cmdline::jitArgs.size(); i++) {
dummy[i] = cmdline::jitArgs[i];
llvmArgs.push_back(&dummy[i]);
}
// Add the result pointer
uint64_t res = 0;
llvmArgs.push_back(&res);
// Create the arguments of the JIT lambda
auto maybeArguments = mlir::zamalang::JITLambda::Argument::create(keySet);
if (auto err = maybeArguments.takeError()) {
// Invoke the lambda
if (lambda->invokeRaw(llvmArgs)) {
LOG_ERROR("Cannot create lambda arguments: " << err << "\n");
return mlir::failure();
}
std::cerr << res << "\n";
// Set the arguments
auto arguments = std::move(maybeArguments.get());
for (auto i = 0; i < cmdline::jitArgs.size(); i++) {
if (auto err = arguments->setArg(i, cmdline::jitArgs[i])) {
LOG_ERROR("Cannot push argument " << i << ": " << err << "\n");
return mlir::failure();
}
}
// Invoke the lambda
if (lambda->invoke(*arguments)) {
return mlir::failure();
}
uint64_t res = 0;
if (auto err = arguments->getResult(0, res)) {
LOG_ERROR("Cannot get result : " << err << "\n");
return mlir::failure();
}
llvm::errs() << res << "\n";
return mlir::success();
}
@@ -137,20 +167,67 @@ processInputBuffer(mlir::MLIRContext &context,
return mlir::success();
}
if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirLLVMDialect(
context, *module,
[](std::string passName) {
return cmdline::passes.size() == 0 ||
std::any_of(
cmdline::passes.begin(), cmdline::passes.end(),
auto enablePass = [](std::string passName) {
return cmdline::passes.size() == 0 ||
std::any_of(cmdline::passes.begin(), cmdline::passes.end(),
[&](const std::string &p) { return passName == p; });
})
};
// Lower to MLIR Stds Dialects and compute the constraint on the FHE Circuit.
mlir::zamalang::FHECircuitConstraint constraint;
LOG_VERBOSE("### Lower from HLFHE to MLIR standards \n");
if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect(
context, *module, constraint, enablePass)
.failed()) {
return mlir::failure();
}
LOG_VERBOSE("### Global FHE constraint: {norm2:" << constraint.norm2 << ", p:"
<< constraint.p << "}\n");
// Retreive the parameters for the v0 approach
mlir::zamalang::V0Parameter *fheParameter =
mlir::zamalang::getV0Parameter(constraint.norm2, constraint.p);
LOG_VERBOSE("### FHE parameters for the atomic pattern: {k: "
<< fheParameter->k
<< ", polynomialSize: " << fheParameter->polynomialSize
<< ", nSmall: " << fheParameter->nSmall
<< ", brLevel: " << fheParameter->brLevel
<< ", brLogBase: " << fheParameter->brLogBase
<< ", ksLevel: " << fheParameter->ksLevel
<< ", polynomialSize: " << fheParameter->ksLogBase << "}\n");
// Generate the keySet
std::unique_ptr<mlir::zamalang::KeySet> keySet;
if (cmdline::generateKeySet || cmdline::runJit) {
// Create the client parameters
auto clientParameter = mlir::zamalang::createClientParametersForV0(
fheParameter, constraint.p, cmdline::jitFuncname, *module);
if (auto err = clientParameter.takeError()) {
LOG_ERROR("cannot generate client parameters: " << err << "\n");
return mlir::failure();
}
LOG_VERBOSE("### Generate the key set\n");
auto maybeKeySet =
mlir::zamalang::KeySet::generate(clientParameter.get(), 0,
0); // TODO: seed
if (auto err = maybeKeySet.takeError()) {
llvm::errs() << err;
return mlir::failure();
}
keySet = std::move(maybeKeySet.get());
}
// Lower to MLIR LLVM Dialect
LOG_VERBOSE("### Lower from MLIR standards to LLVM\n");
if (mlir::zamalang::CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect(
context, *module, enablePass)
.failed()) {
return mlir::failure();
}
if (cmdline::runJit) {
return runJit(module.get(), os);
LOG_VERBOSE("### JIT compile & running\n");
return runJit(module.get(), *keySet, os);
}
if (cmdline::toLLVM) {
return dumpLLVMIR(module.get(), os);

View File

@@ -0,0 +1,6 @@
// RUN: zamacompiler %s --run-jit --jit-args 42 2>&1| FileCheck %s
// CHECK-LABEL: 42
func @main(%arg0: !LowLFHE.lwe_ciphertext) -> !LowLFHE.lwe_ciphertext {
return %arg0 : !LowLFHE.lwe_ciphertext
}