mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(compiler): Output client parameters when compile to a library
close #198
This commit is contained in:
@@ -53,11 +53,20 @@ test-check: concretecompiler file-check not
|
||||
test-python: python-bindings concretecompiler
|
||||
PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/concretelang/python_packages/concretelang_core LD_PRELOAD=$(BUILD_DIR)/lib/libConcretelangRuntime.so pytest -vs tests/python
|
||||
|
||||
test: test-check test-end-to-end-jit test-python
|
||||
test: test-check test-end-to-end-jit test-python support-unit-test
|
||||
|
||||
test-dataflow: test-end-to-end-jit-dfr test-end-to-end-jit-auto-parallelization
|
||||
|
||||
# Unittests
|
||||
# unit-test
|
||||
|
||||
support-unit-test: build-support-unit-test
|
||||
$(BUILD_DIR)/bin/support_unit_test
|
||||
|
||||
build-support-unit-test: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target support_unit_test
|
||||
|
||||
|
||||
# test-end-to-end-jit
|
||||
|
||||
build-end-to-end-jit-test: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target end_to_end_jit_test
|
||||
|
||||
165
compiler/include/concretelang/ClientLib/ClientParameters.h
Normal file
165
compiler/include/concretelang/ClientLib/ClientParameters.h
Normal file
@@ -0,0 +1,165 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
|
||||
// information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_CLIENTPARAMETERS_H_
|
||||
#define CONCRETELANG_CLIENTLIB_CLIENTPARAMETERS_H_
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <llvm/Support/JSON.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
const std::string SMALL_KEY = "small";
|
||||
const std::string BIG_KEY = "big";
|
||||
|
||||
typedef size_t DecompositionLevelCount;
|
||||
typedef size_t DecompositionBaseLog;
|
||||
typedef size_t PolynomialSize;
|
||||
typedef size_t Precision;
|
||||
typedef double Variance;
|
||||
|
||||
typedef uint64_t LweSize;
|
||||
typedef uint64_t GlweDimension;
|
||||
|
||||
typedef std::string LweSecretKeyID;
|
||||
struct LweSecretKeyParam {
|
||||
LweSize size;
|
||||
|
||||
void hash(size_t &seed);
|
||||
};
|
||||
static bool operator==(const LweSecretKeyParam &lhs,
|
||||
const LweSecretKeyParam &rhs) {
|
||||
return lhs.size == rhs.size;
|
||||
}
|
||||
|
||||
typedef std::string BootstrapKeyID;
|
||||
struct BootstrapKeyParam {
|
||||
LweSecretKeyID inputSecretKeyID;
|
||||
LweSecretKeyID outputSecretKeyID;
|
||||
DecompositionLevelCount level;
|
||||
DecompositionBaseLog baseLog;
|
||||
GlweDimension glweDimension;
|
||||
Variance variance;
|
||||
|
||||
void hash(size_t &seed);
|
||||
};
|
||||
static inline bool operator==(const BootstrapKeyParam &lhs,
|
||||
const BootstrapKeyParam &rhs) {
|
||||
return lhs.inputSecretKeyID == rhs.inputSecretKeyID &&
|
||||
lhs.outputSecretKeyID == rhs.outputSecretKeyID &&
|
||||
lhs.level == rhs.level && lhs.baseLog == rhs.baseLog &&
|
||||
lhs.glweDimension == rhs.glweDimension && lhs.variance == rhs.variance;
|
||||
}
|
||||
|
||||
typedef std::string KeyswitchKeyID;
|
||||
struct KeyswitchKeyParam {
|
||||
LweSecretKeyID inputSecretKeyID;
|
||||
LweSecretKeyID outputSecretKeyID;
|
||||
DecompositionLevelCount level;
|
||||
DecompositionBaseLog baseLog;
|
||||
Variance variance;
|
||||
|
||||
void hash(size_t &seed);
|
||||
};
|
||||
static inline bool operator==(const KeyswitchKeyParam &lhs,
|
||||
const KeyswitchKeyParam &rhs) {
|
||||
return lhs.inputSecretKeyID == rhs.inputSecretKeyID &&
|
||||
lhs.outputSecretKeyID == rhs.outputSecretKeyID &&
|
||||
lhs.level == rhs.level && lhs.baseLog == rhs.baseLog &&
|
||||
lhs.variance == rhs.variance;
|
||||
}
|
||||
|
||||
struct Encoding {
|
||||
Precision precision;
|
||||
};
|
||||
static inline bool operator==(const Encoding &lhs, const Encoding &rhs) {
|
||||
return lhs.precision == rhs.precision;
|
||||
}
|
||||
|
||||
struct EncryptionGate {
|
||||
LweSecretKeyID secretKeyID;
|
||||
Variance variance;
|
||||
Encoding encoding;
|
||||
};
|
||||
static inline bool operator==(const EncryptionGate &lhs,
|
||||
const EncryptionGate &rhs) {
|
||||
return lhs.secretKeyID == rhs.secretKeyID && lhs.variance == rhs.variance &&
|
||||
lhs.encoding == rhs.encoding;
|
||||
}
|
||||
|
||||
struct CircuitGateShape {
|
||||
// Width of the scalar value
|
||||
size_t width;
|
||||
// Dimensions of the tensor, empty if scalar
|
||||
std::vector<int64_t> dimensions;
|
||||
// Size of the buffer containing the tensor
|
||||
size_t size;
|
||||
};
|
||||
static inline bool operator==(const CircuitGateShape &lhs,
|
||||
const CircuitGateShape &rhs) {
|
||||
return lhs.width == rhs.width && lhs.dimensions == rhs.dimensions &&
|
||||
lhs.size == rhs.size;
|
||||
}
|
||||
|
||||
struct CircuitGate {
|
||||
llvm::Optional<EncryptionGate> encryption;
|
||||
CircuitGateShape shape;
|
||||
};
|
||||
static inline bool operator==(const CircuitGate &lhs, const CircuitGate &rhs) {
|
||||
return lhs.encryption == rhs.encryption && lhs.shape == rhs.shape;
|
||||
}
|
||||
|
||||
struct ClientParameters {
|
||||
std::map<LweSecretKeyID, LweSecretKeyParam> secretKeys;
|
||||
std::map<BootstrapKeyID, BootstrapKeyParam> bootstrapKeys;
|
||||
std::map<KeyswitchKeyID, KeyswitchKeyParam> keyswitchKeys;
|
||||
std::vector<CircuitGate> inputs;
|
||||
std::vector<CircuitGate> outputs;
|
||||
std::string functionName;
|
||||
size_t hash();
|
||||
};
|
||||
static inline bool operator==(const ClientParameters &lhs,
|
||||
const ClientParameters &rhs) {
|
||||
return lhs.secretKeys == rhs.secretKeys &&
|
||||
lhs.bootstrapKeys == rhs.bootstrapKeys &&
|
||||
lhs.keyswitchKeys == rhs.keyswitchKeys && lhs.inputs == lhs.inputs &&
|
||||
lhs.outputs == lhs.outputs;
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const LweSecretKeyParam &);
|
||||
bool fromJSON(const llvm::json::Value, LweSecretKeyParam &, llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const BootstrapKeyParam &);
|
||||
bool fromJSON(const llvm::json::Value, BootstrapKeyParam &, llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const KeyswitchKeyParam &);
|
||||
bool fromJSON(const llvm::json::Value, KeyswitchKeyParam &, llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const Encoding &);
|
||||
bool fromJSON(const llvm::json::Value, Encoding &, llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const EncryptionGate &);
|
||||
bool fromJSON(const llvm::json::Value, EncryptionGate &, llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const CircuitGateShape &);
|
||||
bool fromJSON(const llvm::json::Value, CircuitGateShape &, llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const CircuitGate &);
|
||||
bool fromJSON(const llvm::json::Value, CircuitGate &, llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const ClientParameters &);
|
||||
bool fromJSON(const llvm::json::Value, ClientParameters &, llvm::json::Path);
|
||||
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
ClientParameters cp) {
|
||||
return OS << llvm::formatv("{0:2}", toJSON(cp));
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -6,7 +6,6 @@
|
||||
#ifndef CONCRETELANG_SUPPORT_KEYSET_H_
|
||||
#define CONCRETELANG_SUPPORT_KEYSET_H_
|
||||
|
||||
#include "llvm/Support/Error.h"
|
||||
#include <memory>
|
||||
|
||||
extern "C" {
|
||||
@@ -14,8 +13,8 @@ extern "C" {
|
||||
#include "concretelang/Runtime/context.h"
|
||||
}
|
||||
|
||||
#include "concretelang/Support/ClientParameters.h"
|
||||
#include "concretelang/Support/KeySetCache.h"
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
@@ -91,9 +90,9 @@ private:
|
||||
bootstrapKeys;
|
||||
std::map<LweSecretKeyID, std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>>
|
||||
keyswitchKeys;
|
||||
std::vector<std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *>>
|
||||
std::vector<std::tuple<CircuitGate, LweSecretKeyParam, LweSecretKey_u64 *>>
|
||||
inputs;
|
||||
std::vector<std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *>>
|
||||
std::vector<std::tuple<CircuitGate, LweSecretKeyParam, LweSecretKey_u64 *>>
|
||||
outputs;
|
||||
|
||||
void setKeys(
|
||||
@@ -6,7 +6,7 @@
|
||||
#ifndef CONCRETELANG_SUPPORT_KEYSETCACHE_H_
|
||||
#define CONCRETELANG_SUPPORT_KEYSETCACHE_H_
|
||||
|
||||
#include "concretelang/Support/KeySet.h"
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
@@ -1,99 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
|
||||
// information.
|
||||
|
||||
#ifndef CONCRETELANG_SUPPORT_CLIENTPARAMETERS_H_
|
||||
#define CONCRETELANG_SUPPORT_CLIENTPARAMETERS_H_
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
|
||||
#include "concretelang/Support/V0Parameters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
typedef size_t DecompositionLevelCount;
|
||||
typedef size_t DecompositionBaseLog;
|
||||
typedef size_t PolynomialSize;
|
||||
typedef size_t Precision;
|
||||
typedef double Variance;
|
||||
|
||||
typedef uint64_t LweSize;
|
||||
typedef uint64_t GlweDimension;
|
||||
|
||||
typedef std::string LweSecretKeyID;
|
||||
struct LweSecretKeyParam {
|
||||
LweSize size;
|
||||
|
||||
void hash(size_t &seed);
|
||||
};
|
||||
|
||||
typedef std::string BootstrapKeyID;
|
||||
struct BootstrapKeyParam {
|
||||
LweSecretKeyID inputSecretKeyID;
|
||||
LweSecretKeyID outputSecretKeyID;
|
||||
DecompositionLevelCount level;
|
||||
DecompositionBaseLog baseLog;
|
||||
GlweDimension glweDimension;
|
||||
Variance variance;
|
||||
|
||||
void hash(size_t &seed);
|
||||
};
|
||||
|
||||
typedef std::string KeyswitchKeyID;
|
||||
struct KeyswitchKeyParam {
|
||||
LweSecretKeyID inputSecretKeyID;
|
||||
LweSecretKeyID outputSecretKeyID;
|
||||
DecompositionLevelCount level;
|
||||
DecompositionBaseLog baseLog;
|
||||
Variance variance;
|
||||
|
||||
void hash(size_t &seed);
|
||||
};
|
||||
|
||||
struct Encoding {
|
||||
Precision precision;
|
||||
};
|
||||
|
||||
struct EncryptionGate {
|
||||
LweSecretKeyID secretKeyID;
|
||||
Variance variance;
|
||||
Encoding encoding;
|
||||
};
|
||||
|
||||
struct CircuitGateShape {
|
||||
// Width of the scalar value
|
||||
size_t width;
|
||||
// Dimensions of the tensor, empty if scalar
|
||||
std::vector<int64_t> dimensions;
|
||||
// Size of the buffer containing the tensor
|
||||
size_t size;
|
||||
};
|
||||
|
||||
struct CircuitGate {
|
||||
llvm::Optional<EncryptionGate> encryption;
|
||||
CircuitGateShape shape;
|
||||
};
|
||||
|
||||
struct ClientParameters {
|
||||
std::map<LweSecretKeyID, LweSecretKeyParam> secretKeys;
|
||||
std::map<BootstrapKeyID, BootstrapKeyParam> bootstrapKeys;
|
||||
std::map<KeyswitchKeyID, KeyswitchKeyParam> keyswitchKeys;
|
||||
std::vector<CircuitGate> inputs;
|
||||
std::vector<CircuitGate> outputs;
|
||||
size_t hash();
|
||||
};
|
||||
|
||||
llvm::Expected<ClientParameters>
|
||||
createClientParametersForV0(V0FHEContext context, llvm::StringRef name,
|
||||
mlir::ModuleOp module);
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -7,7 +7,7 @@
|
||||
#define CONCRETELANG_SUPPORT_COMPILER_ENGINE_H
|
||||
|
||||
#include <concretelang/Conversion/Utils/GlobalFHEContext.h>
|
||||
#include <concretelang/Support/ClientParameters.h>
|
||||
#include <concretelang/Support/V0ClientParameters.h>
|
||||
#include <llvm/IR/Module.h>
|
||||
#include <llvm/Support/Error.h>
|
||||
#include <llvm/Support/SourceMgr.h>
|
||||
@@ -58,6 +58,7 @@ public:
|
||||
class Library {
|
||||
std::string libraryPath;
|
||||
std::vector<std::string> objectsPath;
|
||||
std::vector<mlir::concretelang::ClientParameters> clientParametersList;
|
||||
bool cleanUp;
|
||||
|
||||
public:
|
||||
@@ -68,21 +69,38 @@ public:
|
||||
: libraryPath(libraryPath), cleanUp(cleanUp) {}
|
||||
/** Add a compilation result to the library */
|
||||
llvm::Expected<std::string> addCompilation(CompilationResult &compilation);
|
||||
/** Emit a shared library with the previously added compilation result */
|
||||
llvm::Expected<std::string> emitShared();
|
||||
/** Emit a shared library with the previously added compilation result */
|
||||
llvm::Expected<std::string> emitStatic();
|
||||
/** Emit the library artifacts with the previously added compilation result
|
||||
*/
|
||||
llvm::Error emitArtifacts();
|
||||
/** After a shared library has been emitted, its path is here */
|
||||
std::string sharedLibraryPath;
|
||||
/** After a static library has been emitted, its path is here */
|
||||
std::string staticLibraryPath;
|
||||
|
||||
/** Returns the path of the shared library */
|
||||
static std::string getSharedLibraryPath(std::string path);
|
||||
|
||||
/** Returns the path of the static library */
|
||||
static std::string getStaticLibraryPath(std::string path);
|
||||
|
||||
/** Returns the path of the static library */
|
||||
static std::string getClientParametersPath(std::string path);
|
||||
|
||||
// For advanced use
|
||||
const static std::string OBJECT_EXT, LINKER, LINKER_SHARED_OPT, AR,
|
||||
AR_STATIC_OPT, DOT_STATIC_LIB_EXT, DOT_SHARED_LIB_EXT;
|
||||
const static std::string OBJECT_EXT, CLIENT_PARAMETERS_EXT, LINKER,
|
||||
LINKER_SHARED_OPT, AR, AR_STATIC_OPT, DOT_STATIC_LIB_EXT,
|
||||
DOT_SHARED_LIB_EXT;
|
||||
void addExtraObjectFilePath(std::string objectFilePath);
|
||||
llvm::Expected<std::string> emit(std::string dotExt, std::string linker);
|
||||
~Library();
|
||||
|
||||
private:
|
||||
/** Emit a shared library with the previously added compilation result */
|
||||
llvm::Expected<std::string> emitStatic();
|
||||
/** Emit a shared library with the previously added compilation result */
|
||||
llvm::Expected<std::string> emitShared();
|
||||
/** Emit a shared library with the previously added compilation result */
|
||||
llvm::Expected<std::string> emitClientParametersJSON();
|
||||
};
|
||||
|
||||
// Specification of the exit stage of the compilation pipeline
|
||||
|
||||
@@ -50,7 +50,10 @@ protected:
|
||||
llvm::raw_string_ostream os;
|
||||
};
|
||||
|
||||
StreamStringError &operator<<(StreamStringError &se, llvm::Error &err);
|
||||
inline StreamStringError &operator<<(StreamStringError &se, llvm::Error &err) {
|
||||
se << llvm::toString(std::move(err));
|
||||
return se;
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/Support/LogicalResult.h>
|
||||
|
||||
#include <concretelang/Support/KeySet.h>
|
||||
#include <concretelang/ClientLib/KeySet.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#ifndef CONCRETELANG_SUPPORT_JIT_COMPILER_ENGINE_H
|
||||
#define CONCRETELANG_SUPPORT_JIT_COMPILER_ENGINE_H
|
||||
|
||||
#include "concretelang/Support/KeySetCache.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include <concretelang/Support/CompilerEngine.h>
|
||||
#include <concretelang/Support/Error.h>
|
||||
#include <concretelang/Support/Jit.h>
|
||||
|
||||
28
compiler/include/concretelang/Support/V0ClientParameters.h
Normal file
28
compiler/include/concretelang/Support/V0ClientParameters.h
Normal file
@@ -0,0 +1,28 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
|
||||
// information.
|
||||
|
||||
#ifndef CONCRETELANG_SUPPORT_CLIENTPARAMETERS_H_
|
||||
#define CONCRETELANG_SUPPORT_CLIENTPARAMETERS_H_
|
||||
|
||||
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/Support/V0Parameters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
ClientParameters emptyClientParametersForV0(llvm::StringRef functionName,
|
||||
mlir::ModuleOp module);
|
||||
|
||||
llvm::Expected<ClientParameters>
|
||||
createClientParametersForV0(V0FHEContext context, llvm::StringRef functionName,
|
||||
mlir::ModuleOp module);
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -6,10 +6,10 @@
|
||||
#include "llvm/ADT/SmallString.h"
|
||||
|
||||
#include "concretelang-c/Support/CompilerEngine.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Jit.h"
|
||||
#include "concretelang/Support/JitCompilerEngine.h"
|
||||
#include "concretelang/Support/KeySetCache.h"
|
||||
|
||||
using mlir::concretelang::JitCompilerEngine;
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ add_subdirectory(Dialect)
|
||||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Support)
|
||||
add_subdirectory(Runtime)
|
||||
add_subdirectory(ClientLib)
|
||||
add_subdirectory(Bindings)
|
||||
|
||||
# CAPI needed only for python bindings
|
||||
|
||||
9
compiler/lib/ClientLib/CMakeLists.txt
Normal file
9
compiler/lib/ClientLib/CMakeLists.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
add_mlir_library(
|
||||
ConcretelangClientLib
|
||||
ClientParameters.cpp
|
||||
KeySet.cpp
|
||||
KeySetCache.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/ClientLib
|
||||
)
|
||||
388
compiler/lib/ClientLib/ClientParameters.cpp
Normal file
388
compiler/lib/ClientLib/ClientParameters.cpp
Normal file
@@ -0,0 +1,388 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
|
||||
// information.
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
// https://stackoverflow.com/a/38140932
|
||||
static inline void hash(std::size_t &seed) {}
|
||||
template <typename T, typename... Rest>
|
||||
static inline void hash(std::size_t &seed, const T &v, Rest... rest) {
|
||||
// See https://softwareengineering.stackexchange.com/a/402543
|
||||
const auto GOLDEN_RATIO = 0x9e3779b97f4a7c15; // pseudo random bits
|
||||
const std::hash<T> hasher;
|
||||
seed ^= hasher(v) + GOLDEN_RATIO + (seed << 6) + (seed >> 2);
|
||||
hash(seed, rest...);
|
||||
}
|
||||
|
||||
void LweSecretKeyParam::hash(size_t &seed) {
|
||||
mlir::concretelang::hash(seed, size);
|
||||
}
|
||||
|
||||
void BootstrapKeyParam::hash(size_t &seed) {
|
||||
mlir::concretelang::hash(seed, inputSecretKeyID, outputSecretKeyID, level,
|
||||
baseLog, glweDimension, variance);
|
||||
}
|
||||
|
||||
void KeyswitchKeyParam::hash(size_t &seed) {
|
||||
mlir::concretelang::hash(seed, inputSecretKeyID, outputSecretKeyID, level,
|
||||
baseLog, variance);
|
||||
}
|
||||
|
||||
std::size_t ClientParameters::hash() {
|
||||
std::size_t currentHash = 1;
|
||||
for (auto secretKeyParam : secretKeys) {
|
||||
mlir::concretelang::hash(currentHash, secretKeyParam.first);
|
||||
secretKeyParam.second.hash(currentHash);
|
||||
}
|
||||
for (auto bootstrapKeyParam : bootstrapKeys) {
|
||||
mlir::concretelang::hash(currentHash, bootstrapKeyParam.first);
|
||||
bootstrapKeyParam.second.hash(currentHash);
|
||||
}
|
||||
for (auto keyswitchParam : keyswitchKeys) {
|
||||
mlir::concretelang::hash(currentHash, keyswitchParam.first);
|
||||
keyswitchParam.second.hash(currentHash);
|
||||
}
|
||||
return currentHash;
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const LweSecretKeyParam &v) {
|
||||
llvm::json::Object object{
|
||||
{"size", v.size},
|
||||
};
|
||||
return object;
|
||||
}
|
||||
|
||||
bool fromJSON(const llvm::json::Value j, LweSecretKeyParam &v,
|
||||
llvm::json::Path p) {
|
||||
auto obj = j.getAsObject();
|
||||
if (obj == nullptr) {
|
||||
p.report("should be an object");
|
||||
return false;
|
||||
}
|
||||
auto size = obj->getInteger("size");
|
||||
if (!size.hasValue()) {
|
||||
p.report("missing size field");
|
||||
return false;
|
||||
}
|
||||
v.size = *size;
|
||||
return true;
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const BootstrapKeyParam &v) {
|
||||
llvm::json::Object object{
|
||||
{"inputSecretKeyID", v.inputSecretKeyID},
|
||||
{"outputSecretKeyID", v.outputSecretKeyID},
|
||||
{"level", v.level},
|
||||
{"glweDimension", v.glweDimension},
|
||||
{"baseLog", v.baseLog},
|
||||
{"variance", v.variance},
|
||||
};
|
||||
return object;
|
||||
}
|
||||
|
||||
bool fromJSON(const llvm::json::Value j, BootstrapKeyParam &v,
|
||||
llvm::json::Path p) {
|
||||
auto obj = j.getAsObject();
|
||||
if (obj == nullptr) {
|
||||
p.report("should be an object");
|
||||
return false;
|
||||
}
|
||||
auto inputSecretKeyID = obj->getString("inputSecretKeyID");
|
||||
if (!inputSecretKeyID.hasValue()) {
|
||||
p.report("missing inputSecretKeyID field");
|
||||
return false;
|
||||
}
|
||||
auto outputSecretKeyID = obj->getString("outputSecretKeyID");
|
||||
if (!outputSecretKeyID.hasValue()) {
|
||||
p.report("missing outputSecretKeyID field");
|
||||
return false;
|
||||
}
|
||||
auto level = obj->getInteger("level");
|
||||
if (!level.hasValue()) {
|
||||
p.report("missing level field");
|
||||
return false;
|
||||
}
|
||||
auto baseLog = obj->getInteger("baseLog");
|
||||
if (!baseLog.hasValue()) {
|
||||
p.report("missing baseLog field");
|
||||
return false;
|
||||
}
|
||||
auto glweDimension = obj->getInteger("glweDimension");
|
||||
if (!glweDimension.hasValue()) {
|
||||
p.report("missing glweDimension field");
|
||||
return false;
|
||||
}
|
||||
auto variance = obj->getNumber("variance");
|
||||
if (!variance.hasValue()) {
|
||||
p.report("missing variance field");
|
||||
return false;
|
||||
}
|
||||
v.inputSecretKeyID = (std::string)inputSecretKeyID.getValue();
|
||||
v.outputSecretKeyID = (std::string)outputSecretKeyID.getValue();
|
||||
v.level = level.getValue();
|
||||
v.baseLog = baseLog.getValue();
|
||||
v.glweDimension = glweDimension.getValue();
|
||||
v.variance = variance.getValue();
|
||||
return true;
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const KeyswitchKeyParam &v) {
|
||||
llvm::json::Object object{
|
||||
{"inputSecretKeyID", v.inputSecretKeyID},
|
||||
{"outputSecretKeyID", v.outputSecretKeyID},
|
||||
{"level", v.level},
|
||||
{"baseLog", v.baseLog},
|
||||
{"variance", v.variance},
|
||||
};
|
||||
return object;
|
||||
}
|
||||
bool fromJSON(const llvm::json::Value j, KeyswitchKeyParam &v,
|
||||
llvm::json::Path p) {
|
||||
auto obj = j.getAsObject();
|
||||
if (obj == nullptr) {
|
||||
p.report("should be an object");
|
||||
return false;
|
||||
}
|
||||
auto inputSecretKeyID = obj->getString("inputSecretKeyID");
|
||||
if (!inputSecretKeyID.hasValue()) {
|
||||
p.report("missing inputSecretKeyID field");
|
||||
return false;
|
||||
}
|
||||
auto outputSecretKeyID = obj->getString("outputSecretKeyID");
|
||||
if (!outputSecretKeyID.hasValue()) {
|
||||
p.report("missing outputSecretKeyID field");
|
||||
return false;
|
||||
}
|
||||
auto level = obj->getInteger("level");
|
||||
if (!level.hasValue()) {
|
||||
p.report("missing level field");
|
||||
return false;
|
||||
}
|
||||
auto baseLog = obj->getInteger("baseLog");
|
||||
if (!baseLog.hasValue()) {
|
||||
p.report("missing baseLog field");
|
||||
return false;
|
||||
}
|
||||
auto variance = obj->getNumber("variance");
|
||||
if (!variance.hasValue()) {
|
||||
p.report("missing variance field");
|
||||
return false;
|
||||
}
|
||||
v.inputSecretKeyID = (std::string)inputSecretKeyID.getValue();
|
||||
v.outputSecretKeyID = (std::string)outputSecretKeyID.getValue();
|
||||
v.level = level.getValue();
|
||||
v.baseLog = baseLog.getValue();
|
||||
v.variance = variance.getValue();
|
||||
return true;
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const CircuitGateShape &v) {
|
||||
llvm::json::Object object{
|
||||
{"width", v.width},
|
||||
{"dimensions", v.dimensions},
|
||||
{"size", v.size},
|
||||
};
|
||||
return object;
|
||||
}
|
||||
bool fromJSON(const llvm::json::Value j, CircuitGateShape &v,
|
||||
llvm::json::Path p) {
|
||||
auto obj = j.getAsObject();
|
||||
if (obj == nullptr) {
|
||||
p.report("should be an object");
|
||||
return false;
|
||||
}
|
||||
auto width = obj->getInteger("width");
|
||||
if (!width.hasValue()) {
|
||||
p.report("missing width field");
|
||||
return false;
|
||||
}
|
||||
auto dimensions = obj->getArray("dimensions");
|
||||
if (dimensions == nullptr) {
|
||||
p.report("missing dimensions field");
|
||||
return false;
|
||||
}
|
||||
for (auto dim : *dimensions) {
|
||||
auto iDim = dim.getAsInteger();
|
||||
if (!iDim.hasValue()) {
|
||||
p.report("dimensions must be integer");
|
||||
return false;
|
||||
}
|
||||
v.dimensions.push_back(iDim.getValue());
|
||||
}
|
||||
auto size = obj->getInteger("size");
|
||||
if (!size.hasValue()) {
|
||||
p.report("missing size field");
|
||||
return false;
|
||||
}
|
||||
v.width = width.getValue();
|
||||
v.size = size.getValue();
|
||||
return true;
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const Encoding &v) {
|
||||
llvm::json::Object object{
|
||||
{"precision", v.precision},
|
||||
};
|
||||
return object;
|
||||
}
|
||||
bool fromJSON(const llvm::json::Value j, Encoding &v, llvm::json::Path p) {
|
||||
auto obj = j.getAsObject();
|
||||
if (obj == nullptr) {
|
||||
p.report("should be an object");
|
||||
return false;
|
||||
}
|
||||
auto precision = obj->getInteger("precision");
|
||||
if (!precision.hasValue()) {
|
||||
p.report("missing precision field");
|
||||
return false;
|
||||
}
|
||||
v.precision = precision.getValue();
|
||||
return true;
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const EncryptionGate &v) {
|
||||
llvm::json::Object object{
|
||||
{"secretKeyID", v.secretKeyID},
|
||||
{"variance", v.variance},
|
||||
{"encoding", v.encoding},
|
||||
};
|
||||
return object;
|
||||
}
|
||||
bool fromJSON(const llvm::json::Value j, EncryptionGate &v,
|
||||
llvm::json::Path p) {
|
||||
auto obj = j.getAsObject();
|
||||
if (obj == nullptr) {
|
||||
p.report("should be an object");
|
||||
return false;
|
||||
}
|
||||
auto secretKeyID = obj->getString("secretKeyID");
|
||||
if (!secretKeyID.hasValue()) {
|
||||
p.report("missing secretKeyID field");
|
||||
return false;
|
||||
}
|
||||
v.secretKeyID = (std::string)secretKeyID.getValue();
|
||||
auto variance = obj->getNumber("variance");
|
||||
if (!variance.hasValue()) {
|
||||
p.report("missing variance field");
|
||||
return false;
|
||||
}
|
||||
v.variance = variance.getValue();
|
||||
auto encoding = obj->get("encoding");
|
||||
if (encoding == nullptr) {
|
||||
p.report("missing encoding field");
|
||||
return false;
|
||||
}
|
||||
if (!fromJSON(*encoding, v.encoding, p.field("encoding"))) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const CircuitGate &v) {
|
||||
llvm::json::Object object{
|
||||
{"encryption", v.encryption},
|
||||
{"shape", v.shape},
|
||||
};
|
||||
return object;
|
||||
}
|
||||
bool fromJSON(const llvm::json::Value j, CircuitGate &v, llvm::json::Path p) {
|
||||
auto obj = j.getAsObject();
|
||||
auto encryption = obj->get("encryption");
|
||||
if (encryption == nullptr) {
|
||||
p.report("missing encryption field");
|
||||
return false;
|
||||
}
|
||||
if (!fromJSON(*encryption, v.encryption, p.field("encryption"))) {
|
||||
return false;
|
||||
}
|
||||
auto shape = obj->get("shape");
|
||||
if (shape == nullptr) {
|
||||
p.report("missing shape field");
|
||||
return false;
|
||||
}
|
||||
if (!fromJSON(*shape, v.shape, p.field("shape"))) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T> llvm::json::Value toJson(std::map<std::string, T> map) {
|
||||
llvm::json::Object obj;
|
||||
for (auto entry : map) {
|
||||
obj[entry.first] = entry.second;
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const ClientParameters &v) {
|
||||
llvm::json::Object object{
|
||||
{"secretKeys", toJson(v.secretKeys)},
|
||||
{"bootstrapKeys", toJson(v.bootstrapKeys)},
|
||||
{"keyswitchKeys", toJson(v.keyswitchKeys)},
|
||||
{"inputs", v.inputs},
|
||||
{"outputs", v.outputs},
|
||||
{"functionName", v.functionName},
|
||||
};
|
||||
return object;
|
||||
}
|
||||
bool fromJSON(const llvm::json::Value j, ClientParameters &v,
|
||||
llvm::json::Path p) {
|
||||
|
||||
auto obj = j.getAsObject();
|
||||
auto secretkeys = obj->get("secretKeys");
|
||||
if (secretkeys == nullptr) {
|
||||
p.report("missing secretKeys field");
|
||||
return false;
|
||||
}
|
||||
if (!fromJSON(*secretkeys, v.secretKeys, p.field("secretKeys"))) {
|
||||
return false;
|
||||
}
|
||||
auto bootstrapKeys = obj->get("bootstrapKeys");
|
||||
if (bootstrapKeys == nullptr) {
|
||||
p.report("missing bootstrapKeys field");
|
||||
return false;
|
||||
}
|
||||
if (!fromJSON(*bootstrapKeys, v.bootstrapKeys, p.field("bootstrapKeys"))) {
|
||||
return false;
|
||||
}
|
||||
auto keyswitchKeys = obj->get("keyswitchKeys");
|
||||
if (keyswitchKeys == nullptr) {
|
||||
p.report("missing keyswitchKeys field");
|
||||
return false;
|
||||
}
|
||||
if (!fromJSON(*keyswitchKeys, v.keyswitchKeys, p.field("keyswitchKeys"))) {
|
||||
return false;
|
||||
}
|
||||
auto inputs = obj->get("inputs");
|
||||
if (inputs == nullptr) {
|
||||
p.report("missing inputs field");
|
||||
return false;
|
||||
}
|
||||
if (!fromJSON(*inputs, v.inputs, p.field("inputs"))) {
|
||||
return false;
|
||||
}
|
||||
auto outputs = obj->get("outputs");
|
||||
if (outputs == nullptr) {
|
||||
p.report("missing outputs field");
|
||||
return false;
|
||||
}
|
||||
if (!fromJSON(*outputs, v.outputs, p.field("outputs"))) {
|
||||
return false;
|
||||
}
|
||||
auto functionName = obj->getString("functionName");
|
||||
if (!functionName.hasValue()) {
|
||||
p.report("missing functionName field");
|
||||
return false;
|
||||
}
|
||||
v.functionName = (std::string)functionName.getValue();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -3,7 +3,7 @@
|
||||
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
|
||||
// information.
|
||||
|
||||
#include "concretelang/Support/KeySet.h"
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
|
||||
#define CAPI_ERR_TO_LLVM_ERROR(s, msg) \
|
||||
@@ -37,23 +37,18 @@ llvm::Expected<std::unique_ptr<KeySet>>
|
||||
KeySet::generate(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
|
||||
auto a = uninitialized();
|
||||
auto keySet = uninitialized();
|
||||
|
||||
auto fillError = a->generateKeysFromParams(params, seed_msb, seed_lsb);
|
||||
|
||||
if (fillError) {
|
||||
return StreamStringError()
|
||||
<< "Cannot fill keys from params: " << std::move(fillError);
|
||||
if (auto error = keySet->generateKeysFromParams(params, seed_msb, seed_lsb)) {
|
||||
return std::move(error);
|
||||
}
|
||||
|
||||
fillError = a->setupEncryptionMaterial(params, seed_msb, seed_lsb);
|
||||
|
||||
if (fillError) {
|
||||
return StreamStringError()
|
||||
<< "Cannot setup encryption material: " << std::move(fillError);
|
||||
if (auto error =
|
||||
keySet->setupEncryptionMaterial(params, seed_msb, seed_lsb)) {
|
||||
return std::move(error);
|
||||
}
|
||||
|
||||
return std::move(a);
|
||||
return std::move(keySet);
|
||||
}
|
||||
|
||||
std::unique_ptr<KeySet> KeySet::uninitialized() {
|
||||
@@ -66,8 +61,8 @@ llvm::Error KeySet::setupEncryptionMaterial(ClientParameters ¶ms,
|
||||
// Set inputs and outputs LWE secret keys
|
||||
{
|
||||
for (auto param : params.inputs) {
|
||||
std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *> input = {
|
||||
param, nullptr, nullptr};
|
||||
LweSecretKeyParam secretKeyParam = {0};
|
||||
LweSecretKey_u64 *secretKey = nullptr;
|
||||
if (param.encryption.hasValue()) {
|
||||
auto inputSk = this->secretKeys.find(param.encryption->secretKeyID);
|
||||
if (inputSk == this->secretKeys.end()) {
|
||||
@@ -76,14 +71,16 @@ llvm::Error KeySet::setupEncryptionMaterial(ClientParameters ¶ms,
|
||||
") does not exist ",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
std::get<1>(input) = &inputSk->second.first;
|
||||
std::get<2>(input) = inputSk->second.second;
|
||||
secretKeyParam = inputSk->second.first;
|
||||
secretKey = inputSk->second.second;
|
||||
}
|
||||
std::tuple<CircuitGate, LweSecretKeyParam, LweSecretKey_u64 *> input = {
|
||||
param, secretKeyParam, secretKey};
|
||||
this->inputs.push_back(input);
|
||||
}
|
||||
for (auto param : params.outputs) {
|
||||
std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *> output =
|
||||
{param, nullptr, nullptr};
|
||||
LweSecretKeyParam secretKeyParam = {0};
|
||||
LweSecretKey_u64 *secretKey = nullptr;
|
||||
if (param.encryption.hasValue()) {
|
||||
auto outputSk = this->secretKeys.find(param.encryption->secretKeyID);
|
||||
if (outputSk == this->secretKeys.end()) {
|
||||
@@ -91,9 +88,11 @@ llvm::Error KeySet::setupEncryptionMaterial(ClientParameters ¶ms,
|
||||
"cannot find output key to generate bootstrap key",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
std::get<1>(output) = &outputSk->second.first;
|
||||
std::get<2>(output) = outputSk->second.second;
|
||||
secretKeyParam = outputSk->second.first;
|
||||
secretKey = outputSk->second.second;
|
||||
}
|
||||
std::tuple<CircuitGate, LweSecretKeyParam, LweSecretKey_u64 *> output = {
|
||||
param, secretKeyParam, secretKey};
|
||||
this->outputs.push_back(output);
|
||||
}
|
||||
}
|
||||
@@ -283,7 +282,7 @@ llvm::Error KeySet::allocate_lwe(size_t argPos,
|
||||
}
|
||||
auto inputSk = inputs[argPos];
|
||||
CAPI_ERR_TO_LLVM_ERROR(*ciphertext = allocate_lwe_ciphertext_u64(
|
||||
&err, {std::get<1>(inputSk)->size + 1}),
|
||||
&err, {std::get<1>(inputSk).size + 1}),
|
||||
"cannot allocate ciphertext");
|
||||
return llvm::Error::success();
|
||||
}
|
||||
@@ -369,4 +368,4 @@ KeySet::getKeyswitchKeys() {
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
} // namespace mlir
|
||||
@@ -3,7 +3,7 @@
|
||||
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
|
||||
// information.
|
||||
|
||||
#include "concretelang/Support/KeySetCache.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include "llvm/ADT/ScopeExit.h"
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
@@ -11,6 +11,7 @@
|
||||
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
extern "C" {
|
||||
@@ -229,4 +230,4 @@ KeySetCache::tryLoadOrGenerateSave(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
} // namespace mlir
|
||||
@@ -1,5 +1,4 @@
|
||||
add_mlir_library(ConcretelangSupport
|
||||
Error.cpp
|
||||
Pipeline.cpp
|
||||
Jit.cpp
|
||||
CompilerEngine.cpp
|
||||
@@ -7,11 +6,9 @@ add_mlir_library(ConcretelangSupport
|
||||
LambdaArgument.cpp
|
||||
V0Parameters.cpp
|
||||
V0Curves.cpp
|
||||
ClientParameters.cpp
|
||||
KeySet.cpp
|
||||
V0ClientParameters.cpp
|
||||
logging.cpp
|
||||
Jit.cpp
|
||||
KeySetCache.cpp
|
||||
LLVMEmitFile.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
@@ -34,4 +31,5 @@ add_mlir_library(ConcretelangSupport
|
||||
${LLVM_PTHREAD_LIB}
|
||||
|
||||
ConcretelangRuntime
|
||||
ConcretelangClientLib
|
||||
)
|
||||
|
||||
@@ -3,7 +3,10 @@
|
||||
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
|
||||
// information.
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <stdio.h>
|
||||
#include <string>
|
||||
|
||||
#include <llvm/Support/Error.h>
|
||||
#include <llvm/Support/SMLoc.h>
|
||||
@@ -271,15 +274,23 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
return StreamStringError(
|
||||
"Cannot generate client parameters, the fhe context is empty");
|
||||
}
|
||||
}
|
||||
// Generate client parameters if requested
|
||||
auto funcName = this->clientParametersFuncName.getValueOr("main");
|
||||
if (this->generateClientParameters || target == Target::LIBRARY) {
|
||||
if (!res.fheContext.hasValue()) {
|
||||
// Some tests can involves a usual function
|
||||
res.clientParameters =
|
||||
mlir::concretelang::emptyClientParametersForV0(funcName, module);
|
||||
} else {
|
||||
auto clientParametersOrErr =
|
||||
mlir::concretelang::createClientParametersForV0(*res.fheContext,
|
||||
funcName, module);
|
||||
if (!clientParametersOrErr)
|
||||
return clientParametersOrErr.takeError();
|
||||
|
||||
llvm::Expected<mlir::concretelang::ClientParameters> clientParametersOrErr =
|
||||
mlir::concretelang::createClientParametersForV0(
|
||||
*res.fheContext, *this->clientParametersFuncName, module);
|
||||
|
||||
if (llvm::Error err = clientParametersOrErr.takeError())
|
||||
return std::move(err);
|
||||
|
||||
res.clientParameters = clientParametersOrErr.get();
|
||||
res.clientParameters = clientParametersOrErr.get();
|
||||
}
|
||||
}
|
||||
|
||||
// MLIR canonical dialects -> LLVM Dialect
|
||||
@@ -334,10 +345,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
llvm::Expected<CompilerEngine::CompilationResult>
|
||||
CompilerEngine::compile(llvm::StringRef s, Target target, OptionalLib lib) {
|
||||
std::unique_ptr<llvm::MemoryBuffer> mb = llvm::MemoryBuffer::getMemBuffer(s);
|
||||
llvm::Expected<CompilationResult> res =
|
||||
this->compile(std::move(mb), target, lib);
|
||||
|
||||
return std::move(res);
|
||||
return this->compile(std::move(mb), target, lib);
|
||||
}
|
||||
|
||||
// Compile the contained in `buffer` to the target dialect
|
||||
@@ -351,9 +359,7 @@ CompilerEngine::compile(std::unique_ptr<llvm::MemoryBuffer> buffer,
|
||||
|
||||
sm.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
|
||||
|
||||
llvm::Expected<CompilationResult> res = this->compile(sm, target, lib);
|
||||
|
||||
return std::move(res);
|
||||
return this->compile(sm, target, lib);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
@@ -371,10 +377,9 @@ CompilerEngine::compile(std::vector<T> inputs, std::string libraryPath) {
|
||||
}
|
||||
}
|
||||
|
||||
auto libPath = outputLib->emitShared();
|
||||
if (!libPath) {
|
||||
return StreamStringError("Can't link: ")
|
||||
<< llvm::toString(libPath.takeError());
|
||||
if (auto err = outputLib->emitArtifacts()) {
|
||||
return StreamStringError("Can't emit artifacts: ")
|
||||
<< llvm::toString(std::move(err));
|
||||
}
|
||||
return *outputLib.get();
|
||||
}
|
||||
@@ -384,7 +389,24 @@ template llvm::Expected<CompilerEngine::Library>
|
||||
CompilerEngine::compile(std::vector<std::string> inputs,
|
||||
std::string libraryPath);
|
||||
|
||||
/** Returns the path of the shared library */
|
||||
std::string CompilerEngine::Library::getSharedLibraryPath(std::string path) {
|
||||
return path + DOT_SHARED_LIB_EXT;
|
||||
}
|
||||
|
||||
/** Returns the path of the static library */
|
||||
std::string CompilerEngine::Library::getStaticLibraryPath(std::string path) {
|
||||
return path + DOT_STATIC_LIB_EXT;
|
||||
}
|
||||
|
||||
/** Returns the path of the static library */
|
||||
std::string CompilerEngine::Library::getClientParametersPath(std::string path) {
|
||||
return path + CLIENT_PARAMETERS_EXT;
|
||||
}
|
||||
|
||||
const std::string CompilerEngine::Library::OBJECT_EXT = ".o";
|
||||
const std::string CompilerEngine::Library::CLIENT_PARAMETERS_EXT =
|
||||
".concrete.params.json";
|
||||
const std::string CompilerEngine::Library::LINKER = "ld";
|
||||
const std::string CompilerEngine::Library::LINKER_SHARED_OPT = " --shared -o ";
|
||||
const std::string CompilerEngine::Library::AR = "ar";
|
||||
@@ -396,6 +418,23 @@ void CompilerEngine::Library::addExtraObjectFilePath(std::string path) {
|
||||
objectsPath.push_back(path);
|
||||
}
|
||||
|
||||
llvm::Expected<std::string>
|
||||
CompilerEngine::Library::emitClientParametersJSON() {
|
||||
auto clientParamsPath = getClientParametersPath(libraryPath);
|
||||
llvm::json::Value value(clientParametersList);
|
||||
std::error_code error;
|
||||
llvm::raw_fd_ostream out(clientParamsPath, error);
|
||||
|
||||
if (error) {
|
||||
return StreamStringError("cannot emit client parameters, error: ")
|
||||
<< error.message();
|
||||
}
|
||||
out << llvm::formatv("{0:2}", value);
|
||||
out.close();
|
||||
|
||||
return clientParamsPath;
|
||||
}
|
||||
|
||||
llvm::Expected<std::string>
|
||||
CompilerEngine::Library::addCompilation(CompilationResult &compilation) {
|
||||
llvm::Module *module = compilation.llvmModule.get();
|
||||
@@ -405,13 +444,14 @@ CompilerEngine::Library::addCompilation(CompilationResult &compilation) {
|
||||
std::to_string(objectsPath.size()) + ".mlir";
|
||||
}
|
||||
auto objectPath = sourceName + OBJECT_EXT;
|
||||
auto error = mlir::concretelang::emitObject(*module, objectPath);
|
||||
|
||||
if (error) {
|
||||
if (auto error = mlir::concretelang::emitObject(*module, objectPath)) {
|
||||
return std::move(error);
|
||||
}
|
||||
|
||||
addExtraObjectFilePath(objectPath);
|
||||
if (compilation.clientParameters.hasValue()) {
|
||||
clientParametersList.push_back(compilation.clientParameters.getValue());
|
||||
}
|
||||
return objectPath;
|
||||
}
|
||||
|
||||
@@ -437,9 +477,8 @@ llvm::Expected<std::string> CompilerEngine::Library::emit(std::string dotExt,
|
||||
auto error = mlir::concretelang::emitLibrary(objectsPath, pathDotExt, linker);
|
||||
if (error) {
|
||||
return std::move(error);
|
||||
} else {
|
||||
return pathDotExt;
|
||||
}
|
||||
return pathDotExt;
|
||||
}
|
||||
|
||||
llvm::Expected<std::string> CompilerEngine::Library::emitShared() {
|
||||
@@ -458,6 +497,19 @@ llvm::Expected<std::string> CompilerEngine::Library::emitStatic() {
|
||||
return path;
|
||||
}
|
||||
|
||||
llvm::Error CompilerEngine::Library::emitArtifacts() {
|
||||
if (auto err = emitShared().takeError()) {
|
||||
return err;
|
||||
}
|
||||
if (auto err = emitStatic().takeError()) {
|
||||
return err;
|
||||
}
|
||||
if (auto err = emitClientParametersJSON().takeError()) {
|
||||
return err;
|
||||
}
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
CompilerEngine::Library::~Library() {
|
||||
if (cleanUp) {
|
||||
for (auto path : objectsPath) {
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
|
||||
// information.
|
||||
|
||||
#include <concretelang/Support/Error.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
// Specialized `operator<<` for `llvm::Error` that marks the error
|
||||
// as checked through `std::move` and `llvm::toString`
|
||||
StreamStringError &operator<<(StreamStringError &se, llvm::Error &err) {
|
||||
se << llvm::toString(std::move(err));
|
||||
return se;
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -8,8 +8,9 @@
|
||||
|
||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
|
||||
#include "concretelang/Support/ClientParameters.h"
|
||||
#include "concretelang/Support/V0Curves.h"
|
||||
|
||||
namespace mlir {
|
||||
@@ -20,7 +21,7 @@ const auto keyFormat = KEY_FORMAT_BINARY;
|
||||
const auto v0Curve = getV0Curves(securityLevel, keyFormat);
|
||||
|
||||
// For the v0 the secretKeyID and precision are the same for all gates.
|
||||
llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
|
||||
llvm::Expected<CircuitGate> gateFromMLIRType(LweSecretKeyID secretKeyID,
|
||||
Precision precision,
|
||||
Variance variance,
|
||||
mlir::Type type) {
|
||||
@@ -46,10 +47,13 @@ llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
|
||||
// TODO - Get the width from the LWECiphertextType instead of global
|
||||
// precision (could be possible after merge concrete-ciphertext-parameter)
|
||||
return CircuitGate{
|
||||
.encryption = llvm::Optional<EncryptionGate>({
|
||||
.secretKeyID = secretKeyID,
|
||||
.variance = variance,
|
||||
.encoding = {.precision = precision},
|
||||
/* .encryption = */ llvm::Optional<EncryptionGate>({
|
||||
/* .secretKeyID = */ secretKeyID,
|
||||
/* .variance = */ variance,
|
||||
/* .encoding = */
|
||||
{
|
||||
/* .precision = */ precision,
|
||||
},
|
||||
}),
|
||||
/*.shape = */
|
||||
{
|
||||
@@ -77,25 +81,33 @@ llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
|
||||
"cannot convert MLIR type to shape", llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
ClientParameters emptyClientParametersForV0(llvm::StringRef functionName,
|
||||
mlir::ModuleOp module) {
|
||||
ClientParameters c;
|
||||
c.functionName = (std::string)functionName;
|
||||
return c;
|
||||
}
|
||||
|
||||
llvm::Expected<ClientParameters>
|
||||
createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name,
|
||||
createClientParametersForV0(V0FHEContext fheContext,
|
||||
llvm::StringRef functionName,
|
||||
mlir::ModuleOp module) {
|
||||
auto v0Param = fheContext.parameter;
|
||||
Variance encryptionVariance =
|
||||
v0Curve->getVariance(1, 1 << v0Param.logPolynomialSize, 64);
|
||||
Variance keyswitchVariance = v0Curve->getVariance(1, v0Param.nSmall, 64);
|
||||
// Static client parameters from global parameters for v0
|
||||
ClientParameters c = {};
|
||||
ClientParameters c;
|
||||
c.secretKeys = {
|
||||
{"small", {/*.size = */ v0Param.nSmall}},
|
||||
{"big", {/*.size = */ v0Param.getNBigGlweDimension()}},
|
||||
{SMALL_KEY, {/*.size = */ v0Param.nSmall}},
|
||||
{BIG_KEY, {/*.size = */ v0Param.getNBigGlweDimension()}},
|
||||
};
|
||||
c.bootstrapKeys = {
|
||||
{
|
||||
"bsk_v0",
|
||||
{
|
||||
/*.inputSecretKeyID = */ "small",
|
||||
/*.outputSecretKeyID = */ "big",
|
||||
/*.inputSecretKeyID = */ SMALL_KEY,
|
||||
/*.outputSecretKeyID = */ BIG_KEY,
|
||||
/*.level = */ v0Param.brLevel,
|
||||
/*.baseLog = */ v0Param.brLogBase,
|
||||
/*.glweDimension = */ v0Param.glweDimension,
|
||||
@@ -107,18 +119,19 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name,
|
||||
{
|
||||
"ksk_v0",
|
||||
{
|
||||
/*.inputSecretKeyID = */ "big",
|
||||
/*.outputSecretKeyID = */ "small",
|
||||
/*.inputSecretKeyID = */ BIG_KEY,
|
||||
/*.outputSecretKeyID = */ SMALL_KEY,
|
||||
/*.level = */ v0Param.ksLevel,
|
||||
/*.baseLog = */ v0Param.ksLogBase,
|
||||
/*.variance = */ keyswitchVariance,
|
||||
},
|
||||
},
|
||||
};
|
||||
c.functionName = (std::string)functionName;
|
||||
// Find the input function
|
||||
auto rangeOps = module.getOps<mlir::FuncOp>();
|
||||
auto funcOp = llvm::find_if(
|
||||
rangeOps, [&](mlir::FuncOp op) { return op.getName() == name; });
|
||||
rangeOps, [&](mlir::FuncOp op) { return op.getName() == functionName; });
|
||||
if (funcOp == rangeOps.end()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot find the function for generate client parameters",
|
||||
@@ -135,14 +148,16 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name,
|
||||
.isa<mlir::concretelang::Concrete::ContextType>();
|
||||
for (auto inType = funcType.getInputs().begin();
|
||||
inType < funcType.getInputs().end() - hasContext; inType++) {
|
||||
auto gate = gateFromMLIRType("big", precision, encryptionVariance, *inType);
|
||||
auto gate =
|
||||
gateFromMLIRType(BIG_KEY, precision, encryptionVariance, *inType);
|
||||
if (auto err = gate.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
c.inputs.push_back(gate.get());
|
||||
}
|
||||
for (auto outType : funcType.getResults()) {
|
||||
auto gate = gateFromMLIRType("big", precision, encryptionVariance, outType);
|
||||
auto gate =
|
||||
gateFromMLIRType(BIG_KEY, precision, encryptionVariance, outType);
|
||||
if (auto err = gate.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
@@ -151,46 +166,5 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name,
|
||||
return c;
|
||||
}
|
||||
|
||||
// https://stackoverflow.com/a/38140932
|
||||
static inline void hash(std::size_t &seed) {}
|
||||
template <typename T, typename... Rest>
|
||||
static inline void hash(std::size_t &seed, const T &v, Rest... rest) {
|
||||
// See https://softwareengineering.stackexchange.com/a/402543
|
||||
const auto GOLDEN_RATIO = 0x9e3779b97f4a7c15; // pseudo random bits
|
||||
const std::hash<T> hasher;
|
||||
seed ^= hasher(v) + GOLDEN_RATIO + (seed << 6) + (seed >> 2);
|
||||
hash(seed, rest...);
|
||||
}
|
||||
|
||||
void LweSecretKeyParam::hash(size_t &seed) {
|
||||
mlir::concretelang::hash(seed, size);
|
||||
}
|
||||
|
||||
void BootstrapKeyParam::hash(size_t &seed) {
|
||||
mlir::concretelang::hash(seed, inputSecretKeyID, outputSecretKeyID, level,
|
||||
baseLog, glweDimension, variance);
|
||||
}
|
||||
|
||||
void KeyswitchKeyParam::hash(size_t &seed) {
|
||||
mlir::concretelang::hash(seed, inputSecretKeyID, outputSecretKeyID, level,
|
||||
baseLog, variance);
|
||||
}
|
||||
|
||||
std::size_t ClientParameters::hash() {
|
||||
std::size_t currentHash = 1;
|
||||
for (auto secretKeyParam : secretKeys) {
|
||||
mlir::concretelang::hash(currentHash, secretKeyParam.first);
|
||||
secretKeyParam.second.hash(currentHash);
|
||||
}
|
||||
for (auto bootstrapKeyParam : bootstrapKeys) {
|
||||
mlir::concretelang::hash(currentHash, bootstrapKeyParam.first);
|
||||
bootstrapKeyParam.second.hash(currentHash);
|
||||
}
|
||||
for (auto keyswitchParam : keyswitchKeys) {
|
||||
mlir::concretelang::hash(currentHash, keyswitchParam.first);
|
||||
keyswitchParam.second.hash(currentHash);
|
||||
}
|
||||
return currentHash;
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -27,4 +27,4 @@ V0Curves *getV0Curves(int securityLevel, int keyFormat) {
|
||||
return &curves[securityLevel][keyFormat];
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
} // namespace mlir
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include <mlir/Support/ToolUtilities.h>
|
||||
#include <sstream>
|
||||
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
|
||||
@@ -30,7 +31,6 @@
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include "concretelang/Support/JitCompilerEngine.h"
|
||||
#include "concretelang/Support/KeySet.h"
|
||||
#include "concretelang/Support/LLVMEmitFile.h"
|
||||
#include "concretelang/Support/Pipeline.h"
|
||||
#include "concretelang/Support/logging.h"
|
||||
@@ -112,7 +112,7 @@ static llvm::cl::opt<enum Action> action(
|
||||
"Lower to LLVM-IR, optimize and dump result")),
|
||||
llvm::cl::values(clEnumValN(Action::JIT_INVOKE, "jit-invoke",
|
||||
"Lower and JIT-compile input module and invoke "
|
||||
"function specified with --jit-funcname")),
|
||||
"function specified with --funcname")),
|
||||
llvm::cl::values(clEnumValN(Action::COMPILE, "compile",
|
||||
"Lower to LLVM-IR, compile to a file")));
|
||||
|
||||
@@ -133,10 +133,10 @@ llvm::cl::opt<bool> autoParallelize(
|
||||
llvm::cl::desc("Generate (and execute if JIT) parallel code"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
llvm::cl::opt<std::string> jitFuncName(
|
||||
"jit-funcname",
|
||||
llvm::cl::desc("Name of the function to execute, default 'main'"),
|
||||
llvm::cl::init<std::string>("main"));
|
||||
llvm::cl::opt<std::string>
|
||||
funcName("funcname",
|
||||
llvm::cl::desc("Name of the function to compile, default 'main'"),
|
||||
llvm::cl::init<std::string>("main"));
|
||||
|
||||
llvm::cl::list<uint64_t>
|
||||
jitArgs("jit-args",
|
||||
@@ -216,7 +216,7 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
|
||||
// The parameter `action` specifies how the buffer should be processed
|
||||
// and thus defines the output.
|
||||
//
|
||||
// If the specified action involves JIT compilation, `jitFuncName`
|
||||
// If the specified action involves JIT compilation, `funcName`
|
||||
// designates the function to JIT compile. This function is invoked
|
||||
// using the parameters given in `jitArgs`.
|
||||
//
|
||||
@@ -239,7 +239,7 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
|
||||
// Compilation output is written to the stream specified by `os`.
|
||||
mlir::LogicalResult processInputBuffer(
|
||||
std::unique_ptr<llvm::MemoryBuffer> buffer, std::string sourceFileName,
|
||||
enum Action action, const std::string &jitFuncName,
|
||||
enum Action action, const std::string &funcName,
|
||||
llvm::ArrayRef<uint64_t> jitArgs,
|
||||
llvm::Optional<size_t> overrideMaxEintPrecision,
|
||||
llvm::Optional<size_t> overrideMaxMANP, bool verifyDiagnostics,
|
||||
@@ -269,17 +269,18 @@ mlir::LogicalResult processInputBuffer(
|
||||
if (overrideMaxMANP.hasValue())
|
||||
ce.setMaxMANP(overrideMaxMANP.getValue());
|
||||
|
||||
ce.setClientParametersFuncName(funcName);
|
||||
if (fhelinalgTileSizes.hasValue())
|
||||
ce.setFHELinalgTileSizes(*fhelinalgTileSizes);
|
||||
|
||||
if (action == Action::JIT_INVOKE) {
|
||||
llvm::Expected<mlir::concretelang::JitCompilerEngine::Lambda> lambdaOrErr =
|
||||
ce.buildLambda(std::move(buffer), jitFuncName, keySetCache);
|
||||
ce.buildLambda(std::move(buffer), funcName, keySetCache);
|
||||
|
||||
if (!lambdaOrErr) {
|
||||
mlir::concretelang::log_error()
|
||||
<< "Failed to JIT-compile " << jitFuncName << ": "
|
||||
<< llvm::toString(std::move(lambdaOrErr.takeError()));
|
||||
<< "Failed to JIT-compile " << funcName << ": "
|
||||
<< llvm::toString(lambdaOrErr.takeError());
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
@@ -287,7 +288,7 @@ mlir::LogicalResult processInputBuffer(
|
||||
|
||||
if (!resOrErr) {
|
||||
mlir::concretelang::log_error()
|
||||
<< "Failed to JIT-invoke " << jitFuncName << " with arguments "
|
||||
<< "Failed to JIT-invoke " << funcName << " with arguments "
|
||||
<< jitArgs << ": " << llvm::toString(std::move(resOrErr.takeError()));
|
||||
return mlir::failure();
|
||||
}
|
||||
@@ -425,11 +426,11 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
|
||||
auto process = [&](std::unique_ptr<llvm::MemoryBuffer> inputBuffer,
|
||||
llvm::raw_ostream &os) {
|
||||
return processInputBuffer(
|
||||
std::move(inputBuffer), fileName, cmdline::action,
|
||||
cmdline::jitFuncName, cmdline::jitArgs,
|
||||
cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP,
|
||||
cmdline::verifyDiagnostics, fhelinalgTileSizes,
|
||||
cmdline::autoParallelize, jitKeySetCache, os, outputLib);
|
||||
std::move(inputBuffer), fileName, cmdline::action, cmdline::funcName,
|
||||
cmdline::jitArgs, cmdline::assumeMaxEintPrecision,
|
||||
cmdline::assumeMaxMANP, cmdline::verifyDiagnostics,
|
||||
fhelinalgTileSizes, cmdline::autoParallelize, jitKeySetCache, os,
|
||||
outputLib);
|
||||
};
|
||||
auto &os = output->os();
|
||||
auto res = mlir::failure();
|
||||
@@ -446,12 +447,8 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
|
||||
}
|
||||
|
||||
if (cmdline::action == Action::COMPILE) {
|
||||
auto libPath = outputLib->emitShared();
|
||||
if (!libPath) {
|
||||
return mlir::failure();
|
||||
}
|
||||
libPath = outputLib->emitStatic();
|
||||
if (!libPath) {
|
||||
auto err = outputLib->emitArtifacts();
|
||||
if (err) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
if (CONCRETELANG_UNIT_TESTS)
|
||||
add_subdirectory(unittest)
|
||||
add_subdirectory(Support)
|
||||
endif()
|
||||
|
||||
23
compiler/tests/Support/CMakeLists.txt
Normal file
23
compiler/tests/Support/CMakeLists.txt
Normal file
@@ -0,0 +1,23 @@
|
||||
enable_testing()
|
||||
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
||||
|
||||
add_executable(
|
||||
support_unit_test
|
||||
support_unit_test.cpp
|
||||
)
|
||||
|
||||
set_source_files_properties(
|
||||
support_unit_test.cpp
|
||||
|
||||
PROPERTIES COMPILE_FLAGS "-fno-rtti"
|
||||
)
|
||||
|
||||
target_link_libraries(
|
||||
support_unit_test
|
||||
gtest_main
|
||||
ConcretelangSupport
|
||||
)
|
||||
|
||||
include(GoogleTest)
|
||||
gtest_discover_tests(support_unit_test)
|
||||
69
compiler/tests/Support/support_unit_test.cpp
Normal file
69
compiler/tests/Support/support_unit_test.cpp
Normal file
@@ -0,0 +1,69 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "../unittest/end_to_end_jit_test.h"
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
|
||||
namespace CL = mlir::concretelang;
|
||||
|
||||
TEST(Support, client_parameters_json_serde) {
|
||||
mlir::concretelang::ClientParameters params0;
|
||||
params0.secretKeys = {
|
||||
{mlir::concretelang::SMALL_KEY, {/*.size = */ 12}},
|
||||
{mlir::concretelang::BIG_KEY, {/*.size = */ 14}},
|
||||
};
|
||||
params0.bootstrapKeys = {
|
||||
{
|
||||
"bsk_v0", {
|
||||
/*.inputSecretKeyID = */ mlir::concretelang::SMALL_KEY,
|
||||
/*.outputSecretKeyID = */ mlir::concretelang::BIG_KEY,
|
||||
/*.level = */ 1,
|
||||
/*.baseLog = */ 2,
|
||||
/*.k = */ 3,
|
||||
/*.variance = */ 0.001
|
||||
}
|
||||
},{
|
||||
"wtf_bsk_v0", {
|
||||
/*.inputSecretKeyID = */ mlir::concretelang::BIG_KEY,
|
||||
/*.outputSecretKeyID = */ mlir::concretelang::SMALL_KEY,
|
||||
/*.level = */ 3,
|
||||
/*.baseLog = */ 2,
|
||||
/*.k = */ 1,
|
||||
/*.variance = */ 0.0001,
|
||||
}
|
||||
},
|
||||
};
|
||||
params0.keyswitchKeys = {
|
||||
{
|
||||
"ksk_v0", {
|
||||
/*.inputSecretKeyID = */ mlir::concretelang::BIG_KEY,
|
||||
/*.outputSecretKeyID = */ mlir::concretelang::SMALL_KEY,
|
||||
/*.level = */ 1,
|
||||
/*.baseLog = */ 2,
|
||||
/*.variance = */ 3,
|
||||
}
|
||||
}
|
||||
};
|
||||
params0.inputs = {
|
||||
{
|
||||
/*.encryption = */ {{CL::SMALL_KEY, 0.01, {4}}},
|
||||
/*.shape = */ {32, {1, 2, 3, 4}, 1 * 2 * 3 * 4},
|
||||
},
|
||||
{
|
||||
/*.encryption = */ {{CL::SMALL_KEY, 0.03, {5}}},
|
||||
/*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4},
|
||||
},
|
||||
};
|
||||
params0.outputs = {
|
||||
{
|
||||
/*.encryption = */ {{CL::SMALL_KEY, 0.03, {5}}},
|
||||
/*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4},
|
||||
},
|
||||
};
|
||||
auto json = mlir::concretelang::toJSON(params0);
|
||||
std::string jsonStr;
|
||||
llvm::raw_string_ostream os(jsonStr);
|
||||
os << json;
|
||||
auto parseResult =
|
||||
llvm::json::parse<mlir::concretelang::ClientParameters>(jsonStr);
|
||||
ASSERT_EXPECTED_VALUE(parseResult, params0);
|
||||
}
|
||||
@@ -3,9 +3,9 @@
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/JitCompilerEngine.h"
|
||||
#include "concretelang/Support/KeySetCache.h"
|
||||
#include "llvm/Support/Path.h"
|
||||
|
||||
#include "globals.h"
|
||||
|
||||
Reference in New Issue
Block a user