feat(compiler): Output client parameters when compile to a library

close #198
This commit is contained in:
rudy
2021-12-29 11:34:54 +01:00
committed by Quentin Bourgerie
parent a4e8227692
commit b8bd38dd6c
26 changed files with 889 additions and 271 deletions

View File

@@ -53,11 +53,20 @@ test-check: concretecompiler file-check not
test-python: python-bindings concretecompiler
PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/concretelang/python_packages/concretelang_core LD_PRELOAD=$(BUILD_DIR)/lib/libConcretelangRuntime.so pytest -vs tests/python
test: test-check test-end-to-end-jit test-python
test: test-check test-end-to-end-jit test-python support-unit-test
test-dataflow: test-end-to-end-jit-dfr test-end-to-end-jit-auto-parallelization
# Unittests
# unit-test
support-unit-test: build-support-unit-test
$(BUILD_DIR)/bin/support_unit_test
build-support-unit-test: build-initialized
cmake --build $(BUILD_DIR) --target support_unit_test
# test-end-to-end-jit
build-end-to-end-jit-test: build-initialized
cmake --build $(BUILD_DIR) --target end_to_end_jit_test

View File

@@ -0,0 +1,165 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
// information.
#ifndef CONCRETELANG_CLIENTLIB_CLIENTPARAMETERS_H_
#define CONCRETELANG_CLIENTLIB_CLIENTPARAMETERS_H_
#include <map>
#include <string>
#include <vector>
#include <llvm/Support/JSON.h>
namespace mlir {
namespace concretelang {
const std::string SMALL_KEY = "small";
const std::string BIG_KEY = "big";
typedef size_t DecompositionLevelCount;
typedef size_t DecompositionBaseLog;
typedef size_t PolynomialSize;
typedef size_t Precision;
typedef double Variance;
typedef uint64_t LweSize;
typedef uint64_t GlweDimension;
typedef std::string LweSecretKeyID;
struct LweSecretKeyParam {
LweSize size;
void hash(size_t &seed);
};
static bool operator==(const LweSecretKeyParam &lhs,
const LweSecretKeyParam &rhs) {
return lhs.size == rhs.size;
}
typedef std::string BootstrapKeyID;
struct BootstrapKeyParam {
LweSecretKeyID inputSecretKeyID;
LweSecretKeyID outputSecretKeyID;
DecompositionLevelCount level;
DecompositionBaseLog baseLog;
GlweDimension glweDimension;
Variance variance;
void hash(size_t &seed);
};
static inline bool operator==(const BootstrapKeyParam &lhs,
const BootstrapKeyParam &rhs) {
return lhs.inputSecretKeyID == rhs.inputSecretKeyID &&
lhs.outputSecretKeyID == rhs.outputSecretKeyID &&
lhs.level == rhs.level && lhs.baseLog == rhs.baseLog &&
lhs.glweDimension == rhs.glweDimension && lhs.variance == rhs.variance;
}
typedef std::string KeyswitchKeyID;
struct KeyswitchKeyParam {
LweSecretKeyID inputSecretKeyID;
LweSecretKeyID outputSecretKeyID;
DecompositionLevelCount level;
DecompositionBaseLog baseLog;
Variance variance;
void hash(size_t &seed);
};
static inline bool operator==(const KeyswitchKeyParam &lhs,
const KeyswitchKeyParam &rhs) {
return lhs.inputSecretKeyID == rhs.inputSecretKeyID &&
lhs.outputSecretKeyID == rhs.outputSecretKeyID &&
lhs.level == rhs.level && lhs.baseLog == rhs.baseLog &&
lhs.variance == rhs.variance;
}
struct Encoding {
Precision precision;
};
static inline bool operator==(const Encoding &lhs, const Encoding &rhs) {
return lhs.precision == rhs.precision;
}
struct EncryptionGate {
LweSecretKeyID secretKeyID;
Variance variance;
Encoding encoding;
};
static inline bool operator==(const EncryptionGate &lhs,
const EncryptionGate &rhs) {
return lhs.secretKeyID == rhs.secretKeyID && lhs.variance == rhs.variance &&
lhs.encoding == rhs.encoding;
}
struct CircuitGateShape {
// Width of the scalar value
size_t width;
// Dimensions of the tensor, empty if scalar
std::vector<int64_t> dimensions;
// Size of the buffer containing the tensor
size_t size;
};
static inline bool operator==(const CircuitGateShape &lhs,
const CircuitGateShape &rhs) {
return lhs.width == rhs.width && lhs.dimensions == rhs.dimensions &&
lhs.size == rhs.size;
}
struct CircuitGate {
llvm::Optional<EncryptionGate> encryption;
CircuitGateShape shape;
};
static inline bool operator==(const CircuitGate &lhs, const CircuitGate &rhs) {
return lhs.encryption == rhs.encryption && lhs.shape == rhs.shape;
}
struct ClientParameters {
std::map<LweSecretKeyID, LweSecretKeyParam> secretKeys;
std::map<BootstrapKeyID, BootstrapKeyParam> bootstrapKeys;
std::map<KeyswitchKeyID, KeyswitchKeyParam> keyswitchKeys;
std::vector<CircuitGate> inputs;
std::vector<CircuitGate> outputs;
std::string functionName;
size_t hash();
};
static inline bool operator==(const ClientParameters &lhs,
const ClientParameters &rhs) {
return lhs.secretKeys == rhs.secretKeys &&
lhs.bootstrapKeys == rhs.bootstrapKeys &&
lhs.keyswitchKeys == rhs.keyswitchKeys && lhs.inputs == lhs.inputs &&
lhs.outputs == lhs.outputs;
}
llvm::json::Value toJSON(const LweSecretKeyParam &);
bool fromJSON(const llvm::json::Value, LweSecretKeyParam &, llvm::json::Path);
llvm::json::Value toJSON(const BootstrapKeyParam &);
bool fromJSON(const llvm::json::Value, BootstrapKeyParam &, llvm::json::Path);
llvm::json::Value toJSON(const KeyswitchKeyParam &);
bool fromJSON(const llvm::json::Value, KeyswitchKeyParam &, llvm::json::Path);
llvm::json::Value toJSON(const Encoding &);
bool fromJSON(const llvm::json::Value, Encoding &, llvm::json::Path);
llvm::json::Value toJSON(const EncryptionGate &);
bool fromJSON(const llvm::json::Value, EncryptionGate &, llvm::json::Path);
llvm::json::Value toJSON(const CircuitGateShape &);
bool fromJSON(const llvm::json::Value, CircuitGateShape &, llvm::json::Path);
llvm::json::Value toJSON(const CircuitGate &);
bool fromJSON(const llvm::json::Value, CircuitGate &, llvm::json::Path);
llvm::json::Value toJSON(const ClientParameters &);
bool fromJSON(const llvm::json::Value, ClientParameters &, llvm::json::Path);
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
ClientParameters cp) {
return OS << llvm::formatv("{0:2}", toJSON(cp));
}
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -6,7 +6,6 @@
#ifndef CONCRETELANG_SUPPORT_KEYSET_H_
#define CONCRETELANG_SUPPORT_KEYSET_H_
#include "llvm/Support/Error.h"
#include <memory>
extern "C" {
@@ -14,8 +13,8 @@ extern "C" {
#include "concretelang/Runtime/context.h"
}
#include "concretelang/Support/ClientParameters.h"
#include "concretelang/Support/KeySetCache.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/KeySetCache.h"
namespace mlir {
namespace concretelang {
@@ -91,9 +90,9 @@ private:
bootstrapKeys;
std::map<LweSecretKeyID, std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>>
keyswitchKeys;
std::vector<std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *>>
std::vector<std::tuple<CircuitGate, LweSecretKeyParam, LweSecretKey_u64 *>>
inputs;
std::vector<std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *>>
std::vector<std::tuple<CircuitGate, LweSecretKeyParam, LweSecretKey_u64 *>>
outputs;
void setKeys(

View File

@@ -6,7 +6,7 @@
#ifndef CONCRETELANG_SUPPORT_KEYSETCACHE_H_
#define CONCRETELANG_SUPPORT_KEYSETCACHE_H_
#include "concretelang/Support/KeySet.h"
#include "concretelang/ClientLib/KeySet.h"
namespace mlir {
namespace concretelang {

View File

@@ -1,99 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
// information.
#ifndef CONCRETELANG_SUPPORT_CLIENTPARAMETERS_H_
#define CONCRETELANG_SUPPORT_CLIENTPARAMETERS_H_
#include <map>
#include <string>
#include <vector>
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
#include <mlir/IR/BuiltinOps.h>
#include "concretelang/Support/V0Parameters.h"
namespace mlir {
namespace concretelang {
typedef size_t DecompositionLevelCount;
typedef size_t DecompositionBaseLog;
typedef size_t PolynomialSize;
typedef size_t Precision;
typedef double Variance;
typedef uint64_t LweSize;
typedef uint64_t GlweDimension;
typedef std::string LweSecretKeyID;
struct LweSecretKeyParam {
LweSize size;
void hash(size_t &seed);
};
typedef std::string BootstrapKeyID;
struct BootstrapKeyParam {
LweSecretKeyID inputSecretKeyID;
LweSecretKeyID outputSecretKeyID;
DecompositionLevelCount level;
DecompositionBaseLog baseLog;
GlweDimension glweDimension;
Variance variance;
void hash(size_t &seed);
};
typedef std::string KeyswitchKeyID;
struct KeyswitchKeyParam {
LweSecretKeyID inputSecretKeyID;
LweSecretKeyID outputSecretKeyID;
DecompositionLevelCount level;
DecompositionBaseLog baseLog;
Variance variance;
void hash(size_t &seed);
};
struct Encoding {
Precision precision;
};
struct EncryptionGate {
LweSecretKeyID secretKeyID;
Variance variance;
Encoding encoding;
};
struct CircuitGateShape {
// Width of the scalar value
size_t width;
// Dimensions of the tensor, empty if scalar
std::vector<int64_t> dimensions;
// Size of the buffer containing the tensor
size_t size;
};
struct CircuitGate {
llvm::Optional<EncryptionGate> encryption;
CircuitGateShape shape;
};
struct ClientParameters {
std::map<LweSecretKeyID, LweSecretKeyParam> secretKeys;
std::map<BootstrapKeyID, BootstrapKeyParam> bootstrapKeys;
std::map<KeyswitchKeyID, KeyswitchKeyParam> keyswitchKeys;
std::vector<CircuitGate> inputs;
std::vector<CircuitGate> outputs;
size_t hash();
};
llvm::Expected<ClientParameters>
createClientParametersForV0(V0FHEContext context, llvm::StringRef name,
mlir::ModuleOp module);
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -7,7 +7,7 @@
#define CONCRETELANG_SUPPORT_COMPILER_ENGINE_H
#include <concretelang/Conversion/Utils/GlobalFHEContext.h>
#include <concretelang/Support/ClientParameters.h>
#include <concretelang/Support/V0ClientParameters.h>
#include <llvm/IR/Module.h>
#include <llvm/Support/Error.h>
#include <llvm/Support/SourceMgr.h>
@@ -58,6 +58,7 @@ public:
class Library {
std::string libraryPath;
std::vector<std::string> objectsPath;
std::vector<mlir::concretelang::ClientParameters> clientParametersList;
bool cleanUp;
public:
@@ -68,21 +69,38 @@ public:
: libraryPath(libraryPath), cleanUp(cleanUp) {}
/** Add a compilation result to the library */
llvm::Expected<std::string> addCompilation(CompilationResult &compilation);
/** Emit a shared library with the previously added compilation result */
llvm::Expected<std::string> emitShared();
/** Emit a shared library with the previously added compilation result */
llvm::Expected<std::string> emitStatic();
/** Emit the library artifacts with the previously added compilation result
*/
llvm::Error emitArtifacts();
/** After a shared library has been emitted, its path is here */
std::string sharedLibraryPath;
/** After a static library has been emitted, its path is here */
std::string staticLibraryPath;
/** Returns the path of the shared library */
static std::string getSharedLibraryPath(std::string path);
/** Returns the path of the static library */
static std::string getStaticLibraryPath(std::string path);
/** Returns the path of the static library */
static std::string getClientParametersPath(std::string path);
// For advanced use
const static std::string OBJECT_EXT, LINKER, LINKER_SHARED_OPT, AR,
AR_STATIC_OPT, DOT_STATIC_LIB_EXT, DOT_SHARED_LIB_EXT;
const static std::string OBJECT_EXT, CLIENT_PARAMETERS_EXT, LINKER,
LINKER_SHARED_OPT, AR, AR_STATIC_OPT, DOT_STATIC_LIB_EXT,
DOT_SHARED_LIB_EXT;
void addExtraObjectFilePath(std::string objectFilePath);
llvm::Expected<std::string> emit(std::string dotExt, std::string linker);
~Library();
private:
/** Emit a shared library with the previously added compilation result */
llvm::Expected<std::string> emitStatic();
/** Emit a shared library with the previously added compilation result */
llvm::Expected<std::string> emitShared();
/** Emit a shared library with the previously added compilation result */
llvm::Expected<std::string> emitClientParametersJSON();
};
// Specification of the exit stage of the compilation pipeline

View File

@@ -50,7 +50,10 @@ protected:
llvm::raw_string_ostream os;
};
StreamStringError &operator<<(StreamStringError &se, llvm::Error &err);
inline StreamStringError &operator<<(StreamStringError &se, llvm::Error &err) {
se << llvm::toString(std::move(err));
return se;
}
} // namespace concretelang
} // namespace mlir

View File

@@ -10,7 +10,7 @@
#include <mlir/IR/BuiltinOps.h>
#include <mlir/Support/LogicalResult.h>
#include <concretelang/Support/KeySet.h>
#include <concretelang/ClientLib/KeySet.h>
namespace mlir {
namespace concretelang {

View File

@@ -6,7 +6,7 @@
#ifndef CONCRETELANG_SUPPORT_JIT_COMPILER_ENGINE_H
#define CONCRETELANG_SUPPORT_JIT_COMPILER_ENGINE_H
#include "concretelang/Support/KeySetCache.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include <concretelang/Support/CompilerEngine.h>
#include <concretelang/Support/Error.h>
#include <concretelang/Support/Jit.h>

View File

@@ -0,0 +1,28 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
// information.
#ifndef CONCRETELANG_SUPPORT_CLIENTPARAMETERS_H_
#define CONCRETELANG_SUPPORT_CLIENTPARAMETERS_H_
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
#include <mlir/IR/BuiltinOps.h>
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/Support/V0Parameters.h"
namespace mlir {
namespace concretelang {
ClientParameters emptyClientParametersForV0(llvm::StringRef functionName,
mlir::ModuleOp module);
llvm::Expected<ClientParameters>
createClientParametersForV0(V0FHEContext context, llvm::StringRef functionName,
mlir::ModuleOp module);
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -6,10 +6,10 @@
#include "llvm/ADT/SmallString.h"
#include "concretelang-c/Support/CompilerEngine.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Jit.h"
#include "concretelang/Support/JitCompilerEngine.h"
#include "concretelang/Support/KeySetCache.h"
using mlir::concretelang::JitCompilerEngine;

View File

@@ -2,6 +2,7 @@ add_subdirectory(Dialect)
add_subdirectory(Conversion)
add_subdirectory(Support)
add_subdirectory(Runtime)
add_subdirectory(ClientLib)
add_subdirectory(Bindings)
# CAPI needed only for python bindings

View File

@@ -0,0 +1,9 @@
add_mlir_library(
ConcretelangClientLib
ClientParameters.cpp
KeySet.cpp
KeySetCache.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/ClientLib
)

View File

@@ -0,0 +1,388 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
// information.
#include "concretelang/ClientLib/ClientParameters.h"
namespace mlir {
namespace concretelang {
// https://stackoverflow.com/a/38140932
static inline void hash(std::size_t &seed) {}
template <typename T, typename... Rest>
static inline void hash(std::size_t &seed, const T &v, Rest... rest) {
// See https://softwareengineering.stackexchange.com/a/402543
const auto GOLDEN_RATIO = 0x9e3779b97f4a7c15; // pseudo random bits
const std::hash<T> hasher;
seed ^= hasher(v) + GOLDEN_RATIO + (seed << 6) + (seed >> 2);
hash(seed, rest...);
}
void LweSecretKeyParam::hash(size_t &seed) {
mlir::concretelang::hash(seed, size);
}
void BootstrapKeyParam::hash(size_t &seed) {
mlir::concretelang::hash(seed, inputSecretKeyID, outputSecretKeyID, level,
baseLog, glweDimension, variance);
}
void KeyswitchKeyParam::hash(size_t &seed) {
mlir::concretelang::hash(seed, inputSecretKeyID, outputSecretKeyID, level,
baseLog, variance);
}
std::size_t ClientParameters::hash() {
std::size_t currentHash = 1;
for (auto secretKeyParam : secretKeys) {
mlir::concretelang::hash(currentHash, secretKeyParam.first);
secretKeyParam.second.hash(currentHash);
}
for (auto bootstrapKeyParam : bootstrapKeys) {
mlir::concretelang::hash(currentHash, bootstrapKeyParam.first);
bootstrapKeyParam.second.hash(currentHash);
}
for (auto keyswitchParam : keyswitchKeys) {
mlir::concretelang::hash(currentHash, keyswitchParam.first);
keyswitchParam.second.hash(currentHash);
}
return currentHash;
}
llvm::json::Value toJSON(const LweSecretKeyParam &v) {
llvm::json::Object object{
{"size", v.size},
};
return object;
}
bool fromJSON(const llvm::json::Value j, LweSecretKeyParam &v,
llvm::json::Path p) {
auto obj = j.getAsObject();
if (obj == nullptr) {
p.report("should be an object");
return false;
}
auto size = obj->getInteger("size");
if (!size.hasValue()) {
p.report("missing size field");
return false;
}
v.size = *size;
return true;
}
llvm::json::Value toJSON(const BootstrapKeyParam &v) {
llvm::json::Object object{
{"inputSecretKeyID", v.inputSecretKeyID},
{"outputSecretKeyID", v.outputSecretKeyID},
{"level", v.level},
{"glweDimension", v.glweDimension},
{"baseLog", v.baseLog},
{"variance", v.variance},
};
return object;
}
bool fromJSON(const llvm::json::Value j, BootstrapKeyParam &v,
llvm::json::Path p) {
auto obj = j.getAsObject();
if (obj == nullptr) {
p.report("should be an object");
return false;
}
auto inputSecretKeyID = obj->getString("inputSecretKeyID");
if (!inputSecretKeyID.hasValue()) {
p.report("missing inputSecretKeyID field");
return false;
}
auto outputSecretKeyID = obj->getString("outputSecretKeyID");
if (!outputSecretKeyID.hasValue()) {
p.report("missing outputSecretKeyID field");
return false;
}
auto level = obj->getInteger("level");
if (!level.hasValue()) {
p.report("missing level field");
return false;
}
auto baseLog = obj->getInteger("baseLog");
if (!baseLog.hasValue()) {
p.report("missing baseLog field");
return false;
}
auto glweDimension = obj->getInteger("glweDimension");
if (!glweDimension.hasValue()) {
p.report("missing glweDimension field");
return false;
}
auto variance = obj->getNumber("variance");
if (!variance.hasValue()) {
p.report("missing variance field");
return false;
}
v.inputSecretKeyID = (std::string)inputSecretKeyID.getValue();
v.outputSecretKeyID = (std::string)outputSecretKeyID.getValue();
v.level = level.getValue();
v.baseLog = baseLog.getValue();
v.glweDimension = glweDimension.getValue();
v.variance = variance.getValue();
return true;
}
llvm::json::Value toJSON(const KeyswitchKeyParam &v) {
llvm::json::Object object{
{"inputSecretKeyID", v.inputSecretKeyID},
{"outputSecretKeyID", v.outputSecretKeyID},
{"level", v.level},
{"baseLog", v.baseLog},
{"variance", v.variance},
};
return object;
}
bool fromJSON(const llvm::json::Value j, KeyswitchKeyParam &v,
llvm::json::Path p) {
auto obj = j.getAsObject();
if (obj == nullptr) {
p.report("should be an object");
return false;
}
auto inputSecretKeyID = obj->getString("inputSecretKeyID");
if (!inputSecretKeyID.hasValue()) {
p.report("missing inputSecretKeyID field");
return false;
}
auto outputSecretKeyID = obj->getString("outputSecretKeyID");
if (!outputSecretKeyID.hasValue()) {
p.report("missing outputSecretKeyID field");
return false;
}
auto level = obj->getInteger("level");
if (!level.hasValue()) {
p.report("missing level field");
return false;
}
auto baseLog = obj->getInteger("baseLog");
if (!baseLog.hasValue()) {
p.report("missing baseLog field");
return false;
}
auto variance = obj->getNumber("variance");
if (!variance.hasValue()) {
p.report("missing variance field");
return false;
}
v.inputSecretKeyID = (std::string)inputSecretKeyID.getValue();
v.outputSecretKeyID = (std::string)outputSecretKeyID.getValue();
v.level = level.getValue();
v.baseLog = baseLog.getValue();
v.variance = variance.getValue();
return true;
}
llvm::json::Value toJSON(const CircuitGateShape &v) {
llvm::json::Object object{
{"width", v.width},
{"dimensions", v.dimensions},
{"size", v.size},
};
return object;
}
bool fromJSON(const llvm::json::Value j, CircuitGateShape &v,
llvm::json::Path p) {
auto obj = j.getAsObject();
if (obj == nullptr) {
p.report("should be an object");
return false;
}
auto width = obj->getInteger("width");
if (!width.hasValue()) {
p.report("missing width field");
return false;
}
auto dimensions = obj->getArray("dimensions");
if (dimensions == nullptr) {
p.report("missing dimensions field");
return false;
}
for (auto dim : *dimensions) {
auto iDim = dim.getAsInteger();
if (!iDim.hasValue()) {
p.report("dimensions must be integer");
return false;
}
v.dimensions.push_back(iDim.getValue());
}
auto size = obj->getInteger("size");
if (!size.hasValue()) {
p.report("missing size field");
return false;
}
v.width = width.getValue();
v.size = size.getValue();
return true;
}
llvm::json::Value toJSON(const Encoding &v) {
llvm::json::Object object{
{"precision", v.precision},
};
return object;
}
bool fromJSON(const llvm::json::Value j, Encoding &v, llvm::json::Path p) {
auto obj = j.getAsObject();
if (obj == nullptr) {
p.report("should be an object");
return false;
}
auto precision = obj->getInteger("precision");
if (!precision.hasValue()) {
p.report("missing precision field");
return false;
}
v.precision = precision.getValue();
return true;
}
llvm::json::Value toJSON(const EncryptionGate &v) {
llvm::json::Object object{
{"secretKeyID", v.secretKeyID},
{"variance", v.variance},
{"encoding", v.encoding},
};
return object;
}
bool fromJSON(const llvm::json::Value j, EncryptionGate &v,
llvm::json::Path p) {
auto obj = j.getAsObject();
if (obj == nullptr) {
p.report("should be an object");
return false;
}
auto secretKeyID = obj->getString("secretKeyID");
if (!secretKeyID.hasValue()) {
p.report("missing secretKeyID field");
return false;
}
v.secretKeyID = (std::string)secretKeyID.getValue();
auto variance = obj->getNumber("variance");
if (!variance.hasValue()) {
p.report("missing variance field");
return false;
}
v.variance = variance.getValue();
auto encoding = obj->get("encoding");
if (encoding == nullptr) {
p.report("missing encoding field");
return false;
}
if (!fromJSON(*encoding, v.encoding, p.field("encoding"))) {
return false;
}
return true;
}
llvm::json::Value toJSON(const CircuitGate &v) {
llvm::json::Object object{
{"encryption", v.encryption},
{"shape", v.shape},
};
return object;
}
bool fromJSON(const llvm::json::Value j, CircuitGate &v, llvm::json::Path p) {
auto obj = j.getAsObject();
auto encryption = obj->get("encryption");
if (encryption == nullptr) {
p.report("missing encryption field");
return false;
}
if (!fromJSON(*encryption, v.encryption, p.field("encryption"))) {
return false;
}
auto shape = obj->get("shape");
if (shape == nullptr) {
p.report("missing shape field");
return false;
}
if (!fromJSON(*shape, v.shape, p.field("shape"))) {
return false;
}
return true;
}
template <typename T> llvm::json::Value toJson(std::map<std::string, T> map) {
llvm::json::Object obj;
for (auto entry : map) {
obj[entry.first] = entry.second;
}
return obj;
}
llvm::json::Value toJSON(const ClientParameters &v) {
llvm::json::Object object{
{"secretKeys", toJson(v.secretKeys)},
{"bootstrapKeys", toJson(v.bootstrapKeys)},
{"keyswitchKeys", toJson(v.keyswitchKeys)},
{"inputs", v.inputs},
{"outputs", v.outputs},
{"functionName", v.functionName},
};
return object;
}
bool fromJSON(const llvm::json::Value j, ClientParameters &v,
llvm::json::Path p) {
auto obj = j.getAsObject();
auto secretkeys = obj->get("secretKeys");
if (secretkeys == nullptr) {
p.report("missing secretKeys field");
return false;
}
if (!fromJSON(*secretkeys, v.secretKeys, p.field("secretKeys"))) {
return false;
}
auto bootstrapKeys = obj->get("bootstrapKeys");
if (bootstrapKeys == nullptr) {
p.report("missing bootstrapKeys field");
return false;
}
if (!fromJSON(*bootstrapKeys, v.bootstrapKeys, p.field("bootstrapKeys"))) {
return false;
}
auto keyswitchKeys = obj->get("keyswitchKeys");
if (keyswitchKeys == nullptr) {
p.report("missing keyswitchKeys field");
return false;
}
if (!fromJSON(*keyswitchKeys, v.keyswitchKeys, p.field("keyswitchKeys"))) {
return false;
}
auto inputs = obj->get("inputs");
if (inputs == nullptr) {
p.report("missing inputs field");
return false;
}
if (!fromJSON(*inputs, v.inputs, p.field("inputs"))) {
return false;
}
auto outputs = obj->get("outputs");
if (outputs == nullptr) {
p.report("missing outputs field");
return false;
}
if (!fromJSON(*outputs, v.outputs, p.field("outputs"))) {
return false;
}
auto functionName = obj->getString("functionName");
if (!functionName.hasValue()) {
p.report("missing functionName field");
return false;
}
v.functionName = (std::string)functionName.getValue();
return true;
}
} // namespace concretelang
} // namespace mlir

View File

@@ -3,7 +3,7 @@
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
// information.
#include "concretelang/Support/KeySet.h"
#include "concretelang/ClientLib/KeySet.h"
#include "concretelang/Support/Error.h"
#define CAPI_ERR_TO_LLVM_ERROR(s, msg) \
@@ -37,23 +37,18 @@ llvm::Expected<std::unique_ptr<KeySet>>
KeySet::generate(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb) {
auto a = uninitialized();
auto keySet = uninitialized();
auto fillError = a->generateKeysFromParams(params, seed_msb, seed_lsb);
if (fillError) {
return StreamStringError()
<< "Cannot fill keys from params: " << std::move(fillError);
if (auto error = keySet->generateKeysFromParams(params, seed_msb, seed_lsb)) {
return std::move(error);
}
fillError = a->setupEncryptionMaterial(params, seed_msb, seed_lsb);
if (fillError) {
return StreamStringError()
<< "Cannot setup encryption material: " << std::move(fillError);
if (auto error =
keySet->setupEncryptionMaterial(params, seed_msb, seed_lsb)) {
return std::move(error);
}
return std::move(a);
return std::move(keySet);
}
std::unique_ptr<KeySet> KeySet::uninitialized() {
@@ -66,8 +61,8 @@ llvm::Error KeySet::setupEncryptionMaterial(ClientParameters &params,
// Set inputs and outputs LWE secret keys
{
for (auto param : params.inputs) {
std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *> input = {
param, nullptr, nullptr};
LweSecretKeyParam secretKeyParam = {0};
LweSecretKey_u64 *secretKey = nullptr;
if (param.encryption.hasValue()) {
auto inputSk = this->secretKeys.find(param.encryption->secretKeyID);
if (inputSk == this->secretKeys.end()) {
@@ -76,14 +71,16 @@ llvm::Error KeySet::setupEncryptionMaterial(ClientParameters &params,
") does not exist ",
llvm::inconvertibleErrorCode());
}
std::get<1>(input) = &inputSk->second.first;
std::get<2>(input) = inputSk->second.second;
secretKeyParam = inputSk->second.first;
secretKey = inputSk->second.second;
}
std::tuple<CircuitGate, LweSecretKeyParam, LweSecretKey_u64 *> input = {
param, secretKeyParam, secretKey};
this->inputs.push_back(input);
}
for (auto param : params.outputs) {
std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *> output =
{param, nullptr, nullptr};
LweSecretKeyParam secretKeyParam = {0};
LweSecretKey_u64 *secretKey = nullptr;
if (param.encryption.hasValue()) {
auto outputSk = this->secretKeys.find(param.encryption->secretKeyID);
if (outputSk == this->secretKeys.end()) {
@@ -91,9 +88,11 @@ llvm::Error KeySet::setupEncryptionMaterial(ClientParameters &params,
"cannot find output key to generate bootstrap key",
llvm::inconvertibleErrorCode());
}
std::get<1>(output) = &outputSk->second.first;
std::get<2>(output) = outputSk->second.second;
secretKeyParam = outputSk->second.first;
secretKey = outputSk->second.second;
}
std::tuple<CircuitGate, LweSecretKeyParam, LweSecretKey_u64 *> output = {
param, secretKeyParam, secretKey};
this->outputs.push_back(output);
}
}
@@ -283,7 +282,7 @@ llvm::Error KeySet::allocate_lwe(size_t argPos,
}
auto inputSk = inputs[argPos];
CAPI_ERR_TO_LLVM_ERROR(*ciphertext = allocate_lwe_ciphertext_u64(
&err, {std::get<1>(inputSk)->size + 1}),
&err, {std::get<1>(inputSk).size + 1}),
"cannot allocate ciphertext");
return llvm::Error::success();
}
@@ -369,4 +368,4 @@ KeySet::getKeyswitchKeys() {
}
} // namespace concretelang
} // namespace mlir
} // namespace mlir

View File

@@ -3,7 +3,7 @@
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
// information.
#include "concretelang/Support/KeySetCache.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/Support/Error.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/FileSystem.h"
@@ -11,6 +11,7 @@
#include <fstream>
#include <functional>
#include <sstream>
#include <string>
extern "C" {
@@ -229,4 +230,4 @@ KeySetCache::tryLoadOrGenerateSave(ClientParameters &params, uint64_t seed_msb,
}
} // namespace concretelang
} // namespace mlir
} // namespace mlir

View File

@@ -1,5 +1,4 @@
add_mlir_library(ConcretelangSupport
Error.cpp
Pipeline.cpp
Jit.cpp
CompilerEngine.cpp
@@ -7,11 +6,9 @@ add_mlir_library(ConcretelangSupport
LambdaArgument.cpp
V0Parameters.cpp
V0Curves.cpp
ClientParameters.cpp
KeySet.cpp
V0ClientParameters.cpp
logging.cpp
Jit.cpp
KeySetCache.cpp
LLVMEmitFile.cpp
ADDITIONAL_HEADER_DIRS
@@ -34,4 +31,5 @@ add_mlir_library(ConcretelangSupport
${LLVM_PTHREAD_LIB}
ConcretelangRuntime
ConcretelangClientLib
)

View File

@@ -3,7 +3,10 @@
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
// information.
#include <fstream>
#include <iostream>
#include <stdio.h>
#include <string>
#include <llvm/Support/Error.h>
#include <llvm/Support/SMLoc.h>
@@ -271,15 +274,23 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
return StreamStringError(
"Cannot generate client parameters, the fhe context is empty");
}
}
// Generate client parameters if requested
auto funcName = this->clientParametersFuncName.getValueOr("main");
if (this->generateClientParameters || target == Target::LIBRARY) {
if (!res.fheContext.hasValue()) {
// Some tests can involves a usual function
res.clientParameters =
mlir::concretelang::emptyClientParametersForV0(funcName, module);
} else {
auto clientParametersOrErr =
mlir::concretelang::createClientParametersForV0(*res.fheContext,
funcName, module);
if (!clientParametersOrErr)
return clientParametersOrErr.takeError();
llvm::Expected<mlir::concretelang::ClientParameters> clientParametersOrErr =
mlir::concretelang::createClientParametersForV0(
*res.fheContext, *this->clientParametersFuncName, module);
if (llvm::Error err = clientParametersOrErr.takeError())
return std::move(err);
res.clientParameters = clientParametersOrErr.get();
res.clientParameters = clientParametersOrErr.get();
}
}
// MLIR canonical dialects -> LLVM Dialect
@@ -334,10 +345,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
llvm::Expected<CompilerEngine::CompilationResult>
CompilerEngine::compile(llvm::StringRef s, Target target, OptionalLib lib) {
std::unique_ptr<llvm::MemoryBuffer> mb = llvm::MemoryBuffer::getMemBuffer(s);
llvm::Expected<CompilationResult> res =
this->compile(std::move(mb), target, lib);
return std::move(res);
return this->compile(std::move(mb), target, lib);
}
// Compile the contained in `buffer` to the target dialect
@@ -351,9 +359,7 @@ CompilerEngine::compile(std::unique_ptr<llvm::MemoryBuffer> buffer,
sm.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
llvm::Expected<CompilationResult> res = this->compile(sm, target, lib);
return std::move(res);
return this->compile(sm, target, lib);
}
template <class T>
@@ -371,10 +377,9 @@ CompilerEngine::compile(std::vector<T> inputs, std::string libraryPath) {
}
}
auto libPath = outputLib->emitShared();
if (!libPath) {
return StreamStringError("Can't link: ")
<< llvm::toString(libPath.takeError());
if (auto err = outputLib->emitArtifacts()) {
return StreamStringError("Can't emit artifacts: ")
<< llvm::toString(std::move(err));
}
return *outputLib.get();
}
@@ -384,7 +389,24 @@ template llvm::Expected<CompilerEngine::Library>
CompilerEngine::compile(std::vector<std::string> inputs,
std::string libraryPath);
/** Returns the path of the shared library */
std::string CompilerEngine::Library::getSharedLibraryPath(std::string path) {
return path + DOT_SHARED_LIB_EXT;
}
/** Returns the path of the static library */
std::string CompilerEngine::Library::getStaticLibraryPath(std::string path) {
return path + DOT_STATIC_LIB_EXT;
}
/** Returns the path of the static library */
std::string CompilerEngine::Library::getClientParametersPath(std::string path) {
return path + CLIENT_PARAMETERS_EXT;
}
const std::string CompilerEngine::Library::OBJECT_EXT = ".o";
const std::string CompilerEngine::Library::CLIENT_PARAMETERS_EXT =
".concrete.params.json";
const std::string CompilerEngine::Library::LINKER = "ld";
const std::string CompilerEngine::Library::LINKER_SHARED_OPT = " --shared -o ";
const std::string CompilerEngine::Library::AR = "ar";
@@ -396,6 +418,23 @@ void CompilerEngine::Library::addExtraObjectFilePath(std::string path) {
objectsPath.push_back(path);
}
llvm::Expected<std::string>
CompilerEngine::Library::emitClientParametersJSON() {
auto clientParamsPath = getClientParametersPath(libraryPath);
llvm::json::Value value(clientParametersList);
std::error_code error;
llvm::raw_fd_ostream out(clientParamsPath, error);
if (error) {
return StreamStringError("cannot emit client parameters, error: ")
<< error.message();
}
out << llvm::formatv("{0:2}", value);
out.close();
return clientParamsPath;
}
llvm::Expected<std::string>
CompilerEngine::Library::addCompilation(CompilationResult &compilation) {
llvm::Module *module = compilation.llvmModule.get();
@@ -405,13 +444,14 @@ CompilerEngine::Library::addCompilation(CompilationResult &compilation) {
std::to_string(objectsPath.size()) + ".mlir";
}
auto objectPath = sourceName + OBJECT_EXT;
auto error = mlir::concretelang::emitObject(*module, objectPath);
if (error) {
if (auto error = mlir::concretelang::emitObject(*module, objectPath)) {
return std::move(error);
}
addExtraObjectFilePath(objectPath);
if (compilation.clientParameters.hasValue()) {
clientParametersList.push_back(compilation.clientParameters.getValue());
}
return objectPath;
}
@@ -437,9 +477,8 @@ llvm::Expected<std::string> CompilerEngine::Library::emit(std::string dotExt,
auto error = mlir::concretelang::emitLibrary(objectsPath, pathDotExt, linker);
if (error) {
return std::move(error);
} else {
return pathDotExt;
}
return pathDotExt;
}
llvm::Expected<std::string> CompilerEngine::Library::emitShared() {
@@ -458,6 +497,19 @@ llvm::Expected<std::string> CompilerEngine::Library::emitStatic() {
return path;
}
llvm::Error CompilerEngine::Library::emitArtifacts() {
if (auto err = emitShared().takeError()) {
return err;
}
if (auto err = emitStatic().takeError()) {
return err;
}
if (auto err = emitClientParametersJSON().takeError()) {
return err;
}
return llvm::Error::success();
}
CompilerEngine::Library::~Library() {
if (cleanUp) {
for (auto path : objectsPath) {

View File

@@ -1,17 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
// information.
#include <concretelang/Support/Error.h>
namespace mlir {
namespace concretelang {
// Specialized `operator<<` for `llvm::Error` that marks the error
// as checked through `std::move` and `llvm::toString`
StreamStringError &operator<<(StreamStringError &se, llvm::Error &err) {
se << llvm::toString(std::move(err));
return se;
}
} // namespace concretelang
} // namespace mlir

View File

@@ -8,8 +8,9 @@
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
#include "concretelang/Support/ClientParameters.h"
#include "concretelang/Support/V0Curves.h"
namespace mlir {
@@ -20,7 +21,7 @@ const auto keyFormat = KEY_FORMAT_BINARY;
const auto v0Curve = getV0Curves(securityLevel, keyFormat);
// For the v0 the secretKeyID and precision are the same for all gates.
llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
llvm::Expected<CircuitGate> gateFromMLIRType(LweSecretKeyID secretKeyID,
Precision precision,
Variance variance,
mlir::Type type) {
@@ -46,10 +47,13 @@ llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
// TODO - Get the width from the LWECiphertextType instead of global
// precision (could be possible after merge concrete-ciphertext-parameter)
return CircuitGate{
.encryption = llvm::Optional<EncryptionGate>({
.secretKeyID = secretKeyID,
.variance = variance,
.encoding = {.precision = precision},
/* .encryption = */ llvm::Optional<EncryptionGate>({
/* .secretKeyID = */ secretKeyID,
/* .variance = */ variance,
/* .encoding = */
{
/* .precision = */ precision,
},
}),
/*.shape = */
{
@@ -77,25 +81,33 @@ llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
"cannot convert MLIR type to shape", llvm::inconvertibleErrorCode());
}
ClientParameters emptyClientParametersForV0(llvm::StringRef functionName,
mlir::ModuleOp module) {
ClientParameters c;
c.functionName = (std::string)functionName;
return c;
}
llvm::Expected<ClientParameters>
createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name,
createClientParametersForV0(V0FHEContext fheContext,
llvm::StringRef functionName,
mlir::ModuleOp module) {
auto v0Param = fheContext.parameter;
Variance encryptionVariance =
v0Curve->getVariance(1, 1 << v0Param.logPolynomialSize, 64);
Variance keyswitchVariance = v0Curve->getVariance(1, v0Param.nSmall, 64);
// Static client parameters from global parameters for v0
ClientParameters c = {};
ClientParameters c;
c.secretKeys = {
{"small", {/*.size = */ v0Param.nSmall}},
{"big", {/*.size = */ v0Param.getNBigGlweDimension()}},
{SMALL_KEY, {/*.size = */ v0Param.nSmall}},
{BIG_KEY, {/*.size = */ v0Param.getNBigGlweDimension()}},
};
c.bootstrapKeys = {
{
"bsk_v0",
{
/*.inputSecretKeyID = */ "small",
/*.outputSecretKeyID = */ "big",
/*.inputSecretKeyID = */ SMALL_KEY,
/*.outputSecretKeyID = */ BIG_KEY,
/*.level = */ v0Param.brLevel,
/*.baseLog = */ v0Param.brLogBase,
/*.glweDimension = */ v0Param.glweDimension,
@@ -107,18 +119,19 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name,
{
"ksk_v0",
{
/*.inputSecretKeyID = */ "big",
/*.outputSecretKeyID = */ "small",
/*.inputSecretKeyID = */ BIG_KEY,
/*.outputSecretKeyID = */ SMALL_KEY,
/*.level = */ v0Param.ksLevel,
/*.baseLog = */ v0Param.ksLogBase,
/*.variance = */ keyswitchVariance,
},
},
};
c.functionName = (std::string)functionName;
// Find the input function
auto rangeOps = module.getOps<mlir::FuncOp>();
auto funcOp = llvm::find_if(
rangeOps, [&](mlir::FuncOp op) { return op.getName() == name; });
rangeOps, [&](mlir::FuncOp op) { return op.getName() == functionName; });
if (funcOp == rangeOps.end()) {
return llvm::make_error<llvm::StringError>(
"cannot find the function for generate client parameters",
@@ -135,14 +148,16 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name,
.isa<mlir::concretelang::Concrete::ContextType>();
for (auto inType = funcType.getInputs().begin();
inType < funcType.getInputs().end() - hasContext; inType++) {
auto gate = gateFromMLIRType("big", precision, encryptionVariance, *inType);
auto gate =
gateFromMLIRType(BIG_KEY, precision, encryptionVariance, *inType);
if (auto err = gate.takeError()) {
return std::move(err);
}
c.inputs.push_back(gate.get());
}
for (auto outType : funcType.getResults()) {
auto gate = gateFromMLIRType("big", precision, encryptionVariance, outType);
auto gate =
gateFromMLIRType(BIG_KEY, precision, encryptionVariance, outType);
if (auto err = gate.takeError()) {
return std::move(err);
}
@@ -151,46 +166,5 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name,
return c;
}
// https://stackoverflow.com/a/38140932
static inline void hash(std::size_t &seed) {}
template <typename T, typename... Rest>
static inline void hash(std::size_t &seed, const T &v, Rest... rest) {
// See https://softwareengineering.stackexchange.com/a/402543
const auto GOLDEN_RATIO = 0x9e3779b97f4a7c15; // pseudo random bits
const std::hash<T> hasher;
seed ^= hasher(v) + GOLDEN_RATIO + (seed << 6) + (seed >> 2);
hash(seed, rest...);
}
void LweSecretKeyParam::hash(size_t &seed) {
mlir::concretelang::hash(seed, size);
}
void BootstrapKeyParam::hash(size_t &seed) {
mlir::concretelang::hash(seed, inputSecretKeyID, outputSecretKeyID, level,
baseLog, glweDimension, variance);
}
void KeyswitchKeyParam::hash(size_t &seed) {
mlir::concretelang::hash(seed, inputSecretKeyID, outputSecretKeyID, level,
baseLog, variance);
}
std::size_t ClientParameters::hash() {
std::size_t currentHash = 1;
for (auto secretKeyParam : secretKeys) {
mlir::concretelang::hash(currentHash, secretKeyParam.first);
secretKeyParam.second.hash(currentHash);
}
for (auto bootstrapKeyParam : bootstrapKeys) {
mlir::concretelang::hash(currentHash, bootstrapKeyParam.first);
bootstrapKeyParam.second.hash(currentHash);
}
for (auto keyswitchParam : keyswitchKeys) {
mlir::concretelang::hash(currentHash, keyswitchParam.first);
keyswitchParam.second.hash(currentHash);
}
return currentHash;
}
} // namespace concretelang
} // namespace mlir

View File

@@ -27,4 +27,4 @@ V0Curves *getV0Curves(int securityLevel, int keyFormat) {
return &curves[securityLevel][keyFormat];
}
} // namespace concretelang
} // namespace mlir
} // namespace mlir

View File

@@ -19,6 +19,7 @@
#include <mlir/Support/ToolUtilities.h>
#include <sstream>
#include "concretelang/ClientLib/KeySet.h"
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
@@ -30,7 +31,6 @@
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
#include "concretelang/Support/Error.h"
#include "concretelang/Support/JitCompilerEngine.h"
#include "concretelang/Support/KeySet.h"
#include "concretelang/Support/LLVMEmitFile.h"
#include "concretelang/Support/Pipeline.h"
#include "concretelang/Support/logging.h"
@@ -112,7 +112,7 @@ static llvm::cl::opt<enum Action> action(
"Lower to LLVM-IR, optimize and dump result")),
llvm::cl::values(clEnumValN(Action::JIT_INVOKE, "jit-invoke",
"Lower and JIT-compile input module and invoke "
"function specified with --jit-funcname")),
"function specified with --funcname")),
llvm::cl::values(clEnumValN(Action::COMPILE, "compile",
"Lower to LLVM-IR, compile to a file")));
@@ -133,10 +133,10 @@ llvm::cl::opt<bool> autoParallelize(
llvm::cl::desc("Generate (and execute if JIT) parallel code"),
llvm::cl::init(false));
llvm::cl::opt<std::string> jitFuncName(
"jit-funcname",
llvm::cl::desc("Name of the function to execute, default 'main'"),
llvm::cl::init<std::string>("main"));
llvm::cl::opt<std::string>
funcName("funcname",
llvm::cl::desc("Name of the function to compile, default 'main'"),
llvm::cl::init<std::string>("main"));
llvm::cl::list<uint64_t>
jitArgs("jit-args",
@@ -216,7 +216,7 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
// The parameter `action` specifies how the buffer should be processed
// and thus defines the output.
//
// If the specified action involves JIT compilation, `jitFuncName`
// If the specified action involves JIT compilation, `funcName`
// designates the function to JIT compile. This function is invoked
// using the parameters given in `jitArgs`.
//
@@ -239,7 +239,7 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
// Compilation output is written to the stream specified by `os`.
mlir::LogicalResult processInputBuffer(
std::unique_ptr<llvm::MemoryBuffer> buffer, std::string sourceFileName,
enum Action action, const std::string &jitFuncName,
enum Action action, const std::string &funcName,
llvm::ArrayRef<uint64_t> jitArgs,
llvm::Optional<size_t> overrideMaxEintPrecision,
llvm::Optional<size_t> overrideMaxMANP, bool verifyDiagnostics,
@@ -269,17 +269,18 @@ mlir::LogicalResult processInputBuffer(
if (overrideMaxMANP.hasValue())
ce.setMaxMANP(overrideMaxMANP.getValue());
ce.setClientParametersFuncName(funcName);
if (fhelinalgTileSizes.hasValue())
ce.setFHELinalgTileSizes(*fhelinalgTileSizes);
if (action == Action::JIT_INVOKE) {
llvm::Expected<mlir::concretelang::JitCompilerEngine::Lambda> lambdaOrErr =
ce.buildLambda(std::move(buffer), jitFuncName, keySetCache);
ce.buildLambda(std::move(buffer), funcName, keySetCache);
if (!lambdaOrErr) {
mlir::concretelang::log_error()
<< "Failed to JIT-compile " << jitFuncName << ": "
<< llvm::toString(std::move(lambdaOrErr.takeError()));
<< "Failed to JIT-compile " << funcName << ": "
<< llvm::toString(lambdaOrErr.takeError());
return mlir::failure();
}
@@ -287,7 +288,7 @@ mlir::LogicalResult processInputBuffer(
if (!resOrErr) {
mlir::concretelang::log_error()
<< "Failed to JIT-invoke " << jitFuncName << " with arguments "
<< "Failed to JIT-invoke " << funcName << " with arguments "
<< jitArgs << ": " << llvm::toString(std::move(resOrErr.takeError()));
return mlir::failure();
}
@@ -425,11 +426,11 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
auto process = [&](std::unique_ptr<llvm::MemoryBuffer> inputBuffer,
llvm::raw_ostream &os) {
return processInputBuffer(
std::move(inputBuffer), fileName, cmdline::action,
cmdline::jitFuncName, cmdline::jitArgs,
cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP,
cmdline::verifyDiagnostics, fhelinalgTileSizes,
cmdline::autoParallelize, jitKeySetCache, os, outputLib);
std::move(inputBuffer), fileName, cmdline::action, cmdline::funcName,
cmdline::jitArgs, cmdline::assumeMaxEintPrecision,
cmdline::assumeMaxMANP, cmdline::verifyDiagnostics,
fhelinalgTileSizes, cmdline::autoParallelize, jitKeySetCache, os,
outputLib);
};
auto &os = output->os();
auto res = mlir::failure();
@@ -446,12 +447,8 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
}
if (cmdline::action == Action::COMPILE) {
auto libPath = outputLib->emitShared();
if (!libPath) {
return mlir::failure();
}
libPath = outputLib->emitStatic();
if (!libPath) {
auto err = outputLib->emitArtifacts();
if (err) {
return mlir::failure();
}
}

View File

@@ -1,3 +1,4 @@
if (CONCRETELANG_UNIT_TESTS)
add_subdirectory(unittest)
add_subdirectory(Support)
endif()

View File

@@ -0,0 +1,23 @@
enable_testing()
include_directories(${PROJECT_SOURCE_DIR}/include)
add_executable(
support_unit_test
support_unit_test.cpp
)
set_source_files_properties(
support_unit_test.cpp
PROPERTIES COMPILE_FLAGS "-fno-rtti"
)
target_link_libraries(
support_unit_test
gtest_main
ConcretelangSupport
)
include(GoogleTest)
gtest_discover_tests(support_unit_test)

View File

@@ -0,0 +1,69 @@
#include <gtest/gtest.h>
#include "../unittest/end_to_end_jit_test.h"
#include "concretelang/ClientLib/ClientParameters.h"
namespace CL = mlir::concretelang;
TEST(Support, client_parameters_json_serde) {
mlir::concretelang::ClientParameters params0;
params0.secretKeys = {
{mlir::concretelang::SMALL_KEY, {/*.size = */ 12}},
{mlir::concretelang::BIG_KEY, {/*.size = */ 14}},
};
params0.bootstrapKeys = {
{
"bsk_v0", {
/*.inputSecretKeyID = */ mlir::concretelang::SMALL_KEY,
/*.outputSecretKeyID = */ mlir::concretelang::BIG_KEY,
/*.level = */ 1,
/*.baseLog = */ 2,
/*.k = */ 3,
/*.variance = */ 0.001
}
},{
"wtf_bsk_v0", {
/*.inputSecretKeyID = */ mlir::concretelang::BIG_KEY,
/*.outputSecretKeyID = */ mlir::concretelang::SMALL_KEY,
/*.level = */ 3,
/*.baseLog = */ 2,
/*.k = */ 1,
/*.variance = */ 0.0001,
}
},
};
params0.keyswitchKeys = {
{
"ksk_v0", {
/*.inputSecretKeyID = */ mlir::concretelang::BIG_KEY,
/*.outputSecretKeyID = */ mlir::concretelang::SMALL_KEY,
/*.level = */ 1,
/*.baseLog = */ 2,
/*.variance = */ 3,
}
}
};
params0.inputs = {
{
/*.encryption = */ {{CL::SMALL_KEY, 0.01, {4}}},
/*.shape = */ {32, {1, 2, 3, 4}, 1 * 2 * 3 * 4},
},
{
/*.encryption = */ {{CL::SMALL_KEY, 0.03, {5}}},
/*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4},
},
};
params0.outputs = {
{
/*.encryption = */ {{CL::SMALL_KEY, 0.03, {5}}},
/*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4},
},
};
auto json = mlir::concretelang::toJSON(params0);
std::string jsonStr;
llvm::raw_string_ostream os(jsonStr);
os << json;
auto parseResult =
llvm::json::parse<mlir::concretelang::ClientParameters>(jsonStr);
ASSERT_EXPECTED_VALUE(parseResult, params0);
}

View File

@@ -3,9 +3,9 @@
#include <gtest/gtest.h>
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/JitCompilerEngine.h"
#include "concretelang/Support/KeySetCache.h"
#include "llvm/Support/Path.h"
#include "globals.h"