mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(Clientlib): separate client encryption and server computation
Resolve #200
This commit is contained in:
@@ -5,8 +5,9 @@ print_and_exit() {
|
||||
exit 1
|
||||
}
|
||||
|
||||
EXCLUDE_DIRS="-path ./compiler/include/boost-single-header -prune -o"
|
||||
|
||||
files=$(find ./compiler/{include,lib,src} -iregex '^.*\.\(cpp\|cc\|h\|hpp\)$')
|
||||
files=$(find ./compiler/{include,lib,src} $EXCLUDE_DIRS -iregex '^.*\.\(cpp\|cc\|h\|hpp\)$' -print)
|
||||
|
||||
for file in $files
|
||||
do
|
||||
|
||||
4
.github/workflows/scripts/format_cpp.sh
vendored
4
.github/workflows/scripts/format_cpp.sh
vendored
@@ -2,7 +2,9 @@
|
||||
|
||||
set -e -o pipefail
|
||||
|
||||
find ./compiler/{include,lib,src} -iregex '^.*\.\(cpp\|cc\|h\|hpp\)$' | xargs clang-format -i -style='file'
|
||||
EXCLUDE_DIRS="-path ./compiler/include/boost-single-header -prune -o"
|
||||
|
||||
find ./compiler/{include,lib,src} $EXCLUDE_DIRS -iregex '^.*\.\(cpp\|cc\|h\|hpp\)$' -print | xargs clang-format -i -style='file'
|
||||
|
||||
# show changes if any
|
||||
git --no-pager diff
|
||||
|
||||
@@ -12,6 +12,16 @@ if (APPLE)
|
||||
add_definitions("-Wno-narrowing")
|
||||
endif()
|
||||
|
||||
add_compile_options(-Wfatal-errors) # stop at first error
|
||||
# variable length array = vla
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
||||
# using Clang
|
||||
add_compile_options(-Wno-vla-extension)
|
||||
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
|
||||
# using GCC
|
||||
add_compile_options(-Wno-vla)
|
||||
endif()
|
||||
|
||||
# If we are trying to build the compiler with LLVM/MLIR as libraries
|
||||
if( NOT DEFINED LLVM_EXTERNAL_CONCRETELANG_SOURCE_DIR )
|
||||
message(FATAL_ERROR "Concrete compiler requires a unified build with LLVM/MLIR")
|
||||
|
||||
@@ -5,6 +5,8 @@ PARALLEL_EXECUTION_ENABLED=OFF
|
||||
CC_COMPILER=
|
||||
CXX_COMPILER=
|
||||
|
||||
EXTERNAL_HEADERS=include/boost-single-header/outcome.hpp
|
||||
|
||||
export PATH := $(BUILD_DIR)/bin:$(PATH)
|
||||
|
||||
ifeq ($(shell which ccache),)
|
||||
@@ -51,7 +53,10 @@ $(BUILD_DIR)/configured.stamp:
|
||||
-DPython3_EXECUTABLE=${Python3_EXECUTABLE}
|
||||
touch $@
|
||||
|
||||
build-initialized: $(BUILD_DIR)/configured.stamp
|
||||
include/boost-single-header/outcome.hpp:
|
||||
wget https://github.com/ned14/outcome/raw/master/single-header/outcome.hpp -O $@
|
||||
|
||||
build-initialized: $(EXTERNAL_HEADERS) $(BUILD_DIR)/configured.stamp
|
||||
|
||||
doc: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target mlir-doc
|
||||
@@ -63,6 +68,12 @@ python-bindings: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target ConcretelangMLIRPythonModules
|
||||
cmake --build $(BUILD_DIR) --target ConcretelangPythonModules
|
||||
|
||||
clientlib: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target ConcretelangClientLib
|
||||
|
||||
serverlib: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target ConcretelangServerLib
|
||||
|
||||
test-check: concretecompiler file-check not
|
||||
$(BUILD_DIR)/bin/llvm-lit -v tests/
|
||||
|
||||
@@ -87,6 +98,12 @@ uninstall_runtime_lib:
|
||||
|
||||
# unit-test
|
||||
|
||||
clientlib-unit-test: build-clientlib-unit-test
|
||||
$(BUILD_DIR)/bin/clientlib_unit_test
|
||||
|
||||
build-clientlib-unit-test:
|
||||
cmake --build $(BUILD_DIR) --target clientlib_unit_test
|
||||
|
||||
testlib-unit-test: build-testlib-unit-test
|
||||
$(BUILD_DIR)/bin/testlib_unit_test
|
||||
|
||||
|
||||
7939
compiler/include/boost-single-header/outcome.hpp
Normal file
7939
compiler/include/boost-single-header/outcome.hpp
Normal file
File diff suppressed because it is too large
Load Diff
13
compiler/include/boost/outcome.h
Normal file
13
compiler/include/boost/outcome.h
Normal file
@@ -0,0 +1,13 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_BOOST_OUTCOME_H
|
||||
#define CONCRETELANG_BOOST_OUTCOME_H
|
||||
|
||||
#include "boost-single-header/outcome.hpp"
|
||||
|
||||
namespace outcome = outcome_v2_e261cebd;
|
||||
|
||||
#endif
|
||||
141
compiler/include/concretelang/ClientLib/ClientLambda.h
Normal file
141
compiler/include/concretelang/ClientLib/ClientLambda.h
Normal file
@@ -0,0 +1,141 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_CLIENT_LAMBDA_H
|
||||
#define CONCRETELANG_CLIENTLIB_CLIENT_LAMBDA_H
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/EncryptedArgs.h"
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/ClientLib/PublicArguments.h"
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using concretelang::error::StringError;
|
||||
using scalar_in = uint8_t;
|
||||
using scalar_out = uint64_t;
|
||||
using tensor1_in = std::vector<scalar_in>;
|
||||
using tensor2_in = std::vector<std::vector<scalar_in>>;
|
||||
using tensor3_in = std::vector<std::vector<std::vector<scalar_in>>>;
|
||||
using tensor1_out = std::vector<scalar_out>;
|
||||
using tensor2_out = std::vector<std::vector<scalar_out>>;
|
||||
using tensor3_out = std::vector<std::vector<std::vector<scalar_out>>>;
|
||||
|
||||
class ClientLambda {
|
||||
/// Low-level class to create the client side view of a FHE function.
|
||||
public:
|
||||
virtual ~ClientLambda() = default;
|
||||
static outcome::checked<ClientLambda, StringError>
|
||||
/// Construct a ClientLambda from a ClientParameter file.
|
||||
load(std::string funcName, std::string jsonPath);
|
||||
|
||||
/// Emit a call to the given ostream, no meta-date are include, so it's the
|
||||
/// responsability of the the caller/callee to verify the add/verify the
|
||||
/// function to be called.
|
||||
outcome::checked<void, StringError>
|
||||
untypedSerializeCall(PublicArguments &publicArguments, std::ostream &ostream);
|
||||
|
||||
/// Generate or get from cache a KeySet suitable for this ClientLambda
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
keySet(std::shared_ptr<KeySetCache> optionalCache, uint64_t seed_msb,
|
||||
uint64_t seed_lsb);
|
||||
|
||||
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
|
||||
decryptReturnedValues(KeySet &keySet, std::istream &istream);
|
||||
|
||||
outcome::checked<decrypted_scalar_t, StringError>
|
||||
decryptReturnedScalar(KeySet &keySet, std::istream &istream);
|
||||
|
||||
outcome::checked<decrypted_tensor_1_t, StringError>
|
||||
decryptReturnedTensor1(KeySet &keySet, std::istream &istream);
|
||||
|
||||
outcome::checked<decrypted_tensor_2_t, StringError>
|
||||
decryptReturnedTensor2(KeySet &keySet, std::istream &istream);
|
||||
|
||||
outcome::checked<decrypted_tensor_3_t, StringError>
|
||||
decryptReturnedTensor3(KeySet &keySet, std::istream &istream);
|
||||
|
||||
public:
|
||||
ClientParameters clientParameters;
|
||||
};
|
||||
|
||||
template <typename Result>
|
||||
outcome::checked<Result, StringError>
|
||||
topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
|
||||
std::istream &istream);
|
||||
|
||||
template <typename Result, typename... Args>
|
||||
class TypedClientLambda : public ClientLambda {
|
||||
|
||||
public:
|
||||
static outcome::checked<TypedClientLambda<Result, Args...>, StringError>
|
||||
load(std::string funcName, std::string jsonPath) {
|
||||
OUTCOME_TRY(auto lambda, ClientLambda::load(funcName, jsonPath));
|
||||
return TypedClientLambda(lambda);
|
||||
}
|
||||
|
||||
/// Emit a call on this lambda to a binary ostream.
|
||||
/// The ostream is responsible for transporting the call to a
|
||||
/// ServerLambda::real_call_write function. ostream must be in binary mode
|
||||
/// std::ios_base::openmode::binary
|
||||
outcome::checked<void, StringError>
|
||||
serializeCall(Args... args, std::shared_ptr<KeySet> keySet,
|
||||
std::ostream &ostream) {
|
||||
OUTCOME_TRY(auto publicArguments, publicArguments(args..., keySet));
|
||||
return ClientLambda::untypedSerializeCall(publicArguments, ostream);
|
||||
}
|
||||
|
||||
outcome::checked<PublicArguments, StringError>
|
||||
publicArguments(Args... args, std::shared_ptr<KeySet> keySet) {
|
||||
OUTCOME_TRY(auto clientArguments, EncryptedArgs::create(keySet, args...));
|
||||
return clientArguments->asPublicArguments(clientParameters,
|
||||
keySet->runtimeContext());
|
||||
}
|
||||
|
||||
outcome::checked<Result, StringError> decryptReturned(KeySet &keySet,
|
||||
std::istream &istream) {
|
||||
return topLevelDecryptResult<Result>((*this), keySet, istream);
|
||||
}
|
||||
|
||||
TypedClientLambda(ClientLambda &lambda) : ClientLambda(lambda) {
|
||||
// TODO: check parameter types
|
||||
// TODO: add static check on types vs lambda inputs/outpus
|
||||
}
|
||||
|
||||
protected:
|
||||
// Workaround, gcc 6 does not support partial template specialisation in class
|
||||
template <typename Result_>
|
||||
friend outcome::checked<Result_, StringError>
|
||||
topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
|
||||
std::istream &istream);
|
||||
};
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_scalar_t, StringError>
|
||||
topLevelDecryptResult<decrypted_scalar_t>(ClientLambda &lambda, KeySet &keySet,
|
||||
std::istream &istream);
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_tensor_1_t, StringError>
|
||||
topLevelDecryptResult<decrypted_tensor_1_t>(ClientLambda &lambda,
|
||||
KeySet &keySet,
|
||||
std::istream &istream);
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_tensor_2_t, StringError>
|
||||
topLevelDecryptResult<decrypted_tensor_2_t>(ClientLambda &lambda,
|
||||
KeySet &keySet,
|
||||
std::istream &istream);
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -5,18 +5,27 @@
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_CLIENTPARAMETERS_H_
|
||||
#define CONCRETELANG_CLIENTLIB_CLIENTPARAMETERS_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/Common/Error.h"
|
||||
|
||||
#include <llvm/Support/JSON.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using concretelang::error::StringError;
|
||||
|
||||
const std::string SMALL_KEY = "small";
|
||||
const std::string BIG_KEY = "big";
|
||||
|
||||
const std::string CLIENT_PARAMETERS_EXT = ".concrete.params.json";
|
||||
|
||||
typedef size_t DecompositionLevelCount;
|
||||
typedef size_t DecompositionBaseLog;
|
||||
typedef size_t PolynomialSize;
|
||||
@@ -31,6 +40,8 @@ struct LweSecretKeyParam {
|
||||
LweDimension size;
|
||||
|
||||
void hash(size_t &seed);
|
||||
inline uint64_t lweDimension() { return size; }
|
||||
inline uint64_t lweSize() { return size + 1; }
|
||||
};
|
||||
static bool operator==(const LweSecretKeyParam &lhs,
|
||||
const LweSecretKeyParam &rhs) {
|
||||
@@ -121,8 +132,17 @@ struct ClientParameters {
|
||||
std::vector<CircuitGate> inputs;
|
||||
std::vector<CircuitGate> outputs;
|
||||
std::string functionName;
|
||||
|
||||
size_t hash();
|
||||
|
||||
static outcome::checked<std::vector<ClientParameters>, StringError>
|
||||
load(std::string path);
|
||||
|
||||
static std::string getClientParametersPath(std::string path);
|
||||
|
||||
LweSecretKeyParam lweSecretKeyParam(CircuitGate gate);
|
||||
};
|
||||
|
||||
static inline bool operator==(const ClientParameters &lhs,
|
||||
const ClientParameters &rhs) {
|
||||
return lhs.secretKeys == rhs.secretKeys &&
|
||||
@@ -154,12 +174,13 @@ bool fromJSON(const llvm::json::Value, CircuitGate &, llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const ClientParameters &);
|
||||
bool fromJSON(const llvm::json::Value, ClientParameters &, llvm::json::Path);
|
||||
|
||||
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
ClientParameters cp) {
|
||||
return OS << llvm::formatv("{0:2}", toJSON(cp));
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
133
compiler/include/concretelang/ClientLib/EncryptedArgs.h
Normal file
133
compiler/include/concretelang/ClientLib/EncryptedArgs.h
Normal file
@@ -0,0 +1,133 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_ENCRYPTED_ARGS_H
|
||||
#define CONCRETELANG_CLIENTLIB_ENCRYPTED_ARGS_H
|
||||
|
||||
#include <ostream>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "../Common/Error.h"
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
#include "concretelang/Common/BitsSize.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using concretelang::error::StringError;
|
||||
|
||||
class PublicArguments;
|
||||
|
||||
class EncryptedArgs {
|
||||
/// Temporary object used to hold and encrypt parameters before calling a
|
||||
/// ClientLambda. Use preferably TypeClientLambda and serializeCall(Args...).
|
||||
/// Otherwise convert it to a PublicArguments and use
|
||||
/// serializeCall(PublicArguments, KeySet).
|
||||
public:
|
||||
// Create EncryptedArgument that use the given KeySet to perform
|
||||
// encryption/decryption operations.
|
||||
template <typename... Args>
|
||||
static outcome::checked<std::shared_ptr<EncryptedArgs>, StringError>
|
||||
create(std::shared_ptr<KeySet> keySet, Args... args) {
|
||||
auto arguments = std::make_shared<EncryptedArgs>();
|
||||
OUTCOME_TRYV(arguments->pushArgs(keySet, args...));
|
||||
return arguments;
|
||||
}
|
||||
|
||||
/** Low level interface */
|
||||
public:
|
||||
// Add a scalar argument.
|
||||
outcome::checked<void, StringError> pushArg(uint8_t arg,
|
||||
std::shared_ptr<KeySet> keySet);
|
||||
|
||||
// Add a vector-tensor argument.
|
||||
outcome::checked<void, StringError> pushArg(std::vector<uint8_t> arg,
|
||||
std::shared_ptr<KeySet> keySet);
|
||||
|
||||
template <size_t size>
|
||||
outcome::checked<void, StringError> pushArg(std::array<uint8_t, size> arg,
|
||||
std::shared_ptr<KeySet> keySet) {
|
||||
return pushArg(8, (void *)arg.data(), {size}, keySet);
|
||||
}
|
||||
|
||||
// Add a matrix-tensor argument.
|
||||
template <size_t size0, size_t size1>
|
||||
outcome::checked<void, StringError>
|
||||
pushArg(std::array<std::array<uint8_t, size1>, size0> arg,
|
||||
std::shared_ptr<KeySet> keySet) {
|
||||
return pushArg(8, (void *)arg.data(), {size0, size1}, keySet);
|
||||
}
|
||||
|
||||
// Add a rank3 tensor.
|
||||
template <size_t size0, size_t size1, size_t size2>
|
||||
outcome::checked<void, StringError>
|
||||
pushArg(std::array<std::array<std::array<uint8_t, size2>, size1>, size0> arg,
|
||||
std::shared_ptr<KeySet> keySet) {
|
||||
return pushArg(8, (void *)arg.data(), {size0, size1, size2}, keySet);
|
||||
}
|
||||
|
||||
// Generalize by computing shape by template recursion
|
||||
|
||||
// Set a argument at the given pos as a 1D tensor of T.
|
||||
template <typename T>
|
||||
outcome::checked<void, StringError> pushArg(T *data, size_t dim1,
|
||||
std::shared_ptr<KeySet> keySet) {
|
||||
return pushArg<T>(data, llvm::ArrayRef<size_t>(&dim1, 1), keySet);
|
||||
}
|
||||
|
||||
// Set a argument at the given pos as a tensor of T.
|
||||
template <typename T>
|
||||
outcome::checked<void, StringError> pushArg(T *data,
|
||||
llvm::ArrayRef<int64_t> shape,
|
||||
std::shared_ptr<KeySet> keySet) {
|
||||
return pushArg(8 * sizeof(T), static_cast<void *>(data), shape, keySet);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError> pushArg(size_t width, void *data,
|
||||
llvm::ArrayRef<int64_t> shape,
|
||||
std::shared_ptr<KeySet> keySet);
|
||||
|
||||
template <typename Arg0, typename... OtherArgs>
|
||||
outcome::checked<void, StringError> pushArgs(std::shared_ptr<KeySet> keySet,
|
||||
Arg0 arg0, OtherArgs... others) {
|
||||
OUTCOME_TRYV(pushArg(arg0, keySet));
|
||||
return pushArgs(keySet, others...);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError> pushArgs(std::shared_ptr<KeySet> keySet) {
|
||||
return checkAllArgs(keySet);
|
||||
}
|
||||
|
||||
outcome::checked<PublicArguments, StringError>
|
||||
asPublicArguments(ClientParameters clientParameters,
|
||||
RuntimeContext runtimeContext);
|
||||
|
||||
EncryptedArgs();
|
||||
~EncryptedArgs();
|
||||
|
||||
private:
|
||||
outcome::checked<void, StringError>
|
||||
checkPushTooManyArgs(std::shared_ptr<KeySet> keySetPtr);
|
||||
outcome::checked<void, StringError>
|
||||
checkAllArgs(std::shared_ptr<KeySet> keySet);
|
||||
// Add a scalar argument.
|
||||
outcome::checked<void, StringError> pushArg(uint64_t arg,
|
||||
std::shared_ptr<KeySet> keySet);
|
||||
|
||||
// Position of the next pushed argument
|
||||
size_t currentPos;
|
||||
std::vector<void *> preparedArgs;
|
||||
|
||||
// Store buffers of ciphertexts
|
||||
std::vector<encrypted_scalars_and_sizes_t> ciphertextBuffers;
|
||||
};
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -3,11 +3,13 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_SUPPORT_KEYSET_H_
|
||||
#define CONCRETELANG_SUPPORT_KEYSET_H_
|
||||
#ifndef CONCRETELANG_CLIENTLIB_KEYSET_H_
|
||||
#define CONCRETELANG_CLIENTLIB_KEYSET_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
extern "C" {
|
||||
#include "concrete-ffi.h"
|
||||
}
|
||||
@@ -15,26 +17,26 @@ extern "C" {
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using concretelang::error::StringError;
|
||||
using RuntimeContext = mlir::concretelang::RuntimeContext;
|
||||
|
||||
class KeySet {
|
||||
public:
|
||||
~KeySet();
|
||||
|
||||
static std::unique_ptr<KeySet> uninitialized();
|
||||
|
||||
llvm::Error generateKeysFromParams(ClientParameters ¶ms,
|
||||
uint64_t seed_msb, uint64_t seed_lsb);
|
||||
|
||||
llvm::Error setupEncryptionMaterial(ClientParameters ¶ms,
|
||||
uint64_t seed_msb, uint64_t seed_lsb);
|
||||
|
||||
// allocate a KeySet according the ClientParameters.
|
||||
static llvm::Expected<std::unique_ptr<KeySet>>
|
||||
static outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
generate(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb);
|
||||
|
||||
static outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
generateCached(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb);
|
||||
|
||||
// isInputEncrypted return true if the input at the given pos is encrypted.
|
||||
bool isInputEncrypted(size_t pos);
|
||||
|
||||
@@ -57,17 +59,19 @@ public:
|
||||
|
||||
// allocate a lwe ciphertext buffer for the argument at argPos, set the size
|
||||
// of the allocated buffer.
|
||||
llvm::Error allocate_lwe(size_t argPos, uint64_t **ciphertext,
|
||||
uint64_t &size);
|
||||
outcome::checked<void, StringError>
|
||||
allocate_lwe(size_t argPos, uint64_t **ciphertext, uint64_t &size);
|
||||
|
||||
// encrypt the input to the ciphertext for the argument at argPos.
|
||||
llvm::Error encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input);
|
||||
outcome::checked<void, StringError>
|
||||
encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input);
|
||||
|
||||
// isOuputEncrypted return true if the output at the given pos is encrypted.
|
||||
bool isOutputEncrypted(size_t pos);
|
||||
|
||||
// decrypt the ciphertext to the output for the argument at argPos.
|
||||
llvm::Error decrypt_lwe(size_t argPos, uint64_t *ciphertext,
|
||||
uint64_t &output);
|
||||
outcome::checked<void, StringError>
|
||||
decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output);
|
||||
|
||||
size_t numInputs() { return inputs.size(); }
|
||||
size_t numOutputs() { return outputs.size(); }
|
||||
@@ -77,8 +81,14 @@ public:
|
||||
|
||||
void setRuntimeContext(RuntimeContext &context) {
|
||||
context.ksk = std::get<1>(this->keyswitchKeys["ksk_v0"]);
|
||||
context.bsk["_concretelang_base_context_bsk"] =
|
||||
std::get<1>(this->bootstrapKeys["bsk_v0"]);
|
||||
context.bsk[RuntimeContext::BASE_CONTEXT_BSK] =
|
||||
std::get<1>(this->bootstrapKeys.at("bsk_v0"));
|
||||
}
|
||||
|
||||
RuntimeContext runtimeContext() {
|
||||
RuntimeContext context;
|
||||
this->setRuntimeContext(context);
|
||||
return context;
|
||||
}
|
||||
|
||||
const std::map<LweSecretKeyID,
|
||||
@@ -94,12 +104,23 @@ public:
|
||||
getKeyswitchKeys();
|
||||
|
||||
protected:
|
||||
llvm::Error generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param,
|
||||
SecretRandomGenerator *generator);
|
||||
llvm::Error generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param,
|
||||
EncryptionRandomGenerator *generator);
|
||||
llvm::Error generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param,
|
||||
EncryptionRandomGenerator *generator);
|
||||
outcome::checked<void, StringError>
|
||||
generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param,
|
||||
SecretRandomGenerator *generator);
|
||||
outcome::checked<void, StringError>
|
||||
generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param,
|
||||
EncryptionRandomGenerator *generator);
|
||||
outcome::checked<void, StringError>
|
||||
generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param,
|
||||
EncryptionRandomGenerator *generator);
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
generateKeysFromParams(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb);
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
setupEncryptionMaterial(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb);
|
||||
|
||||
friend class KeySetCache;
|
||||
|
||||
@@ -127,7 +148,7 @@ private:
|
||||
keyswitchKeys);
|
||||
};
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
|
||||
@@ -3,13 +3,13 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_SUPPORT_KEYSETCACHE_H_
|
||||
#define CONCRETELANG_SUPPORT_KEYSETCACHE_H_
|
||||
#ifndef CONCRETELANG_CLIENTLIB_KEYSETCACHE_H_
|
||||
#define CONCRETELANG_CLIENTLIB_KEYSETCACHE_H_
|
||||
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
class KeySet;
|
||||
|
||||
@@ -20,17 +20,21 @@ public:
|
||||
KeySetCache(std::string backingDirectoryPath)
|
||||
: backingDirectoryPath(backingDirectoryPath) {}
|
||||
|
||||
llvm::Expected<std::unique_ptr<KeySet>>
|
||||
tryLoadOrGenerateSave(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb);
|
||||
static outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
generate(std::shared_ptr<KeySetCache> optionalCache, ClientParameters ¶ms,
|
||||
uint64_t seed_msb, uint64_t seed_lsb);
|
||||
|
||||
private:
|
||||
static llvm::Expected<std::unique_ptr<KeySet>>
|
||||
tryLoadKeys(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb,
|
||||
llvm::SmallString<0> &folderPath);
|
||||
static outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
loadKeys(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb,
|
||||
std::string folderPath);
|
||||
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
loadOrGenerateSave(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb);
|
||||
};
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
66
compiler/include/concretelang/ClientLib/PublicArguments.h
Normal file
66
compiler/include/concretelang/ClientLib/PublicArguments.h
Normal file
@@ -0,0 +1,66 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_PUBLIC_ARGUMENTS_H
|
||||
#define CONCRETELANG_CLIENTLIB_PUBLIC_ARGUMENTS_H
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/EncryptedArgs.h"
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Runtime/context.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace serverlib {
|
||||
class ServerLambda;
|
||||
}
|
||||
} // namespace concretelang
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using concretelang::error::StringError;
|
||||
|
||||
class EncryptedArgs;
|
||||
class PublicArguments {
|
||||
/// PublicArguments will be sended to the server. It includes encrypted
|
||||
/// arguments and public keys.
|
||||
public:
|
||||
PublicArguments(
|
||||
const ClientParameters &clientParameters, RuntimeContext runtimeContext,
|
||||
bool clearRuntimeContext, std::vector<void *> &&preparedArgs,
|
||||
std::vector<encrypted_scalars_and_sizes_t> &&ciphertextBuffers);
|
||||
PublicArguments(PublicArguments &other) = delete;
|
||||
// to have proper owership transfer (outcome and local object)
|
||||
PublicArguments(PublicArguments &&other);
|
||||
~PublicArguments();
|
||||
|
||||
void freeIfNotOwned(std::vector<encrypted_scalar_t> res);
|
||||
|
||||
static outcome::checked<std::shared_ptr<PublicArguments>, StringError>
|
||||
unserialize(ClientParameters &expectedParams, std::istream &istream);
|
||||
|
||||
outcome::checked<void, StringError> serialize(std::ostream &ostream);
|
||||
|
||||
private:
|
||||
friend class ::concretelang::serverlib::ServerLambda; // from ServerLib
|
||||
|
||||
outcome::checked<void, StringError> unserializeArgs(std::istream &istream);
|
||||
|
||||
ClientParameters clientParameters;
|
||||
RuntimeContext runtimeContext;
|
||||
std::vector<void *> preparedArgs;
|
||||
// Store buffers of ciphertexts
|
||||
std::vector<encrypted_scalars_and_sizes_t> ciphertextBuffers;
|
||||
bool clearRuntimeContext;
|
||||
};
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
78
compiler/include/concretelang/ClientLib/Serializers.h
Normal file
78
compiler/include/concretelang/ClientLib/Serializers.h
Normal file
@@ -0,0 +1,78 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_SERIALIZERS_ARGUMENTS_H
|
||||
#define CONCRETELANG_CLIENTLIB_SERIALIZERS_ARGUMENTS_H
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
#include "concretelang/Runtime/context.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using RuntimeContext = mlir::concretelang::RuntimeContext;
|
||||
|
||||
// integers are not serialized as binary values even on a binary stream
|
||||
// so we cannot rely on << operator directly
|
||||
template <typename Word>
|
||||
std::ostream &writeWord(std::ostream &ostream, Word word) {
|
||||
ostream.write(reinterpret_cast<char *>(&(word)), sizeof(word));
|
||||
assert(ostream.good());
|
||||
return ostream;
|
||||
}
|
||||
|
||||
template <typename Size>
|
||||
std::ostream &writeSize(std::ostream &ostream, Size size) {
|
||||
return writeWord(ostream, size);
|
||||
}
|
||||
|
||||
// for sake of symetry
|
||||
template <typename Word>
|
||||
std::istream &readWord(std::istream &istream, Word &word) {
|
||||
istream.read(reinterpret_cast<char *>(&(word)), sizeof(word));
|
||||
assert(istream.good());
|
||||
return istream;
|
||||
}
|
||||
|
||||
template <typename Size>
|
||||
std::istream &readSize(std::istream &istream, Size &size) {
|
||||
return readWord(istream, size);
|
||||
}
|
||||
|
||||
template <typename Stream> bool incorrectMode(Stream &stream) {
|
||||
auto binary = stream.flags() && std::ios::binary;
|
||||
if (!binary) {
|
||||
stream.setstate(std::ios::failbit);
|
||||
}
|
||||
return !binary;
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream, const ClientParameters ¶ms);
|
||||
std::istream &operator>>(std::istream &istream, ClientParameters ¶ms);
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const RuntimeContext &runtimeContext);
|
||||
std::istream &operator>>(std::istream &istream, RuntimeContext &runtimeContext);
|
||||
|
||||
std::ostream &serializeEncryptedValues(std::vector<size_t> &sizes,
|
||||
encrypted_scalars_t values,
|
||||
std::ostream &ostream);
|
||||
|
||||
std::ostream &
|
||||
serializeEncryptedValues(encrypted_scalars_and_sizes_t &values_and_sizes,
|
||||
std::ostream &ostream);
|
||||
|
||||
encrypted_scalars_and_sizes_t unserializeEncryptedValues(
|
||||
std::vector<int64_t> &expectedSizes, // includes lweSize, unsigned to
|
||||
// accomodate non static sizes
|
||||
std::istream &istream);
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
58
compiler/include/concretelang/ClientLib/Types.h
Normal file
58
compiler/include/concretelang/ClientLib/Types.h
Normal file
@@ -0,0 +1,58 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_TYPES_H_
|
||||
#define CONCRETELANG_CLIENTLIB_TYPES_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
extern "C" {
|
||||
#include "concrete-ffi.h"
|
||||
}
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
template <size_t N> struct MemRefDescriptor {
|
||||
uint64_t *allocated;
|
||||
uint64_t *aligned;
|
||||
size_t offset;
|
||||
size_t sizes[N];
|
||||
size_t strides[N];
|
||||
};
|
||||
|
||||
using decrypted_scalar_t = std::uint64_t;
|
||||
using decrypted_tensor_1_t = std::vector<decrypted_scalar_t>;
|
||||
using decrypted_tensor_2_t = std::vector<decrypted_tensor_1_t>;
|
||||
using decrypted_tensor_3_t = std::vector<decrypted_tensor_2_t>;
|
||||
|
||||
template <size_t Rank> using encrypted_tensor_t = MemRefDescriptor<Rank>;
|
||||
using encrypted_scalar_t = uint64_t *;
|
||||
using encrypted_scalars_t = uint64_t *;
|
||||
|
||||
struct encrypted_scalars_and_sizes_t {
|
||||
std::vector<uint64_t> values; // tensor of rank r + 1
|
||||
std::vector<size_t> sizes; // r sizes
|
||||
|
||||
inline size_t length() {
|
||||
if (sizes.empty()) {
|
||||
assert(false);
|
||||
return 0;
|
||||
}
|
||||
size_t len = 1;
|
||||
for (auto size : sizes) {
|
||||
assert(size > 0);
|
||||
len *= size;
|
||||
}
|
||||
return len;
|
||||
}
|
||||
|
||||
inline size_t lweSize() { return sizes.back(); }
|
||||
};
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
#endif
|
||||
19
compiler/include/concretelang/Common/BitsSize.h
Normal file
19
compiler/include/concretelang/Common/BitsSize.h
Normal file
@@ -0,0 +1,19 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_COMMON_BITS_SIZE_H
|
||||
#define CONCRETELANG_COMMON_BITS_SIZE_H
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
namespace concretelang {
|
||||
namespace common {
|
||||
|
||||
size_t bitWidthAsWord(size_t exactBitWidth);
|
||||
|
||||
}
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
43
compiler/include/concretelang/Common/Error.h
Normal file
43
compiler/include/concretelang/Common/Error.h
Normal file
@@ -0,0 +1,43 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
#ifndef CONCRETELANG_COMMON_ERROR_H
|
||||
#define CONCRETELANG_COMMON_ERROR_H
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace concretelang {
|
||||
namespace error {
|
||||
|
||||
class StringError {
|
||||
public:
|
||||
StringError(std::string mesg) : mesg(mesg){};
|
||||
|
||||
std::string mesg;
|
||||
|
||||
StringError &operator<<(const std::string &v) {
|
||||
mesg += v;
|
||||
return *this;
|
||||
}
|
||||
|
||||
StringError &operator<<(const char *v) {
|
||||
mesg += std::string(v);
|
||||
return *this;
|
||||
}
|
||||
|
||||
StringError &operator<<(char *v) {
|
||||
mesg += std::string(v);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename T> inline StringError &operator<<(const T v) {
|
||||
mesg += std::to_string(v);
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace error
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -15,17 +15,20 @@ extern "C" {
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
typedef struct RuntimeContext {
|
||||
LweKeyswitchKey_u64 *ksk;
|
||||
std::map<std::string, LweBootstrapKey_u64 *> bsk;
|
||||
|
||||
static std::string BASE_CONTEXT_BSK;
|
||||
~RuntimeContext() {
|
||||
for (const auto &key : bsk) {
|
||||
if (key.first != "_concretelang_base_context_bsk")
|
||||
if (key.first != BASE_CONTEXT_BSK)
|
||||
free_lwe_bootstrap_key_u64(key.second);
|
||||
}
|
||||
}
|
||||
} RuntimeContext;
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -3,20 +3,23 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_TESTLIB_DYNAMIC_ARITY_CALL_H
|
||||
#define CONCRETELANG_TESTLIB_DYNAMIC_ARITY_CALL_H
|
||||
// generated: see genDynamicRandAndArityCall.py
|
||||
|
||||
#ifndef CONCRETELANG_SERVERLIB_DYNAMIC_ARITY_CALL_H
|
||||
#define CONCRETELANG_SERVERLIB_DYNAMIC_ARITY_CALL_H
|
||||
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
|
||||
namespace mlir {
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace serverlib {
|
||||
|
||||
template <typename Res>
|
||||
Res call(Res (*func)(void *...), std::vector<void *> args) {
|
||||
Res multi_arity_call(Res (*func)(void *...), std::vector<void *> args) {
|
||||
switch (args.size()) {
|
||||
// generated part: see genDynamicArityCall.py
|
||||
// TODO C17++: https://en.cppreference.com/w/cpp/utility/apply
|
||||
// TODO C17++: https://en.cppreference.com/w/cpp/utility/apply
|
||||
case 1:
|
||||
return func(args[0]);
|
||||
case 2:
|
||||
@@ -1457,12 +1460,13 @@ Res call(Res (*func)(void *...), std::vector<void *> args) {
|
||||
args[111], args[112], args[113], args[114], args[115], args[116],
|
||||
args[117], args[118], args[119], args[120], args[121], args[122],
|
||||
args[123], args[124], args[125], args[126]);
|
||||
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace serverlib
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -6,28 +6,37 @@
|
||||
#ifndef CONCRETELANG_TESTLIB_DYNAMIC_MODULE_H
|
||||
#define CONCRETELANG_TESTLIB_DYNAMIC_MODULE_H
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace serverlib {
|
||||
|
||||
using concretelang::clientlib::ClientParameters;
|
||||
using concretelang::error::StringError;
|
||||
|
||||
class DynamicModule {
|
||||
public:
|
||||
~DynamicModule();
|
||||
static llvm::Expected<std::shared_ptr<DynamicModule>>
|
||||
|
||||
static outcome::checked<std::shared_ptr<DynamicModule>, StringError>
|
||||
open(std::string libraryPath);
|
||||
|
||||
private:
|
||||
llvm::Error loadClientParametersJSON(std::string path);
|
||||
llvm::Error loadSharedLibrary(std::string path);
|
||||
outcome::checked<void, StringError>
|
||||
loadClientParametersJSON(std::string path);
|
||||
|
||||
outcome::checked<void, StringError> loadSharedLibrary(std::string path);
|
||||
|
||||
private:
|
||||
std::vector<ClientParameters> clientParametersList;
|
||||
void *libraryHandle;
|
||||
|
||||
friend class DynamicLambda;
|
||||
friend class ServerLambda;
|
||||
};
|
||||
|
||||
} // namespace serverlib
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
#endif
|
||||
25
compiler/include/concretelang/ServerLib/DynamicRankCall.h
Normal file
25
compiler/include/concretelang/ServerLib/DynamicRankCall.h
Normal file
@@ -0,0 +1,25 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_SERVERLIB_DYNAMIC_RANK_CALL_H
|
||||
#define CONCRETELANG_SERVERLIB_DYNAMIC_RANK_CALL_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace serverlib {
|
||||
|
||||
using concretelang::clientlib::encrypted_scalars_and_sizes_t;
|
||||
|
||||
encrypted_scalars_and_sizes_t
|
||||
multi_arity_call_dynamic_rank(void *(*func)(void *...),
|
||||
std::vector<void *> args, size_t rank);
|
||||
|
||||
} // namespace serverlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
52
compiler/include/concretelang/ServerLib/ServerLambda.h
Normal file
52
compiler/include/concretelang/ServerLib/ServerLambda.h
Normal file
@@ -0,0 +1,52 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_SERVERLIB_SERVER_LAMBDA_H
|
||||
#define CONCRETELANG_SERVERLIB_SERVER_LAMBDA_H
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/PublicArguments.h"
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/ServerLib/DynamicModule.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace serverlib {
|
||||
|
||||
using concretelang::clientlib::encrypted_scalar_t;
|
||||
using concretelang::clientlib::encrypted_scalars_and_sizes_t;
|
||||
using concretelang::clientlib::encrypted_scalars_t;
|
||||
|
||||
encrypted_scalars_and_sizes_t encrypted_scalars_and_sizes_t_from_MemRef(
|
||||
size_t rank, encrypted_scalars_t allocated, encrypted_scalars_t aligned,
|
||||
size_t offset, size_t *sizes, size_t *strides);
|
||||
|
||||
class ServerLambda {
|
||||
|
||||
public:
|
||||
static outcome::checked<ServerLambda, concretelang::error::StringError>
|
||||
load(std::string funcName, std::string outputLib);
|
||||
|
||||
static outcome::checked<ServerLambda, concretelang::error::StringError>
|
||||
loadFromModule(std::shared_ptr<DynamicModule> module, std::string funcName);
|
||||
|
||||
outcome::checked<void, concretelang::error::StringError>
|
||||
read_call_write(std::istream &istream, std::ostream &ostream);
|
||||
|
||||
protected:
|
||||
ClientParameters clientParameters;
|
||||
void *(*func)(void *...);
|
||||
// Retain module and open shared lib alive
|
||||
std::shared_ptr<DynamicModule> module;
|
||||
};
|
||||
|
||||
} // namespace serverlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,45 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
|
||||
|
||||
print(
|
||||
"""// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
// generated: see genDynamicArityCall.py
|
||||
|
||||
#ifndef CONCRETELANG_SERVERLIB_DYNAMIC_ARITY_CALL_H
|
||||
#define CONCRETELANG_SERVERLIB_DYNAMIC_ARITY_CALL_H
|
||||
|
||||
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace serverlib {
|
||||
|
||||
|
||||
template <typename Res>
|
||||
Res multi_arity_call(Res (*func)(void *...), std::vector<void *> args) {
|
||||
switch (args.size()) {
|
||||
// TODO C17++: https://en.cppreference.com/w/cpp/utility/apply
|
||||
""")
|
||||
|
||||
for i in range(1, 128):
|
||||
args = ','.join(f'args[{j}]' for j in range(i))
|
||||
print(f' case {i}: return func({args});')
|
||||
print("""
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
}""")
|
||||
|
||||
print("""
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
""")
|
||||
@@ -99,8 +99,10 @@ public:
|
||||
llvm::Expected<std::string> emitStatic();
|
||||
/** Emit a shared library with the previously added compilation result */
|
||||
llvm::Expected<std::string> emitShared();
|
||||
/** Emit a shared library with the previously added compilation result */
|
||||
/** Emit a json ClientParameters corresponding to library content */
|
||||
llvm::Expected<std::string> emitClientParametersJSON();
|
||||
/// Emit a client header file for this corresponding to library content
|
||||
llvm::Expected<std::string> emitCppHeader();
|
||||
};
|
||||
|
||||
// Specification of the exit stage of the compilation pipeline
|
||||
|
||||
@@ -15,7 +15,8 @@
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
size_t bitWidthAsWord(size_t exactBitWidth);
|
||||
using ::concretelang::clientlib::CircuitGate;
|
||||
using ::concretelang::clientlib::KeySet;
|
||||
|
||||
/// JITLambda is a tool to JIT compile an mlir module and to invoke a function
|
||||
/// of the module.
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
using ::concretelang::clientlib::KeySetCache;
|
||||
|
||||
namespace {
|
||||
// Generic function template as well as specializations of
|
||||
// `typedResult` must be declared at namespace scope due to return
|
||||
|
||||
@@ -15,8 +15,7 @@
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
ClientParameters emptyClientParametersForV0(llvm::StringRef functionName,
|
||||
mlir::ModuleOp module);
|
||||
using ::concretelang::clientlib::ClientParameters;
|
||||
|
||||
llvm::Expected<ClientParameters>
|
||||
createClientParametersForV0(V0FHEContext context, llvm::StringRef functionName,
|
||||
|
||||
@@ -1,105 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_TESTLIB_ARGUMENTS_H
|
||||
#define CONCRETELANG_TESTLIB_ARGUMENTS_H
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
class DynamicLambda;
|
||||
|
||||
class Arguments {
|
||||
public:
|
||||
Arguments(KeySet &keySet) : currentPos(0), keySet(keySet) {
|
||||
keySet.setRuntimeContext(context);
|
||||
}
|
||||
|
||||
~Arguments();
|
||||
|
||||
// Create EncryptedArgument that use the given KeySet to perform encryption
|
||||
// and decryption operations.
|
||||
static std::shared_ptr<Arguments> create(KeySet &keySet);
|
||||
|
||||
// Add a scalar argument.
|
||||
llvm::Error pushArg(uint64_t arg);
|
||||
|
||||
// Add a vector-tensor argument.
|
||||
llvm::Error pushArg(std::vector<uint8_t> arg);
|
||||
|
||||
template <size_t size> llvm::Error pushArg(std::array<uint8_t, size> arg) {
|
||||
return pushArg(8, (void *)arg.data(), {size});
|
||||
}
|
||||
|
||||
// Add a matrix-tensor argument.
|
||||
template <size_t size0, size_t size1>
|
||||
llvm::Error pushArg(std::array<std::array<uint8_t, size1>, size0> arg) {
|
||||
return pushArg(8, (void *)arg.data(), {size0, size1});
|
||||
}
|
||||
|
||||
// Add a rank3 tensor.
|
||||
template <size_t size0, size_t size1, size_t size2>
|
||||
llvm::Error pushArg(
|
||||
std::array<std::array<std::array<uint8_t, size2>, size1>, size0> arg) {
|
||||
return pushArg(8, (void *)arg.data(), {size0, size1, size2});
|
||||
}
|
||||
|
||||
// Generalize by computing shape by template recursion
|
||||
|
||||
// Set a argument at the given pos as a 1D tensor of T.
|
||||
template <typename T> llvm::Error pushArg(T *data, int64_t dim1) {
|
||||
return pushArg<T>(data, llvm::ArrayRef<int64_t>(&dim1, 1));
|
||||
}
|
||||
|
||||
// Set a argument at the given pos as a tensor of T.
|
||||
template <typename T>
|
||||
llvm::Error pushArg(T *data, llvm::ArrayRef<int64_t> shape) {
|
||||
return pushArg(8 * sizeof(T), static_cast<void *>(data), shape);
|
||||
}
|
||||
|
||||
llvm::Error pushArg(size_t width, void *data, llvm::ArrayRef<int64_t> shape);
|
||||
|
||||
// Push the runtime context to the argument list, this must be called
|
||||
// after each argument was pushed.
|
||||
llvm::Error pushContext();
|
||||
|
||||
template <typename Arg0, typename... OtherArgs>
|
||||
llvm::Error pushArgs(Arg0 arg0, OtherArgs... others) {
|
||||
auto err = pushArg(arg0);
|
||||
if (err) {
|
||||
return err;
|
||||
}
|
||||
return pushArgs(others...);
|
||||
}
|
||||
|
||||
llvm::Error pushArgs() { return pushContext(); }
|
||||
|
||||
private:
|
||||
friend DynamicLambda;
|
||||
template <typename Result>
|
||||
friend llvm::Expected<Result> invoke(DynamicLambda &lambda,
|
||||
const Arguments &args);
|
||||
llvm::Error checkPushTooManyArgs();
|
||||
|
||||
// Position of the next pushed argument
|
||||
size_t currentPos;
|
||||
std::vector<void *> preparedArgs;
|
||||
|
||||
// Store allocated lwe ciphertexts (for free)
|
||||
std::vector<uint64_t *> allocatedCiphertexts;
|
||||
// Store buffers of ciphertexts
|
||||
std::vector<uint64_t *> ciphertextBuffers;
|
||||
|
||||
KeySet &keySet;
|
||||
RuntimeContext context;
|
||||
};
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -1,123 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_TESTLIB_DYNAMIC_LAMBDA_H
|
||||
#define CONCRETELANG_TESTLIB_DYNAMIC_LAMBDA_H
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/TestLib/Arguments.h"
|
||||
#include "concretelang/TestLib/DynamicModule.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
template <size_t N> struct MemRefDescriptor;
|
||||
|
||||
template <typename Result>
|
||||
llvm::Expected<Result> invoke(DynamicLambda &lambda, const Arguments &args) {
|
||||
// compile time error if used
|
||||
using COMPATIBLE_RESULT_TYPE = void;
|
||||
return (Result)(COMPATIBLE_RESULT_TYPE)0; // invoke does not accept this kind
|
||||
// of Result
|
||||
}
|
||||
|
||||
template <>
|
||||
llvm::Expected<u_int64_t> invoke<u_int64_t>(DynamicLambda &lambda,
|
||||
const Arguments &args);
|
||||
|
||||
template <>
|
||||
llvm::Expected<std::vector<uint64_t>>
|
||||
invoke<std::vector<uint64_t>>(DynamicLambda &lambda, const Arguments &args);
|
||||
|
||||
template <>
|
||||
llvm::Expected<std::vector<std::vector<uint64_t>>>
|
||||
invoke<std::vector<std::vector<uint64_t>>>(DynamicLambda &lambda,
|
||||
const Arguments &args);
|
||||
|
||||
template <>
|
||||
llvm::Expected<std::vector<std::vector<std::vector<uint64_t>>>>
|
||||
invoke<std::vector<std::vector<std::vector<uint64_t>>>>(DynamicLambda &lambda,
|
||||
const Arguments &args);
|
||||
|
||||
class DynamicLambda {
|
||||
private:
|
||||
template <typename... Args>
|
||||
llvm::Expected<std::shared_ptr<Arguments>> createArguments(Args... args) {
|
||||
if (keySet == nullptr) {
|
||||
return StreamStringError("keySet was not initialized");
|
||||
}
|
||||
auto arg = Arguments::create(*keySet);
|
||||
auto err = arg->pushArgs(args...);
|
||||
if (err) {
|
||||
return StreamStringError(llvm::toString(std::move(err)));
|
||||
}
|
||||
return arg;
|
||||
}
|
||||
|
||||
public:
|
||||
static llvm::Expected<DynamicLambda> load(std::string funcName,
|
||||
std::string outputLib);
|
||||
|
||||
static llvm::Expected<DynamicLambda>
|
||||
load(std::shared_ptr<DynamicModule> module, std::string funcName);
|
||||
|
||||
template <typename Result, typename... Args>
|
||||
llvm::Expected<Result> call(Args... args) {
|
||||
auto argOrErr = createArguments(args...);
|
||||
if (!argOrErr) {
|
||||
return argOrErr.takeError();
|
||||
}
|
||||
auto arg = argOrErr.get();
|
||||
return invoke<Result>(*this, *arg);
|
||||
}
|
||||
|
||||
llvm::Error generateKeySet(llvm::Optional<KeySetCache> cache = llvm::None,
|
||||
uint64_t seed_msb = 0, uint64_t seed_lsb = 0);
|
||||
|
||||
protected:
|
||||
template <typename Result>
|
||||
friend llvm::Expected<Result> invoke(DynamicLambda &lambda,
|
||||
const Arguments &args);
|
||||
|
||||
template <size_t Rank>
|
||||
llvm::Expected<MemRefDescriptor<Rank>>
|
||||
invokeMemRefDecriptor(const Arguments &args);
|
||||
|
||||
ClientParameters clientParameters;
|
||||
std::shared_ptr<KeySet> keySet;
|
||||
void *func;
|
||||
// Retain module and open shared lib alive
|
||||
std::shared_ptr<DynamicModule> module;
|
||||
};
|
||||
|
||||
template <typename Result, typename... Args>
|
||||
class TypedDynamicLambda : public DynamicLambda {
|
||||
|
||||
public:
|
||||
static llvm::Expected<TypedDynamicLambda<Result, Args...>>
|
||||
load(std::string funcName, std::string outputLib) {
|
||||
auto lambda = DynamicLambda::load(funcName, outputLib);
|
||||
if (!lambda) {
|
||||
return lambda.takeError();
|
||||
}
|
||||
return TypedDynamicLambda(*lambda);
|
||||
}
|
||||
|
||||
llvm::Expected<Result> call(Args... args) {
|
||||
return DynamicLambda::call<Result>(args...);
|
||||
}
|
||||
|
||||
// TODO: check parameter types
|
||||
TypedDynamicLambda(DynamicLambda &lambda) : DynamicLambda(lambda) {
|
||||
// TODO: add static check on types vs lambda inputs/outpus
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
118
compiler/include/concretelang/TestLib/TestTypedLambda.h
Normal file
118
compiler/include/concretelang/TestLib/TestTypedLambda.h
Normal file
@@ -0,0 +1,118 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_TESTLIB_TEST_TYPED_LAMBDA_H
|
||||
#define CONCRETELANG_TESTLIB_TEST_TYPED_LAMBDA_H
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/ClientLib/ClientLambda.h"
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/ClientLib/Serializers.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/ServerLib/ServerLambda.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace testlib {
|
||||
|
||||
using concretelang::clientlib::ClientLambda;
|
||||
using concretelang::clientlib::ClientParameters;
|
||||
using concretelang::clientlib::KeySet;
|
||||
using concretelang::clientlib::KeySetCache;
|
||||
using concretelang::error::StringError;
|
||||
using concretelang::serverlib::ServerLambda;
|
||||
|
||||
inline void freeStringMemory(std::string &s) {
|
||||
std::string empty;
|
||||
s.swap(empty);
|
||||
}
|
||||
|
||||
template <typename Result, typename... Args>
|
||||
class TestTypedLambda
|
||||
: public concretelang::clientlib::TypedClientLambda<Result, Args...> {
|
||||
|
||||
template <typename Result_, typename... Args_>
|
||||
using TypedClientLambda =
|
||||
concretelang::clientlib::TypedClientLambda<Result_, Args_...>;
|
||||
|
||||
public:
|
||||
static outcome::checked<TestTypedLambda, StringError>
|
||||
load(std::string funcName, std::string outputLib, uint64_t seed_msb = 0,
|
||||
uint64_t seed_lsb = 0,
|
||||
std::shared_ptr<KeySetCache> unsecure_cache = nullptr) {
|
||||
std::string jsonPath = ClientParameters::getClientParametersPath(outputLib);
|
||||
OUTCOME_TRY(auto cLambda, ClientLambda::load(funcName, jsonPath));
|
||||
OUTCOME_TRY(auto sLambda, ServerLambda::load(funcName, outputLib));
|
||||
OUTCOME_TRY(std::shared_ptr<KeySet> keySet,
|
||||
KeySetCache::generate(unsecure_cache, cLambda.clientParameters,
|
||||
seed_msb, seed_lsb));
|
||||
return TestTypedLambda(cLambda, sLambda, keySet);
|
||||
}
|
||||
|
||||
TestTypedLambda(ClientLambda &cLambda, ServerLambda &sLambda,
|
||||
std::shared_ptr<KeySet> keySet)
|
||||
: TypedClientLambda<Result, Args...>(cLambda), serverLambda(sLambda),
|
||||
keySet(keySet) {}
|
||||
|
||||
TestTypedLambda(TypedClientLambda<Result, Args...> &cLambda,
|
||||
ServerLambda &sLambda, std::shared_ptr<KeySet> keySet)
|
||||
: TypedClientLambda<Result, Args...>(cLambda), serverLambda(sLambda),
|
||||
keySet(keySet) {}
|
||||
|
||||
outcome::checked<Result, StringError> call(Args... args) {
|
||||
// client
|
||||
auto BINARY = std::ios::binary;
|
||||
std::string message;
|
||||
{
|
||||
// client
|
||||
std::ostringstream clientOuput(BINARY);
|
||||
OUTCOME_TRYV(this->serializeCall(args..., keySet, clientOuput));
|
||||
if (clientOuput.fail()) {
|
||||
return StringError("Error on clientOuput");
|
||||
}
|
||||
message = clientOuput.str();
|
||||
}
|
||||
{
|
||||
// server
|
||||
std::istringstream serverInput(message, BINARY);
|
||||
freeStringMemory(message);
|
||||
assert(serverInput.tellg() == 0);
|
||||
std::ostringstream serverOutput(BINARY);
|
||||
OUTCOME_TRYV(serverLambda.read_call_write(serverInput, serverOutput));
|
||||
if (serverInput.fail()) {
|
||||
return StringError("Error on serverOutput");
|
||||
}
|
||||
if (serverOutput.fail()) {
|
||||
return StringError("Error on serverOutput");
|
||||
}
|
||||
message = serverOutput.str();
|
||||
}
|
||||
{
|
||||
// client
|
||||
std::istringstream clientInput(message, BINARY);
|
||||
freeStringMemory(message);
|
||||
OUTCOME_TRY(auto result, this->decryptReturned(*keySet, clientInput));
|
||||
assert(clientInput.good());
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
ServerLambda serverLambda;
|
||||
std::shared_ptr<KeySet> keySet;
|
||||
};
|
||||
|
||||
template <typename Result, typename... Args>
|
||||
static TestTypedLambda<Result, Args...> TestTypedLambdaFrom(
|
||||
concretelang::clientlib::TypedClientLambda<Result, Args...> &cLambda,
|
||||
ServerLambda &sLambda, std::shared_ptr<KeySet> keySet) {
|
||||
return TestTypedLambda<Result, Args...>(cLambda, sLambda, keySet);
|
||||
}
|
||||
|
||||
} // namespace testlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -1,6 +0,0 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
|
||||
|
||||
for i in range(128):
|
||||
args = ','.join(f'args[{j}]' for j in range(i))
|
||||
print(f' case {i}: return func({args});')
|
||||
@@ -50,7 +50,7 @@ lambdaArgument invokeLambda(lambda l, executionArguments args) {
|
||||
}
|
||||
// Set the integer/tensor arguments
|
||||
std::vector<mlir::concretelang::LambdaArgument *> lambdaArgumentsRef;
|
||||
for (auto i = 0; i < args.size; i++) {
|
||||
for (auto i = 0u; i < args.size; i++) {
|
||||
lambdaArgumentsRef.push_back(args.data[i].ptr.get());
|
||||
}
|
||||
// Run lambda
|
||||
|
||||
@@ -5,7 +5,8 @@ add_subdirectory(Support)
|
||||
add_subdirectory(Runtime)
|
||||
add_subdirectory(ClientLib)
|
||||
add_subdirectory(Bindings)
|
||||
add_subdirectory(TestLib)
|
||||
add_subdirectory(ServerLib)
|
||||
add_subdirectory(Common)
|
||||
|
||||
# CAPI needed only for python bindings
|
||||
if (CONCRETELANG_BINDINGS_PYTHON_ENABLED)
|
||||
|
||||
@@ -1,9 +1,29 @@
|
||||
add_compile_options( -Werror )
|
||||
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
||||
# using Clang
|
||||
add_compile_options( -Wno-error=pessimizing-move -Wno-pessimizing-move )
|
||||
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
|
||||
# using GCC
|
||||
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
|
||||
add_compile_options( -Werror -Wno-error=pessimizing-move -Wno-pessimizing-move )
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_mlir_library(
|
||||
ConcretelangClientLib
|
||||
ClientLambda.cpp
|
||||
ClientParameters.cpp
|
||||
EncryptedArgs.cpp
|
||||
KeySet.cpp
|
||||
KeySetCache.cpp
|
||||
PublicArguments.cpp
|
||||
Serializers.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/ClientLib
|
||||
|
||||
ConcretelangCommon
|
||||
ConcretelangRuntime
|
||||
ConcretelangSupportLib
|
||||
)
|
||||
207
compiler/lib/ClientLib/ClientLambda.cpp
Normal file
207
compiler/lib/ClientLib/ClientLambda.cpp
Normal file
@@ -0,0 +1,207 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <dlfcn.h>
|
||||
|
||||
#include "concretelang/ClientLib/ClientLambda.h"
|
||||
#include "concretelang/ClientLib/Serializers.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using concretelang::error::StringError;
|
||||
|
||||
outcome::checked<ClientLambda, StringError>
|
||||
ClientLambda::load(std::string functionName, std::string jsonPath) {
|
||||
OUTCOME_TRY(auto all_params, ClientParameters::load(jsonPath));
|
||||
auto param = llvm::find_if(all_params, [&](ClientParameters param) {
|
||||
return param.functionName == functionName;
|
||||
});
|
||||
|
||||
if (param == all_params.end()) {
|
||||
return StringError("ClientLambda: cannot find function ")
|
||||
<< functionName << " in client parameters" << jsonPath;
|
||||
}
|
||||
|
||||
if (param->outputs.size() != 1) {
|
||||
return StringError("ClientLambda: output arity (")
|
||||
<< std::to_string(param->outputs.size())
|
||||
<< ") != 1 is not supported";
|
||||
}
|
||||
|
||||
if (!param->outputs[0].encryption.hasValue()) {
|
||||
return StringError("ClientLambda: clear output is not yet supported");
|
||||
}
|
||||
ClientLambda lambda;
|
||||
lambda.clientParameters = *param;
|
||||
return lambda;
|
||||
}
|
||||
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
ClientLambda::keySet(std::shared_ptr<KeySetCache> optionalCache,
|
||||
uint64_t seed_msb, uint64_t seed_lsb) {
|
||||
return KeySetCache::generate(optionalCache, clientParameters, seed_msb,
|
||||
seed_lsb);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
ClientLambda::untypedSerializeCall(PublicArguments &serverArguments,
|
||||
std::ostream &ostream) {
|
||||
return serverArguments.serialize(ostream);
|
||||
}
|
||||
|
||||
outcome::checked<decrypted_scalar_t, StringError>
|
||||
ClientLambda::decryptReturnedScalar(KeySet &keySet, std::istream &istream) {
|
||||
OUTCOME_TRY(auto v, decryptReturnedValues(keySet, istream));
|
||||
return v[0];
|
||||
}
|
||||
|
||||
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
|
||||
ClientLambda::decryptReturnedValues(KeySet &keySet, std::istream &istream) {
|
||||
auto lweSize =
|
||||
clientParameters.lweSecretKeyParam(clientParameters.outputs[0]).lweSize();
|
||||
std::vector<int64_t> sizes = clientParameters.outputs[0].shape.dimensions;
|
||||
sizes.push_back(lweSize);
|
||||
auto encryptedValues = unserializeEncryptedValues(sizes, istream);
|
||||
if (istream.fail()) {
|
||||
return StringError("Encrypted scalars has not the right size");
|
||||
}
|
||||
auto len = encryptedValues.length();
|
||||
decrypted_tensor_1_t decryptedValues(len / lweSize);
|
||||
for (size_t i = 0; i < decryptedValues.size(); i++) {
|
||||
auto buffer = (uint64_t *)(&encryptedValues.values[i * lweSize]);
|
||||
OUTCOME_TRYV(keySet.decrypt_lwe(0, buffer, decryptedValues[i]));
|
||||
}
|
||||
return decryptedValues;
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError> errorResultRank(size_t expected,
|
||||
size_t actual) {
|
||||
return StringError("Expected result has rank ")
|
||||
<< expected << " and cannot be converted to rank " << actual;
|
||||
}
|
||||
|
||||
StringError errorIncoherentSizes(size_t flatSize, size_t structuredSize) {
|
||||
return StringError("Received ")
|
||||
<< flatSize << " values but is sizes indicates as global size of "
|
||||
<< structuredSize;
|
||||
}
|
||||
|
||||
template <typename DecryptedTensor>
|
||||
DecryptedTensor flatToTensor(decrypted_tensor_1_t &values, size_t *sizes);
|
||||
|
||||
template <>
|
||||
decrypted_tensor_1_t flatToTensor(decrypted_tensor_1_t &values, size_t *sizes) {
|
||||
return values;
|
||||
}
|
||||
|
||||
template <>
|
||||
decrypted_tensor_2_t flatToTensor(decrypted_tensor_1_t &values, size_t *sizes) {
|
||||
decrypted_tensor_2_t result(sizes[0]);
|
||||
size_t position = 0;
|
||||
for (auto &dest0 : result) {
|
||||
dest0.resize(sizes[1]);
|
||||
for (auto &dest1 : dest0) {
|
||||
dest1 = values[position++];
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
decrypted_tensor_3_t flatToTensor(decrypted_tensor_1_t &values, size_t *sizes) {
|
||||
decrypted_tensor_3_t result(sizes[0]);
|
||||
size_t position = 0;
|
||||
for (auto &dest0 : result) {
|
||||
dest0.resize(sizes[1]);
|
||||
for (auto &dest1 : dest0) {
|
||||
dest1.resize(sizes[2]);
|
||||
for (auto &dest2 : dest1) {
|
||||
dest2 = values[position++];
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename DecryptedTensor>
|
||||
outcome::checked<DecryptedTensor, StringError>
|
||||
decryptReturnedTensor(std::istream &istream, ClientLambda &lambda,
|
||||
ClientParameters ¶ms, size_t expectedRank,
|
||||
KeySet &keySet) {
|
||||
auto shape = params.outputs[0].shape;
|
||||
size_t rank = shape.dimensions.size();
|
||||
if (rank != expectedRank) {
|
||||
return StringError("Function returns a tensor of rank ")
|
||||
<< expectedRank << " which cannot be decrypted to rank " << rank;
|
||||
}
|
||||
OUTCOME_TRY(auto values, lambda.decryptReturnedValues(keySet, istream));
|
||||
size_t sizes[rank];
|
||||
for (size_t dim = 0; dim < rank; dim++) {
|
||||
sizes[dim] = shape.dimensions[dim];
|
||||
}
|
||||
return flatToTensor<DecryptedTensor>(values, sizes);
|
||||
}
|
||||
|
||||
outcome::checked<decrypted_tensor_1_t, StringError>
|
||||
ClientLambda::decryptReturnedTensor1(KeySet &keySet, std::istream &istream) {
|
||||
return decryptReturnedTensor<decrypted_tensor_1_t>(
|
||||
istream, *this, this->clientParameters, 1, keySet);
|
||||
}
|
||||
|
||||
outcome::checked<decrypted_tensor_2_t, StringError>
|
||||
ClientLambda::decryptReturnedTensor2(KeySet &keySet, std::istream &istream) {
|
||||
return decryptReturnedTensor<decrypted_tensor_2_t>(
|
||||
istream, *this, this->clientParameters, 2, keySet);
|
||||
}
|
||||
|
||||
outcome::checked<decrypted_tensor_3_t, StringError>
|
||||
ClientLambda::decryptReturnedTensor3(KeySet &keySet, std::istream &istream) {
|
||||
return decryptReturnedTensor<decrypted_tensor_3_t>(
|
||||
istream, *this, this->clientParameters, 3, keySet);
|
||||
}
|
||||
|
||||
template <typename Result>
|
||||
outcome::checked<Result, StringError>
|
||||
topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
|
||||
std::istream &istream) {
|
||||
// compile time error if used
|
||||
using COMPATIBLE_RESULT_TYPE = void;
|
||||
return (Result)(COMPATIBLE_RESULT_TYPE)0;
|
||||
}
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_scalar_t, StringError>
|
||||
topLevelDecryptResult<decrypted_scalar_t>(ClientLambda &lambda, KeySet &keySet,
|
||||
std::istream &istream) {
|
||||
return lambda.decryptReturnedScalar(keySet, istream);
|
||||
}
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_tensor_1_t, StringError>
|
||||
topLevelDecryptResult<decrypted_tensor_1_t>(ClientLambda &lambda,
|
||||
KeySet &keySet,
|
||||
std::istream &istream) {
|
||||
return lambda.decryptReturnedTensor1(keySet, istream);
|
||||
}
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_tensor_2_t, StringError>
|
||||
topLevelDecryptResult<decrypted_tensor_2_t>(ClientLambda &lambda,
|
||||
KeySet &keySet,
|
||||
std::istream &istream) {
|
||||
return lambda.decryptReturnedTensor2(keySet, istream);
|
||||
}
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_tensor_3_t, StringError>
|
||||
topLevelDecryptResult<decrypted_tensor_3_t>(ClientLambda &lambda,
|
||||
KeySet &keySet,
|
||||
std::istream &istream) {
|
||||
return lambda.decryptReturnedTensor3(keySet, istream);
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
@@ -3,53 +3,60 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using StringError = concretelang::error::StringError;
|
||||
|
||||
// https://stackoverflow.com/a/38140932
|
||||
static inline void hash(std::size_t &seed) {}
|
||||
static inline void hash_(std::size_t &seed) {}
|
||||
template <typename T, typename... Rest>
|
||||
static inline void hash(std::size_t &seed, const T &v, Rest... rest) {
|
||||
static inline void hash_(std::size_t &seed, const T &v, Rest... rest) {
|
||||
// See https://softwareengineering.stackexchange.com/a/402543
|
||||
const auto GOLDEN_RATIO = 0x9e3779b97f4a7c15; // pseudo random bits
|
||||
const std::hash<T> hasher;
|
||||
seed ^= hasher(v) + GOLDEN_RATIO + (seed << 6) + (seed >> 2);
|
||||
hash(seed, rest...);
|
||||
hash_(seed, rest...);
|
||||
}
|
||||
|
||||
void LweSecretKeyParam::hash(size_t &seed) {
|
||||
mlir::concretelang::hash(seed, size);
|
||||
}
|
||||
void LweSecretKeyParam::hash(size_t &seed) { hash_(seed, size); }
|
||||
|
||||
void BootstrapKeyParam::hash(size_t &seed) {
|
||||
mlir::concretelang::hash(seed, inputSecretKeyID, outputSecretKeyID, level,
|
||||
baseLog, glweDimension, variance);
|
||||
hash_(seed, inputSecretKeyID, outputSecretKeyID, level, baseLog,
|
||||
glweDimension, variance);
|
||||
}
|
||||
|
||||
void KeyswitchKeyParam::hash(size_t &seed) {
|
||||
mlir::concretelang::hash(seed, inputSecretKeyID, outputSecretKeyID, level,
|
||||
baseLog, variance);
|
||||
hash_(seed, inputSecretKeyID, outputSecretKeyID, level, baseLog, variance);
|
||||
}
|
||||
|
||||
std::size_t ClientParameters::hash() {
|
||||
std::size_t currentHash = 1;
|
||||
for (auto secretKeyParam : secretKeys) {
|
||||
mlir::concretelang::hash(currentHash, secretKeyParam.first);
|
||||
hash_(currentHash, secretKeyParam.first);
|
||||
secretKeyParam.second.hash(currentHash);
|
||||
}
|
||||
for (auto bootstrapKeyParam : bootstrapKeys) {
|
||||
mlir::concretelang::hash(currentHash, bootstrapKeyParam.first);
|
||||
hash_(currentHash, bootstrapKeyParam.first);
|
||||
bootstrapKeyParam.second.hash(currentHash);
|
||||
}
|
||||
for (auto keyswitchParam : keyswitchKeys) {
|
||||
mlir::concretelang::hash(currentHash, keyswitchParam.first);
|
||||
hash_(currentHash, keyswitchParam.first);
|
||||
keyswitchParam.second.hash(currentHash);
|
||||
}
|
||||
return currentHash;
|
||||
}
|
||||
|
||||
LweSecretKeyParam ClientParameters::lweSecretKeyParam(CircuitGate gate) {
|
||||
return secretKeys.find(gate.encryption->secretKeyID)->second;
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const LweSecretKeyParam &v) {
|
||||
llvm::json::Object object{
|
||||
{"size", v.size},
|
||||
@@ -384,5 +391,27 @@ bool fromJSON(const llvm::json::Value j, ClientParameters &v,
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
std::string ClientParameters::getClientParametersPath(std::string path) {
|
||||
return path + CLIENT_PARAMETERS_EXT;
|
||||
}
|
||||
|
||||
outcome::checked<std::vector<ClientParameters>, StringError>
|
||||
ClientParameters::load(std::string jsonPath) {
|
||||
std::ifstream file(jsonPath);
|
||||
std::string content((std::istreambuf_iterator<char>(file)),
|
||||
(std::istreambuf_iterator<char>()));
|
||||
if (file.fail()) {
|
||||
return StringError("Cannot read file: ") << jsonPath;
|
||||
}
|
||||
auto expectedClientParams =
|
||||
llvm::json::parse<std::vector<ClientParameters>>(content);
|
||||
if (auto err = expectedClientParams.takeError()) {
|
||||
return StringError("Cannot open client parameters: ")
|
||||
<< llvm::toString(std::move(err)) << "\n"
|
||||
<< content << "\n";
|
||||
}
|
||||
return expectedClientParams.get();
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
181
compiler/lib/ClientLib/EncryptedArgs.cpp
Normal file
181
compiler/lib/ClientLib/EncryptedArgs.cpp
Normal file
@@ -0,0 +1,181 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/ClientLib/EncryptedArgs.h"
|
||||
#include "concretelang/ClientLib/PublicArguments.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using StringError = concretelang::error::StringError;
|
||||
|
||||
EncryptedArgs::~EncryptedArgs() {
|
||||
// There is no explicit allocation
|
||||
// All buffers are owned by ciphertextBuffers
|
||||
}
|
||||
|
||||
EncryptedArgs::EncryptedArgs() : currentPos(0) {}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
EncryptedArgs::pushArg(uint8_t arg, std::shared_ptr<KeySet> keySet) {
|
||||
return pushArg((uint64_t)arg, keySet);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
EncryptedArgs::pushArg(uint64_t arg, std::shared_ptr<KeySet> keySet) {
|
||||
// TODO: NON ENCRYPTED
|
||||
OUTCOME_TRYV(checkPushTooManyArgs(keySet));
|
||||
auto pos = currentPos;
|
||||
CircuitGate input = keySet->inputGate(pos);
|
||||
if (input.shape.size != 0) {
|
||||
return StringError("argument #") << pos << " is not a scalar";
|
||||
}
|
||||
if (!input.encryption.hasValue()) {
|
||||
// clear scalar: just push the argument
|
||||
if (input.shape.width != 64) {
|
||||
return StringError(
|
||||
"scalar argument of with != 64 is not supported for DynamicLambda");
|
||||
}
|
||||
preparedArgs.push_back((void *)arg);
|
||||
return outcome::success();
|
||||
}
|
||||
ciphertextBuffers.resize(ciphertextBuffers.size() + 1); // Allocate empty
|
||||
encrypted_scalars_and_sizes_t &values_and_sizes = ciphertextBuffers.back();
|
||||
auto lweSize = keySet->getInputLweSecretKeyParam(pos).size + 1;
|
||||
values_and_sizes.sizes.push_back(lweSize);
|
||||
values_and_sizes.values.resize(lweSize);
|
||||
|
||||
OUTCOME_TRYV(keySet->encrypt_lwe(pos, values_and_sizes.values.data(), arg));
|
||||
// Note: Since we bufferized lwe ciphertext take care of memref calling
|
||||
// convention
|
||||
// allocated
|
||||
preparedArgs.push_back(nullptr);
|
||||
// aligned
|
||||
preparedArgs.push_back((void *)values_and_sizes.values.data());
|
||||
// offset
|
||||
preparedArgs.push_back((void *)0);
|
||||
// size
|
||||
preparedArgs.push_back((void *)values_and_sizes.values.size());
|
||||
// stride
|
||||
preparedArgs.push_back((void *)1);
|
||||
currentPos++;
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
EncryptedArgs::pushArg(std::vector<uint8_t> arg,
|
||||
std::shared_ptr<KeySet> keySet) {
|
||||
return pushArg(8, (void *)arg.data(), {(int64_t)arg.size()}, keySet);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
EncryptedArgs::pushArg(size_t width, void *data, llvm::ArrayRef<int64_t> shape,
|
||||
std::shared_ptr<KeySet> keySet) {
|
||||
OUTCOME_TRYV(checkPushTooManyArgs(keySet));
|
||||
auto pos = currentPos;
|
||||
CircuitGate input = keySet->inputGate(pos);
|
||||
// Check the width of data
|
||||
if (input.shape.width > 64) {
|
||||
return StringError("argument #")
|
||||
<< pos << " width > 64 bits is not supported";
|
||||
}
|
||||
auto roundedSize = concretelang::common::bitWidthAsWord(input.shape.width);
|
||||
if (width != roundedSize) {
|
||||
return StringError("argument #") << pos << "width mismatch, got " << width
|
||||
<< " expected " << roundedSize;
|
||||
}
|
||||
// Check the shape of tensor
|
||||
if (input.shape.dimensions.empty()) {
|
||||
return StringError("argument #") << pos << "is not a tensor";
|
||||
}
|
||||
if (shape.size() != input.shape.dimensions.size()) {
|
||||
return StringError("argument #")
|
||||
<< pos << "has not the expected number of dimension, got "
|
||||
<< shape.size() << " expected " << input.shape.dimensions.size();
|
||||
}
|
||||
ciphertextBuffers.resize(ciphertextBuffers.size() + 1); // Allocate empty
|
||||
encrypted_scalars_and_sizes_t &values_and_sizes = ciphertextBuffers.back();
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
values_and_sizes.sizes.push_back(shape[i]);
|
||||
if (shape[i] != input.shape.dimensions[i]) {
|
||||
return StringError("argument #")
|
||||
<< pos << " has not the expected dimension #" << i << " , got "
|
||||
<< shape[i] << " expected " << input.shape.dimensions[i];
|
||||
}
|
||||
}
|
||||
if (input.encryption.hasValue()) {
|
||||
auto lweSize = keySet->getInputLweSecretKeyParam(pos).size + 1;
|
||||
values_and_sizes.sizes.push_back(lweSize);
|
||||
|
||||
// Encrypted tensor: for now we support only 8 bits for encrypted tensor
|
||||
if (width != 8) {
|
||||
return StringError("argument #")
|
||||
<< pos << " width mismatch, expected 8 got " << width;
|
||||
}
|
||||
const uint8_t *data8 = (const uint8_t *)data;
|
||||
|
||||
// Allocate a buffer for ciphertexts of size of tensor
|
||||
values_and_sizes.values.resize(input.shape.size * lweSize);
|
||||
auto &values = values_and_sizes.values;
|
||||
// Allocate ciphertexts and encrypt, for every values in tensor
|
||||
for (size_t i = 0, offset = 0; i < input.shape.size;
|
||||
i++, offset += lweSize) {
|
||||
OUTCOME_TRYV(keySet->encrypt_lwe(pos, values.data() + offset, data8[i]));
|
||||
}
|
||||
} // TODO: NON ENCRYPTED, COPY CONTENT TO values_and_sizes
|
||||
// allocated
|
||||
preparedArgs.push_back(nullptr);
|
||||
// aligned
|
||||
preparedArgs.push_back((void *)values_and_sizes.values.data());
|
||||
// offset
|
||||
preparedArgs.push_back((void *)0);
|
||||
// sizes
|
||||
for (size_t size : values_and_sizes.sizes) {
|
||||
preparedArgs.push_back((void *)size);
|
||||
}
|
||||
// Set the stride for each dimension, equal to the product of the
|
||||
// following dimensions.
|
||||
int64_t stride = values_and_sizes.length();
|
||||
// If encrypted +1 set the stride for the lwe size rank
|
||||
for (size_t size : values_and_sizes.sizes) {
|
||||
stride /= size;
|
||||
preparedArgs.push_back((void *)stride);
|
||||
}
|
||||
currentPos++;
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
EncryptedArgs::checkPushTooManyArgs(std::shared_ptr<KeySet> keySet) {
|
||||
size_t arity = keySet->numInputs();
|
||||
if (currentPos < arity) {
|
||||
return outcome::success();
|
||||
}
|
||||
return StringError("function has arity ")
|
||||
<< arity << " but is applied to too many arguments";
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
EncryptedArgs::checkAllArgs(std::shared_ptr<KeySet> keySet) {
|
||||
size_t arity = keySet->numInputs();
|
||||
if (currentPos == arity) {
|
||||
return outcome::success();
|
||||
}
|
||||
return StringError("function expects ")
|
||||
<< arity << " arguments but has been called with " << currentPos
|
||||
<< " arguments";
|
||||
}
|
||||
|
||||
outcome::checked<PublicArguments, StringError>
|
||||
EncryptedArgs::asPublicArguments(ClientParameters clientParameters,
|
||||
RuntimeContext runtimeContext) {
|
||||
// On client side the runtimeContext is hold by the KeySet
|
||||
bool clearContext = false;
|
||||
return PublicArguments(clientParameters, runtimeContext, clearContext,
|
||||
std::move(preparedArgs), std::move(ciphertextBuffers));
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
@@ -6,8 +6,17 @@
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
|
||||
namespace mlir {
|
||||
#define CAPI_ERR_TO_STRINGERROR(instr, msg) \
|
||||
{ \
|
||||
int err; \
|
||||
instr; \
|
||||
if (err != 0) { \
|
||||
return concretelang::error::StringError(msg); \
|
||||
} \
|
||||
}
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
KeySet::~KeySet() {
|
||||
for (auto it : secretKeys) {
|
||||
@@ -22,31 +31,20 @@ KeySet::~KeySet() {
|
||||
free_encryption_generator(encryptionRandomGenerator);
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<KeySet>>
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
KeySet::generate(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
auto keySet = std::make_unique<KeySet>();
|
||||
|
||||
auto keySet = uninitialized();
|
||||
|
||||
if (auto error = keySet->generateKeysFromParams(params, seed_msb, seed_lsb)) {
|
||||
return std::move(error);
|
||||
}
|
||||
|
||||
if (auto error =
|
||||
keySet->setupEncryptionMaterial(params, seed_msb, seed_lsb)) {
|
||||
return std::move(error);
|
||||
}
|
||||
OUTCOME_TRYV(keySet->generateKeysFromParams(params, seed_msb, seed_lsb));
|
||||
OUTCOME_TRYV(keySet->setupEncryptionMaterial(params, seed_msb, seed_lsb));
|
||||
|
||||
return std::move(keySet);
|
||||
}
|
||||
|
||||
std::unique_ptr<KeySet> KeySet::uninitialized() {
|
||||
return std::make_unique<KeySet>();
|
||||
}
|
||||
|
||||
llvm::Error KeySet::setupEncryptionMaterial(ClientParameters ¶ms,
|
||||
uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::setupEncryptionMaterial(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
// Set inputs and outputs LWE secret keys
|
||||
{
|
||||
for (auto param : params.inputs) {
|
||||
@@ -55,10 +53,8 @@ llvm::Error KeySet::setupEncryptionMaterial(ClientParameters ¶ms,
|
||||
if (param.encryption.hasValue()) {
|
||||
auto inputSk = this->secretKeys.find(param.encryption->secretKeyID);
|
||||
if (inputSk == this->secretKeys.end()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"input encryption secret key (" + param.encryption->secretKeyID +
|
||||
") does not exist ",
|
||||
llvm::inconvertibleErrorCode());
|
||||
return StringError("input encryption secret key (")
|
||||
<< param.encryption->secretKeyID << ") does not exist ";
|
||||
}
|
||||
secretKeyParam = inputSk->second.first;
|
||||
secretKey = inputSk->second.second;
|
||||
@@ -73,9 +69,8 @@ llvm::Error KeySet::setupEncryptionMaterial(ClientParameters ¶ms,
|
||||
if (param.encryption.hasValue()) {
|
||||
auto outputSk = this->secretKeys.find(param.encryption->secretKeyID);
|
||||
if (outputSk == this->secretKeys.end()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot find output key to generate bootstrap key",
|
||||
llvm::inconvertibleErrorCode());
|
||||
return StringError(
|
||||
"cannot find output key to generate bootstrap key");
|
||||
}
|
||||
secretKeyParam = outputSk->second.first;
|
||||
secretKey = outputSk->second.second;
|
||||
@@ -89,50 +84,41 @@ llvm::Error KeySet::setupEncryptionMaterial(ClientParameters ¶ms,
|
||||
this->encryptionRandomGenerator =
|
||||
allocate_encryption_generator(seed_msb, seed_lsb);
|
||||
|
||||
return llvm::Error::success();
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
llvm::Error KeySet::generateKeysFromParams(ClientParameters ¶ms,
|
||||
uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::generateKeysFromParams(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
|
||||
{
|
||||
// Generate LWE secret keys
|
||||
SecretRandomGenerator *generator;
|
||||
|
||||
generator = allocate_secret_generator(seed_msb, seed_lsb);
|
||||
for (auto secretKeyParam : params.secretKeys) {
|
||||
auto e = this->generateSecretKey(secretKeyParam.first,
|
||||
secretKeyParam.second, generator);
|
||||
if (e) {
|
||||
return std::move(e);
|
||||
}
|
||||
OUTCOME_TRYV(this->generateSecretKey(secretKeyParam.first,
|
||||
secretKeyParam.second, generator));
|
||||
}
|
||||
free_secret_generator(generator);
|
||||
}
|
||||
// Allocate the encryption random generator
|
||||
|
||||
this->encryptionRandomGenerator =
|
||||
allocate_encryption_generator(seed_msb, seed_lsb);
|
||||
// Generate bootstrap and keyswitch keys
|
||||
{
|
||||
for (auto bootstrapKeyParam : params.bootstrapKeys) {
|
||||
auto e = this->generateBootstrapKey(bootstrapKeyParam.first,
|
||||
bootstrapKeyParam.second,
|
||||
this->encryptionRandomGenerator);
|
||||
if (e) {
|
||||
return std::move(e);
|
||||
}
|
||||
OUTCOME_TRYV(this->generateBootstrapKey(bootstrapKeyParam.first,
|
||||
bootstrapKeyParam.second,
|
||||
this->encryptionRandomGenerator));
|
||||
}
|
||||
for (auto keyswitchParam : params.keyswitchKeys) {
|
||||
auto e = this->generateKeyswitchKey(keyswitchParam.first,
|
||||
keyswitchParam.second,
|
||||
this->encryptionRandomGenerator);
|
||||
if (e) {
|
||||
return std::move(e);
|
||||
}
|
||||
OUTCOME_TRYV(this->generateKeyswitchKey(keyswitchParam.first,
|
||||
keyswitchParam.second,
|
||||
this->encryptionRandomGenerator));
|
||||
}
|
||||
}
|
||||
return llvm::Error::success();
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
void KeySet::setKeys(
|
||||
@@ -149,9 +135,9 @@ void KeySet::setKeys(
|
||||
this->keyswitchKeys = keyswitchKeys;
|
||||
}
|
||||
|
||||
llvm::Error KeySet::generateSecretKey(LweSecretKeyID id,
|
||||
LweSecretKeyParam param,
|
||||
SecretRandomGenerator *generator) {
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param,
|
||||
SecretRandomGenerator *generator) {
|
||||
LweSecretKey_u64 *sk;
|
||||
sk = allocate_lwe_secret_key_u64({param.size});
|
||||
|
||||
@@ -159,24 +145,20 @@ llvm::Error KeySet::generateSecretKey(LweSecretKeyID id,
|
||||
|
||||
secretKeys[id] = {param, sk};
|
||||
|
||||
return llvm::Error::success();
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
llvm::Error KeySet::generateBootstrapKey(BootstrapKeyID id,
|
||||
BootstrapKeyParam param,
|
||||
EncryptionRandomGenerator *generator) {
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param,
|
||||
EncryptionRandomGenerator *generator) {
|
||||
// Finding input and output secretKeys
|
||||
auto inputSk = secretKeys.find(param.inputSecretKeyID);
|
||||
if (inputSk == secretKeys.end()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot find input key to generate bootstrap key",
|
||||
llvm::inconvertibleErrorCode());
|
||||
return StringError("cannot find input key to generate bootstrap key");
|
||||
}
|
||||
auto outputSk = secretKeys.find(param.outputSecretKeyID);
|
||||
if (outputSk == secretKeys.end()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot find output key to generate bootstrap key",
|
||||
llvm::inconvertibleErrorCode());
|
||||
return StringError("cannot find output key to generate bootstrap key");
|
||||
}
|
||||
// Allocate the bootstrap key
|
||||
LweBootstrapKey_u64 *bsk;
|
||||
@@ -207,24 +189,20 @@ llvm::Error KeySet::generateBootstrapKey(BootstrapKeyID id,
|
||||
fill_lwe_bootstrap_key_u64(bsk, inputSk->second.second, glwe_sk, generator,
|
||||
{param.variance});
|
||||
free_glwe_secret_key_u64(glwe_sk);
|
||||
return llvm::Error::success();
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
llvm::Error KeySet::generateKeyswitchKey(KeyswitchKeyID id,
|
||||
KeyswitchKeyParam param,
|
||||
EncryptionRandomGenerator *generator) {
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param,
|
||||
EncryptionRandomGenerator *generator) {
|
||||
// Finding input and output secretKeys
|
||||
auto inputSk = secretKeys.find(param.inputSecretKeyID);
|
||||
if (inputSk == secretKeys.end()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot find input key to generate keyswitch key",
|
||||
llvm::inconvertibleErrorCode());
|
||||
return StringError("cannot find input key to generate keyswitch key");
|
||||
}
|
||||
auto outputSk = secretKeys.find(param.outputSecretKeyID);
|
||||
if (outputSk == secretKeys.end()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot find input key to generate keyswitch key",
|
||||
llvm::inconvertibleErrorCode());
|
||||
return StringError("cannot find output key to generate keyswitch key");
|
||||
}
|
||||
// Allocate the keyswitch key
|
||||
LweKeyswitchKey_u64 *ksk;
|
||||
@@ -240,21 +218,19 @@ llvm::Error KeySet::generateKeyswitchKey(KeyswitchKeyID id,
|
||||
fill_lwe_keyswitch_key_u64(ksk, inputSk->second.second,
|
||||
outputSk->second.second, generator,
|
||||
{param.variance});
|
||||
return llvm::Error::success();
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
llvm::Error KeySet::allocate_lwe(size_t argPos, uint64_t **ciphertext,
|
||||
uint64_t &size) {
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::allocate_lwe(size_t argPos, uint64_t **ciphertext, uint64_t &size) {
|
||||
if (argPos >= inputs.size()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"allocate_lwe position of argument is too high",
|
||||
llvm::inconvertibleErrorCode());
|
||||
return StringError("allocate_lwe position of argument is too high");
|
||||
}
|
||||
auto inputSk = inputs[argPos];
|
||||
|
||||
size = std::get<1>(inputSk).size + 1;
|
||||
*ciphertext = (uint64_t *)malloc(sizeof(uint64_t) * size);
|
||||
return llvm::Error::success();
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
bool KeySet::isInputEncrypted(size_t argPos) {
|
||||
@@ -267,18 +243,14 @@ bool KeySet::isOutputEncrypted(size_t argPos) {
|
||||
std::get<0>(outputs[argPos]).encryption.hasValue();
|
||||
}
|
||||
|
||||
llvm::Error KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext,
|
||||
uint64_t input) {
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) {
|
||||
if (argPos >= inputs.size()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"encrypt_lwe position of argument is too high",
|
||||
llvm::inconvertibleErrorCode());
|
||||
return StringError("encrypt_lwe position of argument is too high");
|
||||
}
|
||||
auto inputSk = inputs[argPos];
|
||||
if (!std::get<0>(inputSk).encryption.hasValue()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"encrypt_lwe the positional argument is not encrypted",
|
||||
llvm::inconvertibleErrorCode());
|
||||
return StringError("encrypt_lwe the positional argument is not encrypted");
|
||||
}
|
||||
// Encode - TODO we could check if the input value is in the right range
|
||||
uint64_t plaintext =
|
||||
@@ -286,22 +258,18 @@ llvm::Error KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext,
|
||||
encrypt_lwe_u64(std::get<2>(inputSk), ciphertext, plaintext,
|
||||
encryptionRandomGenerator,
|
||||
{std::get<0>(inputSk).encryption->variance});
|
||||
return llvm::Error::success();
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
llvm::Error KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext,
|
||||
uint64_t &output) {
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) {
|
||||
|
||||
if (argPos >= outputs.size()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"decrypt_lwe: position of argument is too high",
|
||||
llvm::inconvertibleErrorCode());
|
||||
return StringError("decrypt_lwe: position of argument is too high");
|
||||
}
|
||||
auto outputSk = outputs[argPos];
|
||||
if (!std::get<0>(outputSk).encryption.hasValue()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"decrypt_lwe: the positional argument is not encrypted",
|
||||
llvm::inconvertibleErrorCode());
|
||||
return StringError("decrypt_lwe: the positional argument is not encrypted");
|
||||
}
|
||||
uint64_t plaintext = decrypt_lwe_u64(std::get<2>(outputSk), ciphertext);
|
||||
// Decode
|
||||
@@ -309,7 +277,7 @@ llvm::Error KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext,
|
||||
output = plaintext >> (64 - precision - 2);
|
||||
size_t carry = output % 2;
|
||||
output = ((output >> 1) + carry) % (1 << (precision + 1));
|
||||
return llvm::Error::success();
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
const std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
|
||||
@@ -329,5 +297,5 @@ KeySet::getKeyswitchKeys() {
|
||||
return keyswitchKeys;
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -3,14 +3,14 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include "llvm/ADT/ScopeExit.h"
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
#include "llvm/Support/Path.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
@@ -18,8 +18,10 @@ extern "C" {
|
||||
#include "concrete-ffi.h"
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using StringError = concretelang::error::StringError;
|
||||
|
||||
static std::string readFile(llvm::SmallString<0> &path) {
|
||||
std::ifstream in((std::string)path, std::ofstream::binary);
|
||||
@@ -70,11 +72,11 @@ void saveKeyswitchKey(llvm::SmallString<0> &path, LweKeyswitchKey_u64 *key) {
|
||||
free(buffer.pointer);
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<KeySet>>
|
||||
KeySetCache::tryLoadKeys(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb, llvm::SmallString<0> &folderPath) {
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
KeySetCache::loadKeys(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb, std::string folderPath) {
|
||||
// TODO: text dump of all parameter in /hash
|
||||
auto key_set = KeySet::uninitialized();
|
||||
auto key_set = std::make_unique<KeySet>();
|
||||
|
||||
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
|
||||
secretKeys;
|
||||
@@ -87,7 +89,7 @@ KeySetCache::tryLoadKeys(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
for (auto secretKeyParam : params.secretKeys) {
|
||||
auto id = secretKeyParam.first;
|
||||
auto param = secretKeyParam.second;
|
||||
llvm::SmallString<0> path = folderPath;
|
||||
llvm::SmallString<0> path(folderPath);
|
||||
llvm::sys::path::append(path, "secretKey_" + id);
|
||||
LweSecretKey_u64 *sk = loadSecretKey(path);
|
||||
secretKeys[id] = {param, sk};
|
||||
@@ -96,7 +98,7 @@ KeySetCache::tryLoadKeys(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
for (auto bootstrapKeyParam : params.bootstrapKeys) {
|
||||
auto id = bootstrapKeyParam.first;
|
||||
auto param = bootstrapKeyParam.second;
|
||||
llvm::SmallString<0> path = folderPath;
|
||||
llvm::SmallString<0> path(folderPath);
|
||||
llvm::sys::path::append(path, "pbsKey_" + id);
|
||||
LweBootstrapKey_u64 *bsk = loadBootstrapKey(path);
|
||||
bootstrapKeys[id] = {param, bsk};
|
||||
@@ -105,7 +107,7 @@ KeySetCache::tryLoadKeys(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
for (auto keyswitchParam : params.keyswitchKeys) {
|
||||
auto id = keyswitchParam.first;
|
||||
auto param = keyswitchParam.second;
|
||||
llvm::SmallString<0> path = folderPath;
|
||||
llvm::SmallString<0> path(folderPath);
|
||||
llvm::sys::path::append(path, "ksKey_" + id);
|
||||
LweKeyswitchKey_u64 *ksk = loadKeyswitchKey(path);
|
||||
keyswitchKeys[id] = {param, ksk};
|
||||
@@ -113,24 +115,21 @@ KeySetCache::tryLoadKeys(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
|
||||
key_set->setKeys(secretKeys, bootstrapKeys, keyswitchKeys);
|
||||
|
||||
auto err = key_set->setupEncryptionMaterial(params, seed_msb, seed_lsb);
|
||||
if (err) {
|
||||
return StreamStringError() << "Cannot setup encryption material: " << err;
|
||||
}
|
||||
OUTCOME_TRYV(key_set->setupEncryptionMaterial(params, seed_msb, seed_lsb));
|
||||
|
||||
return std::move(key_set);
|
||||
}
|
||||
|
||||
llvm::Error saveKeys(KeySet &key_set, llvm::SmallString<0> &folderPath) {
|
||||
outcome::checked<void, StringError> saveKeys(KeySet &key_set,
|
||||
llvm::SmallString<0> &folderPath) {
|
||||
llvm::SmallString<0> folderIncompletePath = folderPath;
|
||||
|
||||
folderIncompletePath.append(".incomplete");
|
||||
|
||||
auto err = llvm::sys::fs::create_directories(folderIncompletePath);
|
||||
if (err) {
|
||||
return StreamStringError()
|
||||
<< "Cannot create directory \"" << folderIncompletePath
|
||||
<< "\": " << err.message();
|
||||
return StringError("Cannot create directory \"")
|
||||
<< std::string(folderIncompletePath) << "\": " << err.message();
|
||||
}
|
||||
|
||||
// Save LWE secret keys
|
||||
@@ -163,16 +162,16 @@ llvm::Error saveKeys(KeySet &key_set, llvm::SmallString<0> &folderPath) {
|
||||
llvm::sys::fs::remove_directories(folderIncompletePath);
|
||||
}
|
||||
if (!llvm::sys::fs::exists(folderPath)) {
|
||||
return StreamStringError()
|
||||
<< "Cannot save directory \"" << folderPath << "\"";
|
||||
return StringError("Cannot save directory \"")
|
||||
<< std::string(folderPath) << "\"";
|
||||
}
|
||||
|
||||
return llvm::Error::success();
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<KeySet>>
|
||||
KeySetCache::tryLoadOrGenerateSave(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
KeySetCache::loadOrGenerateSave(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
|
||||
llvm::SmallString<0> folderPath =
|
||||
llvm::SmallString<0>(this->backingDirectoryPath);
|
||||
@@ -183,7 +182,7 @@ KeySetCache::tryLoadOrGenerateSave(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
std::to_string(seed_lsb));
|
||||
|
||||
if (llvm::sys::fs::exists(folderPath)) {
|
||||
return tryLoadKeys(params, seed_msb, seed_lsb, folderPath);
|
||||
return loadKeys(params, seed_msb, seed_lsb, std::string(folderPath));
|
||||
}
|
||||
|
||||
// Creating a lock for concurrent generation
|
||||
@@ -198,8 +197,8 @@ KeySetCache::tryLoadOrGenerateSave(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
|
||||
if (err) {
|
||||
// parent does not exists OR right issue (creation or write)
|
||||
return StreamStringError()
|
||||
<< "Cannot access \"" << lockPath << "\": " << err.message();
|
||||
return StringError("Cannot access \"")
|
||||
<< std::string(lockPath) << "\": " << err.message();
|
||||
}
|
||||
|
||||
// The first to lock will generate while the others waits
|
||||
@@ -211,23 +210,23 @@ KeySetCache::tryLoadOrGenerateSave(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
|
||||
if (llvm::sys::fs::exists(folderPath)) {
|
||||
// Others returns here
|
||||
return tryLoadKeys(params, seed_msb, seed_lsb, folderPath);
|
||||
return loadKeys(params, seed_msb, seed_lsb, std::string(folderPath));
|
||||
}
|
||||
|
||||
auto key_set = KeySet::generate(params, seed_msb, seed_lsb);
|
||||
OUTCOME_TRY(auto key_set, KeySet::generate(params, seed_msb, seed_lsb));
|
||||
|
||||
if (!key_set) {
|
||||
return StreamStringError()
|
||||
<< "Cannot generate key set: " << key_set.takeError();
|
||||
}
|
||||
OUTCOME_TRYV(saveKeys(*key_set, folderPath));
|
||||
|
||||
auto savedErr = saveKeys(*(key_set.get()), folderPath);
|
||||
if (savedErr) {
|
||||
return StreamStringError() << "Cannot save key set: " << savedErr;
|
||||
}
|
||||
|
||||
return key_set;
|
||||
return std::move(key_set);
|
||||
}
|
||||
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
KeySetCache::generate(std::shared_ptr<KeySetCache> cache,
|
||||
ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
return cache ? cache->loadOrGenerateSave(params, seed_msb, seed_lsb)
|
||||
: KeySet::generate(params, seed_msb, seed_lsb);
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
156
compiler/lib/ClientLib/PublicArguments.cpp
Normal file
156
compiler/lib/ClientLib/PublicArguments.cpp
Normal file
@@ -0,0 +1,156 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <iostream>
|
||||
#include <stdlib.h>
|
||||
|
||||
extern "C" {
|
||||
#include "concrete-ffi.h"
|
||||
}
|
||||
|
||||
#include "concretelang/ClientLib/PublicArguments.h"
|
||||
#include "concretelang/ClientLib/Serializers.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using concretelang::error::StringError;
|
||||
|
||||
// TODO: optimize the move
|
||||
PublicArguments::PublicArguments(
|
||||
const ClientParameters &clientParameters, RuntimeContext runtimeContext,
|
||||
bool clearRuntimeContext, std::vector<void *> &&preparedArgs_,
|
||||
std::vector<encrypted_scalars_and_sizes_t> &&ciphertextBuffers_)
|
||||
: clientParameters(clientParameters), runtimeContext(runtimeContext),
|
||||
clearRuntimeContext(clearRuntimeContext) {
|
||||
preparedArgs = std::move(preparedArgs_);
|
||||
ciphertextBuffers = std::move(ciphertextBuffers_);
|
||||
}
|
||||
|
||||
PublicArguments::PublicArguments(PublicArguments &&other) {
|
||||
clientParameters = other.clientParameters;
|
||||
runtimeContext = other.runtimeContext;
|
||||
runtimeContext.bsk = std::move(other.runtimeContext.bsk);
|
||||
clearRuntimeContext = other.clearRuntimeContext;
|
||||
preparedArgs = std::move(other.preparedArgs);
|
||||
ciphertextBuffers = std::move(other.ciphertextBuffers);
|
||||
// transfer ownership
|
||||
other.clearRuntimeContext = false;
|
||||
other.runtimeContext.ksk = nullptr;
|
||||
}
|
||||
|
||||
PublicArguments::~PublicArguments() {
|
||||
if (!clearRuntimeContext) {
|
||||
return;
|
||||
}
|
||||
for (auto bsk_entry : runtimeContext.bsk) {
|
||||
free_lwe_bootstrap_key_u64(bsk_entry.second);
|
||||
}
|
||||
runtimeContext.bsk.clear();
|
||||
if (runtimeContext.ksk != nullptr) {
|
||||
free_lwe_keyswitch_key_u64(runtimeContext.ksk);
|
||||
runtimeContext.ksk = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
PublicArguments::serialize(std::ostream &ostream) {
|
||||
if (incorrectMode(ostream)) {
|
||||
return StringError(
|
||||
"PublicArguments::serialize: ostream should be in binary mode");
|
||||
}
|
||||
ostream << runtimeContext;
|
||||
size_t iPreparedArgs = 0;
|
||||
int iGate = -1;
|
||||
for (auto gate : clientParameters.inputs) {
|
||||
iGate++;
|
||||
size_t rank = gate.shape.dimensions.size();
|
||||
if (!gate.encryption.hasValue()) {
|
||||
return StringError("PublicArguments::serialize: Clear arguments "
|
||||
"are not supported. Argument ")
|
||||
<< iGate;
|
||||
}
|
||||
/*auto allocated = */ preparedArgs[iPreparedArgs++];
|
||||
auto aligned = (encrypted_scalars_t)preparedArgs[iPreparedArgs++];
|
||||
assert(aligned != nullptr);
|
||||
auto offset = (size_t)preparedArgs[iPreparedArgs++];
|
||||
std::vector<size_t> sizes; // includes lweSize as last dim
|
||||
sizes.resize(rank + 1);
|
||||
for (auto dim = 0u; dim < sizes.size(); dim++) {
|
||||
// sizes are part of the client parameters signature
|
||||
// it's static now but some day it could be dynamic so we serialize
|
||||
// them.
|
||||
sizes[dim] = (size_t)preparedArgs[iPreparedArgs++];
|
||||
}
|
||||
std::vector<size_t> strides(rank + 1);
|
||||
/* strides should be zero here and are not serialized */
|
||||
for (auto dim = 0u; dim < strides.size(); dim++) {
|
||||
strides[dim] = (size_t)preparedArgs[iPreparedArgs++];
|
||||
}
|
||||
// TODO: STRIDES
|
||||
auto values = aligned + offset;
|
||||
serializeEncryptedValues(sizes, values, ostream);
|
||||
}
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
PublicArguments::unserializeArgs(std::istream &istream) {
|
||||
int iGate = -1;
|
||||
for (auto gate : clientParameters.inputs) {
|
||||
iGate++;
|
||||
if (!gate.encryption.hasValue()) {
|
||||
return StringError("Clear values are not handled");
|
||||
}
|
||||
auto lweSize = clientParameters.lweSecretKeyParam(gate).lweSize();
|
||||
std::vector<int64_t> sizes = gate.shape.dimensions;
|
||||
sizes.push_back(lweSize);
|
||||
ciphertextBuffers.push_back(
|
||||
std::move(unserializeEncryptedValues(sizes, istream)));
|
||||
auto &values_and_sizes = ciphertextBuffers.back();
|
||||
if (istream.fail()) {
|
||||
return StringError(
|
||||
"PublicArguments::unserializeArgs: Failed to read argument ")
|
||||
<< iGate;
|
||||
}
|
||||
preparedArgs.push_back(/*allocated*/ nullptr);
|
||||
preparedArgs.push_back((void *)values_and_sizes.values.data());
|
||||
preparedArgs.push_back(/*offset*/ 0);
|
||||
// sizes
|
||||
for (auto size : values_and_sizes.sizes) {
|
||||
preparedArgs.push_back((void *)size);
|
||||
}
|
||||
// strides has been removed by serialization
|
||||
auto stride = values_and_sizes.length();
|
||||
for (auto size : sizes) {
|
||||
stride /= size;
|
||||
preparedArgs.push_back((void *)stride);
|
||||
}
|
||||
}
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<std::shared_ptr<PublicArguments>, StringError>
|
||||
PublicArguments::unserialize(ClientParameters &clientParameters,
|
||||
std::istream &istream) {
|
||||
RuntimeContext runtimeContext;
|
||||
istream >> runtimeContext;
|
||||
if (istream.fail()) {
|
||||
return StringError("Cannot read runtime context");
|
||||
}
|
||||
std::vector<void *> empty;
|
||||
std::vector<encrypted_scalars_and_sizes_t> emptyBuffers;
|
||||
// On server side the PublicArguments is responsible for the context
|
||||
auto clearRuntimeContext = true;
|
||||
auto sArguments = std::make_shared<PublicArguments>(
|
||||
clientParameters, runtimeContext, clearRuntimeContext, std::move(empty),
|
||||
std::move(emptyBuffers));
|
||||
OUTCOME_TRYV(sArguments->unserializeArgs(istream));
|
||||
sArguments->preparedArgs.push_back((void *)&runtimeContext);
|
||||
return sArguments;
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
185
compiler/lib/ClientLib/Serializers.cpp
Normal file
185
compiler/lib/ClientLib/Serializers.cpp
Normal file
@@ -0,0 +1,185 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <malloc.h>
|
||||
|
||||
#include <iosfwd>
|
||||
#include <iostream>
|
||||
#include <stdlib.h>
|
||||
|
||||
extern "C" {
|
||||
#include "concrete-ffi.h"
|
||||
}
|
||||
|
||||
#include "concretelang/ClientLib/PublicArguments.h"
|
||||
#include "concretelang/ClientLib/Serializers.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
template <typename Result>
|
||||
Result read_deser(std::istream &istream, Result (*deser)(BufferView)) {
|
||||
size_t length;
|
||||
readSize(istream, length);
|
||||
// buffer is too big to be allocated on stack
|
||||
// vector ensures everything is deallocated w.r.t. new
|
||||
std::vector<uint8_t> buffer(length);
|
||||
istream.read((char *)buffer.data(), length);
|
||||
assert(istream.good());
|
||||
return deser({buffer.data(), length});
|
||||
}
|
||||
|
||||
template <typename BufferLike>
|
||||
std::ostream &writeBufferLike(std::ostream &ostream, BufferLike &buffer) {
|
||||
writeSize(ostream, buffer.length);
|
||||
ostream.write((const char *)buffer.pointer, buffer.length);
|
||||
assert(ostream.good());
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const LweKeyswitchKey_u64 *key) {
|
||||
Buffer b = serialize_lwe_keyswitching_key_u64(key);
|
||||
writeBufferLike(ostream, b);
|
||||
free((void *)b.pointer);
|
||||
b.pointer = nullptr;
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const LweBootstrapKey_u64 *key) {
|
||||
Buffer b = serialize_lwe_bootstrap_key_u64(key);
|
||||
writeBufferLike(ostream, b);
|
||||
free((void *)b.pointer);
|
||||
b.pointer = nullptr;
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::istream &operator>>(std::istream &istream, LweKeyswitchKey_u64 *&key) {
|
||||
key = read_deser(istream, deserialize_lwe_keyswitching_key_u64);
|
||||
return istream;
|
||||
}
|
||||
|
||||
std::istream &operator>>(std::istream &istream, LweBootstrapKey_u64 *&key) {
|
||||
key = read_deser(istream, deserialize_lwe_bootstrap_key_u64);
|
||||
return istream;
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream, const ClientParameters &cp) {
|
||||
// For binary stream == not formatting
|
||||
std::string json;
|
||||
llvm::raw_string_ostream tmpostream(json);
|
||||
tmpostream << toJSON(cp);
|
||||
writeSize(ostream, json.size());
|
||||
assert(ostream.good());
|
||||
ostream << json;
|
||||
assert(ostream.good());
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::istream &operator>>(std::istream &istream, ClientParameters ¶ms) {
|
||||
size_t size;
|
||||
readSize(istream, size);
|
||||
char buffer[size + 1];
|
||||
buffer[size] = '\0'; // llvm::json::parse requires \0 ended buffer.
|
||||
istream.read(buffer, size);
|
||||
auto paramsOrErr = llvm::json::parse<ClientParameters>(buffer);
|
||||
if (auto err = paramsOrErr.takeError()) {
|
||||
llvm::errs() << "Parsing client parameters error: " << std::move(err)
|
||||
<< "\n";
|
||||
istream.setstate(std::ios::failbit);
|
||||
return istream;
|
||||
}
|
||||
params = paramsOrErr.get();
|
||||
assert(istream.good());
|
||||
return istream;
|
||||
}
|
||||
|
||||
std::istream &operator>>(std::istream &istream,
|
||||
RuntimeContext &runtimeContext) {
|
||||
istream >> runtimeContext.ksk;
|
||||
istream >> runtimeContext.bsk[RuntimeContext::BASE_CONTEXT_BSK];
|
||||
assert(istream.good());
|
||||
return istream;
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const RuntimeContext &runtimeContext) {
|
||||
ostream << runtimeContext.ksk;
|
||||
ostream << runtimeContext.bsk.at(RuntimeContext::BASE_CONTEXT_BSK);
|
||||
assert(ostream.good());
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::ostream &serializeEncryptedValues(encrypted_scalars_t values,
|
||||
size_t length, std::ostream &ostream) {
|
||||
if (incorrectMode(ostream)) {
|
||||
return ostream;
|
||||
}
|
||||
writeSize(ostream, length);
|
||||
for (size_t i = 0; i < length; i++) {
|
||||
writeWord(ostream, values[i]);
|
||||
}
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::ostream &serializeEncryptedValues(std::vector<size_t> &sizes,
|
||||
encrypted_scalars_t values,
|
||||
std::ostream &ostream) {
|
||||
size_t length = 1;
|
||||
for (auto size : sizes) {
|
||||
length *= size;
|
||||
writeSize(ostream, size);
|
||||
}
|
||||
serializeEncryptedValues(values, length, ostream);
|
||||
assert(ostream.good());
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::ostream &
|
||||
serializeEncryptedValues(encrypted_scalars_and_sizes_t &values_and_sizes,
|
||||
std::ostream &ostream) {
|
||||
std::vector<size_t> &sizes = values_and_sizes.sizes;
|
||||
encrypted_scalars_t values = values_and_sizes.values.data();
|
||||
return serializeEncryptedValues(sizes, values, ostream);
|
||||
}
|
||||
|
||||
encrypted_scalars_and_sizes_t unserializeEncryptedValues(
|
||||
std::vector<int64_t> &expectedSizes, // includes lweSize, unsigned to
|
||||
// accomodate non static sizes
|
||||
std::istream &istream) {
|
||||
encrypted_scalars_and_sizes_t result;
|
||||
if (incorrectMode(istream)) {
|
||||
return result;
|
||||
}
|
||||
for (auto expectedSize : expectedSizes) {
|
||||
size_t actualSize;
|
||||
readSize(istream, actualSize);
|
||||
if ((size_t)expectedSize != actualSize) {
|
||||
istream.setstate(std::ios::badbit);
|
||||
}
|
||||
assert(actualSize > 0);
|
||||
result.sizes.push_back(actualSize);
|
||||
assert(result.sizes.back() > 0);
|
||||
}
|
||||
size_t expectedLen = result.length();
|
||||
assert(expectedLen > 0);
|
||||
// TODO: full read in one step
|
||||
size_t actualLen;
|
||||
readSize(istream, actualLen);
|
||||
if (expectedLen != actualLen) {
|
||||
istream.setstate(std::ios::badbit);
|
||||
}
|
||||
assert(actualLen == expectedLen);
|
||||
result.values.resize(actualLen);
|
||||
for (size_t &value : result.values) {
|
||||
value = 0;
|
||||
readWord(istream, value);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
23
compiler/lib/Common/BitsSize.cpp
Normal file
23
compiler/lib/Common/BitsSize.cpp
Normal file
@@ -0,0 +1,23 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/Common/BitsSize.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace common {
|
||||
|
||||
size_t bitWidthAsWord(size_t exactBitWidth) {
|
||||
size_t sortedWordBitWidths[] = {8, 16, 32, 64};
|
||||
size_t previousWidth = 0;
|
||||
for (auto currentWidth : sortedWordBitWidths) {
|
||||
if (previousWidth < exactBitWidth && exactBitWidth <= currentWidth) {
|
||||
return currentWidth;
|
||||
}
|
||||
}
|
||||
return exactBitWidth;
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace concretelang
|
||||
19
compiler/lib/Common/CMakeLists.txt
Normal file
19
compiler/lib/Common/CMakeLists.txt
Normal file
@@ -0,0 +1,19 @@
|
||||
add_compile_options( -Werror )
|
||||
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
||||
# using Clang
|
||||
add_compile_options( -Wno-error=pessimizing-move -Wno-pessimizing-move )
|
||||
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
|
||||
# using GCC
|
||||
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
|
||||
add_compile_options( -Werror -Wno-error=pessimizing-move -Wno-pessimizing-move )
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_mlir_library(
|
||||
ConcretelangCommon
|
||||
BitsSize.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Common
|
||||
)
|
||||
@@ -11,6 +11,14 @@
|
||||
#include <hpx/include/runtime.hpp>
|
||||
#endif
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
std::string RuntimeContext::BASE_CONTEXT_BSK = "_concretelang_base_context_bsk";
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
LweKeyswitchKey_u64 *
|
||||
get_keyswitch_key(mlir::concretelang::RuntimeContext *context) {
|
||||
return context->ksk;
|
||||
@@ -18,6 +26,7 @@ get_keyswitch_key(mlir::concretelang::RuntimeContext *context) {
|
||||
|
||||
LweBootstrapKey_u64 *
|
||||
get_bootstrap_key(mlir::concretelang::RuntimeContext *context) {
|
||||
using RuntimeContext = mlir::concretelang::RuntimeContext;
|
||||
#ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED
|
||||
std::string threadName = hpx::get_thread_name();
|
||||
auto bskIt = context->bsk.find(threadName);
|
||||
@@ -26,12 +35,11 @@ get_bootstrap_key(mlir::concretelang::RuntimeContext *context) {
|
||||
.insert(std::pair<std::string, LweBootstrapKey_u64 *>(
|
||||
threadName,
|
||||
clone_lwe_bootstrap_key_u64(
|
||||
context->bsk["_concretelang_base_context_bsk"])))
|
||||
context->bsk[RuntimeContext::BASE_CONTEXT_BSK])))
|
||||
.first;
|
||||
}
|
||||
#else
|
||||
std::string baseName = "_concretelang_base_context_bsk";
|
||||
auto bskIt = context->bsk.find(baseName);
|
||||
auto bskIt = context->bsk.find(RuntimeContext::BASE_CONTEXT_BSK);
|
||||
#endif
|
||||
assert(bskIt->second && "No bootstrap key available in context");
|
||||
return bskIt->second;
|
||||
|
||||
27
compiler/lib/ServerLib/CMakeLists.txt
Normal file
27
compiler/lib/ServerLib/CMakeLists.txt
Normal file
@@ -0,0 +1,27 @@
|
||||
add_compile_options( -Werror )
|
||||
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
||||
# using Clang
|
||||
add_compile_options( -Wno-error=pessimizing-move -Wno-pessimizing-move )
|
||||
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
|
||||
# using GCC
|
||||
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
|
||||
add_compile_options( -Wno-error=cast-function-type -Wno-cast-function-type)
|
||||
add_compile_options( -Werror -Wno-error=pessimizing-move -Wno-pessimizing-move )
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_mlir_library(
|
||||
ConcretelangServerLib
|
||||
|
||||
DynamicRankCall.cpp
|
||||
ServerLambda.cpp
|
||||
DynamicModule.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/ServerLib
|
||||
|
||||
ConcretelangRuntime
|
||||
ConcretelangSupportLib
|
||||
ConcretelangClientLib
|
||||
)
|
||||
54
compiler/lib/ServerLib/DynamicModule.cpp
Normal file
54
compiler/lib/ServerLib/DynamicModule.cpp
Normal file
@@ -0,0 +1,54 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <fstream>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/ServerLib/DynamicModule.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace serverlib {
|
||||
|
||||
using concretelang::error::StringError;
|
||||
using mlir::concretelang::CompilerEngine;
|
||||
|
||||
DynamicModule::~DynamicModule() {
|
||||
if (libraryHandle != nullptr) {
|
||||
dlclose(libraryHandle);
|
||||
}
|
||||
}
|
||||
|
||||
outcome::checked<std::shared_ptr<DynamicModule>, StringError>
|
||||
DynamicModule::open(std::string libPath) {
|
||||
std::shared_ptr<DynamicModule> module = std::make_shared<DynamicModule>();
|
||||
OUTCOME_TRYV(module->loadClientParametersJSON(libPath));
|
||||
OUTCOME_TRYV(module->loadSharedLibrary(libPath));
|
||||
return module;
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
DynamicModule::loadSharedLibrary(std::string path) {
|
||||
libraryHandle = dlopen(
|
||||
CompilerEngine::Library::getSharedLibraryPath(path).c_str(), RTLD_LAZY);
|
||||
if (!libraryHandle) {
|
||||
return StringError("Cannot open shared library") << dlerror();
|
||||
}
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
DynamicModule::loadClientParametersJSON(std::string libPath) {
|
||||
auto jsonPath = CompilerEngine::Library::getClientParametersPath(libPath);
|
||||
OUTCOME_TRY(auto clientParams, ClientParameters::load(jsonPath));
|
||||
this->clientParametersList = clientParams;
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
} // namespace serverlib
|
||||
} // namespace concretelang
|
||||
163
compiler/lib/ServerLib/DynamicRankCall.cpp
Normal file
163
compiler/lib/ServerLib/DynamicRankCall.cpp
Normal file
@@ -0,0 +1,163 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
// generated: see genDynamicRandCall.py
|
||||
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
#include "concretelang/ServerLib/DynamicArityCall.h"
|
||||
#include "concretelang/ServerLib/ServerLambda.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace serverlib {
|
||||
|
||||
encrypted_scalars_and_sizes_t
|
||||
multi_arity_call_dynamic_rank(void *(*func)(void *...),
|
||||
std::vector<void *> args, size_t rank) {
|
||||
using concretelang::clientlib::MemRefDescriptor;
|
||||
constexpr auto convert = &encrypted_scalars_and_sizes_t_from_MemRef;
|
||||
switch (rank) {
|
||||
case 0: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<1>(*)(void *...))func, args);
|
||||
return convert(1, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 1: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<2>(*)(void *...))func, args);
|
||||
return convert(2, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 2: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<3>(*)(void *...))func, args);
|
||||
return convert(3, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 3: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<4>(*)(void *...))func, args);
|
||||
return convert(4, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 4: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<5>(*)(void *...))func, args);
|
||||
return convert(5, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 5: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<6>(*)(void *...))func, args);
|
||||
return convert(6, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 6: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<7>(*)(void *...))func, args);
|
||||
return convert(7, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 7: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<8>(*)(void *...))func, args);
|
||||
return convert(8, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 8: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<9>(*)(void *...))func, args);
|
||||
return convert(9, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 9: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<10>(*)(void *...))func, args);
|
||||
return convert(10, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 10: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<11>(*)(void *...))func, args);
|
||||
return convert(11, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 11: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<12>(*)(void *...))func, args);
|
||||
return convert(12, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 12: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<13>(*)(void *...))func, args);
|
||||
return convert(13, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 13: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<14>(*)(void *...))func, args);
|
||||
return convert(14, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 14: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<15>(*)(void *...))func, args);
|
||||
return convert(15, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 15: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<16>(*)(void *...))func, args);
|
||||
return convert(16, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 16: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<17>(*)(void *...))func, args);
|
||||
return convert(17, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 17: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<18>(*)(void *...))func, args);
|
||||
return convert(18, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 18: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<19>(*)(void *...))func, args);
|
||||
return convert(19, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 19: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<20>(*)(void *...))func, args);
|
||||
return convert(20, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 20: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<21>(*)(void *...))func, args);
|
||||
return convert(21, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 21: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<22>(*)(void *...))func, args);
|
||||
return convert(22, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 22: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<23>(*)(void *...))func, args);
|
||||
return convert(23, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 23: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<24>(*)(void *...))func, args);
|
||||
return convert(24, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 24: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<25>(*)(void *...))func, args);
|
||||
return convert(25, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 25: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<26>(*)(void *...))func, args);
|
||||
return convert(26, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 26: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<27>(*)(void *...))func, args);
|
||||
return convert(27, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 27: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<28>(*)(void *...))func, args);
|
||||
return convert(28, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 28: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<29>(*)(void *...))func, args);
|
||||
return convert(29, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 29: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<30>(*)(void *...))func, args);
|
||||
return convert(30, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 30: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<31>(*)(void *...))func, args);
|
||||
return convert(31, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 31: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<32>(*)(void *...))func, args);
|
||||
return convert(32, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 32: {
|
||||
auto m = multi_arity_call((MemRefDescriptor<33>(*)(void *...))func, args);
|
||||
return convert(33, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace serverlib
|
||||
} // namespace concretelang
|
||||
159
compiler/lib/ServerLib/ServerLambda.cpp
Normal file
159
compiler/lib/ServerLib/ServerLambda.cpp
Normal file
@@ -0,0 +1,159 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <dlfcn.h>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/ClientLib/Serializers.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/ServerLib/DynamicArityCall.h"
|
||||
#include "concretelang/ServerLib/DynamicModule.h"
|
||||
#include "concretelang/ServerLib/DynamicRankCall.h"
|
||||
#include "concretelang/ServerLib/ServerLambda.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace serverlib {
|
||||
|
||||
using concretelang::clientlib::CircuitGate;
|
||||
using concretelang::clientlib::CircuitGateShape;
|
||||
using concretelang::clientlib::PublicArguments;
|
||||
using concretelang::error::StringError;
|
||||
|
||||
void next_coord_index(size_t index[], size_t sizes[], size_t rank) {
|
||||
// increase multi dim index
|
||||
for (int r = rank - 1; r >= 0; r--) {
|
||||
if (index[r] < sizes[r] - 1) {
|
||||
index[r]++;
|
||||
return;
|
||||
}
|
||||
index[r] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
size_t global_index(size_t index[], size_t sizes[], size_t strides[],
|
||||
size_t rank) {
|
||||
// compute global index from multi dim index
|
||||
size_t g_index = 0;
|
||||
size_t default_stride = 1;
|
||||
for (int r = rank - 1; r >= 0; r--) {
|
||||
g_index += index[r] * ((strides[r] == 0) ? default_stride : strides[r]);
|
||||
default_stride *= sizes[r];
|
||||
}
|
||||
return g_index;
|
||||
}
|
||||
|
||||
/** Helper function to convert from MemRefDescriptor to
|
||||
* encrypted_scalars_and_sizes_t assuming MemRefDescriptor are bufferized */
|
||||
encrypted_scalars_and_sizes_t encrypted_scalars_and_sizes_t_from_MemRef(
|
||||
size_t memref_rank, encrypted_scalars_t allocated,
|
||||
encrypted_scalars_t aligned, size_t offset, size_t *sizes,
|
||||
size_t *strides) {
|
||||
|
||||
(void)allocated;
|
||||
encrypted_scalars_and_sizes_t result;
|
||||
assert(aligned != nullptr);
|
||||
result.sizes.resize(memref_rank);
|
||||
for (size_t r = 0; r < memref_rank; r++) {
|
||||
result.sizes[r] = sizes[r];
|
||||
}
|
||||
size_t
|
||||
index[memref_rank]; // ephemeral multi dim index to compute global strides
|
||||
for (size_t r = 0; r < memref_rank; r++) {
|
||||
index[r] = 0;
|
||||
}
|
||||
auto len = result.length();
|
||||
result.values.resize(len);
|
||||
// TODO: add a fast path for dense result (no real strides)
|
||||
for (size_t i = 0; i < len; i++) {
|
||||
int g_index = offset + global_index(index, sizes, strides, memref_rank);
|
||||
result.values[i] = aligned[offset + g_index];
|
||||
next_coord_index(index, sizes, memref_rank);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
outcome::checked<ServerLambda, StringError>
|
||||
ServerLambda::loadFromModule(std::shared_ptr<DynamicModule> module,
|
||||
std::string funcName) {
|
||||
ServerLambda lambda;
|
||||
lambda.module =
|
||||
module; // prevent module and library handler from being destroyed
|
||||
lambda.func =
|
||||
(void *(*)(void *, ...))dlsym(module->libraryHandle, funcName.c_str());
|
||||
|
||||
if (auto err = dlerror()) {
|
||||
return StringError("Cannot open lambda:") << std::string(err);
|
||||
}
|
||||
|
||||
auto param =
|
||||
llvm::find_if(module->clientParametersList, [&](ClientParameters param) {
|
||||
return param.functionName == funcName;
|
||||
});
|
||||
|
||||
if (param == module->clientParametersList.end()) {
|
||||
return StringError("cannot find function ")
|
||||
<< funcName << "in client parameters";
|
||||
}
|
||||
|
||||
if (param->outputs.size() != 1) {
|
||||
return StringError("ServerLambda: output arity (")
|
||||
<< std::to_string(param->outputs.size())
|
||||
<< ") != 1 is not supported";
|
||||
}
|
||||
|
||||
if (!param->outputs[0].encryption.hasValue()) {
|
||||
return StringError("ServerLambda: clear output is not yet supported");
|
||||
}
|
||||
|
||||
lambda.clientParameters = *param;
|
||||
return lambda;
|
||||
}
|
||||
|
||||
outcome::checked<ServerLambda, StringError>
|
||||
ServerLambda::load(std::string funcName, std::string outputLib) {
|
||||
OUTCOME_TRY(auto module, DynamicModule::open(outputLib));
|
||||
return ServerLambda::loadFromModule(module, funcName);
|
||||
}
|
||||
|
||||
encrypted_scalars_and_sizes_t dynamicCall(void *(*func)(void *...),
|
||||
std::vector<void *> &preparedArgs,
|
||||
CircuitGate &output,
|
||||
std::ostream &ostream) {
|
||||
size_t rank = output.shape.dimensions.size();
|
||||
return multi_arity_call_dynamic_rank(func, preparedArgs, rank);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
ServerLambda::read_call_write(std::istream &istream, std::ostream &ostream) {
|
||||
OUTCOME_TRY(auto argumentsPtr,
|
||||
PublicArguments::unserialize(clientParameters, istream));
|
||||
assert(istream.good());
|
||||
|
||||
PublicArguments &arguments = *argumentsPtr;
|
||||
// The runtime context is always the last argument list
|
||||
arguments.preparedArgs.push_back((void *)&arguments.runtimeContext);
|
||||
auto values_and_sizes = dynamicCall(this->func, arguments.preparedArgs,
|
||||
clientParameters.outputs[0], ostream);
|
||||
auto shape = clientParameters.outputs[0].shape;
|
||||
size_t rank = shape.dimensions.size();
|
||||
for (size_t dim = 0; dim < rank; dim++) {
|
||||
if (values_and_sizes.sizes[dim] != (size_t)shape.dimensions[dim]) {
|
||||
return StringError("Dimension mismatch on dim ")
|
||||
<< dim << " actual: " << values_and_sizes.sizes[dim]
|
||||
<< " vs expected: " << shape.dimensions[dim] << "\n";
|
||||
}
|
||||
}
|
||||
serializeEncryptedValues(values_and_sizes, ostream);
|
||||
if (ostream.fail()) {
|
||||
return StringError("Cannot write result");
|
||||
}
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
} // namespace serverlib
|
||||
} // namespace concretelang
|
||||
44
compiler/lib/ServerLib/genDynamicRankCall.py
Normal file
44
compiler/lib/ServerLib/genDynamicRankCall.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt for license information.
|
||||
|
||||
print(
|
||||
"""// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
// generated: see genDynamicRandCall.py
|
||||
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
#include "concretelang/ServerLib/DynamicArityCall.h"
|
||||
#include "concretelang/ServerLib/ServerLambda.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace serverlib {
|
||||
|
||||
encrypted_scalars_and_sizes_t
|
||||
multi_arity_call_dynamic_rank(void* (*func)(void *...), std::vector<void *> args, size_t rank) {
|
||||
using concretelang::clientlib::MemRefDescriptor;
|
||||
constexpr auto convert = &encrypted_scalars_and_sizes_t_from_MemRef;
|
||||
switch (rank) {""")
|
||||
|
||||
for tensor_rank in range(0, 33):
|
||||
memref_rank = tensor_rank + 1
|
||||
print(f""" case {tensor_rank}:
|
||||
{{
|
||||
auto m = multi_arity_call((MemRefDescriptor<{memref_rank}>(*)(void *...))func, args);
|
||||
return convert({memref_rank}, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}}""")
|
||||
|
||||
print("""
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
}""")
|
||||
|
||||
print("""
|
||||
} // namespace serverlib
|
||||
} // namespace concretelang""")
|
||||
@@ -30,7 +30,8 @@ add_mlir_library(ConcretelangSupport
|
||||
|
||||
MLIRExecutionEngine
|
||||
${LLVM_PTHREAD_LIB}
|
||||
|
||||
|
||||
ConcretelangCommon
|
||||
ConcretelangRuntime
|
||||
ConcretelangClientLib
|
||||
)
|
||||
|
||||
@@ -5,10 +5,12 @@
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <regex>
|
||||
#include <stdio.h>
|
||||
#include <string>
|
||||
|
||||
#include <llvm/Support/Error.h>
|
||||
#include <llvm/Support/Path.h>
|
||||
#include <llvm/Support/SMLoc.h>
|
||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||
#include <mlir/Dialect/Linalg/IR/LinalgOps.h>
|
||||
@@ -19,6 +21,7 @@
|
||||
#include <mlir/ExecutionEngine/OptUtils.h>
|
||||
#include <mlir/Parser.h>
|
||||
|
||||
#include <concretelang/ClientLib/ClientParameters.h>
|
||||
#include <concretelang/Dialect/BConcrete/IR/BConcreteDialect.h>
|
||||
#include <concretelang/Dialect/Concrete/IR/ConcreteDialect.h>
|
||||
#include <concretelang/Dialect/FHE/IR/FHEDialect.h>
|
||||
@@ -280,9 +283,10 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
auto funcName = this->clientParametersFuncName.getValueOr("main");
|
||||
if (this->generateClientParameters || target == Target::LIBRARY) {
|
||||
if (!res.fheContext.hasValue()) {
|
||||
// Some tests can involves a usual function
|
||||
res.clientParameters =
|
||||
mlir::concretelang::emptyClientParametersForV0(funcName, module);
|
||||
// Some tests involve call a to non encrypted functions
|
||||
ClientParameters emptyParams;
|
||||
emptyParams.functionName = funcName;
|
||||
res.clientParameters = emptyParams;
|
||||
} else {
|
||||
auto clientParametersOrErr =
|
||||
mlir::concretelang::createClientParametersForV0(*res.fheContext,
|
||||
@@ -425,7 +429,7 @@ std::string CompilerEngine::Library::getStaticLibraryPath(std::string path) {
|
||||
|
||||
/** Returns the path of the static library */
|
||||
std::string CompilerEngine::Library::getClientParametersPath(std::string path) {
|
||||
return path + CLIENT_PARAMETERS_EXT;
|
||||
return ClientParameters::getClientParametersPath(path);
|
||||
}
|
||||
|
||||
const std::string CompilerEngine::Library::OBJECT_EXT = ".o";
|
||||
@@ -466,6 +470,92 @@ CompilerEngine::Library::emitClientParametersJSON() {
|
||||
return clientParamsPath;
|
||||
}
|
||||
|
||||
static std::string ccpResultType(size_t rank) {
|
||||
if (rank == 0) {
|
||||
return "scalar_out";
|
||||
} else {
|
||||
return "tensor" + std::to_string(rank) + "_out";
|
||||
}
|
||||
}
|
||||
|
||||
static std::string ccpArgType(size_t rank) {
|
||||
if (rank == 0) {
|
||||
return "scalar_in";
|
||||
} else {
|
||||
return "tensor" + std::to_string(rank) + "_in";
|
||||
}
|
||||
}
|
||||
|
||||
static std::string cppArgsType(std::vector<CircuitGate> inputs) {
|
||||
std::string args;
|
||||
for (auto input : inputs) {
|
||||
if (!args.empty()) {
|
||||
args += ", ";
|
||||
}
|
||||
args += ccpArgType(input.shape.dimensions.size());
|
||||
}
|
||||
return args;
|
||||
}
|
||||
|
||||
llvm::Expected<std::string> CompilerEngine::Library::emitCppHeader() {
|
||||
auto libraryName = llvm::sys::path::filename(libraryPath).str();
|
||||
auto headerName = libraryName + "-client.h";
|
||||
auto headerPath = std::regex_replace(
|
||||
libraryPath, std::regex(libraryName + "$"), headerName);
|
||||
|
||||
std::error_code error;
|
||||
llvm::raw_fd_ostream out(headerPath, error);
|
||||
if (error) {
|
||||
StreamStringError("Cannot emit header: ")
|
||||
<< headerPath << ", " << error.message() << "\n";
|
||||
}
|
||||
|
||||
out << "#include \"boost/outcome.h\"\n";
|
||||
out << "#include \"concretelang/ClientLib/ClientLambda.h\"\n";
|
||||
out << "#include \"concretelang/ClientLib/KeySetCache.h\"\n";
|
||||
out << "#include \"concretelang/ClientLib/Types.h\"\n";
|
||||
out << "#include \"concretelang/Common/Error.h\"\n";
|
||||
out << "\n";
|
||||
out << "namespace " << libraryName << " {\n";
|
||||
out << "namespace client {\n";
|
||||
|
||||
for (auto params : clientParametersList) {
|
||||
std::string args;
|
||||
std::string result;
|
||||
if (params.outputs.size() > 0) {
|
||||
args = cppArgsType(params.inputs);
|
||||
} else {
|
||||
args = "void";
|
||||
}
|
||||
if (params.outputs.size() > 0) {
|
||||
size_t rank = params.outputs[0].shape.dimensions.size();
|
||||
result = ccpResultType(rank);
|
||||
} else {
|
||||
result = "void";
|
||||
}
|
||||
out << "\n";
|
||||
out << "namespace " << params.functionName << " {\n";
|
||||
out << " using namespace concretelang::clientlib;\n";
|
||||
out << " using concretelang::error::StringError;\n";
|
||||
out << " using " << params.functionName << "_t = TypedClientLambda<"
|
||||
<< result << ", " << args << ">;\n";
|
||||
out << " static const std::string name = \"" << params.functionName
|
||||
<< "\";\n";
|
||||
out << "\n";
|
||||
out << " static outcome::checked<extract_t, StringError>\n";
|
||||
out << " load(std::string outputLib)\n";
|
||||
out << " { return extract_t::load(name, outputLib); }\n";
|
||||
out << "} // namespace " << params.functionName << "\n";
|
||||
}
|
||||
out << "\n";
|
||||
out << "} // namespace client\n";
|
||||
out << "} // namespace " << libraryName << "\n";
|
||||
|
||||
out.close();
|
||||
|
||||
return headerPath;
|
||||
}
|
||||
|
||||
llvm::Expected<std::string>
|
||||
CompilerEngine::Library::addCompilation(CompilationResult &compilation) {
|
||||
llvm::Module *module = compilation.llvmModule.get();
|
||||
@@ -538,6 +628,9 @@ llvm::Error CompilerEngine::Library::emitArtifacts() {
|
||||
if (auto err = emitClientParametersJSON().takeError()) {
|
||||
return err;
|
||||
}
|
||||
if (auto err = emitCppHeader().takeError()) {
|
||||
return err;
|
||||
}
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
|
||||
|
||||
#include "concretelang/Common/BitsSize.h"
|
||||
#include <concretelang/Support/Error.h>
|
||||
#include <concretelang/Support/Jit.h>
|
||||
#include <concretelang/Support/logging.h>
|
||||
@@ -198,12 +199,14 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) {
|
||||
// Else if is encryted, allocate ciphertext and encrypt.
|
||||
uint64_t *ctArg;
|
||||
uint64_t ctSize;
|
||||
if (auto err = this->keySet.allocate_lwe(pos, &ctArg, ctSize)) {
|
||||
return std::move(err);
|
||||
auto check = this->keySet.allocate_lwe(pos, &ctArg, ctSize);
|
||||
if (!check) {
|
||||
return StreamStringError(check.error().mesg);
|
||||
}
|
||||
allocatedCiphertexts.push_back(ctArg);
|
||||
if (auto err = this->keySet.encrypt_lwe(pos, ctArg, arg)) {
|
||||
return std::move(err);
|
||||
check = this->keySet.encrypt_lwe(pos, ctArg, arg);
|
||||
if (!check) {
|
||||
return StreamStringError(check.error().mesg);
|
||||
}
|
||||
// memref calling convention
|
||||
// allocated
|
||||
@@ -224,17 +227,6 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) {
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
size_t bitWidthAsWord(size_t exactBitWidth) {
|
||||
size_t sortedWordBitWidths[] = {8, 16, 32, 64};
|
||||
size_t previousWidth = 0;
|
||||
for (auto currentWidth : sortedWordBitWidths) {
|
||||
if (previousWidth < exactBitWidth && exactBitWidth <= currentWidth) {
|
||||
return currentWidth;
|
||||
}
|
||||
}
|
||||
return exactBitWidth;
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width,
|
||||
const void *data,
|
||||
llvm::ArrayRef<int64_t> shape) {
|
||||
@@ -253,7 +245,7 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width,
|
||||
return llvm::make_error<llvm::StringError>(msg,
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
auto roundedSize = bitWidthAsWord(info.shape.width);
|
||||
auto roundedSize = ::concretelang::common::bitWidthAsWord(info.shape.width);
|
||||
if (width != roundedSize) {
|
||||
auto msg = "Bad argument (pos=" + llvm::Twine(pos) + ") : expected " +
|
||||
llvm::Twine(roundedSize) + "bits" + " but received " +
|
||||
@@ -315,9 +307,9 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width,
|
||||
for (size_t i = 0, offset = 0; i < info.shape.size;
|
||||
i++, offset += lweSize) {
|
||||
|
||||
if (auto err =
|
||||
this->keySet.encrypt_lwe(pos, ctBuffer + offset, data8[i])) {
|
||||
return std::move(err);
|
||||
auto check = this->keySet.encrypt_lwe(pos, ctBuffer + offset, data8[i]);
|
||||
if (!check) {
|
||||
return StreamStringError(check.error().mesg);
|
||||
}
|
||||
}
|
||||
// Replace the data by the buffer to ciphertext
|
||||
@@ -387,8 +379,9 @@ llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t &res) {
|
||||
}
|
||||
// Else if is encryted, decrypt
|
||||
uint64_t *ct = (uint64_t *)(outputs[offset + 1]);
|
||||
if (auto err = this->keySet.decrypt_lwe(pos, ct, res)) {
|
||||
return std::move(err);
|
||||
auto check = this->keySet.decrypt_lwe(pos, ct, res);
|
||||
if (!check) {
|
||||
return StreamStringError(check.error().mesg);
|
||||
}
|
||||
return llvm::Error::success();
|
||||
}
|
||||
@@ -504,8 +497,9 @@ llvm::Error JITLambda::Argument::getResult(size_t pos, void *res,
|
||||
|
||||
for (size_t i = 0, o = 0; i < numElements; i++, o += lweSize) {
|
||||
uint64_t *ct = ((uint64_t *)alignedBytes) + o;
|
||||
if (auto err = this->keySet.decrypt_lwe(pos, ct, ((uint64_t *)res)[i])) {
|
||||
return std::move(err);
|
||||
auto check = this->keySet.decrypt_lwe(pos, ct, ((uint64_t *)res)[i]);
|
||||
if (!check) {
|
||||
return StreamStringError(check.error().mesg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,16 +116,18 @@ JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName,
|
||||
"parameters has not been computed");
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<mlir::concretelang::KeySet>> keySetOrErr =
|
||||
(cache.hasValue())
|
||||
? cache->tryLoadOrGenerateSave(*compRes.clientParameters, 0, 0)
|
||||
: KeySet::generate(*compRes.clientParameters, 0, 0);
|
||||
std::shared_ptr<KeySetCache> cachePtr;
|
||||
if (cache.hasValue()) {
|
||||
cachePtr = std::make_shared<KeySetCache>(cache.getValue());
|
||||
}
|
||||
auto keySetOrErr =
|
||||
KeySetCache::generate(cachePtr, *compRes.clientParameters, 0, 0);
|
||||
|
||||
if (!keySetOrErr) {
|
||||
return keySetOrErr.takeError();
|
||||
return StreamStringError(keySetOrErr.error().mesg);
|
||||
}
|
||||
|
||||
auto keySet = std::move(keySetOrErr.get());
|
||||
auto keySet = std::move(keySetOrErr.value());
|
||||
|
||||
return Lambda{this->compilationContext, std::move(lambda), std::move(keySet)};
|
||||
}
|
||||
|
||||
@@ -16,6 +16,15 @@
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
using ::concretelang::clientlib::BIG_KEY;
|
||||
using ::concretelang::clientlib::CircuitGate;
|
||||
using ::concretelang::clientlib::ClientParameters;
|
||||
using ::concretelang::clientlib::EncryptionGate;
|
||||
using ::concretelang::clientlib::LweSecretKeyID;
|
||||
using ::concretelang::clientlib::Precision;
|
||||
using ::concretelang::clientlib::SMALL_KEY;
|
||||
using ::concretelang::clientlib::Variance;
|
||||
|
||||
const auto securityLevel = SECURITY_LEVEL_128;
|
||||
const auto keyFormat = KEY_FORMAT_BINARY;
|
||||
const auto v0Curve = getV0Curves(securityLevel, keyFormat);
|
||||
@@ -81,13 +90,6 @@ llvm::Expected<CircuitGate> gateFromMLIRType(LweSecretKeyID secretKeyID,
|
||||
"cannot convert MLIR type to shape", llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
ClientParameters emptyClientParametersForV0(llvm::StringRef functionName,
|
||||
mlir::ModuleOp module) {
|
||||
ClientParameters c;
|
||||
c.functionName = (std::string)functionName;
|
||||
return c;
|
||||
}
|
||||
|
||||
llvm::Expected<ClientParameters>
|
||||
createClientParametersForV0(V0FHEContext fheContext,
|
||||
llvm::StringRef functionName,
|
||||
|
||||
@@ -1,187 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/TestLib/Arguments.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include "concretelang/Support/Jit.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
Arguments::~Arguments() {
|
||||
for (auto ct : allocatedCiphertexts) {
|
||||
free(ct);
|
||||
}
|
||||
for (auto ctBuffer : ciphertextBuffers) {
|
||||
free(ctBuffer);
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<Arguments> Arguments::create(KeySet &keySet) {
|
||||
auto args = std::make_shared<Arguments>(keySet);
|
||||
return args;
|
||||
}
|
||||
|
||||
llvm::Error Arguments::pushArg(uint64_t arg) {
|
||||
if (auto err = checkPushTooManyArgs()) {
|
||||
return err;
|
||||
}
|
||||
|
||||
auto pos = currentPos++;
|
||||
CircuitGate input = keySet.inputGate(pos);
|
||||
if (input.shape.size != 0) {
|
||||
return StreamStringError("argument #") << pos << " is not a scalar";
|
||||
}
|
||||
if (!input.encryption.hasValue()) {
|
||||
// clear scalar: just push the argument
|
||||
if (input.shape.width != 64) {
|
||||
return StreamStringError(
|
||||
"scalar argument of with != 64 is not supported for DynamicLambda");
|
||||
}
|
||||
preparedArgs.push_back((void *)arg);
|
||||
return llvm::Error::success();
|
||||
}
|
||||
// encrypted scalar: allocate, encrypt and push
|
||||
uint64_t *ctArg;
|
||||
uint64_t ctSize = 0;
|
||||
if (auto err = keySet.allocate_lwe(pos, &ctArg, ctSize)) {
|
||||
return err;
|
||||
}
|
||||
allocatedCiphertexts.push_back(ctArg);
|
||||
if (auto err = keySet.encrypt_lwe(pos, ctArg, arg)) {
|
||||
return err;
|
||||
}
|
||||
// Note: Since we bufferized lwe ciphertext take care of memref calling
|
||||
// convention
|
||||
// allocated
|
||||
preparedArgs.push_back(nullptr);
|
||||
// aligned
|
||||
preparedArgs.push_back(ctArg);
|
||||
// offset
|
||||
preparedArgs.push_back((void *)0);
|
||||
// size
|
||||
preparedArgs.push_back((void *)ctSize);
|
||||
// stride
|
||||
preparedArgs.push_back((void *)1);
|
||||
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error Arguments::pushArg(std::vector<uint8_t> arg) {
|
||||
return pushArg(8, (void *)arg.data(), {(int64_t)arg.size()});
|
||||
}
|
||||
|
||||
llvm::Error Arguments::pushArg(size_t width, void *data,
|
||||
llvm::ArrayRef<int64_t> shape) {
|
||||
if (auto err = checkPushTooManyArgs()) {
|
||||
return err;
|
||||
}
|
||||
auto pos = currentPos;
|
||||
currentPos = currentPos + 1;
|
||||
CircuitGate input = keySet.inputGate(pos);
|
||||
// Check the width of data
|
||||
if (input.shape.width > 64) {
|
||||
return StreamStringError("argument #")
|
||||
<< pos << " width > 64 bits is not supported";
|
||||
}
|
||||
auto roundedSize = bitWidthAsWord(input.shape.width);
|
||||
if (width != roundedSize) {
|
||||
return StreamStringError("argument #")
|
||||
<< pos << "width mismatch, got " << width << " expected "
|
||||
<< roundedSize;
|
||||
}
|
||||
// Check the shape of tensor
|
||||
if (input.shape.dimensions.empty()) {
|
||||
return StreamStringError("argument #") << pos << "is not a tensor";
|
||||
}
|
||||
if (shape.size() != input.shape.dimensions.size()) {
|
||||
return StreamStringError("argument #")
|
||||
<< pos << "has not the expected number of dimension, got "
|
||||
<< shape.size() << " expected " << input.shape.dimensions.size();
|
||||
}
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
if (shape[i] != input.shape.dimensions[i]) {
|
||||
return StreamStringError("argument #")
|
||||
<< pos << " has not the expected dimension #" << i << " , got "
|
||||
<< shape[i] << " expected " << input.shape.dimensions[i];
|
||||
}
|
||||
}
|
||||
if (input.encryption.hasValue()) {
|
||||
// Encrypted tensor: for now we support only 8 bits for encrypted tensor
|
||||
if (width != 8) {
|
||||
return StreamStringError("argument #")
|
||||
<< pos << " width mismatch, expected 8 got " << width;
|
||||
}
|
||||
const uint8_t *data8 = (const uint8_t *)data;
|
||||
|
||||
// Allocate a buffer for ciphertexts of size of tensor
|
||||
auto lweSize = keySet.getInputLweSecretKeyParam(pos).size + 1;
|
||||
auto ctBuffer =
|
||||
(uint64_t *)malloc(input.shape.size * lweSize * sizeof(uint64_t));
|
||||
ciphertextBuffers.push_back(ctBuffer);
|
||||
// Allocate ciphertexts and encrypt, for every values in tensor
|
||||
for (size_t i = 0, offset = 0; i < input.shape.size;
|
||||
i++, offset += lweSize) {
|
||||
|
||||
if (auto err =
|
||||
this->keySet.encrypt_lwe(pos, ctBuffer + offset, data8[i])) {
|
||||
return err;
|
||||
}
|
||||
}
|
||||
// Replace the data by the buffer to ciphertext
|
||||
data = (void *)ctBuffer;
|
||||
}
|
||||
// allocated
|
||||
preparedArgs.push_back(nullptr);
|
||||
// aligned
|
||||
preparedArgs.push_back(data);
|
||||
// offset
|
||||
preparedArgs.push_back((void *)0);
|
||||
// sizes
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
preparedArgs.push_back((void *)shape[i]);
|
||||
}
|
||||
// If encrypted +1 for the lwe size rank
|
||||
if (keySet.isInputEncrypted(pos)) {
|
||||
preparedArgs.push_back(
|
||||
(void *)(keySet.getInputLweSecretKeyParam(pos).size + 1));
|
||||
}
|
||||
// Set the stride for each dimension, equal to the product of the
|
||||
// following dimensions.
|
||||
int64_t stride = 1;
|
||||
// If encrypted +1 set the stride for the lwe size rank
|
||||
if (keySet.isInputEncrypted(pos)) {
|
||||
stride *= keySet.getInputLweSecretKeyParam(pos).size + 1;
|
||||
}
|
||||
for (ssize_t i = shape.size() - 1; i >= 0; i--) {
|
||||
preparedArgs.push_back((void *)stride);
|
||||
stride *= shape[i];
|
||||
}
|
||||
if (keySet.isInputEncrypted(pos)) {
|
||||
preparedArgs.push_back((void *)1);
|
||||
}
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error Arguments::pushContext() {
|
||||
if (currentPos < keySet.numInputs()) {
|
||||
return StreamStringError("Missing arguments");
|
||||
}
|
||||
preparedArgs.push_back(&context);
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error Arguments::checkPushTooManyArgs() {
|
||||
size_t arity = keySet.numInputs();
|
||||
if (currentPos < arity) {
|
||||
return llvm::Error::success();
|
||||
}
|
||||
return StreamStringError("function has arity ")
|
||||
<< arity << " but is applied to too many arguments";
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -1,16 +0,0 @@
|
||||
add_mlir_library(ConcretelangTestLib
|
||||
Arguments.cpp
|
||||
DynamicLambda.cpp
|
||||
DynamicModule.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/TestLib
|
||||
|
||||
DEPENDS
|
||||
MLIRConversionPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
|
||||
ConcretelangSupport
|
||||
ConcretelangClientLib
|
||||
)
|
||||
@@ -1,217 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <dlfcn.h>
|
||||
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include "concretelang/TestLib/DynamicLambda.h"
|
||||
#include "concretelang/TestLib/dynamicArityCall.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
template <size_t N> struct MemRefDescriptor {
|
||||
uint64_t *allocated;
|
||||
uint64_t *aligned;
|
||||
size_t offset;
|
||||
size_t sizes[N];
|
||||
size_t strides[N];
|
||||
};
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>>
|
||||
decryptSlice(KeySet &keySet, uint64_t *aligned, size_t size) {
|
||||
auto pos = 0;
|
||||
std::vector<uint64_t> result(size);
|
||||
auto lweSize = keySet.getInputLweSecretKeyParam(pos).size + 1;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
size_t offset = i * lweSize;
|
||||
auto err = keySet.decrypt_lwe(pos, aligned + offset, result[i]);
|
||||
if (err) {
|
||||
return StreamStringError()
|
||||
<< "cannot decrypt result #" << i << ", err:" << err;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
llvm::Expected<mlir::concretelang::DynamicLambda>
|
||||
DynamicLambda::load(std::string funcName, std::string outputLib) {
|
||||
auto moduleOrErr = mlir::concretelang::DynamicModule::open(outputLib);
|
||||
if (!moduleOrErr) {
|
||||
return moduleOrErr.takeError();
|
||||
}
|
||||
return mlir::concretelang::DynamicLambda::load(*moduleOrErr, funcName);
|
||||
}
|
||||
|
||||
llvm::Expected<DynamicLambda>
|
||||
DynamicLambda::load(std::shared_ptr<DynamicModule> module,
|
||||
std::string funcName) {
|
||||
DynamicLambda lambda;
|
||||
lambda.module =
|
||||
module; // prevent module and library handler from being destroyed
|
||||
lambda.func = dlsym(module->libraryHandle, funcName.c_str());
|
||||
|
||||
if (auto err = dlerror()) {
|
||||
return StreamStringError("Cannot open lambda: ") << err;
|
||||
}
|
||||
|
||||
auto param =
|
||||
llvm::find_if(module->clientParametersList, [&](ClientParameters param) {
|
||||
return param.functionName == funcName;
|
||||
});
|
||||
|
||||
if (param == module->clientParametersList.end()) {
|
||||
return StreamStringError("cannot find function ")
|
||||
<< funcName << "in client parameters";
|
||||
}
|
||||
|
||||
if (param->outputs.size() != 1) {
|
||||
return StreamStringError("DynamicLambda: output arity (")
|
||||
<< std::to_string(param->outputs.size())
|
||||
<< ") != 1 is not supported";
|
||||
}
|
||||
|
||||
if (!param->outputs[0].encryption.hasValue()) {
|
||||
return StreamStringError(
|
||||
"DynamicLambda: clear output is not yet supported");
|
||||
}
|
||||
|
||||
lambda.clientParameters = *param;
|
||||
return lambda;
|
||||
}
|
||||
|
||||
template <>
|
||||
llvm::Expected<uint64_t> invoke<uint64_t>(DynamicLambda &lambda,
|
||||
const Arguments &args) {
|
||||
auto output = lambda.clientParameters.outputs[0];
|
||||
if (output.shape.size != 0) {
|
||||
return StreamStringError("the function doesn't return a scalar");
|
||||
}
|
||||
// Scalar encrypted result
|
||||
auto fCasted = (MemRefDescriptor<1>(*)(void *...))(lambda.func);
|
||||
MemRefDescriptor<1> lweResult =
|
||||
mlir::concretelang::call(fCasted, args.preparedArgs);
|
||||
|
||||
uint64_t decryptedResult;
|
||||
if (auto err =
|
||||
lambda.keySet->decrypt_lwe(0, lweResult.aligned, decryptedResult)) {
|
||||
return std::move(err);
|
||||
}
|
||||
return decryptedResult;
|
||||
}
|
||||
|
||||
template <size_t Rank>
|
||||
llvm::Expected<MemRefDescriptor<Rank>>
|
||||
DynamicLambda::invokeMemRefDecriptor(const Arguments &args) {
|
||||
auto output = clientParameters.outputs[0];
|
||||
if (output.shape.size == 0) {
|
||||
return StreamStringError("the function doesn't return a tensor");
|
||||
}
|
||||
if (output.shape.dimensions.size() != Rank - 1) {
|
||||
return StreamStringError("the function doesn't return a tensor of rank ")
|
||||
<< Rank - 1;
|
||||
}
|
||||
// Tensor encrypted result
|
||||
auto fCasted = (MemRefDescriptor<Rank>(*)(void *...))(func);
|
||||
auto encryptedResult = mlir::concretelang::call(fCasted, args.preparedArgs);
|
||||
|
||||
for (size_t dim = 0; dim < Rank - 1; dim++) {
|
||||
size_t actual_size = encryptedResult.sizes[dim];
|
||||
size_t expected_size = output.shape.dimensions[dim];
|
||||
if (actual_size != expected_size) {
|
||||
return StreamStringError("the function returned a vector of size ")
|
||||
<< actual_size << " instead of size " << expected_size;
|
||||
}
|
||||
}
|
||||
return encryptedResult;
|
||||
}
|
||||
|
||||
template <>
|
||||
llvm::Expected<std::vector<uint64_t>>
|
||||
invoke<std::vector<uint64_t>>(DynamicLambda &lambda, const Arguments &args) {
|
||||
auto encryptedResultOrErr = lambda.invokeMemRefDecriptor<2>(args);
|
||||
if (!encryptedResultOrErr) {
|
||||
return encryptedResultOrErr.takeError();
|
||||
}
|
||||
auto &encryptedResult = encryptedResultOrErr.get();
|
||||
auto &keySet = lambda.keySet;
|
||||
return decryptSlice(*keySet, encryptedResult.aligned,
|
||||
encryptedResult.sizes[0]);
|
||||
}
|
||||
|
||||
template <>
|
||||
llvm::Expected<std::vector<std::vector<uint64_t>>>
|
||||
invoke<std::vector<std::vector<uint64_t>>>(DynamicLambda &lambda,
|
||||
const Arguments &args) {
|
||||
auto encryptedResultOrErr = lambda.invokeMemRefDecriptor<3>(args);
|
||||
if (!encryptedResultOrErr) {
|
||||
return encryptedResultOrErr.takeError();
|
||||
}
|
||||
auto &encryptedResult = encryptedResultOrErr.get();
|
||||
|
||||
std::vector<std::vector<uint64_t>> result;
|
||||
result.reserve(encryptedResult.sizes[0]);
|
||||
for (size_t i = 0; i < encryptedResult.sizes[0]; i++) {
|
||||
int offset = encryptedResult.offset + i * encryptedResult.strides[1];
|
||||
auto slice = decryptSlice(*lambda.keySet, encryptedResult.aligned + offset,
|
||||
encryptedResult.sizes[1]);
|
||||
if (!slice) {
|
||||
return StreamStringError(llvm::toString(slice.takeError()));
|
||||
}
|
||||
result.push_back(slice.get());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
llvm::Expected<std::vector<std::vector<std::vector<uint64_t>>>>
|
||||
invoke<std::vector<std::vector<std::vector<uint64_t>>>>(DynamicLambda &lambda,
|
||||
const Arguments &args) {
|
||||
auto encryptedResultOrErr = lambda.invokeMemRefDecriptor<4>(args);
|
||||
if (!encryptedResultOrErr) {
|
||||
return encryptedResultOrErr.takeError();
|
||||
}
|
||||
auto &encryptedResult = encryptedResultOrErr.get();
|
||||
auto &keySet = lambda.keySet;
|
||||
|
||||
std::vector<std::vector<std::vector<uint64_t>>> result0;
|
||||
result0.reserve(encryptedResult.sizes[0]);
|
||||
for (size_t i = 0; i < encryptedResult.sizes[0]; i++) {
|
||||
std::vector<std::vector<uint64_t>> result1;
|
||||
result1.reserve(encryptedResult.sizes[1]);
|
||||
for (size_t j = 0; j < encryptedResult.sizes[1]; j++) {
|
||||
int offset = encryptedResult.offset + (i * encryptedResult.sizes[1] + j) *
|
||||
encryptedResult.strides[1];
|
||||
auto slice = decryptSlice(*keySet, encryptedResult.aligned + offset,
|
||||
encryptedResult.sizes[2]);
|
||||
if (!slice) {
|
||||
return StreamStringError(llvm::toString(slice.takeError()));
|
||||
}
|
||||
result1.push_back(slice.get());
|
||||
}
|
||||
result0.push_back(result1);
|
||||
}
|
||||
return result0;
|
||||
}
|
||||
|
||||
llvm::Error DynamicLambda::generateKeySet(llvm::Optional<KeySetCache> cache,
|
||||
uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
auto maybeKeySet =
|
||||
cache.hasValue()
|
||||
? cache->tryLoadOrGenerateSave(clientParameters, seed_msb, seed_lsb)
|
||||
: KeySet::generate(clientParameters, seed_msb, seed_lsb);
|
||||
|
||||
if (auto err = maybeKeySet.takeError()) {
|
||||
return err;
|
||||
}
|
||||
keySet = std::move(maybeKeySet.get());
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -1,61 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <fstream>
|
||||
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
|
||||
#include "concretelang/TestLib/DynamicModule.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
DynamicModule::~DynamicModule() {
|
||||
if (libraryHandle != nullptr) {
|
||||
dlclose(libraryHandle);
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Expected<std::shared_ptr<DynamicModule>>
|
||||
DynamicModule::open(std::string path) {
|
||||
std::shared_ptr<DynamicModule> module = std::make_shared<DynamicModule>();
|
||||
if (auto err = module->loadClientParametersJSON(path)) {
|
||||
return StreamStringError("Cannot load client parameters: ")
|
||||
<< llvm::toString(std::move(err));
|
||||
}
|
||||
if (auto err = module->loadSharedLibrary(path)) {
|
||||
return StreamStringError("Cannot load client parameters: ")
|
||||
<< llvm::toString(std::move(err));
|
||||
}
|
||||
return module;
|
||||
}
|
||||
|
||||
llvm::Error DynamicModule::loadSharedLibrary(std::string path) {
|
||||
libraryHandle = dlopen(
|
||||
CompilerEngine::Library::getSharedLibraryPath(path).c_str(), RTLD_LAZY);
|
||||
if (!libraryHandle) {
|
||||
return StreamStringError("Cannot open shared library") << dlerror();
|
||||
}
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error DynamicModule::loadClientParametersJSON(std::string path) {
|
||||
|
||||
std::ifstream file(CompilerEngine::Library::getClientParametersPath(path));
|
||||
std::string content((std::istreambuf_iterator<char>(file)),
|
||||
(std::istreambuf_iterator<char>()));
|
||||
llvm::Expected<std::vector<ClientParameters>> expectedClientParams =
|
||||
llvm::json::parse<std::vector<ClientParameters>>(content);
|
||||
if (auto err = expectedClientParams.takeError()) {
|
||||
return StreamStringError("Cannot open client parameters: ") << err;
|
||||
}
|
||||
this->clientParametersList = *expectedClientParams;
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -23,7 +23,6 @@ if(CONCRETELANG_PARALLEL_EXECUTION_ENABLED)
|
||||
RTDialect
|
||||
|
||||
ConcretelangSupport
|
||||
ConcretelangTestLib
|
||||
|
||||
-Wl,-rpath,${CMAKE_BINARY_DIR}/lib/Runtime
|
||||
-Wl,-rpath,${HPX_DIR}/../../
|
||||
|
||||
@@ -16,6 +16,7 @@ set_source_files_properties(
|
||||
target_link_libraries(
|
||||
support_unit_test
|
||||
gtest_main
|
||||
ConcretelangClientLib
|
||||
ConcretelangSupport
|
||||
)
|
||||
|
||||
|
||||
@@ -2,63 +2,69 @@
|
||||
|
||||
#include "../unittest/end_to_end_jit_test.h"
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/EncryptedArgs.h"
|
||||
|
||||
namespace CL = mlir::concretelang;
|
||||
namespace clientlib = concretelang::clientlib;
|
||||
|
||||
TEST(Support, client_parameters_json_serde) {
|
||||
mlir::concretelang::ClientParameters params0;
|
||||
clientlib::ClientParameters params0;
|
||||
params0.secretKeys = {
|
||||
{mlir::concretelang::SMALL_KEY, {/*.size = */ 12}},
|
||||
{mlir::concretelang::BIG_KEY, {/*.size = */ 14}},
|
||||
{clientlib::SMALL_KEY, {/*.size = */ 12}},
|
||||
{clientlib::BIG_KEY, {/*.size = */ 14}},
|
||||
};
|
||||
params0.bootstrapKeys = {
|
||||
{"bsk_v0",
|
||||
{/*.inputSecretKeyID = */ mlir::concretelang::SMALL_KEY,
|
||||
/*.outputSecretKeyID = */ mlir::concretelang::BIG_KEY,
|
||||
{
|
||||
"bsk_v0", {
|
||||
/*.inputSecretKeyID = */ clientlib::SMALL_KEY,
|
||||
/*.outputSecretKeyID = */ clientlib::BIG_KEY,
|
||||
/*.level = */ 1,
|
||||
/*.baseLog = */ 2,
|
||||
/*.glweDimension = */ 3,
|
||||
/*.variance = */ 0.001}},
|
||||
{"wtf_bsk_v0",
|
||||
{
|
||||
/*.inputSecretKeyID = */ mlir::concretelang::BIG_KEY,
|
||||
/*.outputSecretKeyID = */ mlir::concretelang::SMALL_KEY,
|
||||
/*.level = */ 3,
|
||||
/*.baseLog = */ 2,
|
||||
/*.glweDimension = */ 1,
|
||||
/*.variance = */ 0.0001,
|
||||
}},
|
||||
/*.variance = */ 0.001
|
||||
}
|
||||
},{
|
||||
"wtf_bsk_v0", {
|
||||
/*.inputSecretKeyID = */ clientlib::BIG_KEY,
|
||||
/*.outputSecretKeyID = */ clientlib::SMALL_KEY,
|
||||
/*.level = */ 3,
|
||||
/*.baseLog = */ 2,
|
||||
/*.glweDimension = */ 1,
|
||||
/*.variance = */ 0.0001,
|
||||
}
|
||||
},
|
||||
};
|
||||
params0.keyswitchKeys = {
|
||||
{"ksk_v0",
|
||||
{
|
||||
/*.inputSecretKeyID = */ mlir::concretelang::BIG_KEY,
|
||||
/*.outputSecretKeyID = */ mlir::concretelang::SMALL_KEY,
|
||||
/*.level = */ 1,
|
||||
/*.baseLog = */ 2,
|
||||
/*.variance = */ 3,
|
||||
}}};
|
||||
{
|
||||
"ksk_v0", {
|
||||
/*.inputSecretKeyID = */ clientlib::BIG_KEY,
|
||||
/*.outputSecretKeyID = */ clientlib::SMALL_KEY,
|
||||
/*.level = */ 1,
|
||||
/*.baseLog = */ 2,
|
||||
/*.variance = */ 3,
|
||||
}
|
||||
}
|
||||
};
|
||||
params0.inputs = {
|
||||
{
|
||||
/*.encryption = */ {{CL::SMALL_KEY, 0.01, {4}}},
|
||||
/*.shape = */ {32, {1, 2, 3, 4}, 1 * 2 * 3 * 4},
|
||||
},
|
||||
{
|
||||
/*.encryption = */ {{CL::SMALL_KEY, 0.03, {5}}},
|
||||
/*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4},
|
||||
},
|
||||
{
|
||||
/*.encryption = */ {{clientlib::SMALL_KEY, 0.01, {4}}},
|
||||
/*.shape = */ {32, {1, 2, 3, 4}, 1 * 2 * 3 * 4},
|
||||
},
|
||||
{
|
||||
/*.encryption = */ {{clientlib::SMALL_KEY, 0.03, {5}}},
|
||||
/*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4},
|
||||
},
|
||||
};
|
||||
params0.outputs = {
|
||||
{
|
||||
/*.encryption = */ {{CL::SMALL_KEY, 0.03, {5}}},
|
||||
/*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4},
|
||||
},
|
||||
};
|
||||
auto json = mlir::concretelang::toJSON(params0);
|
||||
{
|
||||
/*.encryption = */ {{clientlib::SMALL_KEY, 0.03, {5}}},
|
||||
/*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4},
|
||||
},
|
||||
};
|
||||
auto json = clientlib::toJSON(params0);
|
||||
std::string jsonStr;
|
||||
llvm::raw_string_ostream os(jsonStr);
|
||||
os << json;
|
||||
auto parseResult =
|
||||
llvm::json::parse<mlir::concretelang::ClientParameters>(jsonStr);
|
||||
llvm::json::parse<clientlib::ClientParameters>(jsonStr);
|
||||
ASSERT_EXPECTED_VALUE(parseResult, params0);
|
||||
}
|
||||
|
||||
@@ -15,9 +15,12 @@ set_source_files_properties(
|
||||
|
||||
target_link_libraries(
|
||||
testlib_unit_test
|
||||
gtest_main
|
||||
ConcretelangCommon
|
||||
ConcretelangRuntime
|
||||
ConcretelangTestLib
|
||||
ConcretelangSupport
|
||||
ConcretelangClientLib
|
||||
ConcretelangServerLib
|
||||
gtest_main
|
||||
)
|
||||
|
||||
include(GoogleTest)
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
#include "boost/outcome.h"
|
||||
#include "concretelang/ClientLib/ClientLambda.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
|
||||
namespace call_2t_1s_with_header {
|
||||
namespace client {
|
||||
|
||||
namespace extract {
|
||||
using namespace concretelang::clientlib;
|
||||
using concretelang::error::StringError;
|
||||
using extract_t = TypedClientLambda<scalar_out, tensor1_in, tensor1_in>;
|
||||
static const std::string name = "extract";
|
||||
|
||||
static outcome::checked<extract_t, StringError>
|
||||
load(std::string outputLib)
|
||||
{ return extract_t::load(name, outputLib); }
|
||||
} // namespace extract
|
||||
|
||||
} // namespace client
|
||||
} // namespace call_2t_1s_with_header
|
||||
@@ -1,40 +1,88 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <numeric>
|
||||
#include <cassert>
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "../unittest/end_to_end_jit_test.h"
|
||||
#include "concretelang/TestLib/DynamicLambda.h"
|
||||
#include "concretelang/ClientLib/ClientLambda.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/TestLib/TestTypedLambda.h"
|
||||
|
||||
#include "call_2t_1s_with_header-client.h.generated"
|
||||
|
||||
const std::string FUNCNAME = "main";
|
||||
|
||||
template<typename... Params>
|
||||
using TypedDynamicLambda = mlir::concretelang::TypedDynamicLambda<Params...>;
|
||||
using namespace concretelang::testlib;
|
||||
|
||||
using scalar = uint64_t;
|
||||
using tensor1_in = std::vector<uint8_t>;
|
||||
using tensor1_out = std::vector<uint64_t>;
|
||||
using tensor2_out = std::vector<std::vector<uint64_t>>;
|
||||
using tensor3_out = std::vector<std::vector<std::vector<uint64_t>>>;
|
||||
using concretelang::clientlib::scalar_in;
|
||||
using concretelang::clientlib::scalar_out;
|
||||
using concretelang::clientlib::tensor1_in;
|
||||
using concretelang::clientlib::tensor2_in;
|
||||
using concretelang::clientlib::tensor1_out;
|
||||
using concretelang::clientlib::tensor2_out;
|
||||
using concretelang::clientlib::tensor3_out;
|
||||
|
||||
std::vector<uint8_t>
|
||||
values_7bits() {
|
||||
return {0, 1, 2, 63, 64, 65, 125, 126};
|
||||
}
|
||||
|
||||
llvm::Expected<mlir::concretelang::CompilerEngine::Library>
|
||||
compile(std::string outputLib, std::string source) {
|
||||
mlir::concretelang::CompilerEngine::Library
|
||||
compile(std::string outputLib, std::string source, std::string funcname = FUNCNAME) {
|
||||
std::vector<std::string> sources = {source};
|
||||
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
|
||||
mlir::concretelang::CompilationContext::createShared();
|
||||
mlir::concretelang::JitCompilerEngine ce {ccx};
|
||||
ce.setClientParametersFuncName(FUNCNAME);
|
||||
return ce.compile(sources, outputLib);
|
||||
ce.setClientParametersFuncName(funcname);
|
||||
auto result = ce.compile(sources, outputLib);
|
||||
assert(result);
|
||||
return result.get();
|
||||
}
|
||||
|
||||
static const std::string THIS_TEST_DIRECTORY = "tests/TestLib";
|
||||
static const std::string OUT_DIRECTORY = THIS_TEST_DIRECTORY + "/out";
|
||||
|
||||
template<typename Info>
|
||||
std::string outputLibFromThis(Info *info) {
|
||||
return "tests/TestLib/out/" + std::string(info->name());
|
||||
return OUT_DIRECTORY + "/" + std::string(info->name());
|
||||
}
|
||||
|
||||
template<typename Lambda>
|
||||
Lambda load(std::string outputLib) {
|
||||
auto l = Lambda::load(FUNCNAME, outputLib, 0, 0, getTestKeySetCachePtr());
|
||||
assert(l.has_value());
|
||||
return l.value();
|
||||
}
|
||||
|
||||
TEST(CompiledModule, call_1s_1s_client_view) {
|
||||
std::string source = R"(
|
||||
func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
return %arg0: !FHE.eint<7>
|
||||
}
|
||||
)";
|
||||
namespace clientlib = concretelang::clientlib;
|
||||
using MyLambda = clientlib::TypedClientLambda<scalar_out, scalar_in>;
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
std::string jsonPath = ClientParameters::getClientParametersPath(outputLib);
|
||||
auto maybeLambda = MyLambda::load("main", jsonPath);
|
||||
ASSERT_TRUE(maybeLambda.has_value());
|
||||
auto lambda = maybeLambda.value();
|
||||
auto maybeKeySet = lambda.keySet(getTestKeySetCachePtr(), 0, 0);
|
||||
ASSERT_TRUE(maybeKeySet.has_value());
|
||||
std::shared_ptr<KeySet> keySet = std::move(maybeKeySet.value());
|
||||
auto maybePublicArguments = lambda.publicArguments(1, keySet);
|
||||
ASSERT_TRUE(maybePublicArguments.has_value());
|
||||
auto publicArguments = std::move(maybePublicArguments.value());
|
||||
std::ostringstream osstream(std::ios::binary);
|
||||
EXPECT_TRUE(lambda.untypedSerializeCall(publicArguments, osstream));
|
||||
EXPECT_TRUE(osstream.good());
|
||||
// Direct call without intermediate
|
||||
EXPECT_TRUE(lambda.serializeCall(1, keySet, osstream));
|
||||
EXPECT_TRUE(osstream.good());
|
||||
}
|
||||
|
||||
TEST(CompiledModule, call_1s_1s) {
|
||||
@@ -45,13 +93,28 @@ func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
)";
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
ASSERT_EXPECTED_SUCCESS(compiled);
|
||||
auto lambda = TypedDynamicLambda<scalar, scalar>::load(FUNCNAME, outputLib);
|
||||
ASSERT_EXPECTED_SUCCESS(lambda);
|
||||
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
|
||||
auto lambda = load<TestTypedLambda<scalar_out, scalar_in>>(outputLib);
|
||||
for(auto a: values_7bits()) {
|
||||
auto res = lambda->call(a);
|
||||
ASSERT_EXPECTED_VALUE(res, a);
|
||||
auto res = lambda.call(a);
|
||||
ASSERT_EQ_OUTCOME(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompiledModule, call_2s_1s_choose) {
|
||||
std::string source = R"(
|
||||
func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
return %arg0: !FHE.eint<7>
|
||||
}
|
||||
)";
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
auto lambda = load<TestTypedLambda<scalar_out, scalar_in, scalar_in>>(outputLib);
|
||||
for(auto a: values_7bits()) for(auto b: values_7bits()) {
|
||||
if (a > b) {
|
||||
continue;
|
||||
}
|
||||
auto res = lambda.call(a, b);
|
||||
ASSERT_EQ_OUTCOME(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,16 +127,30 @@ func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
)";
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
ASSERT_EXPECTED_SUCCESS(compiled);
|
||||
auto lambda = TypedDynamicLambda<scalar, scalar, scalar>::load(FUNCNAME, outputLib);
|
||||
ASSERT_EXPECTED_SUCCESS(lambda);
|
||||
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
|
||||
auto lambda = load<TestTypedLambda<scalar_out, scalar_in, scalar_in>>(outputLib);
|
||||
for(auto a: values_7bits()) for(auto b: values_7bits()) {
|
||||
auto res = lambda->call(a, b);
|
||||
ASSERT_EXPECTED_VALUE(res, a + b);
|
||||
if (a > b) {
|
||||
continue;
|
||||
}
|
||||
auto res = lambda.call(a, b);
|
||||
ASSERT_EQ_OUTCOME(res, a + b);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompiledModule, call_1s_1s_bad_call) {
|
||||
std::string source = R"(
|
||||
func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
)";
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
auto lambda = load<TestTypedLambda<scalar_out, scalar_in>>(outputLib);
|
||||
auto res = lambda.call(1);
|
||||
ASSERT_FALSE(res.has_value());
|
||||
}
|
||||
|
||||
TEST(CompiledModule, call_1s_1t) {
|
||||
std::string source = R"(
|
||||
func @main(%arg0: !FHE.eint<7>) -> tensor<1x!FHE.eint<7>> {
|
||||
@@ -83,14 +160,11 @@ func @main(%arg0: !FHE.eint<7>) -> tensor<1x!FHE.eint<7>> {
|
||||
)";
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
ASSERT_EXPECTED_SUCCESS(compiled);
|
||||
auto lambda = TypedDynamicLambda<tensor1_out, scalar>::load(FUNCNAME, outputLib);
|
||||
ASSERT_EXPECTED_SUCCESS(lambda);
|
||||
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
|
||||
auto lambda = load<TestTypedLambda<tensor1_out, scalar_in>>(outputLib);
|
||||
for(auto a: values_7bits()) {
|
||||
auto res = lambda->call(a);
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
tensor1_out v = res.get();
|
||||
auto res = lambda.call(a);
|
||||
EXPECT_TRUE(res);
|
||||
tensor1_out v = res.value();
|
||||
EXPECT_EQ(v[0], a);
|
||||
}
|
||||
}
|
||||
@@ -104,16 +178,13 @@ func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> tensor<2x!FHE.eint<7>> {
|
||||
)";
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
ASSERT_EXPECTED_SUCCESS(compiled);
|
||||
auto lambda = TypedDynamicLambda<tensor1_out, scalar, scalar>::load(FUNCNAME, outputLib);
|
||||
ASSERT_EXPECTED_SUCCESS(lambda);
|
||||
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
|
||||
auto lambda = load<TestTypedLambda<tensor1_out, scalar_in, scalar_in>>(outputLib);
|
||||
for(auto a : values_7bits()) {
|
||||
auto res = lambda->call(a, a+1);
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
tensor1_out v = res.get();
|
||||
EXPECT_EQ((scalar)v[0], a);
|
||||
EXPECT_EQ((scalar)v[1], a + 1u);
|
||||
auto res = lambda.call(a, a+1);
|
||||
EXPECT_TRUE(res);
|
||||
tensor1_out v = res.value();
|
||||
EXPECT_EQ(v[0], (scalar_out)a);
|
||||
EXPECT_EQ(v[1], (scalar_out)(a + 1u));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -127,14 +198,11 @@ func @main(%arg0: tensor<1x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
)";
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
ASSERT_EXPECTED_SUCCESS(compiled);
|
||||
auto lambda = TypedDynamicLambda<scalar, tensor1_in>::load(FUNCNAME, outputLib);
|
||||
ASSERT_EXPECTED_SUCCESS(lambda);
|
||||
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
|
||||
auto lambda = load<TestTypedLambda<scalar_out, tensor1_in>>(outputLib);
|
||||
for(uint8_t a : values_7bits()) {
|
||||
tensor1_in ta = {a};
|
||||
auto res = lambda->call(ta);
|
||||
ASSERT_EXPECTED_VALUE(res, a);
|
||||
auto res = lambda.call(ta);
|
||||
ASSERT_EQ_OUTCOME(res, a);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -146,14 +214,11 @@ func @main(%arg0: tensor<3x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> {
|
||||
)";
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
ASSERT_EXPECTED_SUCCESS(compiled);
|
||||
auto lambda = TypedDynamicLambda<tensor1_out, tensor1_in>::load(FUNCNAME, outputLib);
|
||||
ASSERT_EXPECTED_SUCCESS(lambda);
|
||||
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
|
||||
auto lambda = load<TestTypedLambda<tensor1_out, tensor1_in>>(outputLib);
|
||||
tensor1_in ta = {1, 2, 3};
|
||||
auto res = lambda->call(ta);
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
tensor1_out v = res.get();
|
||||
auto res = lambda.call(ta);
|
||||
ASSERT_TRUE(res);
|
||||
tensor1_out v = res.value();
|
||||
for(size_t i = 0; i < v.size(); i++) {
|
||||
EXPECT_EQ(v[i], ta[i]);
|
||||
}
|
||||
@@ -171,16 +236,13 @@ func @main(%arg0: tensor<3x!FHE.eint<7>>, %arg1: tensor<3x!FHE.eint<7>>) -> !FHE
|
||||
)";
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
ASSERT_EXPECTED_SUCCESS(compiled);
|
||||
auto lambda = TypedDynamicLambda<scalar, tensor1_in, std::array<uint8_t, 3>>::load(FUNCNAME, outputLib);
|
||||
ASSERT_EXPECTED_SUCCESS(lambda);
|
||||
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
|
||||
auto lambda = load<TestTypedLambda<scalar_out, tensor1_in, std::array<uint8_t, 3>>>(outputLib);
|
||||
tensor1_in ta {1, 2, 3};
|
||||
std::array<uint8_t, 3> tb {5, 7, 9};
|
||||
auto res = lambda->call(ta, tb);
|
||||
auto res = lambda.call(ta, tb);
|
||||
auto expected = std::accumulate(ta.begin(), ta.end(), 0u) +
|
||||
std::accumulate(tb.begin(), tb.end(), 0u);
|
||||
ASSERT_EXPECTED_VALUE(res, expected);
|
||||
ASSERT_EQ_OUTCOME(res, expected);
|
||||
}
|
||||
|
||||
TEST(CompiledModule, call_1tr2_1tr2) {
|
||||
@@ -192,17 +254,14 @@ func @main(%arg0: tensor<2x3x!FHE.eint<7>>) -> tensor<2x3x!FHE.eint<7>> {
|
||||
using tensor2_in = std::array<std::array<uint8_t, 3>, 2>;
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
ASSERT_EXPECTED_SUCCESS(compiled);
|
||||
auto lambda = TypedDynamicLambda<tensor2_out, tensor2_in>::load(FUNCNAME, outputLib);
|
||||
ASSERT_EXPECTED_SUCCESS(lambda);
|
||||
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
|
||||
auto lambda = load<TestTypedLambda<tensor2_out, tensor2_in>>(outputLib);
|
||||
tensor2_in ta = {{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6}
|
||||
}};
|
||||
auto res = lambda->call(ta);
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
tensor2_out v = res.get();
|
||||
auto res = lambda.call(ta);
|
||||
ASSERT_TRUE(res);
|
||||
tensor2_out v = res.value();
|
||||
for(size_t i = 0; i < v.size(); i++) {
|
||||
for(size_t j = 0; j < v.size(); j++) {
|
||||
EXPECT_EQ(v[i][j], ta[i][j]);
|
||||
@@ -210,7 +269,6 @@ func @main(%arg0: tensor<2x3x!FHE.eint<7>>) -> tensor<2x3x!FHE.eint<7>> {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST(CompiledModule, call_1tr3_1tr3) {
|
||||
std::string source = R"(
|
||||
func @main(%arg0: tensor<2x3x1x!FHE.eint<7>>) -> tensor<2x3x1x!FHE.eint<7>> {
|
||||
@@ -220,17 +278,14 @@ func @main(%arg0: tensor<2x3x1x!FHE.eint<7>>) -> tensor<2x3x1x!FHE.eint<7>> {
|
||||
using tensor3_in = std::array<std::array<std::array<uint8_t, 1>, 3>, 2>;
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
ASSERT_EXPECTED_SUCCESS(compiled);
|
||||
auto lambda = TypedDynamicLambda<tensor3_out, tensor3_in>::load(FUNCNAME, outputLib);
|
||||
ASSERT_EXPECTED_SUCCESS(lambda);
|
||||
ASSERT_LLVM_ERROR(lambda->generateKeySet(getTestKeySetCache()));
|
||||
auto lambda = load<TestTypedLambda<tensor3_out, tensor3_in>>(outputLib);
|
||||
tensor3_in ta = {{
|
||||
{{ {1}, {2}, {3} }},
|
||||
{{ {4}, {5}, {6} }}
|
||||
}};
|
||||
auto res = lambda->call(ta);
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
tensor3_out v = res.get();
|
||||
auto res = lambda.call(ta);
|
||||
ASSERT_TRUE(res);
|
||||
tensor3_out v = res.value();
|
||||
for(size_t i = 0; i < v.size(); i++) {
|
||||
for(size_t j = 0; j < v[i].size(); j++) {
|
||||
for(size_t k = 0; k < v[i][j].size(); k++) {
|
||||
@@ -239,3 +294,73 @@ func @main(%arg0: tensor<2x3x1x!FHE.eint<7>>) -> tensor<2x3x1x!FHE.eint<7>> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompiledModule, call_2tr3_1tr3) {
|
||||
std::string source = R"(
|
||||
func @main(%arg0: tensor<2x3x1x!FHE.eint<7>>, %arg1: tensor<2x3x1x!FHE.eint<7>>) -> tensor<2x3x1x!FHE.eint<7>> {
|
||||
%1 = "FHELinalg.add_eint"(%arg0, %arg1): (tensor<2x3x1x!FHE.eint<7>>, tensor<2x3x1x!FHE.eint<7>>) -> tensor<2x3x1x!FHE.eint<7>>
|
||||
return %1: tensor<2x3x1x!FHE.eint<7>>
|
||||
}
|
||||
)";
|
||||
using tensor3_in = std::array<std::array<std::array<uint8_t, 1>, 3>, 2>;
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
auto compiled = compile(outputLib, source);
|
||||
auto lambda = load<TestTypedLambda<tensor3_out, tensor3_in, tensor3_in>>(outputLib);
|
||||
tensor3_in ta = {{
|
||||
{{ {1}, {2}, {3} }},
|
||||
{{ {4}, {5}, {6} }}
|
||||
}};
|
||||
auto res = lambda.call(ta, ta);
|
||||
ASSERT_TRUE(res);
|
||||
tensor3_out v = res.value();
|
||||
for(size_t i = 0; i < v.size(); i++) {
|
||||
for(size_t j = 0; j < v[i].size(); j++) {
|
||||
for(size_t k = 0; k < v[i][j].size(); k++) {
|
||||
EXPECT_EQ(v[i][j][k], 2 * ta[i][j][k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static std::string fileContent(std::string path) {
|
||||
std::ifstream file(path);
|
||||
std::stringstream buffer;
|
||||
buffer << file.rdbuf();
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
TEST(CompiledModule, call_2t_1s_with_header) {
|
||||
std::string source = R"(
|
||||
func @extract(%arg0: tensor<3x!FHE.eint<7>>, %arg1: tensor<3x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
%1 = "FHELinalg.add_eint"(%arg0, %arg1) : (tensor<3x!FHE.eint<7>>, tensor<3x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>>
|
||||
%c1 = arith.constant 1 : i8
|
||||
%2 = tensor.from_elements %c1, %c1, %c1 : tensor<3xi8>
|
||||
%3 = "FHELinalg.dot_eint_int"(%1, %2) : (tensor<3x!FHE.eint<7>>, tensor<3xi8>) -> !FHE.eint<7>
|
||||
return %3: !FHE.eint<7>
|
||||
}
|
||||
)";
|
||||
std::string outputLib = outputLibFromThis(this->test_info_);
|
||||
namespace extract = call_2t_1s_with_header::client::extract;
|
||||
auto compiled = compile(outputLib, source, extract::name);
|
||||
std::string jsonPath = ClientParameters::getClientParametersPath(outputLib);
|
||||
auto cLambda_ = extract::load(jsonPath);
|
||||
ASSERT_TRUE(cLambda_);
|
||||
tensor1_in ta {1, 2, 3};
|
||||
tensor1_in tb {5, 7, 9};
|
||||
auto sLambda_ = ServerLambda::load(extract::name, outputLib);
|
||||
ASSERT_TRUE(sLambda_);
|
||||
auto cLambda = cLambda_.value();
|
||||
auto sLambda = sLambda_.value();
|
||||
auto keySet_ = cLambda.keySet(getTestKeySetCachePtr(), 0, 0);
|
||||
ASSERT_TRUE(keySet_.has_value());
|
||||
std::shared_ptr<KeySet> keySet = std::move(keySet_.value());
|
||||
auto testLambda = TestTypedLambdaFrom(cLambda, sLambda, keySet);
|
||||
auto res = testLambda.call(ta, tb);
|
||||
auto expected = std::accumulate(ta.begin(), ta.end(), 0u) +
|
||||
std::accumulate(tb.begin(), tb.end(), 0u);
|
||||
ASSERT_EQ_OUTCOME(res, expected);
|
||||
|
||||
EXPECT_EQ(
|
||||
fileContent(THIS_TEST_DIRECTORY + "/call_2t_1s_with_header-client.h.generated"),
|
||||
fileContent(OUT_DIRECTORY + "/call_2t_1s_with_header-client.h"));
|
||||
}
|
||||
|
||||
@@ -153,8 +153,8 @@ func @main(%t: tensor<10xi1>, %i: index) -> i1{
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
const size_t numDim = 2;
|
||||
const int64_t dim0 = 2;
|
||||
const int64_t dim1 = 10;
|
||||
const size_t dim0 = 2;
|
||||
const size_t dim1 = 10;
|
||||
const int64_t dims[numDim]{dim0, dim1};
|
||||
static std::vector<uint64_t> tensor2D{
|
||||
0xFFFFFFFFFFFFFFFF,
|
||||
|
||||
@@ -5,8 +5,8 @@
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
const size_t numDim = 2;
|
||||
const int64_t dim0 = 2;
|
||||
const int64_t dim1 = 10;
|
||||
const size_t dim0 = 2;
|
||||
const size_t dim1 = 10;
|
||||
const int64_t dims[numDim]{dim0, dim1};
|
||||
static std::vector<uint8_t> tensor2D{
|
||||
63, 12, 7, 43, 52, 9, 26, 34, 22, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
|
||||
@@ -138,10 +138,10 @@ func @main(%t: tensor<8x4x5x3x!FHE.eint<6>>, %d0: index, %d1: index, %d2: index,
|
||||
uint8_t A[dimSizes[0]][dimSizes[1]][dimSizes[2]][dimSizes[3]];
|
||||
|
||||
// Fill with some reproducible pattern
|
||||
for (int64_t d0 = 0; d0 < dimSizes[0]; d0++) {
|
||||
for (int64_t d1 = 0; d1 < dimSizes[1]; d1++) {
|
||||
for (int64_t d2 = 0; d2 < dimSizes[2]; d2++) {
|
||||
for (int64_t d3 = 0; d3 < dimSizes[3]; d3++) {
|
||||
for (size_t d0 = 0; d0 < dimSizes[0]; d0++) {
|
||||
for (size_t d1 = 0; d1 < dimSizes[1]; d1++) {
|
||||
for (size_t d2 = 0; d2 < dimSizes[2]; d2++) {
|
||||
for (size_t d3 = 0; d3 < dimSizes[3]; d3++) {
|
||||
A[d0][d1][d2][d3] = d0 + d1 + d2 + d3;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,6 +93,13 @@ static bool assert_expected_value(llvm::Expected<T> &&val, const V &exp) {
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define ASSERT_EQ_OUTCOME(val, exp) \
|
||||
if(!val.has_value()) { \
|
||||
std::string msg = "ERROR: <" + val.error().mesg + "> \n"; \
|
||||
GTEST_FATAL_FAILURE_(msg.c_str()); \
|
||||
}; \
|
||||
ASSERT_EQ(val.value(), exp);
|
||||
|
||||
static inline llvm::Optional<mlir::concretelang::KeySetCache> getTestKeySetCache() {
|
||||
|
||||
llvm::SmallString<0> cachePath;
|
||||
@@ -104,6 +111,11 @@ static inline llvm::Optional<mlir::concretelang::KeySetCache> getTestKeySetCache
|
||||
mlir::concretelang::KeySetCache(cachePathStr));
|
||||
}
|
||||
|
||||
static inline std::shared_ptr<mlir::concretelang::KeySetCache> getTestKeySetCachePtr() {
|
||||
return std::make_shared<mlir::concretelang::KeySetCache>(
|
||||
getTestKeySetCache().getValue());
|
||||
}
|
||||
|
||||
// Jit-compiles the function specified by `func` from `src` and
|
||||
// returns the corresponding lambda. Any compilation errors are caught
|
||||
// and reult in abnormal termination.
|
||||
|
||||
16
docs/cpp_api/clientlib.rst
Normal file
16
docs/cpp_api/clientlib.rst
Normal file
@@ -0,0 +1,16 @@
|
||||
ClientLib
|
||||
=========
|
||||
|
||||
.. toctree::
|
||||
clientlib/intro
|
||||
clientlib/client_lambda
|
||||
clientlib/server_lambda
|
||||
|
||||
|
||||
Index
|
||||
=====
|
||||
.. toctree::
|
||||
:glob:
|
||||
:maxdepth: 2
|
||||
|
||||
clientlib/*
|
||||
4
docs/cpp_api/clientlib/arguments.rst
Normal file
4
docs/cpp_api/clientlib/arguments.rst
Normal file
@@ -0,0 +1,4 @@
|
||||
LambdaArgument:
|
||||
===============
|
||||
|
||||
.. doxygenfile:: Arguments.h
|
||||
16
docs/cpp_api/clientlib/client_example.cpp
Normal file
16
docs/cpp_api/clientlib/client_example.cpp
Normal file
@@ -0,0 +1,16 @@
|
||||
#include <concretelang/ClientLib/ClientLambda.h>
|
||||
|
||||
// Include the header that has been generated by the compiler
|
||||
#include "myinclude/fhe_service/additions.h"
|
||||
|
||||
void query_server(MyConnection conn) {
|
||||
auto libPath = "/opt/fhe_service/libs/additions.so";
|
||||
std::ostream to_server = conn.ostream();
|
||||
std::ostream from_server = conn.istream();
|
||||
// In a real code only load once.
|
||||
auto add2int = additions::add2int::load(libPath, seed_msb, seed_lsb);
|
||||
auto err = add2int->callSerialize(1, 2, to_server);
|
||||
if( err ) { throw MyException() };
|
||||
auto result = add2int->decryptReturned(from_server);
|
||||
assert(result == 3);
|
||||
}
|
||||
5
docs/cpp_api/clientlib/client_lambda.rst
Normal file
5
docs/cpp_api/clientlib/client_lambda.rst
Normal file
@@ -0,0 +1,5 @@
|
||||
ClientLambda:
|
||||
===============
|
||||
|
||||
.. doxygenfile:: ClientLambda.h
|
||||
:sections: briefdescription innernamespace typedef innerclass public-static-func public-func
|
||||
5
docs/cpp_api/clientlib/client_parameters.rst
Normal file
5
docs/cpp_api/clientlib/client_parameters.rst
Normal file
@@ -0,0 +1,5 @@
|
||||
ClientParameters:
|
||||
===============
|
||||
|
||||
.. doxygenstruct:: mlir::concretelang::ClientParameters
|
||||
.. doxygenfile:: ClientParameters.h
|
||||
23
docs/cpp_api/clientlib/intro.rst
Normal file
23
docs/cpp_api/clientlib/intro.rst
Normal file
@@ -0,0 +1,23 @@
|
||||
Description
|
||||
========
|
||||
|
||||
ClientLambda represents a FHE function on the client side.
|
||||
ServerLambda represents a FHE function on the server side.
|
||||
|
||||
These object read/write on istreams/ostreams.
|
||||
|
||||
Implementing a client/server consists in connecting the two, by connecting to actual streams.
|
||||
|
||||
Example on client side:
|
||||
|
||||
.. literalinclude:: client_example.cpp
|
||||
:linenos:
|
||||
:language: bash
|
||||
..
|
||||
For some reason cpp does not work well here.
|
||||
|
||||
Example on server side:
|
||||
|
||||
.. literalinclude:: server_example.cpp
|
||||
:linenos:
|
||||
:language: bash
|
||||
10
docs/cpp_api/clientlib/server_example.cpp
Normal file
10
docs/cpp_api/clientlib/server_example.cpp
Normal file
@@ -0,0 +1,10 @@
|
||||
#include <concretelang/ServerLib/ServerLambda.h>
|
||||
|
||||
void answer_client(MyConnection conn) {
|
||||
std::istream from_client = conn.istream();
|
||||
std::ostream to_client = conn.ostream();
|
||||
auto err = serverLambda.read_call_write(serverInput, serverOutput);
|
||||
if (err) {
|
||||
throw MyException();
|
||||
}
|
||||
}
|
||||
4
docs/cpp_api/clientlib/server_lambda.rst
Normal file
4
docs/cpp_api/clientlib/server_lambda.rst
Normal file
@@ -0,0 +1,4 @@
|
||||
ServerLambda:
|
||||
===============
|
||||
|
||||
.. doxygenfile:: ServerLambda.h
|
||||
@@ -1,4 +0,0 @@
|
||||
ClientParameters:
|
||||
===============
|
||||
|
||||
.. doxygenfile:: ClientParameters.h
|
||||
Reference in New Issue
Block a user