feat(Clientlib): separate client encryption and server computation

Resolve #200
This commit is contained in:
rudy
2022-01-03 11:06:01 +01:00
committed by Quentin Bourgerie
parent 4f19dce899
commit 8b71e9d476
76 changed files with 10799 additions and 1106 deletions

View File

@@ -5,8 +5,9 @@ print_and_exit() {
exit 1
}
EXCLUDE_DIRS="-path ./compiler/include/boost-single-header -prune -o"
files=$(find ./compiler/{include,lib,src} -iregex '^.*\.\(cpp\|cc\|h\|hpp\)$')
files=$(find ./compiler/{include,lib,src} $EXCLUDE_DIRS -iregex '^.*\.\(cpp\|cc\|h\|hpp\)$' -print)
for file in $files
do

View File

@@ -2,7 +2,9 @@
set -e -o pipefail
find ./compiler/{include,lib,src} -iregex '^.*\.\(cpp\|cc\|h\|hpp\)$' | xargs clang-format -i -style='file'
EXCLUDE_DIRS="-path ./compiler/include/boost-single-header -prune -o"
find ./compiler/{include,lib,src} $EXCLUDE_DIRS -iregex '^.*\.\(cpp\|cc\|h\|hpp\)$' -print | xargs clang-format -i -style='file'
# show changes if any
git --no-pager diff

View File

@@ -12,6 +12,16 @@ if (APPLE)
add_definitions("-Wno-narrowing")
endif()
add_compile_options(-Wfatal-errors) # stop at first error
# variable length array = vla
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
# using Clang
add_compile_options(-Wno-vla-extension)
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
# using GCC
add_compile_options(-Wno-vla)
endif()
# If we are trying to build the compiler with LLVM/MLIR as libraries
if( NOT DEFINED LLVM_EXTERNAL_CONCRETELANG_SOURCE_DIR )
message(FATAL_ERROR "Concrete compiler requires a unified build with LLVM/MLIR")

View File

@@ -5,6 +5,8 @@ PARALLEL_EXECUTION_ENABLED=OFF
CC_COMPILER=
CXX_COMPILER=
EXTERNAL_HEADERS=include/boost-single-header/outcome.hpp
export PATH := $(BUILD_DIR)/bin:$(PATH)
ifeq ($(shell which ccache),)
@@ -51,7 +53,10 @@ $(BUILD_DIR)/configured.stamp:
-DPython3_EXECUTABLE=${Python3_EXECUTABLE}
touch $@
build-initialized: $(BUILD_DIR)/configured.stamp
include/boost-single-header/outcome.hpp:
wget https://github.com/ned14/outcome/raw/master/single-header/outcome.hpp -O $@
build-initialized: $(EXTERNAL_HEADERS) $(BUILD_DIR)/configured.stamp
doc: build-initialized
cmake --build $(BUILD_DIR) --target mlir-doc
@@ -63,6 +68,12 @@ python-bindings: build-initialized
cmake --build $(BUILD_DIR) --target ConcretelangMLIRPythonModules
cmake --build $(BUILD_DIR) --target ConcretelangPythonModules
clientlib: build-initialized
cmake --build $(BUILD_DIR) --target ConcretelangClientLib
serverlib: build-initialized
cmake --build $(BUILD_DIR) --target ConcretelangServerLib
test-check: concretecompiler file-check not
$(BUILD_DIR)/bin/llvm-lit -v tests/
@@ -87,6 +98,12 @@ uninstall_runtime_lib:
# unit-test
clientlib-unit-test: build-clientlib-unit-test
$(BUILD_DIR)/bin/clientlib_unit_test
build-clientlib-unit-test:
cmake --build $(BUILD_DIR) --target clientlib_unit_test
testlib-unit-test: build-testlib-unit-test
$(BUILD_DIR)/bin/testlib_unit_test

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,13 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_BOOST_OUTCOME_H
#define CONCRETELANG_BOOST_OUTCOME_H
#include "boost-single-header/outcome.hpp"
namespace outcome = outcome_v2_e261cebd;
#endif

View File

@@ -0,0 +1,141 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CLIENTLIB_CLIENT_LAMBDA_H
#define CONCRETELANG_CLIENTLIB_CLIENT_LAMBDA_H
#include <cassert>
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/EncryptedArgs.h"
#include "concretelang/ClientLib/KeySet.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/ClientLib/PublicArguments.h"
#include "concretelang/ClientLib/Types.h"
#include "concretelang/Common/Error.h"
namespace concretelang {
namespace clientlib {
using concretelang::error::StringError;
using scalar_in = uint8_t;
using scalar_out = uint64_t;
using tensor1_in = std::vector<scalar_in>;
using tensor2_in = std::vector<std::vector<scalar_in>>;
using tensor3_in = std::vector<std::vector<std::vector<scalar_in>>>;
using tensor1_out = std::vector<scalar_out>;
using tensor2_out = std::vector<std::vector<scalar_out>>;
using tensor3_out = std::vector<std::vector<std::vector<scalar_out>>>;
class ClientLambda {
/// Low-level class to create the client side view of a FHE function.
public:
virtual ~ClientLambda() = default;
static outcome::checked<ClientLambda, StringError>
/// Construct a ClientLambda from a ClientParameter file.
load(std::string funcName, std::string jsonPath);
/// Emit a call to the given ostream, no meta-date are include, so it's the
/// responsability of the the caller/callee to verify the add/verify the
/// function to be called.
outcome::checked<void, StringError>
untypedSerializeCall(PublicArguments &publicArguments, std::ostream &ostream);
/// Generate or get from cache a KeySet suitable for this ClientLambda
outcome::checked<std::unique_ptr<KeySet>, StringError>
keySet(std::shared_ptr<KeySetCache> optionalCache, uint64_t seed_msb,
uint64_t seed_lsb);
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
decryptReturnedValues(KeySet &keySet, std::istream &istream);
outcome::checked<decrypted_scalar_t, StringError>
decryptReturnedScalar(KeySet &keySet, std::istream &istream);
outcome::checked<decrypted_tensor_1_t, StringError>
decryptReturnedTensor1(KeySet &keySet, std::istream &istream);
outcome::checked<decrypted_tensor_2_t, StringError>
decryptReturnedTensor2(KeySet &keySet, std::istream &istream);
outcome::checked<decrypted_tensor_3_t, StringError>
decryptReturnedTensor3(KeySet &keySet, std::istream &istream);
public:
ClientParameters clientParameters;
};
template <typename Result>
outcome::checked<Result, StringError>
topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
std::istream &istream);
template <typename Result, typename... Args>
class TypedClientLambda : public ClientLambda {
public:
static outcome::checked<TypedClientLambda<Result, Args...>, StringError>
load(std::string funcName, std::string jsonPath) {
OUTCOME_TRY(auto lambda, ClientLambda::load(funcName, jsonPath));
return TypedClientLambda(lambda);
}
/// Emit a call on this lambda to a binary ostream.
/// The ostream is responsible for transporting the call to a
/// ServerLambda::real_call_write function. ostream must be in binary mode
/// std::ios_base::openmode::binary
outcome::checked<void, StringError>
serializeCall(Args... args, std::shared_ptr<KeySet> keySet,
std::ostream &ostream) {
OUTCOME_TRY(auto publicArguments, publicArguments(args..., keySet));
return ClientLambda::untypedSerializeCall(publicArguments, ostream);
}
outcome::checked<PublicArguments, StringError>
publicArguments(Args... args, std::shared_ptr<KeySet> keySet) {
OUTCOME_TRY(auto clientArguments, EncryptedArgs::create(keySet, args...));
return clientArguments->asPublicArguments(clientParameters,
keySet->runtimeContext());
}
outcome::checked<Result, StringError> decryptReturned(KeySet &keySet,
std::istream &istream) {
return topLevelDecryptResult<Result>((*this), keySet, istream);
}
TypedClientLambda(ClientLambda &lambda) : ClientLambda(lambda) {
// TODO: check parameter types
// TODO: add static check on types vs lambda inputs/outpus
}
protected:
// Workaround, gcc 6 does not support partial template specialisation in class
template <typename Result_>
friend outcome::checked<Result_, StringError>
topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
std::istream &istream);
};
template <>
outcome::checked<decrypted_scalar_t, StringError>
topLevelDecryptResult<decrypted_scalar_t>(ClientLambda &lambda, KeySet &keySet,
std::istream &istream);
template <>
outcome::checked<decrypted_tensor_1_t, StringError>
topLevelDecryptResult<decrypted_tensor_1_t>(ClientLambda &lambda,
KeySet &keySet,
std::istream &istream);
template <>
outcome::checked<decrypted_tensor_2_t, StringError>
topLevelDecryptResult<decrypted_tensor_2_t>(ClientLambda &lambda,
KeySet &keySet,
std::istream &istream);
} // namespace clientlib
} // namespace concretelang
#endif

View File

