refactor: separate runtime context from public arguments

This commit is contained in:
Umut
2022-05-18 17:32:42 +02:00
parent fcad585546
commit b052157fae
32 changed files with 548 additions and 269 deletions

View File

@@ -53,7 +53,8 @@ jit_load_server_lambda(JITSupport_C support,
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
jit_server_call(JITSupport_C support, mlir::concretelang::JITLambda &lambda,
concretelang::clientlib::PublicArguments &args);
concretelang::clientlib::PublicArguments &args,
concretelang::clientlib::EvaluationKeys &evaluationKeys);
// Library Support bindings ///////////////////////////////////////////////////
@@ -82,7 +83,8 @@ library_load_server_lambda(LibrarySupport_C support,
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
library_server_call(LibrarySupport_C support,
concretelang::serverlib::ServerLambda lambda,
concretelang::clientlib::PublicArguments &args);
concretelang::clientlib::PublicArguments &args,
concretelang::clientlib::EvaluationKeys &evaluationKeys);
MLIR_CAPI_EXPORTED std::string
library_get_shared_lib_path(LibrarySupport_C support);
@@ -128,6 +130,12 @@ publicResultUnserialize(mlir::concretelang::ClientParameters &clientParameters,
MLIR_CAPI_EXPORTED std::string
publicResultSerialize(concretelang::clientlib::PublicResult &publicResult);
MLIR_CAPI_EXPORTED concretelang::clientlib::EvaluationKeys
evaluationKeysUnserialize(const std::string &buffer);
MLIR_CAPI_EXPORTED std::string evaluationKeysSerialize(
concretelang::clientlib::EvaluationKeys &evaluationKeys);
// Parse then print a textual representation of an MLIR module
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);

View File

@@ -0,0 +1,108 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CLIENTLIB_EVALUATION_KEYS_H_
#define CONCRETELANG_CLIENTLIB_EVALUATION_KEYS_H_
#include <memory>
extern "C" {
#include "concrete-ffi.h"
}
namespace concretelang {
namespace clientlib {
// =============================================
/// Wrapper for `LweKeyswitchKey_u64` so that it cleans up properly.
class LweKeyswitchKey {
private:
LweKeyswitchKey_u64 *ksk;
protected:
friend std::ostream &operator<<(std::ostream &ostream,
const LweKeyswitchKey &wrappedKsk);
friend std::istream &operator>>(std::istream &istream,
LweKeyswitchKey &wrappedKsk);
public:
LweKeyswitchKey(LweKeyswitchKey_u64 *ksk) : ksk{ksk} {}
LweKeyswitchKey(LweKeyswitchKey &other) = delete;
LweKeyswitchKey(LweKeyswitchKey &&other) : ksk{other.ksk} {
other.ksk = nullptr;
}
~LweKeyswitchKey() {
if (this->ksk != nullptr) {
free_lwe_keyswitch_key_u64(this->ksk);
this->ksk = nullptr;
}
}
LweKeyswitchKey_u64 *get() { return this->ksk; }
};
// =============================================
/// Wrapper for `LweBootstrapKey_u64` so that it cleans up properly.
class LweBootstrapKey {
private:
LweBootstrapKey_u64 *bsk;
protected:
friend std::ostream &operator<<(std::ostream &ostream,
const LweBootstrapKey &wrappedBsk);
friend std::istream &operator>>(std::istream &istream,
LweBootstrapKey &wrappedBsk);
public:
LweBootstrapKey(LweBootstrapKey_u64 *bsk) : bsk{bsk} {}
LweBootstrapKey(LweBootstrapKey &other) = delete;
LweBootstrapKey(LweBootstrapKey &&other) : bsk{other.bsk} {
other.bsk = nullptr;
}
~LweBootstrapKey() {
if (this->bsk != nullptr) {
free_lwe_bootstrap_key_u64(this->bsk);
this->bsk = nullptr;
}
}
LweBootstrapKey_u64 *get() { return this->bsk; }
};
// =============================================
/// Evalution keys required for execution.
class EvaluationKeys {
private:
std::shared_ptr<LweKeyswitchKey> sharedKsk;
std::shared_ptr<LweBootstrapKey> sharedBsk;
protected:
friend std::ostream &operator<<(std::ostream &ostream,
const EvaluationKeys &evaluationKeys);
friend std::istream &operator>>(std::istream &istream,
EvaluationKeys &evaluationKeys);
public:
EvaluationKeys()
: sharedKsk{std::shared_ptr<LweKeyswitchKey>(nullptr)},
sharedBsk{std::shared_ptr<LweBootstrapKey>(nullptr)} {}
EvaluationKeys(std::shared_ptr<LweKeyswitchKey> sharedKsk,
std::shared_ptr<LweBootstrapKey> sharedBsk)
: sharedKsk{sharedKsk}, sharedBsk{sharedBsk} {}
LweKeyswitchKey_u64 *getKsk() { return this->sharedKsk->get(); }
LweBootstrapKey_u64 *getBsk() { return this->sharedBsk->get(); }
};
// =============================================
} // namespace clientlib
} // namespace concretelang
#endif

View File

