mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor: separate runtime context from public arguments
This commit is contained in:
@@ -53,7 +53,8 @@ jit_load_server_lambda(JITSupport_C support,
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
|
||||
jit_server_call(JITSupport_C support, mlir::concretelang::JITLambda &lambda,
|
||||
concretelang::clientlib::PublicArguments &args);
|
||||
concretelang::clientlib::PublicArguments &args,
|
||||
concretelang::clientlib::EvaluationKeys &evaluationKeys);
|
||||
|
||||
// Library Support bindings ///////////////////////////////////////////////////
|
||||
|
||||
@@ -82,7 +83,8 @@ library_load_server_lambda(LibrarySupport_C support,
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
|
||||
library_server_call(LibrarySupport_C support,
|
||||
concretelang::serverlib::ServerLambda lambda,
|
||||
concretelang::clientlib::PublicArguments &args);
|
||||
concretelang::clientlib::PublicArguments &args,
|
||||
concretelang::clientlib::EvaluationKeys &evaluationKeys);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
library_get_shared_lib_path(LibrarySupport_C support);
|
||||
@@ -128,6 +130,12 @@ publicResultUnserialize(mlir::concretelang::ClientParameters &clientParameters,
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
publicResultSerialize(concretelang::clientlib::PublicResult &publicResult);
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::clientlib::EvaluationKeys
|
||||
evaluationKeysUnserialize(const std::string &buffer);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string evaluationKeysSerialize(
|
||||
concretelang::clientlib::EvaluationKeys &evaluationKeys);
|
||||
|
||||
// Parse then print a textual representation of an MLIR module
|
||||
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);
|
||||
|
||||
|
||||
108
compiler/include/concretelang/ClientLib/EvaluationKeys.h
Normal file
108
compiler/include/concretelang/ClientLib/EvaluationKeys.h
Normal file
@@ -0,0 +1,108 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_EVALUATION_KEYS_H_
|
||||
#define CONCRETELANG_CLIENTLIB_EVALUATION_KEYS_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
extern "C" {
|
||||
#include "concrete-ffi.h"
|
||||
}
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
// =============================================
|
||||
|
||||
/// Wrapper for `LweKeyswitchKey_u64` so that it cleans up properly.
|
||||
class LweKeyswitchKey {
|
||||
private:
|
||||
LweKeyswitchKey_u64 *ksk;
|
||||
|
||||
protected:
|
||||
friend std::ostream &operator<<(std::ostream &ostream,
|
||||
const LweKeyswitchKey &wrappedKsk);
|
||||
friend std::istream &operator>>(std::istream &istream,
|
||||
LweKeyswitchKey &wrappedKsk);
|
||||
|
||||
public:
|
||||
LweKeyswitchKey(LweKeyswitchKey_u64 *ksk) : ksk{ksk} {}
|
||||
LweKeyswitchKey(LweKeyswitchKey &other) = delete;
|
||||
LweKeyswitchKey(LweKeyswitchKey &&other) : ksk{other.ksk} {
|
||||
other.ksk = nullptr;
|
||||
}
|
||||
~LweKeyswitchKey() {
|
||||
if (this->ksk != nullptr) {
|
||||
free_lwe_keyswitch_key_u64(this->ksk);
|
||||
this->ksk = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
LweKeyswitchKey_u64 *get() { return this->ksk; }
|
||||
};
|
||||
|
||||
// =============================================
|
||||
|
||||
/// Wrapper for `LweBootstrapKey_u64` so that it cleans up properly.
|
||||
class LweBootstrapKey {
|
||||
private:
|
||||
LweBootstrapKey_u64 *bsk;
|
||||
|
||||
protected:
|
||||
friend std::ostream &operator<<(std::ostream &ostream,
|
||||
const LweBootstrapKey &wrappedBsk);
|
||||
friend std::istream &operator>>(std::istream &istream,
|
||||
LweBootstrapKey &wrappedBsk);
|
||||
|
||||
public:
|
||||
LweBootstrapKey(LweBootstrapKey_u64 *bsk) : bsk{bsk} {}
|
||||
LweBootstrapKey(LweBootstrapKey &other) = delete;
|
||||
LweBootstrapKey(LweBootstrapKey &&other) : bsk{other.bsk} {
|
||||
other.bsk = nullptr;
|
||||
}
|
||||
~LweBootstrapKey() {
|
||||
if (this->bsk != nullptr) {
|
||||
free_lwe_bootstrap_key_u64(this->bsk);
|
||||
this->bsk = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
LweBootstrapKey_u64 *get() { return this->bsk; }
|
||||
};
|
||||
|
||||
// =============================================
|
||||
|
||||
/// Evalution keys required for execution.
|
||||
class EvaluationKeys {
|
||||
private:
|
||||
std::shared_ptr<LweKeyswitchKey> sharedKsk;
|
||||
std::shared_ptr<LweBootstrapKey> sharedBsk;
|
||||
|
||||
protected:
|
||||
friend std::ostream &operator<<(std::ostream &ostream,
|
||||
const EvaluationKeys &evaluationKeys);
|
||||
friend std::istream &operator>>(std::istream &istream,
|
||||
EvaluationKeys &evaluationKeys);
|
||||
|
||||
public:
|
||||
EvaluationKeys()
|
||||
: sharedKsk{std::shared_ptr<LweKeyswitchKey>(nullptr)},
|
||||
sharedBsk{std::shared_ptr<LweBootstrapKey>(nullptr)} {}
|
||||
|
||||
EvaluationKeys(std::shared_ptr<LweKeyswitchKey> sharedKsk,
|
||||
std::shared_ptr<LweBootstrapKey> sharedBsk)
|
||||
: sharedKsk{sharedKsk}, sharedBsk{sharedBsk} {}
|
||||
|
||||
LweKeyswitchKey_u64 *getKsk() { return this->sharedKsk->get(); }
|
||||
LweBootstrapKey_u64 *getBsk() { return this->sharedBsk->get(); }
|
||||
};
|
||||
|
||||
// =============================================
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -16,6 +16,7 @@ extern "C" {
|
||||
#include "concretelang/Runtime/context.h"
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
|
||||
@@ -77,28 +78,29 @@ public:
|
||||
CircuitGate inputGate(size_t pos) { return std::get<0>(inputs[pos]); }
|
||||
CircuitGate outputGate(size_t pos) { return std::get<0>(outputs[pos]); }
|
||||
|
||||
void setRuntimeContext(RuntimeContext &context) {
|
||||
context.ksk = std::get<1>(this->keyswitchKeys["ksk_v0"]);
|
||||
context.bsk = std::get<1>(this->bootstrapKeys.at("bsk_v0"));
|
||||
}
|
||||
|
||||
RuntimeContext runtimeContext() {
|
||||
RuntimeContext context;
|
||||
this->setRuntimeContext(context);
|
||||
context.evaluationKeys = this->evaluationKeys();
|
||||
return context;
|
||||
}
|
||||
|
||||
EvaluationKeys evaluationKeys() {
|
||||
auto sharedKsk = std::get<1>(this->keyswitchKeys.at("ksk_v0"));
|
||||
auto sharedBsk = std::get<1>(this->bootstrapKeys.at("bsk_v0"));
|
||||
return EvaluationKeys(sharedKsk, sharedBsk);
|
||||
}
|
||||
|
||||
const std::map<LweSecretKeyID,
|
||||
std::pair<LweSecretKeyParam, LweSecretKey_u64 *>> &
|
||||
getSecretKeys();
|
||||
|
||||
const std::map<LweSecretKeyID,
|
||||
std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>> &
|
||||
getBootstrapKeys();
|
||||
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>>
|
||||
&getBootstrapKeys();
|
||||
|
||||
const std::map<LweSecretKeyID,
|
||||
std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>> &
|
||||
getKeyswitchKeys();
|
||||
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>>
|
||||
&getKeyswitchKeys();
|
||||
|
||||
protected:
|
||||
outcome::checked<void, StringError>
|
||||
@@ -124,9 +126,11 @@ private:
|
||||
Engine *engine;
|
||||
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
|
||||
secretKeys;
|
||||
std::map<LweSecretKeyID, std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>>
|
||||
std::map<LweSecretKeyID,
|
||||
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>>
|
||||
bootstrapKeys;
|
||||
std::map<LweSecretKeyID, std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>>
|
||||
std::map<LweSecretKeyID,
|
||||
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>>
|
||||
keyswitchKeys;
|
||||
std::vector<std::tuple<CircuitGate, LweSecretKeyParam, LweSecretKey_u64 *>>
|
||||
inputs;
|
||||
@@ -137,10 +141,10 @@ private:
|
||||
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
|
||||
secretKeys,
|
||||
std::map<LweSecretKeyID,
|
||||
std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>>
|
||||
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>>
|
||||
bootstrapKeys,
|
||||
std::map<LweSecretKeyID,
|
||||
std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>>
|
||||
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>>
|
||||
keyswitchKeys);
|
||||
};
|
||||
|
||||
|
||||
@@ -37,7 +37,6 @@ class PublicArguments {
|
||||
/// arguments and public keys.
|
||||
public:
|
||||
PublicArguments(const ClientParameters &clientParameters,
|
||||
RuntimeContext runtimeContext, bool clearRuntimeContext,
|
||||
std::vector<void *> &&preparedArgs,
|
||||
std::vector<TensorData> &&ciphertextBuffers);
|
||||
~PublicArguments();
|
||||
@@ -56,13 +55,9 @@ private:
|
||||
outcome::checked<void, StringError> unserializeArgs(std::istream &istream);
|
||||
|
||||
ClientParameters clientParameters;
|
||||
RuntimeContext runtimeContext;
|
||||
std::vector<void *> preparedArgs;
|
||||
// Store buffers of ciphertexts
|
||||
std::vector<TensorData> ciphertextBuffers;
|
||||
|
||||
// Indicates if this public argument own the runtime keys.
|
||||
bool clearRuntimeContext;
|
||||
};
|
||||
|
||||
struct PublicResult {
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
#include "concretelang/Runtime/context.h"
|
||||
|
||||
@@ -67,6 +68,18 @@ TensorData unserializeTensorData(
|
||||
// accomodate non static sizes
|
||||
std::istream &istream);
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const LweKeyswitchKey &wrappedKsk);
|
||||
std::istream &operator>>(std::istream &istream, LweKeyswitchKey &wrappedKsk);
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const LweBootstrapKey &wrappedBsk);
|
||||
std::istream &operator>>(std::istream &istream, LweBootstrapKey &wrappedBsk);
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const EvaluationKeys &evaluationKeys);
|
||||
std::istream &operator>>(std::istream &istream, EvaluationKeys &evaluationKeys);
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@
|
||||
#include <mutex>
|
||||
#include <pthread.h>
|
||||
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
|
||||
extern "C" {
|
||||
#include "concrete-ffi.h"
|
||||
}
|
||||
@@ -18,16 +20,18 @@ namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
typedef struct RuntimeContext {
|
||||
LweKeyswitchKey_u64 *ksk;
|
||||
LweBootstrapKey_u64 *bsk;
|
||||
::concretelang::clientlib::EvaluationKeys evaluationKeys;
|
||||
std::map<pthread_t, Engine *> engines;
|
||||
std::mutex engines_map_guard;
|
||||
|
||||
RuntimeContext() {}
|
||||
|
||||
// Ensure that the engines map is not copied
|
||||
RuntimeContext(const RuntimeContext &ctx) : ksk(ctx.ksk), bsk(ctx.bsk) {}
|
||||
RuntimeContext(const RuntimeContext &ctx)
|
||||
: evaluationKeys(ctx.evaluationKeys) {}
|
||||
RuntimeContext(const RuntimeContext &&other)
|
||||
: ksk(other.ksk), bsk(other.bsk) {}
|
||||
: evaluationKeys(other.evaluationKeys) {}
|
||||
|
||||
~RuntimeContext() {
|
||||
for (const auto &key : engines) {
|
||||
free_engine(key.second);
|
||||
@@ -35,8 +39,7 @@ typedef struct RuntimeContext {
|
||||
}
|
||||
|
||||
RuntimeContext &operator=(const RuntimeContext &rhs) {
|
||||
ksk = rhs.ksk;
|
||||
bsk = rhs.bsk;
|
||||
this->evaluationKeys = rhs.evaluationKeys;
|
||||
return *this;
|
||||
}
|
||||
} RuntimeContext;
|
||||
|
||||
@@ -39,7 +39,8 @@ public:
|
||||
|
||||
/// Call the ServerLambda with public arguments.
|
||||
std::unique_ptr<clientlib::PublicResult>
|
||||
call(clientlib::PublicArguments &args);
|
||||
call(clientlib::PublicArguments &args,
|
||||
clientlib::EvaluationKeys &evaluationKeys);
|
||||
|
||||
protected:
|
||||
ClientParameters clientParameters;
|
||||
@@ -51,4 +52,4 @@ protected:
|
||||
} // namespace serverlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@@ -51,8 +51,9 @@ public:
|
||||
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
|
||||
serverCall(std::shared_ptr<concretelang::JITLambda> lambda,
|
||||
clientlib::PublicArguments &args) override {
|
||||
return lambda->call(args);
|
||||
clientlib::PublicArguments &args,
|
||||
clientlib::EvaluationKeys &evaluationKeys) override {
|
||||
return lambda->call(args, evaluationKeys);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
@@ -36,7 +36,8 @@ public:
|
||||
|
||||
/// Call the JIT lambda with the public arguments.
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
|
||||
call(clientlib::PublicArguments &args);
|
||||
call(clientlib::PublicArguments &args,
|
||||
clientlib::EvaluationKeys &evaluationKeys);
|
||||
|
||||
void setUseDataflow(bool option) { this->useDataflow = option; }
|
||||
|
||||
|
||||
@@ -275,7 +275,8 @@ public:
|
||||
|
||||
/// Call the lambda with the public arguments.
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>> virtual serverCall(
|
||||
Lambda lambda, clientlib::PublicArguments &args) = 0;
|
||||
Lambda lambda, clientlib::PublicArguments &args,
|
||||
clientlib::EvaluationKeys &evaluationKeys) = 0;
|
||||
|
||||
/// Build the client KeySet from the client parameters.
|
||||
static llvm::Expected<std::unique_ptr<clientlib::KeySet>>
|
||||
@@ -302,11 +303,12 @@ public:
|
||||
}
|
||||
|
||||
template <typename ResT>
|
||||
static llvm::Expected<ResT>
|
||||
call(Lambda lambda, clientlib::PublicArguments &publicArguments) {
|
||||
static llvm::Expected<ResT> call(Lambda lambda,
|
||||
clientlib::PublicArguments &publicArguments,
|
||||
clientlib::EvaluationKeys &evaluationKeys) {
|
||||
// Call the lambda
|
||||
auto publicResult = LambdaSupport<Lambda, CompilationResult>().serverCall(
|
||||
lambda, publicArguments);
|
||||
lambda, publicArguments, evaluationKeys);
|
||||
if (auto err = publicResult.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
@@ -357,7 +359,9 @@ public:
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
auto publicResult = support.serverCall(lambda, **publicArguments);
|
||||
auto evaluationKeys = this->keySet->evaluationKeys();
|
||||
auto publicResult =
|
||||
support.serverCall(lambda, **publicArguments, evaluationKeys);
|
||||
if (auto err = publicResult.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
@@ -375,7 +379,9 @@ public:
|
||||
if (!publicArguments.has_value()) {
|
||||
return StreamStringError(publicArguments.error().mesg);
|
||||
}
|
||||
auto publicResult = support.serverCall(lambda, *publicArguments.value());
|
||||
auto evaluationKeys = keySet->evaluationKeys();
|
||||
auto publicResult =
|
||||
support.serverCall(lambda, *publicArguments.value(), evaluationKeys);
|
||||
if (auto err = publicResult.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
@@ -394,7 +400,9 @@ public:
|
||||
if (publicArguments.has_error()) {
|
||||
return StreamStringError(publicArguments.error().mesg);
|
||||
}
|
||||
auto publicResult = support.serverCall(lambda, *publicArguments.value());
|
||||
auto evaluationKeys = keySet->evaluationKeys();
|
||||
auto publicResult =
|
||||
support.serverCall(lambda, *publicArguments.value(), evaluationKeys);
|
||||
if (auto err = publicResult.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
@@ -103,9 +103,9 @@ public:
|
||||
|
||||
/// Call the lambda with the public arguments.
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
|
||||
serverCall(serverlib::ServerLambda lambda,
|
||||
clientlib::PublicArguments &args) override {
|
||||
return lambda.call(args);
|
||||
serverCall(serverlib::ServerLambda lambda, clientlib::PublicArguments &args,
|
||||
clientlib::EvaluationKeys &evaluationKeys) override {
|
||||
return lambda.call(args, evaluationKeys);
|
||||
}
|
||||
|
||||
/// Get path to shared library
|
||||
|
||||
@@ -89,7 +89,8 @@ public:
|
||||
// serverInput));
|
||||
|
||||
// server function call
|
||||
auto publicResult = serverLambda.call(*publicArgument);
|
||||
auto evaluationKeys = keySet->evaluationKeys();
|
||||
auto publicResult = serverLambda.call(*publicArgument, evaluationKeys);
|
||||
|
||||
// client result decryption
|
||||
return this->decryptResult(*keySet, *publicResult);
|
||||
|
||||
@@ -41,6 +41,7 @@ declare_mlir_python_sources(ConcretelangBindingsPythonSources
|
||||
concrete/compiler/library_lambda.py
|
||||
concrete/compiler/public_arguments.py
|
||||
concrete/compiler/public_result.py
|
||||
concrete/compiler/evaluation_keys.py
|
||||
concrete/compiler/utils.py
|
||||
concrete/compiler/wrapper.py
|
||||
concrete/__init__.py
|
||||
@@ -119,4 +120,4 @@ add_mlir_python_modules(ConcretelangPythonModules
|
||||
ConcretelangBindingsPythonSources
|
||||
COMMON_CAPI_LINK_LIBS
|
||||
ConcretelangBindingsPythonCAPI
|
||||
)
|
||||
)
|
||||
|
||||
@@ -90,8 +90,10 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
pybind11::return_value_policy::reference)
|
||||
.def("server_call",
|
||||
[](JITSupport_C &support, concretelang::JITLambda &lambda,
|
||||
clientlib::PublicArguments &publicArguments) {
|
||||
return jit_server_call(support, lambda, publicArguments);
|
||||
clientlib::PublicArguments &publicArguments,
|
||||
clientlib::EvaluationKeys &evaluationKeys) {
|
||||
return jit_server_call(support, lambda, publicArguments,
|
||||
evaluationKeys);
|
||||
});
|
||||
|
||||
pybind11::class_<mlir::concretelang::LibraryCompilationResult>(
|
||||
@@ -132,8 +134,10 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
pybind11::return_value_policy::reference)
|
||||
.def("server_call",
|
||||
[](LibrarySupport_C &support, serverlib::ServerLambda lambda,
|
||||
clientlib::PublicArguments &publicArguments) {
|
||||
return library_server_call(support, lambda, publicArguments);
|
||||
clientlib::PublicArguments &publicArguments,
|
||||
clientlib::EvaluationKeys &evaluationKeys) {
|
||||
return library_server_call(support, lambda, publicArguments,
|
||||
evaluationKeys);
|
||||
})
|
||||
.def("get_shared_lib_path",
|
||||
[](LibrarySupport_C &support) {
|
||||
@@ -185,7 +189,10 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
clientParametersSerialize(clientParameters));
|
||||
});
|
||||
|
||||
pybind11::class_<clientlib::KeySet>(m, "KeySet");
|
||||
pybind11::class_<clientlib::KeySet>(m, "KeySet")
|
||||
.def("get_evaluation_keys",
|
||||
[](clientlib::KeySet &keySet) { return keySet.evaluationKeys(); });
|
||||
|
||||
pybind11::class_<clientlib::PublicArguments,
|
||||
std::unique_ptr<clientlib::PublicArguments>>(
|
||||
m, "PublicArguments")
|
||||
@@ -207,6 +214,15 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
return pybind11::bytes(publicResultSerialize(publicResult));
|
||||
});
|
||||
|
||||
pybind11::class_<clientlib::EvaluationKeys>(m, "EvaluationKeys")
|
||||
.def_static("unserialize",
|
||||
[](const pybind11::bytes &buffer) {
|
||||
return evaluationKeysUnserialize(buffer);
|
||||
})
|
||||
.def("serialize", [](clientlib::EvaluationKeys &evaluationKeys) {
|
||||
return pybind11::bytes(evaluationKeysSerialize(evaluationKeys));
|
||||
});
|
||||
|
||||
pybind11::class_<lambdaArgument>(m, "LambdaArgument")
|
||||
.def_static("from_tensor",
|
||||
[](std::vector<uint8_t> tensor, std::vector<int64_t> dims) {
|
||||
|
||||
@@ -26,6 +26,7 @@ from .library_lambda import LibraryLambda
|
||||
from .client_support import ClientSupport
|
||||
from .jit_support import JITSupport
|
||||
from .library_support import LibrarySupport
|
||||
from .evaluation_keys import EvaluationKeys
|
||||
|
||||
|
||||
# Terminate parallelization in the compiler (if init) during cleanup
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
|
||||
|
||||
"""EvaluationKeys."""
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
EvaluationKeys as _EvaluationKeys,
|
||||
)
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
from .wrapper import WrapperCpp
|
||||
|
||||
|
||||
class EvaluationKeys(WrapperCpp):
|
||||
"""
|
||||
EvaluationKeys required for execution.
|
||||
"""
|
||||
|
||||
def __init__(self, evaluation_keys: _EvaluationKeys):
|
||||
"""Wrap the native Cpp object.
|
||||
|
||||
Args:
|
||||
evaluation_keys (_EvaluationKeys): object to wrap
|
||||
|
||||
Raises:
|
||||
TypeError: if evaluation_keys is not of type _EvaluationKeys
|
||||
"""
|
||||
if not isinstance(evaluation_keys, _EvaluationKeys):
|
||||
raise TypeError(
|
||||
f"evaluation_keys must be of type _EvaluationKeys, not {type(evaluation_keys)}"
|
||||
)
|
||||
super().__init__(evaluation_keys)
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""Serialize the EvaluationKeys.
|
||||
|
||||
Returns:
|
||||
bytes: serialized object
|
||||
"""
|
||||
return self.cpp().serialize()
|
||||
|
||||
@staticmethod
|
||||
def unserialize(serialized_evaluation_keys: bytes) -> "EvaluationKeys":
|
||||
"""Unserialize EvaluationKeys from bytes.
|
||||
|
||||
Args:
|
||||
serialized_evaluation_keys (bytes): previously serialized EvaluationKeys
|
||||
|
||||
Raises:
|
||||
TypeError: if serialized_evaluation_keys is not of type bytes
|
||||
|
||||
Returns:
|
||||
EvaluationKeys: unserialized object
|
||||
"""
|
||||
if not isinstance(serialized_evaluation_keys, bytes):
|
||||
raise TypeError(
|
||||
f"serialized_evaluation_keys must be of type bytes, "
|
||||
f"not {type(serialized_evaluation_keys)}"
|
||||
)
|
||||
return EvaluationKeys.wrap(
|
||||
_EvaluationKeys.unserialize(serialized_evaluation_keys)
|
||||
)
|
||||
@@ -23,6 +23,7 @@ from .jit_lambda import JITLambda
|
||||
from .public_arguments import PublicArguments
|
||||
from .public_result import PublicResult
|
||||
from .wrapper import WrapperCpp
|
||||
from .evaluation_keys import EvaluationKeys
|
||||
|
||||
|
||||
class JITSupport(WrapperCpp):
|
||||
@@ -139,17 +140,22 @@ class JITSupport(WrapperCpp):
|
||||
return JITLambda.wrap(self.cpp().load_server_lambda(compilation_result.cpp()))
|
||||
|
||||
def server_call(
|
||||
self, jit_lambda: JITLambda, public_arguments: PublicArguments
|
||||
self,
|
||||
jit_lambda: JITLambda,
|
||||
public_arguments: PublicArguments,
|
||||
evaluation_keys: EvaluationKeys,
|
||||
) -> PublicResult:
|
||||
"""Call the JITLambda with public_arguments.
|
||||
|
||||
Args:
|
||||
jit_lambda (JITLambda): A server lambda to call.
|
||||
public_arguments (PublicArguments): The arguments of the call.
|
||||
evaluation_keys (EvaluationKeys): Evalutation keys of the call.
|
||||
|
||||
Raises:
|
||||
TypeError: if jit_lambda is not of type JITLambda
|
||||
TypeError: if public_arguments is not of type PublicArguments
|
||||
TypeError: if evaluation_keys is not of type EvaluationKeys
|
||||
|
||||
Returns:
|
||||
PublicResult: the result of the call of the server lambda.
|
||||
@@ -162,6 +168,12 @@ class JITSupport(WrapperCpp):
|
||||
raise TypeError(
|
||||
f"public_arguments must be of type PublicArguments, not {type(public_arguments)}"
|
||||
)
|
||||
if not isinstance(evaluation_keys, EvaluationKeys):
|
||||
raise TypeError(
|
||||
f"evaluation_keys must be of type EvaluationKeys, not {type(evaluation_keys)}"
|
||||
)
|
||||
return PublicResult.wrap(
|
||||
self.cpp().server_call(jit_lambda.cpp(), public_arguments.cpp())
|
||||
self.cpp().server_call(
|
||||
jit_lambda.cpp(), public_arguments.cpp(), evaluation_keys.cpp()
|
||||
)
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@ from mlir._mlir_libs._concretelang._compiler import (
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
from .wrapper import WrapperCpp
|
||||
from .evaluation_keys import EvaluationKeys
|
||||
|
||||
|
||||
class KeySet(WrapperCpp):
|
||||
@@ -34,3 +35,13 @@ class KeySet(WrapperCpp):
|
||||
if not isinstance(keyset, _KeySet):
|
||||
raise TypeError(f"keyset must be of type _KeySet, not {type(keyset)}")
|
||||
super().__init__(keyset)
|
||||
|
||||
def get_evaluation_keys(self) -> EvaluationKeys:
|
||||
"""
|
||||
Get evaluation keys for execution.
|
||||
|
||||
Returns:
|
||||
EvaluationKeys:
|
||||
evaluation keys for execution
|
||||
"""
|
||||
return EvaluationKeys(self.cpp().get_evaluation_keys())
|
||||
|
||||
@@ -23,6 +23,7 @@ from .public_result import PublicResult
|
||||
from .client_parameters import ClientParameters
|
||||
from .wrapper import WrapperCpp
|
||||
from .utils import lookup_runtime_lib
|
||||
from .evaluation_keys import EvaluationKeys
|
||||
|
||||
|
||||
# Default output path for compilation artifacts
|
||||
@@ -211,17 +212,22 @@ class LibrarySupport(WrapperCpp):
|
||||
)
|
||||
|
||||
def server_call(
|
||||
self, library_lambda: LibraryLambda, public_arguments: PublicArguments
|
||||
self,
|
||||
library_lambda: LibraryLambda,
|
||||
public_arguments: PublicArguments,
|
||||
evaluation_keys: EvaluationKeys,
|
||||
) -> PublicResult:
|
||||
"""Call the library with public_arguments.
|
||||
|
||||
Args:
|
||||
library_lambda (LibraryLambda): reference to the compiled library
|
||||
public_arguments (PublicArguments): arguments to use for execution
|
||||
evaluation_keys (EvaluationKeys): evaluation keys to use for execution
|
||||
|
||||
Raises:
|
||||
TypeError: if library_lambda is not of type LibraryLambda
|
||||
TypeError: if public_arguments is not of type PublicArguments
|
||||
TypeError: if evaluation_keys is not of type EvaluationKeys
|
||||
|
||||
Returns:
|
||||
PublicResult: result of the execution
|
||||
@@ -234,8 +240,16 @@ class LibrarySupport(WrapperCpp):
|
||||
raise TypeError(
|
||||
f"public_arguments must be of type PublicArguments, not {type(public_arguments)}"
|
||||
)
|
||||
if not isinstance(evaluation_keys, EvaluationKeys):
|
||||
raise TypeError(
|
||||
f"evaluation_keys must be of type EvaluationKeys, not {type(evaluation_keys)}"
|
||||
)
|
||||
return PublicResult.wrap(
|
||||
self.cpp().server_call(library_lambda.cpp(), public_arguments.cpp())
|
||||
self.cpp().server_call(
|
||||
library_lambda.cpp(),
|
||||
public_arguments.cpp(),
|
||||
evaluation_keys.cpp(),
|
||||
)
|
||||
)
|
||||
|
||||
def get_shared_lib_path(self) -> str:
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
|
||||
#include "concretelang-c/Support/CompilerEngine.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/ClientLib/Serializers.h"
|
||||
#include "concretelang/Runtime/runtime_api.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/JITSupport.h"
|
||||
@@ -53,8 +54,9 @@ jit_load_server_lambda(JITSupport_C support,
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
|
||||
jit_server_call(JITSupport_C support, mlir::concretelang::JITLambda &lambda,
|
||||
concretelang::clientlib::PublicArguments &args) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(publicResult, lambda.call(args));
|
||||
concretelang::clientlib::PublicArguments &args,
|
||||
concretelang::clientlib::EvaluationKeys &evaluationKeys) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(publicResult, lambda.call(args, evaluationKeys));
|
||||
return std::move(*publicResult);
|
||||
}
|
||||
|
||||
@@ -97,9 +99,10 @@ library_load_server_lambda(
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
|
||||
library_server_call(LibrarySupport_C support,
|
||||
concretelang::serverlib::ServerLambda lambda,
|
||||
concretelang::clientlib::PublicArguments &args) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(publicResult,
|
||||
support.support.serverCall(lambda, args));
|
||||
concretelang::clientlib::PublicArguments &args,
|
||||
concretelang::clientlib::EvaluationKeys &evaluationKeys) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(
|
||||
publicResult, support.support.serverCall(lambda, args, evaluationKeys));
|
||||
return std::move(*publicResult);
|
||||
}
|
||||
|
||||
@@ -192,6 +195,27 @@ publicResultSerialize(concretelang::clientlib::PublicResult &publicResult) {
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::clientlib::EvaluationKeys
|
||||
evaluationKeysUnserialize(const std::string &buffer) {
|
||||
std::stringstream istream(buffer);
|
||||
|
||||
concretelang::clientlib::EvaluationKeys evaluationKeys;
|
||||
concretelang::clientlib::operator>>(istream, evaluationKeys);
|
||||
|
||||
if (istream.fail()) {
|
||||
throw std::runtime_error("Cannot read evaluation keys");
|
||||
}
|
||||
|
||||
return evaluationKeys;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string evaluationKeysSerialize(
|
||||
concretelang::clientlib::EvaluationKeys &evaluationKeys) {
|
||||
std::ostringstream buffer(std::ios::binary);
|
||||
concretelang::clientlib::operator<<(buffer, evaluationKeys);
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
clientParametersUnserialize(const std::string &json) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(
|
||||
|
||||
@@ -25,11 +25,8 @@ size_t bitWidthAsWord(size_t exactBitWidth) {
|
||||
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
|
||||
EncryptedArguments::exportPublicArguments(ClientParameters clientParameters,
|
||||
RuntimeContext runtimeContext) {
|
||||
// On client side the runtimeContext is hold by the KeySet
|
||||
bool clearContext = false;
|
||||
return std::make_unique<PublicArguments>(
|
||||
clientParameters, runtimeContext, clearContext, std::move(preparedArgs),
|
||||
std::move(ciphertextBuffers));
|
||||
clientParameters, std::move(preparedArgs), std::move(ciphertextBuffers));
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
|
||||
@@ -24,12 +24,6 @@ KeySet::~KeySet() {
|
||||
for (auto it : secretKeys) {
|
||||
free_lwe_secret_key_u64(it.second.second);
|
||||
}
|
||||
for (auto it : bootstrapKeys) {
|
||||
free_lwe_bootstrap_key_u64(it.second.second);
|
||||
}
|
||||
for (auto it : keyswitchKeys) {
|
||||
free_lwe_keyswitch_key_u64(it.second.second);
|
||||
}
|
||||
free_engine(engine);
|
||||
}
|
||||
|
||||
@@ -115,10 +109,10 @@ void KeySet::setKeys(
|
||||
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
|
||||
secretKeys,
|
||||
std::map<LweSecretKeyID,
|
||||
std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>>
|
||||
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>>
|
||||
bootstrapKeys,
|
||||
std::map<LweSecretKeyID,
|
||||
std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>>
|
||||
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>>
|
||||
keyswitchKeys) {
|
||||
this->secretKeys = secretKeys;
|
||||
this->bootstrapKeys = bootstrapKeys;
|
||||
@@ -160,7 +154,7 @@ KeySet::generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param) {
|
||||
param.level, param.variance, param.glweDimension, polynomialSize);
|
||||
|
||||
// Store the bootstrap key
|
||||
bootstrapKeys[id] = {param, bsk};
|
||||
bootstrapKeys[id] = {param, std::make_shared<LweBootstrapKey>(bsk)};
|
||||
|
||||
return outcome::success();
|
||||
}
|
||||
@@ -184,7 +178,7 @@ KeySet::generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param) {
|
||||
param.baseLog, param.variance);
|
||||
|
||||
// Store the keyswitch key
|
||||
keyswitchKeys[id] = {param, ksk};
|
||||
keyswitchKeys[id] = {param, std::make_shared<LweKeyswitchKey>(ksk)};
|
||||
|
||||
return outcome::success();
|
||||
}
|
||||
@@ -253,13 +247,13 @@ const std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
|
||||
}
|
||||
|
||||
const std::map<LweSecretKeyID,
|
||||
std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>> &
|
||||
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>> &
|
||||
KeySet::getBootstrapKeys() {
|
||||
return bootstrapKeys;
|
||||
}
|
||||
|
||||
const std::map<LweSecretKeyID,
|
||||
std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>> &
|
||||
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>> &
|
||||
KeySet::getKeyswitchKeys() {
|
||||
return keyswitchKeys;
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "llvm/ADT/ScopeExit.h"
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
@@ -96,9 +97,11 @@ KeySetCache::loadKeys(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
|
||||
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
|
||||
secretKeys;
|
||||
std::map<LweSecretKeyID, std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>>
|
||||
std::map<LweSecretKeyID,
|
||||
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>>
|
||||
bootstrapKeys;
|
||||
std::map<LweSecretKeyID, std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>>
|
||||
std::map<LweSecretKeyID,
|
||||
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>>
|
||||
keyswitchKeys;
|
||||
|
||||
// Load LWE secret keys
|
||||
@@ -117,7 +120,7 @@ KeySetCache::loadKeys(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
llvm::SmallString<0> path(folderPath);
|
||||
llvm::sys::path::append(path, "pbsKey_" + id);
|
||||
OUTCOME_TRY(LweBootstrapKey_u64 * bsk, loadBootstrapKey(path));
|
||||
bootstrapKeys[id] = {param, bsk};
|
||||
bootstrapKeys[id] = {param, std::make_shared<LweBootstrapKey>(bsk)};
|
||||
}
|
||||
// Load keyswitch keys
|
||||
for (auto keyswitchParam : params.keyswitchKeys) {
|
||||
@@ -126,7 +129,7 @@ KeySetCache::loadKeys(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
llvm::SmallString<0> path(folderPath);
|
||||
llvm::sys::path::append(path, "ksKey_" + id);
|
||||
OUTCOME_TRY(LweKeyswitchKey_u64 * ksk, loadKeyswitchKey(path));
|
||||
keyswitchKeys[id] = {param, ksk};
|
||||
keyswitchKeys[id] = {param, std::make_shared<LweKeyswitchKey>(ksk)};
|
||||
}
|
||||
|
||||
key_set->setKeys(secretKeys, bootstrapKeys, keyswitchKeys);
|
||||
@@ -162,7 +165,7 @@ outcome::checked<void, StringError> saveKeys(KeySet &key_set,
|
||||
auto key = bootstrapKeyParam.second.second;
|
||||
llvm::SmallString<0> path = folderIncompletePath;
|
||||
llvm::sys::path::append(path, "pbsKey_" + id);
|
||||
saveBootstrapKey(path, key);
|
||||
saveBootstrapKey(path, key->get());
|
||||
}
|
||||
// Save keyswitch keys
|
||||
for (auto keyswitchParam : key_set.getKeyswitchKeys()) {
|
||||
@@ -170,7 +173,7 @@ outcome::checked<void, StringError> saveKeys(KeySet &key_set,
|
||||
auto key = keyswitchParam.second.second;
|
||||
llvm::SmallString<0> path = folderIncompletePath;
|
||||
llvm::sys::path::append(path, "ksKey_" + id);
|
||||
saveKeyswitchKey(path, key);
|
||||
saveKeyswitchKey(path, key->get());
|
||||
}
|
||||
|
||||
err = llvm::sys::fs::rename(folderIncompletePath, folderPath);
|
||||
|
||||
@@ -20,28 +20,14 @@ using concretelang::error::StringError;
|
||||
|
||||
// TODO: optimize the move
|
||||
PublicArguments::PublicArguments(const ClientParameters &clientParameters,
|
||||
RuntimeContext runtimeContext,
|
||||
bool clearRuntimeContext,
|
||||
std::vector<void *> &&preparedArgs_,
|
||||
std::vector<TensorData> &&ciphertextBuffers_)
|
||||
: clientParameters(clientParameters), runtimeContext(runtimeContext),
|
||||
clearRuntimeContext(clearRuntimeContext) {
|
||||
: clientParameters(clientParameters) {
|
||||
preparedArgs = std::move(preparedArgs_);
|
||||
ciphertextBuffers = std::move(ciphertextBuffers_);
|
||||
}
|
||||
|
||||
PublicArguments::~PublicArguments() {
|
||||
if (!clearRuntimeContext) {
|
||||
return;
|
||||
}
|
||||
if (runtimeContext.bsk != nullptr) {
|
||||
free_lwe_bootstrap_key_u64(runtimeContext.bsk);
|
||||
}
|
||||
if (runtimeContext.ksk != nullptr) {
|
||||
free_lwe_keyswitch_key_u64(runtimeContext.ksk);
|
||||
runtimeContext.ksk = nullptr;
|
||||
}
|
||||
}
|
||||
PublicArguments::~PublicArguments() {}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
PublicArguments::serialize(std::ostream &ostream) {
|
||||
@@ -49,7 +35,6 @@ PublicArguments::serialize(std::ostream &ostream) {
|
||||
return StringError(
|
||||
"PublicArguments::serialize: ostream should be in binary mode");
|
||||
}
|
||||
ostream << runtimeContext;
|
||||
size_t iPreparedArgs = 0;
|
||||
int iGate = -1;
|
||||
for (auto gate : clientParameters.inputs) {
|
||||
@@ -122,16 +107,10 @@ PublicArguments::unserializeArgs(std::istream &istream) {
|
||||
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
|
||||
PublicArguments::unserialize(ClientParameters &clientParameters,
|
||||
std::istream &istream) {
|
||||
RuntimeContext runtimeContext;
|
||||
istream >> runtimeContext;
|
||||
if (istream.fail()) {
|
||||
return StringError("Cannot read runtime context");
|
||||
}
|
||||
std::vector<void *> empty;
|
||||
std::vector<TensorData> emptyBuffers;
|
||||
auto sArguments = std::make_unique<PublicArguments>(
|
||||
clientParameters, runtimeContext, true, std::move(empty),
|
||||
std::move(emptyBuffers));
|
||||
clientParameters, std::move(empty), std::move(emptyBuffers));
|
||||
OUTCOME_TRYV(sArguments->unserializeArgs(istream));
|
||||
return std::move(sArguments);
|
||||
}
|
||||
|
||||
@@ -67,16 +67,14 @@ std::istream &operator>>(std::istream &istream, LweBootstrapKey_u64 *&key) {
|
||||
|
||||
std::istream &operator>>(std::istream &istream,
|
||||
RuntimeContext &runtimeContext) {
|
||||
istream >> runtimeContext.ksk;
|
||||
istream >> runtimeContext.bsk;
|
||||
istream >> runtimeContext.evaluationKeys;
|
||||
assert(istream.good());
|
||||
return istream;
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const RuntimeContext &runtimeContext) {
|
||||
ostream << runtimeContext.ksk;
|
||||
ostream << runtimeContext.bsk;
|
||||
ostream << runtimeContext.evaluationKeys;
|
||||
assert(ostream.good());
|
||||
return ostream;
|
||||
}
|
||||
@@ -147,5 +145,54 @@ TensorData unserializeTensorData(
|
||||
return result;
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const LweKeyswitchKey &wrappedKsk) {
|
||||
ostream << wrappedKsk.ksk;
|
||||
assert(ostream.good());
|
||||
return ostream;
|
||||
}
|
||||
std::istream &operator>>(std::istream &istream, LweKeyswitchKey &wrappedKsk) {
|
||||
istream >> wrappedKsk.ksk;
|
||||
assert(istream.good());
|
||||
return istream;
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const LweBootstrapKey &wrappedBsk) {
|
||||
ostream << wrappedBsk.bsk;
|
||||
assert(ostream.good());
|
||||
return ostream;
|
||||
}
|
||||
std::istream &operator>>(std::istream &istream, LweBootstrapKey &wrappedBsk) {
|
||||
istream >> wrappedBsk.bsk;
|
||||
assert(istream.good());
|
||||
return istream;
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const EvaluationKeys &evaluationKeys) {
|
||||
ostream << *evaluationKeys.sharedKsk;
|
||||
ostream << *evaluationKeys.sharedBsk;
|
||||
assert(ostream.good());
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::istream &operator>>(std::istream &istream,
|
||||
EvaluationKeys &evaluationKeys) {
|
||||
auto sharedKsk = LweKeyswitchKey(nullptr);
|
||||
auto sharedBsk = LweBootstrapKey(nullptr);
|
||||
|
||||
istream >> sharedKsk;
|
||||
istream >> sharedBsk;
|
||||
|
||||
evaluationKeys.sharedKsk =
|
||||
std::make_shared<LweKeyswitchKey>(std::move(sharedKsk));
|
||||
evaluationKeys.sharedBsk =
|
||||
std::make_shared<LweBootstrapKey>(std::move(sharedBsk));
|
||||
|
||||
assert(istream.good());
|
||||
return istream;
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
} // namespace concretelang
|
||||
|
||||
@@ -9,12 +9,12 @@
|
||||
|
||||
LweKeyswitchKey_u64 *
|
||||
get_keyswitch_key_u64(mlir::concretelang::RuntimeContext *context) {
|
||||
return context->ksk;
|
||||
return context->evaluationKeys.getKsk();
|
||||
}
|
||||
|
||||
LweBootstrapKey_u64 *
|
||||
get_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context) {
|
||||
return context->bsk;
|
||||
return context->evaluationKeys.getBsk();
|
||||
}
|
||||
|
||||
// Instantiate one engine per thread on demand
|
||||
|
||||
@@ -21,7 +21,9 @@ namespace serverlib {
|
||||
|
||||
using concretelang::clientlib::CircuitGate;
|
||||
using concretelang::clientlib::CircuitGateShape;
|
||||
using concretelang::clientlib::EvaluationKeys;
|
||||
using concretelang::clientlib::PublicArguments;
|
||||
using concretelang::clientlib::RuntimeContext;
|
||||
using concretelang::error::StringError;
|
||||
|
||||
outcome::checked<ServerLambda, StringError>
|
||||
@@ -74,10 +76,14 @@ TensorData dynamicCall(void *(*func)(void *...),
|
||||
}
|
||||
|
||||
std::unique_ptr<clientlib::PublicResult>
|
||||
ServerLambda::call(PublicArguments &args) {
|
||||
ServerLambda::call(PublicArguments &args, EvaluationKeys &evaluationKeys) {
|
||||
std::vector<void *> preparedArgs(args.preparedArgs.begin(),
|
||||
args.preparedArgs.end());
|
||||
preparedArgs.push_back((void *)&args.runtimeContext);
|
||||
|
||||
RuntimeContext runtimeContext;
|
||||
runtimeContext.evaluationKeys = evaluationKeys;
|
||||
preparedArgs.push_back((void *)&runtimeContext);
|
||||
|
||||
return clientlib::PublicResult::fromBuffers(
|
||||
clientParameters,
|
||||
{dynamicCall(this->func, preparedArgs, clientParameters.outputs[0])});
|
||||
|
||||
@@ -76,7 +76,8 @@ uint64_t numArgOfRankedMemrefCallingConvention(uint64_t rank) {
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
|
||||
JITLambda::call(clientlib::PublicArguments &args) {
|
||||
JITLambda::call(clientlib::PublicArguments &args,
|
||||
clientlib::EvaluationKeys &evaluationKeys) {
|
||||
#ifndef CONCRETELANG_PARALLEL_EXECUTION_ENABLED
|
||||
if (this->useDataflow) {
|
||||
return StreamStringError(
|
||||
@@ -116,9 +117,12 @@ JITLambda::call(clientlib::PublicArguments &args) {
|
||||
for (auto &arg : args.preparedArgs) {
|
||||
rawArgs[i++] = &arg;
|
||||
}
|
||||
|
||||
RuntimeContext runtimeContext;
|
||||
runtimeContext.evaluationKeys = evaluationKeys;
|
||||
// Pointer on runtime context, the rawArgs take pointer on actual value that
|
||||
// is passed to the compiled function.
|
||||
auto rtCtxPtr = &args.runtimeContext;
|
||||
auto rtCtxPtr = &runtimeContext;
|
||||
rawArgs[i++] = &rtCtxPtr;
|
||||
// Pointers on outputs
|
||||
for (auto &out : outputs) {
|
||||
|
||||
112
compiler/tests/python/test_client_server.py
Normal file
112
compiler/tests/python/test_client_server.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
from concrete.compiler import (
|
||||
ClientSupport,
|
||||
EvaluationKeys,
|
||||
LibrarySupport,
|
||||
PublicArguments,
|
||||
PublicResult,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mlir, args, expected_result",
|
||||
[
|
||||
pytest.param(
|
||||
"""
|
||||
|
||||
func @main(%arg0: !FHE.eint<5>, %arg1: i6) -> !FHE.eint<5> {
|
||||
%1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<5>, i6) -> (!FHE.eint<5>)
|
||||
return %1: !FHE.eint<5>
|
||||
}
|
||||
|
||||
""",
|
||||
(5, 7),
|
||||
12,
|
||||
id="enc_plain_int_args",
|
||||
marks=pytest.mark.xfail,
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
|
||||
func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
|
||||
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<5>, !FHE.eint<5>) -> (!FHE.eint<5>)
|
||||
return %1: !FHE.eint<5>
|
||||
}
|
||||
|
||||
""",
|
||||
(5, 7),
|
||||
12,
|
||||
id="enc_enc_int_args",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
|
||||
func @main(%arg0: tensor<4x!FHE.eint<5>>, %arg1: tensor<4xi6>) -> !FHE.eint<5> {
|
||||
%ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) : (tensor<4x!FHE.eint<5>>, tensor<4xi6>) -> !FHE.eint<5>
|
||||
return %ret : !FHE.eint<5>
|
||||
}
|
||||
|
||||
""",
|
||||
(
|
||||
np.array([1, 2, 3, 4], dtype=np.uint8),
|
||||
np.array([4, 3, 2, 1], dtype=np.uint8),
|
||||
),
|
||||
20,
|
||||
id="enc_plain_ndarray_args",
|
||||
marks=pytest.mark.xfail,
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
|
||||
func @main(%a0: tensor<4x!FHE.eint<5>>, %a1: tensor<4x!FHE.eint<5>>) -> tensor<4x!FHE.eint<5>> {
|
||||
%res = "FHELinalg.add_eint"(%a0, %a1) : (tensor<4x!FHE.eint<5>>, tensor<4x!FHE.eint<5>>) -> tensor<4x!FHE.eint<5>>
|
||||
return %res : tensor<4x!FHE.eint<5>>
|
||||
}
|
||||
|
||||
""",
|
||||
(
|
||||
np.array([1, 2, 3, 4], dtype=np.uint8),
|
||||
np.array([7, 0, 1, 5], dtype=np.uint8),
|
||||
),
|
||||
np.array([8, 2, 4, 9]),
|
||||
id="enc_enc_ndarray_args",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_client_server_end_to_end(mlir, args, expected_result, keyset_cache):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
support = LibrarySupport.new(str(tmpdirname))
|
||||
compilation_result = support.compile(mlir)
|
||||
server_lambda = support.load_server_lambda(compilation_result)
|
||||
|
||||
client_parameters = support.load_client_parameters(compilation_result)
|
||||
keyset = ClientSupport.key_set(client_parameters, keyset_cache)
|
||||
|
||||
evaluation_keys = keyset.get_evaluation_keys()
|
||||
evaluation_keys_serialized = evaluation_keys.serialize()
|
||||
evaluation_keys_unserialized = EvaluationKeys.unserialize(
|
||||
evaluation_keys_serialized
|
||||
)
|
||||
|
||||
args = ClientSupport.encrypt_arguments(client_parameters, keyset, args)
|
||||
args_serialized = args.serialize()
|
||||
args_unserialized = PublicArguments.unserialize(
|
||||
client_parameters, args_serialized
|
||||
)
|
||||
|
||||
result = support.server_call(
|
||||
server_lambda,
|
||||
args_unserialized,
|
||||
evaluation_keys_unserialized,
|
||||
)
|
||||
result_serialized = result.serialize()
|
||||
result_unserialized = PublicResult.unserialize(
|
||||
client_parameters, result_serialized
|
||||
)
|
||||
|
||||
output = ClientSupport.decrypt_result(keyset, result_unserialized)
|
||||
assert np.array_equal(output, expected_result)
|
||||
@@ -32,7 +32,8 @@ def run(engine, args, compilation_result, keyset_cache):
|
||||
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
|
||||
# Server
|
||||
server_lambda = engine.load_server_lambda(compilation_result)
|
||||
public_result = engine.server_call(server_lambda, public_arguments)
|
||||
evaluation_keys = key_set.get_evaluation_keys()
|
||||
public_result = engine.server_call(server_lambda, public_arguments, evaluation_keys)
|
||||
# Client
|
||||
result = ClientSupport.decrypt_result(key_set, public_result)
|
||||
return result
|
||||
|
||||
@@ -1,152 +0,0 @@
|
||||
import pytest
|
||||
import shutil
|
||||
import numpy as np
|
||||
from concrete.compiler import (
|
||||
JITSupport,
|
||||
LibrarySupport,
|
||||
ClientSupport,
|
||||
CompilationOptions,
|
||||
PublicArguments,
|
||||
)
|
||||
from concrete.compiler.client_parameters import ClientParameters
|
||||
from concrete.compiler.public_result import PublicResult
|
||||
|
||||
|
||||
def assert_result(result, expected_result):
|
||||
"""Assert that result and expected result are equal.
|
||||
|
||||
result and expected_result can be integers on numpy arrays.
|
||||
"""
|
||||
assert type(expected_result) == type(result)
|
||||
if isinstance(expected_result, int):
|
||||
assert result == expected_result
|
||||
else:
|
||||
assert np.all(result == expected_result)
|
||||
|
||||
|
||||
def run_with_serialization(
|
||||
engine,
|
||||
args,
|
||||
compilation_result,
|
||||
keyset_cache,
|
||||
):
|
||||
"""Execute engine on the given arguments. Performs serialization betwee client/server.
|
||||
|
||||
Perform required loading, encryption, execution, and decryption."""
|
||||
# Client
|
||||
client_parameters = engine.load_client_parameters(compilation_result)
|
||||
serialized_client_parameters = client_parameters.serialize()
|
||||
client_parameters = ClientParameters.unserialize(serialized_client_parameters)
|
||||
key_set = ClientSupport.key_set(client_parameters, keyset_cache)
|
||||
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
|
||||
public_arguments_buffer = public_arguments.serialize()
|
||||
# Server
|
||||
public_arguments = PublicArguments.unserialize(
|
||||
client_parameters, public_arguments_buffer
|
||||
)
|
||||
del public_arguments_buffer
|
||||
server_lambda = engine.load_server_lambda(compilation_result)
|
||||
public_result = engine.server_call(server_lambda, public_arguments)
|
||||
public_result_buffer = public_result.serialize()
|
||||
# Client
|
||||
public_result = PublicResult.unserialize(client_parameters, public_result_buffer)
|
||||
del public_result_buffer
|
||||
result = ClientSupport.decrypt_result(key_set, public_result)
|
||||
return result
|
||||
|
||||
|
||||
def compile_run_assert_with_serialization(
|
||||
engine,
|
||||
mlir_input,
|
||||
args,
|
||||
expected_result,
|
||||
keyset_cache,
|
||||
):
|
||||
"""Compile run and assert result. Performs serialization betwee client/server.
|
||||
|
||||
Can take both JITSupport or LibrarySupport as engine.
|
||||
"""
|
||||
options = CompilationOptions.new("main")
|
||||
compilation_result = engine.compile(mlir_input, options)
|
||||
result = run_with_serialization(engine, args, compilation_result, keyset_cache)
|
||||
assert_result(result, expected_result)
|
||||
|
||||
|
||||
end_to_end_fixture = [
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%arg0: !FHE.eint<5>, %arg1: i6) -> !FHE.eint<5> {
|
||||
%1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<5>, i6) -> (!FHE.eint<5>)
|
||||
return %1: !FHE.eint<5>
|
||||
}
|
||||
""",
|
||||
(5, 7),
|
||||
12,
|
||||
id="enc_plain_int_args",
|
||||
marks=pytest.mark.xfail,
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
|
||||
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<5>, !FHE.eint<5>) -> (!FHE.eint<5>)
|
||||
return %1: !FHE.eint<5>
|
||||
}
|
||||
""",
|
||||
(5, 7),
|
||||
12,
|
||||
id="enc_enc_int_args",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%arg0: tensor<4x!FHE.eint<5>>, %arg1: tensor<4xi6>) -> !FHE.eint<5>
|
||||
{
|
||||
%ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) :
|
||||
(tensor<4x!FHE.eint<5>>, tensor<4xi6>) -> !FHE.eint<5>
|
||||
return %ret : !FHE.eint<5>
|
||||
}
|
||||
""",
|
||||
(
|
||||
np.array([1, 2, 3, 4], dtype=np.uint8),
|
||||
np.array([4, 3, 2, 1], dtype=np.uint8),
|
||||
),
|
||||
20,
|
||||
id="enc_plain_ndarray_args",
|
||||
marks=pytest.mark.xfail,
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%a0: tensor<4x!FHE.eint<5>>, %a1: tensor<4x!FHE.eint<5>>) -> tensor<4x!FHE.eint<5>> {
|
||||
%res = "FHELinalg.add_eint"(%a0, %a1) : (tensor<4x!FHE.eint<5>>, tensor<4x!FHE.eint<5>>) -> tensor<4x!FHE.eint<5>>
|
||||
return %res : tensor<4x!FHE.eint<5>>
|
||||
}
|
||||
""",
|
||||
(
|
||||
np.array([1, 2, 3, 4], dtype=np.uint8),
|
||||
np.array([7, 0, 1, 5], dtype=np.uint8),
|
||||
),
|
||||
np.array([8, 2, 4, 9]),
|
||||
id="enc_enc_ndarray_args",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
|
||||
def test_jit_compile_and_run_with_serialization(
|
||||
mlir_input, args, expected_result, keyset_cache
|
||||
):
|
||||
engine = JITSupport.new()
|
||||
compile_run_assert_with_serialization(
|
||||
engine, mlir_input, args, expected_result, keyset_cache
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
|
||||
def test_lib_compile_and_run_with_serialization(
|
||||
mlir_input, args, expected_result, keyset_cache
|
||||
):
|
||||
artifact_dir = "./py_test_lib_compile_and_run"
|
||||
engine = LibrarySupport.new(artifact_dir)
|
||||
compile_run_assert_with_serialization(
|
||||
engine, mlir_input, args, expected_result, keyset_cache
|
||||
)
|
||||
shutil.rmtree(artifact_dir)
|
||||
@@ -21,6 +21,8 @@ void compile_and_run(EndToEndDesc desc, LambdaSupport support) {
|
||||
auto keySet = support.keySet(*clientParameters, getTestKeySetCache());
|
||||
ASSERT_EXPECTED_SUCCESS(keySet);
|
||||
|
||||
auto evaluationKeys = (*keySet)->evaluationKeys();
|
||||
|
||||
/* 3 - Load the server lambda */
|
||||
auto serverLambda = support.loadServerLambda(**compilationResult);
|
||||
ASSERT_EXPECTED_SUCCESS(serverLambda);
|
||||
@@ -41,7 +43,8 @@ void compile_and_run(EndToEndDesc desc, LambdaSupport support) {
|
||||
ASSERT_EXPECTED_SUCCESS(publicArguments);
|
||||
|
||||
/* 5 - Call the server lambda */
|
||||
auto publicResult = support.serverCall(*serverLambda, **publicArguments);
|
||||
auto publicResult =
|
||||
support.serverCall(*serverLambda, **publicArguments, evaluationKeys);
|
||||
ASSERT_EXPECTED_SUCCESS(publicResult);
|
||||
|
||||
/* 6 - Decrypt the public result */
|
||||
|
||||
Reference in New Issue
Block a user