diff --git a/compiler/Makefile b/compiler/Makefile index b775d5cdd..7435472f1 100644 --- a/compiler/Makefile +++ b/compiler/Makefile @@ -53,11 +53,20 @@ test-check: concretecompiler file-check not test-python: python-bindings concretecompiler PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/concretelang/python_packages/concretelang_core LD_PRELOAD=$(BUILD_DIR)/lib/libConcretelangRuntime.so pytest -vs tests/python -test: test-check test-end-to-end-jit test-python +test: test-check test-end-to-end-jit test-python support-unit-test test-dataflow: test-end-to-end-jit-dfr test-end-to-end-jit-auto-parallelization -# Unittests +# unit-test + +support-unit-test: build-support-unit-test + $(BUILD_DIR)/bin/support_unit_test + +build-support-unit-test: build-initialized + cmake --build $(BUILD_DIR) --target support_unit_test + + +# test-end-to-end-jit build-end-to-end-jit-test: build-initialized cmake --build $(BUILD_DIR) --target end_to_end_jit_test diff --git a/compiler/include/concretelang/ClientLib/ClientParameters.h b/compiler/include/concretelang/ClientLib/ClientParameters.h new file mode 100644 index 000000000..d929aed71 --- /dev/null +++ b/compiler/include/concretelang/ClientLib/ClientParameters.h @@ -0,0 +1,165 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license +// information. + +#ifndef CONCRETELANG_CLIENTLIB_CLIENTPARAMETERS_H_ +#define CONCRETELANG_CLIENTLIB_CLIENTPARAMETERS_H_ +#include +#include +#include + +#include + +namespace mlir { +namespace concretelang { + +const std::string SMALL_KEY = "small"; +const std::string BIG_KEY = "big"; + +typedef size_t DecompositionLevelCount; +typedef size_t DecompositionBaseLog; +typedef size_t PolynomialSize; +typedef size_t Precision; +typedef double Variance; + +typedef uint64_t LweSize; +typedef uint64_t GlweDimension; + +typedef std::string LweSecretKeyID; +struct LweSecretKeyParam { + LweSize size; + + void hash(size_t &seed); +}; +static bool operator==(const LweSecretKeyParam &lhs, + const LweSecretKeyParam &rhs) { + return lhs.size == rhs.size; +} + +typedef std::string BootstrapKeyID; +struct BootstrapKeyParam { + LweSecretKeyID inputSecretKeyID; + LweSecretKeyID outputSecretKeyID; + DecompositionLevelCount level; + DecompositionBaseLog baseLog; + GlweDimension glweDimension; + Variance variance; + + void hash(size_t &seed); +}; +static inline bool operator==(const BootstrapKeyParam &lhs, + const BootstrapKeyParam &rhs) { + return lhs.inputSecretKeyID == rhs.inputSecretKeyID && + lhs.outputSecretKeyID == rhs.outputSecretKeyID && + lhs.level == rhs.level && lhs.baseLog == rhs.baseLog && + lhs.glweDimension == rhs.glweDimension && lhs.variance == rhs.variance; +} + +typedef std::string KeyswitchKeyID; +struct KeyswitchKeyParam { + LweSecretKeyID inputSecretKeyID; + LweSecretKeyID outputSecretKeyID; + DecompositionLevelCount level; + DecompositionBaseLog baseLog; + Variance variance; + + void hash(size_t &seed); +}; +static inline bool operator==(const KeyswitchKeyParam &lhs, + const KeyswitchKeyParam &rhs) { + return lhs.inputSecretKeyID == rhs.inputSecretKeyID && + lhs.outputSecretKeyID == rhs.outputSecretKeyID && + lhs.level == rhs.level && lhs.baseLog == rhs.baseLog && + lhs.variance == rhs.variance; +} + +struct Encoding { + Precision precision; +}; +static inline bool operator==(const Encoding &lhs, const Encoding &rhs) { + return lhs.precision == rhs.precision; +} + +struct EncryptionGate { + LweSecretKeyID secretKeyID; + Variance variance; + Encoding encoding; +}; +static inline bool operator==(const EncryptionGate &lhs, + const EncryptionGate &rhs) { + return lhs.secretKeyID == rhs.secretKeyID && lhs.variance == rhs.variance && + lhs.encoding == rhs.encoding; +} + +struct CircuitGateShape { + // Width of the scalar value + size_t width; + // Dimensions of the tensor, empty if scalar + std::vector dimensions; + // Size of the buffer containing the tensor + size_t size; +}; +static inline bool operator==(const CircuitGateShape &lhs, + const CircuitGateShape &rhs) { + return lhs.width == rhs.width && lhs.dimensions == rhs.dimensions && + lhs.size == rhs.size; +} + +struct CircuitGate { + llvm::Optional encryption; + CircuitGateShape shape; +}; +static inline bool operator==(const CircuitGate &lhs, const CircuitGate &rhs) { + return lhs.encryption == rhs.encryption && lhs.shape == rhs.shape; +} + +struct ClientParameters { + std::map secretKeys; + std::map bootstrapKeys; + std::map keyswitchKeys; + std::vector inputs; + std::vector outputs; + std::string functionName; + size_t hash(); +}; +static inline bool operator==(const ClientParameters &lhs, + const ClientParameters &rhs) { + return lhs.secretKeys == rhs.secretKeys && + lhs.bootstrapKeys == rhs.bootstrapKeys && + lhs.keyswitchKeys == rhs.keyswitchKeys && lhs.inputs == lhs.inputs && + lhs.outputs == lhs.outputs; +} + +llvm::json::Value toJSON(const LweSecretKeyParam &); +bool fromJSON(const llvm::json::Value, LweSecretKeyParam &, llvm::json::Path); + +llvm::json::Value toJSON(const BootstrapKeyParam &); +bool fromJSON(const llvm::json::Value, BootstrapKeyParam &, llvm::json::Path); + +llvm::json::Value toJSON(const KeyswitchKeyParam &); +bool fromJSON(const llvm::json::Value, KeyswitchKeyParam &, llvm::json::Path); + +llvm::json::Value toJSON(const Encoding &); +bool fromJSON(const llvm::json::Value, Encoding &, llvm::json::Path); + +llvm::json::Value toJSON(const EncryptionGate &); +bool fromJSON(const llvm::json::Value, EncryptionGate &, llvm::json::Path); + +llvm::json::Value toJSON(const CircuitGateShape &); +bool fromJSON(const llvm::json::Value, CircuitGateShape &, llvm::json::Path); + +llvm::json::Value toJSON(const CircuitGate &); +bool fromJSON(const llvm::json::Value, CircuitGate &, llvm::json::Path); + +llvm::json::Value toJSON(const ClientParameters &); +bool fromJSON(const llvm::json::Value, ClientParameters &, llvm::json::Path); +static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, + ClientParameters cp) { + return OS << llvm::formatv("{0:2}", toJSON(cp)); +} + +} // namespace concretelang +} // namespace mlir + +#endif \ No newline at end of file diff --git a/compiler/include/concretelang/Support/KeySet.h b/compiler/include/concretelang/ClientLib/KeySet.h similarity index 92% rename from compiler/include/concretelang/Support/KeySet.h rename to compiler/include/concretelang/ClientLib/KeySet.h index 7bc579822..573224a3c 100644 --- a/compiler/include/concretelang/Support/KeySet.h +++ b/compiler/include/concretelang/ClientLib/KeySet.h @@ -6,7 +6,6 @@ #ifndef CONCRETELANG_SUPPORT_KEYSET_H_ #define CONCRETELANG_SUPPORT_KEYSET_H_ -#include "llvm/Support/Error.h" #include extern "C" { @@ -14,8 +13,8 @@ extern "C" { #include "concretelang/Runtime/context.h" } -#include "concretelang/Support/ClientParameters.h" -#include "concretelang/Support/KeySetCache.h" +#include "concretelang/ClientLib/ClientParameters.h" +#include "concretelang/ClientLib/KeySetCache.h" namespace mlir { namespace concretelang { @@ -91,9 +90,9 @@ private: bootstrapKeys; std::map> keyswitchKeys; - std::vector> + std::vector> inputs; - std::vector> + std::vector> outputs; void setKeys( diff --git a/compiler/include/concretelang/Support/KeySetCache.h b/compiler/include/concretelang/ClientLib/KeySetCache.h similarity index 95% rename from compiler/include/concretelang/Support/KeySetCache.h rename to compiler/include/concretelang/ClientLib/KeySetCache.h index 507783eb3..7539d470e 100644 --- a/compiler/include/concretelang/Support/KeySetCache.h +++ b/compiler/include/concretelang/ClientLib/KeySetCache.h @@ -6,7 +6,7 @@ #ifndef CONCRETELANG_SUPPORT_KEYSETCACHE_H_ #define CONCRETELANG_SUPPORT_KEYSETCACHE_H_ -#include "concretelang/Support/KeySet.h" +#include "concretelang/ClientLib/KeySet.h" namespace mlir { namespace concretelang { diff --git a/compiler/include/concretelang/Support/ClientParameters.h b/compiler/include/concretelang/Support/ClientParameters.h deleted file mode 100644 index 1fa4d979c..000000000 --- a/compiler/include/concretelang/Support/ClientParameters.h +++ /dev/null @@ -1,99 +0,0 @@ -// Part of the Concrete Compiler Project, under the BSD3 License with Zama -// Exceptions. See -// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license -// information. - -#ifndef CONCRETELANG_SUPPORT_CLIENTPARAMETERS_H_ -#define CONCRETELANG_SUPPORT_CLIENTPARAMETERS_H_ -#include -#include -#include - -#include -#include - -#include "concretelang/Support/V0Parameters.h" - -namespace mlir { -namespace concretelang { - -typedef size_t DecompositionLevelCount; -typedef size_t DecompositionBaseLog; -typedef size_t PolynomialSize; -typedef size_t Precision; -typedef double Variance; - -typedef uint64_t LweSize; -typedef uint64_t GlweDimension; - -typedef std::string LweSecretKeyID; -struct LweSecretKeyParam { - LweSize size; - - void hash(size_t &seed); -}; - -typedef std::string BootstrapKeyID; -struct BootstrapKeyParam { - LweSecretKeyID inputSecretKeyID; - LweSecretKeyID outputSecretKeyID; - DecompositionLevelCount level; - DecompositionBaseLog baseLog; - GlweDimension glweDimension; - Variance variance; - - void hash(size_t &seed); -}; - -typedef std::string KeyswitchKeyID; -struct KeyswitchKeyParam { - LweSecretKeyID inputSecretKeyID; - LweSecretKeyID outputSecretKeyID; - DecompositionLevelCount level; - DecompositionBaseLog baseLog; - Variance variance; - - void hash(size_t &seed); -}; - -struct Encoding { - Precision precision; -}; - -struct EncryptionGate { - LweSecretKeyID secretKeyID; - Variance variance; - Encoding encoding; -}; - -struct CircuitGateShape { - // Width of the scalar value - size_t width; - // Dimensions of the tensor, empty if scalar - std::vector dimensions; - // Size of the buffer containing the tensor - size_t size; -}; - -struct CircuitGate { - llvm::Optional encryption; - CircuitGateShape shape; -}; - -struct ClientParameters { - std::map secretKeys; - std::map bootstrapKeys; - std::map keyswitchKeys; - std::vector inputs; - std::vector outputs; - size_t hash(); -}; - -llvm::Expected -createClientParametersForV0(V0FHEContext context, llvm::StringRef name, - mlir::ModuleOp module); - -} // namespace concretelang -} // namespace mlir - -#endif \ No newline at end of file diff --git a/compiler/include/concretelang/Support/CompilerEngine.h b/compiler/include/concretelang/Support/CompilerEngine.h index 31b1e2680..4dd7c7df9 100644 --- a/compiler/include/concretelang/Support/CompilerEngine.h +++ b/compiler/include/concretelang/Support/CompilerEngine.h @@ -7,7 +7,7 @@ #define CONCRETELANG_SUPPORT_COMPILER_ENGINE_H #include -#include +#include #include #include #include @@ -58,6 +58,7 @@ public: class Library { std::string libraryPath; std::vector objectsPath; + std::vector clientParametersList; bool cleanUp; public: @@ -68,21 +69,38 @@ public: : libraryPath(libraryPath), cleanUp(cleanUp) {} /** Add a compilation result to the library */ llvm::Expected addCompilation(CompilationResult &compilation); - /** Emit a shared library with the previously added compilation result */ - llvm::Expected emitShared(); - /** Emit a shared library with the previously added compilation result */ - llvm::Expected emitStatic(); + /** Emit the library artifacts with the previously added compilation result + */ + llvm::Error emitArtifacts(); /** After a shared library has been emitted, its path is here */ std::string sharedLibraryPath; /** After a static library has been emitted, its path is here */ std::string staticLibraryPath; + /** Returns the path of the shared library */ + static std::string getSharedLibraryPath(std::string path); + + /** Returns the path of the static library */ + static std::string getStaticLibraryPath(std::string path); + + /** Returns the path of the static library */ + static std::string getClientParametersPath(std::string path); + // For advanced use - const static std::string OBJECT_EXT, LINKER, LINKER_SHARED_OPT, AR, - AR_STATIC_OPT, DOT_STATIC_LIB_EXT, DOT_SHARED_LIB_EXT; + const static std::string OBJECT_EXT, CLIENT_PARAMETERS_EXT, LINKER, + LINKER_SHARED_OPT, AR, AR_STATIC_OPT, DOT_STATIC_LIB_EXT, + DOT_SHARED_LIB_EXT; void addExtraObjectFilePath(std::string objectFilePath); llvm::Expected emit(std::string dotExt, std::string linker); ~Library(); + + private: + /** Emit a shared library with the previously added compilation result */ + llvm::Expected emitStatic(); + /** Emit a shared library with the previously added compilation result */ + llvm::Expected emitShared(); + /** Emit a shared library with the previously added compilation result */ + llvm::Expected emitClientParametersJSON(); }; // Specification of the exit stage of the compilation pipeline diff --git a/compiler/include/concretelang/Support/Error.h b/compiler/include/concretelang/Support/Error.h index 24149f82f..0697fc648 100644 --- a/compiler/include/concretelang/Support/Error.h +++ b/compiler/include/concretelang/Support/Error.h @@ -50,7 +50,10 @@ protected: llvm::raw_string_ostream os; }; -StreamStringError &operator<<(StreamStringError &se, llvm::Error &err); +inline StreamStringError &operator<<(StreamStringError &se, llvm::Error &err) { + se << llvm::toString(std::move(err)); + return se; +} } // namespace concretelang } // namespace mlir diff --git a/compiler/include/concretelang/Support/Jit.h b/compiler/include/concretelang/Support/Jit.h index 20b7ea66d..b5aea676c 100644 --- a/compiler/include/concretelang/Support/Jit.h +++ b/compiler/include/concretelang/Support/Jit.h @@ -10,7 +10,7 @@ #include #include -#include +#include namespace mlir { namespace concretelang { diff --git a/compiler/include/concretelang/Support/JitCompilerEngine.h b/compiler/include/concretelang/Support/JitCompilerEngine.h index 11bab7355..7a290040e 100644 --- a/compiler/include/concretelang/Support/JitCompilerEngine.h +++ b/compiler/include/concretelang/Support/JitCompilerEngine.h @@ -6,7 +6,7 @@ #ifndef CONCRETELANG_SUPPORT_JIT_COMPILER_ENGINE_H #define CONCRETELANG_SUPPORT_JIT_COMPILER_ENGINE_H -#include "concretelang/Support/KeySetCache.h" +#include "concretelang/ClientLib/KeySetCache.h" #include #include #include diff --git a/compiler/include/concretelang/Support/V0ClientParameters.h b/compiler/include/concretelang/Support/V0ClientParameters.h new file mode 100644 index 000000000..de8b61c9c --- /dev/null +++ b/compiler/include/concretelang/Support/V0ClientParameters.h @@ -0,0 +1,28 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license +// information. + +#ifndef CONCRETELANG_SUPPORT_CLIENTPARAMETERS_H_ +#define CONCRETELANG_SUPPORT_CLIENTPARAMETERS_H_ + +#include +#include + +#include "concretelang/ClientLib/ClientParameters.h" +#include "concretelang/Support/V0Parameters.h" + +namespace mlir { +namespace concretelang { + +ClientParameters emptyClientParametersForV0(llvm::StringRef functionName, + mlir::ModuleOp module); + +llvm::Expected +createClientParametersForV0(V0FHEContext context, llvm::StringRef functionName, + mlir::ModuleOp module); + +} // namespace concretelang +} // namespace mlir + +#endif \ No newline at end of file diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index 71ae9f113..7e594c53f 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -6,10 +6,10 @@ #include "llvm/ADT/SmallString.h" #include "concretelang-c/Support/CompilerEngine.h" +#include "concretelang/ClientLib/KeySetCache.h" #include "concretelang/Support/CompilerEngine.h" #include "concretelang/Support/Jit.h" #include "concretelang/Support/JitCompilerEngine.h" -#include "concretelang/Support/KeySetCache.h" using mlir::concretelang::JitCompilerEngine; diff --git a/compiler/lib/CMakeLists.txt b/compiler/lib/CMakeLists.txt index a8aede839..e35d64627 100644 --- a/compiler/lib/CMakeLists.txt +++ b/compiler/lib/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(Dialect) add_subdirectory(Conversion) add_subdirectory(Support) add_subdirectory(Runtime) +add_subdirectory(ClientLib) add_subdirectory(Bindings) # CAPI needed only for python bindings diff --git a/compiler/lib/ClientLib/CMakeLists.txt b/compiler/lib/ClientLib/CMakeLists.txt new file mode 100644 index 000000000..e4e178b46 --- /dev/null +++ b/compiler/lib/ClientLib/CMakeLists.txt @@ -0,0 +1,9 @@ +add_mlir_library( + ConcretelangClientLib + ClientParameters.cpp + KeySet.cpp + KeySetCache.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/concretelang/ClientLib +) \ No newline at end of file diff --git a/compiler/lib/ClientLib/ClientParameters.cpp b/compiler/lib/ClientLib/ClientParameters.cpp new file mode 100644 index 000000000..02adbb205 --- /dev/null +++ b/compiler/lib/ClientLib/ClientParameters.cpp @@ -0,0 +1,388 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license +// information. + +#include "concretelang/ClientLib/ClientParameters.h" + +namespace mlir { +namespace concretelang { + +// https://stackoverflow.com/a/38140932 +static inline void hash(std::size_t &seed) {} +template +static inline void hash(std::size_t &seed, const T &v, Rest... rest) { + // See https://softwareengineering.stackexchange.com/a/402543 + const auto GOLDEN_RATIO = 0x9e3779b97f4a7c15; // pseudo random bits + const std::hash hasher; + seed ^= hasher(v) + GOLDEN_RATIO + (seed << 6) + (seed >> 2); + hash(seed, rest...); +} + +void LweSecretKeyParam::hash(size_t &seed) { + mlir::concretelang::hash(seed, size); +} + +void BootstrapKeyParam::hash(size_t &seed) { + mlir::concretelang::hash(seed, inputSecretKeyID, outputSecretKeyID, level, + baseLog, glweDimension, variance); +} + +void KeyswitchKeyParam::hash(size_t &seed) { + mlir::concretelang::hash(seed, inputSecretKeyID, outputSecretKeyID, level, + baseLog, variance); +} + +std::size_t ClientParameters::hash() { + std::size_t currentHash = 1; + for (auto secretKeyParam : secretKeys) { + mlir::concretelang::hash(currentHash, secretKeyParam.first); + secretKeyParam.second.hash(currentHash); + } + for (auto bootstrapKeyParam : bootstrapKeys) { + mlir::concretelang::hash(currentHash, bootstrapKeyParam.first); + bootstrapKeyParam.second.hash(currentHash); + } + for (auto keyswitchParam : keyswitchKeys) { + mlir::concretelang::hash(currentHash, keyswitchParam.first); + keyswitchParam.second.hash(currentHash); + } + return currentHash; +} + +llvm::json::Value toJSON(const LweSecretKeyParam &v) { + llvm::json::Object object{ + {"size", v.size}, + }; + return object; +} + +bool fromJSON(const llvm::json::Value j, LweSecretKeyParam &v, + llvm::json::Path p) { + auto obj = j.getAsObject(); + if (obj == nullptr) { + p.report("should be an object"); + return false; + } + auto size = obj->getInteger("size"); + if (!size.hasValue()) { + p.report("missing size field"); + return false; + } + v.size = *size; + return true; +} + +llvm::json::Value toJSON(const BootstrapKeyParam &v) { + llvm::json::Object object{ + {"inputSecretKeyID", v.inputSecretKeyID}, + {"outputSecretKeyID", v.outputSecretKeyID}, + {"level", v.level}, + {"glweDimension", v.glweDimension}, + {"baseLog", v.baseLog}, + {"variance", v.variance}, + }; + return object; +} + +bool fromJSON(const llvm::json::Value j, BootstrapKeyParam &v, + llvm::json::Path p) { + auto obj = j.getAsObject(); + if (obj == nullptr) { + p.report("should be an object"); + return false; + } + auto inputSecretKeyID = obj->getString("inputSecretKeyID"); + if (!inputSecretKeyID.hasValue()) { + p.report("missing inputSecretKeyID field"); + return false; + } + auto outputSecretKeyID = obj->getString("outputSecretKeyID"); + if (!outputSecretKeyID.hasValue()) { + p.report("missing outputSecretKeyID field"); + return false; + } + auto level = obj->getInteger("level"); + if (!level.hasValue()) { + p.report("missing level field"); + return false; + } + auto baseLog = obj->getInteger("baseLog"); + if (!baseLog.hasValue()) { + p.report("missing baseLog field"); + return false; + } + auto glweDimension = obj->getInteger("glweDimension"); + if (!glweDimension.hasValue()) { + p.report("missing glweDimension field"); + return false; + } + auto variance = obj->getNumber("variance"); + if (!variance.hasValue()) { + p.report("missing variance field"); + return false; + } + v.inputSecretKeyID = (std::string)inputSecretKeyID.getValue(); + v.outputSecretKeyID = (std::string)outputSecretKeyID.getValue(); + v.level = level.getValue(); + v.baseLog = baseLog.getValue(); + v.glweDimension = glweDimension.getValue(); + v.variance = variance.getValue(); + return true; +} + +llvm::json::Value toJSON(const KeyswitchKeyParam &v) { + llvm::json::Object object{ + {"inputSecretKeyID", v.inputSecretKeyID}, + {"outputSecretKeyID", v.outputSecretKeyID}, + {"level", v.level}, + {"baseLog", v.baseLog}, + {"variance", v.variance}, + }; + return object; +} +bool fromJSON(const llvm::json::Value j, KeyswitchKeyParam &v, + llvm::json::Path p) { + auto obj = j.getAsObject(); + if (obj == nullptr) { + p.report("should be an object"); + return false; + } + auto inputSecretKeyID = obj->getString("inputSecretKeyID"); + if (!inputSecretKeyID.hasValue()) { + p.report("missing inputSecretKeyID field"); + return false; + } + auto outputSecretKeyID = obj->getString("outputSecretKeyID"); + if (!outputSecretKeyID.hasValue()) { + p.report("missing outputSecretKeyID field"); + return false; + } + auto level = obj->getInteger("level"); + if (!level.hasValue()) { + p.report("missing level field"); + return false; + } + auto baseLog = obj->getInteger("baseLog"); + if (!baseLog.hasValue()) { + p.report("missing baseLog field"); + return false; + } + auto variance = obj->getNumber("variance"); + if (!variance.hasValue()) { + p.report("missing variance field"); + return false; + } + v.inputSecretKeyID = (std::string)inputSecretKeyID.getValue(); + v.outputSecretKeyID = (std::string)outputSecretKeyID.getValue(); + v.level = level.getValue(); + v.baseLog = baseLog.getValue(); + v.variance = variance.getValue(); + return true; +} + +llvm::json::Value toJSON(const CircuitGateShape &v) { + llvm::json::Object object{ + {"width", v.width}, + {"dimensions", v.dimensions}, + {"size", v.size}, + }; + return object; +} +bool fromJSON(const llvm::json::Value j, CircuitGateShape &v, + llvm::json::Path p) { + auto obj = j.getAsObject(); + if (obj == nullptr) { + p.report("should be an object"); + return false; + } + auto width = obj->getInteger("width"); + if (!width.hasValue()) { + p.report("missing width field"); + return false; + } + auto dimensions = obj->getArray("dimensions"); + if (dimensions == nullptr) { + p.report("missing dimensions field"); + return false; + } + for (auto dim : *dimensions) { + auto iDim = dim.getAsInteger(); + if (!iDim.hasValue()) { + p.report("dimensions must be integer"); + return false; + } + v.dimensions.push_back(iDim.getValue()); + } + auto size = obj->getInteger("size"); + if (!size.hasValue()) { + p.report("missing size field"); + return false; + } + v.width = width.getValue(); + v.size = size.getValue(); + return true; +} + +llvm::json::Value toJSON(const Encoding &v) { + llvm::json::Object object{ + {"precision", v.precision}, + }; + return object; +} +bool fromJSON(const llvm::json::Value j, Encoding &v, llvm::json::Path p) { + auto obj = j.getAsObject(); + if (obj == nullptr) { + p.report("should be an object"); + return false; + } + auto precision = obj->getInteger("precision"); + if (!precision.hasValue()) { + p.report("missing precision field"); + return false; + } + v.precision = precision.getValue(); + return true; +} + +llvm::json::Value toJSON(const EncryptionGate &v) { + llvm::json::Object object{ + {"secretKeyID", v.secretKeyID}, + {"variance", v.variance}, + {"encoding", v.encoding}, + }; + return object; +} +bool fromJSON(const llvm::json::Value j, EncryptionGate &v, + llvm::json::Path p) { + auto obj = j.getAsObject(); + if (obj == nullptr) { + p.report("should be an object"); + return false; + } + auto secretKeyID = obj->getString("secretKeyID"); + if (!secretKeyID.hasValue()) { + p.report("missing secretKeyID field"); + return false; + } + v.secretKeyID = (std::string)secretKeyID.getValue(); + auto variance = obj->getNumber("variance"); + if (!variance.hasValue()) { + p.report("missing variance field"); + return false; + } + v.variance = variance.getValue(); + auto encoding = obj->get("encoding"); + if (encoding == nullptr) { + p.report("missing encoding field"); + return false; + } + if (!fromJSON(*encoding, v.encoding, p.field("encoding"))) { + return false; + } + return true; +} + +llvm::json::Value toJSON(const CircuitGate &v) { + llvm::json::Object object{ + {"encryption", v.encryption}, + {"shape", v.shape}, + }; + return object; +} +bool fromJSON(const llvm::json::Value j, CircuitGate &v, llvm::json::Path p) { + auto obj = j.getAsObject(); + auto encryption = obj->get("encryption"); + if (encryption == nullptr) { + p.report("missing encryption field"); + return false; + } + if (!fromJSON(*encryption, v.encryption, p.field("encryption"))) { + return false; + } + auto shape = obj->get("shape"); + if (shape == nullptr) { + p.report("missing shape field"); + return false; + } + if (!fromJSON(*shape, v.shape, p.field("shape"))) { + return false; + } + return true; +} + +template llvm::json::Value toJson(std::map map) { + llvm::json::Object obj; + for (auto entry : map) { + obj[entry.first] = entry.second; + } + return obj; +} + +llvm::json::Value toJSON(const ClientParameters &v) { + llvm::json::Object object{ + {"secretKeys", toJson(v.secretKeys)}, + {"bootstrapKeys", toJson(v.bootstrapKeys)}, + {"keyswitchKeys", toJson(v.keyswitchKeys)}, + {"inputs", v.inputs}, + {"outputs", v.outputs}, + {"functionName", v.functionName}, + }; + return object; +} +bool fromJSON(const llvm::json::Value j, ClientParameters &v, + llvm::json::Path p) { + + auto obj = j.getAsObject(); + auto secretkeys = obj->get("secretKeys"); + if (secretkeys == nullptr) { + p.report("missing secretKeys field"); + return false; + } + if (!fromJSON(*secretkeys, v.secretKeys, p.field("secretKeys"))) { + return false; + } + auto bootstrapKeys = obj->get("bootstrapKeys"); + if (bootstrapKeys == nullptr) { + p.report("missing bootstrapKeys field"); + return false; + } + if (!fromJSON(*bootstrapKeys, v.bootstrapKeys, p.field("bootstrapKeys"))) { + return false; + } + auto keyswitchKeys = obj->get("keyswitchKeys"); + if (keyswitchKeys == nullptr) { + p.report("missing keyswitchKeys field"); + return false; + } + if (!fromJSON(*keyswitchKeys, v.keyswitchKeys, p.field("keyswitchKeys"))) { + return false; + } + auto inputs = obj->get("inputs"); + if (inputs == nullptr) { + p.report("missing inputs field"); + return false; + } + if (!fromJSON(*inputs, v.inputs, p.field("inputs"))) { + return false; + } + auto outputs = obj->get("outputs"); + if (outputs == nullptr) { + p.report("missing outputs field"); + return false; + } + if (!fromJSON(*outputs, v.outputs, p.field("outputs"))) { + return false; + } + auto functionName = obj->getString("functionName"); + if (!functionName.hasValue()) { + p.report("missing functionName field"); + return false; + } + v.functionName = (std::string)functionName.getValue(); + + return true; +} + +} // namespace concretelang +} // namespace mlir diff --git a/compiler/lib/Support/KeySet.cpp b/compiler/lib/ClientLib/KeySet.cpp similarity index 92% rename from compiler/lib/Support/KeySet.cpp rename to compiler/lib/ClientLib/KeySet.cpp index e36b71450..596d9d1b6 100644 --- a/compiler/lib/Support/KeySet.cpp +++ b/compiler/lib/ClientLib/KeySet.cpp @@ -3,7 +3,7 @@ // https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license // information. -#include "concretelang/Support/KeySet.h" +#include "concretelang/ClientLib/KeySet.h" #include "concretelang/Support/Error.h" #define CAPI_ERR_TO_LLVM_ERROR(s, msg) \ @@ -37,23 +37,18 @@ llvm::Expected> KeySet::generate(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb) { - auto a = uninitialized(); + auto keySet = uninitialized(); - auto fillError = a->generateKeysFromParams(params, seed_msb, seed_lsb); - - if (fillError) { - return StreamStringError() - << "Cannot fill keys from params: " << std::move(fillError); + if (auto error = keySet->generateKeysFromParams(params, seed_msb, seed_lsb)) { + return std::move(error); } - fillError = a->setupEncryptionMaterial(params, seed_msb, seed_lsb); - - if (fillError) { - return StreamStringError() - << "Cannot setup encryption material: " << std::move(fillError); + if (auto error = + keySet->setupEncryptionMaterial(params, seed_msb, seed_lsb)) { + return std::move(error); } - return std::move(a); + return std::move(keySet); } std::unique_ptr KeySet::uninitialized() { @@ -66,8 +61,8 @@ llvm::Error KeySet::setupEncryptionMaterial(ClientParameters ¶ms, // Set inputs and outputs LWE secret keys { for (auto param : params.inputs) { - std::tuple input = { - param, nullptr, nullptr}; + LweSecretKeyParam secretKeyParam = {0}; + LweSecretKey_u64 *secretKey = nullptr; if (param.encryption.hasValue()) { auto inputSk = this->secretKeys.find(param.encryption->secretKeyID); if (inputSk == this->secretKeys.end()) { @@ -76,14 +71,16 @@ llvm::Error KeySet::setupEncryptionMaterial(ClientParameters ¶ms, ") does not exist ", llvm::inconvertibleErrorCode()); } - std::get<1>(input) = &inputSk->second.first; - std::get<2>(input) = inputSk->second.second; + secretKeyParam = inputSk->second.first; + secretKey = inputSk->second.second; } + std::tuple input = { + param, secretKeyParam, secretKey}; this->inputs.push_back(input); } for (auto param : params.outputs) { - std::tuple output = - {param, nullptr, nullptr}; + LweSecretKeyParam secretKeyParam = {0}; + LweSecretKey_u64 *secretKey = nullptr; if (param.encryption.hasValue()) { auto outputSk = this->secretKeys.find(param.encryption->secretKeyID); if (outputSk == this->secretKeys.end()) { @@ -91,9 +88,11 @@ llvm::Error KeySet::setupEncryptionMaterial(ClientParameters ¶ms, "cannot find output key to generate bootstrap key", llvm::inconvertibleErrorCode()); } - std::get<1>(output) = &outputSk->second.first; - std::get<2>(output) = outputSk->second.second; + secretKeyParam = outputSk->second.first; + secretKey = outputSk->second.second; } + std::tuple output = { + param, secretKeyParam, secretKey}; this->outputs.push_back(output); } } @@ -283,7 +282,7 @@ llvm::Error KeySet::allocate_lwe(size_t argPos, } auto inputSk = inputs[argPos]; CAPI_ERR_TO_LLVM_ERROR(*ciphertext = allocate_lwe_ciphertext_u64( - &err, {std::get<1>(inputSk)->size + 1}), + &err, {std::get<1>(inputSk).size + 1}), "cannot allocate ciphertext"); return llvm::Error::success(); } @@ -369,4 +368,4 @@ KeySet::getKeyswitchKeys() { } } // namespace concretelang -} // namespace mlir \ No newline at end of file +} // namespace mlir diff --git a/compiler/lib/Support/KeySetCache.cpp b/compiler/lib/ClientLib/KeySetCache.cpp similarity index 98% rename from compiler/lib/Support/KeySetCache.cpp rename to compiler/lib/ClientLib/KeySetCache.cpp index b66491166..c060fb905 100644 --- a/compiler/lib/Support/KeySetCache.cpp +++ b/compiler/lib/ClientLib/KeySetCache.cpp @@ -3,7 +3,7 @@ // https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license // information. -#include "concretelang/Support/KeySetCache.h" +#include "concretelang/ClientLib/KeySetCache.h" #include "concretelang/Support/Error.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/FileSystem.h" @@ -11,6 +11,7 @@ #include #include +#include #include extern "C" { @@ -229,4 +230,4 @@ KeySetCache::tryLoadOrGenerateSave(ClientParameters ¶ms, uint64_t seed_msb, } } // namespace concretelang -} // namespace mlir \ No newline at end of file +} // namespace mlir diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index 2b4174a87..52a7bbeb5 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -1,5 +1,4 @@ add_mlir_library(ConcretelangSupport - Error.cpp Pipeline.cpp Jit.cpp CompilerEngine.cpp @@ -7,11 +6,9 @@ add_mlir_library(ConcretelangSupport LambdaArgument.cpp V0Parameters.cpp V0Curves.cpp - ClientParameters.cpp - KeySet.cpp + V0ClientParameters.cpp logging.cpp Jit.cpp - KeySetCache.cpp LLVMEmitFile.cpp ADDITIONAL_HEADER_DIRS @@ -34,4 +31,5 @@ add_mlir_library(ConcretelangSupport ${LLVM_PTHREAD_LIB} ConcretelangRuntime + ConcretelangClientLib ) diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 9a76a7668..70ba671a8 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -3,7 +3,10 @@ // https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license // information. +#include +#include #include +#include #include #include @@ -271,15 +274,23 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { return StreamStringError( "Cannot generate client parameters, the fhe context is empty"); } + } + // Generate client parameters if requested + auto funcName = this->clientParametersFuncName.getValueOr("main"); + if (this->generateClientParameters || target == Target::LIBRARY) { + if (!res.fheContext.hasValue()) { + // Some tests can involves a usual function + res.clientParameters = + mlir::concretelang::emptyClientParametersForV0(funcName, module); + } else { + auto clientParametersOrErr = + mlir::concretelang::createClientParametersForV0(*res.fheContext, + funcName, module); + if (!clientParametersOrErr) + return clientParametersOrErr.takeError(); - llvm::Expected clientParametersOrErr = - mlir::concretelang::createClientParametersForV0( - *res.fheContext, *this->clientParametersFuncName, module); - - if (llvm::Error err = clientParametersOrErr.takeError()) - return std::move(err); - - res.clientParameters = clientParametersOrErr.get(); + res.clientParameters = clientParametersOrErr.get(); + } } // MLIR canonical dialects -> LLVM Dialect @@ -334,10 +345,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { llvm::Expected CompilerEngine::compile(llvm::StringRef s, Target target, OptionalLib lib) { std::unique_ptr mb = llvm::MemoryBuffer::getMemBuffer(s); - llvm::Expected res = - this->compile(std::move(mb), target, lib); - - return std::move(res); + return this->compile(std::move(mb), target, lib); } // Compile the contained in `buffer` to the target dialect @@ -351,9 +359,7 @@ CompilerEngine::compile(std::unique_ptr buffer, sm.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc()); - llvm::Expected res = this->compile(sm, target, lib); - - return std::move(res); + return this->compile(sm, target, lib); } template @@ -371,10 +377,9 @@ CompilerEngine::compile(std::vector inputs, std::string libraryPath) { } } - auto libPath = outputLib->emitShared(); - if (!libPath) { - return StreamStringError("Can't link: ") - << llvm::toString(libPath.takeError()); + if (auto err = outputLib->emitArtifacts()) { + return StreamStringError("Can't emit artifacts: ") + << llvm::toString(std::move(err)); } return *outputLib.get(); } @@ -384,7 +389,24 @@ template llvm::Expected CompilerEngine::compile(std::vector inputs, std::string libraryPath); +/** Returns the path of the shared library */ +std::string CompilerEngine::Library::getSharedLibraryPath(std::string path) { + return path + DOT_SHARED_LIB_EXT; +} + +/** Returns the path of the static library */ +std::string CompilerEngine::Library::getStaticLibraryPath(std::string path) { + return path + DOT_STATIC_LIB_EXT; +} + +/** Returns the path of the static library */ +std::string CompilerEngine::Library::getClientParametersPath(std::string path) { + return path + CLIENT_PARAMETERS_EXT; +} + const std::string CompilerEngine::Library::OBJECT_EXT = ".o"; +const std::string CompilerEngine::Library::CLIENT_PARAMETERS_EXT = + ".concrete.params.json"; const std::string CompilerEngine::Library::LINKER = "ld"; const std::string CompilerEngine::Library::LINKER_SHARED_OPT = " --shared -o "; const std::string CompilerEngine::Library::AR = "ar"; @@ -396,6 +418,23 @@ void CompilerEngine::Library::addExtraObjectFilePath(std::string path) { objectsPath.push_back(path); } +llvm::Expected +CompilerEngine::Library::emitClientParametersJSON() { + auto clientParamsPath = getClientParametersPath(libraryPath); + llvm::json::Value value(clientParametersList); + std::error_code error; + llvm::raw_fd_ostream out(clientParamsPath, error); + + if (error) { + return StreamStringError("cannot emit client parameters, error: ") + << error.message(); + } + out << llvm::formatv("{0:2}", value); + out.close(); + + return clientParamsPath; +} + llvm::Expected CompilerEngine::Library::addCompilation(CompilationResult &compilation) { llvm::Module *module = compilation.llvmModule.get(); @@ -405,13 +444,14 @@ CompilerEngine::Library::addCompilation(CompilationResult &compilation) { std::to_string(objectsPath.size()) + ".mlir"; } auto objectPath = sourceName + OBJECT_EXT; - auto error = mlir::concretelang::emitObject(*module, objectPath); - - if (error) { + if (auto error = mlir::concretelang::emitObject(*module, objectPath)) { return std::move(error); } addExtraObjectFilePath(objectPath); + if (compilation.clientParameters.hasValue()) { + clientParametersList.push_back(compilation.clientParameters.getValue()); + } return objectPath; } @@ -437,9 +477,8 @@ llvm::Expected CompilerEngine::Library::emit(std::string dotExt, auto error = mlir::concretelang::emitLibrary(objectsPath, pathDotExt, linker); if (error) { return std::move(error); - } else { - return pathDotExt; } + return pathDotExt; } llvm::Expected CompilerEngine::Library::emitShared() { @@ -458,6 +497,19 @@ llvm::Expected CompilerEngine::Library::emitStatic() { return path; } +llvm::Error CompilerEngine::Library::emitArtifacts() { + if (auto err = emitShared().takeError()) { + return err; + } + if (auto err = emitStatic().takeError()) { + return err; + } + if (auto err = emitClientParametersJSON().takeError()) { + return err; + } + return llvm::Error::success(); +} + CompilerEngine::Library::~Library() { if (cleanUp) { for (auto path : objectsPath) { diff --git a/compiler/lib/Support/Error.cpp b/compiler/lib/Support/Error.cpp deleted file mode 100644 index 40ad794d6..000000000 --- a/compiler/lib/Support/Error.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Part of the Concrete Compiler Project, under the BSD3 License with Zama -// Exceptions. See -// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license -// information. - -#include - -namespace mlir { -namespace concretelang { -// Specialized `operator<<` for `llvm::Error` that marks the error -// as checked through `std::move` and `llvm::toString` -StreamStringError &operator<<(StreamStringError &se, llvm::Error &err) { - se << llvm::toString(std::move(err)); - return se; -} -} // namespace concretelang -} // namespace mlir diff --git a/compiler/lib/Support/ClientParameters.cpp b/compiler/lib/Support/V0ClientParameters.cpp similarity index 64% rename from compiler/lib/Support/ClientParameters.cpp rename to compiler/lib/Support/V0ClientParameters.cpp index caa0e533c..46e2b0621 100644 --- a/compiler/lib/Support/ClientParameters.cpp +++ b/compiler/lib/Support/V0ClientParameters.cpp @@ -8,8 +8,9 @@ #include +#include "concretelang/ClientLib/ClientParameters.h" +#include "concretelang/Conversion/Utils/GlobalFHEContext.h" #include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h" -#include "concretelang/Support/ClientParameters.h" #include "concretelang/Support/V0Curves.h" namespace mlir { @@ -20,7 +21,7 @@ const auto keyFormat = KEY_FORMAT_BINARY; const auto v0Curve = getV0Curves(securityLevel, keyFormat); // For the v0 the secretKeyID and precision are the same for all gates. -llvm::Expected gateFromMLIRType(std::string secretKeyID, +llvm::Expected gateFromMLIRType(LweSecretKeyID secretKeyID, Precision precision, Variance variance, mlir::Type type) { @@ -46,10 +47,13 @@ llvm::Expected gateFromMLIRType(std::string secretKeyID, // TODO - Get the width from the LWECiphertextType instead of global // precision (could be possible after merge concrete-ciphertext-parameter) return CircuitGate{ - .encryption = llvm::Optional({ - .secretKeyID = secretKeyID, - .variance = variance, - .encoding = {.precision = precision}, + /* .encryption = */ llvm::Optional({ + /* .secretKeyID = */ secretKeyID, + /* .variance = */ variance, + /* .encoding = */ + { + /* .precision = */ precision, + }, }), /*.shape = */ { @@ -77,25 +81,33 @@ llvm::Expected gateFromMLIRType(std::string secretKeyID, "cannot convert MLIR type to shape", llvm::inconvertibleErrorCode()); } +ClientParameters emptyClientParametersForV0(llvm::StringRef functionName, + mlir::ModuleOp module) { + ClientParameters c; + c.functionName = (std::string)functionName; + return c; +} + llvm::Expected -createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name, +createClientParametersForV0(V0FHEContext fheContext, + llvm::StringRef functionName, mlir::ModuleOp module) { auto v0Param = fheContext.parameter; Variance encryptionVariance = v0Curve->getVariance(1, 1 << v0Param.logPolynomialSize, 64); Variance keyswitchVariance = v0Curve->getVariance(1, v0Param.nSmall, 64); // Static client parameters from global parameters for v0 - ClientParameters c = {}; + ClientParameters c; c.secretKeys = { - {"small", {/*.size = */ v0Param.nSmall}}, - {"big", {/*.size = */ v0Param.getNBigGlweDimension()}}, + {SMALL_KEY, {/*.size = */ v0Param.nSmall}}, + {BIG_KEY, {/*.size = */ v0Param.getNBigGlweDimension()}}, }; c.bootstrapKeys = { { "bsk_v0", { - /*.inputSecretKeyID = */ "small", - /*.outputSecretKeyID = */ "big", + /*.inputSecretKeyID = */ SMALL_KEY, + /*.outputSecretKeyID = */ BIG_KEY, /*.level = */ v0Param.brLevel, /*.baseLog = */ v0Param.brLogBase, /*.glweDimension = */ v0Param.glweDimension, @@ -107,18 +119,19 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name, { "ksk_v0", { - /*.inputSecretKeyID = */ "big", - /*.outputSecretKeyID = */ "small", + /*.inputSecretKeyID = */ BIG_KEY, + /*.outputSecretKeyID = */ SMALL_KEY, /*.level = */ v0Param.ksLevel, /*.baseLog = */ v0Param.ksLogBase, /*.variance = */ keyswitchVariance, }, }, }; + c.functionName = (std::string)functionName; // Find the input function auto rangeOps = module.getOps(); auto funcOp = llvm::find_if( - rangeOps, [&](mlir::FuncOp op) { return op.getName() == name; }); + rangeOps, [&](mlir::FuncOp op) { return op.getName() == functionName; }); if (funcOp == rangeOps.end()) { return llvm::make_error( "cannot find the function for generate client parameters", @@ -135,14 +148,16 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name, .isa(); for (auto inType = funcType.getInputs().begin(); inType < funcType.getInputs().end() - hasContext; inType++) { - auto gate = gateFromMLIRType("big", precision, encryptionVariance, *inType); + auto gate = + gateFromMLIRType(BIG_KEY, precision, encryptionVariance, *inType); if (auto err = gate.takeError()) { return std::move(err); } c.inputs.push_back(gate.get()); } for (auto outType : funcType.getResults()) { - auto gate = gateFromMLIRType("big", precision, encryptionVariance, outType); + auto gate = + gateFromMLIRType(BIG_KEY, precision, encryptionVariance, outType); if (auto err = gate.takeError()) { return std::move(err); } @@ -151,46 +166,5 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name, return c; } -// https://stackoverflow.com/a/38140932 -static inline void hash(std::size_t &seed) {} -template -static inline void hash(std::size_t &seed, const T &v, Rest... rest) { - // See https://softwareengineering.stackexchange.com/a/402543 - const auto GOLDEN_RATIO = 0x9e3779b97f4a7c15; // pseudo random bits - const std::hash hasher; - seed ^= hasher(v) + GOLDEN_RATIO + (seed << 6) + (seed >> 2); - hash(seed, rest...); -} - -void LweSecretKeyParam::hash(size_t &seed) { - mlir::concretelang::hash(seed, size); -} - -void BootstrapKeyParam::hash(size_t &seed) { - mlir::concretelang::hash(seed, inputSecretKeyID, outputSecretKeyID, level, - baseLog, glweDimension, variance); -} - -void KeyswitchKeyParam::hash(size_t &seed) { - mlir::concretelang::hash(seed, inputSecretKeyID, outputSecretKeyID, level, - baseLog, variance); -} - -std::size_t ClientParameters::hash() { - std::size_t currentHash = 1; - for (auto secretKeyParam : secretKeys) { - mlir::concretelang::hash(currentHash, secretKeyParam.first); - secretKeyParam.second.hash(currentHash); - } - for (auto bootstrapKeyParam : bootstrapKeys) { - mlir::concretelang::hash(currentHash, bootstrapKeyParam.first); - bootstrapKeyParam.second.hash(currentHash); - } - for (auto keyswitchParam : keyswitchKeys) { - mlir::concretelang::hash(currentHash, keyswitchParam.first); - keyswitchParam.second.hash(currentHash); - } - return currentHash; -} } // namespace concretelang } // namespace mlir diff --git a/compiler/lib/Support/V0Curves.cpp b/compiler/lib/Support/V0Curves.cpp index c7e1da42e..d951e9424 100644 --- a/compiler/lib/Support/V0Curves.cpp +++ b/compiler/lib/Support/V0Curves.cpp @@ -27,4 +27,4 @@ V0Curves *getV0Curves(int securityLevel, int keyFormat) { return &curves[securityLevel][keyFormat]; } } // namespace concretelang -} // namespace mlir \ No newline at end of file +} // namespace mlir diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index c066ff933..4c7f68523 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -19,6 +19,7 @@ #include #include +#include "concretelang/ClientLib/KeySet.h" #include "concretelang/Conversion/Passes.h" #include "concretelang/Conversion/Utils/GlobalFHEContext.h" #include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" @@ -30,7 +31,6 @@ #include "concretelang/Dialect/TFHE/IR/TFHETypes.h" #include "concretelang/Support/Error.h" #include "concretelang/Support/JitCompilerEngine.h" -#include "concretelang/Support/KeySet.h" #include "concretelang/Support/LLVMEmitFile.h" #include "concretelang/Support/Pipeline.h" #include "concretelang/Support/logging.h" @@ -112,7 +112,7 @@ static llvm::cl::opt action( "Lower to LLVM-IR, optimize and dump result")), llvm::cl::values(clEnumValN(Action::JIT_INVOKE, "jit-invoke", "Lower and JIT-compile input module and invoke " - "function specified with --jit-funcname")), + "function specified with --funcname")), llvm::cl::values(clEnumValN(Action::COMPILE, "compile", "Lower to LLVM-IR, compile to a file"))); @@ -133,10 +133,10 @@ llvm::cl::opt autoParallelize( llvm::cl::desc("Generate (and execute if JIT) parallel code"), llvm::cl::init(false)); -llvm::cl::opt jitFuncName( - "jit-funcname", - llvm::cl::desc("Name of the function to execute, default 'main'"), - llvm::cl::init("main")); +llvm::cl::opt + funcName("funcname", + llvm::cl::desc("Name of the function to compile, default 'main'"), + llvm::cl::init("main")); llvm::cl::list jitArgs("jit-args", @@ -216,7 +216,7 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, // The parameter `action` specifies how the buffer should be processed // and thus defines the output. // -// If the specified action involves JIT compilation, `jitFuncName` +// If the specified action involves JIT compilation, `funcName` // designates the function to JIT compile. This function is invoked // using the parameters given in `jitArgs`. // @@ -239,7 +239,7 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, // Compilation output is written to the stream specified by `os`. mlir::LogicalResult processInputBuffer( std::unique_ptr buffer, std::string sourceFileName, - enum Action action, const std::string &jitFuncName, + enum Action action, const std::string &funcName, llvm::ArrayRef jitArgs, llvm::Optional overrideMaxEintPrecision, llvm::Optional overrideMaxMANP, bool verifyDiagnostics, @@ -269,17 +269,18 @@ mlir::LogicalResult processInputBuffer( if (overrideMaxMANP.hasValue()) ce.setMaxMANP(overrideMaxMANP.getValue()); + ce.setClientParametersFuncName(funcName); if (fhelinalgTileSizes.hasValue()) ce.setFHELinalgTileSizes(*fhelinalgTileSizes); if (action == Action::JIT_INVOKE) { llvm::Expected lambdaOrErr = - ce.buildLambda(std::move(buffer), jitFuncName, keySetCache); + ce.buildLambda(std::move(buffer), funcName, keySetCache); if (!lambdaOrErr) { mlir::concretelang::log_error() - << "Failed to JIT-compile " << jitFuncName << ": " - << llvm::toString(std::move(lambdaOrErr.takeError())); + << "Failed to JIT-compile " << funcName << ": " + << llvm::toString(lambdaOrErr.takeError()); return mlir::failure(); } @@ -287,7 +288,7 @@ mlir::LogicalResult processInputBuffer( if (!resOrErr) { mlir::concretelang::log_error() - << "Failed to JIT-invoke " << jitFuncName << " with arguments " + << "Failed to JIT-invoke " << funcName << " with arguments " << jitArgs << ": " << llvm::toString(std::move(resOrErr.takeError())); return mlir::failure(); } @@ -425,11 +426,11 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { auto process = [&](std::unique_ptr inputBuffer, llvm::raw_ostream &os) { return processInputBuffer( - std::move(inputBuffer), fileName, cmdline::action, - cmdline::jitFuncName, cmdline::jitArgs, - cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP, - cmdline::verifyDiagnostics, fhelinalgTileSizes, - cmdline::autoParallelize, jitKeySetCache, os, outputLib); + std::move(inputBuffer), fileName, cmdline::action, cmdline::funcName, + cmdline::jitArgs, cmdline::assumeMaxEintPrecision, + cmdline::assumeMaxMANP, cmdline::verifyDiagnostics, + fhelinalgTileSizes, cmdline::autoParallelize, jitKeySetCache, os, + outputLib); }; auto &os = output->os(); auto res = mlir::failure(); @@ -446,12 +447,8 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { } if (cmdline::action == Action::COMPILE) { - auto libPath = outputLib->emitShared(); - if (!libPath) { - return mlir::failure(); - } - libPath = outputLib->emitStatic(); - if (!libPath) { + auto err = outputLib->emitArtifacts(); + if (err) { return mlir::failure(); } } diff --git a/compiler/tests/CMakeLists.txt b/compiler/tests/CMakeLists.txt index 32ddb74cd..9b726791c 100644 --- a/compiler/tests/CMakeLists.txt +++ b/compiler/tests/CMakeLists.txt @@ -1,3 +1,4 @@ if (CONCRETELANG_UNIT_TESTS) add_subdirectory(unittest) + add_subdirectory(Support) endif() diff --git a/compiler/tests/Support/CMakeLists.txt b/compiler/tests/Support/CMakeLists.txt new file mode 100644 index 000000000..6f5435add --- /dev/null +++ b/compiler/tests/Support/CMakeLists.txt @@ -0,0 +1,23 @@ +enable_testing() + +include_directories(${PROJECT_SOURCE_DIR}/include) + +add_executable( + support_unit_test + support_unit_test.cpp +) + +set_source_files_properties( + support_unit_test.cpp + + PROPERTIES COMPILE_FLAGS "-fno-rtti" +) + +target_link_libraries( + support_unit_test + gtest_main + ConcretelangSupport +) + +include(GoogleTest) +gtest_discover_tests(support_unit_test) diff --git a/compiler/tests/Support/support_unit_test.cpp b/compiler/tests/Support/support_unit_test.cpp new file mode 100644 index 000000000..b8975782a --- /dev/null +++ b/compiler/tests/Support/support_unit_test.cpp @@ -0,0 +1,69 @@ +#include + +#include "../unittest/end_to_end_jit_test.h" +#include "concretelang/ClientLib/ClientParameters.h" + +namespace CL = mlir::concretelang; + +TEST(Support, client_parameters_json_serde) { + mlir::concretelang::ClientParameters params0; + params0.secretKeys = { + {mlir::concretelang::SMALL_KEY, {/*.size = */ 12}}, + {mlir::concretelang::BIG_KEY, {/*.size = */ 14}}, + }; + params0.bootstrapKeys = { + { + "bsk_v0", { + /*.inputSecretKeyID = */ mlir::concretelang::SMALL_KEY, + /*.outputSecretKeyID = */ mlir::concretelang::BIG_KEY, + /*.level = */ 1, + /*.baseLog = */ 2, + /*.k = */ 3, + /*.variance = */ 0.001 + } + },{ + "wtf_bsk_v0", { + /*.inputSecretKeyID = */ mlir::concretelang::BIG_KEY, + /*.outputSecretKeyID = */ mlir::concretelang::SMALL_KEY, + /*.level = */ 3, + /*.baseLog = */ 2, + /*.k = */ 1, + /*.variance = */ 0.0001, + } + }, + }; + params0.keyswitchKeys = { + { + "ksk_v0", { + /*.inputSecretKeyID = */ mlir::concretelang::BIG_KEY, + /*.outputSecretKeyID = */ mlir::concretelang::SMALL_KEY, + /*.level = */ 1, + /*.baseLog = */ 2, + /*.variance = */ 3, + } + } + }; + params0.inputs = { + { + /*.encryption = */ {{CL::SMALL_KEY, 0.01, {4}}}, + /*.shape = */ {32, {1, 2, 3, 4}, 1 * 2 * 3 * 4}, + }, + { + /*.encryption = */ {{CL::SMALL_KEY, 0.03, {5}}}, + /*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4}, + }, + }; + params0.outputs = { + { + /*.encryption = */ {{CL::SMALL_KEY, 0.03, {5}}}, + /*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4}, + }, + }; + auto json = mlir::concretelang::toJSON(params0); + std::string jsonStr; + llvm::raw_string_ostream os(jsonStr); + os << json; + auto parseResult = + llvm::json::parse(jsonStr); + ASSERT_EXPECTED_VALUE(parseResult, params0); +} diff --git a/compiler/tests/unittest/end_to_end_jit_test.h b/compiler/tests/unittest/end_to_end_jit_test.h index 18c774d07..a6e377b06 100644 --- a/compiler/tests/unittest/end_to_end_jit_test.h +++ b/compiler/tests/unittest/end_to_end_jit_test.h @@ -3,9 +3,9 @@ #include +#include "concretelang/ClientLib/KeySetCache.h" #include "concretelang/Support/CompilerEngine.h" #include "concretelang/Support/JitCompilerEngine.h" -#include "concretelang/Support/KeySetCache.h" #include "llvm/Support/Path.h" #include "globals.h"