@@ -5,18 +5,27 @@
#ifndef CONCRETELANG_CLIENTLIB_CLIENTPARAMETERS_H_
#define CONCRETELANG_CLIENTLIB_CLIENTPARAMETERS_H_
#include <map>
#include <string>
#include <vector>
#include "boost/outcome.h"
#include "concretelang/Common/Error.h"
#include <llvm/Support/JSON.h>
namespace mlir {
namespace concretelang {
namespace clientlib {
using concretelang::error::StringError;
const std::string SMALL_KEY = "small";
const std::string BIG_KEY = "big";
const std::string CLIENT_PARAMETERS_EXT = ".concrete.params.json";
typedef size_t DecompositionLevelCount;
typedef size_t DecompositionBaseLog;
typedef size_t PolynomialSize;
@@ -31,6 +40,8 @@ struct LweSecretKeyParam {
LweDimension size;
void hash(size_t &seed);
inline uint64_t lweDimension() { return size; }
inline uint64_t lweSize() { return size + 1; }
};
static bool operator==(const LweSecretKeyParam &lhs,
const LweSecretKeyParam &rhs) {
@@ -121,8 +132,17 @@ struct ClientParameters {
std::vector<CircuitGate> inputs;
std::vector<CircuitGate> outputs;
std::string functionName;
size_t hash();
static outcome::checked<std::vector<ClientParameters>, StringError>
load(std::string path);
static std::string getClientParametersPath(std::string path);
LweSecretKeyParam lweSecretKeyParam(CircuitGate gate);
};
static inline bool operator==(const ClientParameters &lhs,
const ClientParameters &rhs) {
return lhs.secretKeys == rhs.secretKeys &&
@@ -154,12 +174,13 @@ bool fromJSON(const llvm::json::Value, CircuitGate &, llvm::json::Path);
llvm::json::Value toJSON(const ClientParameters &);
bool fromJSON(const llvm::json::Value, ClientParameters &, llvm::json::Path);
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
ClientParameters cp) {
return OS << llvm::formatv("{0:2}", toJSON(cp));
}
} // namespace clientlib
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -0,0 +1,133 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CLIENTLIB_ENCRYPTED_ARGS_H
#define CONCRETELANG_CLIENTLIB_ENCRYPTED_ARGS_H
#include <ostream>
#include "boost/outcome.h"
#include "../Common/Error.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/KeySet.h"
#include "concretelang/ClientLib/Types.h"
#include "concretelang/Common/BitsSize.h"
namespace concretelang {
namespace clientlib {
using concretelang::error::StringError;
class PublicArguments;
class EncryptedArgs {
/// Temporary object used to hold and encrypt parameters before calling a
/// ClientLambda. Use preferably TypeClientLambda and serializeCall(Args...).
/// Otherwise convert it to a PublicArguments and use
/// serializeCall(PublicArguments, KeySet).
public:
// Create EncryptedArgument that use the given KeySet to perform
// encryption/decryption operations.
template <typename... Args>
static outcome::checked<std::shared_ptr<EncryptedArgs>, StringError>
create(std::shared_ptr<KeySet> keySet, Args... args) {
auto arguments = std::make_shared<EncryptedArgs>();
OUTCOME_TRYV(arguments->pushArgs(keySet, args...));
return arguments;
}
/** Low level interface */
public:
// Add a scalar argument.
outcome::checked<void, StringError> pushArg(uint8_t arg,
std::shared_ptr<KeySet> keySet);
// Add a vector-tensor argument.
outcome::checked<void, StringError> pushArg(std::vector<uint8_t> arg,
std::shared_ptr<KeySet> keySet);
template <size_t size>
outcome::checked<void, StringError> pushArg(std::array<uint8_t, size> arg,
std::shared_ptr<KeySet> keySet) {
return pushArg(8, (void *)arg.data(), {size}, keySet);
}
// Add a matrix-tensor argument.
template <size_t size0, size_t size1>
outcome::checked<void, StringError>
pushArg(std::array<std::array<uint8_t, size1>, size0> arg,
std::shared_ptr<KeySet> keySet) {
return pushArg(8, (void *)arg.data(), {size0, size1}, keySet);
}
// Add a rank3 tensor.
template <size_t size0, size_t size1, size_t size2>
outcome::checked<void, StringError>
pushArg(std::array<std::array<std::array<uint8_t, size2>, size1>, size0> arg,
std::shared_ptr<KeySet> keySet) {
return pushArg(8, (void *)arg.data(), {size0, size1, size2}, keySet);
}
// Generalize by computing shape by template recursion
// Set a argument at the given pos as a 1D tensor of T.
template <typename T>
outcome::checked<void, StringError> pushArg(T *data, size_t dim1,
std::shared_ptr<KeySet> keySet) {
return pushArg<T>(data, llvm::ArrayRef<size_t>(&dim1, 1), keySet);
}
// Set a argument at the given pos as a tensor of T.
template <typename T>
outcome::checked<void, StringError> pushArg(T *data,
llvm::ArrayRef<int64_t> shape,
std::shared_ptr<KeySet> keySet) {
return pushArg(8 * sizeof(T), static_cast<void *>(data), shape, keySet);
}
outcome::checked<void, StringError> pushArg(size_t width, void *data,
llvm::ArrayRef<int64_t> shape,
std::shared_ptr<KeySet> keySet);
template <typename Arg0, typename... OtherArgs>
outcome::checked<void, StringError> pushArgs(std::shared_ptr<KeySet> keySet,
Arg0 arg0, OtherArgs... others) {
OUTCOME_TRYV(pushArg(arg0, keySet));
return pushArgs(keySet, others...);
}
outcome::checked<void, StringError> pushArgs(std::shared_ptr<KeySet> keySet) {
return checkAllArgs(keySet);
}
outcome::checked<PublicArguments, StringError>
asPublicArguments(ClientParameters clientParameters,
RuntimeContext runtimeContext);
EncryptedArgs();
~EncryptedArgs();
private:
outcome::checked<void, StringError>
checkPushTooManyArgs(std::shared_ptr<KeySet> keySetPtr);
outcome::checked<void, StringError>
checkAllArgs(std::shared_ptr<KeySet> keySet);
// Add a scalar argument.
outcome::checked<void, StringError> pushArg(uint64_t arg,
std::shared_ptr<KeySet> keySet);
// Position of the next pushed argument
size_t currentPos;
std::vector<void *> preparedArgs;
// Store buffers of ciphertexts
std::vector<encrypted_scalars_and_sizes_t> ciphertextBuffers;
};
} // namespace clientlib
} // namespace concretelang
#endif

View File

@@ -3,11 +3,13 @@
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_SUPPORT_KEYSET_H_
#define CONCRETELANG_SUPPORT_KEYSET_H_
#ifndef CONCRETELANG_CLIENTLIB_KEYSET_H_
#define CONCRETELANG_CLIENTLIB_KEYSET_H_
#include <memory>
#include "boost/outcome.h"
extern "C" {
#include "concrete-ffi.h"
}
@@ -15,26 +17,26 @@ extern "C" {
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/Common/Error.h"
namespace mlir {
namespace concretelang {
namespace clientlib {
using concretelang::error::StringError;
using RuntimeContext = mlir::concretelang::RuntimeContext;
class KeySet {
public:
~KeySet();
static std::unique_ptr<KeySet> uninitialized();
llvm::Error generateKeysFromParams(ClientParameters &params,
uint64_t seed_msb, uint64_t seed_lsb);
llvm::Error setupEncryptionMaterial(ClientParameters &params,
uint64_t seed_msb, uint64_t seed_lsb);
// allocate a KeySet according the ClientParameters.
static llvm::Expected<std::unique_ptr<KeySet>>
static outcome::checked<std::unique_ptr<KeySet>, StringError>
generate(ClientParameters &params, uint64_t seed_msb, uint64_t seed_lsb);
static outcome::checked<std::unique_ptr<KeySet>, StringError>
generateCached(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb);
// isInputEncrypted return true if the input at the given pos is encrypted.
bool isInputEncrypted(size_t pos);
@@ -57,17 +59,19 @@ public:
// allocate a lwe ciphertext buffer for the argument at argPos, set the size
// of the allocated buffer.
llvm::Error allocate_lwe(size_t argPos, uint64_t **ciphertext,
uint64_t &size);
outcome::checked<void, StringError>
allocate_lwe(size_t argPos, uint64_t **ciphertext, uint64_t &size);
// encrypt the input to the ciphertext for the argument at argPos.
llvm::Error encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input);
outcome::checked<void, StringError>
encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input);
// isOuputEncrypted return true if the output at the given pos is encrypted.
bool isOutputEncrypted(size_t pos);
// decrypt the ciphertext to the output for the argument at argPos.
llvm::Error decrypt_lwe(size_t argPos, uint64_t *ciphertext,
uint64_t &output);
outcome::checked<void, StringError>
decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output);
size_t numInputs() { return inputs.size(); }
size_t numOutputs() { return outputs.size(); }
@@ -77,8 +81,14 @@ public:
void setRuntimeContext(RuntimeContext &context) {
context.ksk = std::get<1>(this->keyswitchKeys["ksk_v0"]);
context.bsk["_concretelang_base_context_bsk"] =
std::get<1>(this->bootstrapKeys["bsk_v0"]);
context.bsk[RuntimeContext::BASE_CONTEXT_BSK] =
std::get<1>(this->bootstrapKeys.at("bsk_v0"));
}
RuntimeContext runtimeContext() {
RuntimeContext context;
this->setRuntimeContext(context);
return context;
}
const std::map<LweSecretKeyID,
@@ -94,12 +104,23 @@ public:
getKeyswitchKeys();
protected:
llvm::Error generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param,
SecretRandomGenerator *generator);
llvm::Error generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param,
EncryptionRandomGenerator *generator);
llvm::Error generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param,
EncryptionRandomGenerator *generator);
outcome::checked<void, StringError>
generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param,
SecretRandomGenerator *generator);
outcome::checked<void, StringError>
generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param,
EncryptionRandomGenerator *generator);
outcome::checked<void, StringError>
generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param,
EncryptionRandomGenerator *generator);
outcome::checked<void, StringError>
generateKeysFromParams(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb);
outcome::checked<void, StringError>
setupEncryptionMaterial(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb);
friend class KeySetCache;
@@ -127,7 +148,7 @@ private:
keyswitchKeys);
};
} // namespace clientlib
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -3,13 +3,13 @@
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_SUPPORT_KEYSETCACHE_H_
#define CONCRETELANG_SUPPORT_KEYSETCACHE_H_
#ifndef CONCRETELANG_CLIENTLIB_KEYSETCACHE_H_
#define CONCRETELANG_CLIENTLIB_KEYSETCACHE_H_
#include "concretelang/ClientLib/KeySet.h"
namespace mlir {
namespace concretelang {
namespace clientlib {
class KeySet;
@@ -20,17 +20,21 @@ public:
KeySetCache(std::string backingDirectoryPath)
: backingDirectoryPath(backingDirectoryPath) {}
llvm::Expected<std::unique_ptr<KeySet>>
tryLoadOrGenerateSave(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb);
static outcome::checked<std::unique_ptr<KeySet>, StringError>
generate(std::shared_ptr<KeySetCache> optionalCache, ClientParameters &params,
uint64_t seed_msb, uint64_t seed_lsb);
private:
static llvm::Expected<std::unique_ptr<KeySet>>
tryLoadKeys(ClientParameters &params, uint64_t seed_msb, uint64_t seed_lsb,
llvm::SmallString<0> &folderPath);
static outcome::checked<std::unique_ptr<KeySet>, StringError>
loadKeys(ClientParameters &params, uint64_t seed_msb, uint64_t seed_lsb,
std::string folderPath);
outcome::checked<std::unique_ptr<KeySet>, StringError>
loadOrGenerateSave(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb);
};
} // namespace clientlib
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -0,0 +1,66 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CLIENTLIB_PUBLIC_ARGUMENTS_H
#define CONCRETELANG_CLIENTLIB_PUBLIC_ARGUMENTS_H
#include <iostream>
#include "boost/outcome.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/EncryptedArgs.h"
#include "concretelang/ClientLib/Types.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Runtime/context.h"
namespace concretelang {
namespace serverlib {
class ServerLambda;
}
} // namespace concretelang
namespace concretelang {
namespace clientlib {
using concretelang::error::StringError;
class EncryptedArgs;
class PublicArguments {
/// PublicArguments will be sended to the server. It includes encrypted
/// arguments and public keys.
public:
PublicArguments(
const ClientParameters &clientParameters, RuntimeContext runtimeContext,
bool clearRuntimeContext, std::vector<void *> &&preparedArgs,
std::vector<encrypted_scalars_and_sizes_t> &&ciphertextBuffers);
PublicArguments(PublicArguments &other) = delete;
// to have proper owership transfer (outcome and local object)
PublicArguments(PublicArguments &&other);
~PublicArguments();
void freeIfNotOwned(std::vector<encrypted_scalar_t> res);
static outcome::checked<std::shared_ptr<PublicArguments>, StringError>
unserialize(ClientParameters &expectedParams, std::istream &istream);
outcome::checked<void, StringError> serialize(std::ostream &ostream);
private:
friend class ::concretelang::serverlib::ServerLambda; // from ServerLib
outcome::checked<void, StringError> unserializeArgs(std::istream &istream);
ClientParameters clientParameters;
RuntimeContext runtimeContext;
std::vector<void *> preparedArgs;
// Store buffers of ciphertexts
std::vector<encrypted_scalars_and_sizes_t> ciphertextBuffers;
bool clearRuntimeContext;
};
} // namespace clientlib
} // namespace concretelang
#endif

View File

@@ -0,0 +1,78 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CLIENTLIB_SERIALIZERS_ARGUMENTS_H
#define CONCRETELANG_CLIENTLIB_SERIALIZERS_ARGUMENTS_H
#include <iostream>
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/Types.h"
#include "concretelang/Runtime/context.h"
namespace concretelang {
namespace clientlib {
using RuntimeContext = mlir::concretelang::RuntimeContext;
// integers are not serialized as binary values even on a binary stream
// so we cannot rely on << operator directly
template <typename Word>
std::ostream &writeWord(std::ostream &ostream, Word word) {
ostream.write(reinterpret_cast<char *>(&(word)), sizeof(word));
assert(ostream.good());
return ostream;
}
template <typename Size>
std::ostream &writeSize(std::ostream &ostream, Size size) {
return writeWord(ostream, size);
}
// for sake of symetry
template <typename Word>
std::istream &readWord(std::istream &istream, Word &word) {
istream.read(reinterpret_cast<char *>(&(word)), sizeof(word));
assert(istream.good());
return istream;
}
template <typename Size>
std::istream &readSize(std::istream &istream, Size &size) {
return readWord(istream, size);
}
template <typename Stream> bool incorrectMode(Stream &stream) {
auto binary = stream.flags() && std::ios::binary;
if (!binary) {
stream.setstate(std::ios::failbit);
}
return !binary;
}
std::ostream &operator<<(std::ostream &ostream, const ClientParameters &params);
std::istream &operator>>(std::istream &istream, ClientParameters &params);
std::ostream &operator<<(std::ostream &ostream,
const RuntimeContext &runtimeContext);
std::istream &operator>>(std::istream &istream, RuntimeContext &runtimeContext);
std::ostream &serializeEncryptedValues(std::vector<size_t> &sizes,
encrypted_scalars_t values,
std::ostream &ostream);
std::ostream &
serializeEncryptedValues(encrypted_scalars_and_sizes_t &values_and_sizes,
std::ostream &ostream);
encrypted_scalars_and_sizes_t unserializeEncryptedValues(
std::vector<int64_t> &expectedSizes, // includes lweSize, unsigned to
// accomodate non static sizes
std::istream &istream);
} // namespace clientlib
} // namespace concretelang
#endif

View File

@@ -0,0 +1,58 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CLIENTLIB_TYPES_H_
#define CONCRETELANG_CLIENTLIB_TYPES_H_
#include <cstdint>
#include <vector>
extern "C" {
#include "concrete-ffi.h"
}
namespace concretelang {
namespace clientlib {
template <size_t N> struct MemRefDescriptor {
uint64_t *allocated;
uint64_t *aligned;
size_t offset;
size_t sizes[N];
size_t strides[N];
};
using decrypted_scalar_t = std::uint64_t;
using decrypted_tensor_1_t = std::vector<decrypted_scalar_t>;
using decrypted_tensor_2_t = std::vector<decrypted_tensor_1_t>;
using decrypted_tensor_3_t = std::vector<decrypted_tensor_2_t>;
template <size_t Rank> using encrypted_tensor_t = MemRefDescriptor<Rank>;
using encrypted_scalar_t = uint64_t *;
using encrypted_scalars_t = uint64_t *;
struct encrypted_scalars_and_sizes_t {
std::vector<uint64_t> values; // tensor of rank r + 1
std::vector<size_t> sizes; // r sizes
inline size_t length() {
if (sizes.empty()) {
assert(false);
return 0;
}
size_t len = 1;
for (auto size : sizes) {
assert(size > 0);
len *= size;
}
return len;
}
inline size_t lweSize() { return sizes.back(); }
};
} // namespace clientlib
} // namespace concretelang
#endif

View File

@@ -0,0 +1,19 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_COMMON_BITS_SIZE_H
#define CONCRETELANG_COMMON_BITS_SIZE_H
#include <stdlib.h>
namespace concretelang {
namespace common {
size_t bitWidthAsWord(size_t exactBitWidth);
}
} // namespace concretelang
#endif

View File

@@ -0,0 +1,43 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_COMMON_ERROR_H
#define CONCRETELANG_COMMON_ERROR_H
#include <string>
namespace concretelang {
namespace error {
class StringError {
public:
StringError(std::string mesg) : mesg(mesg){};
std::string mesg;
StringError &operator<<(const std::string &v) {
mesg += v;
return *this;
}
StringError &operator<<(const char *v) {
mesg += std::string(v);
return *this;
}
StringError &operator<<(char *v) {
mesg += std::string(v);
return *this;
}
template <typename T> inline StringError &operator<<(const T v) {
mesg += std::to_string(v);
return *this;
}
};
} // namespace error
} // namespace concretelang
#endif

View File

@@ -15,17 +15,20 @@ extern "C" {
namespace mlir {
namespace concretelang {
typedef struct RuntimeContext {
LweKeyswitchKey_u64 *ksk;
std::map<std::string, LweBootstrapKey_u64 *> bsk;
static std::string BASE_CONTEXT_BSK;
~RuntimeContext() {
for (const auto &key : bsk) {
if (key.first != "_concretelang_base_context_bsk")
if (key.first != BASE_CONTEXT_BSK)
free_lwe_bootstrap_key_u64(key.second);
}
}
} RuntimeContext;
} // namespace concretelang
} // namespace mlir

View File

@@ -3,20 +3,23 @@
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_TESTLIB_DYNAMIC_ARITY_CALL_H
#define CONCRETELANG_TESTLIB_DYNAMIC_ARITY_CALL_H
// generated: see genDynamicRandAndArityCall.py
#ifndef CONCRETELANG_SERVERLIB_DYNAMIC_ARITY_CALL_H
#define CONCRETELANG_SERVERLIB_DYNAMIC_ARITY_CALL_H
#include <cassert>
#include <vector>
namespace mlir {
#include "concretelang/ClientLib/Types.h"
namespace concretelang {
namespace serverlib {
template <typename Res>
Res call(Res (*func)(void *...), std::vector<void *> args) {
Res multi_arity_call(Res (*func)(void *...), std::vector<void *> args) {
switch (args.size()) {
// generated part: see genDynamicArityCall.py
// TODO C17++: https://en.cppreference.com/w/cpp/utility/apply
// TODO C17++: https://en.cppreference.com/w/cpp/utility/apply
case 1:
return func(args[0]);
case 2:
@@ -1457,12 +1460,13 @@ Res call(Res (*func)(void *...), std::vector<void *> args) {
args[111], args[112], args[113], args[114], args[115], args[116],
args[117], args[118], args[119], args[120], args[121], args[122],
args[123], args[124], args[125], args[126]);
default:
assert(false);
}
}
} // namespace serverlib
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -6,28 +6,37 @@
#ifndef CONCRETELANG_TESTLIB_DYNAMIC_MODULE_H
#define CONCRETELANG_TESTLIB_DYNAMIC_MODULE_H
#include "concretelang/ClientLib/ClientParameters.h"
#include "boost/outcome.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/Common/Error.h"
namespace mlir {
namespace concretelang {
namespace serverlib {
using concretelang::clientlib::ClientParameters;
using concretelang::error::StringError;
class DynamicModule {
public:
~DynamicModule();
static llvm::Expected<std::shared_ptr<DynamicModule>>
static outcome::checked<std::shared_ptr<DynamicModule>, StringError>
open(std::string libraryPath);
private:
llvm::Error loadClientParametersJSON(std::string path);
llvm::Error loadSharedLibrary(std::string path);
outcome::checked<void, StringError>
loadClientParametersJSON(std::string path);
outcome::checked<void, StringError> loadSharedLibrary(std::string path);
private:
std::vector<ClientParameters> clientParametersList;
void *libraryHandle;
friend class DynamicLambda;
friend class ServerLambda;
};
} // namespace serverlib
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -0,0 +1,25 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_SERVERLIB_DYNAMIC_RANK_CALL_H
#define CONCRETELANG_SERVERLIB_DYNAMIC_RANK_CALL_H
#include <vector>
#include "concretelang/ClientLib/Types.h"
namespace concretelang {
namespace serverlib {
using concretelang::clientlib::encrypted_scalars_and_sizes_t;
encrypted_scalars_and_sizes_t
multi_arity_call_dynamic_rank(void *(*func)(void *...),
std::vector<void *> args, size_t rank);
} // namespace serverlib
} // namespace concretelang
#endif

View File

@@ -0,0 +1,52 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_SERVERLIB_SERVER_LAMBDA_H
#define CONCRETELANG_SERVERLIB_SERVER_LAMBDA_H
#include <cassert>
#include "boost/outcome.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/PublicArguments.h"
#include "concretelang/ClientLib/Types.h"
#include "concretelang/Common/Error.h"
#include "concretelang/ServerLib/DynamicModule.h"
namespace concretelang {
namespace serverlib {
using concretelang::clientlib::encrypted_scalar_t;
using concretelang::clientlib::encrypted_scalars_and_sizes_t;
using concretelang::clientlib::encrypted_scalars_t;
encrypted_scalars_and_sizes_t encrypted_scalars_and_sizes_t_from_MemRef(
size_t rank, encrypted_scalars_t allocated, encrypted_scalars_t aligned,
size_t offset, size_t *sizes, size_t *strides);
class ServerLambda {
public:
static outcome::checked<ServerLambda, concretelang::error::StringError>
load(std::string funcName, std::string outputLib);
static outcome::checked<ServerLambda, concretelang::error::StringError>
loadFromModule(std::shared_ptr<DynamicModule> module, std::string funcName);
outcome::checked<void, concretelang::error::StringError>
read_call_write(std::istream &istream, std::ostream &ostream);
protected:
ClientParameters clientParameters;
void *(*func)(void *...);
// Retain module and open shared lib alive
std::shared_ptr<DynamicModule> module;
};
} // namespace serverlib
} // namespace concretelang
#endif

View File

@@ -0,0 +1,45 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
print(
"""// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt
// for license information.
// generated: see genDynamicArityCall.py
#ifndef CONCRETELANG_SERVERLIB_DYNAMIC_ARITY_CALL_H
#define CONCRETELANG_SERVERLIB_DYNAMIC_ARITY_CALL_H
#include <cassert>
#include <vector>
#include "concretelang/ClientLib/Types.h"
namespace mlir {
namespace serverlib {
template <typename Res>
Res multi_arity_call(Res (*func)(void *...), std::vector<void *> args) {
switch (args.size()) {
// TODO C17++: https://en.cppreference.com/w/cpp/utility/apply
""")
for i in range(1, 128):
args = ','.join(f'args[{j}]' for j in range(i))
print(f' case {i}: return func({args});')
print("""
default:
assert(false);
}
}""")
print("""
} // namespace concretelang
} // namespace mlir
#endif
""")

View File

@@ -99,8 +99,10 @@ public:
llvm::Expected<std::string> emitStatic();
/** Emit a shared library with the previously added compilation result */
llvm::Expected<std::string> emitShared();
/** Emit a shared library with the previously added compilation result */
/** Emit a json ClientParameters corresponding to library content */
llvm::Expected<std::string> emitClientParametersJSON();
/// Emit a client header file for this corresponding to library content
llvm::Expected<std::string> emitCppHeader();
};
// Specification of the exit stage of the compilation pipeline

View File

@@ -15,7 +15,8 @@
namespace mlir {
namespace concretelang {
size_t bitWidthAsWord(size_t exactBitWidth);
using ::concretelang::clientlib::CircuitGate;
using ::concretelang::clientlib::KeySet;
/// JITLambda is a tool to JIT compile an mlir module and to invoke a function
/// of the module.

View File

@@ -16,6 +16,8 @@
namespace mlir {
namespace concretelang {
using ::concretelang::clientlib::KeySetCache;
namespace {
// Generic function template as well as specializations of
// `typedResult` must be declared at namespace scope due to return

View File

@@ -15,8 +15,7 @@
namespace mlir {
namespace concretelang {
ClientParameters emptyClientParametersForV0(llvm::StringRef functionName,
mlir::ModuleOp module);
using ::concretelang::clientlib::ClientParameters;
llvm::Expected<ClientParameters>
createClientParametersForV0(V0FHEContext context, llvm::StringRef functionName,

View File

@@ -1,105 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_TESTLIB_ARGUMENTS_H
#define CONCRETELANG_TESTLIB_ARGUMENTS_H
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/KeySet.h"
namespace mlir {
namespace concretelang {
class DynamicLambda;
class Arguments {
public:
Arguments(KeySet &keySet) : currentPos(0), keySet(keySet) {
keySet.setRuntimeContext(context);
}
~Arguments();
// Create EncryptedArgument that use the given KeySet to perform encryption
// and decryption operations.
static std::shared_ptr<Arguments> create(KeySet &keySet);
// Add a scalar argument.
llvm::Error pushArg(uint64_t arg);
// Add a vector-tensor argument.
llvm::Error pushArg(std::vector<uint8_t> arg);
template <size_t size> llvm::Error pushArg(std::array<uint8_t, size> arg) {
return pushArg(8, (void *)arg.data(), {size});
}
// Add a matrix-tensor argument.
template <size_t size0, size_t size1>
llvm::Error pushArg(std::array<std::array<uint8_t, size1>, size0> arg) {
return pushArg(8, (void *)arg.data(), {size0, size1});
}
// Add a rank3 tensor.
template <size_t size0, size_t size1, size_t size2>
llvm::Error pushArg(
std::array<std::array<std::array<uint8_t, size2>, size1>, size0> arg) {
return pushArg(8, (void *)arg.data(), {size0, size1, size2});
}
// Generalize by computing shape by template recursion
// Set a argument at the given pos as a 1D tensor of T.
template <typename T> llvm::Error pushArg(T *data, int64_t dim1) {
return pushArg<T>(data, llvm::ArrayRef<int64_t>(&dim1, 1));
}
// Set a argument at the given pos as a tensor of T.
template <typename T>
llvm::Error pushArg(T *data, llvm::ArrayRef<int64_t> shape) {
return pushArg(8 * sizeof(T), static_cast<void *>(data), shape);
}
llvm::Error pushArg(size_t width, void *data, llvm::ArrayRef<int64_t> shape);
// Push the runtime context to the argument list, this must be called
// after each argument was pushed.
llvm::Error pushContext();
template <typename Arg0, typename... OtherArgs>
llvm::Error pushArgs(Arg0 arg0, OtherArgs... others) {
auto err = pushArg(arg0);
if (err) {
return err;
}
return pushArgs(others...);
}
llvm::Error pushArgs() { return pushContext(); }
private:
friend DynamicLambda;
template <typename Result>
friend llvm::Expected<Result> invoke(DynamicLambda &lambda,
const Arguments &args);
llvm::Error checkPushTooManyArgs();
// Position of the next pushed argument
size_t currentPos;
std::vector<void *> preparedArgs;
// Store allocated lwe ciphertexts (for free)
std::vector<uint64_t *> allocatedCiphertexts;
// Store buffers of ciphertexts
std::vector<uint64_t *> ciphertextBuffers;
KeySet &keySet;
RuntimeContext context;
};
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -1,123 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_TESTLIB_DYNAMIC_LAMBDA_H
#define CONCRETELANG_TESTLIB_DYNAMIC_LAMBDA_H
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/KeySet.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/TestLib/Arguments.h"
#include "concretelang/TestLib/DynamicModule.h"
namespace mlir {
namespace concretelang {
template <size_t N> struct MemRefDescriptor;
template <typename Result>
llvm::Expected<Result> invoke(DynamicLambda &lambda, const Arguments &args) {
// compile time error if used
using COMPATIBLE_RESULT_TYPE = void;
return (Result)(COMPATIBLE_RESULT_TYPE)0; // invoke does not accept this kind
// of Result
}
template <>
llvm::Expected<u_int64_t> invoke<u_int64_t>(DynamicLambda &lambda,
const Arguments &args);
template <>
llvm::Expected<std::vector<uint64_t>>
invoke<std::vector<uint64_t>>(DynamicLambda &lambda, const Arguments &args);
template <>
llvm::Expected<std::vector<std::vector<uint64_t>>>
invoke<std::vector<std::vector<uint64_t>>>(DynamicLambda &lambda,
const Arguments &args);
template <>
llvm::Expected<std::vector<std::vector<std::vector<uint64_t>>>>
invoke<std::vector<std::vector<std::vector<uint64_t>>>>(DynamicLambda &lambda,
const Arguments &args);
class DynamicLambda {
private:
template <typename... Args>
llvm::Expected<std::shared_ptr<Arguments>> createArguments(Args... args) {
if (keySet == nullptr) {
return StreamStringError("keySet was not initialized");
}
auto arg = Arguments::create(*keySet);
auto err = arg->pushArgs(args...);
if (err) {
return StreamStringError(llvm::toString(std::move(err)));
}
return arg;
}
public:
static llvm::Expected<DynamicLambda> load(std::string funcName,
std::string outputLib);
static llvm::Expected<DynamicLambda>
load(std::shared_ptr<DynamicModule> module, std::string funcName);
template <typename Result, typename... Args>
llvm::Expected<Result> call(Args... args) {
auto argOrErr = createArguments(args...);
if (!argOrErr) {
return argOrErr.takeError();
}
auto arg = argOrErr.get();
return invoke<Result>(*this, *arg);
}
llvm::Error generateKeySet(llvm::Optional<KeySetCache> cache = llvm::None,
uint64_t seed_msb = 0, uint64_t seed_lsb = 0);
protected:
template <typename Result>
friend llvm::Expected<Result> invoke(DynamicLambda &lambda,
const Arguments &args);
template <size_t Rank>
llvm::Expected<MemRefDescriptor<Rank>>
invokeMemRefDecriptor(const Arguments &args);
ClientParameters clientParameters;
std::shared_ptr<KeySet> keySet;
void *func;
// Retain module and open shared lib alive
std::shared_ptr<DynamicModule> module;
};
template <typename Result, typename... Args>
class TypedDynamicLambda : public DynamicLambda {
public:
static llvm::Expected<TypedDynamicLambda<Result, Args...>>
load(std::string funcName, std::string outputLib) {
auto lambda = DynamicLambda::load(funcName, outputLib);
if (!lambda) {
return lambda.takeError();
}
return TypedDynamicLambda(*lambda);
}
llvm::Expected<Result> call(Args... args) {
return DynamicLambda::call<Result>(args...);
}
// TODO: check parameter types
TypedDynamicLambda(DynamicLambda &lambda) : DynamicLambda(lambda) {
// TODO: add static check on types vs lambda inputs/outpus
}
};
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -0,0 +1,118 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_TESTLIB_TEST_TYPED_LAMBDA_H
#define CONCRETELANG_TESTLIB_TEST_TYPED_LAMBDA_H
#include "boost/outcome.h"
#include "concretelang/ClientLib/ClientLambda.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/ClientLib/Serializers.h"
#include "concretelang/Common/Error.h"
#include "concretelang/ServerLib/ServerLambda.h"
namespace concretelang {
namespace testlib {
using concretelang::clientlib::ClientLambda;
using concretelang::clientlib::ClientParameters;
using concretelang::clientlib::KeySet;
using concretelang::clientlib::KeySetCache;
using concretelang::error::StringError;
using concretelang::serverlib::ServerLambda;
inline void freeStringMemory(std::string &s) {
std::string empty;
s.swap(empty);
}
template <typename Result, typename... Args>
class TestTypedLambda
: public concretelang::clientlib::TypedClientLambda<Result, Args...> {
template <typename Result_, typename... Args_>
using TypedClientLambda =
concretelang::clientlib::TypedClientLambda<Result_, Args_...>;
public:
static outcome::checked<TestTypedLambda, StringError>
load(std::string funcName, std::string outputLib, uint64_t seed_msb = 0,
uint64_t seed_lsb = 0,
std::shared_ptr<KeySetCache> unsecure_cache = nullptr) {
std::string jsonPath = ClientParameters::getClientParametersPath(outputLib);
OUTCOME_TRY(auto cLambda, ClientLambda::load(funcName, jsonPath));
OUTCOME_TRY(auto sLambda, ServerLambda::load(funcName, outputLib));
OUTCOME_TRY(std::shared_ptr<KeySet> keySet,
KeySetCache::generate(unsecure_cache, cLambda.clientParameters,
seed_msb, seed_lsb));
return TestTypedLambda(cLambda, sLambda, keySet);
}
TestTypedLambda(ClientLambda &cLambda, ServerLambda &sLambda,
std::shared_ptr<KeySet> keySet)
: TypedClientLambda<Result, Args...>(cLambda), serverLambda(sLambda),
keySet(keySet) {}
TestTypedLambda(TypedClientLambda<Result, Args...> &cLambda,
ServerLambda &sLambda, std::shared_ptr<KeySet> keySet)
: TypedClientLambda<Result, Args...>(cLambda), serverLambda(sLambda),
keySet(keySet) {}
outcome::checked<Result, StringError> call(Args... args) {
// client
auto BINARY = std::ios::binary;
std::string message;
{
// client
std::ostringstream clientOuput(BINARY);
OUTCOME_TRYV(this->serializeCall(args..., keySet, clientOuput));
if (clientOuput.fail()) {
return StringError("Error on clientOuput");
}
message = clientOuput.str();
}
{
// server
std::istringstream serverInput(message, BINARY);
freeStringMemory(message);
assert(serverInput.tellg() == 0);
std::ostringstream serverOutput(BINARY);
OUTCOME_TRYV(serverLambda.read_call_write(serverInput, serverOutput));
if (serverInput.fail()) {
return StringError("Error on serverOutput");
}
if (serverOutput.fail()) {
return StringError("Error on serverOutput");
}
message = serverOutput.str();
}
{
// client
std::istringstream clientInput(message, BINARY);
freeStringMemory(message);
OUTCOME_TRY(auto result, this->decryptReturned(*keySet, clientInput));
assert(clientInput.good());
return result;
}
}
private:
ServerLambda serverLambda;
std::shared_ptr<KeySet> keySet;
};
template <typename Result, typename... Args>
static TestTypedLambda<Result, Args...> TestTypedLambdaFrom(
concretelang::clientlib::TypedClientLambda<Result, Args...> &cLambda,
ServerLambda &sLambda, std::shared_ptr<KeySet> keySet) {
return TestTypedLambda<Result, Args...>(cLambda, sLambda, keySet);
}
} // namespace testlib
} // namespace concretelang
#endif

View File

@@ -1,6 +0,0 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
for i in range(128):
args = ','.join(f'args[{j}]' for j in range(i))
print(f' case {i}: return func({args});')

View File

@@ -50,7 +50,7 @@ lambdaArgument invokeLambda(lambda l, executionArguments args) {
}
// Set the integer/tensor arguments
std::vector<mlir::concretelang::LambdaArgument *> lambdaArgumentsRef;
for (auto i = 0; i < args.size; i++) {
for (auto i = 0u; i < args.size; i++) {
lambdaArgumentsRef.push_back(args.data[i].ptr.get());
}
// Run lambda

View File

@@ -5,7 +5,8 @@ add_subdirectory(Support)
add_subdirectory(Runtime)
add_subdirectory(ClientLib)
add_subdirectory(Bindings)
add_subdirectory(TestLib)
add_subdirectory(ServerLib)
add_subdirectory(Common)
# CAPI needed only for python bindings
if (CONCRETELANG_BINDINGS_PYTHON_ENABLED)

View File

@@ -1,9 +1,29 @@
add_compile_options( -Werror )
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
# using Clang
add_compile_options( -Wno-error=pessimizing-move -Wno-pessimizing-move )
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
# using GCC
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
add_compile_options( -Werror -Wno-error=pessimizing-move -Wno-pessimizing-move )
endif()
endif()
add_mlir_library(
ConcretelangClientLib
ClientLambda.cpp
ClientParameters.cpp
EncryptedArgs.cpp
KeySet.cpp
KeySetCache.cpp
PublicArguments.cpp
Serializers.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/ClientLib
ConcretelangCommon
ConcretelangRuntime
ConcretelangSupportLib
)

View File

@@ -0,0 +1,207 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#include <dlfcn.h>
#include "concretelang/ClientLib/ClientLambda.h"
#include "concretelang/ClientLib/Serializers.h"
namespace concretelang {
namespace clientlib {
using concretelang::error::StringError;
outcome::checked<ClientLambda, StringError>
ClientLambda::load(std::string functionName, std::string jsonPath) {
OUTCOME_TRY(auto all_params, ClientParameters::load(jsonPath));
auto param = llvm::find_if(all_params, [&](ClientParameters param) {
return param.functionName == functionName;
});
if (param == all_params.end()) {
return StringError("ClientLambda: cannot find function ")
<< functionName << " in client parameters" << jsonPath;
}
if (param->outputs.size() != 1) {
return StringError("ClientLambda: output arity (")
<< std::to_string(param->outputs.size())
<< ") != 1 is not supported";
}
if (!param->outputs[0].encryption.hasValue()) {
return StringError("ClientLambda: clear output is not yet supported");
}
ClientLambda lambda;
lambda.clientParameters = *param;
return lambda;
}
outcome::checked<std::unique_ptr<KeySet>, StringError>
ClientLambda::keySet(std::shared_ptr<KeySetCache> optionalCache,
uint64_t seed_msb, uint64_t seed_lsb) {
return KeySetCache::generate(optionalCache, clientParameters, seed_msb,
seed_lsb);
}
outcome::checked<void, StringError>
ClientLambda::untypedSerializeCall(PublicArguments &serverArguments,
std::ostream &ostream) {
return serverArguments.serialize(ostream);
}
outcome::checked<decrypted_scalar_t, StringError>
ClientLambda::decryptReturnedScalar(KeySet &keySet, std::istream &istream) {
OUTCOME_TRY(auto v, decryptReturnedValues(keySet, istream));
return v[0];
}
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
ClientLambda::decryptReturnedValues(KeySet &keySet, std::istream &istream) {
auto lweSize =
clientParameters.lweSecretKeyParam(clientParameters.outputs[0]).lweSize();
std::vector<int64_t> sizes = clientParameters.outputs[0].shape.dimensions;
sizes.push_back(lweSize);
auto encryptedValues = unserializeEncryptedValues(sizes, istream);
if (istream.fail()) {
return StringError("Encrypted scalars has not the right size");
}
auto len = encryptedValues.length();
decrypted_tensor_1_t decryptedValues(len / lweSize);
for (size_t i = 0; i < decryptedValues.size(); i++) {
auto buffer = (uint64_t *)(&encryptedValues.values[i * lweSize]);
OUTCOME_TRYV(keySet.decrypt_lwe(0, buffer, decryptedValues[i]));
}
return decryptedValues;
}
outcome::checked<void, StringError> errorResultRank(size_t expected,
size_t actual) {
return StringError("Expected result has rank ")
<< expected << " and cannot be converted to rank " << actual;
}
StringError errorIncoherentSizes(size_t flatSize, size_t structuredSize) {
return StringError("Received ")
<< flatSize << " values but is sizes indicates as global size of "
<< structuredSize;
}
template <typename DecryptedTensor>
DecryptedTensor flatToTensor(decrypted_tensor_1_t &values, size_t *sizes);
template <>
decrypted_tensor_1_t flatToTensor(decrypted_tensor_1_t &values, size_t *sizes) {
return values;
}
template <>
decrypted_tensor_2_t flatToTensor(decrypted_tensor_1_t &values, size_t *sizes) {
decrypted_tensor_2_t result(sizes[0]);
size_t position = 0;
for (auto &dest0 : result) {
dest0.resize(sizes[1]);
for (auto &dest1 : dest0) {
dest1 = values[position++];
}
}
return result;
}
template <>
decrypted_tensor_3_t flatToTensor(decrypted_tensor_1_t &values, size_t *sizes) {
decrypted_tensor_3_t result(sizes[0]);
size_t position = 0;
for (auto &dest0 : result) {
dest0.resize(sizes[1]);
for (auto &dest1 : dest0) {
dest1.resize(sizes[2]);
for (auto &dest2 : dest1) {
dest2 = values[position++];
}
}
}
return result;
}
template <typename DecryptedTensor>
outcome::checked<DecryptedTensor, StringError>
decryptReturnedTensor(std::istream &istream, ClientLambda &lambda,
ClientParameters &params, size_t expectedRank,
KeySet &keySet) {
auto shape = params.outputs[0].shape;
size_t rank = shape.dimensions.size();
if (rank != expectedRank) {
return StringError("Function returns a tensor of rank ")
<< expectedRank << " which cannot be decrypted to rank " << rank;
}
OUTCOME_TRY(auto values, lambda.decryptReturnedValues(keySet, istream));
size_t sizes[rank];
for (size_t dim = 0; dim < rank; dim++) {
sizes[dim] = shape.dimensions[dim];
}
return flatToTensor<DecryptedTensor>(values, sizes);
}
outcome::checked<decrypted_tensor_1_t, StringError>
ClientLambda::decryptReturnedTensor1(KeySet &keySet, std::istream &istream) {
return decryptReturnedTensor<decrypted_tensor_1_t>(
istream, *this, this->clientParameters, 1, keySet);
}
outcome::checked<decrypted_tensor_2_t, StringError>
ClientLambda::decryptReturnedTensor2(KeySet &keySet, std::istream &istream) {
return decryptReturnedTensor<decrypted_tensor_2_t>(
istream, *this, this->clientParameters, 2, keySet);
}
outcome::checked<decrypted_tensor_3_t, StringError>
ClientLambda::decryptReturnedTensor3(KeySet &keySet, std::istream &istream) {
return decryptReturnedTensor<decrypted_tensor_3_t>(
istream, *this, this->clientParameters, 3, keySet);
}
template <typename Result>
outcome::checked<Result, StringError>
topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
std::istream &istream) {
// compile time error if used
using COMPATIBLE_RESULT_TYPE = void;
return (Result)(COMPATIBLE_RESULT_TYPE)0;
}
template <>
outcome::checked<decrypted_scalar_t, StringError>
topLevelDecryptResult<decrypted_scalar_t>(ClientLambda &lambda, KeySet &keySet,
std::istream &istream) {
return lambda.decryptReturnedScalar(keySet, istream);
}
template <>
outcome::checked<decrypted_tensor_1_t, StringError>
topLevelDecryptResult<decrypted_tensor_1_t>(ClientLambda &lambda,
KeySet &keySet,
std::istream &istream) {
return lambda.decryptReturnedTensor1(keySet, istream);
}
template <>
outcome::checked<decrypted_tensor_2_t, StringError>
topLevelDecryptResult<decrypted_tensor_2_t>(ClientLambda &lambda,
KeySet &keySet,
std::istream &istream) {
return lambda.decryptReturnedTensor2(keySet, istream);
}
template <>
outcome::checked<decrypted_tensor_3_t, StringError>
topLevelDecryptResult<decrypted_tensor_3_t>(ClientLambda &lambda,
KeySet &keySet,
std::istream &istream) {
return lambda.decryptReturnedTensor3(keySet, istream);
}
} // namespace clientlib
} // namespace concretelang

View File

@@ -3,53 +3,60 @@
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#include <fstream>
#include "boost/outcome.h"
#include "concretelang/ClientLib/ClientParameters.h"
namespace mlir {
namespace concretelang {
namespace clientlib {
using StringError = concretelang::error::StringError;
// https://stackoverflow.com/a/38140932
static inline void hash(std::size_t &seed) {}
static inline void hash_(std::size_t &seed) {}
template <typename T, typename... Rest>
static inline void hash(std::size_t &seed, const T &v, Rest... rest) {
static inline void hash_(std::size_t &seed, const T &v, Rest... rest) {
// See https://softwareengineering.stackexchange.com/a/402543
const auto GOLDEN_RATIO = 0x9e3779b97f4a7c15; // pseudo random bits
const std::hash<T> hasher;
seed ^= hasher(v) + GOLDEN_RATIO + (seed << 6) + (seed >> 2);
hash(seed, rest...);
hash_(seed, rest...);
}
void LweSecretKeyParam::hash(size_t &seed) {
mlir::concretelang::hash(seed, size);
}
void LweSecretKeyParam::hash(size_t &seed) { hash_(seed, size); }
void BootstrapKeyParam::hash(size_t &seed) {
mlir::concretelang::hash(seed, inputSecretKeyID, outputSecretKeyID, level,
baseLog, glweDimension, variance);
hash_(seed, inputSecretKeyID, outputSecretKeyID, level, baseLog,
glweDimension, variance);
}
void KeyswitchKeyParam::hash(size_t &seed) {
mlir::concretelang::hash(seed, inputSecretKeyID, outputSecretKeyID, level,
baseLog, variance);
hash_(seed, inputSecretKeyID, outputSecretKeyID, level, baseLog, variance);
}
std::size_t ClientParameters::hash() {
std::size_t currentHash = 1;
for (auto secretKeyParam : secretKeys) {
mlir::concretelang::hash(currentHash, secretKeyParam.first);
hash_(currentHash, secretKeyParam.first);
secretKeyParam.second.hash(currentHash);
}
for (auto bootstrapKeyParam : bootstrapKeys) {
mlir::concretelang::hash(currentHash, bootstrapKeyParam.first);
hash_(currentHash, bootstrapKeyParam.first);
bootstrapKeyParam.second.hash(currentHash);
}
for (auto keyswitchParam : keyswitchKeys) {
mlir::concretelang::hash(currentHash, keyswitchParam.first);
hash_(currentHash, keyswitchParam.first);
keyswitchParam.second.hash(currentHash);
}
return currentHash;
}
LweSecretKeyParam ClientParameters::lweSecretKeyParam(CircuitGate gate) {
return secretKeys.find(gate.encryption->secretKeyID)->second;
}
llvm::json::Value toJSON(const LweSecretKeyParam &v) {
llvm::json::Object object{
{"size", v.size},
@@ -384,5 +391,27 @@ bool fromJSON(const llvm::json::Value j, ClientParameters &v,
return true;
}
} // namespace concretelang
} // namespace mlir
std::string ClientParameters::getClientParametersPath(std::string path) {
return path + CLIENT_PARAMETERS_EXT;
}
outcome::checked<std::vector<ClientParameters>, StringError>
ClientParameters::load(std::string jsonPath) {
std::ifstream file(jsonPath);
std::string content((std::istreambuf_iterator<char>(file)),
(std::istreambuf_iterator<char>()));
if (file.fail()) {
return StringError("Cannot read file: ") << jsonPath;
}
auto expectedClientParams =
llvm::json::parse<std::vector<ClientParameters>>(content);
if (auto err = expectedClientParams.takeError()) {
return StringError("Cannot open client parameters: ")
<< llvm::toString(std::move(err)) << "\n"
<< content << "\n";
}
return expectedClientParams.get();
}
} // namespace clientlib
} // namespace concretelang

View File

@@ -0,0 +1,181 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#include "concretelang/ClientLib/EncryptedArgs.h"
#include "concretelang/ClientLib/PublicArguments.h"
namespace concretelang {
namespace clientlib {
using StringError = concretelang::error::StringError;
EncryptedArgs::~EncryptedArgs() {
// There is no explicit allocation
// All buffers are owned by ciphertextBuffers
}
EncryptedArgs::EncryptedArgs() : currentPos(0) {}
outcome::checked<void, StringError>
EncryptedArgs::pushArg(uint8_t arg, std::shared_ptr<KeySet> keySet) {
return pushArg((uint64_t)arg, keySet);
}
outcome::checked<void, StringError>
EncryptedArgs::pushArg(uint64_t arg, std::shared_ptr<KeySet> keySet) {
// TODO: NON ENCRYPTED
OUTCOME_TRYV(checkPushTooManyArgs(keySet));
auto pos = currentPos;
CircuitGate input = keySet->inputGate(pos);
if (input.shape.size != 0) {
return StringError("argument #") << pos << " is not a scalar";
}
if (!input.encryption.hasValue()) {
// clear scalar: just push the argument
if (input.shape.width != 64) {
return StringError(
"scalar argument of with != 64 is not supported for DynamicLambda");
}
preparedArgs.push_back((void *)arg);
return outcome::success();
}
ciphertextBuffers.resize(ciphertextBuffers.size() + 1); // Allocate empty
encrypted_scalars_and_sizes_t &values_and_sizes = ciphertextBuffers.back();
auto lweSize = keySet->getInputLweSecretKeyParam(pos).size + 1;
values_and_sizes.sizes.push_back(lweSize);
values_and_sizes.values.resize(lweSize);
OUTCOME_TRYV(keySet->encrypt_lwe(pos, values_and_sizes.values.data(), arg));
// Note: Since we bufferized lwe ciphertext take care of memref calling
// convention
// allocated
preparedArgs.push_back(nullptr);
// aligned
preparedArgs.push_back((void *)values_and_sizes.values.data());
// offset
preparedArgs.push_back((void *)0);
// size
preparedArgs.push_back((void *)values_and_sizes.values.size());
// stride
preparedArgs.push_back((void *)1);
currentPos++;
return outcome::success();
}
outcome::checked<void, StringError>
EncryptedArgs::pushArg(std::vector<uint8_t> arg,
std::shared_ptr<KeySet> keySet) {
return pushArg(8, (void *)arg.data(), {(int64_t)arg.size()}, keySet);
}
outcome::checked<void, StringError>
EncryptedArgs::pushArg(size_t width, void *data, llvm::ArrayRef<int64_t> shape,
std::shared_ptr<KeySet> keySet) {
OUTCOME_TRYV(checkPushTooManyArgs(keySet));
auto pos = currentPos;
CircuitGate input = keySet->inputGate(pos);
// Check the width of data
if (input.shape.width > 64) {
return StringError("argument #")
<< pos << " width > 64 bits is not supported";
}
auto roundedSize = concretelang::common::bitWidthAsWord(input.shape.width);
if (width != roundedSize) {
return StringError("argument #") << pos << "width mismatch, got " << width
<< " expected " << roundedSize;
}
// Check the shape of tensor
if (input.shape.dimensions.empty()) {
return StringError("argument #") << pos << "is not a tensor";
}
if (shape.size() != input.shape.dimensions.size()) {
return StringError("argument #")
<< pos << "has not the expected number of dimension, got "
<< shape.size() << " expected " << input.shape.dimensions.size();
}
ciphertextBuffers.resize(ciphertextBuffers.size() + 1); // Allocate empty
encrypted_scalars_and_sizes_t &values_and_sizes = ciphertextBuffers.back();
for (size_t i = 0; i < shape.size(); i++) {
values_and_sizes.sizes.push_back(shape[i]);
if (shape[i] != input.shape.dimensions[i]) {
return StringError("argument #")
<< pos << " has not the expected dimension #" << i << " , got "
<< shape[i] << " expected " << input.shape.dimensions[i];
}
}
if (input.encryption.hasValue()) {
auto lweSize = keySet->getInputLweSecretKeyParam(pos).size + 1;
values_and_sizes.sizes.push_back(lweSize);
// Encrypted tensor: for now we support only 8 bits for encrypted tensor
if (width != 8) {
return StringError("argument #")
<< pos << " width mismatch, expected 8 got " << width;
}
const uint8_t *data8 = (const uint8_t *)data;
// Allocate a buffer for ciphertexts of size of tensor
values_and_sizes.values.resize(input.shape.size * lweSize);
auto &values = values_and_sizes.values;
// Allocate ciphertexts and encrypt, for every values in tensor
for (size_t i = 0, offset = 0; i < input.shape.size;
i++, offset += lweSize) {
OUTCOME_TRYV(keySet->encrypt_lwe(pos, values.data() + offset, data8[i]));
}
} // TODO: NON ENCRYPTED, COPY CONTENT TO values_and_sizes
// allocated
preparedArgs.push_back(nullptr);
// aligned
preparedArgs.push_back((void *)values_and_sizes.values.data());
// offset
preparedArgs.push_back((void *)0);
// sizes
for (size_t size : values_and_sizes.sizes) {
preparedArgs.push_back((void *)size);
}
// Set the stride for each dimension, equal to the product of the
// following dimensions.
int64_t stride = values_and_sizes.length();
// If encrypted +1 set the stride for the lwe size rank
for (size_t size : values_and_sizes.sizes) {
stride /= size;
preparedArgs.push_back((void *)stride);
}
currentPos++;
return outcome::success();
}
outcome::checked<void, StringError>
EncryptedArgs::checkPushTooManyArgs(std::shared_ptr<KeySet> keySet) {
size_t arity = keySet->numInputs();
if (currentPos < arity) {
return outcome::success();
}
return StringError("function has arity ")
<< arity << " but is applied to too many arguments";
}
outcome::checked<void, StringError>
EncryptedArgs::checkAllArgs(std::shared_ptr<KeySet> keySet) {
size_t arity = keySet->numInputs();
if (currentPos == arity) {
return outcome::success();
}
return StringError("function expects ")
<< arity << " arguments but has been called with " << currentPos
<< " arguments";
}
outcome::checked<PublicArguments, StringError>
EncryptedArgs::asPublicArguments(ClientParameters clientParameters,
RuntimeContext runtimeContext) {
// On client side the runtimeContext is hold by the KeySet
bool clearContext = false;
return PublicArguments(clientParameters, runtimeContext, clearContext,
std::move(preparedArgs), std::move(ciphertextBuffers));
}
} // namespace clientlib
} // namespace concretelang

View File

@@ -6,8 +6,17 @@
#include "concretelang/ClientLib/KeySet.h"
#include "concretelang/Support/Error.h"
namespace mlir {
#define CAPI_ERR_TO_STRINGERROR(instr, msg) \
{ \
int err; \
instr; \
if (err != 0) { \
return concretelang::error::StringError(msg); \
} \
}
namespace concretelang {
namespace clientlib {
KeySet::~KeySet() {
for (auto it : secretKeys) {
@@ -22,31 +31,20 @@ KeySet::~KeySet() {
free_encryption_generator(encryptionRandomGenerator);
}
llvm::Expected<std::unique_ptr<KeySet>>
outcome::checked<std::unique_ptr<KeySet>, StringError>
KeySet::generate(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb) {
auto keySet = std::make_unique<KeySet>();
auto keySet = uninitialized();
if (auto error = keySet->generateKeysFromParams(params, seed_msb, seed_lsb)) {
return std::move(error);
}
if (auto error =
keySet->setupEncryptionMaterial(params, seed_msb, seed_lsb)) {
return std::move(error);
}
OUTCOME_TRYV(keySet->generateKeysFromParams(params, seed_msb, seed_lsb));
OUTCOME_TRYV(keySet->setupEncryptionMaterial(params, seed_msb, seed_lsb));
return std::move(keySet);
}
std::unique_ptr<KeySet> KeySet::uninitialized() {
return std::make_unique<KeySet>();
}
llvm::Error KeySet::setupEncryptionMaterial(ClientParameters &params,
uint64_t seed_msb,
uint64_t seed_lsb) {
outcome::checked<void, StringError>
KeySet::setupEncryptionMaterial(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb) {
// Set inputs and outputs LWE secret keys
{
for (auto param : params.inputs) {
@@ -55,10 +53,8 @@ llvm::Error KeySet::setupEncryptionMaterial(ClientParameters &params,
if (param.encryption.hasValue()) {
auto inputSk = this->secretKeys.find(param.encryption->secretKeyID);
if (inputSk == this->secretKeys.end()) {
return llvm::make_error<llvm::StringError>(
"input encryption secret key (" + param.encryption->secretKeyID +
") does not exist ",
llvm::inconvertibleErrorCode());
return StringError("input encryption secret key (")
<< param.encryption->secretKeyID << ") does not exist ";
}
secretKeyParam = inputSk->second.first;
secretKey = inputSk->second.second;
@@ -73,9 +69,8 @@ llvm::Error KeySet::setupEncryptionMaterial(ClientParameters &params,
if (param.encryption.hasValue()) {
auto outputSk = this->secretKeys.find(param.encryption->secretKeyID);
if (outputSk == this->secretKeys.end()) {
return llvm::make_error<llvm::StringError>(
"cannot find output key to generate bootstrap key",
llvm::inconvertibleErrorCode());
return StringError(
"cannot find output key to generate bootstrap key");
}
secretKeyParam = outputSk->second.first;
secretKey = outputSk->second.second;
@@ -89,50 +84,41 @@ llvm::Error KeySet::setupEncryptionMaterial(ClientParameters &params,
this->encryptionRandomGenerator =
allocate_encryption_generator(seed_msb, seed_lsb);
return llvm::Error::success();
return outcome::success();
}
llvm::Error KeySet::generateKeysFromParams(ClientParameters &params,
uint64_t seed_msb,
uint64_t seed_lsb) {
outcome::checked<void, StringError>
KeySet::generateKeysFromParams(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb) {
{
// Generate LWE secret keys
SecretRandomGenerator *generator;
generator = allocate_secret_generator(seed_msb, seed_lsb);
for (auto secretKeyParam : params.secretKeys) {
auto e = this->generateSecretKey(secretKeyParam.first,
secretKeyParam.second, generator);
if (e) {
return std::move(e);
}
OUTCOME_TRYV(this->generateSecretKey(secretKeyParam.first,
secretKeyParam.second, generator));
}
free_secret_generator(generator);
}
// Allocate the encryption random generator
this->encryptionRandomGenerator =
allocate_encryption_generator(seed_msb, seed_lsb);
// Generate bootstrap and keyswitch keys
{
for (auto bootstrapKeyParam : params.bootstrapKeys) {
auto e = this->generateBootstrapKey(bootstrapKeyParam.first,
bootstrapKeyParam.second,
this->encryptionRandomGenerator);
if (e) {
return std::move(e);
}
OUTCOME_TRYV(this->generateBootstrapKey(bootstrapKeyParam.first,
bootstrapKeyParam.second,
this->encryptionRandomGenerator));
}
for (auto keyswitchParam : params.keyswitchKeys) {
auto e = this->generateKeyswitchKey(keyswitchParam.first,
keyswitchParam.second,
this->encryptionRandomGenerator);
if (e) {
return std::move(e);
}
OUTCOME_TRYV(this->generateKeyswitchKey(keyswitchParam.first,
keyswitchParam.second,
this->encryptionRandomGenerator));
}
}
return llvm::Error::success();
return outcome::success();
}
void KeySet::setKeys(
@@ -149,9 +135,9 @@ void KeySet::setKeys(
this->keyswitchKeys = keyswitchKeys;
}
llvm::Error KeySet::generateSecretKey(LweSecretKeyID id,
LweSecretKeyParam param,
SecretRandomGenerator *generator) {
outcome::checked<void, StringError>
KeySet::generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param,
SecretRandomGenerator *generator) {
LweSecretKey_u64 *sk;
sk = allocate_lwe_secret_key_u64({param.size});
@@ -159,24 +145,20 @@ llvm::Error KeySet::generateSecretKey(LweSecretKeyID id,
secretKeys[id] = {param, sk};
return llvm::Error::success();
return outcome::success();
}
llvm::Error KeySet::generateBootstrapKey(BootstrapKeyID id,
BootstrapKeyParam param,
EncryptionRandomGenerator *generator) {
outcome::checked<void, StringError>
KeySet::generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param,
EncryptionRandomGenerator *generator) {
// Finding input and output secretKeys
auto inputSk = secretKeys.find(param.inputSecretKeyID);
if (inputSk == secretKeys.end()) {
return llvm::make_error<llvm::StringError>(
"cannot find input key to generate bootstrap key",
llvm::inconvertibleErrorCode());
return StringError("cannot find input key to generate bootstrap key");
}
auto outputSk = secretKeys.find(param.outputSecretKeyID);
if (outputSk == secretKeys.end()) {
return llvm::make_error<llvm::StringError>(
"cannot find output key to generate bootstrap key",
llvm::inconvertibleErrorCode());
return StringError("cannot find output key to generate bootstrap key");
}
// Allocate the bootstrap key
LweBootstrapKey_u64 *bsk;
@@ -207,24 +189,20 @@ llvm::Error KeySet::generateBootstrapKey(BootstrapKeyID id,
fill_lwe_bootstrap_key_u64(bsk, inputSk->second.second, glwe_sk, generator,
{param.variance});
free_glwe_secret_key_u64(glwe_sk);
return llvm::Error::success();
return outcome::success();
}
llvm::Error KeySet::generateKeyswitchKey(KeyswitchKeyID id,
KeyswitchKeyParam param,
EncryptionRandomGenerator *generator) {
outcome::checked<void, StringError>
KeySet::generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param,
EncryptionRandomGenerator *generator) {
// Finding input and output secretKeys
auto inputSk = secretKeys.find(param.inputSecretKeyID);
if (inputSk == secretKeys.end()) {
return llvm::make_error<llvm::StringError>(
"cannot find input key to generate keyswitch key",
llvm::inconvertibleErrorCode());
return StringError("cannot find input key to generate keyswitch key");
}
auto outputSk = secretKeys.find(param.outputSecretKeyID);
if (outputSk == secretKeys.end()) {
return llvm::make_error<llvm::StringError>(
"cannot find input key to generate keyswitch key",
llvm::inconvertibleErrorCode());
return StringError("cannot find output key to generate keyswitch key");
}
// Allocate the keyswitch key
LweKeyswitchKey_u64 *ksk;
@@ -240,21 +218,19 @@ llvm::Error KeySet::generateKeyswitchKey(KeyswitchKeyID id,
fill_lwe_keyswitch_key_u64(ksk, inputSk->second.second,
outputSk->second.second, generator,
{param.variance});
return llvm::Error::success();
return outcome::success();
}
llvm::Error KeySet::allocate_lwe(size_t argPos, uint64_t **ciphertext,
uint64_t &size) {
outcome::checked<void, StringError>
KeySet::allocate_lwe(size_t argPos, uint64_t **ciphertext, uint64_t &size) {
if (argPos >= inputs.size()) {
return llvm::make_error<llvm::StringError>(
"allocate_lwe position of argument is too high",
llvm::inconvertibleErrorCode());
return StringError("allocate_lwe position of argument is too high");
}
auto inputSk = inputs[argPos];
size = std::get<1>(inputSk).size + 1;
*ciphertext = (uint64_t *)malloc(sizeof(uint64_t) * size);
return llvm::Error::success();
return outcome::success();
}
bool KeySet::isInputEncrypted(size_t argPos) {
@@ -267,18 +243,14 @@ bool KeySet::isOutputEncrypted(size_t argPos) {
std::get<0>(outputs[argPos]).encryption.hasValue();
}
llvm::Error KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext,
uint64_t input) {
outcome::checked<void, StringError>
KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) {
if (argPos >= inputs.size()) {
return llvm::make_error<llvm::StringError>(
"encrypt_lwe position of argument is too high",
llvm::inconvertibleErrorCode());
return StringError("encrypt_lwe position of argument is too high");
}
auto inputSk = inputs[argPos];
if (!std::get<0>(inputSk).encryption.hasValue()) {
return llvm::make_error<llvm::StringError>(
"encrypt_lwe the positional argument is not encrypted",
llvm::inconvertibleErrorCode());
return StringError("encrypt_lwe the positional argument is not encrypted");
}
// Encode - TODO we could check if the input value is in the right range
uint64_t plaintext =
@@ -286,22 +258,18 @@ llvm::Error KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext,
encrypt_lwe_u64(std::get<2>(inputSk), ciphertext, plaintext,
encryptionRandomGenerator,
{std::get<0>(inputSk).encryption->variance});
return llvm::Error::success();
return outcome::success();
}
llvm::Error KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext,
uint64_t &output) {
outcome::checked<void, StringError>
KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) {
if (argPos >= outputs.size()) {
return llvm::make_error<llvm::StringError>(
"decrypt_lwe: position of argument is too high",
llvm::inconvertibleErrorCode());
return StringError("decrypt_lwe: position of argument is too high");
}
auto outputSk = outputs[argPos];
if (!std::get<0>(outputSk).encryption.hasValue()) {
return llvm::make_error<llvm::StringError>(
"decrypt_lwe: the positional argument is not encrypted",
llvm::inconvertibleErrorCode());
return StringError("decrypt_lwe: the positional argument is not encrypted");
}
uint64_t plaintext = decrypt_lwe_u64(std::get<2>(outputSk), ciphertext);
// Decode
@@ -309,7 +277,7 @@ llvm::Error KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext,
output = plaintext >> (64 - precision - 2);
size_t carry = output % 2;
output = ((output >> 1) + carry) % (1 << (precision + 1));
return llvm::Error::success();
return outcome::success();
}
const std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
@@ -329,5 +297,5 @@ KeySet::getKeyswitchKeys() {
return keyswitchKeys;
}
} // namespace clientlib
} // namespace concretelang
} // namespace mlir

View File

@@ -3,14 +3,14 @@
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#include "boost/outcome.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/Support/Error.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/Path.h"
#include <fstream>
#include <functional>
#include <sstream>
#include <string>
@@ -18,8 +18,10 @@ extern "C" {
#include "concrete-ffi.h"
}
namespace mlir {
namespace concretelang {
namespace clientlib {
using StringError = concretelang::error::StringError;
static std::string readFile(llvm::SmallString<0> &path) {
std::ifstream in((std::string)path, std::ofstream::binary);
@@ -70,11 +72,11 @@ void saveKeyswitchKey(llvm::SmallString<0> &path, LweKeyswitchKey_u64 *key) {
free(buffer.pointer);
}
llvm::Expected<std::unique_ptr<KeySet>>
KeySetCache::tryLoadKeys(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb, llvm::SmallString<0> &folderPath) {
outcome::checked<std::unique_ptr<KeySet>, StringError>
KeySetCache::loadKeys(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb, std::string folderPath) {
// TODO: text dump of all parameter in /hash
auto key_set = KeySet::uninitialized();
auto key_set = std::make_unique<KeySet>();
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
secretKeys;
@@ -87,7 +89,7 @@ KeySetCache::tryLoadKeys(ClientParameters &params, uint64_t seed_msb,
for (auto secretKeyParam : params.secretKeys) {
auto id = secretKeyParam.first;
auto param = secretKeyParam.second;
llvm::SmallString<0> path = folderPath;
llvm::SmallString<0> path(folderPath);
llvm::sys::path::append(path, "secretKey_" + id);
LweSecretKey_u64 *sk = loadSecretKey(path);
secretKeys[id] = {param, sk};
@@ -96,7 +98,7 @@ KeySetCache::tryLoadKeys(ClientParameters &params, uint64_t seed_msb,
for (auto bootstrapKeyParam : params.bootstrapKeys) {
auto id = bootstrapKeyParam.first;
auto param = bootstrapKeyParam.second;
llvm::SmallString<0> path = folderPath;
llvm::SmallString<0> path(folderPath);
llvm::sys::path::append(path, "pbsKey_" + id);
LweBootstrapKey_u64 *bsk = loadBootstrapKey(path);
bootstrapKeys[id] = {param, bsk};
@@ -105,7 +107,7 @@ KeySetCache::tryLoadKeys(ClientParameters &params, uint64_t seed_msb,
for (auto keyswitchParam : params.keyswitchKeys) {
auto id = keyswitchParam.first;
auto param = keyswitchParam.second;
llvm::SmallString<0> path = folderPath;
llvm::SmallString<0> path(folderPath);
llvm::sys::path::append(path, "ksKey_" + id);
LweKeyswitchKey_u64 *ksk = loadKeyswitchKey(path);
keyswitchKeys[id] = {param, ksk};
@@ -113,24 +115,21 @@ KeySetCache::tryLoadKeys(ClientParameters &params, uint64_t seed_msb,
key_set->setKeys(secretKeys, bootstrapKeys, keyswitchKeys);
auto err = key_set->setupEncryptionMaterial(params, seed_msb, seed_lsb);
if (err) {
return StreamStringError() << "Cannot setup encryption material: " << err;
}
OUTCOME_TRYV(key_set->setupEncryptionMaterial(params, seed_msb, seed_lsb));
return std::move(key_set);
}
llvm::Error saveKeys(KeySet &key_set, llvm::SmallString<0> &folderPath) {
outcome::checked<void, StringError> saveKeys(KeySet &key_set,
llvm::SmallString<0> &folderPath) {
llvm::SmallString<0> folderIncompletePath = folderPath;
folderIncompletePath.append(".incomplete");
auto err = llvm::sys::fs::create_directories(folderIncompletePath);
if (err) {
return StreamStringError()
<< "Cannot create directory \"" << folderIncompletePath
<< "\": " << err.message();
return StringError("Cannot create directory \"")
<< std::string(folderIncompletePath) << "\": " << err.message();
}
// Save LWE secret keys
@@ -163,16 +162,16 @@ llvm::Error saveKeys(KeySet &key_set, llvm::SmallString<0> &folderPath) {
llvm::sys::fs::remove_directories(folderIncompletePath);
}
if (!llvm::sys::fs::exists(folderPath)) {
return StreamStringError()
<< "Cannot save directory \"" << folderPath << "\"";
return StringError("Cannot save directory \"")
<< std::string(folderPath) << "\"";
}
return llvm::Error::success();
return outcome::success();
}
llvm::Expected<std::unique_ptr<KeySet>>
KeySetCache::tryLoadOrGenerateSave(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb) {
outcome::checked<std::unique_ptr<KeySet>, StringError>
KeySetCache::loadOrGenerateSave(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb) {
llvm::SmallString<0> folderPath =
llvm::SmallString<0>(this->backingDirectoryPath);
@@ -183,7 +182,7 @@ KeySetCache::tryLoadOrGenerateSave(ClientParameters &params, uint64_t seed_msb,
std::to_string(seed_lsb));
if (llvm::sys::fs::exists(folderPath)) {
return tryLoadKeys(params, seed_msb, seed_lsb, folderPath);
return loadKeys(params, seed_msb, seed_lsb, std::string(folderPath));
}
// Creating a lock for concurrent generation
@@ -198,8 +197,8 @@ KeySetCache::tryLoadOrGenerateSave(ClientParameters &params, uint64_t seed_msb,
if (err) {
// parent does not exists OR right issue (creation or write)
return StreamStringError()
<< "Cannot access \"" << lockPath << "\": " << err.message();
return StringError("Cannot access \"")
<< std::string(lockPath) << "\": " << err.message();
}
// The first to lock will generate while the others waits
@@ -211,23 +210,23 @@ KeySetCache::tryLoadOrGenerateSave(ClientParameters &params, uint64_t seed_msb,
if (llvm::sys::fs::exists(folderPath)) {
// Others returns here
return tryLoadKeys(params, seed_msb, seed_lsb, folderPath);
return loadKeys(params, seed_msb, seed_lsb, std::string(folderPath));
}
auto key_set = KeySet::generate(params, seed_msb, seed_lsb);
OUTCOME_TRY(auto key_set, KeySet::generate(params, seed_msb, seed_lsb));
if (!key_set) {
return StreamStringError()
<< "Cannot generate key set: " << key_set.takeError();
}
OUTCOME_TRYV(saveKeys(*key_set, folderPath));
auto savedErr = saveKeys(*(key_set.get()), folderPath);
if (savedErr) {
return StreamStringError() << "Cannot save key set: " << savedErr;
}
return key_set;
return std::move(key_set);
}
outcome::checked<std::unique_ptr<KeySet>, StringError>
KeySetCache::generate(std::shared_ptr<KeySetCache> cache,
ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb) {
return cache ? cache->loadOrGenerateSave(params, seed_msb, seed_lsb)
: KeySet::generate(params, seed_msb, seed_lsb);
}
} // namespace clientlib
} // namespace concretelang
} // namespace mlir

View File

@@ -0,0 +1,156 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#include <iostream>
#include <stdlib.h>
extern "C" {
#include "concrete-ffi.h"
}
#include "concretelang/ClientLib/PublicArguments.h"
#include "concretelang/ClientLib/Serializers.h"
namespace concretelang {
namespace clientlib {
using concretelang::error::StringError;
// TODO: optimize the move
PublicArguments::PublicArguments(
const ClientParameters &clientParameters, RuntimeContext runtimeContext,
bool clearRuntimeContext, std::vector<void *> &&preparedArgs_,
std::vector<encrypted_scalars_and_sizes_t> &&ciphertextBuffers_)
: clientParameters(clientParameters), runtimeContext(runtimeContext),
clearRuntimeContext(clearRuntimeContext) {
preparedArgs = std::move(preparedArgs_);
ciphertextBuffers = std::move(ciphertextBuffers_);
}
PublicArguments::PublicArguments(PublicArguments &&other) {
clientParameters = other.clientParameters;
runtimeContext = other.runtimeContext;
runtimeContext.bsk = std::move(other.runtimeContext.bsk);
clearRuntimeContext = other.clearRuntimeContext;
preparedArgs = std::move(other.preparedArgs);
ciphertextBuffers = std::move(other.ciphertextBuffers);
// transfer ownership
other.clearRuntimeContext = false;
other.runtimeContext.ksk = nullptr;
}
PublicArguments::~PublicArguments() {
if (!clearRuntimeContext) {
return;
}
for (auto bsk_entry : runtimeContext.bsk) {
free_lwe_bootstrap_key_u64(bsk_entry.second);
}
runtimeContext.bsk.clear();
if (runtimeContext.ksk != nullptr) {
free_lwe_keyswitch_key_u64(runtimeContext.ksk);
runtimeContext.ksk = nullptr;
}
}
outcome::checked<void, StringError>
PublicArguments::serialize(std::ostream &ostream) {
if (incorrectMode(ostream)) {
return StringError(
"PublicArguments::serialize: ostream should be in binary mode");
}
ostream << runtimeContext;
size_t iPreparedArgs = 0;
int iGate = -1;
for (auto gate : clientParameters.inputs) {
iGate++;
size_t rank = gate.shape.dimensions.size();
if (!gate.encryption.hasValue()) {
return StringError("PublicArguments::serialize: Clear arguments "
"are not supported. Argument ")
<< iGate;
}
/*auto allocated = */ preparedArgs[iPreparedArgs++];
auto aligned = (encrypted_scalars_t)preparedArgs[iPreparedArgs++];
assert(aligned != nullptr);
auto offset = (size_t)preparedArgs[iPreparedArgs++];
std::vector<size_t> sizes; // includes lweSize as last dim
sizes.resize(rank + 1);
for (auto dim = 0u; dim < sizes.size(); dim++) {
// sizes are part of the client parameters signature
// it's static now but some day it could be dynamic so we serialize
// them.
sizes[dim] = (size_t)preparedArgs[iPreparedArgs++];
}
std::vector<size_t> strides(rank + 1);
/* strides should be zero here and are not serialized */
for (auto dim = 0u; dim < strides.size(); dim++) {
strides[dim] = (size_t)preparedArgs[iPreparedArgs++];
}
// TODO: STRIDES
auto values = aligned + offset;
serializeEncryptedValues(sizes, values, ostream);
}
return outcome::success();
}
outcome::checked<void, StringError>
PublicArguments::unserializeArgs(std::istream &istream) {
int iGate = -1;
for (auto gate : clientParameters.inputs) {
iGate++;
if (!gate.encryption.hasValue()) {
return StringError("Clear values are not handled");
}
auto lweSize = clientParameters.lweSecretKeyParam(gate).lweSize();
std::vector<int64_t> sizes = gate.shape.dimensions;
sizes.push_back(lweSize);
ciphertextBuffers.push_back(
std::move(unserializeEncryptedValues(sizes, istream)));
auto &values_and_sizes = ciphertextBuffers.back();
if (istream.fail()) {
return StringError(
"PublicArguments::unserializeArgs: Failed to read argument ")
<< iGate;
}
preparedArgs.push_back(/*allocated*/ nullptr);
preparedArgs.push_back((void *)values_and_sizes.values.data());
preparedArgs.push_back(/*offset*/ 0);
// sizes
for (auto size : values_and_sizes.sizes) {
preparedArgs.push_back((void *)size);
}
// strides has been removed by serialization
auto stride = values_and_sizes.length();
for (auto size : sizes) {
stride /= size;
preparedArgs.push_back((void *)stride);
}
}
return outcome::success();
}
outcome::checked<std::shared_ptr<PublicArguments>, StringError>
PublicArguments::unserialize(ClientParameters &clientParameters,
std::istream &istream) {
RuntimeContext runtimeContext;
istream >> runtimeContext;
if (istream.fail()) {
return StringError("Cannot read runtime context");
}
std::vector<void *> empty;
std::vector<encrypted_scalars_and_sizes_t> emptyBuffers;
// On server side the PublicArguments is responsible for the context
auto clearRuntimeContext = true;
auto sArguments = std::make_shared<PublicArguments>(
clientParameters, runtimeContext, clearRuntimeContext, std::move(empty),
std::move(emptyBuffers));
OUTCOME_TRYV(sArguments->unserializeArgs(istream));
sArguments->preparedArgs.push_back((void *)&runtimeContext);
return sArguments;
}
} // namespace clientlib
} // namespace concretelang

View File

@@ -0,0 +1,185 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#include <malloc.h>
#include <iosfwd>
#include <iostream>
#include <stdlib.h>
extern "C" {
#include "concrete-ffi.h"
}
#include "concretelang/ClientLib/PublicArguments.h"
#include "concretelang/ClientLib/Serializers.h"
namespace concretelang {
namespace clientlib {
template <typename Result>
Result read_deser(std::istream &istream, Result (*deser)(BufferView)) {
size_t length;
readSize(istream, length);
// buffer is too big to be allocated on stack
// vector ensures everything is deallocated w.r.t. new
std::vector<uint8_t> buffer(length);
istream.read((char *)buffer.data(), length);
assert(istream.good());
return deser({buffer.data(), length});
}
template <typename BufferLike>
std::ostream &writeBufferLike(std::ostream &ostream, BufferLike &buffer) {
writeSize(ostream, buffer.length);
ostream.write((const char *)buffer.pointer, buffer.length);
assert(ostream.good());
return ostream;
}
std::ostream &operator<<(std::ostream &ostream,
const LweKeyswitchKey_u64 *key) {
Buffer b = serialize_lwe_keyswitching_key_u64(key);
writeBufferLike(ostream, b);
free((void *)b.pointer);
b.pointer = nullptr;
return ostream;
}
std::ostream &operator<<(std::ostream &ostream,
const LweBootstrapKey_u64 *key) {
Buffer b = serialize_lwe_bootstrap_key_u64(key);
writeBufferLike(ostream, b);
free((void *)b.pointer);
b.pointer = nullptr;
return ostream;
}
std::istream &operator>>(std::istream &istream, LweKeyswitchKey_u64 *&key) {
key = read_deser(istream, deserialize_lwe_keyswitching_key_u64);
return istream;
}
std::istream &operator>>(std::istream &istream, LweBootstrapKey_u64 *&key) {
key = read_deser(istream, deserialize_lwe_bootstrap_key_u64);
return istream;
}
std::ostream &operator<<(std::ostream &ostream, const ClientParameters &cp) {
// For binary stream == not formatting
std::string json;
llvm::raw_string_ostream tmpostream(json);
tmpostream << toJSON(cp);
writeSize(ostream, json.size());
assert(ostream.good());
ostream << json;
assert(ostream.good());
return ostream;
}
std::istream &operator>>(std::istream &istream, ClientParameters &params) {
size_t size;
readSize(istream, size);
char buffer[size + 1];
buffer[size] = '\0'; // llvm::json::parse requires \0 ended buffer.
istream.read(buffer, size);
auto paramsOrErr = llvm::json::parse<ClientParameters>(buffer);
if (auto err = paramsOrErr.takeError()) {
llvm::errs() << "Parsing client parameters error: " << std::move(err)
<< "\n";
istream.setstate(std::ios::failbit);
return istream;
}
params = paramsOrErr.get();
assert(istream.good());
return istream;
}
std::istream &operator>>(std::istream &istream,
RuntimeContext &runtimeContext) {
istream >> runtimeContext.ksk;
istream >> runtimeContext.bsk[RuntimeContext::BASE_CONTEXT_BSK];
assert(istream.good());
return istream;
}
std::ostream &operator<<(std::ostream &ostream,
const RuntimeContext &runtimeContext) {
ostream << runtimeContext.ksk;
ostream << runtimeContext.bsk.at(RuntimeContext::BASE_CONTEXT_BSK);
assert(ostream.good());
return ostream;
}
std::ostream &serializeEncryptedValues(encrypted_scalars_t values,
size_t length, std::ostream &ostream) {
if (incorrectMode(ostream)) {
return ostream;
}
writeSize(ostream, length);
for (size_t i = 0; i < length; i++) {
writeWord(ostream, values[i]);
}
return ostream;
}
std::ostream &serializeEncryptedValues(std::vector<size_t> &sizes,
encrypted_scalars_t values,
std::ostream &ostream) {
size_t length = 1;
for (auto size : sizes) {
length *= size;
writeSize(ostream, size);
}
serializeEncryptedValues(values, length, ostream);
assert(ostream.good());
return ostream;
}
std::ostream &
serializeEncryptedValues(encrypted_scalars_and_sizes_t &values_and_sizes,
std::ostream &ostream) {
std::vector<size_t> &sizes = values_and_sizes.sizes;
encrypted_scalars_t values = values_and_sizes.values.data();
return serializeEncryptedValues(sizes, values, ostream);
}
encrypted_scalars_and_sizes_t unserializeEncryptedValues(
std::vector<int64_t> &expectedSizes, // includes lweSize, unsigned to
// accomodate non static sizes
std::istream &istream) {
encrypted_scalars_and_sizes_t result;
if (incorrectMode(istream)) {
return result;
}
for (auto expectedSize : expectedSizes) {
size_t actualSize;
readSize(istream, actualSize);
if ((size_t)expectedSize != actualSize) {
istream.setstate(std::ios::badbit);
}
assert(actualSize > 0);
result.sizes.push_back(actualSize);
assert(result.sizes.back() > 0);
}
size_t expectedLen = result.length();
assert(expectedLen > 0);
// TODO: full read in one step
size_t actualLen;
readSize(istream, actualLen);
if (expectedLen != actualLen) {
istream.setstate(std::ios::badbit);
}
assert(actualLen == expectedLen);
result.values.resize(actualLen);
for (size_t &value : result.values) {
value = 0;
readWord(istream, value);
}
return result;
}
} // namespace clientlib
} // namespace concretelang

View File

@@ -0,0 +1,23 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#include "concretelang/Common/BitsSize.h"
namespace concretelang {
namespace common {
size_t bitWidthAsWord(size_t exactBitWidth) {
size_t sortedWordBitWidths[] = {8, 16, 32, 64};
size_t previousWidth = 0;
for (auto currentWidth : sortedWordBitWidths) {
if (previousWidth < exactBitWidth && exactBitWidth <= currentWidth) {
return currentWidth;
}
}
return exactBitWidth;
}
} // namespace common
} // namespace concretelang

View File

@@ -0,0 +1,19 @@
add_compile_options( -Werror )
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
# using Clang
add_compile_options( -Wno-error=pessimizing-move -Wno-pessimizing-move )
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
# using GCC
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
add_compile_options( -Werror -Wno-error=pessimizing-move -Wno-pessimizing-move )
endif()
endif()
add_mlir_library(
ConcretelangCommon
BitsSize.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Common
)

View File

@@ -11,6 +11,14 @@
#include <hpx/include/runtime.hpp>
#endif
namespace mlir {
namespace concretelang {
std::string RuntimeContext::BASE_CONTEXT_BSK = "_concretelang_base_context_bsk";
} // namespace concretelang
} // namespace mlir
LweKeyswitchKey_u64 *
get_keyswitch_key(mlir::concretelang::RuntimeContext *context) {
return context->ksk;
@@ -18,6 +26,7 @@ get_keyswitch_key(mlir::concretelang::RuntimeContext *context) {
LweBootstrapKey_u64 *
get_bootstrap_key(mlir::concretelang::RuntimeContext *context) {
using RuntimeContext = mlir::concretelang::RuntimeContext;
#ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED
std::string threadName = hpx::get_thread_name();
auto bskIt = context->bsk.find(threadName);
@@ -26,12 +35,11 @@ get_bootstrap_key(mlir::concretelang::RuntimeContext *context) {
.insert(std::pair<std::string, LweBootstrapKey_u64 *>(
threadName,
clone_lwe_bootstrap_key_u64(
context->bsk["_concretelang_base_context_bsk"])))
context->bsk[RuntimeContext::BASE_CONTEXT_BSK])))
.first;
}
#else
std::string baseName = "_concretelang_base_context_bsk";
auto bskIt = context->bsk.find(baseName);
auto bskIt = context->bsk.find(RuntimeContext::BASE_CONTEXT_BSK);
#endif
assert(bskIt->second && "No bootstrap key available in context");
return bskIt->second;

View File

@@ -0,0 +1,27 @@
add_compile_options( -Werror )
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
# using Clang
add_compile_options( -Wno-error=pessimizing-move -Wno-pessimizing-move )
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
# using GCC
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
add_compile_options( -Wno-error=cast-function-type -Wno-cast-function-type)
add_compile_options( -Werror -Wno-error=pessimizing-move -Wno-pessimizing-move )
endif()
endif()
add_mlir_library(
ConcretelangServerLib
DynamicRankCall.cpp
ServerLambda.cpp
DynamicModule.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/ServerLib
ConcretelangRuntime
ConcretelangSupportLib
ConcretelangClientLib
)

View File

@@ -0,0 +1,54 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#include <dlfcn.h>
#include <fstream>
#include "boost/outcome.h"
#include "concretelang/ServerLib/DynamicModule.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Error.h"
namespace concretelang {
namespace serverlib {
using concretelang::error::StringError;
using mlir::concretelang::CompilerEngine;
DynamicModule::~DynamicModule() {
if (libraryHandle != nullptr) {
dlclose(libraryHandle);
}
}
outcome::checked<std::shared_ptr<DynamicModule>, StringError>
DynamicModule::open(std::string libPath) {
std::shared_ptr<DynamicModule> module = std::make_shared<DynamicModule>();
OUTCOME_TRYV(module->loadClientParametersJSON(libPath));
OUTCOME_TRYV(module->loadSharedLibrary(libPath));
return module;
}
outcome::checked<void, StringError>
DynamicModule::loadSharedLibrary(std::string path) {
libraryHandle = dlopen(
CompilerEngine::Library::getSharedLibraryPath(path).c_str(), RTLD_LAZY);
if (!libraryHandle) {
return StringError("Cannot open shared library") << dlerror();
}
return outcome::success();
}
outcome::checked<void, StringError>
DynamicModule::loadClientParametersJSON(std::string libPath) {
auto jsonPath = CompilerEngine::Library::getClientParametersPath(libPath);
OUTCOME_TRY(auto clientParams, ClientParameters::load(jsonPath));
this->clientParametersList = clientParams;
return outcome::success();
}
} // namespace serverlib
} // namespace concretelang

View File

@@ -0,0 +1,163 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
// generated: see genDynamicRandCall.py
#include <cassert>
#include <vector>
#include "concretelang/ClientLib/Types.h"
#include "concretelang/ServerLib/DynamicArityCall.h"
#include "concretelang/ServerLib/ServerLambda.h"
namespace concretelang {
namespace serverlib {
encrypted_scalars_and_sizes_t
multi_arity_call_dynamic_rank(void *(*func)(void *...),
std::vector<void *> args, size_t rank) {
using concretelang::clientlib::MemRefDescriptor;
constexpr auto convert = &encrypted_scalars_and_sizes_t_from_MemRef;
switch (rank) {
case 0: {
auto m = multi_arity_call((MemRefDescriptor<1>(*)(void *...))func, args);
return convert(1, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 1: {
auto m = multi_arity_call((MemRefDescriptor<2>(*)(void *...))func, args);
return convert(2, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 2: {
auto m = multi_arity_call((MemRefDescriptor<3>(*)(void *...))func, args);
return convert(3, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 3: {
auto m = multi_arity_call((MemRefDescriptor<4>(*)(void *...))func, args);
return convert(4, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 4: {
auto m = multi_arity_call((MemRefDescriptor<5>(*)(void *...))func, args);
return convert(5, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 5: {
auto m = multi_arity_call((MemRefDescriptor<6>(*)(void *...))func, args);
return convert(6, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 6: {
auto m = multi_arity_call((MemRefDescriptor<7>(*)(void *...))func, args);
return convert(7, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 7: {
auto m = multi_arity_call((MemRefDescriptor<8>(*)(void *...))func, args);
return convert(8, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 8: {
auto m = multi_arity_call((MemRefDescriptor<9>(*)(void *...))func, args);
return convert(9, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 9: {
auto m = multi_arity_call((MemRefDescriptor<10>(*)(void *...))func, args);
return convert(10, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 10: {
auto m = multi_arity_call((MemRefDescriptor<11>(*)(void *...))func, args);
return convert(11, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 11: {
auto m = multi_arity_call((MemRefDescriptor<12>(*)(void *...))func, args);
return convert(12, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 12: {
auto m = multi_arity_call((MemRefDescriptor<13>(*)(void *...))func, args);
return convert(13, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 13: {
auto m = multi_arity_call((MemRefDescriptor<14>(*)(void *...))func, args);
return convert(14, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 14: {
auto m = multi_arity_call((MemRefDescriptor<15>(*)(void *...))func, args);
return convert(15, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 15: {
auto m = multi_arity_call((MemRefDescriptor<16>(*)(void *...))func, args);
return convert(16, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 16: {
auto m = multi_arity_call((MemRefDescriptor<17>(*)(void *...))func, args);
return convert(17, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 17: {
auto m = multi_arity_call((MemRefDescriptor<18>(*)(void *...))func, args);
return convert(18, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 18: {
auto m = multi_arity_call((MemRefDescriptor<19>(*)(void *...))func, args);
return convert(19, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 19: {
auto m = multi_arity_call((MemRefDescriptor<20>(*)(void *...))func, args);
return convert(20, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 20: {
auto m = multi_arity_call((MemRefDescriptor<21>(*)(void *...))func, args);
return convert(21, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 21: {
auto m = multi_arity_call((MemRefDescriptor<22>(*)(void *...))func, args);
return convert(22, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 22: {
auto m = multi_arity_call((MemRefDescriptor<23>(*)(void *...))func, args);
return convert(23, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 23: {
auto m = multi_arity_call((MemRefDescriptor<24>(*)(void *...))func, args);
return convert(24, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 24: {
auto m = multi_arity_call((MemRefDescriptor<25>(*)(void *...))func, args);
return convert(25, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 25: {
auto m = multi_arity_call((MemRefDescriptor<26>(*)(void *...))func, args);
return convert(26, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 26: {
auto m = multi_arity_call((MemRefDescriptor<27>(*)(void *...))func, args);
return convert(27, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 27: {
auto m = multi_arity_call((MemRefDescriptor<28>(*)(void *...))func, args);
return convert(28, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 28: {
auto m = multi_arity_call((MemRefDescriptor<29>(*)(void *...))func, args);
return convert(29, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 29: {
auto m = multi_arity_call((MemRefDescriptor<30>(*)(void *...))func, args);
return convert(30, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 30: {
auto m = multi_arity_call((MemRefDescriptor<31>(*)(void *...))func, args);
return convert(31, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 31: {
auto m = multi_arity_call((MemRefDescriptor<32>(*)(void *...))func, args);
return convert(32, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 32: {
auto m = multi_arity_call((MemRefDescriptor<33>(*)(void *...))func, args);
return convert(33, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
default:
assert(false);
}
}
} // namespace serverlib
} // namespace concretelang

View File

@@ -0,0 +1,159 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#include <dlfcn.h>
#include "boost/outcome.h"
#include "concretelang/ClientLib/Serializers.h"
#include "concretelang/Common/Error.h"
#include "concretelang/ServerLib/DynamicArityCall.h"
#include "concretelang/ServerLib/DynamicModule.h"
#include "concretelang/ServerLib/DynamicRankCall.h"
#include "concretelang/ServerLib/ServerLambda.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Error.h"
namespace concretelang {
namespace serverlib {
using concretelang::clientlib::CircuitGate;
using concretelang::clientlib::CircuitGateShape;
using concretelang::clientlib::PublicArguments;
using concretelang::error::StringError;
void next_coord_index(size_t index[], size_t sizes[], size_t rank) {
// increase multi dim index
for (int r = rank - 1; r >= 0; r--) {
if (index[r] < sizes[r] - 1) {
index[r]++;
return;
}
index[r] = 0;
}
}
size_t global_index(size_t index[], size_t sizes[], size_t strides[],
size_t rank) {
// compute global index from multi dim index
size_t g_index = 0;
size_t default_stride = 1;
for (int r = rank - 1; r >= 0; r--) {
g_index += index[r] * ((strides[r] == 0) ? default_stride : strides[r]);
default_stride *= sizes[r];
}
return g_index;
}
/** Helper function to convert from MemRefDescriptor to
* encrypted_scalars_and_sizes_t assuming MemRefDescriptor are bufferized */
encrypted_scalars_and_sizes_t encrypted_scalars_and_sizes_t_from_MemRef(
size_t memref_rank, encrypted_scalars_t allocated,
encrypted_scalars_t aligned, size_t offset, size_t *sizes,
size_t *strides) {
(void)allocated;
encrypted_scalars_and_sizes_t result;
assert(aligned != nullptr);
result.sizes.resize(memref_rank);
for (size_t r = 0; r < memref_rank; r++) {
result.sizes[r] = sizes[r];
}
size_t
index[memref_rank]; // ephemeral multi dim index to compute global strides
for (size_t r = 0; r < memref_rank; r++) {
index[r] = 0;
}
auto len = result.length();
result.values.resize(len);
// TODO: add a fast path for dense result (no real strides)
for (size_t i = 0; i < len; i++) {
int g_index = offset + global_index(index, sizes, strides, memref_rank);
result.values[i] = aligned[offset + g_index];
next_coord_index(index, sizes, memref_rank);
}
return result;
}
outcome::checked<ServerLambda, StringError>
ServerLambda::loadFromModule(std::shared_ptr<DynamicModule> module,
std::string funcName) {
ServerLambda lambda;
lambda.module =
module; // prevent module and library handler from being destroyed
lambda.func =
(void *(*)(void *, ...))dlsym(module->libraryHandle, funcName.c_str());
if (auto err = dlerror()) {
return StringError("Cannot open lambda:") << std::string(err);
}
auto param =
llvm::find_if(module->clientParametersList, [&](ClientParameters param) {
return param.functionName == funcName;
});
if (param == module->clientParametersList.end()) {
return StringError("cannot find function ")
<< funcName << "in client parameters";
}
if (param->outputs.size() != 1) {
return StringError("ServerLambda: output arity (")
<< std::to_string(param->outputs.size())
<< ") != 1 is not supported";
}
if (!param->outputs[0].encryption.hasValue()) {
return StringError("ServerLambda: clear output is not yet supported");
}
lambda.clientParameters = *param;
return lambda;
}
outcome::checked<ServerLambda, StringError>
ServerLambda::load(std::string funcName, std::string outputLib) {
OUTCOME_TRY(auto module, DynamicModule::open(outputLib));
return ServerLambda::loadFromModule(module, funcName);
}
encrypted_scalars_and_sizes_t dynamicCall(void *(*func)(void *...),
std::vector<void *> &preparedArgs,
CircuitGate &output,
std::ostream &ostream) {
size_t rank = output.shape.dimensions.size();
return multi_arity_call_dynamic_rank(func, preparedArgs, rank);
}
outcome::checked<void, StringError>
ServerLambda::read_call_write(std::istream &istream, std::ostream &ostream) {
OUTCOME_TRY(auto argumentsPtr,
PublicArguments::unserialize(clientParameters, istream));
assert(istream.good());
PublicArguments &arguments = *argumentsPtr;
// The runtime context is always the last argument list
arguments.preparedArgs.push_back((void *)&arguments.runtimeContext);
auto values_and_sizes = dynamicCall(this->func, arguments.preparedArgs,
clientParameters.outputs[0], ostream);
auto shape = clientParameters.outputs[0].shape;
size_t rank = shape.dimensions.size();
for (size_t dim = 0; dim < rank; dim++) {
if (values_and_sizes.sizes[dim] != (size_t)shape.dimensions[dim]) {
return StringError("Dimension mismatch on dim ")
<< dim << " actual: " << values_and_sizes.sizes[dim]
<< " vs expected: " << shape.dimensions[dim] << "\n";
}
}
serializeEncryptedValues(values_and_sizes, ostream);
if (ostream.fail()) {
return StringError("Cannot write result");
}
return outcome::success();
}
} // namespace serverlib
} // namespace concretelang

View File

@@ -0,0 +1,44 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
print(
"""// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
// generated: see genDynamicRandCall.py
#include <cassert>
#include <vector>
#include "concretelang/ClientLib/Types.h"
#include "concretelang/ServerLib/DynamicArityCall.h"
#include "concretelang/ServerLib/ServerLambda.h"
namespace concretelang {
namespace serverlib {
encrypted_scalars_and_sizes_t
multi_arity_call_dynamic_rank(void* (*func)(void *...), std::vector<void *> args, size_t rank) {
using concretelang::clientlib::MemRefDescriptor;
constexpr auto convert = &encrypted_scalars_and_sizes_t_from_MemRef;
switch (rank) {""")
for tensor_rank in range(0, 33):
memref_rank = tensor_rank + 1
print(f""" case {tensor_rank}:
{{
auto m = multi_arity_call((MemRefDescriptor<{memref_rank}>(*)(void *...))func, args);
return convert({memref_rank}, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}}""")
print("""
default:
assert(false);
}
}""")
print("""
} // namespace serverlib
} // namespace concretelang""")

View File

@@ -30,7 +30,8 @@ add_mlir_library(ConcretelangSupport
MLIRExecutionEngine
${LLVM_PTHREAD_LIB}
ConcretelangCommon
ConcretelangRuntime
ConcretelangClientLib
)

View File

@@ -5,10 +5,12 @@
#include <fstream>
#include <iostream>
#include <regex>
#include <stdio.h>
#include <string>
#include <llvm/Support/Error.h>
#include <llvm/Support/Path.h>
#include <llvm/Support/SMLoc.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Dialect/Linalg/IR/LinalgOps.h>
@@ -19,6 +21,7 @@
#include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/Parser.h>
#include <concretelang/ClientLib/ClientParameters.h>
#include <concretelang/Dialect/BConcrete/IR/BConcreteDialect.h>
#include <concretelang/Dialect/Concrete/IR/ConcreteDialect.h>
#include <concretelang/Dialect/FHE/IR/FHEDialect.h>
@@ -280,9 +283,10 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
auto funcName = this->clientParametersFuncName.getValueOr("main");
if (this->generateClientParameters || target == Target::LIBRARY) {
if (!res.fheContext.hasValue()) {
// Some tests can involves a usual function
res.clientParameters =
mlir::concretelang::emptyClientParametersForV0(funcName, module);
// Some tests involve call a to non encrypted functions
ClientParameters emptyParams;
emptyParams.functionName = funcName;
res.clientParameters = emptyParams;
} else {
auto clientParametersOrErr =
mlir::concretelang::createClientParametersForV0(*res.fheContext,
@@ -425,7 +429,7 @@ std::string CompilerEngine::Library::getStaticLibraryPath(std::string path) {
/** Returns the path of the static library */
std::string CompilerEngine::Library::getClientParametersPath(std::string path) {
return path + CLIENT_PARAMETERS_EXT;
return ClientParameters::getClientParametersPath(path);
}
const std::string CompilerEngine::Library::OBJECT_EXT = ".o";
@@ -466,6 +470,92 @@ CompilerEngine::Library::emitClientParametersJSON() {
return clientParamsPath;
}
static std::string ccpResultType(size_t rank) {
if (rank == 0) {
return "scalar_out";
} else {
return "tensor" + std::to_string(rank) + "_out";
}
}
static std::string ccpArgType(size_t rank) {
if (rank == 0) {
return "scalar_in";
} else {
return "tensor" + std::to_string(rank) + "_in";
}
}
static std::string cppArgsType(std::vector<CircuitGate> inputs) {
std::string args;
for (auto input : inputs) {
if (!args.empty()) {
args += ", ";
}
args += ccpArgType(input.shape.dimensions.size());
}
return args;
}
llvm::Expected<std::string> CompilerEngine::Library::emitCppHeader() {
auto libraryName = llvm::sys::path::filename(libraryPath).str();
auto headerName = libraryName + "-client.h";
auto headerPath = std::regex_replace(
libraryPath, std::regex(libraryName + "$"), headerName);
std::error_code error;
llvm::raw_fd_ostream out(headerPath, error);
if (error) {
StreamStringError("Cannot emit header: ")
<< headerPath << ", " << error.message() << "\n";
}
out << "#include \"boost/outcome.h\"\n";
out << "#include \"concretelang/ClientLib/ClientLambda.h\"\n";
out << "#include \"concretelang/ClientLib/KeySetCache.h\"\n";
out << "#include \"concretelang/ClientLib/Types.h\"\n";
out << "#include \"concretelang/Common/Error.h\"\n";
out << "\n";
out << "namespace " << libraryName << " {\n";
out << "namespace client {\n";
for (auto params : clientParametersList) {
std::string args;
std::string result;
if (params.outputs.size() > 0) {
args = cppArgsType(params.inputs);
} else {
args = "void";
}
if (params.outputs.size() > 0) {
size_t rank = params.outputs[0].shape.dimensions.size();
result = ccpResultType(rank);
} else {
result = "void";
}
out << "\n";
out << "namespace " << params.functionName << " {\n";
out << " using namespace concretelang::clientlib;\n";
out << " using concretelang::error::StringError;\n";
out << " using " << params.functionName << "_t = TypedClientLambda<"
<< result << ", " << args << ">;\n";
out << " static const std::string name = \"" << params.functionName
<< "\";\n";
out << "\n";
out << " static outcome::checked<extract_t, StringError>\n";
out << " load(std::string outputLib)\n";
out << " { return extract_t::load(name, outputLib); }\n";
out << "} // namespace " << params.functionName << "\n";
}
out << "\n";
out << "} // namespace client\n";
out << "} // namespace " << libraryName << "\n";
out.close();
return headerPath;
}
llvm::Expected<std::string>
CompilerEngine::Library::addCompilation(CompilationResult &compilation) {
llvm::Module *module = compilation.llvmModule.get();
@@ -538,6 +628,9 @@ llvm::Error CompilerEngine::Library::emitArtifacts() {
if (auto err = emitClientParametersJSON().takeError()) {
return err;
}
if (auto err = emitCppHeader().takeError()) {
return err;
}
return llvm::Error::success();
}

View File

@@ -12,6 +12,7 @@
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
#include "concretelang/Common/BitsSize.h"
#include <concretelang/Support/Error.h>
#include <concretelang/Support/Jit.h>
#include <concretelang/Support/logging.h>
@@ -198,12 +199,14 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) {
// Else if is encryted, allocate ciphertext and encrypt.
uint64_t *ctArg;
uint64_t ctSize;
if (auto err = this->keySet.allocate_lwe(pos, &ctArg, ctSize)) {
return std::move(err);
auto check = this->keySet.allocate_lwe(pos, &ctArg, ctSize);
if (!check) {
return StreamStringError(check.error().mesg);
}
allocatedCiphertexts.push_back(ctArg);
if (auto err = this->keySet.encrypt_lwe(pos, ctArg, arg)) {
return std::move(err);
check = this->keySet.encrypt_lwe(pos, ctArg, arg);
if (!check) {
return StreamStringError(check.error().mesg);
}
// memref calling convention
// allocated
@@ -224,17 +227,6 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) {
return llvm::Error::success();
}
size_t bitWidthAsWord(size_t exactBitWidth) {
size_t sortedWordBitWidths[] = {8, 16, 32, 64};
size_t previousWidth = 0;
for (auto currentWidth : sortedWordBitWidths) {
if (previousWidth < exactBitWidth && exactBitWidth <= currentWidth) {
return currentWidth;
}
}
return exactBitWidth;
}
llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width,
const void *data,
llvm::ArrayRef<int64_t> shape) {
@@ -253,7 +245,7 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width,
return llvm::make_error<llvm::StringError>(msg,
llvm::inconvertibleErrorCode());
}
auto roundedSize = bitWidthAsWord(info.shape.width);
auto roundedSize = ::concretelang::common::bitWidthAsWord(info.shape.width);
if (width != roundedSize) {
auto msg = "Bad argument (pos=" + llvm::Twine(pos) + ") : expected " +
llvm::Twine(roundedSize) + "bits" + " but received " +
@@ -315,9 +307,9 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width,
for (size_t i = 0, offset = 0; i < info.shape.size;
i++, offset += lweSize) {
if (auto err =
this->keySet.encrypt_lwe(pos, ctBuffer + offset, data8[i])) {
return std::move(err);
auto check = this->keySet.encrypt_lwe(pos, ctBuffer + offset, data8[i]);
if (!check) {
return StreamStringError(check.error().mesg);
}
}
// Replace the data by the buffer to ciphertext
@@ -387,8 +379,9 @@ llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t &res) {
}
// Else if is encryted, decrypt
uint64_t *ct = (uint64_t *)(outputs[offset + 1]);
if (auto err = this->keySet.decrypt_lwe(pos, ct, res)) {
return std::move(err);
auto check = this->keySet.decrypt_lwe(pos, ct, res);
if (!check) {
return StreamStringError(check.error().mesg);
}
return llvm::Error::success();
}
@@ -504,8 +497,9 @@ llvm::Error JITLambda::Argument::getResult(size_t pos, void *res,
for (size_t i = 0, o = 0; i < numElements; i++, o += lweSize) {
uint64_t *ct = ((uint64_t *)alignedBytes) + o;
if (auto err = this->keySet.decrypt_lwe(pos, ct, ((uint64_t *)res)[i])) {
return std::move(err);
auto check = this->keySet.decrypt_lwe(pos, ct, ((uint64_t *)res)[i]);
if (!check) {
return StreamStringError(check.error().mesg);
}
}
}

View File

@@ -116,16 +116,18 @@ JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName,
"parameters has not been computed");
}
llvm::Expected<std::unique_ptr<mlir::concretelang::KeySet>> keySetOrErr =
(cache.hasValue())
? cache->tryLoadOrGenerateSave(*compRes.clientParameters, 0, 0)
: KeySet::generate(*compRes.clientParameters, 0, 0);
std::shared_ptr<KeySetCache> cachePtr;
if (cache.hasValue()) {
cachePtr = std::make_shared<KeySetCache>(cache.getValue());
}
auto keySetOrErr =
KeySetCache::generate(cachePtr, *compRes.clientParameters, 0, 0);
if (!keySetOrErr) {
return keySetOrErr.takeError();
return StreamStringError(keySetOrErr.error().mesg);
}
auto keySet = std::move(keySetOrErr.get());
auto keySet = std::move(keySetOrErr.value());
return Lambda{this->compilationContext, std::move(lambda), std::move(keySet)};
}

View File

@@ -16,6 +16,15 @@
namespace mlir {
namespace concretelang {
using ::concretelang::clientlib::BIG_KEY;
using ::concretelang::clientlib::CircuitGate;
using ::concretelang::clientlib::ClientParameters;
using ::concretelang::clientlib::EncryptionGate;
using ::concretelang::clientlib::LweSecretKeyID;
using ::concretelang::clientlib::Precision;
using ::concretelang::clientlib::SMALL_KEY;
using ::concretelang::clientlib::Variance;
const auto securityLevel = SECURITY_LEVEL_128;
const auto keyFormat = KEY_FORMAT_BINARY;
const auto v0Curve = getV0Curves(securityLevel, keyFormat);
@@ -81,13 +90,6 @@ llvm::Expected<CircuitGate> gateFromMLIRType(LweSecretKeyID secretKeyID,
"cannot convert MLIR type to shape", llvm::inconvertibleErrorCode());
}
ClientParameters emptyClientParametersForV0(llvm::StringRef functionName,
mlir::ModuleOp module) {
ClientParameters c;
c.functionName = (std::string)functionName;
return c;
}
llvm::Expected<ClientParameters>
createClientParametersForV0(V0FHEContext fheContext,
llvm::StringRef functionName,

View File

@@ -1,187 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#include "concretelang/TestLib/Arguments.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Error.h"
#include "concretelang/Support/Jit.h"
namespace mlir {
namespace concretelang {
Arguments::~Arguments() {
for (auto ct : allocatedCiphertexts) {
free(ct);
}
for (auto ctBuffer : ciphertextBuffers) {
free(ctBuffer);
}
}
std::shared_ptr<Arguments> Arguments::create(KeySet &keySet) {
auto args = std::make_shared<Arguments>(keySet);
return args;
}
llvm::Error Arguments::pushArg(uint64_t arg) {
if (auto err = checkPushTooManyArgs()) {
return err;
}
auto pos = currentPos++;
CircuitGate input = keySet.inputGate(pos);
if (input.shape.size != 0) {
return StreamStringError("argument #") << pos << " is not a scalar";
}
if (!input.encryption.hasValue()) {
// clear scalar: just push the argument
if (input.shape.width != 64) {
return StreamStringError(
"scalar argument of with != 64 is not supported for DynamicLambda");
}
preparedArgs.push_back((void *)arg);
return llvm::Error::success();
}
// encrypted scalar: allocate, encrypt and push
uint64_t *ctArg;
uint64_t ctSize = 0;
if (auto err = keySet.allocate_lwe(pos, &ctArg, ctSize)) {
return err;
}
allocatedCiphertexts.push_back(ctArg);
if (auto err = keySet.encrypt_lwe(pos, ctArg, arg)) {
return err;
}
// Note: Since we bufferized lwe ciphertext take care of memref calling
// convention
// allocated
preparedArgs.push_back(nullptr);
// aligned
preparedArgs.push_back(ctArg);
// offset
preparedArgs.push_back((void *)0);
// size
preparedArgs.push_back((void *)ctSize);
// stride
preparedArgs.push_back((void *)1);
return llvm::Error::success();
}
llvm::Error Arguments::pushArg(std::vector<uint8_t> arg) {
return pushArg(8, (void *)arg.data(), {(int64_t)arg.size()});
}
llvm::Error Arguments::pushArg(size_t width, void *data,
llvm::ArrayRef<int64_t> shape) {
if (auto err = checkPushTooManyArgs()) {
return err;
}
auto pos = currentPos;
currentPos = currentPos + 1;
CircuitGate input = keySet.inputGate(pos);
// Check the width of data
if (input.shape.width > 64) {
return StreamStringError("argument #")
<< pos << " width > 64 bits is not supported";
}
auto roundedSize = bitWidthAsWord(input.shape.width);
if (width != roundedSize) {
return StreamStringError("argument #")
<< pos << "width mismatch, got " << width << " expected "
<< roundedSize;
}
// Check the shape of tensor
if (input.shape.dimensions.empty()) {
return StreamStringError("argument #") << pos << "is not a tensor";
}
if (shape.size() != input.shape.dimensions.size()) {
return StreamStringError("argument #")
<< pos << "has not the expected number of dimension, got "
<< shape.size() << " expected " << input.shape.dimensions.size();
}
for (size_t i = 0; i < shape.size(); i++) {
if (shape[i] != input.shape.dimensions[i]) {
return StreamStringError("argument #")
<< pos << " has not the expected dimension #" << i << " , got "
<< shape[i] << " expected " << input.shape.dimensions[i];
}
}
if (input.encryption.hasValue()) {
// Encrypted tensor: for now we support only 8 bits for encrypted tensor
if (width != 8) {
return StreamStringError("argument #")
<< pos << " width mismatch, expected 8 got " << width;
}
const uint8_t *data8 = (const uint8_t *)data;
// Allocate a buffer for ciphertexts of size of tensor
auto lweSize = keySet.getInputLweSecretKeyParam(pos).size + 1;
auto ctBuffer =
(uint64_t *)malloc(input.shape.size * lweSize * sizeof(uint64_t));
ciphertextBuffers.push_back(ctBuffer);
// Allocate ciphertexts and encrypt, for every values in tensor
for (size_t i = 0, offset = 0; i < input.shape.size;
i++, offset += lweSize) {
if (auto err =
this->keySet.encrypt_lwe(pos, ctBuffer + offset, data8[i])) {
return err;
}
}
// Replace the data by the buffer to ciphertext
data = (void *)ctBuffer;
}
// allocated
preparedArgs.push_back(nullptr);
// aligned
preparedArgs.push_back(data);
// offset
preparedArgs.push_back((void *)0);
// sizes
for (size_t i = 0; i < shape.size(); i++) {
preparedArgs.push_back((void *)shape[i]);
}
// If encrypted +1 for the lwe size rank
if (keySet.isInputEncrypted(pos)) {
preparedArgs.push_back(
(void *)(keySet.getInputLweSecretKeyParam(pos).size + 1));
}
// Set the stride for each dimension, equal to the product of the
// following dimensions.
int64_t stride = 1;
// If encrypted +1 set the stride for the lwe size rank
if (keySet.isInputEncrypted(pos)) {
stride *= keySet.getInputLweSecretKeyParam(pos).size + 1;
}
for (ssize_t i = shape.size() - 1; i >= 0; i--) {
preparedArgs.push_back((void *)stride);
stride *= shape[i];
}
if (keySet.isInputEncrypted(pos)) {
preparedArgs.push_back((void *)1);
}
return llvm::Error::success();
}
llvm::Error Arguments::pushContext() {
if (currentPos < keySet.numInputs()) {
return StreamStringError("Missing arguments");
}
preparedArgs.push_back(&context);
return llvm::Error::success();
}
llvm::Error Arguments::checkPushTooManyArgs() {
size_t arity = keySet.numInputs();
if (currentPos < arity) {
return llvm::Error::success();
}
return StreamStringError("function has arity ")
<< arity << " but is applied to too many arguments";
}
} // namespace concretelang
} // namespace mlir

View File

@@ -1,16 +0,0 @@
add_mlir_library(ConcretelangTestLib
Arguments.cpp
DynamicLambda.cpp
DynamicModule.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/TestLib
DEPENDS
MLIRConversionPassIncGen
LINK_LIBS PUBLIC
ConcretelangSupport
ConcretelangClientLib
)

View File

@@ -1,217 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#include <dlfcn.h>
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Error.h"
#include "concretelang/TestLib/DynamicLambda.h"
#include "concretelang/TestLib/dynamicArityCall.h"
namespace mlir {
namespace concretelang {
template <size_t N> struct MemRefDescriptor {
uint64_t *allocated;
uint64_t *aligned;
size_t offset;
size_t sizes[N];
size_t strides[N];
};
llvm::Expected<std::vector<uint64_t>>
decryptSlice(KeySet &keySet, uint64_t *aligned, size_t size) {
auto pos = 0;
std::vector<uint64_t> result(size);
auto lweSize = keySet.getInputLweSecretKeyParam(pos).size + 1;
for (size_t i = 0; i < size; i++) {
size_t offset = i * lweSize;
auto err = keySet.decrypt_lwe(pos, aligned + offset, result[i]);
if (err) {
return StreamStringError()
<< "cannot decrypt result #" << i << ", err:" << err;
}
}
return result;
}
llvm::Expected<mlir::concretelang::DynamicLambda>
DynamicLambda::load(std::string funcName, std::string outputLib) {
auto moduleOrErr = mlir::concretelang::DynamicModule::open(outputLib);
if (!moduleOrErr) {
return moduleOrErr.takeError();
}
return mlir::concretelang::DynamicLambda::load(*moduleOrErr, funcName);
}
llvm::Expected<DynamicLambda>
DynamicLambda::load(std::shared_ptr<DynamicModule> module,
std::string funcName) {
DynamicLambda lambda;
lambda.module =
module; // prevent module and library handler from being destroyed
lambda.func = dlsym(module->libraryHandle, funcName.c_str());
if (auto err = dlerror()) {
return StreamStringError("Cannot open lambda: ") << err;
}
auto param =
llvm::find_if(module->clientParametersList, [&](ClientParameters param) {
return param.functionName == funcName;
});
if (param == module->clientParametersList.end()) {
return StreamStringError("cannot find function ")
<< funcName << "in client parameters";
}
if (param->outputs.size() != 1) {
return StreamStringError("DynamicLambda: output arity (")
<< std::to_string(param->outputs.size())
<< ") != 1 is not supported";
}
if (!param->outputs[0].encryption.hasValue()) {
return StreamStringError(
"DynamicLambda: clear output is not yet supported");
}
lambda.clientParameters = *param;
return lambda;
}
template <>
llvm::Expected<uint64_t> invoke<uint64_t>(DynamicLambda &lambda,
const Arguments &args) {
auto output = lambda.clientParameters.outputs[0];
if (output.shape.size != 0) {
return StreamStringError("the function doesn't return a scalar");
}
// Scalar encrypted result
auto fCasted = (MemRefDescriptor<1>(*)(void *...))(lambda.func);
MemRefDescriptor<1> lweResult =
mlir::concretelang::call(fCasted, args.preparedArgs);
uint64_t decryptedResult;
if (auto err =
lambda.keySet->decrypt_lwe(0, lweResult.aligned, decryptedResult)) {
return std::move(err);
}
return decryptedResult;
}
template <size_t Rank>
llvm::Expected<MemRefDescriptor<Rank>>
DynamicLambda::invokeMemRefDecriptor(const Arguments &args) {
auto output = clientParameters.outputs[0];
if (output.shape.size == 0) {
return StreamStringError("the function doesn't return a tensor");
}
if (output.shape.dimensions.size() != Rank - 1) {
return StreamStringError("the function doesn't return a tensor of rank ")
<< Rank - 1;
}
// Tensor encrypted result
auto fCasted = (MemRefDescriptor<Rank>(*)(void *...))(func);
auto encryptedResult = mlir::concretelang::call(fCasted, args.preparedArgs);
for (size_t dim = 0; dim < Rank - 1; dim++) {
size_t actual_size = encryptedResult.sizes[dim];
size_t expected_size = output.shape.dimensions[dim];
if (actual_size != expected_size) {
return StreamStringError("the function returned a vector of size ")
<< actual_size << " instead of size " << expected_size;
}
}
return encryptedResult;
}
template <>
llvm::Expected<std::vector<uint64_t>>
invoke<std::vector<uint64_t>>(DynamicLambda &lambda, const Arguments &args) {
auto encryptedResultOrErr = lambda.invokeMemRefDecriptor<2>(args);
if (!encryptedResultOrErr) {
return encryptedResultOrErr.takeError();
}
auto &encryptedResult = encryptedResultOrErr.get();
auto &keySet = lambda.keySet;
return decryptSlice(*keySet, encryptedResult.aligned,
encryptedResult.sizes[0]);
}
template <>
llvm::Expected<std::vector<std::vector<uint64_t>>>
invoke<std::vector<std::vector<uint64_t>>>(DynamicLambda &lambda,
const Arguments &args) {
auto encryptedResultOrErr = lambda.invokeMemRefDecriptor<3>(args);
if (!encryptedResultOrErr) {
return encryptedResultOrErr.takeError();
}
auto &encryptedResult = encryptedResultOrErr.get();
std::vector<std::vector<uint64_t>> result;
result.reserve(encryptedResult.sizes[0]);
for (size_t i = 0; i < encryptedResult.sizes[0]; i++) {
int offset = encryptedResult.offset + i * encryptedResult.strides[1];
auto slice = decryptSlice(*lambda.keySet, encryptedResult.aligned + offset,
encryptedResult.sizes[1]);
if (!slice) {
return StreamStringError(llvm::toString(slice.takeError()));
}
result.push_back(slice.get());
}
return result;
}
template <>
llvm::Expected<std::vector<std::vector<std::vector<uint64_t>>>>
invoke<std::vector<std::vector<std::vector<uint64_t>>>>(DynamicLambda &lambda,
const Arguments &args) {
auto encryptedResultOrErr = lambda.invokeMemRefDecriptor<4>(args);
if (!encryptedResultOrErr) {
return encryptedResultOrErr.takeError();
}
auto &encryptedResult = encryptedResultOrErr.get();
auto &keySet = lambda.keySet;
std::vector<std::vector<std::vector<uint64_t>>> result0;
result0.reserve(encryptedResult.sizes[0]);
for (size_t i = 0; i < encryptedResult.sizes[0]; i++) {
std::vector<std::vector<uint64_t>> result1;
result1.reserve(encryptedResult.sizes[1]);
for (size_t j = 0; j < encryptedResult.sizes[1]; j++) {
int offset = encryptedResult.offset + (i * encryptedResult.sizes[1] + j) *
encryptedResult.strides[1];
auto slice = decryptSlice(*keySet, encryptedResult.aligned + offset,
encryptedResult.sizes[2]);
if (!slice) {
return StreamStringError(llvm::toString(slice.takeError()));
}
result1.push_back(slice.get());
}
result0.push_back(result1);
}
return result0;
}
llvm::Error DynamicLambda::generateKeySet(llvm::Optional<KeySetCache> cache,
uint64_t seed_msb,
uint64_t seed_lsb) {
auto maybeKeySet =
cache.hasValue()
? cache->tryLoadOrGenerateSave(clientParameters, seed_msb, seed_lsb)
: KeySet::generate(clientParameters, seed_msb, seed_lsb);
if (auto err = maybeKeySet.takeError()) {
return err;
}
keySet = std::move(maybeKeySet.get());
return llvm::Error::success();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -1,61 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#include <dlfcn.h>
#include <fstream>
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Error.h"
#include "concretelang/TestLib/DynamicModule.h"
namespace mlir {
namespace concretelang {
DynamicModule::~DynamicModule() {
if (libraryHandle != nullptr) {
dlclose(libraryHandle);
}
}
llvm::Expected<std::shared_ptr<DynamicModule>>
DynamicModule::open(std::string path) {
std::shared_ptr<DynamicModule> module = std::make_shared<DynamicModule>();
if (auto err = module->loadClientParametersJSON(path)) {
return StreamStringError("Cannot load client parameters: ")
<< llvm::toString(std::move(err));
}
if (auto err = module->loadSharedLibrary(path)) {
return StreamStringError("Cannot load client parameters: ")
<< llvm::toString(std::move(err));
}
return module;
}
llvm::Error DynamicModule::loadSharedLibrary(std::string path) {
libraryHandle = dlopen(
CompilerEngine::Library::getSharedLibraryPath(path).c_str(), RTLD_LAZY);
if (!libraryHandle) {
return StreamStringError("Cannot open shared library") << dlerror();
}
return llvm::Error::success();
}
llvm::Error DynamicModule::loadClientParametersJSON(std::string path) {
std::ifstream file(CompilerEngine::Library::getClientParametersPath(path));
std::string content((std::istreambuf_iterator<char>(file)),
(std::istreambuf_iterator<char>()));
llvm::Expected<std::vector<ClientParameters>> expectedClientParams =
llvm::json::parse<std::vector<ClientParameters>>(content);
if (auto err = expectedClientParams.takeError()) {
return StreamStringError("Cannot open client parameters: ") << err;
}
this->clientParametersList = *expectedClientParams;
return llvm::Error::success();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -23,7 +23,6 @@ if(CONCRETELANG_PARALLEL_EXECUTION_ENABLED)
RTDialect
ConcretelangSupport
ConcretelangTestLib
-Wl,-rpath,${CMAKE_BINARY_DIR}/lib/Runtime
-Wl,-rpath,${HPX_DIR}/../../

View File

@@ -16,6 +16,7 @@ set_source_files_properties(
target_link_libraries(
support_unit_test
gtest_main
ConcretelangClientLib
ConcretelangSupport
)

View File

@@ -2,63 +2,69 @@
#include "../unittest/end_to_end_jit_test.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/EncryptedArgs.h"
namespace CL = mlir::concretelang;
namespace clientlib = concretelang::clientlib;
TEST(Support, client_parameters_json_serde) {
mlir::concretelang::ClientParameters params0;
clientlib::ClientParameters params0;
params0.secretKeys = {
{mlir::concretelang::SMALL_KEY, {/*.size = */ 12}},
{mlir::concretelang::BIG_KEY, {/*.size = */ 14}},
{clientlib::SMALL_KEY, {/*.size = */ 12}},
{clientlib::BIG_KEY, {/*.size = */ 14}},
};
params0.bootstrapKeys = {
{"bsk_v0",
{/*.inputSecretKeyID = */ mlir::concretelang::SMALL_KEY,
/*.outputSecretKeyID = */ mlir::concretelang::BIG_KEY,
{
"bsk_v0", {
/*.inputSecretKeyID = */ clientlib::SMALL_KEY,
/*.outputSecretKeyID = */ clientlib::BIG_KEY,
/*.level = */ 1,
/*.baseLog = */ 2,
/*.glweDimension = */ 3,
/*.variance = */ 0.001}},
{"wtf_bsk_v0",
{
/*.inputSecretKeyID = */ mlir::concretelang::BIG_KEY,
/*.outputSecretKeyID = */ mlir::concretelang::SMALL_KEY,
/*.level = */ 3,
/*.baseLog = */ 2,
/*.glweDimension = */ 1,
/*.variance = */ 0.0001,
}},
/*.variance = */ 0.001
}
},{
"wtf_bsk_v0", {
/*.inputSecretKeyID = */ clientlib::BIG_KEY,
/*.outputSecretKeyID = */ clientlib::SMALL_KEY,
/*.level = */ 3,
/*.baseLog = */ 2,
/*.glweDimension = */ 1,
/*.variance = */ 0.0001,
}
},
};
params0.keyswitchKeys = {
{"ksk_v0",
{
/*.inputSecretKeyID = */ mlir::concretelang::BIG_KEY,
/*.outputSecretKeyID = */ mlir::concretelang::SMALL_KEY,
/*.level = */ 1,
/*.baseLog = */ 2,
/*.variance = */ 3,
}}};
{
"ksk_v0", {
/*.inputSecretKeyID = */ clientlib::BIG_KEY,
/*.outputSecretKeyID = */ clientlib::SMALL_KEY,
/*.level = */ 1,
/*.baseLog = */ 2,
/*.variance = */ 3,
}
}
};
params0.inputs = {
{
/*.encryption = */ {{CL::SMALL_KEY, 0.01, {4}}},
/*.shape = */ {32, {1, 2, 3, 4}, 1 * 2 * 3 * 4},
},
{
/*.encryption = */ {{CL::SMALL_KEY, 0.03, {5}}},
/*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4},
},
{
/*.encryption = */ {{clientlib::SMALL_KEY, 0.01, {4}}},
/*.shape = */ {32, {1, 2, 3, 4}, 1 * 2 * 3 * 4},
},
{
/*.encryption = */ {{clientlib::SMALL_KEY, 0.03, {5}}},
/*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4},
},
};
params0.outputs = {
{
/*.encryption = */ {{CL::SMALL_KEY, 0.03, {5}}},
/*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4},
},
};
auto json = mlir::concretelang::toJSON(params0);
{
/*.encryption = */ {{clientlib::SMALL_KEY, 0.03, {5}}},
/*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4},
},
};
auto json = clientlib::toJSON(params0);
std::string jsonStr;
llvm::raw_string_ostream os(jsonStr);
os << json;
auto parseResult =
llvm::json::parse<mlir::concretelang::ClientParameters>(jsonStr);
llvm::json::parse<clientlib::ClientParameters>(jsonStr);
ASSERT_EXPECTED_VALUE(parseResult, params0);
}

View File

@@ -15,9 +15,12 @@ set_source_files_properties(
target_link_libraries(
testlib_unit_test
gtest_main
ConcretelangCommon
ConcretelangRuntime
ConcretelangTestLib
ConcretelangSupport
ConcretelangClientLib
ConcretelangServerLib
gtest_main
)
include(GoogleTest)

View File

@@ -0,0 +1,22 @@
#include "boost/outcome.h"
#include "concretelang/ClientLib/ClientLambda.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/ClientLib/Types.h"
#include "concretelang/Common/Error.h"
namespace call_2t_1s_with_header {
namespace client {
namespace extract {
using namespace concretelang::clientlib;
using concretelang::error::StringError;
using extract_t = TypedClientLambda<scalar_out, tensor1_in, tensor1_in>;
static const std::string name = "extract";
static outcome::checked<extract_t, StringError>
load(std::string outputLib)
{ return extract_t::load(name, outputLib); }
} // namespace extract
} // namespace client
} // namespace call_2t_1s_with_header

View File

@@ -1,40 +1,88 @@
#include <gtest/gtest.h>
#include <numeric>
#include <cassert>
#include <fstream>
#include <numeric>
#include "boost/outcome.h"
#include "../unittest/end_to_end_jit_test.h"
#include "concretelang/TestLib/DynamicLambda.h"
#include "concretelang/ClientLib/ClientLambda.h"
#include "concretelang/Common/Error.h"
#include "concretelang/TestLib/TestTypedLambda.h"
#include "call_2t_1s_with_header-client.h.generated"
const std::string FUNCNAME = "main";
template<typename... Params>
using TypedDynamicLambda = mlir::concretelang::TypedDynamicLambda<Params...>;
using namespace concretelang::testlib;
using scalar = uint64_t;
using tensor1_in = std::vector<uint8_t>;
using tensor1_out = std::vector<uint64_t>;
using tensor2_out = std::vector<std::vector<uint64_t>>;
using tensor3_out = std::vector<std::vector<std::vector<uint64_t>>>;
using concretelang::clientlib::scalar_in;
using concretelang::clientlib::scalar_out;
using concretelang::clientlib::tensor1_in;
using concretelang::clientlib::tensor2_in;
using concretelang::clientlib::tensor1_out;
using concretelang::clientlib::tensor2_out;
using concretelang::clientlib::tensor3_out;
std::vector<uint8_t>
values_7bits() {
return {0, 1, 2, 63, 64, 65, 125, 126};
}
llvm::Expected<mlir::concretelang::CompilerEngine::Library>
compile(std::string outputLib, std::string source) {
mlir::concretelang::CompilerEngine::Library
compile(std::string outputLib, std::string source, std::string funcname = FUNCNAME) {
std::vector<std::string> sources = {source};
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
mlir::concretelang::CompilationContext::createShared();
mlir::concretelang::JitCompilerEngine ce {ccx};
ce.setClientParametersFuncName(FUNCNAME);
return ce.compile(sources, outputLib);
ce.setClientParametersFuncName(funcname);
auto result = ce.compile(sources, outputLib);
assert(result);
return result.get();
}
static const std::string THIS_TEST_DIRECTORY = "tests/TestLib";
static const std::string OUT_DIRECTORY = THIS_TEST_DIRECTORY + "/out";
template<typename Info>
std::string outputLibFromThis(Info *info) {
return "tests/TestLib/out/" + std::string(info->name());
return OUT_DIRECTORY + "/" + std::string(info->name());
}
template<typename Lambda>
Lambda load(std::string outputLib) {
auto l = Lambda::load(FUNCNAME, outputLib, 0, 0, getTestKeySetCachePtr());
assert(l.has_value());
return l.value();
}
TEST(CompiledModule, call_1s_1s_client_view) {
std::string source = R"(
func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
return %arg0: !FHE.eint<7>
}
)";
namespace clientlib = concretelang::clientlib;
using MyLambda = clientlib::TypedClientLambda<scalar_out, scalar_in>;
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
std::string jsonPath = ClientParameters::getClientParametersPath(outputLib);
auto maybeLambda = MyLambda::load("main", jsonPath);
ASSERT_TRUE(maybeLambda.has_value());
auto lambda = maybeLambda.value();
auto maybeKeySet = lambda.keySet(getTestKeySetCachePtr(), 0, 0);
ASSERT_TRUE(maybeKeySet.has_value());
std::shared_ptr<KeySet> keySet = std::move(maybeKeySet.value());
auto maybePublicArguments = lambda.publicArguments(1, keySet);
ASSERT_TRUE(maybePublicArguments.has_value());
auto publicArguments = std::move(maybePublicArguments.value());
std::ostringstream osstream(std::ios::binary);
EXPECT_TRUE(lambda.untypedSerializeCall(publicArguments, osstream));
EXPECT_TRUE(osstream.good());
// Direct call without intermediate
EXPECT_TRUE(lambda.serializeCall(1, keySet, osstream));
EXPECT_TRUE(osstream.good());
}
TEST(CompiledModule, call_1s_1s) {
@@ -45,13 +93,28 @@ func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
ASSERT_EXPECTED_SUCCESS(compiled);
auto lambda = TypedDynamicLambda<scalar, scalar>::load(FUNCNAME, outputLib);
ASSERT_EXPECTED_SUCCESS(lambda);
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
auto lambda = load<TestTypedLambda<scalar_out, scalar_in>>(outputLib);
for(auto a: values_7bits()) {
auto res = lambda->call(a);
ASSERT_EXPECTED_VALUE(res, a);
auto res = lambda.call(a);
ASSERT_EQ_OUTCOME(res, a);
}
}
TEST(CompiledModule, call_2s_1s_choose) {
std::string source = R"(
func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
return %arg0: !FHE.eint<7>
}
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
auto lambda = load<TestTypedLambda<scalar_out, scalar_in, scalar_in>>(outputLib);
for(auto a: values_7bits()) for(auto b: values_7bits()) {
if (a > b) {
continue;
}
auto res = lambda.call(a, b);
ASSERT_EQ_OUTCOME(res, a);
}
}
@@ -64,16 +127,30 @@ func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
ASSERT_EXPECTED_SUCCESS(compiled);
auto lambda = TypedDynamicLambda<scalar, scalar, scalar>::load(FUNCNAME, outputLib);
ASSERT_EXPECTED_SUCCESS(lambda);
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
auto lambda = load<TestTypedLambda<scalar_out, scalar_in, scalar_in>>(outputLib);
for(auto a: values_7bits()) for(auto b: values_7bits()) {
auto res = lambda->call(a, b);
ASSERT_EXPECTED_VALUE(res, a + b);
if (a > b) {
continue;
}
auto res = lambda.call(a, b);
ASSERT_EQ_OUTCOME(res, a + b);
}
}
TEST(CompiledModule, call_1s_1s_bad_call) {
std::string source = R"(
func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
auto lambda = load<TestTypedLambda<scalar_out, scalar_in>>(outputLib);
auto res = lambda.call(1);
ASSERT_FALSE(res.has_value());
}
TEST(CompiledModule, call_1s_1t) {
std::string source = R"(
func @main(%arg0: !FHE.eint<7>) -> tensor<1x!FHE.eint<7>> {
@@ -83,14 +160,11 @@ func @main(%arg0: !FHE.eint<7>) -> tensor<1x!FHE.eint<7>> {
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
ASSERT_EXPECTED_SUCCESS(compiled);
auto lambda = TypedDynamicLambda<tensor1_out, scalar>::load(FUNCNAME, outputLib);
ASSERT_EXPECTED_SUCCESS(lambda);
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
auto lambda = load<TestTypedLambda<tensor1_out, scalar_in>>(outputLib);
for(auto a: values_7bits()) {
auto res = lambda->call(a);
ASSERT_EXPECTED_SUCCESS(res);
tensor1_out v = res.get();
auto res = lambda.call(a);
EXPECT_TRUE(res);
tensor1_out v = res.value();
EXPECT_EQ(v[0], a);
}
}
@@ -104,16 +178,13 @@ func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> tensor<2x!FHE.eint<7>> {
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
ASSERT_EXPECTED_SUCCESS(compiled);
auto lambda = TypedDynamicLambda<tensor1_out, scalar, scalar>::load(FUNCNAME, outputLib);
ASSERT_EXPECTED_SUCCESS(lambda);
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
auto lambda = load<TestTypedLambda<tensor1_out, scalar_in, scalar_in>>(outputLib);
for(auto a : values_7bits()) {
auto res = lambda->call(a, a+1);
ASSERT_EXPECTED_SUCCESS(res);
tensor1_out v = res.get();
EXPECT_EQ((scalar)v[0], a);
EXPECT_EQ((scalar)v[1], a + 1u);
auto res = lambda.call(a, a+1);
EXPECT_TRUE(res);
tensor1_out v = res.value();
EXPECT_EQ(v[0], (scalar_out)a);
EXPECT_EQ(v[1], (scalar_out)(a + 1u));
}
}
@@ -127,14 +198,11 @@ func @main(%arg0: tensor<1x!FHE.eint<7>>) -> !FHE.eint<7> {
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
ASSERT_EXPECTED_SUCCESS(compiled);
auto lambda = TypedDynamicLambda<scalar, tensor1_in>::load(FUNCNAME, outputLib);
ASSERT_EXPECTED_SUCCESS(lambda);
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
auto lambda = load<TestTypedLambda<scalar_out, tensor1_in>>(outputLib);
for(uint8_t a : values_7bits()) {
tensor1_in ta = {a};
auto res = lambda->call(ta);
ASSERT_EXPECTED_VALUE(res, a);
auto res = lambda.call(ta);
ASSERT_EQ_OUTCOME(res, a);
}
}
@@ -146,14 +214,11 @@ func @main(%arg0: tensor<3x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> {
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
ASSERT_EXPECTED_SUCCESS(compiled);
auto lambda = TypedDynamicLambda<tensor1_out, tensor1_in>::load(FUNCNAME, outputLib);
ASSERT_EXPECTED_SUCCESS(lambda);
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
auto lambda = load<TestTypedLambda<tensor1_out, tensor1_in>>(outputLib);
tensor1_in ta = {1, 2, 3};
auto res = lambda->call(ta);
ASSERT_EXPECTED_SUCCESS(res);
tensor1_out v = res.get();
auto res = lambda.call(ta);
ASSERT_TRUE(res);
tensor1_out v = res.value();
for(size_t i = 0; i < v.size(); i++) {
EXPECT_EQ(v[i], ta[i]);
}
@@ -171,16 +236,13 @@ func @main(%arg0: tensor<3x!FHE.eint<7>>, %arg1: tensor<3x!FHE.eint<7>>) -> !FHE
)";
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
ASSERT_EXPECTED_SUCCESS(compiled);
auto lambda = TypedDynamicLambda<scalar, tensor1_in, std::array<uint8_t, 3>>::load(FUNCNAME, outputLib);
ASSERT_EXPECTED_SUCCESS(lambda);
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
auto lambda = load<TestTypedLambda<scalar_out, tensor1_in, std::array<uint8_t, 3>>>(outputLib);
tensor1_in ta {1, 2, 3};
std::array<uint8_t, 3> tb {5, 7, 9};
auto res = lambda->call(ta, tb);
auto res = lambda.call(ta, tb);
auto expected = std::accumulate(ta.begin(), ta.end(), 0u) +
std::accumulate(tb.begin(), tb.end(), 0u);
ASSERT_EXPECTED_VALUE(res, expected);
ASSERT_EQ_OUTCOME(res, expected);
}
TEST(CompiledModule, call_1tr2_1tr2) {
@@ -192,17 +254,14 @@ func @main(%arg0: tensor<2x3x!FHE.eint<7>>) -> tensor<2x3x!FHE.eint<7>> {
using tensor2_in = std::array<std::array<uint8_t, 3>, 2>;
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
ASSERT_EXPECTED_SUCCESS(compiled);
auto lambda = TypedDynamicLambda<tensor2_out, tensor2_in>::load(FUNCNAME, outputLib);
ASSERT_EXPECTED_SUCCESS(lambda);
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
auto lambda = load<TestTypedLambda<tensor2_out, tensor2_in>>(outputLib);
tensor2_in ta = {{
{1, 2, 3},
{4, 5, 6}
}};
auto res = lambda->call(ta);
ASSERT_EXPECTED_SUCCESS(res);
tensor2_out v = res.get();
auto res = lambda.call(ta);
ASSERT_TRUE(res);
tensor2_out v = res.value();
for(size_t i = 0; i < v.size(); i++) {
for(size_t j = 0; j < v.size(); j++) {
EXPECT_EQ(v[i][j], ta[i][j]);
@@ -210,7 +269,6 @@ func @main(%arg0: tensor<2x3x!FHE.eint<7>>) -> tensor<2x3x!FHE.eint<7>> {
}
}
TEST(CompiledModule, call_1tr3_1tr3) {
std::string source = R"(
func @main(%arg0: tensor<2x3x1x!FHE.eint<7>>) -> tensor<2x3x1x!FHE.eint<7>> {
@@ -220,17 +278,14 @@ func @main(%arg0: tensor<2x3x1x!FHE.eint<7>>) -> tensor<2x3x1x!FHE.eint<7>> {
using tensor3_in = std::array<std::array<std::array<uint8_t, 1>, 3>, 2>;
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
ASSERT_EXPECTED_SUCCESS(compiled);
auto lambda = TypedDynamicLambda<tensor3_out, tensor3_in>::load(FUNCNAME, outputLib);
ASSERT_EXPECTED_SUCCESS(lambda);
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
auto lambda = load<TestTypedLambda<tensor3_out, tensor3_in>>(outputLib);
tensor3_in ta = {{
{{ {1}, {2}, {3} }},
{{ {4}, {5}, {6} }}
}};
auto res = lambda->call(ta);
ASSERT_EXPECTED_SUCCESS(res);
tensor3_out v = res.get();
auto res = lambda.call(ta);
ASSERT_TRUE(res);
tensor3_out v = res.value();
for(size_t i = 0; i < v.size(); i++) {
for(size_t j = 0; j < v[i].size(); j++) {
for(size_t k = 0; k < v[i][j].size(); k++) {
@@ -239,3 +294,73 @@ func @main(%arg0: tensor<2x3x1x!FHE.eint<7>>) -> tensor<2x3x1x!FHE.eint<7>> {
}
}
}
TEST(CompiledModule, call_2tr3_1tr3) {
std::string source = R"(
func @main(%arg0: tensor<2x3x1x!FHE.eint<7>>, %arg1: tensor<2x3x1x!FHE.eint<7>>) -> tensor<2x3x1x!FHE.eint<7>> {
%1 = "FHELinalg.add_eint"(%arg0, %arg1): (tensor<2x3x1x!FHE.eint<7>>, tensor<2x3x1x!FHE.eint<7>>) -> tensor<2x3x1x!FHE.eint<7>>
return %1: tensor<2x3x1x!FHE.eint<7>>
}
)";
using tensor3_in = std::array<std::array<std::array<uint8_t, 1>, 3>, 2>;
std::string outputLib = outputLibFromThis(this->test_info_);
auto compiled = compile(outputLib, source);
auto lambda = load<TestTypedLambda<tensor3_out, tensor3_in, tensor3_in>>(outputLib);
tensor3_in ta = {{
{{ {1}, {2}, {3} }},
{{ {4}, {5}, {6} }}
}};
auto res = lambda.call(ta, ta);
ASSERT_TRUE(res);
tensor3_out v = res.value();
for(size_t i = 0; i < v.size(); i++) {
for(size_t j = 0; j < v[i].size(); j++) {
for(size_t k = 0; k < v[i][j].size(); k++) {
EXPECT_EQ(v[i][j][k], 2 * ta[i][j][k]);
}
}
}
}
static std::string fileContent(std::string path) {
std::ifstream file(path);
std::stringstream buffer;
buffer << file.rdbuf();
return buffer.str();
}
TEST(CompiledModule, call_2t_1s_with_header) {
std::string source = R"(
func @extract(%arg0: tensor<3x!FHE.eint<7>>, %arg1: tensor<3x!FHE.eint<7>>) -> !FHE.eint<7> {
%1 = "FHELinalg.add_eint"(%arg0, %arg1) : (tensor<3x!FHE.eint<7>>, tensor<3x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>>
%c1 = arith.constant 1 : i8
%2 = tensor.from_elements %c1, %c1, %c1 : tensor<3xi8>
%3 = "FHELinalg.dot_eint_int"(%1, %2) : (tensor<3x!FHE.eint<7>>, tensor<3xi8>) -> !FHE.eint<7>
return %3: !FHE.eint<7>
}
)";
std::string outputLib = outputLibFromThis(this->test_info_);
namespace extract = call_2t_1s_with_header::client::extract;
auto compiled = compile(outputLib, source, extract::name);
std::string jsonPath = ClientParameters::getClientParametersPath(outputLib);
auto cLambda_ = extract::load(jsonPath);
ASSERT_TRUE(cLambda_);
tensor1_in ta {1, 2, 3};
tensor1_in tb {5, 7, 9};
auto sLambda_ = ServerLambda::load(extract::name, outputLib);
ASSERT_TRUE(sLambda_);
auto cLambda = cLambda_.value();
auto sLambda = sLambda_.value();
auto keySet_ = cLambda.keySet(getTestKeySetCachePtr(), 0, 0);
ASSERT_TRUE(keySet_.has_value());
std::shared_ptr<KeySet> keySet = std::move(keySet_.value());
auto testLambda = TestTypedLambdaFrom(cLambda, sLambda, keySet);
auto res = testLambda.call(ta, tb);
auto expected = std::accumulate(ta.begin(), ta.end(), 0u) +
std::accumulate(tb.begin(), tb.end(), 0u);
ASSERT_EQ_OUTCOME(res, expected);
EXPECT_EQ(
fileContent(THIS_TEST_DIRECTORY + "/call_2t_1s_with_header-client.h.generated"),
fileContent(OUT_DIRECTORY + "/call_2t_1s_with_header-client.h"));
}

View File

@@ -153,8 +153,8 @@ func @main(%t: tensor<10xi1>, %i: index) -> i1{
///////////////////////////////////////////////////////////////////////////////
const size_t numDim = 2;
const int64_t dim0 = 2;
const int64_t dim1 = 10;
const size_t dim0 = 2;
const size_t dim1 = 10;
const int64_t dims[numDim]{dim0, dim1};
static std::vector<uint64_t> tensor2D{
0xFFFFFFFFFFFFFFFF,

View File

@@ -5,8 +5,8 @@
///////////////////////////////////////////////////////////////////////////////
const size_t numDim = 2;
const int64_t dim0 = 2;
const int64_t dim1 = 10;
const size_t dim0 = 2;
const size_t dim1 = 10;
const int64_t dims[numDim]{dim0, dim1};
static std::vector<uint8_t> tensor2D{
63, 12, 7, 43, 52, 9, 26, 34, 22, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
@@ -138,10 +138,10 @@ func @main(%t: tensor<8x4x5x3x!FHE.eint<6>>, %d0: index, %d1: index, %d2: index,
uint8_t A[dimSizes[0]][dimSizes[1]][dimSizes[2]][dimSizes[3]];
// Fill with some reproducible pattern
for (int64_t d0 = 0; d0 < dimSizes[0]; d0++) {
for (int64_t d1 = 0; d1 < dimSizes[1]; d1++) {
for (int64_t d2 = 0; d2 < dimSizes[2]; d2++) {
for (int64_t d3 = 0; d3 < dimSizes[3]; d3++) {
for (size_t d0 = 0; d0 < dimSizes[0]; d0++) {
for (size_t d1 = 0; d1 < dimSizes[1]; d1++) {
for (size_t d2 = 0; d2 < dimSizes[2]; d2++) {
for (size_t d3 = 0; d3 < dimSizes[3]; d3++) {
A[d0][d1][d2][d3] = d0 + d1 + d2 + d3;
}
}

View File

@@ -93,6 +93,13 @@ static bool assert_expected_value(llvm::Expected<T> &&val, const V &exp) {
} \
} while (0)
#define ASSERT_EQ_OUTCOME(val, exp) \
if(!val.has_value()) { \
std::string msg = "ERROR: <" + val.error().mesg + "> \n"; \
GTEST_FATAL_FAILURE_(msg.c_str()); \
}; \
ASSERT_EQ(val.value(), exp);
static inline llvm::Optional<mlir::concretelang::KeySetCache> getTestKeySetCache() {
llvm::SmallString<0> cachePath;
@@ -104,6 +111,11 @@ static inline llvm::Optional<mlir::concretelang::KeySetCache> getTestKeySetCache
mlir::concretelang::KeySetCache(cachePathStr));
}
static inline std::shared_ptr<mlir::concretelang::KeySetCache> getTestKeySetCachePtr() {
return std::make_shared<mlir::concretelang::KeySetCache>(
getTestKeySetCache().getValue());
}
// Jit-compiles the function specified by `func` from `src` and
// returns the corresponding lambda. Any compilation errors are caught
// and reult in abnormal termination.

View File

@@ -0,0 +1,16 @@
ClientLib
=========
.. toctree::
clientlib/intro
clientlib/client_lambda
clientlib/server_lambda
Index
=====
.. toctree::
:glob:
:maxdepth: 2
clientlib/*

View File

@@ -0,0 +1,4 @@
LambdaArgument:
===============
.. doxygenfile:: Arguments.h

View File

@@ -0,0 +1,16 @@
#include <concretelang/ClientLib/ClientLambda.h>
// Include the header that has been generated by the compiler
#include "myinclude/fhe_service/additions.h"
void query_server(MyConnection conn) {
auto libPath = "/opt/fhe_service/libs/additions.so";
std::ostream to_server = conn.ostream();
std::ostream from_server = conn.istream();
// In a real code only load once.
auto add2int = additions::add2int::load(libPath, seed_msb, seed_lsb);
auto err = add2int->callSerialize(1, 2, to_server);
if( err ) { throw MyException() };
auto result = add2int->decryptReturned(from_server);
assert(result == 3);
}

View File

@@ -0,0 +1,5 @@
ClientLambda:
===============
.. doxygenfile:: ClientLambda.h
:sections: briefdescription innernamespace typedef innerclass public-static-func public-func

View File

@@ -0,0 +1,5 @@
ClientParameters:
===============
.. doxygenstruct:: mlir::concretelang::ClientParameters
.. doxygenfile:: ClientParameters.h

View File

@@ -0,0 +1,23 @@
Description
========
ClientLambda represents a FHE function on the client side.
ServerLambda represents a FHE function on the server side.
These object read/write on istreams/ostreams.
Implementing a client/server consists in connecting the two, by connecting to actual streams.
Example on client side:
.. literalinclude:: client_example.cpp
:linenos:
:language: bash
..
For some reason cpp does not work well here.
Example on server side:
.. literalinclude:: server_example.cpp
:linenos:
:language: bash

View File

@@ -0,0 +1,10 @@
#include <concretelang/ServerLib/ServerLambda.h>
void answer_client(MyConnection conn) {
std::istream from_client = conn.istream();
std::ostream to_client = conn.ostream();
auto err = serverLambda.read_call_write(serverInput, serverOutput);
if (err) {
throw MyException();
}
}

View File

@@ -0,0 +1,4 @@
ServerLambda:
===============
.. doxygenfile:: ServerLambda.h

View File

@@ -1,4 +0,0 @@
ClientParameters:
===============
.. doxygenfile:: ClientParameters.h