mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(compiler): First draft of client parameters generation, runtime support for encrypting and decrypting circuit gates, integration of fhe parameters for the v0 (#65, #66, #56)
This commit is contained in:
@@ -6,7 +6,6 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 14)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-rtti")
|
||||
|
||||
find_package(MLIR REQUIRED CONFIG)
|
||||
message(STATUS "Using MLIR cmake file from: ${MLIR_DIR}")
|
||||
@@ -27,6 +26,12 @@ include_directories(${PROJECT_BINARY_DIR}/include)
|
||||
link_directories(${LLVM_BUILD_LIBRARY_DIR})
|
||||
add_definitions(${LLVM_DEFINITIONS})
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Concrete FFI Configuration
|
||||
#-------------------------------------------------------------------------------
|
||||
include_directories(${CONCRETE_FFI_RELEASE})
|
||||
add_library(Concrete SHARED IMPORTED)
|
||||
set_target_properties(Concrete PROPERTIES IMPORTED_LOCATION ${CONCRETE_FFI_RELEASE}/libconcrete_ffi.so )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Python Configuration
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
build:
|
||||
cmake -B build . -DLLVM_DIR=${LLVM_PROJECT}/build/lib/cmake/llvm -DMLIR_DIR=${LLVM_PROJECT}/build/lib/cmake/mlir
|
||||
cmake -B build . \
|
||||
-DLLVM_DIR=${LLVM_PROJECT}/build/lib/cmake/llvm \
|
||||
-DMLIR_DIR=${LLVM_PROJECT}/build/lib/cmake/mlir \
|
||||
-DCONCRETE_FFI_RELEASE=${CONCRETE_PROJECT}/target/release
|
||||
|
||||
zamacompiler:
|
||||
make -C build/ zamacompiler
|
||||
|
||||
81
compiler/include/zamalang/Support/ClientParameters.h
Normal file
81
compiler/include/zamalang/Support/ClientParameters.h
Normal file
@@ -0,0 +1,81 @@
|
||||
#ifndef ZAMALANG_SUPPORT_CLIENTPARAMETERS_H_
|
||||
#define ZAMALANG_SUPPORT_CLIENTPARAMETERS_H_
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
|
||||
#include "zamalang/Support/V0Parameters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
typedef size_t DecompositionLevelCount;
|
||||
typedef size_t DecompositionBaseLog;
|
||||
typedef size_t PolynomialSize;
|
||||
typedef size_t Precision;
|
||||
typedef double Variance;
|
||||
|
||||
typedef uint64_t LweSize;
|
||||
typedef uint64_t GLWESize;
|
||||
|
||||
typedef std::string LweSecretKeyID;
|
||||
struct LweSecretKeyParam {
|
||||
LweSize size;
|
||||
};
|
||||
|
||||
typedef std::string BootstrapKeyID;
|
||||
struct BootstrapKeyParam {
|
||||
LweSecretKeyID inputSecretKeyID;
|
||||
LweSecretKeyID outputSecretKeyID;
|
||||
DecompositionLevelCount level;
|
||||
DecompositionBaseLog baseLog;
|
||||
GLWESize k;
|
||||
Variance variance;
|
||||
};
|
||||
|
||||
typedef std::string KeyswitchKeyID;
|
||||
struct KeyswitchKeyParam {
|
||||
LweSecretKeyID inputSecretKeyID;
|
||||
LweSecretKeyID outputSecretKeyID;
|
||||
DecompositionLevelCount level;
|
||||
DecompositionBaseLog baseLog;
|
||||
Variance variance;
|
||||
};
|
||||
|
||||
struct Encoding {
|
||||
Precision precision;
|
||||
};
|
||||
|
||||
struct EncryptionGate {
|
||||
LweSecretKeyID secretKeyID;
|
||||
Variance variance;
|
||||
Encoding encoding;
|
||||
};
|
||||
|
||||
struct CircuitGateShape {
|
||||
uint64_t size;
|
||||
};
|
||||
|
||||
struct CircuitGate {
|
||||
llvm::Optional<EncryptionGate> encryption;
|
||||
CircuitGateShape shape;
|
||||
};
|
||||
|
||||
struct ClientParameters {
|
||||
std::map<LweSecretKeyID, LweSecretKeyParam> secretKeys;
|
||||
std::map<BootstrapKeyID, BootstrapKeyParam> bootstrapKeys;
|
||||
std::map<KeyswitchKeyID, KeyswitchKeyParam> keyswitchKeys;
|
||||
std::vector<CircuitGate> inputs;
|
||||
std::vector<CircuitGate> outputs;
|
||||
};
|
||||
|
||||
llvm::Expected<ClientParameters>
|
||||
createClientParametersForV0(V0Parameter *v0Param, Precision precision,
|
||||
llvm::StringRef name, mlir::ModuleOp module);
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -5,21 +5,43 @@
|
||||
#include <mlir/ExecutionEngine/ExecutionEngine.h>
|
||||
#include <mlir/Pass/PassManager.h>
|
||||
|
||||
#include "zamalang/Support/ClientParameters.h"
|
||||
#include "zamalang/Support/KeySet.h"
|
||||
#include "zamalang/Support/V0Parameters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
/// For the v0 we compute a global constraint, this is defined here as the
|
||||
/// high-level verification pass is not yet implemented.
|
||||
struct FHECircuitConstraint {
|
||||
size_t norm2;
|
||||
size_t p;
|
||||
};
|
||||
|
||||
class CompilerTools {
|
||||
public:
|
||||
/// lowerHLFHEToMlirLLVMDialect run all passes to lower FHE dialects to mlir
|
||||
/// LLVM dialect.
|
||||
static mlir::LogicalResult lowerHLFHEToMlirLLVMDialect(
|
||||
/// lowerable to llvm dialect.
|
||||
/// The given module MLIR operation would be modified and the constraints set.
|
||||
static mlir::LogicalResult lowerHLFHEToMlirStdsDialect(
|
||||
mlir::MLIRContext &context, mlir::Operation *module,
|
||||
FHECircuitConstraint &constraint,
|
||||
llvm::function_ref<bool(std::string)> enablePass = [](std::string pass) {
|
||||
return true;
|
||||
});
|
||||
|
||||
/// lowerMlirStdsDialectToMlirLLVMDialect run all passes to lower MLIR
|
||||
/// dialects to MLIR LLVM dialect. The given module MLIR operation would be
|
||||
/// modified.
|
||||
static mlir::LogicalResult lowerMlirStdsDialectToMlirLLVMDialect(
|
||||
mlir::MLIRContext &context, mlir::Operation *module,
|
||||
llvm::function_ref<bool(std::string)> enablePass = [](std::string pass) {
|
||||
return true;
|
||||
});
|
||||
|
||||
static llvm::Expected<std::unique_ptr<llvm::Module>>
|
||||
toLLVMModule(llvm::LLVMContext &context, mlir::ModuleOp &module,
|
||||
toLLVMModule(llvm::LLVMContext &llvmContext, mlir::ModuleOp &module,
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline);
|
||||
};
|
||||
|
||||
@@ -27,6 +49,28 @@ public:
|
||||
/// of the module.
|
||||
class JITLambda {
|
||||
public:
|
||||
class Argument {
|
||||
public:
|
||||
Argument(KeySet &keySet);
|
||||
~Argument();
|
||||
|
||||
// Create lambda Argument that use the given KeySet to perform encryption
|
||||
// and decryption operations.
|
||||
static llvm::Expected<std::unique_ptr<Argument>> create(KeySet &keySet);
|
||||
|
||||
// Set the argument at the given pos as a uint64_t.
|
||||
llvm::Error setArg(size_t pos, uint64_t arg);
|
||||
|
||||
// Get the result at the given pos as an uint64_t.
|
||||
llvm::Error getResult(size_t pos, uint64_t &res);
|
||||
|
||||
private:
|
||||
friend JITLambda;
|
||||
std::vector<void *> rawArg;
|
||||
std::vector<void *> inputs;
|
||||
std::vector<void *> results;
|
||||
KeySet &keySet;
|
||||
};
|
||||
JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name)
|
||||
: type(type), name(name){};
|
||||
|
||||
@@ -35,7 +79,7 @@ public:
|
||||
create(llvm::StringRef name, mlir::ModuleOp &module,
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline);
|
||||
|
||||
/// invokeRaw execute the jit lambda with a lits of arguments, the last one is
|
||||
/// invokeRaw execute the jit lambda with a list of Argument, the last one is
|
||||
/// used to store the result of the computation.
|
||||
/// Example:
|
||||
/// uin64_t arg0 = 1;
|
||||
@@ -44,6 +88,9 @@ public:
|
||||
/// lambda.invokeRaw(args);
|
||||
llvm::Error invokeRaw(llvm::MutableArrayRef<void *> args);
|
||||
|
||||
/// invoke the jit lambda with the Argument.
|
||||
llvm::Error invoke(Argument &args);
|
||||
|
||||
private:
|
||||
mlir::LLVM::LLVMFunctionType type;
|
||||
llvm::StringRef name;
|
||||
|
||||
65
compiler/include/zamalang/Support/KeySet.h
Normal file
65
compiler/include/zamalang/Support/KeySet.h
Normal file
@@ -0,0 +1,65 @@
|
||||
#ifndef ZAMALANG_SUPPORT_KEYSET_H_
|
||||
#define ZAMALANG_SUPPORT_KEYSET_H_
|
||||
|
||||
#include "llvm/Support/Error.h"
|
||||
#include <memory>
|
||||
|
||||
extern "C" {
|
||||
#include "concrete-ffi.h"
|
||||
}
|
||||
|
||||
#include "zamalang/Support/ClientParameters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
class KeySet {
|
||||
public:
|
||||
~KeySet();
|
||||
// allocate a KeySet according the ClientParameters.
|
||||
static llvm::Expected<std::unique_ptr<KeySet>>
|
||||
generate(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb);
|
||||
|
||||
// isInputEncrypted return true if the input at the given pos is encrypted.
|
||||
bool isInputEncrypted(size_t pos);
|
||||
// allocate a lwe ciphertext for the argument at argPos.
|
||||
llvm::Error allocate_lwe(size_t argPos, LweCiphertext_u64 **ciphertext);
|
||||
// encrypt the input to the ciphertext for the argument at argPos.
|
||||
llvm::Error encrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
|
||||
uint64_t input);
|
||||
|
||||
// isOuputEncrypted return true if the output at the given pos is encrypted.
|
||||
bool isOutputEncrypted(size_t pos);
|
||||
// decrypt the ciphertext to the output for the argument at argPos.
|
||||
llvm::Error decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
|
||||
uint64_t &output);
|
||||
|
||||
size_t numInputs() { return inputs.size(); }
|
||||
size_t numOutputs() { return outputs.size(); }
|
||||
|
||||
protected:
|
||||
llvm::Error generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param,
|
||||
SecretRandomGenerator *generator);
|
||||
llvm::Error generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param,
|
||||
EncryptionRandomGenerator *generator);
|
||||
llvm::Error generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param,
|
||||
EncryptionRandomGenerator *generator);
|
||||
|
||||
private:
|
||||
EncryptionRandomGenerator *encryptionRandomGenerator;
|
||||
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
|
||||
secretKeys;
|
||||
std::map<LweSecretKeyID, std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>>
|
||||
bootstrapKeys;
|
||||
std::map<LweSecretKeyID, std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>>
|
||||
keyswitchKeys;
|
||||
std::vector<std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *>>
|
||||
inputs;
|
||||
std::vector<std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *>>
|
||||
outputs;
|
||||
};
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
29
compiler/include/zamalang/Support/V0Parameters.h
Normal file
29
compiler/include/zamalang/Support/V0Parameters.h
Normal file
@@ -0,0 +1,29 @@
|
||||
#ifndef ZAMALANG_SUPPORT_V0Parameter_H_
|
||||
#define ZAMALANG_SUPPORT_V0Parameter_H_
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
typedef struct V0Parameter {
|
||||
size_t k;
|
||||
size_t polynomialSize;
|
||||
size_t nSmall;
|
||||
size_t brLevel;
|
||||
size_t brLogBase;
|
||||
size_t ksLevel;
|
||||
size_t ksLogBase;
|
||||
|
||||
V0Parameter(size_t k, size_t polynomialSize, size_t nSmall, size_t brLevel,
|
||||
size_t brLogBase, size_t ksLevel, size_t ksLogBase)
|
||||
: k(k), polynomialSize(polynomialSize), nSmall(nSmall), brLevel(brLevel),
|
||||
brLogBase(brLogBase), ksLevel(ksLevel), ksLogBase(ksLogBase) {}
|
||||
|
||||
} V0Parameter;
|
||||
|
||||
V0Parameter *getV0Parameter(size_t norm, size_t p);
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
#endif
|
||||
@@ -3,9 +3,6 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "zamalang/Conversion/Passes.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
@@ -19,10 +16,16 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
|
||||
#include "zamalang/Conversion/Passes.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
|
||||
|
||||
namespace {
|
||||
struct MLIRLowerableDialectsToLLVMPass
|
||||
: public MLIRLowerableDialectsToLLVMBase<MLIRLowerableDialectsToLLVMPass> {
|
||||
void runOnOperation() final;
|
||||
|
||||
/// Convert types to the LLVM dialect-compatible type
|
||||
static llvm::Optional<mlir::Type> convertTypes(mlir::Type type);
|
||||
};
|
||||
} // namespace
|
||||
|
||||
@@ -35,6 +38,7 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
|
||||
// Setup the LLVMTypeConverter (that converts `std` types to `llvm` types) and
|
||||
// add our types conversion to `llvm` compatible type.
|
||||
mlir::LLVMTypeConverter typeConverter(&getContext());
|
||||
typeConverter.addConversion(convertTypes);
|
||||
|
||||
// Setup the set of the patterns rewriter. At this point we want to
|
||||
// convert the `scf` operations to `std` and `std` operations to `llvm`.
|
||||
@@ -49,6 +53,15 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Optional<mlir::Type>
|
||||
MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) {
|
||||
if (type.isa<mlir::zamalang::LowLFHE::LweCiphertextType>()) {
|
||||
return mlir::LLVM::LLVMPointerType::get(
|
||||
mlir::IntegerType::get(type.getContext(), 8));
|
||||
}
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
/// Create a pass for lowering operations the remaining mlir dialects
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
add_mlir_library(ZamalangSupport
|
||||
CompilerTools.cpp
|
||||
V0Parameters.cpp
|
||||
ClientParameters.cpp
|
||||
KeySet.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/zamalang/Support
|
||||
@@ -13,4 +16,7 @@ add_mlir_library(ZamalangSupport
|
||||
MLIRLowerableDialectsToLLVM
|
||||
|
||||
MLIRExecutionEngine
|
||||
${LLVM_PTHREAD_LIB})
|
||||
${LLVM_PTHREAD_LIB}
|
||||
|
||||
Concrete
|
||||
)
|
||||
|
||||
117
compiler/lib/Support/ClientParameters.cpp
Normal file
117
compiler/lib/Support/ClientParameters.cpp
Normal file
@@ -0,0 +1,117 @@
|
||||
#include <llvm/ADT/STLExtras.h>
|
||||
#include <llvm/Support/Error.h>
|
||||
|
||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
|
||||
#include "zamalang/Support/ClientParameters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
// For the v0 the secretKeyID and precision are the same for all gates.
|
||||
llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
|
||||
Precision precision,
|
||||
mlir::Type type) {
|
||||
if (type.isInteger(64)) {
|
||||
return CircuitGate{
|
||||
.encryption = llvm::None,
|
||||
.shape = {.size = 0},
|
||||
};
|
||||
}
|
||||
if (type.isa<mlir::zamalang::LowLFHE::LweCiphertextType>()) {
|
||||
return CircuitGate{
|
||||
.encryption = llvm::Optional<EncryptionGate>({
|
||||
.secretKeyID = secretKeyID,
|
||||
// TODO - Compute variance, wait for security estimator
|
||||
.variance = 0.,
|
||||
.encoding = {.precision = precision},
|
||||
}),
|
||||
.shape = {.size = 0},
|
||||
};
|
||||
}
|
||||
auto memref = type.dyn_cast_or_null<mlir::MemRefType>();
|
||||
if (memref != nullptr) {
|
||||
auto gate =
|
||||
gateFromMLIRType(secretKeyID, precision, memref.getElementType());
|
||||
if (auto err = gate.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
gate->shape.size = memref.getDimSize(0);
|
||||
return gate;
|
||||
}
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot convert MLIR type to shape", llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
llvm::Expected<ClientParameters>
|
||||
createClientParametersForV0(V0Parameter *v0Param, Precision precision,
|
||||
llvm::StringRef name, mlir::ModuleOp module) {
|
||||
// Static client parameters from global parameters for v0
|
||||
ClientParameters c{
|
||||
.secretKeys{
|
||||
{"small", {.size = v0Param->nSmall}},
|
||||
{"big", {.size = v0Param->k * (1 << v0Param->polynomialSize)}},
|
||||
},
|
||||
.bootstrapKeys{
|
||||
{
|
||||
"bsk_v0",
|
||||
{
|
||||
.inputSecretKeyID = "small",
|
||||
.outputSecretKeyID = "big",
|
||||
.level = v0Param->brLevel,
|
||||
.baseLog = v0Param->brLogBase,
|
||||
.k = v0Param->k,
|
||||
// TODO - Compute variance, wait for security estimator
|
||||
.variance = 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
.keyswitchKeys{
|
||||
{
|
||||
"ksk_v0",
|
||||
{
|
||||
.inputSecretKeyID = "big",
|
||||
.outputSecretKeyID = "small",
|
||||
.level = v0Param->ksLevel,
|
||||
.baseLog = v0Param->ksLogBase,
|
||||
// TODO - Compute variance, wait for security estimator
|
||||
.variance = 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
// Find the input function
|
||||
auto rangeOps = module.getOps<mlir::FuncOp>();
|
||||
auto funcOp = llvm::find_if(
|
||||
rangeOps, [&](mlir::FuncOp op) { return op.getName() == name; });
|
||||
if (funcOp == rangeOps.end()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot find the function for generate client parameters",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
// For the v0 the precision is global
|
||||
Encoding encoding = {.precision = precision};
|
||||
|
||||
// Create input and output circuit gate parameters
|
||||
auto funcType = (*funcOp).getType();
|
||||
for (auto inType : funcType.getInputs()) {
|
||||
auto gate = gateFromMLIRType("big", precision, inType);
|
||||
if (auto err = gate.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
c.inputs.push_back(gate.get());
|
||||
}
|
||||
for (auto outType : funcType.getResults()) {
|
||||
auto gate = gateFromMLIRType("big", precision, outType);
|
||||
if (auto err = gate.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
c.outputs.push_back(gate.get());
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
@@ -8,6 +8,9 @@
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
// This is temporary while we doesn't yet have the high-level verification pass
|
||||
FHECircuitConstraint defaultGlobalFHECircuitConstraint{.norm2 = 20, .p = 6};
|
||||
|
||||
void initLLVMNativeTarget() {
|
||||
// Initialize LLVM targets.
|
||||
llvm::InitializeNativeTarget();
|
||||
@@ -27,8 +30,9 @@ void addFilteredPassToPassManager(
|
||||
}
|
||||
};
|
||||
|
||||
mlir::LogicalResult CompilerTools::lowerHLFHEToMlirLLVMDialect(
|
||||
mlir::LogicalResult CompilerTools::lowerHLFHEToMlirStdsDialect(
|
||||
mlir::MLIRContext &context, mlir::Operation *module,
|
||||
FHECircuitConstraint &constraint,
|
||||
llvm::function_ref<bool(std::string)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
|
||||
@@ -37,11 +41,7 @@ mlir::LogicalResult CompilerTools::lowerHLFHEToMlirLLVMDialect(
|
||||
pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg(), enablePass);
|
||||
addFilteredPassToPassManager(
|
||||
pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass);
|
||||
addFilteredPassToPassManager(
|
||||
pm, mlir::zamalang::createConvertMidLFHEToLowLFHEPass(), enablePass);
|
||||
addFilteredPassToPassManager(
|
||||
pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass(),
|
||||
enablePass);
|
||||
constraint = defaultGlobalFHECircuitConstraint;
|
||||
|
||||
// Run the passes
|
||||
if (pm.run(module).failed()) {
|
||||
@@ -51,14 +51,31 @@ mlir::LogicalResult CompilerTools::lowerHLFHEToMlirLLVMDialect(
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect(
|
||||
mlir::MLIRContext &context, mlir::Operation *module,
|
||||
llvm::function_ref<bool(std::string)> enablePass) {
|
||||
|
||||
mlir::PassManager pm(&context);
|
||||
addFilteredPassToPassManager(
|
||||
pm, mlir::zamalang::createConvertMidLFHEToLowLFHEPass(), enablePass);
|
||||
addFilteredPassToPassManager(
|
||||
pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass(),
|
||||
enablePass);
|
||||
|
||||
if (pm.run(module).failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<llvm::Module>> CompilerTools::toLLVMModule(
|
||||
llvm::LLVMContext &context, mlir::ModuleOp &module,
|
||||
llvm::LLVMContext &llvmContext, mlir::ModuleOp &module,
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline) {
|
||||
|
||||
initLLVMNativeTarget();
|
||||
mlir::registerLLVMDialectTranslation(*module->getContext());
|
||||
|
||||
auto llvmModule = mlir::translateModuleToLLVMIR(module, context);
|
||||
auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext);
|
||||
if (!llvmModule) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"failed to translate MLIR to LLVM IR", llvm::inconvertibleErrorCode());
|
||||
@@ -113,5 +130,69 @@ llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef<void *> args) {
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::invoke(Argument &args) { return invokeRaw(args.rawArg); }
|
||||
|
||||
JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
|
||||
inputs = std::vector<void *>(keySet.numInputs());
|
||||
results = std::vector<void *>(keySet.numOutputs());
|
||||
// The raw argument contains pointers to inputs and pointers to store the
|
||||
// results
|
||||
rawArg =
|
||||
std::vector<void *>(keySet.numInputs() + keySet.numOutputs(), nullptr);
|
||||
// Set the results pointer on the rawArg
|
||||
for (auto i = keySet.numInputs(); i < rawArg.size(); i++) {
|
||||
rawArg[i] = &results[i - keySet.numInputs()];
|
||||
}
|
||||
}
|
||||
|
||||
JITLambda::Argument::~Argument() {
|
||||
int err;
|
||||
for (auto i = 0; i < keySet.numInputs(); i++) {
|
||||
if (keySet.isInputEncrypted(i)) {
|
||||
free_lwe_ciphertext_u64(&err, (LweCiphertext_u64 *)(inputs[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<JITLambda::Argument>>
|
||||
JITLambda::Argument::create(KeySet &keySet) {
|
||||
auto args = std::make_unique<JITLambda::Argument>(keySet);
|
||||
return std::move(args);
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) {
|
||||
// If argument is not encrypted, just save.
|
||||
if (!keySet.isInputEncrypted(pos)) {
|
||||
inputs[pos] = (void *)arg;
|
||||
rawArg[pos] = &inputs[pos];
|
||||
return llvm::Error::success();
|
||||
}
|
||||
// Else if is encryted, allocate ciphertext.
|
||||
LweCiphertext_u64 *ctArg;
|
||||
if (auto err = this->keySet.allocate_lwe(pos, &ctArg)) {
|
||||
return std::move(err);
|
||||
}
|
||||
if (auto err = this->keySet.encrypt_lwe(pos, ctArg, arg)) {
|
||||
return std::move(err);
|
||||
}
|
||||
inputs[pos] = ctArg;
|
||||
rawArg[pos] = &inputs[pos];
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t &res) {
|
||||
// If result is not encrypted, just set the result
|
||||
if (!keySet.isOutputEncrypted(pos)) {
|
||||
res = (uint64_t)(results[pos]);
|
||||
return llvm::Error::success();
|
||||
}
|
||||
// Else if is encryted, decrypt
|
||||
LweCiphertext_u64 *ct = (LweCiphertext_u64 *)(results[pos]);
|
||||
if (auto err = this->keySet.decrypt_lwe(pos, ct, res)) {
|
||||
return std::move(err);
|
||||
}
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
276
compiler/lib/Support/KeySet.cpp
Normal file
276
compiler/lib/Support/KeySet.cpp
Normal file
@@ -0,0 +1,276 @@
|
||||
#include "zamalang/Support/KeySet.h"
|
||||
|
||||
#define CAPI_ERR_TO_LLVM_ERROR(s, msg) \
|
||||
{ \
|
||||
int err; \
|
||||
s; \
|
||||
if (err != 0) { \
|
||||
return llvm::make_error<llvm::StringError>( \
|
||||
msg, llvm::inconvertibleErrorCode()); \
|
||||
} \
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
KeySet::~KeySet() {
|
||||
int err;
|
||||
for (auto it : secretKeys) {
|
||||
free_lwe_secret_key_u64(&err, it.second.second);
|
||||
}
|
||||
for (auto it : bootstrapKeys) {
|
||||
free_lwe_bootstrap_key_u64(&err, it.second.second);
|
||||
}
|
||||
for (auto it : keyswitchKeys) {
|
||||
free_lwe_keyswitch_key_u64(&err, it.second.second);
|
||||
}
|
||||
free_encryption_generator(&err, encryptionRandomGenerator);
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<KeySet>>
|
||||
KeySet::generate(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
auto keySet = std::make_unique<KeySet>();
|
||||
|
||||
{
|
||||
// Generate LWE secret keys
|
||||
SecretRandomGenerator *generator;
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
generator = allocate_secret_generator(&err, seed_msb, seed_lsb),
|
||||
"cannot allocate random generator");
|
||||
for (auto secretKeyParam : params.secretKeys) {
|
||||
auto e = keySet->generateSecretKey(secretKeyParam.first,
|
||||
secretKeyParam.second, generator);
|
||||
if (e) {
|
||||
return e;
|
||||
}
|
||||
}
|
||||
CAPI_ERR_TO_LLVM_ERROR(free_secret_generator(&err, generator),
|
||||
"cannot free random generator");
|
||||
}
|
||||
// Allocate the encryption random generator
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
keySet->encryptionRandomGenerator =
|
||||
allocate_encryption_generator(&err, seed_msb, seed_lsb),
|
||||
"cannot allocate encryption generator");
|
||||
// Generate bootstrap and keyswitch keys
|
||||
{
|
||||
for (auto bootstrapKeyParam : params.bootstrapKeys) {
|
||||
auto e = keySet->generateBootstrapKey(bootstrapKeyParam.first,
|
||||
bootstrapKeyParam.second,
|
||||
keySet->encryptionRandomGenerator);
|
||||
if (e) {
|
||||
return e;
|
||||
}
|
||||
}
|
||||
for (auto keyswitchParam : params.keyswitchKeys) {
|
||||
auto e = keySet->generateKeyswitchKey(keyswitchParam.first,
|
||||
keyswitchParam.second,
|
||||
keySet->encryptionRandomGenerator);
|
||||
if (e) {
|
||||
return e;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Set inputs and outputs LWE secret keys
|
||||
{
|
||||
for (auto param : params.inputs) {
|
||||
std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *> input = {
|
||||
param, nullptr, nullptr};
|
||||
if (param.encryption.hasValue()) {
|
||||
auto inputSk = keySet->secretKeys.find(param.encryption->secretKeyID);
|
||||
if (inputSk == keySet->secretKeys.end()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot find input key to generate bootstrap key",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
std::get<1>(input) = &inputSk->second.first;
|
||||
std::get<2>(input) = inputSk->second.second;
|
||||
}
|
||||
keySet->inputs.push_back(input);
|
||||
}
|
||||
for (auto param : params.outputs) {
|
||||
std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *> output =
|
||||
{param, nullptr, nullptr};
|
||||
if (param.encryption.hasValue()) {
|
||||
auto outputSk = keySet->secretKeys.find(param.encryption->secretKeyID);
|
||||
if (outputSk == keySet->secretKeys.end()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot find output key to generate bootstrap key",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
std::get<1>(output) = &outputSk->second.first;
|
||||
std::get<2>(output) = outputSk->second.second;
|
||||
}
|
||||
keySet->outputs.push_back(output);
|
||||
}
|
||||
}
|
||||
return std::move(keySet);
|
||||
}
|
||||
|
||||
llvm::Error KeySet::generateSecretKey(LweSecretKeyID id,
|
||||
LweSecretKeyParam param,
|
||||
SecretRandomGenerator *generator) {
|
||||
LweSecretKey_u64 *sk;
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
sk = allocate_lwe_secret_key_u64(&err, {_0 : param.size}),
|
||||
"cannot allocate secret key");
|
||||
CAPI_ERR_TO_LLVM_ERROR(fill_lwe_secret_key_u64(&err, sk, generator),
|
||||
"cannot fill secret key with random generator")
|
||||
secretKeys[id] = {param, sk};
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error KeySet::generateBootstrapKey(BootstrapKeyID id,
|
||||
BootstrapKeyParam param,
|
||||
EncryptionRandomGenerator *generator) {
|
||||
// Finding input and output secretKeys
|
||||
auto inputSk = secretKeys.find(param.inputSecretKeyID);
|
||||
if (inputSk == secretKeys.end()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot find input key to generate bootstrap key",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
auto outputSk = secretKeys.find(param.outputSecretKeyID);
|
||||
if (outputSk == secretKeys.end()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot find input key to generate bootstrap key",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// Allocate the bootstrap key
|
||||
LweBootstrapKey_u64 *bsk;
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
bsk = allocate_lwe_bootstrap_key_u64(
|
||||
&err, {param.level}, {param.baseLog}, {param.k},
|
||||
{inputSk->second.first.size},
|
||||
{outputSk->second.first.size /*TODO: size / k ?*/}),
|
||||
"cannot allocate bootstrap key");
|
||||
// Store the bootstrap key
|
||||
bootstrapKeys[id] = {param, bsk};
|
||||
// Convert the output lwe key to glwe key
|
||||
GlweSecretKey_u64 *glwe_sk;
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
glwe_sk = allocate_glwe_secret_key_u64(&err, {param.k},
|
||||
{outputSk->second.first.size}),
|
||||
"cannot allocate glwe key for initiliazation of bootstrap key");
|
||||
// Initialize the bootstrap key
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
fill_lwe_bootstrap_key_u64(&err, bsk, inputSk->second.second, glwe_sk,
|
||||
generator, {param.variance}),
|
||||
"cannot fill bootstrap key");
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
free_glwe_secret_key_u64(&err, glwe_sk),
|
||||
"cannot free glwe key for initiliazation of bootstrap key")
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error KeySet::generateKeyswitchKey(KeyswitchKeyID id,
|
||||
KeyswitchKeyParam param,
|
||||
EncryptionRandomGenerator *generator) {
|
||||
// Finding input and output secretKeys
|
||||
auto inputSk = secretKeys.find(param.inputSecretKeyID);
|
||||
if (inputSk == secretKeys.end()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot find input key to generate keyswitch key",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
auto outputSk = secretKeys.find(param.outputSecretKeyID);
|
||||
if (outputSk == secretKeys.end()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot find input key to generate keyswitch key",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// Allocate the keyswitch key
|
||||
LweKeyswitchKey_u64 *ksk;
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
ksk = allocate_lwe_keyswitch_key_u64(&err, {param.level}, {param.baseLog},
|
||||
{inputSk->second.first.size},
|
||||
{outputSk->second.first.size}),
|
||||
"cannot allocate keyswitch key");
|
||||
// Store the keyswitch key
|
||||
keyswitchKeys[id] = {param, ksk};
|
||||
// Initialize the keyswitch key
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
fill_lwe_keyswitch_key_u64(&err, ksk, inputSk->second.second,
|
||||
outputSk->second.second, generator,
|
||||
{param.variance}),
|
||||
"cannot fill bootsrap key");
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error KeySet::allocate_lwe(size_t argPos,
|
||||
LweCiphertext_u64 **ciphertext) {
|
||||
if (argPos >= inputs.size()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"allocate_lwe position of argument is too high",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
auto inputSk = inputs[argPos];
|
||||
CAPI_ERR_TO_LLVM_ERROR(*ciphertext = allocate_lwe_ciphertext_u64(
|
||||
&err, {std::get<1>(inputSk)->size}),
|
||||
"cannot allocate ciphertext");
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
bool KeySet::isInputEncrypted(size_t argPos) {
|
||||
return argPos < inputs.size() &&
|
||||
std::get<0>(inputs[argPos]).encryption.hasValue();
|
||||
}
|
||||
|
||||
bool KeySet::isOutputEncrypted(size_t argPos) {
|
||||
return argPos < outputs.size() &&
|
||||
std::get<0>(outputs[argPos]).encryption.hasValue();
|
||||
}
|
||||
|
||||
llvm::Error KeySet::encrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
|
||||
uint64_t input) {
|
||||
if (argPos >= inputs.size()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"encrypt_lwe position of argument is too high",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
auto inputSk = inputs[argPos];
|
||||
if (!std::get<0>(inputSk).encryption.hasValue()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"encrypt_lwe the positional argument is not encrypted",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// Encode - TODO we could check if the input value is in the right range
|
||||
Plaintext_u64 plaintext = {
|
||||
input << (64 -
|
||||
(std::get<0>(inputSk).encryption->encoding.precision + 1))};
|
||||
// Encrypt
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
encrypt_lwe_u64(&err, std::get<2>(inputSk), ciphertext, plaintext,
|
||||
encryptionRandomGenerator,
|
||||
{std::get<0>(inputSk).encryption->variance}),
|
||||
"cannot encrypt");
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error KeySet::decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
|
||||
uint64_t &output) {
|
||||
if (argPos >= outputs.size()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"decrypt_lwe: position of argument is too high",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
auto outputSk = outputs[argPos];
|
||||
if (!std::get<0>(outputSk).encryption.hasValue()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"decrypt_lwe: the positional argument is not encrypted",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// Decrypt
|
||||
Plaintext_u64 plaintext;
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
decrypt_lwe_u64(&err, std::get<2>(outputSk), ciphertext, &plaintext),
|
||||
"cannot decrypt");
|
||||
// Decode
|
||||
output = plaintext._0 >>
|
||||
(64 - (std::get<0>(outputSk).encryption->encoding.precision + 1));
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
134
compiler/lib/Support/V0Parameters.cpp
Normal file
134
compiler/lib/Support/V0Parameters.cpp
Normal file
@@ -0,0 +1,134 @@
|
||||
/// DO NOT MANUALLY EDIT THIS FILE.
|
||||
/// This file was generated thanks the "parameters optimizer".
|
||||
/// We should include this in our build system, but for moment it is just a cc
|
||||
/// from the optimizer output.
|
||||
|
||||
#include "zamalang/Support/V0Parameters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
using namespace std;
|
||||
|
||||
const int NORM2_MAX = 25;
|
||||
const int P_MAX = 7;
|
||||
|
||||
V0Parameter parameters[NORM2_MAX][P_MAX] = {
|
||||
{V0Parameter(1, 10, 481, 2, 8, 4, 2), V0Parameter(1, 10, 483, 2, 8, 4, 2),
|
||||
V0Parameter(1, 10, 491, 2, 8, 4, 2), V0Parameter(1, 10, 484, 3, 6, 4, 2),
|
||||
V0Parameter(1, 10, 497, 3, 6, 4, 2), V0Parameter(1, 11, 506, 1, 24, 4, 2),
|
||||
V0Parameter(1, 11, 506, 1, 24, 4, 2)},
|
||||
{V0Parameter(1, 11, 506, 1, 23, 4, 2), V0Parameter(1, 11, 508, 1, 23, 4, 2),
|
||||
V0Parameter(1, 11, 533, 1, 23, 3, 3), V0Parameter(1, 11, 506, 2, 13, 4, 2),
|
||||
V0Parameter(1, 11, 506, 2, 15, 4, 2), V0Parameter(1, 11, 506, 2, 17, 4, 2),
|
||||
V0Parameter(1, 11, 506, 2, 15, 4, 2)},
|
||||
{V0Parameter(1, 11, 506, 2, 15, 4, 2), V0Parameter(1, 11, 508, 2, 15, 4, 2),
|
||||
V0Parameter(1, 11, 513, 2, 15, 4, 2), V0Parameter(1, 11, 530, 2, 15, 5, 2),
|
||||
V0Parameter(1, 11, 507, 3, 12, 4, 2), V0Parameter(1, 11, 511, 3, 11, 4, 2),
|
||||
V0Parameter(1, 11, 517, 3, 12, 5, 2)},
|
||||
{V0Parameter(1, 11, 508, 4, 9, 4, 2), V0Parameter(1, 11, 509, 4, 9, 5, 2),
|
||||
V0Parameter(1, 11, 507, 5, 8, 5, 2), V0Parameter(1, 11, 541, 5, 8, 5, 2),
|
||||
V0Parameter(1, 10, 549, 2, 8, 3, 3), V0Parameter(1, 10, 530, 2, 8, 5, 2),
|
||||
V0Parameter(1, 10, 524, 3, 6, 5, 2)},
|
||||
{V0Parameter(1, 10, 536, 3, 6, 5, 2), V0Parameter(1, 11, 571, 1, 22, 3, 3),
|
||||
V0Parameter(1, 11, 572, 1, 22, 3, 3), V0Parameter(1, 11, 572, 1, 22, 3, 3),
|
||||
V0Parameter(1, 11, 574, 1, 23, 3, 3), V0Parameter(1, 11, 585, 1, 23, 3, 3),
|
||||
V0Parameter(1, 11, 541, 2, 14, 5, 2)},
|
||||
{V0Parameter(1, 11, 541, 2, 14, 5, 2), V0Parameter(1, 11, 541, 2, 17, 5, 2),
|
||||
V0Parameter(1, 11, 541, 2, 15, 5, 2), V0Parameter(1, 11, 541, 2, 15, 5, 2),
|
||||
V0Parameter(1, 11, 542, 2, 15, 5, 2), V0Parameter(1, 11, 547, 2, 15, 5, 2),
|
||||
V0Parameter(1, 11, 580, 2, 15, 5, 2)},
|
||||
{V0Parameter(1, 11, 542, 3, 11, 5, 2), V0Parameter(1, 11, 545, 3, 12, 5, 2),
|
||||
V0Parameter(1, 11, 561, 3, 12, 5, 2), V0Parameter(1, 11, 543, 4, 9, 5, 2),
|
||||
V0Parameter(1, 11, 549, 4, 9, 5, 2), V0Parameter(1, 11, 547, 5, 8, 5, 2),
|
||||
V0Parameter(1, 11, 591, 5, 8, 6, 2)},
|
||||
{V0Parameter(1, 11, 550, 7, 6, 5, 2), V0Parameter(1, 10, 576, 2, 8, 5, 2),
|
||||
V0Parameter(1, 10, 567, 3, 6, 5, 2), V0Parameter(1, 10, 584, 3, 6, 5, 2),
|
||||
V0Parameter(1, 11, 607, 1, 20, 4, 3), V0Parameter(1, 11, 607, 1, 22, 4, 3),
|
||||
V0Parameter(1, 11, 607, 1, 22, 4, 3)},
|
||||
{V0Parameter(1, 11, 609, 1, 23, 4, 3), V0Parameter(1, 11, 616, 1, 23, 4, 3),
|
||||
V0Parameter(1, 11, 588, 2, 13, 5, 2), V0Parameter(1, 11, 588, 2, 17, 5, 2),
|
||||
V0Parameter(1, 11, 588, 2, 14, 5, 2), V0Parameter(1, 11, 588, 2, 16, 5, 2),
|
||||
V0Parameter(1, 11, 588, 2, 15, 5, 2)},
|
||||
{V0Parameter(1, 11, 590, 2, 15, 5, 2), V0Parameter(1, 11, 596, 2, 15, 5, 2),
|
||||
V0Parameter(1, 11, 622, 2, 15, 6, 2), V0Parameter(1, 11, 589, 3, 11, 5, 2),
|
||||
V0Parameter(1, 11, 594, 3, 12, 5, 2), V0Parameter(1, 11, 602, 3, 12, 6, 2),
|
||||
V0Parameter(1, 11, 591, 4, 9, 5, 2)},
|
||||
{V0Parameter(1, 11, 591, 4, 9, 6, 2), V0Parameter(1, 11, 589, 5, 8, 6, 2),
|
||||
V0Parameter(1, 11, 655, 5, 8, 6, 2), V0Parameter(1, 11, 592, 7, 6, 6, 2),
|
||||
V0Parameter(1, 11, 598, 8, 5, 6, 2), V0Parameter(1, 10, 608, 3, 6, 6, 2),
|
||||
V0Parameter(1, 10, 625, 3, 6, 6, 2)},
|
||||
{V0Parameter(1, 11, 647, 1, 24, 4, 3), V0Parameter(1, 11, 647, 1, 22, 4, 3),
|
||||
V0Parameter(1, 11, 647, 1, 23, 4, 3), V0Parameter(1, 11, 649, 1, 23, 4, 3),
|
||||
V0Parameter(1, 11, 658, 1, 23, 4, 3), V0Parameter(1, 11, 647, 2, 19, 4, 3),
|
||||
V0Parameter(1, 11, 647, 2, 14, 4, 3)},
|
||||
{V0Parameter(1, 11, 647, 2, 17, 4, 3), V0Parameter(1, 11, 647, 2, 15, 4, 3),
|
||||
V0Parameter(1, 11, 647, 2, 15, 4, 3), V0Parameter(1, 11, 648, 2, 15, 4, 3),
|
||||
V0Parameter(1, 11, 654, 2, 15, 4, 3), V0Parameter(1, 11, 674, 2, 15, 7, 2),
|
||||
V0Parameter(1, 11, 624, 3, 12, 6, 2)},
|
||||
{V0Parameter(1, 11, 627, 3, 11, 6, 2), V0Parameter(1, 11, 648, 3, 12, 6, 2),
|
||||
V0Parameter(1, 11, 625, 4, 9, 6, 2), V0Parameter(1, 11, 633, 4, 9, 6, 2),
|
||||
V0Parameter(1, 11, 630, 5, 8, 6, 2), V0Parameter(1, 11, 632, 6, 7, 6, 2),
|
||||
V0Parameter(1, 11, 634, 7, 6, 6, 2)},
|
||||
{V0Parameter(1, 11, 642, 8, 5, 6, 2), V0Parameter(1, 11, 644, 11, 4, 6, 2),
|
||||
V0Parameter(1, 11, 698, 1, 20, 4, 3), V0Parameter(1, 11, 698, 1, 21, 4, 3),
|
||||
V0Parameter(1, 11, 698, 1, 22, 4, 3), V0Parameter(1, 11, 699, 1, 22, 4, 3),
|
||||
V0Parameter(1, 11, 702, 1, 23, 4, 3)},
|
||||
{V0Parameter(1, 11, 721, 1, 23, 4, 3), V0Parameter(1, 11, 698, 2, 13, 4, 3),
|
||||
V0Parameter(1, 11, 698, 2, 17, 4, 3), V0Parameter(1, 11, 698, 2, 14, 4, 3),
|
||||
V0Parameter(1, 11, 698, 2, 16, 4, 3), V0Parameter(1, 11, 698, 2, 15, 4, 3),
|
||||
V0Parameter(1, 11, 700, 2, 15, 4, 3)},
|
||||
{V0Parameter(1, 11, 711, 2, 15, 4, 3), V0Parameter(1, 11, 674, 3, 12, 6, 2),
|
||||
V0Parameter(1, 11, 675, 3, 11, 6, 2), V0Parameter(1, 11, 681, 3, 12, 6, 2),
|
||||
V0Parameter(1, 11, 695, 3, 12, 7, 2), V0Parameter(1, 11, 677, 4, 9, 6, 2),
|
||||
V0Parameter(1, 11, 677, 4, 9, 7, 2)},
|
||||
{V0Parameter(1, 11, 674, 5, 8, 7, 2), V0Parameter(1, 11, 676, 6, 7, 7, 2),
|
||||
V0Parameter(1, 11, 678, 7, 6, 7, 2), V0Parameter(1, 11, 688, 8, 5, 7, 2),
|
||||
V0Parameter(1, 11, 690, 11, 4, 7, 2),
|
||||
V0Parameter(1, 11, 731, 14, 3, 15, 1),
|
||||
V0Parameter(1, 11, 745, 1, 21, 5, 3)},
|
||||
{V0Parameter(1, 11, 745, 1, 22, 5, 3), V0Parameter(1, 11, 746, 1, 23, 5, 3),
|
||||
V0Parameter(1, 11, 750, 1, 23, 5, 3), V0Parameter(1, 11, 781, 1, 23, 5, 3),
|
||||
V0Parameter(1, 11, 745, 2, 19, 5, 3), V0Parameter(1, 11, 745, 2, 14, 5, 3),
|
||||
V0Parameter(1, 11, 745, 2, 17, 5, 3)},
|
||||
{V0Parameter(1, 11, 745, 2, 15, 5, 3), V0Parameter(1, 11, 746, 2, 15, 5, 3),
|
||||
V0Parameter(1, 11, 748, 2, 15, 5, 3), V0Parameter(1, 11, 763, 2, 15, 5, 3),
|
||||
V0Parameter(1, 11, 745, 3, 12, 5, 3), V0Parameter(1, 11, 747, 3, 11, 5, 3),
|
||||
V0Parameter(1, 11, 756, 3, 11, 5, 3)},
|
||||
{V0Parameter(1, 11, 722, 4, 9, 7, 2), V0Parameter(1, 11, 727, 4, 9, 7, 2),
|
||||
V0Parameter(1, 11, 750, 4, 9, 8, 2), V0Parameter(1, 11, 747, 5, 8, 7, 2),
|
||||
V0Parameter(1, 11, 747, 6, 7, 8, 2), V0Parameter(1, 11, 757, 7, 6, 8, 2),
|
||||
V0Parameter(1, 11, 739, 10, 4, 8, 2)},
|
||||
{V0Parameter(1, 11, 735, 14, 3, 8, 2), V0Parameter(1, 11, 767, 21, 2, 8, 2),
|
||||
V0Parameter(0, 0, 0, 0, 0, 0, 0), V0Parameter(1, 12, 839, 1, 22, 4, 4),
|
||||
V0Parameter(1, 12, 844, 1, 23, 4, 4), V0Parameter(1, 12, 847, 1, 23, 6, 3),
|
||||
V0Parameter(1, 12, 815, 2, 13, 5, 3)},
|
||||
{V0Parameter(1, 12, 815, 2, 17, 5, 3), V0Parameter(1, 12, 815, 2, 14, 5, 3),
|
||||
V0Parameter(1, 12, 815, 2, 15, 5, 3), V0Parameter(1, 12, 816, 2, 15, 5, 3),
|
||||
V0Parameter(1, 12, 820, 2, 15, 5, 3), V0Parameter(1, 12, 858, 2, 15, 4, 4),
|
||||
V0Parameter(1, 12, 815, 3, 11, 5, 3)},
|
||||
{V0Parameter(1, 12, 818, 3, 11, 5, 3), V0Parameter(1, 12, 790, 3, 11, 8, 2),
|
||||
V0Parameter(1, 12, 783, 4, 9, 8, 2), V0Parameter(1, 12, 787, 4, 9, 8, 2),
|
||||
V0Parameter(1, 12, 823, 4, 9, 8, 2), V0Parameter(1, 12, 822, 5, 8, 8, 2),
|
||||
V0Parameter(1, 12, 843, 6, 7, 9, 2)},
|
||||
{V0Parameter(1, 12, 792, 8, 5, 8, 2), V0Parameter(1, 12, 802, 10, 4, 8, 2),
|
||||
V0Parameter(1, 12, 804, 14, 3, 8, 2), V0Parameter(1, 12, 825, 22, 2, 9, 2),
|
||||
V0Parameter(0, 0, 0, 0, 0, 0, 0), V0Parameter(0, 0, 0, 0, 0, 0, 0),
|
||||
V0Parameter(0, 0, 0, 0, 0, 0, 0)}};
|
||||
|
||||
V0Parameter *getV0Parameter(size_t norm, size_t p) {
|
||||
if (norm > NORM2_MAX) {
|
||||
return nullptr;
|
||||
}
|
||||
if (p >= P_MAX) {
|
||||
return nullptr;
|
||||
}
|
||||
// - 1 is an offset as norm and p are in [1, ...] and not [0, ...]
|
||||
auto param = ¶meters[norm - 1][p - 1];
|
||||
if (param->k == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return param;
|
||||
}
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
@@ -32,6 +32,9 @@ llvm::cl::opt<std::string> output("o",
|
||||
llvm::cl::value_desc("filename"),
|
||||
llvm::cl::init("-"));
|
||||
|
||||
llvm::cl::opt<bool> verbose("verbose", llvm::cl::desc("verbose logs"),
|
||||
llvm::cl::init<bool>(false));
|
||||
|
||||
llvm::cl::list<std::string> passes(
|
||||
"passes",
|
||||
llvm::cl::desc("Specify the passes to run (use only for compiler tests)"),
|
||||
@@ -53,8 +56,19 @@ llvm::cl::opt<bool> splitInputFile(
|
||||
"chunk independently"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
llvm::cl::opt<bool> generateKeySet(
|
||||
"generate-keyset",
|
||||
llvm::cl::desc("[tmp] Generate a key set for the compiled fhe circuit"),
|
||||
llvm::cl::init<bool>(false));
|
||||
|
||||
llvm::cl::opt<bool> runJit("run-jit", llvm::cl::desc("JIT the code and run it"),
|
||||
llvm::cl::init<bool>(false));
|
||||
|
||||
llvm::cl::opt<std::string> jitFuncname(
|
||||
"jit-funcname",
|
||||
llvm::cl::desc("Name of the function to execute, default 'main'"),
|
||||
llvm::cl::init<std::string>("main"));
|
||||
|
||||
llvm::cl::list<int>
|
||||
jitArgs("jit-args",
|
||||
llvm::cl::desc("Value of arguments to pass to the main func"),
|
||||
@@ -64,6 +78,12 @@ llvm::cl::opt<bool> toLLVM("to-llvm", llvm::cl::desc("Compile to llvm and "),
|
||||
llvm::cl::init<bool>(false));
|
||||
}; // namespace cmdline
|
||||
|
||||
#define LOG_VERBOSE(expr) \
|
||||
if (cmdline::verbose) \
|
||||
llvm::errs() << expr;
|
||||
|
||||
#define LOG_ERROR(expr) llvm::errs() << expr;
|
||||
|
||||
auto defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr);
|
||||
|
||||
mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) {
|
||||
@@ -77,32 +97,42 @@ mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) {
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult runJit(mlir::ModuleOp module, llvm::raw_ostream &os) {
|
||||
mlir::LogicalResult runJit(mlir::ModuleOp module,
|
||||
mlir::zamalang::KeySet &keySet,
|
||||
llvm::raw_ostream &os) {
|
||||
// Create the JIT lambda
|
||||
auto maybeLambda =
|
||||
mlir::zamalang::JITLambda::create("main", module, defaultOptPipeline);
|
||||
auto maybeLambda = mlir::zamalang::JITLambda::create(
|
||||
cmdline::jitFuncname, module, defaultOptPipeline);
|
||||
if (!maybeLambda) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto lambda = maybeLambda.get().get();
|
||||
auto lambda = std::move(maybeLambda.get());
|
||||
|
||||
// Create buffer to copy argument
|
||||
std::vector<int64_t> dummy(cmdline::jitArgs.size());
|
||||
llvm::SmallVector<void *> llvmArgs;
|
||||
for (auto i = 0; i < cmdline::jitArgs.size(); i++) {
|
||||
dummy[i] = cmdline::jitArgs[i];
|
||||
llvmArgs.push_back(&dummy[i]);
|
||||
}
|
||||
// Add the result pointer
|
||||
uint64_t res = 0;
|
||||
llvmArgs.push_back(&res);
|
||||
// Create the arguments of the JIT lambda
|
||||
auto maybeArguments = mlir::zamalang::JITLambda::Argument::create(keySet);
|
||||
if (auto err = maybeArguments.takeError()) {
|
||||
|
||||
// Invoke the lambda
|
||||
if (lambda->invokeRaw(llvmArgs)) {
|
||||
LOG_ERROR("Cannot create lambda arguments: " << err << "\n");
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
std::cerr << res << "\n";
|
||||
// Set the arguments
|
||||
auto arguments = std::move(maybeArguments.get());
|
||||
for (auto i = 0; i < cmdline::jitArgs.size(); i++) {
|
||||
if (auto err = arguments->setArg(i, cmdline::jitArgs[i])) {
|
||||
LOG_ERROR("Cannot push argument " << i << ": " << err << "\n");
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Invoke the lambda
|
||||
if (lambda->invoke(*arguments)) {
|
||||
return mlir::failure();
|
||||
}
|
||||
uint64_t res = 0;
|
||||
if (auto err = arguments->getResult(0, res)) {
|
||||
LOG_ERROR("Cannot get result : " << err << "\n");
|
||||
return mlir::failure();
|
||||
}
|
||||
llvm::errs() << res << "\n";
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
@@ -137,20 +167,67 @@ processInputBuffer(mlir::MLIRContext &context,
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirLLVMDialect(
|
||||
context, *module,
|
||||
[](std::string passName) {
|
||||
return cmdline::passes.size() == 0 ||
|
||||
std::any_of(
|
||||
cmdline::passes.begin(), cmdline::passes.end(),
|
||||
auto enablePass = [](std::string passName) {
|
||||
return cmdline::passes.size() == 0 ||
|
||||
std::any_of(cmdline::passes.begin(), cmdline::passes.end(),
|
||||
[&](const std::string &p) { return passName == p; });
|
||||
})
|
||||
};
|
||||
|
||||
// Lower to MLIR Stds Dialects and compute the constraint on the FHE Circuit.
|
||||
mlir::zamalang::FHECircuitConstraint constraint;
|
||||
LOG_VERBOSE("### Lower from HLFHE to MLIR standards \n");
|
||||
if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect(
|
||||
context, *module, constraint, enablePass)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
LOG_VERBOSE("### Global FHE constraint: {norm2:" << constraint.norm2 << ", p:"
|
||||
<< constraint.p << "}\n");
|
||||
|
||||
// Retreive the parameters for the v0 approach
|
||||
mlir::zamalang::V0Parameter *fheParameter =
|
||||
mlir::zamalang::getV0Parameter(constraint.norm2, constraint.p);
|
||||
LOG_VERBOSE("### FHE parameters for the atomic pattern: {k: "
|
||||
<< fheParameter->k
|
||||
<< ", polynomialSize: " << fheParameter->polynomialSize
|
||||
<< ", nSmall: " << fheParameter->nSmall
|
||||
<< ", brLevel: " << fheParameter->brLevel
|
||||
<< ", brLogBase: " << fheParameter->brLogBase
|
||||
<< ", ksLevel: " << fheParameter->ksLevel
|
||||
<< ", polynomialSize: " << fheParameter->ksLogBase << "}\n");
|
||||
|
||||
// Generate the keySet
|
||||
std::unique_ptr<mlir::zamalang::KeySet> keySet;
|
||||
if (cmdline::generateKeySet || cmdline::runJit) {
|
||||
// Create the client parameters
|
||||
auto clientParameter = mlir::zamalang::createClientParametersForV0(
|
||||
fheParameter, constraint.p, cmdline::jitFuncname, *module);
|
||||
if (auto err = clientParameter.takeError()) {
|
||||
LOG_ERROR("cannot generate client parameters: " << err << "\n");
|
||||
return mlir::failure();
|
||||
}
|
||||
LOG_VERBOSE("### Generate the key set\n");
|
||||
auto maybeKeySet =
|
||||
mlir::zamalang::KeySet::generate(clientParameter.get(), 0,
|
||||
0); // TODO: seed
|
||||
if (auto err = maybeKeySet.takeError()) {
|
||||
llvm::errs() << err;
|
||||
return mlir::failure();
|
||||
}
|
||||
keySet = std::move(maybeKeySet.get());
|
||||
}
|
||||
|
||||
// Lower to MLIR LLVM Dialect
|
||||
LOG_VERBOSE("### Lower from MLIR standards to LLVM\n");
|
||||
if (mlir::zamalang::CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect(
|
||||
context, *module, enablePass)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (cmdline::runJit) {
|
||||
return runJit(module.get(), os);
|
||||
LOG_VERBOSE("### JIT compile & running\n");
|
||||
return runJit(module.get(), *keySet, os);
|
||||
}
|
||||
if (cmdline::toLLVM) {
|
||||
return dumpLLVMIR(module.get(), os);
|
||||
|
||||
6
compiler/tests/RunJit/lowlfhe_identity.mlir
Normal file
6
compiler/tests/RunJit/lowlfhe_identity.mlir
Normal file
@@ -0,0 +1,6 @@
|
||||
// RUN: zamacompiler %s --run-jit --jit-args 42 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: 42
|
||||
func @main(%arg0: !LowLFHE.lwe_ciphertext) -> !LowLFHE.lwe_ciphertext {
|
||||
return %arg0 : !LowLFHE.lwe_ciphertext
|
||||
}
|
||||
Reference in New Issue
Block a user