@@ -16,6 +16,7 @@ extern "C" {
#include "concretelang/Runtime/context.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/Common/Error.h"
@@ -77,28 +78,29 @@ public:
CircuitGate inputGate(size_t pos) { return std::get<0>(inputs[pos]); }
CircuitGate outputGate(size_t pos) { return std::get<0>(outputs[pos]); }
void setRuntimeContext(RuntimeContext &context) {
context.ksk = std::get<1>(this->keyswitchKeys["ksk_v0"]);
context.bsk = std::get<1>(this->bootstrapKeys.at("bsk_v0"));
}
RuntimeContext runtimeContext() {
RuntimeContext context;
this->setRuntimeContext(context);
context.evaluationKeys = this->evaluationKeys();
return context;
}
EvaluationKeys evaluationKeys() {
auto sharedKsk = std::get<1>(this->keyswitchKeys.at("ksk_v0"));
auto sharedBsk = std::get<1>(this->bootstrapKeys.at("bsk_v0"));
return EvaluationKeys(sharedKsk, sharedBsk);
}
const std::map<LweSecretKeyID,
std::pair<LweSecretKeyParam, LweSecretKey_u64 *>> &
getSecretKeys();
const std::map<LweSecretKeyID,
std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>> &
getBootstrapKeys();
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>>
&getBootstrapKeys();
const std::map<LweSecretKeyID,
std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>> &
getKeyswitchKeys();
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>>
&getKeyswitchKeys();
protected:
outcome::checked<void, StringError>
@@ -124,9 +126,11 @@ private:
Engine *engine;
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
secretKeys;
std::map<LweSecretKeyID, std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>>
std::map<LweSecretKeyID,
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>>
bootstrapKeys;
std::map<LweSecretKeyID, std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>>
std::map<LweSecretKeyID,
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>>
keyswitchKeys;
std::vector<std::tuple<CircuitGate, LweSecretKeyParam, LweSecretKey_u64 *>>
inputs;
@@ -137,10 +141,10 @@ private:
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
secretKeys,
std::map<LweSecretKeyID,
std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>>
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>>
bootstrapKeys,
std::map<LweSecretKeyID,
std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>>
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>>
keyswitchKeys);
};

View File

@@ -37,7 +37,6 @@ class PublicArguments {
/// arguments and public keys.
public:
PublicArguments(const ClientParameters &clientParameters,
RuntimeContext runtimeContext, bool clearRuntimeContext,
std::vector<void *> &&preparedArgs,
std::vector<TensorData> &&ciphertextBuffers);
~PublicArguments();
@@ -56,13 +55,9 @@ private:
outcome::checked<void, StringError> unserializeArgs(std::istream &istream);
ClientParameters clientParameters;
RuntimeContext runtimeContext;
std::vector<void *> preparedArgs;
// Store buffers of ciphertexts
std::vector<TensorData> ciphertextBuffers;
// Indicates if this public argument own the runtime keys.
bool clearRuntimeContext;
};
struct PublicResult {

View File

@@ -9,6 +9,7 @@
#include <iostream>
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/ClientLib/Types.h"
#include "concretelang/Runtime/context.h"
@@ -67,6 +68,18 @@ TensorData unserializeTensorData(
// accomodate non static sizes
std::istream &istream);
std::ostream &operator<<(std::ostream &ostream,
const LweKeyswitchKey &wrappedKsk);
std::istream &operator>>(std::istream &istream, LweKeyswitchKey &wrappedKsk);
std::ostream &operator<<(std::ostream &ostream,
const LweBootstrapKey &wrappedBsk);
std::istream &operator>>(std::istream &istream, LweBootstrapKey &wrappedBsk);
std::ostream &operator<<(std::ostream &ostream,
const EvaluationKeys &evaluationKeys);
std::istream &operator>>(std::istream &istream, EvaluationKeys &evaluationKeys);
} // namespace clientlib
} // namespace concretelang

View File

@@ -10,6 +10,8 @@
#include <mutex>
#include <pthread.h>
#include "concretelang/ClientLib/EvaluationKeys.h"
extern "C" {
#include "concrete-ffi.h"
}
@@ -18,16 +20,18 @@ namespace mlir {
namespace concretelang {
typedef struct RuntimeContext {
LweKeyswitchKey_u64 *ksk;
LweBootstrapKey_u64 *bsk;
::concretelang::clientlib::EvaluationKeys evaluationKeys;
std::map<pthread_t, Engine *> engines;
std::mutex engines_map_guard;
RuntimeContext() {}
// Ensure that the engines map is not copied
RuntimeContext(const RuntimeContext &ctx) : ksk(ctx.ksk), bsk(ctx.bsk) {}
RuntimeContext(const RuntimeContext &ctx)
: evaluationKeys(ctx.evaluationKeys) {}
RuntimeContext(const RuntimeContext &&other)
: ksk(other.ksk), bsk(other.bsk) {}
: evaluationKeys(other.evaluationKeys) {}
~RuntimeContext() {
for (const auto &key : engines) {
free_engine(key.second);
@@ -35,8 +39,7 @@ typedef struct RuntimeContext {
}
RuntimeContext &operator=(const RuntimeContext &rhs) {
ksk = rhs.ksk;
bsk = rhs.bsk;
this->evaluationKeys = rhs.evaluationKeys;
return *this;
}
} RuntimeContext;

View File

@@ -39,7 +39,8 @@ public:
/// Call the ServerLambda with public arguments.
std::unique_ptr<clientlib::PublicResult>
call(clientlib::PublicArguments &args);
call(clientlib::PublicArguments &args,
clientlib::EvaluationKeys &evaluationKeys);
protected:
ClientParameters clientParameters;
@@ -51,4 +52,4 @@ protected:
} // namespace serverlib
} // namespace concretelang
#endif
#endif

View File

@@ -51,8 +51,9 @@ public:
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
serverCall(std::shared_ptr<concretelang::JITLambda> lambda,
clientlib::PublicArguments &args) override {
return lambda->call(args);
clientlib::PublicArguments &args,
clientlib::EvaluationKeys &evaluationKeys) override {
return lambda->call(args, evaluationKeys);
}
private:

View File

@@ -36,7 +36,8 @@ public:
/// Call the JIT lambda with the public arguments.
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
call(clientlib::PublicArguments &args);
call(clientlib::PublicArguments &args,
clientlib::EvaluationKeys &evaluationKeys);
void setUseDataflow(bool option) { this->useDataflow = option; }

View File

@@ -275,7 +275,8 @@ public:
/// Call the lambda with the public arguments.
llvm::Expected<std::unique_ptr<clientlib::PublicResult>> virtual serverCall(
Lambda lambda, clientlib::PublicArguments &args) = 0;
Lambda lambda, clientlib::PublicArguments &args,
clientlib::EvaluationKeys &evaluationKeys) = 0;
/// Build the client KeySet from the client parameters.
static llvm::Expected<std::unique_ptr<clientlib::KeySet>>
@@ -302,11 +303,12 @@ public:
}
template <typename ResT>
static llvm::Expected<ResT>
call(Lambda lambda, clientlib::PublicArguments &publicArguments) {
static llvm::Expected<ResT> call(Lambda lambda,
clientlib::PublicArguments &publicArguments,
clientlib::EvaluationKeys &evaluationKeys) {
// Call the lambda
auto publicResult = LambdaSupport<Lambda, CompilationResult>().serverCall(
lambda, publicArguments);
lambda, publicArguments, evaluationKeys);
if (auto err = publicResult.takeError()) {
return std::move(err);
}
@@ -357,7 +359,9 @@ public:
return std::move(err);
}
auto publicResult = support.serverCall(lambda, **publicArguments);
auto evaluationKeys = this->keySet->evaluationKeys();
auto publicResult =
support.serverCall(lambda, **publicArguments, evaluationKeys);
if (auto err = publicResult.takeError()) {
return std::move(err);
}
@@ -375,7 +379,9 @@ public:
if (!publicArguments.has_value()) {
return StreamStringError(publicArguments.error().mesg);
}
auto publicResult = support.serverCall(lambda, *publicArguments.value());
auto evaluationKeys = keySet->evaluationKeys();
auto publicResult =
support.serverCall(lambda, *publicArguments.value(), evaluationKeys);
if (auto err = publicResult.takeError()) {
return std::move(err);
}
@@ -394,7 +400,9 @@ public:
if (publicArguments.has_error()) {
return StreamStringError(publicArguments.error().mesg);
}
auto publicResult = support.serverCall(lambda, *publicArguments.value());
auto evaluationKeys = keySet->evaluationKeys();
auto publicResult =
support.serverCall(lambda, *publicArguments.value(), evaluationKeys);
if (auto err = publicResult.takeError()) {
return std::move(err);
}

