From b052157fae7dc9a4774da0687f4aa0e33fa345e3 Mon Sep 17 00:00:00 2001 From: Umut Date: Wed, 18 May 2022 17:32:42 +0200 Subject: [PATCH] refactor: separate runtime context from public arguments --- .../concretelang-c/Support/CompilerEngine.h | 12 +- .../concretelang/ClientLib/EvaluationKeys.h | 108 +++++++++++++ .../include/concretelang/ClientLib/KeySet.h | 32 ++-- .../concretelang/ClientLib/PublicArguments.h | 5 - .../concretelang/ClientLib/Serializers.h | 13 ++ .../include/concretelang/Runtime/context.h | 15 +- .../concretelang/ServerLib/ServerLambda.h | 5 +- .../include/concretelang/Support/JITSupport.h | 5 +- compiler/include/concretelang/Support/Jit.h | 3 +- .../concretelang/Support/LambdaSupport.h | 22 ++- .../concretelang/Support/LibrarySupport.h | 6 +- .../concretelang/TestLib/TestTypedLambda.h | 3 +- compiler/lib/Bindings/Python/CMakeLists.txt | 3 +- .../lib/Bindings/Python/CompilerAPIModule.cpp | 26 ++- .../Python/concrete/compiler/__init__.py | 1 + .../concrete/compiler/evaluation_keys.py | 63 ++++++++ .../Python/concrete/compiler/jit_support.py | 16 +- .../Python/concrete/compiler/key_set.py | 11 ++ .../concrete/compiler/library_support.py | 18 ++- compiler/lib/CAPI/Support/CompilerEngine.cpp | 34 +++- compiler/lib/ClientLib/EncryptedArguments.cpp | 5 +- compiler/lib/ClientLib/KeySet.cpp | 18 +-- compiler/lib/ClientLib/KeySetCache.cpp | 15 +- compiler/lib/ClientLib/PublicArguments.cpp | 27 +--- compiler/lib/ClientLib/Serializers.cpp | 57 ++++++- compiler/lib/Runtime/context.cpp | 4 +- compiler/lib/ServerLib/ServerLambda.cpp | 10 +- compiler/lib/Support/Jit.cpp | 8 +- compiler/tests/python/test_client_server.py | 112 +++++++++++++ compiler/tests/python/test_compilation.py | 3 +- compiler/tests/python/test_serialization.py | 152 ------------------ compiler/tests/unittest/end_to_end_jit_fhe.cc | 5 +- 32 files changed, 548 insertions(+), 269 deletions(-) create mode 100644 compiler/include/concretelang/ClientLib/EvaluationKeys.h create mode 100644 compiler/lib/Bindings/Python/concrete/compiler/evaluation_keys.py create mode 100644 compiler/tests/python/test_client_server.py delete mode 100644 compiler/tests/python/test_serialization.py diff --git a/compiler/include/concretelang-c/Support/CompilerEngine.h b/compiler/include/concretelang-c/Support/CompilerEngine.h index 53751d8a0..88464232e 100644 --- a/compiler/include/concretelang-c/Support/CompilerEngine.h +++ b/compiler/include/concretelang-c/Support/CompilerEngine.h @@ -53,7 +53,8 @@ jit_load_server_lambda(JITSupport_C support, MLIR_CAPI_EXPORTED std::unique_ptr jit_server_call(JITSupport_C support, mlir::concretelang::JITLambda &lambda, - concretelang::clientlib::PublicArguments &args); + concretelang::clientlib::PublicArguments &args, + concretelang::clientlib::EvaluationKeys &evaluationKeys); // Library Support bindings /////////////////////////////////////////////////// @@ -82,7 +83,8 @@ library_load_server_lambda(LibrarySupport_C support, MLIR_CAPI_EXPORTED std::unique_ptr library_server_call(LibrarySupport_C support, concretelang::serverlib::ServerLambda lambda, - concretelang::clientlib::PublicArguments &args); + concretelang::clientlib::PublicArguments &args, + concretelang::clientlib::EvaluationKeys &evaluationKeys); MLIR_CAPI_EXPORTED std::string library_get_shared_lib_path(LibrarySupport_C support); @@ -128,6 +130,12 @@ publicResultUnserialize(mlir::concretelang::ClientParameters &clientParameters, MLIR_CAPI_EXPORTED std::string publicResultSerialize(concretelang::clientlib::PublicResult &publicResult); +MLIR_CAPI_EXPORTED concretelang::clientlib::EvaluationKeys +evaluationKeysUnserialize(const std::string &buffer); + +MLIR_CAPI_EXPORTED std::string evaluationKeysSerialize( + concretelang::clientlib::EvaluationKeys &evaluationKeys); + // Parse then print a textual representation of an MLIR module MLIR_CAPI_EXPORTED std::string roundTrip(const char *module); diff --git a/compiler/include/concretelang/ClientLib/EvaluationKeys.h b/compiler/include/concretelang/ClientLib/EvaluationKeys.h new file mode 100644 index 000000000..9022daed0 --- /dev/null +++ b/compiler/include/concretelang/ClientLib/EvaluationKeys.h @@ -0,0 +1,108 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_CLIENTLIB_EVALUATION_KEYS_H_ +#define CONCRETELANG_CLIENTLIB_EVALUATION_KEYS_H_ + +#include + +extern "C" { +#include "concrete-ffi.h" +} + +namespace concretelang { +namespace clientlib { + +// ============================================= + +/// Wrapper for `LweKeyswitchKey_u64` so that it cleans up properly. +class LweKeyswitchKey { +private: + LweKeyswitchKey_u64 *ksk; + +protected: + friend std::ostream &operator<<(std::ostream &ostream, + const LweKeyswitchKey &wrappedKsk); + friend std::istream &operator>>(std::istream &istream, + LweKeyswitchKey &wrappedKsk); + +public: + LweKeyswitchKey(LweKeyswitchKey_u64 *ksk) : ksk{ksk} {} + LweKeyswitchKey(LweKeyswitchKey &other) = delete; + LweKeyswitchKey(LweKeyswitchKey &&other) : ksk{other.ksk} { + other.ksk = nullptr; + } + ~LweKeyswitchKey() { + if (this->ksk != nullptr) { + free_lwe_keyswitch_key_u64(this->ksk); + this->ksk = nullptr; + } + } + + LweKeyswitchKey_u64 *get() { return this->ksk; } +}; + +// ============================================= + +/// Wrapper for `LweBootstrapKey_u64` so that it cleans up properly. +class LweBootstrapKey { +private: + LweBootstrapKey_u64 *bsk; + +protected: + friend std::ostream &operator<<(std::ostream &ostream, + const LweBootstrapKey &wrappedBsk); + friend std::istream &operator>>(std::istream &istream, + LweBootstrapKey &wrappedBsk); + +public: + LweBootstrapKey(LweBootstrapKey_u64 *bsk) : bsk{bsk} {} + LweBootstrapKey(LweBootstrapKey &other) = delete; + LweBootstrapKey(LweBootstrapKey &&other) : bsk{other.bsk} { + other.bsk = nullptr; + } + ~LweBootstrapKey() { + if (this->bsk != nullptr) { + free_lwe_bootstrap_key_u64(this->bsk); + this->bsk = nullptr; + } + } + + LweBootstrapKey_u64 *get() { return this->bsk; } +}; + +// ============================================= + +/// Evalution keys required for execution. +class EvaluationKeys { +private: + std::shared_ptr sharedKsk; + std::shared_ptr sharedBsk; + +protected: + friend std::ostream &operator<<(std::ostream &ostream, + const EvaluationKeys &evaluationKeys); + friend std::istream &operator>>(std::istream &istream, + EvaluationKeys &evaluationKeys); + +public: + EvaluationKeys() + : sharedKsk{std::shared_ptr(nullptr)}, + sharedBsk{std::shared_ptr(nullptr)} {} + + EvaluationKeys(std::shared_ptr sharedKsk, + std::shared_ptr sharedBsk) + : sharedKsk{sharedKsk}, sharedBsk{sharedBsk} {} + + LweKeyswitchKey_u64 *getKsk() { return this->sharedKsk->get(); } + LweBootstrapKey_u64 *getBsk() { return this->sharedBsk->get(); } +}; + +// ============================================= + +} // namespace clientlib +} // namespace concretelang + +#endif diff --git a/compiler/include/concretelang/ClientLib/KeySet.h b/compiler/include/concretelang/ClientLib/KeySet.h index adee94de0..23dcc198e 100644 --- a/compiler/include/concretelang/ClientLib/KeySet.h +++ b/compiler/include/concretelang/ClientLib/KeySet.h @@ -16,6 +16,7 @@ extern "C" { #include "concretelang/Runtime/context.h" #include "concretelang/ClientLib/ClientParameters.h" +#include "concretelang/ClientLib/EvaluationKeys.h" #include "concretelang/ClientLib/KeySetCache.h" #include "concretelang/Common/Error.h" @@ -77,28 +78,29 @@ public: CircuitGate inputGate(size_t pos) { return std::get<0>(inputs[pos]); } CircuitGate outputGate(size_t pos) { return std::get<0>(outputs[pos]); } - void setRuntimeContext(RuntimeContext &context) { - context.ksk = std::get<1>(this->keyswitchKeys["ksk_v0"]); - context.bsk = std::get<1>(this->bootstrapKeys.at("bsk_v0")); - } - RuntimeContext runtimeContext() { RuntimeContext context; - this->setRuntimeContext(context); + context.evaluationKeys = this->evaluationKeys(); return context; } + EvaluationKeys evaluationKeys() { + auto sharedKsk = std::get<1>(this->keyswitchKeys.at("ksk_v0")); + auto sharedBsk = std::get<1>(this->bootstrapKeys.at("bsk_v0")); + return EvaluationKeys(sharedKsk, sharedBsk); + } + const std::map> & getSecretKeys(); const std::map> & - getBootstrapKeys(); + std::pair>> + &getBootstrapKeys(); const std::map> & - getKeyswitchKeys(); + std::pair>> + &getKeyswitchKeys(); protected: outcome::checked @@ -124,9 +126,11 @@ private: Engine *engine; std::map> secretKeys; - std::map> + std::map>> bootstrapKeys; - std::map> + std::map>> keyswitchKeys; std::vector> inputs; @@ -137,10 +141,10 @@ private: std::map> secretKeys, std::map> + std::pair>> bootstrapKeys, std::map> + std::pair>> keyswitchKeys); }; diff --git a/compiler/include/concretelang/ClientLib/PublicArguments.h b/compiler/include/concretelang/ClientLib/PublicArguments.h index 3d79ec63a..c0474876d 100644 --- a/compiler/include/concretelang/ClientLib/PublicArguments.h +++ b/compiler/include/concretelang/ClientLib/PublicArguments.h @@ -37,7 +37,6 @@ class PublicArguments { /// arguments and public keys. public: PublicArguments(const ClientParameters &clientParameters, - RuntimeContext runtimeContext, bool clearRuntimeContext, std::vector &&preparedArgs, std::vector &&ciphertextBuffers); ~PublicArguments(); @@ -56,13 +55,9 @@ private: outcome::checked unserializeArgs(std::istream &istream); ClientParameters clientParameters; - RuntimeContext runtimeContext; std::vector preparedArgs; // Store buffers of ciphertexts std::vector ciphertextBuffers; - - // Indicates if this public argument own the runtime keys. - bool clearRuntimeContext; }; struct PublicResult { diff --git a/compiler/include/concretelang/ClientLib/Serializers.h b/compiler/include/concretelang/ClientLib/Serializers.h index d850949d8..9099d13d7 100644 --- a/compiler/include/concretelang/ClientLib/Serializers.h +++ b/compiler/include/concretelang/ClientLib/Serializers.h @@ -9,6 +9,7 @@ #include #include "concretelang/ClientLib/ClientParameters.h" +#include "concretelang/ClientLib/EvaluationKeys.h" #include "concretelang/ClientLib/Types.h" #include "concretelang/Runtime/context.h" @@ -67,6 +68,18 @@ TensorData unserializeTensorData( // accomodate non static sizes std::istream &istream); +std::ostream &operator<<(std::ostream &ostream, + const LweKeyswitchKey &wrappedKsk); +std::istream &operator>>(std::istream &istream, LweKeyswitchKey &wrappedKsk); + +std::ostream &operator<<(std::ostream &ostream, + const LweBootstrapKey &wrappedBsk); +std::istream &operator>>(std::istream &istream, LweBootstrapKey &wrappedBsk); + +std::ostream &operator<<(std::ostream &ostream, + const EvaluationKeys &evaluationKeys); +std::istream &operator>>(std::istream &istream, EvaluationKeys &evaluationKeys); + } // namespace clientlib } // namespace concretelang diff --git a/compiler/include/concretelang/Runtime/context.h b/compiler/include/concretelang/Runtime/context.h index 1bd4edf9d..2401783fa 100644 --- a/compiler/include/concretelang/Runtime/context.h +++ b/compiler/include/concretelang/Runtime/context.h @@ -10,6 +10,8 @@ #include #include +#include "concretelang/ClientLib/EvaluationKeys.h" + extern "C" { #include "concrete-ffi.h" } @@ -18,16 +20,18 @@ namespace mlir { namespace concretelang { typedef struct RuntimeContext { - LweKeyswitchKey_u64 *ksk; - LweBootstrapKey_u64 *bsk; + ::concretelang::clientlib::EvaluationKeys evaluationKeys; std::map engines; std::mutex engines_map_guard; RuntimeContext() {} + // Ensure that the engines map is not copied - RuntimeContext(const RuntimeContext &ctx) : ksk(ctx.ksk), bsk(ctx.bsk) {} + RuntimeContext(const RuntimeContext &ctx) + : evaluationKeys(ctx.evaluationKeys) {} RuntimeContext(const RuntimeContext &&other) - : ksk(other.ksk), bsk(other.bsk) {} + : evaluationKeys(other.evaluationKeys) {} + ~RuntimeContext() { for (const auto &key : engines) { free_engine(key.second); @@ -35,8 +39,7 @@ typedef struct RuntimeContext { } RuntimeContext &operator=(const RuntimeContext &rhs) { - ksk = rhs.ksk; - bsk = rhs.bsk; + this->evaluationKeys = rhs.evaluationKeys; return *this; } } RuntimeContext; diff --git a/compiler/include/concretelang/ServerLib/ServerLambda.h b/compiler/include/concretelang/ServerLib/ServerLambda.h index ad01e9c83..7c7da59cc 100644 --- a/compiler/include/concretelang/ServerLib/ServerLambda.h +++ b/compiler/include/concretelang/ServerLib/ServerLambda.h @@ -39,7 +39,8 @@ public: /// Call the ServerLambda with public arguments. std::unique_ptr - call(clientlib::PublicArguments &args); + call(clientlib::PublicArguments &args, + clientlib::EvaluationKeys &evaluationKeys); protected: ClientParameters clientParameters; @@ -51,4 +52,4 @@ protected: } // namespace serverlib } // namespace concretelang -#endif \ No newline at end of file +#endif diff --git a/compiler/include/concretelang/Support/JITSupport.h b/compiler/include/concretelang/Support/JITSupport.h index 318063f67..af355dc9f 100644 --- a/compiler/include/concretelang/Support/JITSupport.h +++ b/compiler/include/concretelang/Support/JITSupport.h @@ -51,8 +51,9 @@ public: llvm::Expected> serverCall(std::shared_ptr lambda, - clientlib::PublicArguments &args) override { - return lambda->call(args); + clientlib::PublicArguments &args, + clientlib::EvaluationKeys &evaluationKeys) override { + return lambda->call(args, evaluationKeys); } private: diff --git a/compiler/include/concretelang/Support/Jit.h b/compiler/include/concretelang/Support/Jit.h index ec9477d90..f7ac2f5a9 100644 --- a/compiler/include/concretelang/Support/Jit.h +++ b/compiler/include/concretelang/Support/Jit.h @@ -36,7 +36,8 @@ public: /// Call the JIT lambda with the public arguments. llvm::Expected> - call(clientlib::PublicArguments &args); + call(clientlib::PublicArguments &args, + clientlib::EvaluationKeys &evaluationKeys); void setUseDataflow(bool option) { this->useDataflow = option; } diff --git a/compiler/include/concretelang/Support/LambdaSupport.h b/compiler/include/concretelang/Support/LambdaSupport.h index 18b70fe24..18decf151 100644 --- a/compiler/include/concretelang/Support/LambdaSupport.h +++ b/compiler/include/concretelang/Support/LambdaSupport.h @@ -275,7 +275,8 @@ public: /// Call the lambda with the public arguments. llvm::Expected> virtual serverCall( - Lambda lambda, clientlib::PublicArguments &args) = 0; + Lambda lambda, clientlib::PublicArguments &args, + clientlib::EvaluationKeys &evaluationKeys) = 0; /// Build the client KeySet from the client parameters. static llvm::Expected> @@ -302,11 +303,12 @@ public: } template - static llvm::Expected - call(Lambda lambda, clientlib::PublicArguments &publicArguments) { + static llvm::Expected call(Lambda lambda, + clientlib::PublicArguments &publicArguments, + clientlib::EvaluationKeys &evaluationKeys) { // Call the lambda auto publicResult = LambdaSupport().serverCall( - lambda, publicArguments); + lambda, publicArguments, evaluationKeys); if (auto err = publicResult.takeError()) { return std::move(err); } @@ -357,7 +359,9 @@ public: return std::move(err); } - auto publicResult = support.serverCall(lambda, **publicArguments); + auto evaluationKeys = this->keySet->evaluationKeys(); + auto publicResult = + support.serverCall(lambda, **publicArguments, evaluationKeys); if (auto err = publicResult.takeError()) { return std::move(err); } @@ -375,7 +379,9 @@ public: if (!publicArguments.has_value()) { return StreamStringError(publicArguments.error().mesg); } - auto publicResult = support.serverCall(lambda, *publicArguments.value()); + auto evaluationKeys = keySet->evaluationKeys(); + auto publicResult = + support.serverCall(lambda, *publicArguments.value(), evaluationKeys); if (auto err = publicResult.takeError()) { return std::move(err); } @@ -394,7 +400,9 @@ public: if (publicArguments.has_error()) { return StreamStringError(publicArguments.error().mesg); } - auto publicResult = support.serverCall(lambda, *publicArguments.value()); + auto evaluationKeys = keySet->evaluationKeys(); + auto publicResult = + support.serverCall(lambda, *publicArguments.value(), evaluationKeys); if (auto err = publicResult.takeError()) { return std::move(err); } diff --git a/compiler/include/concretelang/Support/LibrarySupport.h b/compiler/include/concretelang/Support/LibrarySupport.h index a6d92a9aa..00278094a 100644 --- a/compiler/include/concretelang/Support/LibrarySupport.h +++ b/compiler/include/concretelang/Support/LibrarySupport.h @@ -103,9 +103,9 @@ public: /// Call the lambda with the public arguments. llvm::Expected> - serverCall(serverlib::ServerLambda lambda, - clientlib::PublicArguments &args) override { - return lambda.call(args); + serverCall(serverlib::ServerLambda lambda, clientlib::PublicArguments &args, + clientlib::EvaluationKeys &evaluationKeys) override { + return lambda.call(args, evaluationKeys); } /// Get path to shared library diff --git a/compiler/include/concretelang/TestLib/TestTypedLambda.h b/compiler/include/concretelang/TestLib/TestTypedLambda.h index 858e9fe72..665365c8c 100644 --- a/compiler/include/concretelang/TestLib/TestTypedLambda.h +++ b/compiler/include/concretelang/TestLib/TestTypedLambda.h @@ -89,7 +89,8 @@ public: // serverInput)); // server function call - auto publicResult = serverLambda.call(*publicArgument); + auto evaluationKeys = keySet->evaluationKeys(); + auto publicResult = serverLambda.call(*publicArgument, evaluationKeys); // client result decryption return this->decryptResult(*keySet, *publicResult); diff --git a/compiler/lib/Bindings/Python/CMakeLists.txt b/compiler/lib/Bindings/Python/CMakeLists.txt index e7ab161ad..3051ff6fa 100644 --- a/compiler/lib/Bindings/Python/CMakeLists.txt +++ b/compiler/lib/Bindings/Python/CMakeLists.txt @@ -41,6 +41,7 @@ declare_mlir_python_sources(ConcretelangBindingsPythonSources concrete/compiler/library_lambda.py concrete/compiler/public_arguments.py concrete/compiler/public_result.py + concrete/compiler/evaluation_keys.py concrete/compiler/utils.py concrete/compiler/wrapper.py concrete/__init__.py @@ -119,4 +120,4 @@ add_mlir_python_modules(ConcretelangPythonModules ConcretelangBindingsPythonSources COMMON_CAPI_LINK_LIBS ConcretelangBindingsPythonCAPI -) \ No newline at end of file +) diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 461422ead..8280542c5 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -90,8 +90,10 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( pybind11::return_value_policy::reference) .def("server_call", [](JITSupport_C &support, concretelang::JITLambda &lambda, - clientlib::PublicArguments &publicArguments) { - return jit_server_call(support, lambda, publicArguments); + clientlib::PublicArguments &publicArguments, + clientlib::EvaluationKeys &evaluationKeys) { + return jit_server_call(support, lambda, publicArguments, + evaluationKeys); }); pybind11::class_( @@ -132,8 +134,10 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( pybind11::return_value_policy::reference) .def("server_call", [](LibrarySupport_C &support, serverlib::ServerLambda lambda, - clientlib::PublicArguments &publicArguments) { - return library_server_call(support, lambda, publicArguments); + clientlib::PublicArguments &publicArguments, + clientlib::EvaluationKeys &evaluationKeys) { + return library_server_call(support, lambda, publicArguments, + evaluationKeys); }) .def("get_shared_lib_path", [](LibrarySupport_C &support) { @@ -185,7 +189,10 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( clientParametersSerialize(clientParameters)); }); - pybind11::class_(m, "KeySet"); + pybind11::class_(m, "KeySet") + .def("get_evaluation_keys", + [](clientlib::KeySet &keySet) { return keySet.evaluationKeys(); }); + pybind11::class_>( m, "PublicArguments") @@ -207,6 +214,15 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( return pybind11::bytes(publicResultSerialize(publicResult)); }); + pybind11::class_(m, "EvaluationKeys") + .def_static("unserialize", + [](const pybind11::bytes &buffer) { + return evaluationKeysUnserialize(buffer); + }) + .def("serialize", [](clientlib::EvaluationKeys &evaluationKeys) { + return pybind11::bytes(evaluationKeysSerialize(evaluationKeys)); + }); + pybind11::class_(m, "LambdaArgument") .def_static("from_tensor", [](std::vector tensor, std::vector dims) { diff --git a/compiler/lib/Bindings/Python/concrete/compiler/__init__.py b/compiler/lib/Bindings/Python/concrete/compiler/__init__.py index 0c0e85864..011f86eb5 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler/__init__.py +++ b/compiler/lib/Bindings/Python/concrete/compiler/__init__.py @@ -26,6 +26,7 @@ from .library_lambda import LibraryLambda from .client_support import ClientSupport from .jit_support import JITSupport from .library_support import LibrarySupport +from .evaluation_keys import EvaluationKeys # Terminate parallelization in the compiler (if init) during cleanup diff --git a/compiler/lib/Bindings/Python/concrete/compiler/evaluation_keys.py b/compiler/lib/Bindings/Python/concrete/compiler/evaluation_keys.py new file mode 100644 index 000000000..eba12a6c9 --- /dev/null +++ b/compiler/lib/Bindings/Python/concrete/compiler/evaluation_keys.py @@ -0,0 +1,63 @@ +# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. +# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information. + +"""EvaluationKeys.""" + +# pylint: disable=no-name-in-module,import-error +from mlir._mlir_libs._concretelang._compiler import ( + EvaluationKeys as _EvaluationKeys, +) + +# pylint: enable=no-name-in-module,import-error +from .wrapper import WrapperCpp + + +class EvaluationKeys(WrapperCpp): + """ + EvaluationKeys required for execution. + """ + + def __init__(self, evaluation_keys: _EvaluationKeys): + """Wrap the native Cpp object. + + Args: + evaluation_keys (_EvaluationKeys): object to wrap + + Raises: + TypeError: if evaluation_keys is not of type _EvaluationKeys + """ + if not isinstance(evaluation_keys, _EvaluationKeys): + raise TypeError( + f"evaluation_keys must be of type _EvaluationKeys, not {type(evaluation_keys)}" + ) + super().__init__(evaluation_keys) + + def serialize(self) -> bytes: + """Serialize the EvaluationKeys. + + Returns: + bytes: serialized object + """ + return self.cpp().serialize() + + @staticmethod + def unserialize(serialized_evaluation_keys: bytes) -> "EvaluationKeys": + """Unserialize EvaluationKeys from bytes. + + Args: + serialized_evaluation_keys (bytes): previously serialized EvaluationKeys + + Raises: + TypeError: if serialized_evaluation_keys is not of type bytes + + Returns: + EvaluationKeys: unserialized object + """ + if not isinstance(serialized_evaluation_keys, bytes): + raise TypeError( + f"serialized_evaluation_keys must be of type bytes, " + f"not {type(serialized_evaluation_keys)}" + ) + return EvaluationKeys.wrap( + _EvaluationKeys.unserialize(serialized_evaluation_keys) + ) diff --git a/compiler/lib/Bindings/Python/concrete/compiler/jit_support.py b/compiler/lib/Bindings/Python/concrete/compiler/jit_support.py index aedf35103..6dc3738d6 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler/jit_support.py +++ b/compiler/lib/Bindings/Python/concrete/compiler/jit_support.py @@ -23,6 +23,7 @@ from .jit_lambda import JITLambda from .public_arguments import PublicArguments from .public_result import PublicResult from .wrapper import WrapperCpp +from .evaluation_keys import EvaluationKeys class JITSupport(WrapperCpp): @@ -139,17 +140,22 @@ class JITSupport(WrapperCpp): return JITLambda.wrap(self.cpp().load_server_lambda(compilation_result.cpp())) def server_call( - self, jit_lambda: JITLambda, public_arguments: PublicArguments + self, + jit_lambda: JITLambda, + public_arguments: PublicArguments, + evaluation_keys: EvaluationKeys, ) -> PublicResult: """Call the JITLambda with public_arguments. Args: jit_lambda (JITLambda): A server lambda to call. public_arguments (PublicArguments): The arguments of the call. + evaluation_keys (EvaluationKeys): Evalutation keys of the call. Raises: TypeError: if jit_lambda is not of type JITLambda TypeError: if public_arguments is not of type PublicArguments + TypeError: if evaluation_keys is not of type EvaluationKeys Returns: PublicResult: the result of the call of the server lambda. @@ -162,6 +168,12 @@ class JITSupport(WrapperCpp): raise TypeError( f"public_arguments must be of type PublicArguments, not {type(public_arguments)}" ) + if not isinstance(evaluation_keys, EvaluationKeys): + raise TypeError( + f"evaluation_keys must be of type EvaluationKeys, not {type(evaluation_keys)}" + ) return PublicResult.wrap( - self.cpp().server_call(jit_lambda.cpp(), public_arguments.cpp()) + self.cpp().server_call( + jit_lambda.cpp(), public_arguments.cpp(), evaluation_keys.cpp() + ) ) diff --git a/compiler/lib/Bindings/Python/concrete/compiler/key_set.py b/compiler/lib/Bindings/Python/concrete/compiler/key_set.py index 815ef5f45..b07657087 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler/key_set.py +++ b/compiler/lib/Bindings/Python/concrete/compiler/key_set.py @@ -14,6 +14,7 @@ from mlir._mlir_libs._concretelang._compiler import ( # pylint: enable=no-name-in-module,import-error from .wrapper import WrapperCpp +from .evaluation_keys import EvaluationKeys class KeySet(WrapperCpp): @@ -34,3 +35,13 @@ class KeySet(WrapperCpp): if not isinstance(keyset, _KeySet): raise TypeError(f"keyset must be of type _KeySet, not {type(keyset)}") super().__init__(keyset) + + def get_evaluation_keys(self) -> EvaluationKeys: + """ + Get evaluation keys for execution. + + Returns: + EvaluationKeys: + evaluation keys for execution + """ + return EvaluationKeys(self.cpp().get_evaluation_keys()) diff --git a/compiler/lib/Bindings/Python/concrete/compiler/library_support.py b/compiler/lib/Bindings/Python/concrete/compiler/library_support.py index 678981acc..9bd9b4551 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler/library_support.py +++ b/compiler/lib/Bindings/Python/concrete/compiler/library_support.py @@ -23,6 +23,7 @@ from .public_result import PublicResult from .client_parameters import ClientParameters from .wrapper import WrapperCpp from .utils import lookup_runtime_lib +from .evaluation_keys import EvaluationKeys # Default output path for compilation artifacts @@ -211,17 +212,22 @@ class LibrarySupport(WrapperCpp): ) def server_call( - self, library_lambda: LibraryLambda, public_arguments: PublicArguments + self, + library_lambda: LibraryLambda, + public_arguments: PublicArguments, + evaluation_keys: EvaluationKeys, ) -> PublicResult: """Call the library with public_arguments. Args: library_lambda (LibraryLambda): reference to the compiled library public_arguments (PublicArguments): arguments to use for execution + evaluation_keys (EvaluationKeys): evaluation keys to use for execution Raises: TypeError: if library_lambda is not of type LibraryLambda TypeError: if public_arguments is not of type PublicArguments + TypeError: if evaluation_keys is not of type EvaluationKeys Returns: PublicResult: result of the execution @@ -234,8 +240,16 @@ class LibrarySupport(WrapperCpp): raise TypeError( f"public_arguments must be of type PublicArguments, not {type(public_arguments)}" ) + if not isinstance(evaluation_keys, EvaluationKeys): + raise TypeError( + f"evaluation_keys must be of type EvaluationKeys, not {type(evaluation_keys)}" + ) return PublicResult.wrap( - self.cpp().server_call(library_lambda.cpp(), public_arguments.cpp()) + self.cpp().server_call( + library_lambda.cpp(), + public_arguments.cpp(), + evaluation_keys.cpp(), + ) ) def get_shared_lib_path(self) -> str: diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index 1d3769bbb..b4824cb3e 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -7,6 +7,7 @@ #include "concretelang-c/Support/CompilerEngine.h" #include "concretelang/ClientLib/KeySetCache.h" +#include "concretelang/ClientLib/Serializers.h" #include "concretelang/Runtime/runtime_api.h" #include "concretelang/Support/CompilerEngine.h" #include "concretelang/Support/JITSupport.h" @@ -53,8 +54,9 @@ jit_load_server_lambda(JITSupport_C support, MLIR_CAPI_EXPORTED std::unique_ptr jit_server_call(JITSupport_C support, mlir::concretelang::JITLambda &lambda, - concretelang::clientlib::PublicArguments &args) { - GET_OR_THROW_LLVM_EXPECTED(publicResult, lambda.call(args)); + concretelang::clientlib::PublicArguments &args, + concretelang::clientlib::EvaluationKeys &evaluationKeys) { + GET_OR_THROW_LLVM_EXPECTED(publicResult, lambda.call(args, evaluationKeys)); return std::move(*publicResult); } @@ -97,9 +99,10 @@ library_load_server_lambda( MLIR_CAPI_EXPORTED std::unique_ptr library_server_call(LibrarySupport_C support, concretelang::serverlib::ServerLambda lambda, - concretelang::clientlib::PublicArguments &args) { - GET_OR_THROW_LLVM_EXPECTED(publicResult, - support.support.serverCall(lambda, args)); + concretelang::clientlib::PublicArguments &args, + concretelang::clientlib::EvaluationKeys &evaluationKeys) { + GET_OR_THROW_LLVM_EXPECTED( + publicResult, support.support.serverCall(lambda, args, evaluationKeys)); return std::move(*publicResult); } @@ -192,6 +195,27 @@ publicResultSerialize(concretelang::clientlib::PublicResult &publicResult) { return buffer.str(); } +MLIR_CAPI_EXPORTED concretelang::clientlib::EvaluationKeys +evaluationKeysUnserialize(const std::string &buffer) { + std::stringstream istream(buffer); + + concretelang::clientlib::EvaluationKeys evaluationKeys; + concretelang::clientlib::operator>>(istream, evaluationKeys); + + if (istream.fail()) { + throw std::runtime_error("Cannot read evaluation keys"); + } + + return evaluationKeys; +} + +MLIR_CAPI_EXPORTED std::string evaluationKeysSerialize( + concretelang::clientlib::EvaluationKeys &evaluationKeys) { + std::ostringstream buffer(std::ios::binary); + concretelang::clientlib::operator<<(buffer, evaluationKeys); + return buffer.str(); +} + MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters clientParametersUnserialize(const std::string &json) { GET_OR_THROW_LLVM_EXPECTED( diff --git a/compiler/lib/ClientLib/EncryptedArguments.cpp b/compiler/lib/ClientLib/EncryptedArguments.cpp index 358972885..dbfbda729 100644 --- a/compiler/lib/ClientLib/EncryptedArguments.cpp +++ b/compiler/lib/ClientLib/EncryptedArguments.cpp @@ -25,11 +25,8 @@ size_t bitWidthAsWord(size_t exactBitWidth) { outcome::checked, StringError> EncryptedArguments::exportPublicArguments(ClientParameters clientParameters, RuntimeContext runtimeContext) { - // On client side the runtimeContext is hold by the KeySet - bool clearContext = false; return std::make_unique( - clientParameters, runtimeContext, clearContext, std::move(preparedArgs), - std::move(ciphertextBuffers)); + clientParameters, std::move(preparedArgs), std::move(ciphertextBuffers)); } outcome::checked diff --git a/compiler/lib/ClientLib/KeySet.cpp b/compiler/lib/ClientLib/KeySet.cpp index 2bfbbbc72..15d4696ca 100644 --- a/compiler/lib/ClientLib/KeySet.cpp +++ b/compiler/lib/ClientLib/KeySet.cpp @@ -24,12 +24,6 @@ KeySet::~KeySet() { for (auto it : secretKeys) { free_lwe_secret_key_u64(it.second.second); } - for (auto it : bootstrapKeys) { - free_lwe_bootstrap_key_u64(it.second.second); - } - for (auto it : keyswitchKeys) { - free_lwe_keyswitch_key_u64(it.second.second); - } free_engine(engine); } @@ -115,10 +109,10 @@ void KeySet::setKeys( std::map> secretKeys, std::map> + std::pair>> bootstrapKeys, std::map> + std::pair>> keyswitchKeys) { this->secretKeys = secretKeys; this->bootstrapKeys = bootstrapKeys; @@ -160,7 +154,7 @@ KeySet::generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param) { param.level, param.variance, param.glweDimension, polynomialSize); // Store the bootstrap key - bootstrapKeys[id] = {param, bsk}; + bootstrapKeys[id] = {param, std::make_shared(bsk)}; return outcome::success(); } @@ -184,7 +178,7 @@ KeySet::generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param) { param.baseLog, param.variance); // Store the keyswitch key - keyswitchKeys[id] = {param, ksk}; + keyswitchKeys[id] = {param, std::make_shared(ksk)}; return outcome::success(); } @@ -253,13 +247,13 @@ const std::map> } const std::map> & + std::pair>> & KeySet::getBootstrapKeys() { return bootstrapKeys; } const std::map> & + std::pair>> & KeySet::getKeyswitchKeys() { return keyswitchKeys; } diff --git a/compiler/lib/ClientLib/KeySetCache.cpp b/compiler/lib/ClientLib/KeySetCache.cpp index eab7850c1..ce3fc95f3 100644 --- a/compiler/lib/ClientLib/KeySetCache.cpp +++ b/compiler/lib/ClientLib/KeySetCache.cpp @@ -5,6 +5,7 @@ #include "boost/outcome.h" +#include "concretelang/ClientLib/EvaluationKeys.h" #include "concretelang/ClientLib/KeySetCache.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/FileSystem.h" @@ -96,9 +97,11 @@ KeySetCache::loadKeys(ClientParameters ¶ms, uint64_t seed_msb, std::map> secretKeys; - std::map> + std::map>> bootstrapKeys; - std::map> + std::map>> keyswitchKeys; // Load LWE secret keys @@ -117,7 +120,7 @@ KeySetCache::loadKeys(ClientParameters ¶ms, uint64_t seed_msb, llvm::SmallString<0> path(folderPath); llvm::sys::path::append(path, "pbsKey_" + id); OUTCOME_TRY(LweBootstrapKey_u64 * bsk, loadBootstrapKey(path)); - bootstrapKeys[id] = {param, bsk}; + bootstrapKeys[id] = {param, std::make_shared(bsk)}; } // Load keyswitch keys for (auto keyswitchParam : params.keyswitchKeys) { @@ -126,7 +129,7 @@ KeySetCache::loadKeys(ClientParameters ¶ms, uint64_t seed_msb, llvm::SmallString<0> path(folderPath); llvm::sys::path::append(path, "ksKey_" + id); OUTCOME_TRY(LweKeyswitchKey_u64 * ksk, loadKeyswitchKey(path)); - keyswitchKeys[id] = {param, ksk}; + keyswitchKeys[id] = {param, std::make_shared(ksk)}; } key_set->setKeys(secretKeys, bootstrapKeys, keyswitchKeys); @@ -162,7 +165,7 @@ outcome::checked saveKeys(KeySet &key_set, auto key = bootstrapKeyParam.second.second; llvm::SmallString<0> path = folderIncompletePath; llvm::sys::path::append(path, "pbsKey_" + id); - saveBootstrapKey(path, key); + saveBootstrapKey(path, key->get()); } // Save keyswitch keys for (auto keyswitchParam : key_set.getKeyswitchKeys()) { @@ -170,7 +173,7 @@ outcome::checked saveKeys(KeySet &key_set, auto key = keyswitchParam.second.second; llvm::SmallString<0> path = folderIncompletePath; llvm::sys::path::append(path, "ksKey_" + id); - saveKeyswitchKey(path, key); + saveKeyswitchKey(path, key->get()); } err = llvm::sys::fs::rename(folderIncompletePath, folderPath); diff --git a/compiler/lib/ClientLib/PublicArguments.cpp b/compiler/lib/ClientLib/PublicArguments.cpp index ad334d136..e19451be5 100644 --- a/compiler/lib/ClientLib/PublicArguments.cpp +++ b/compiler/lib/ClientLib/PublicArguments.cpp @@ -20,28 +20,14 @@ using concretelang::error::StringError; // TODO: optimize the move PublicArguments::PublicArguments(const ClientParameters &clientParameters, - RuntimeContext runtimeContext, - bool clearRuntimeContext, std::vector &&preparedArgs_, std::vector &&ciphertextBuffers_) - : clientParameters(clientParameters), runtimeContext(runtimeContext), - clearRuntimeContext(clearRuntimeContext) { + : clientParameters(clientParameters) { preparedArgs = std::move(preparedArgs_); ciphertextBuffers = std::move(ciphertextBuffers_); } -PublicArguments::~PublicArguments() { - if (!clearRuntimeContext) { - return; - } - if (runtimeContext.bsk != nullptr) { - free_lwe_bootstrap_key_u64(runtimeContext.bsk); - } - if (runtimeContext.ksk != nullptr) { - free_lwe_keyswitch_key_u64(runtimeContext.ksk); - runtimeContext.ksk = nullptr; - } -} +PublicArguments::~PublicArguments() {} outcome::checked PublicArguments::serialize(std::ostream &ostream) { @@ -49,7 +35,6 @@ PublicArguments::serialize(std::ostream &ostream) { return StringError( "PublicArguments::serialize: ostream should be in binary mode"); } - ostream << runtimeContext; size_t iPreparedArgs = 0; int iGate = -1; for (auto gate : clientParameters.inputs) { @@ -122,16 +107,10 @@ PublicArguments::unserializeArgs(std::istream &istream) { outcome::checked, StringError> PublicArguments::unserialize(ClientParameters &clientParameters, std::istream &istream) { - RuntimeContext runtimeContext; - istream >> runtimeContext; - if (istream.fail()) { - return StringError("Cannot read runtime context"); - } std::vector empty; std::vector emptyBuffers; auto sArguments = std::make_unique( - clientParameters, runtimeContext, true, std::move(empty), - std::move(emptyBuffers)); + clientParameters, std::move(empty), std::move(emptyBuffers)); OUTCOME_TRYV(sArguments->unserializeArgs(istream)); return std::move(sArguments); } diff --git a/compiler/lib/ClientLib/Serializers.cpp b/compiler/lib/ClientLib/Serializers.cpp index c94b70bc2..15b34865b 100644 --- a/compiler/lib/ClientLib/Serializers.cpp +++ b/compiler/lib/ClientLib/Serializers.cpp @@ -67,16 +67,14 @@ std::istream &operator>>(std::istream &istream, LweBootstrapKey_u64 *&key) { std::istream &operator>>(std::istream &istream, RuntimeContext &runtimeContext) { - istream >> runtimeContext.ksk; - istream >> runtimeContext.bsk; + istream >> runtimeContext.evaluationKeys; assert(istream.good()); return istream; } std::ostream &operator<<(std::ostream &ostream, const RuntimeContext &runtimeContext) { - ostream << runtimeContext.ksk; - ostream << runtimeContext.bsk; + ostream << runtimeContext.evaluationKeys; assert(ostream.good()); return ostream; } @@ -147,5 +145,54 @@ TensorData unserializeTensorData( return result; } +std::ostream &operator<<(std::ostream &ostream, + const LweKeyswitchKey &wrappedKsk) { + ostream << wrappedKsk.ksk; + assert(ostream.good()); + return ostream; +} +std::istream &operator>>(std::istream &istream, LweKeyswitchKey &wrappedKsk) { + istream >> wrappedKsk.ksk; + assert(istream.good()); + return istream; +} + +std::ostream &operator<<(std::ostream &ostream, + const LweBootstrapKey &wrappedBsk) { + ostream << wrappedBsk.bsk; + assert(ostream.good()); + return ostream; +} +std::istream &operator>>(std::istream &istream, LweBootstrapKey &wrappedBsk) { + istream >> wrappedBsk.bsk; + assert(istream.good()); + return istream; +} + +std::ostream &operator<<(std::ostream &ostream, + const EvaluationKeys &evaluationKeys) { + ostream << *evaluationKeys.sharedKsk; + ostream << *evaluationKeys.sharedBsk; + assert(ostream.good()); + return ostream; +} + +std::istream &operator>>(std::istream &istream, + EvaluationKeys &evaluationKeys) { + auto sharedKsk = LweKeyswitchKey(nullptr); + auto sharedBsk = LweBootstrapKey(nullptr); + + istream >> sharedKsk; + istream >> sharedBsk; + + evaluationKeys.sharedKsk = + std::make_shared(std::move(sharedKsk)); + evaluationKeys.sharedBsk = + std::make_shared(std::move(sharedBsk)); + + assert(istream.good()); + return istream; +} + } // namespace clientlib -} // namespace concretelang \ No newline at end of file +} // namespace concretelang diff --git a/compiler/lib/Runtime/context.cpp b/compiler/lib/Runtime/context.cpp index a524a140a..09a3bc264 100644 --- a/compiler/lib/Runtime/context.cpp +++ b/compiler/lib/Runtime/context.cpp @@ -9,12 +9,12 @@ LweKeyswitchKey_u64 * get_keyswitch_key_u64(mlir::concretelang::RuntimeContext *context) { - return context->ksk; + return context->evaluationKeys.getKsk(); } LweBootstrapKey_u64 * get_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context) { - return context->bsk; + return context->evaluationKeys.getBsk(); } // Instantiate one engine per thread on demand diff --git a/compiler/lib/ServerLib/ServerLambda.cpp b/compiler/lib/ServerLib/ServerLambda.cpp index 8d4edd785..13760bea3 100644 --- a/compiler/lib/ServerLib/ServerLambda.cpp +++ b/compiler/lib/ServerLib/ServerLambda.cpp @@ -21,7 +21,9 @@ namespace serverlib { using concretelang::clientlib::CircuitGate; using concretelang::clientlib::CircuitGateShape; +using concretelang::clientlib::EvaluationKeys; using concretelang::clientlib::PublicArguments; +using concretelang::clientlib::RuntimeContext; using concretelang::error::StringError; outcome::checked @@ -74,10 +76,14 @@ TensorData dynamicCall(void *(*func)(void *...), } std::unique_ptr -ServerLambda::call(PublicArguments &args) { +ServerLambda::call(PublicArguments &args, EvaluationKeys &evaluationKeys) { std::vector preparedArgs(args.preparedArgs.begin(), args.preparedArgs.end()); - preparedArgs.push_back((void *)&args.runtimeContext); + + RuntimeContext runtimeContext; + runtimeContext.evaluationKeys = evaluationKeys; + preparedArgs.push_back((void *)&runtimeContext); + return clientlib::PublicResult::fromBuffers( clientParameters, {dynamicCall(this->func, preparedArgs, clientParameters.outputs[0])}); diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index c455331dc..62b4f9642 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -76,7 +76,8 @@ uint64_t numArgOfRankedMemrefCallingConvention(uint64_t rank) { } llvm::Expected> -JITLambda::call(clientlib::PublicArguments &args) { +JITLambda::call(clientlib::PublicArguments &args, + clientlib::EvaluationKeys &evaluationKeys) { #ifndef CONCRETELANG_PARALLEL_EXECUTION_ENABLED if (this->useDataflow) { return StreamStringError( @@ -116,9 +117,12 @@ JITLambda::call(clientlib::PublicArguments &args) { for (auto &arg : args.preparedArgs) { rawArgs[i++] = &arg; } + + RuntimeContext runtimeContext; + runtimeContext.evaluationKeys = evaluationKeys; // Pointer on runtime context, the rawArgs take pointer on actual value that // is passed to the compiled function. - auto rtCtxPtr = &args.runtimeContext; + auto rtCtxPtr = &runtimeContext; rawArgs[i++] = &rtCtxPtr; // Pointers on outputs for (auto &out : outputs) { diff --git a/compiler/tests/python/test_client_server.py b/compiler/tests/python/test_client_server.py new file mode 100644 index 000000000..1e81b6fb0 --- /dev/null +++ b/compiler/tests/python/test_client_server.py @@ -0,0 +1,112 @@ +import numpy as np +import pytest +import shutil +import tempfile + +from concrete.compiler import ( + ClientSupport, + EvaluationKeys, + LibrarySupport, + PublicArguments, + PublicResult, +) + + +@pytest.mark.parametrize( + "mlir, args, expected_result", + [ + pytest.param( + """ + +func @main(%arg0: !FHE.eint<5>, %arg1: i6) -> !FHE.eint<5> { + %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<5>, i6) -> (!FHE.eint<5>) + return %1: !FHE.eint<5> +} + + """, + (5, 7), + 12, + id="enc_plain_int_args", + marks=pytest.mark.xfail, + ), + pytest.param( + """ + +func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> { + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<5>, !FHE.eint<5>) -> (!FHE.eint<5>) + return %1: !FHE.eint<5> +} + + """, + (5, 7), + 12, + id="enc_enc_int_args", + ), + pytest.param( + """ + +func @main(%arg0: tensor<4x!FHE.eint<5>>, %arg1: tensor<4xi6>) -> !FHE.eint<5> { + %ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) : (tensor<4x!FHE.eint<5>>, tensor<4xi6>) -> !FHE.eint<5> + return %ret : !FHE.eint<5> +} + + """, + ( + np.array([1, 2, 3, 4], dtype=np.uint8), + np.array([4, 3, 2, 1], dtype=np.uint8), + ), + 20, + id="enc_plain_ndarray_args", + marks=pytest.mark.xfail, + ), + pytest.param( + """ + +func @main(%a0: tensor<4x!FHE.eint<5>>, %a1: tensor<4x!FHE.eint<5>>) -> tensor<4x!FHE.eint<5>> { + %res = "FHELinalg.add_eint"(%a0, %a1) : (tensor<4x!FHE.eint<5>>, tensor<4x!FHE.eint<5>>) -> tensor<4x!FHE.eint<5>> + return %res : tensor<4x!FHE.eint<5>> +} + + """, + ( + np.array([1, 2, 3, 4], dtype=np.uint8), + np.array([7, 0, 1, 5], dtype=np.uint8), + ), + np.array([8, 2, 4, 9]), + id="enc_enc_ndarray_args", + ), + ], +) +def test_client_server_end_to_end(mlir, args, expected_result, keyset_cache): + with tempfile.TemporaryDirectory() as tmpdirname: + support = LibrarySupport.new(str(tmpdirname)) + compilation_result = support.compile(mlir) + server_lambda = support.load_server_lambda(compilation_result) + + client_parameters = support.load_client_parameters(compilation_result) + keyset = ClientSupport.key_set(client_parameters, keyset_cache) + + evaluation_keys = keyset.get_evaluation_keys() + evaluation_keys_serialized = evaluation_keys.serialize() + evaluation_keys_unserialized = EvaluationKeys.unserialize( + evaluation_keys_serialized + ) + + args = ClientSupport.encrypt_arguments(client_parameters, keyset, args) + args_serialized = args.serialize() + args_unserialized = PublicArguments.unserialize( + client_parameters, args_serialized + ) + + result = support.server_call( + server_lambda, + args_unserialized, + evaluation_keys_unserialized, + ) + result_serialized = result.serialize() + result_unserialized = PublicResult.unserialize( + client_parameters, result_serialized + ) + + output = ClientSupport.decrypt_result(keyset, result_unserialized) + assert np.array_equal(output, expected_result) diff --git a/compiler/tests/python/test_compilation.py b/compiler/tests/python/test_compilation.py index cb2bec580..833afcdf4 100644 --- a/compiler/tests/python/test_compilation.py +++ b/compiler/tests/python/test_compilation.py @@ -32,7 +32,8 @@ def run(engine, args, compilation_result, keyset_cache): public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args) # Server server_lambda = engine.load_server_lambda(compilation_result) - public_result = engine.server_call(server_lambda, public_arguments) + evaluation_keys = key_set.get_evaluation_keys() + public_result = engine.server_call(server_lambda, public_arguments, evaluation_keys) # Client result = ClientSupport.decrypt_result(key_set, public_result) return result diff --git a/compiler/tests/python/test_serialization.py b/compiler/tests/python/test_serialization.py deleted file mode 100644 index ba634e32e..000000000 --- a/compiler/tests/python/test_serialization.py +++ /dev/null @@ -1,152 +0,0 @@ -import pytest -import shutil -import numpy as np -from concrete.compiler import ( - JITSupport, - LibrarySupport, - ClientSupport, - CompilationOptions, - PublicArguments, -) -from concrete.compiler.client_parameters import ClientParameters -from concrete.compiler.public_result import PublicResult - - -def assert_result(result, expected_result): - """Assert that result and expected result are equal. - - result and expected_result can be integers on numpy arrays. - """ - assert type(expected_result) == type(result) - if isinstance(expected_result, int): - assert result == expected_result - else: - assert np.all(result == expected_result) - - -def run_with_serialization( - engine, - args, - compilation_result, - keyset_cache, -): - """Execute engine on the given arguments. Performs serialization betwee client/server. - - Perform required loading, encryption, execution, and decryption.""" - # Client - client_parameters = engine.load_client_parameters(compilation_result) - serialized_client_parameters = client_parameters.serialize() - client_parameters = ClientParameters.unserialize(serialized_client_parameters) - key_set = ClientSupport.key_set(client_parameters, keyset_cache) - public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args) - public_arguments_buffer = public_arguments.serialize() - # Server - public_arguments = PublicArguments.unserialize( - client_parameters, public_arguments_buffer - ) - del public_arguments_buffer - server_lambda = engine.load_server_lambda(compilation_result) - public_result = engine.server_call(server_lambda, public_arguments) - public_result_buffer = public_result.serialize() - # Client - public_result = PublicResult.unserialize(client_parameters, public_result_buffer) - del public_result_buffer - result = ClientSupport.decrypt_result(key_set, public_result) - return result - - -def compile_run_assert_with_serialization( - engine, - mlir_input, - args, - expected_result, - keyset_cache, -): - """Compile run and assert result. Performs serialization betwee client/server. - - Can take both JITSupport or LibrarySupport as engine. - """ - options = CompilationOptions.new("main") - compilation_result = engine.compile(mlir_input, options) - result = run_with_serialization(engine, args, compilation_result, keyset_cache) - assert_result(result, expected_result) - - -end_to_end_fixture = [ - pytest.param( - """ - func @main(%arg0: !FHE.eint<5>, %arg1: i6) -> !FHE.eint<5> { - %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<5>, i6) -> (!FHE.eint<5>) - return %1: !FHE.eint<5> - } - """, - (5, 7), - 12, - id="enc_plain_int_args", - marks=pytest.mark.xfail, - ), - pytest.param( - """ - func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> { - %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<5>, !FHE.eint<5>) -> (!FHE.eint<5>) - return %1: !FHE.eint<5> - } - """, - (5, 7), - 12, - id="enc_enc_int_args", - ), - pytest.param( - """ - func @main(%arg0: tensor<4x!FHE.eint<5>>, %arg1: tensor<4xi6>) -> !FHE.eint<5> - { - %ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) : - (tensor<4x!FHE.eint<5>>, tensor<4xi6>) -> !FHE.eint<5> - return %ret : !FHE.eint<5> - } - """, - ( - np.array([1, 2, 3, 4], dtype=np.uint8), - np.array([4, 3, 2, 1], dtype=np.uint8), - ), - 20, - id="enc_plain_ndarray_args", - marks=pytest.mark.xfail, - ), - pytest.param( - """ - func @main(%a0: tensor<4x!FHE.eint<5>>, %a1: tensor<4x!FHE.eint<5>>) -> tensor<4x!FHE.eint<5>> { - %res = "FHELinalg.add_eint"(%a0, %a1) : (tensor<4x!FHE.eint<5>>, tensor<4x!FHE.eint<5>>) -> tensor<4x!FHE.eint<5>> - return %res : tensor<4x!FHE.eint<5>> - } - """, - ( - np.array([1, 2, 3, 4], dtype=np.uint8), - np.array([7, 0, 1, 5], dtype=np.uint8), - ), - np.array([8, 2, 4, 9]), - id="enc_enc_ndarray_args", - ), -] - - -@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture) -def test_jit_compile_and_run_with_serialization( - mlir_input, args, expected_result, keyset_cache -): - engine = JITSupport.new() - compile_run_assert_with_serialization( - engine, mlir_input, args, expected_result, keyset_cache - ) - - -@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture) -def test_lib_compile_and_run_with_serialization( - mlir_input, args, expected_result, keyset_cache -): - artifact_dir = "./py_test_lib_compile_and_run" - engine = LibrarySupport.new(artifact_dir) - compile_run_assert_with_serialization( - engine, mlir_input, args, expected_result, keyset_cache - ) - shutil.rmtree(artifact_dir) diff --git a/compiler/tests/unittest/end_to_end_jit_fhe.cc b/compiler/tests/unittest/end_to_end_jit_fhe.cc index d2bfd9e43..d7c6070c3 100644 --- a/compiler/tests/unittest/end_to_end_jit_fhe.cc +++ b/compiler/tests/unittest/end_to_end_jit_fhe.cc @@ -21,6 +21,8 @@ void compile_and_run(EndToEndDesc desc, LambdaSupport support) { auto keySet = support.keySet(*clientParameters, getTestKeySetCache()); ASSERT_EXPECTED_SUCCESS(keySet); + auto evaluationKeys = (*keySet)->evaluationKeys(); + /* 3 - Load the server lambda */ auto serverLambda = support.loadServerLambda(**compilationResult); ASSERT_EXPECTED_SUCCESS(serverLambda); @@ -41,7 +43,8 @@ void compile_and_run(EndToEndDesc desc, LambdaSupport support) { ASSERT_EXPECTED_SUCCESS(publicArguments); /* 5 - Call the server lambda */ - auto publicResult = support.serverCall(*serverLambda, **publicArguments); + auto publicResult = + support.serverCall(*serverLambda, **publicArguments, evaluationKeys); ASSERT_EXPECTED_SUCCESS(publicResult); /* 6 - Decrypt the public result */