View File

@@ -103,9 +103,9 @@ public:
/// Call the lambda with the public arguments.
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
serverCall(serverlib::ServerLambda lambda,
clientlib::PublicArguments &args) override {
return lambda.call(args);
serverCall(serverlib::ServerLambda lambda, clientlib::PublicArguments &args,
clientlib::EvaluationKeys &evaluationKeys) override {
return lambda.call(args, evaluationKeys);
}
/// Get path to shared library

View File

@@ -89,7 +89,8 @@ public:
// serverInput));
// server function call
auto publicResult = serverLambda.call(*publicArgument);
auto evaluationKeys = keySet->evaluationKeys();
auto publicResult = serverLambda.call(*publicArgument, evaluationKeys);
// client result decryption
return this->decryptResult(*keySet, *publicResult);

View File

@@ -41,6 +41,7 @@ declare_mlir_python_sources(ConcretelangBindingsPythonSources
concrete/compiler/library_lambda.py
concrete/compiler/public_arguments.py
concrete/compiler/public_result.py
concrete/compiler/evaluation_keys.py
concrete/compiler/utils.py
concrete/compiler/wrapper.py
concrete/__init__.py
@@ -119,4 +120,4 @@ add_mlir_python_modules(ConcretelangPythonModules
ConcretelangBindingsPythonSources
COMMON_CAPI_LINK_LIBS
ConcretelangBindingsPythonCAPI
)
)

View File

@@ -90,8 +90,10 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
pybind11::return_value_policy::reference)
.def("server_call",
[](JITSupport_C &support, concretelang::JITLambda &lambda,
clientlib::PublicArguments &publicArguments) {
return jit_server_call(support, lambda, publicArguments);
clientlib::PublicArguments &publicArguments,
clientlib::EvaluationKeys &evaluationKeys) {
return jit_server_call(support, lambda, publicArguments,
evaluationKeys);
});
pybind11::class_<mlir::concretelang::LibraryCompilationResult>(
@@ -132,8 +134,10 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
pybind11::return_value_policy::reference)
.def("server_call",
[](LibrarySupport_C &support, serverlib::ServerLambda lambda,
clientlib::PublicArguments &publicArguments) {
return library_server_call(support, lambda, publicArguments);
clientlib::PublicArguments &publicArguments,
clientlib::EvaluationKeys &evaluationKeys) {
return library_server_call(support, lambda, publicArguments,
evaluationKeys);
})
.def("get_shared_lib_path",
[](LibrarySupport_C &support) {
@@ -185,7 +189,10 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
clientParametersSerialize(clientParameters));
});
pybind11::class_<clientlib::KeySet>(m, "KeySet");
pybind11::class_<clientlib::KeySet>(m, "KeySet")
.def("get_evaluation_keys",
[](clientlib::KeySet &keySet) { return keySet.evaluationKeys(); });
pybind11::class_<clientlib::PublicArguments,
std::unique_ptr<clientlib::PublicArguments>>(
m, "PublicArguments")
@@ -207,6 +214,15 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
return pybind11::bytes(publicResultSerialize(publicResult));
});
pybind11::class_<clientlib::EvaluationKeys>(m, "EvaluationKeys")
.def_static("unserialize",
[](const pybind11::bytes &buffer) {
return evaluationKeysUnserialize(buffer);
})
.def("serialize", [](clientlib::EvaluationKeys &evaluationKeys) {
return pybind11::bytes(evaluationKeysSerialize(evaluationKeys));
});
pybind11::class_<lambdaArgument>(m, "LambdaArgument")
.def_static("from_tensor",
[](std::vector<uint8_t> tensor, std::vector<int64_t> dims) {

View File

@@ -26,6 +26,7 @@ from .library_lambda import LibraryLambda
from .client_support import ClientSupport
from .jit_support import JITSupport
from .library_support import LibrarySupport
from .evaluation_keys import EvaluationKeys
# Terminate parallelization in the compiler (if init) during cleanup

View File

@@ -0,0 +1,63 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
"""EvaluationKeys."""
# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
EvaluationKeys as _EvaluationKeys,
)
# pylint: enable=no-name-in-module,import-error
from .wrapper import WrapperCpp
class EvaluationKeys(WrapperCpp):
"""
EvaluationKeys required for execution.
"""
def __init__(self, evaluation_keys: _EvaluationKeys):
"""Wrap the native Cpp object.
Args:
evaluation_keys (_EvaluationKeys): object to wrap
Raises:
TypeError: if evaluation_keys is not of type _EvaluationKeys
"""
if not isinstance(evaluation_keys, _EvaluationKeys):
raise TypeError(
f"evaluation_keys must be of type _EvaluationKeys, not {type(evaluation_keys)}"
)
super().__init__(evaluation_keys)
def serialize(self) -> bytes:
"""Serialize the EvaluationKeys.
Returns:
bytes: serialized object
"""
return self.cpp().serialize()
@staticmethod
def unserialize(serialized_evaluation_keys: bytes) -> "EvaluationKeys":
"""Unserialize EvaluationKeys from bytes.
Args:
serialized_evaluation_keys (bytes): previously serialized EvaluationKeys
Raises:
TypeError: if serialized_evaluation_keys is not of type bytes
Returns:
EvaluationKeys: unserialized object
"""
if not isinstance(serialized_evaluation_keys, bytes):
raise TypeError(
f"serialized_evaluation_keys must be of type bytes, "
f"not {type(serialized_evaluation_keys)}"
)
return EvaluationKeys.wrap(
_EvaluationKeys.unserialize(serialized_evaluation_keys)
)

View File

@@ -23,6 +23,7 @@ from .jit_lambda import JITLambda
from .public_arguments import PublicArguments
from .public_result import PublicResult
from .wrapper import WrapperCpp
from .evaluation_keys import EvaluationKeys
class JITSupport(WrapperCpp):
@@ -139,17 +140,22 @@ class JITSupport(WrapperCpp):
return JITLambda.wrap(self.cpp().load_server_lambda(compilation_result.cpp()))
def server_call(
self, jit_lambda: JITLambda, public_arguments: PublicArguments
self,
jit_lambda: JITLambda,
public_arguments: PublicArguments,
evaluation_keys: EvaluationKeys,
) -> PublicResult:
"""Call the JITLambda with public_arguments.
Args:
jit_lambda (JITLambda): A server lambda to call.
public_arguments (PublicArguments): The arguments of the call.
evaluation_keys (EvaluationKeys): Evalutation keys of the call.
Raises:
TypeError: if jit_lambda is not of type JITLambda
TypeError: if public_arguments is not of type PublicArguments
TypeError: if evaluation_keys is not of type EvaluationKeys
Returns:
PublicResult: the result of the call of the server lambda.
@@ -162,6 +168,12 @@ class JITSupport(WrapperCpp):
raise TypeError(
f"public_arguments must be of type PublicArguments, not {type(public_arguments)}"
)
if not isinstance(evaluation_keys, EvaluationKeys):
raise TypeError(
f"evaluation_keys must be of type EvaluationKeys, not {type(evaluation_keys)}"
)
return PublicResult.wrap(
self.cpp().server_call(jit_lambda.cpp(), public_arguments.cpp())
self.cpp().server_call(
jit_lambda.cpp(), public_arguments.cpp(), evaluation_keys.cpp()
)
)

View File

@@ -14,6 +14,7 @@ from mlir._mlir_libs._concretelang._compiler import (
# pylint: enable=no-name-in-module,import-error
from .wrapper import WrapperCpp
from .evaluation_keys import EvaluationKeys
class KeySet(WrapperCpp):
@@ -34,3 +35,13 @@ class KeySet(WrapperCpp):
if not isinstance(keyset, _KeySet):
raise TypeError(f"keyset must be of type _KeySet, not {type(keyset)}")
super().__init__(keyset)
def get_evaluation_keys(self) -> EvaluationKeys:
"""
Get evaluation keys for execution.
Returns:
EvaluationKeys:
evaluation keys for execution
"""
return EvaluationKeys(self.cpp().get_evaluation_keys())

View File

@@ -23,6 +23,7 @@ from .public_result import PublicResult
from .client_parameters import ClientParameters
from .wrapper import WrapperCpp
from .utils import lookup_runtime_lib
from .evaluation_keys import EvaluationKeys
# Default output path for compilation artifacts
@@ -211,17 +212,22 @@ class LibrarySupport(WrapperCpp):
)
def server_call(
self, library_lambda: LibraryLambda, public_arguments: PublicArguments
self,
library_lambda: LibraryLambda,
public_arguments: PublicArguments,
evaluation_keys: EvaluationKeys,
) -> PublicResult:
"""Call the library with public_arguments.
Args:
library_lambda (LibraryLambda): reference to the compiled library
public_arguments (PublicArguments): arguments to use for execution
evaluation_keys (EvaluationKeys): evaluation keys to use for execution
Raises:
TypeError: if library_lambda is not of type LibraryLambda
TypeError: if public_arguments is not of type PublicArguments
TypeError: if evaluation_keys is not of type EvaluationKeys
Returns:
PublicResult: result of the execution
@@ -234,8 +240,16 @@ class LibrarySupport(WrapperCpp):
raise TypeError(
f"public_arguments must be of type PublicArguments, not {type(public_arguments)}"
)
if not isinstance(evaluation_keys, EvaluationKeys):
raise TypeError(
f"evaluation_keys must be of type EvaluationKeys, not {type(evaluation_keys)}"
)
return PublicResult.wrap(
self.cpp().server_call(library_lambda.cpp(), public_arguments.cpp())
self.cpp().server_call(
library_lambda.cpp(),
public_arguments.cpp(),
evaluation_keys.cpp(),
)
)
def get_shared_lib_path(self) -> str:

View File

@@ -7,6 +7,7 @@
#include "concretelang-c/Support/CompilerEngine.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/ClientLib/Serializers.h"
#include "concretelang/Runtime/runtime_api.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/JITSupport.h"
@@ -53,8 +54,9 @@ jit_load_server_lambda(JITSupport_C support,
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
jit_server_call(JITSupport_C support, mlir::concretelang::JITLambda &lambda,
concretelang::clientlib::PublicArguments &args) {
GET_OR_THROW_LLVM_EXPECTED(publicResult, lambda.call(args));
concretelang::clientlib::PublicArguments &args,
concretelang::clientlib::EvaluationKeys &evaluationKeys) {
GET_OR_THROW_LLVM_EXPECTED(publicResult, lambda.call(args, evaluationKeys));
return std::move(*publicResult);
}
@@ -97,9 +99,10 @@ library_load_server_lambda(
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
library_server_call(LibrarySupport_C support,
concretelang::serverlib::ServerLambda lambda,
concretelang::clientlib::PublicArguments &args) {
GET_OR_THROW_LLVM_EXPECTED(publicResult,
support.support.serverCall(lambda, args));
concretelang::clientlib::PublicArguments &args,
concretelang::clientlib::EvaluationKeys &evaluationKeys) {
GET_OR_THROW_LLVM_EXPECTED(
publicResult, support.support.serverCall(lambda, args, evaluationKeys));
return std::move(*publicResult);
}
@@ -192,6 +195,27 @@ publicResultSerialize(concretelang::clientlib::PublicResult &publicResult) {
return buffer.str();
}
MLIR_CAPI_EXPORTED concretelang::clientlib::EvaluationKeys
evaluationKeysUnserialize(const std::string &buffer) {
std::stringstream istream(buffer);
concretelang::clientlib::EvaluationKeys evaluationKeys;
concretelang::clientlib::operator>>(istream, evaluationKeys);
if (istream.fail()) {
throw std::runtime_error("Cannot read evaluation keys");
}
return evaluationKeys;
}
MLIR_CAPI_EXPORTED std::string evaluationKeysSerialize(
concretelang::clientlib::EvaluationKeys &evaluationKeys) {
std::ostringstream buffer(std::ios::binary);
concretelang::clientlib::operator<<(buffer, evaluationKeys);
return buffer.str();
}
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
clientParametersUnserialize(const std::string &json) {
GET_OR_THROW_LLVM_EXPECTED(

View File

@@ -25,11 +25,8 @@ size_t bitWidthAsWord(size_t exactBitWidth) {
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
EncryptedArguments::exportPublicArguments(ClientParameters clientParameters,
RuntimeContext runtimeContext) {
// On client side the runtimeContext is hold by the KeySet
bool clearContext = false;
return std::make_unique<PublicArguments>(
clientParameters, runtimeContext, clearContext, std::move(preparedArgs),
std::move(ciphertextBuffers));
clientParameters, std::move(preparedArgs), std::move(ciphertextBuffers));
}
outcome::checked<void, StringError>

View File

@@ -24,12 +24,6 @@ KeySet::~KeySet() {
for (auto it : secretKeys) {
free_lwe_secret_key_u64(it.second.second);
}
for (auto it : bootstrapKeys) {
free_lwe_bootstrap_key_u64(it.second.second);
}
for (auto it : keyswitchKeys) {
free_lwe_keyswitch_key_u64(it.second.second);
}
free_engine(engine);
}
@@ -115,10 +109,10 @@ void KeySet::setKeys(
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
secretKeys,
std::map<LweSecretKeyID,
std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>>
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>>
bootstrapKeys,
std::map<LweSecretKeyID,
std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>>
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>>
keyswitchKeys) {
this->secretKeys = secretKeys;
this->bootstrapKeys = bootstrapKeys;
@@ -160,7 +154,7 @@ KeySet::generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param) {
param.level, param.variance, param.glweDimension, polynomialSize);
// Store the bootstrap key
bootstrapKeys[id] = {param, bsk};
bootstrapKeys[id] = {param, std::make_shared<LweBootstrapKey>(bsk)};
return outcome::success();
}
@@ -184,7 +178,7 @@ KeySet::generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param) {
param.baseLog, param.variance);
// Store the keyswitch key
keyswitchKeys[id] = {param, ksk};
keyswitchKeys[id] = {param, std::make_shared<LweKeyswitchKey>(ksk)};
return outcome::success();
}
@@ -253,13 +247,13 @@ const std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
}
const std::map<LweSecretKeyID,
std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>> &
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>> &
KeySet::getBootstrapKeys() {
return bootstrapKeys;
}
const std::map<LweSecretKeyID,
std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>> &
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>> &
KeySet::getKeyswitchKeys() {
return keyswitchKeys;
}

View File

@@ -5,6 +5,7 @@
#include "boost/outcome.h"
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/FileSystem.h"
@@ -96,9 +97,11 @@ KeySetCache::loadKeys(ClientParameters &params, uint64_t seed_msb,
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
secretKeys;
std::map<LweSecretKeyID, std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>>
std::map<LweSecretKeyID,
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>>
bootstrapKeys;
std::map<LweSecretKeyID, std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>>
std::map<LweSecretKeyID,
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>>
keyswitchKeys;
// Load LWE secret keys
@@ -117,7 +120,7 @@ KeySetCache::loadKeys(ClientParameters &params, uint64_t seed_msb,
llvm::SmallString<0> path(folderPath);
llvm::sys::path::append(path, "pbsKey_" + id);
OUTCOME_TRY(LweBootstrapKey_u64 * bsk, loadBootstrapKey(path));
bootstrapKeys[id] = {param, bsk};
bootstrapKeys[id] = {param, std::make_shared<LweBootstrapKey>(bsk)};
}
// Load keyswitch keys
for (auto keyswitchParam : params.keyswitchKeys) {
@@ -126,7 +129,7 @@ KeySetCache::loadKeys(ClientParameters &params, uint64_t seed_msb,
llvm::SmallString<0> path(folderPath);
llvm::sys::path::append(path, "ksKey_" + id);
OUTCOME_TRY(LweKeyswitchKey_u64 * ksk, loadKeyswitchKey(path));
keyswitchKeys[id] = {param, ksk};
keyswitchKeys[id] = {param, std::make_shared<LweKeyswitchKey>(ksk)};
}
key_set->setKeys(secretKeys, bootstrapKeys, keyswitchKeys);
@@ -162,7 +165,7 @@ outcome::checked<void, StringError> saveKeys(KeySet &key_set,
auto key = bootstrapKeyParam.second.second;
llvm::SmallString<0> path = folderIncompletePath;
llvm::sys::path::append(path, "pbsKey_" + id);
saveBootstrapKey(path, key);
saveBootstrapKey(path, key->get());
}
// Save keyswitch keys
for (auto keyswitchParam : key_set.getKeyswitchKeys()) {
@@ -170,7 +173,7 @@ outcome::checked<void, StringError> saveKeys(KeySet &key_set,
auto key = keyswitchParam.second.second;
llvm::SmallString<0> path = folderIncompletePath;
llvm::sys::path::append(path, "ksKey_" + id);
saveKeyswitchKey(path, key);
saveKeyswitchKey(path, key->get());
}
err = llvm::sys::fs::rename(folderIncompletePath, folderPath);

View File

@@ -20,28 +20,14 @@ using concretelang::error::StringError;
// TODO: optimize the move
PublicArguments::PublicArguments(const ClientParameters &clientParameters,
RuntimeContext runtimeContext,
bool clearRuntimeContext,
std::vector<void *> &&preparedArgs_,
std::vector<TensorData> &&ciphertextBuffers_)
: clientParameters(clientParameters), runtimeContext(runtimeContext),
clearRuntimeContext(clearRuntimeContext) {
: clientParameters(clientParameters) {
preparedArgs = std::move(preparedArgs_);
ciphertextBuffers = std::move(ciphertextBuffers_);
}
PublicArguments::~PublicArguments() {
if (!clearRuntimeContext) {
return;
}
if (runtimeContext.bsk != nullptr) {
free_lwe_bootstrap_key_u64(runtimeContext.bsk);
}
if (runtimeContext.ksk != nullptr) {
free_lwe_keyswitch_key_u64(runtimeContext.ksk);
runtimeContext.ksk = nullptr;
}
}
PublicArguments::~PublicArguments() {}
outcome::checked<void, StringError>
PublicArguments::serialize(std::ostream &ostream) {
@@ -49,7 +35,6 @@ PublicArguments::serialize(std::ostream &ostream) {
return StringError(
"PublicArguments::serialize: ostream should be in binary mode");
}
ostream << runtimeContext;
size_t iPreparedArgs = 0;
int iGate = -1;
for (auto gate : clientParameters.inputs) {
@@ -122,16 +107,10 @@ PublicArguments::unserializeArgs(std::istream &istream) {
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
PublicArguments::unserialize(ClientParameters &clientParameters,
std::istream &istream) {
RuntimeContext runtimeContext;
istream >> runtimeContext;
if (istream.fail()) {
return StringError("Cannot read runtime context");
}
std::vector<void *> empty;
std::vector<TensorData> emptyBuffers;
auto sArguments = std::make_unique<PublicArguments>(
clientParameters, runtimeContext, true, std::move(empty),
std::move(emptyBuffers));
clientParameters, std::move(empty), std::move(emptyBuffers));
OUTCOME_TRYV(sArguments->unserializeArgs(istream));
return std::move(sArguments);
}

View File

@@ -67,16 +67,14 @@ std::istream &operator>>(std::istream &istream, LweBootstrapKey_u64 *&key) {
std::istream &operator>>(std::istream &istream,
RuntimeContext &runtimeContext) {
istream >> runtimeContext.ksk;
istream >> runtimeContext.bsk;
istream >> runtimeContext.evaluationKeys;
assert(istream.good());
return istream;
}
std::ostream &operator<<(std::ostream &ostream,
const RuntimeContext &runtimeContext) {
ostream << runtimeContext.ksk;
ostream << runtimeContext.bsk;
ostream << runtimeContext.evaluationKeys;
assert(ostream.good());
return ostream;
}
@@ -147,5 +145,54 @@ TensorData unserializeTensorData(
return result;
}
std::ostream &operator<<(std::ostream &ostream,
const LweKeyswitchKey &wrappedKsk) {
ostream << wrappedKsk.ksk;
assert(ostream.good());
return ostream;
}
std::istream &operator>>(std::istream &istream, LweKeyswitchKey &wrappedKsk) {
istream >> wrappedKsk.ksk;
assert(istream.good());
return istream;
}
std::ostream &operator<<(std::ostream &ostream,
const LweBootstrapKey &wrappedBsk) {
ostream << wrappedBsk.bsk;
assert(ostream.good());
return ostream;
}
std::istream &operator>>(std::istream &istream, LweBootstrapKey &wrappedBsk) {
istream >> wrappedBsk.bsk;
assert(istream.good());
return istream;
}
std::ostream &operator<<(std::ostream &ostream,
const EvaluationKeys &evaluationKeys) {
ostream << *evaluationKeys.sharedKsk;
ostream << *evaluationKeys.sharedBsk;
assert(ostream.good());
return ostream;
}
std::istream &operator>>(std::istream &istream,
EvaluationKeys &evaluationKeys) {
auto sharedKsk = LweKeyswitchKey(nullptr);
auto sharedBsk = LweBootstrapKey(nullptr);
istream >> sharedKsk;
istream >> sharedBsk;
evaluationKeys.sharedKsk =
std::make_shared<LweKeyswitchKey>(std::move(sharedKsk));
evaluationKeys.sharedBsk =
std::make_shared<LweBootstrapKey>(std::move(sharedBsk));
assert(istream.good());
return istream;
}
} // namespace clientlib
} // namespace concretelang
} // namespace concretelang

View File

@@ -9,12 +9,12 @@
LweKeyswitchKey_u64 *
get_keyswitch_key_u64(mlir::concretelang::RuntimeContext *context) {
return context->ksk;
return context->evaluationKeys.getKsk();
}
LweBootstrapKey_u64 *
get_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context) {
return context->bsk;
return context->evaluationKeys.getBsk();
}
// Instantiate one engine per thread on demand

View File

@@ -21,7 +21,9 @@ namespace serverlib {
using concretelang::clientlib::CircuitGate;
using concretelang::clientlib::CircuitGateShape;
using concretelang::clientlib::EvaluationKeys;
using concretelang::clientlib::PublicArguments;
using concretelang::clientlib::RuntimeContext;
using concretelang::error::StringError;
outcome::checked<ServerLambda, StringError>
@@ -74,10 +76,14 @@ TensorData dynamicCall(void *(*func)(void *...),
}
std::unique_ptr<clientlib::PublicResult>
ServerLambda::call(PublicArguments &args) {
ServerLambda::call(PublicArguments &args, EvaluationKeys &evaluationKeys) {
std::vector<void *> preparedArgs(args.preparedArgs.begin(),
args.preparedArgs.end());
preparedArgs.push_back((void *)&args.runtimeContext);
RuntimeContext runtimeContext;
runtimeContext.evaluationKeys = evaluationKeys;
preparedArgs.push_back((void *)&runtimeContext);
return clientlib::PublicResult::fromBuffers(
clientParameters,
{dynamicCall(this->func, preparedArgs, clientParameters.outputs[0])});

View File

@@ -76,7 +76,8 @@ uint64_t numArgOfRankedMemrefCallingConvention(uint64_t rank) {
}
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
JITLambda::call(clientlib::PublicArguments &args) {
JITLambda::call(clientlib::PublicArguments &args,
clientlib::EvaluationKeys &evaluationKeys) {
#ifndef CONCRETELANG_PARALLEL_EXECUTION_ENABLED
if (this->useDataflow) {
return StreamStringError(
@@ -116,9 +117,12 @@ JITLambda::call(clientlib::PublicArguments &args) {
for (auto &arg : args.preparedArgs) {
rawArgs[i++] = &arg;
}
RuntimeContext runtimeContext;
runtimeContext.evaluationKeys = evaluationKeys;
// Pointer on runtime context, the rawArgs take pointer on actual value that
// is passed to the compiled function.
auto rtCtxPtr = &args.runtimeContext;
auto rtCtxPtr = &runtimeContext;
rawArgs[i++] = &rtCtxPtr;
// Pointers on outputs
for (auto &out : outputs) {

View File

@@ -0,0 +1,112 @@
import numpy as np
import pytest
import shutil
import tempfile
from concrete.compiler import (
ClientSupport,
EvaluationKeys,
LibrarySupport,
PublicArguments,
PublicResult,
)
@pytest.mark.parametrize(
"mlir, args, expected_result",
[
pytest.param(
"""
func @main(%arg0: !FHE.eint<5>, %arg1: i6) -> !FHE.eint<5> {
%1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<5>, i6) -> (!FHE.eint<5>)
return %1: !FHE.eint<5>
}
""",
(5, 7),
12,
id="enc_plain_int_args",
marks=pytest.mark.xfail,
),
pytest.param(
"""
func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<5>, !FHE.eint<5>) -> (!FHE.eint<5>)
return %1: !FHE.eint<5>
}
""",
(5, 7),
12,
id="enc_enc_int_args",
),
pytest.param(
"""
func @main(%arg0: tensor<4x!FHE.eint<5>>, %arg1: tensor<4xi6>) -> !FHE.eint<5> {
%ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) : (tensor<4x!FHE.eint<5>>, tensor<4xi6>) -> !FHE.eint<5>
return %ret : !FHE.eint<5>
}
""",
(
np.array([1, 2, 3, 4], dtype=np.uint8),
np.array([4, 3, 2, 1], dtype=np.uint8),
),
20,
id="enc_plain_ndarray_args",
marks=pytest.mark.xfail,
),
pytest.param(
"""
func @main(%a0: tensor<4x!FHE.eint<5>>, %a1: tensor<4x!FHE.eint<5>>) -> tensor<4x!FHE.eint<5>> {
%res = "FHELinalg.add_eint"(%a0, %a1) : (tensor<4x!FHE.eint<5>>, tensor<4x!FHE.eint<5>>) -> tensor<4x!FHE.eint<5>>
return %res : tensor<4x!FHE.eint<5>>
}
""",
(
np.array([1, 2, 3, 4], dtype=np.uint8),
np.array([7, 0, 1, 5], dtype=np.uint8),
),
np.array([8, 2, 4, 9]),
id="enc_enc_ndarray_args",
),
],
)
def test_client_server_end_to_end(mlir, args, expected_result, keyset_cache):
with tempfile.TemporaryDirectory() as tmpdirname:
support = LibrarySupport.new(str(tmpdirname))
compilation_result = support.compile(mlir)
server_lambda = support.load_server_lambda(compilation_result)
client_parameters = support.load_client_parameters(compilation_result)
keyset = ClientSupport.key_set(client_parameters, keyset_cache)
evaluation_keys = keyset.get_evaluation_keys()
evaluation_keys_serialized = evaluation_keys.serialize()
evaluation_keys_unserialized = EvaluationKeys.unserialize(
evaluation_keys_serialized
)
args = ClientSupport.encrypt_arguments(client_parameters, keyset, args)
args_serialized = args.serialize()
args_unserialized = PublicArguments.unserialize(
client_parameters, args_serialized
)
result = support.server_call(
server_lambda,
args_unserialized,
evaluation_keys_unserialized,
)
result_serialized = result.serialize()
result_unserialized = PublicResult.unserialize(
client_parameters, result_serialized
)
output = ClientSupport.decrypt_result(keyset, result_unserialized)
assert np.array_equal(output, expected_result)

View File

@@ -32,7 +32,8 @@ def run(engine, args, compilation_result, keyset_cache):
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
# Server
server_lambda = engine.load_server_lambda(compilation_result)
public_result = engine.server_call(server_lambda, public_arguments)
evaluation_keys = key_set.get_evaluation_keys()
public_result = engine.server_call(server_lambda, public_arguments, evaluation_keys)
# Client
result = ClientSupport.decrypt_result(key_set, public_result)
return result

View File

@@ -1,152 +0,0 @@
import pytest
import shutil
import numpy as np
from concrete.compiler import (
JITSupport,
LibrarySupport,
ClientSupport,
CompilationOptions,
PublicArguments,
)
from concrete.compiler.client_parameters import ClientParameters
from concrete.compiler.public_result import PublicResult
def assert_result(result, expected_result):
"""Assert that result and expected result are equal.
result and expected_result can be integers on numpy arrays.
"""
assert type(expected_result) == type(result)
if isinstance(expected_result, int):
assert result == expected_result
else:
assert np.all(result == expected_result)
def run_with_serialization(
engine,
args,
compilation_result,
keyset_cache,
):
"""Execute engine on the given arguments. Performs serialization betwee client/server.
Perform required loading, encryption, execution, and decryption."""
# Client
client_parameters = engine.load_client_parameters(compilation_result)
serialized_client_parameters = client_parameters.serialize()
client_parameters = ClientParameters.unserialize(serialized_client_parameters)
key_set = ClientSupport.key_set(client_parameters, keyset_cache)
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
public_arguments_buffer = public_arguments.serialize()
# Server
public_arguments = PublicArguments.unserialize(
client_parameters, public_arguments_buffer
)
del public_arguments_buffer
server_lambda = engine.load_server_lambda(compilation_result)
public_result = engine.server_call(server_lambda, public_arguments)
public_result_buffer = public_result.serialize()
# Client
public_result = PublicResult.unserialize(client_parameters, public_result_buffer)
del public_result_buffer
result = ClientSupport.decrypt_result(key_set, public_result)
return result
def compile_run_assert_with_serialization(
engine,
mlir_input,
args,
expected_result,
keyset_cache,
):
"""Compile run and assert result. Performs serialization betwee client/server.
Can take both JITSupport or LibrarySupport as engine.
"""
options = CompilationOptions.new("main")
compilation_result = engine.compile(mlir_input, options)
result = run_with_serialization(engine, args, compilation_result, keyset_cache)
assert_result(result, expected_result)
end_to_end_fixture = [
pytest.param(
"""
func @main(%arg0: !FHE.eint<5>, %arg1: i6) -> !FHE.eint<5> {
%1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<5>, i6) -> (!FHE.eint<5>)
return %1: !FHE.eint<5>
}
""",
(5, 7),
12,
id="enc_plain_int_args",
marks=pytest.mark.xfail,
),
pytest.param(
"""
func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<5>, !FHE.eint<5>) -> (!FHE.eint<5>)
return %1: !FHE.eint<5>
}
""",
(5, 7),
12,
id="enc_enc_int_args",
),
pytest.param(
"""
func @main(%arg0: tensor<4x!FHE.eint<5>>, %arg1: tensor<4xi6>) -> !FHE.eint<5>
{
%ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) :
(tensor<4x!FHE.eint<5>>, tensor<4xi6>) -> !FHE.eint<5>
return %ret : !FHE.eint<5>
}
""",
(
np.array([1, 2, 3, 4], dtype=np.uint8),
np.array([4, 3, 2, 1], dtype=np.uint8),
),
20,
id="enc_plain_ndarray_args",
marks=pytest.mark.xfail,
),
pytest.param(
"""
func @main(%a0: tensor<4x!FHE.eint<5>>, %a1: tensor<4x!FHE.eint<5>>) -> tensor<4x!FHE.eint<5>> {
%res = "FHELinalg.add_eint"(%a0, %a1) : (tensor<4x!FHE.eint<5>>, tensor<4x!FHE.eint<5>>) -> tensor<4x!FHE.eint<5>>
return %res : tensor<4x!FHE.eint<5>>
}
""",
(
np.array([1, 2, 3, 4], dtype=np.uint8),
np.array([7, 0, 1, 5], dtype=np.uint8),
),
np.array([8, 2, 4, 9]),
id="enc_enc_ndarray_args",
),
]
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
def test_jit_compile_and_run_with_serialization(
mlir_input, args, expected_result, keyset_cache
):
engine = JITSupport.new()
compile_run_assert_with_serialization(
engine, mlir_input, args, expected_result, keyset_cache
)
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
def test_lib_compile_and_run_with_serialization(
mlir_input, args, expected_result, keyset_cache
):
artifact_dir = "./py_test_lib_compile_and_run"
engine = LibrarySupport.new(artifact_dir)
compile_run_assert_with_serialization(
engine, mlir_input, args, expected_result, keyset_cache
)
shutil.rmtree(artifact_dir)

View File

@@ -21,6 +21,8 @@ void compile_and_run(EndToEndDesc desc, LambdaSupport support) {
auto keySet = support.keySet(*clientParameters, getTestKeySetCache());
ASSERT_EXPECTED_SUCCESS(keySet);
auto evaluationKeys = (*keySet)->evaluationKeys();
/* 3 - Load the server lambda */
auto serverLambda = support.loadServerLambda(**compilationResult);
ASSERT_EXPECTED_SUCCESS(serverLambda);
@@ -41,7 +43,8 @@ void compile_and_run(EndToEndDesc desc, LambdaSupport support) {
ASSERT_EXPECTED_SUCCESS(publicArguments);
/* 5 - Call the server lambda */
auto publicResult = support.serverCall(*serverLambda, **publicArguments);
auto publicResult =
support.serverCall(*serverLambda, **publicArguments, evaluationKeys);
ASSERT_EXPECTED_SUCCESS(publicResult);
/* 6 - Decrypt the public result */