diff --git a/.github/workflows/aws_build_gpu.yml b/.github/workflows/aws_build_gpu.yml index fc8c69f89..f9a993b69 100644 --- a/.github/workflows/aws_build_gpu.yml +++ b/.github/workflows/aws_build_gpu.yml @@ -73,7 +73,7 @@ jobs: - name: Create build dir run: mkdir build - - name: Build compiler + - name: Build and test compiler uses: addnab/docker-run-action@v3 id: build-compiler with: @@ -94,34 +94,7 @@ jobs: set -e cd /compiler rm -rf /build/* - make DATAFLOW_EXECUTION_ENABLED=ON CCACHE=ON Python3_EXECUTABLE=$PYTHON_EXEC BUILD_DIR=/build CUDA_SUPPORT=ON CUDA_PATH=${{ env.CUDA_PATH }} all python-package build-end-to-end-dataflow-tests + mkdir -p /tmp/concrete_compiler/gpu_tests/ + make BINDINGS_PYTHON_ENABLED=OFF CCACHE=ON Python3_EXECUTABLE=$PYTHON_EXEC CUDA_SUPPORT=ON CUDA_PATH=${{ env.CUDA_PATH }} run-end-to-end-tests-gpu echo "Debug: ccache statistics (after the build):" ccache -s - - - name: Archive python package - uses: actions/upload-artifact@v3 - with: - name: concrete-compiler-gpu - path: build/wheels/*.whl - - - name: Test compiler - uses: addnab/docker-run-action@v3 - with: - registry: ghcr.io - image: ${{ env.DOCKER_IMAGE_TEST }} - username: ${{ secrets.GHCR_LOGIN }} - password: ${{ secrets.GHCR_PASSWORD }} - options: >- - -v ${{ github.workspace }}/llvm-project:/llvm-project - -v ${{ github.workspace }}/compiler:/compiler - -v ${{ github.workspace }}/build:/build - --gpus all - shell: bash - run: | - set -e - cd /compiler - pip install pytest - sed "s/pytest/python -m pytest/g" -i Makefile - mkdir -p /tmp/concrete_compiler/gpu_tests/ - make DATAFLOW_EXECUTION_ENABLED=ON CCACHE=ON Python3_EXECUTABLE=$PYTHON_EXEC BUILD_DIR=/build run-end-to-end-tests-gpu - chmod -R ugo+rwx /tmp/KeySetCache diff --git a/.gitmodules b/.gitmodules index 8067b5368..34e964a07 100644 --- a/.gitmodules +++ b/.gitmodules @@ -13,3 +13,8 @@ [submodule "compiler/parameter-curves"] path = compiler/parameter-curves url = git@github.com:zama-ai/parameter-curves.git + shallow = true +[submodule "compiler/concrete-cpu"] + path = compiler/concrete-cpu + url = git@github.com:zama-ai/concrete-cpu.git + shallow = true diff --git a/compiler/.gitignore b/compiler/.gitignore index 68a0d350c..7ae8b1d18 100644 --- a/compiler/.gitignore +++ b/compiler/.gitignore @@ -4,9 +4,6 @@ build*/ *.mlir.script *.lit_test_times.txt -# core ffi artifacts -concrete-core-ffi* - # Test-generated artifacts concrete-compiler_compilation_artifacts/ py_test_lib_compile_and_run_custom_perror/ diff --git a/compiler/CMakeLists.txt b/compiler/CMakeLists.txt index 09ce56170..774751b3d 100644 --- a/compiler/CMakeLists.txt +++ b/compiler/CMakeLists.txt @@ -61,11 +61,23 @@ set(CONCRETELANG_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) include_directories(${PROJECT_SOURCE_DIR}/parameter-curves/concrete-security-curves-cpp/include) # ------------------------------------------------------------------------------- -# Concrete FFI Configuration +# Concrete CPU Configuration # ------------------------------------------------------------------------------- -include_directories(${CONCRETE_FFI_RELEASE}) -add_library(Concrete STATIC IMPORTED) -set_target_properties(Concrete PROPERTIES IMPORTED_LOCATION ${CONCRETE_FFI_RELEASE}/libconcrete_core_ffi.a) +set(CONCRETE_CPU_STATIC_LIB "${PROJECT_SOURCE_DIR}/concrete-cpu/target/release/libconcrete_cpu.a") +ExternalProject_Add( + concrete_cpu_rust + DOWNLOAD_COMMAND "" + CONFIGURE_COMMAND "" OUTPUT "${CONCRETE_CPU_STATIC_LIB}" + BUILD_COMMAND cargo build + COMMAND cargo build --release + BINARY_DIR "${PROJECT_SOURCE_DIR}/concrete-cpu" + INSTALL_COMMAND "" + LOG_BUILD ON) +add_library(concrete_cpu STATIC IMPORTED) +# TODO - Change that to a location in the release dir +set(CONCRETE_CPU_INCLUDE_DIR "${PROJECT_SOURCE_DIR}/concrete-cpu/concrete-cpu") +set_target_properties(concrete_cpu PROPERTIES IMPORTED_LOCATION "${CONCRETE_CPU_STATIC_LIB}") +add_dependencies(concrete_cpu concrete_cpu_rust) # -------------------------------------------------------------------------------- # Concrete Cuda Configuration diff --git a/compiler/Makefile b/compiler/Makefile index 50de08a66..8a61f2b09 100644 --- a/compiler/Makefile +++ b/compiler/Makefile @@ -24,8 +24,6 @@ HPX_TARBALL=$(shell pwd)/hpx-$(HPX_VERSION).tar.gz HPX_LOCAL_DIR=$(shell pwd)/hpx-$(HPX_VERSION) HPX_INSTALL_DIR?=$(HPX_LOCAL_DIR)/build -CONCRETE_CORE_FFI_VERSION?=0.2.0 - ML_BENCH_SUBSET_ID= # Find OS @@ -50,8 +48,6 @@ else ARCHITECTURE=amd64 endif -CONCRETE_CORE_FFI_TARBALL=concrete-core-ffi_$(CONCRETE_CORE_FFI_VERSION)_$(OS)_$(ARCHITECTURE).tar.gz - export PATH := $(abspath $(BUILD_DIR))/bin:$(PATH) ifeq ($(shell which ccache),) @@ -87,23 +83,6 @@ endif all: concretecompiler python-bindings build-tests build-benchmarks build-mlbench doc rust-bindings -# concrete-core-ffi ####################################### - -CONCRETE_CORE_FFI_FOLDER = $(shell pwd)/concrete-core-ffi-$(CONCRETE_CORE_FFI_VERSION) - -CONCRETE_CORE_FFI_ARTIFACTS=$(CONCRETE_CORE_FFI_FOLDER)/libconcrete_core_ffi.a $(CONCRETE_CORE_FFI_FOLDER)/concrete-core-ffi.h - -concrete-core-ffi: $(CONCRETE_CORE_FFI_ARTIFACTS) - -$(CONCRETE_CORE_FFI_ARTIFACTS): $(CONCRETE_CORE_FFI_FOLDER) - -$(CONCRETE_CORE_FFI_FOLDER): $(CONCRETE_CORE_FFI_TARBALL) - mkdir -p $(CONCRETE_CORE_FFI_FOLDER) - tar -xvzf $(CONCRETE_CORE_FFI_TARBALL) --directory $(CONCRETE_CORE_FFI_FOLDER) - -$(CONCRETE_CORE_FFI_TARBALL): - curl -L https://github.com/zama-ai/concrete-core/releases/download/concrete-core-ffi-$(CONCRETE_CORE_FFI_VERSION)/$(CONCRETE_CORE_FFI_TARBALL) -o $(CONCRETE_CORE_FFI_TARBALL) - # concrete-optimizer ###################################### LIB_CONCRETE_OPTIMIZER_CPP = $(CONCRETE_OPTIMIZER_DIR)/target/libconcrete_optimizer_cpp.a @@ -143,7 +122,6 @@ $(BUILD_DIR)/configured.stamp: -DCONCRETELANG_BINDINGS_PYTHON_ENABLED=$(BINDINGS_PYTHON_ENABLED) \ -DCONCRETELANG_DATAFLOW_EXECUTION_ENABLED=$(DATAFLOW_EXECUTION_ENABLED) \ -DCONCRETELANG_TIMING_ENABLED=$(TIMING_ENABLED) \ - -DCONCRETE_FFI_RELEASE=$(CONCRETE_CORE_FFI_FOLDER) \ -DHPX_DIR=${HPX_INSTALL_DIR}/lib/cmake/HPX \ -DLLVM_EXTERNAL_PROJECTS=concretelang \ -DLLVM_EXTERNAL_CONCRETELANG_SOURCE_DIR=. \ @@ -154,7 +132,7 @@ $(BUILD_DIR)/configured.stamp: -DCUDAToolkit_ROOT=$(CUDA_PATH) touch $@ -build-initialized: concrete-optimizer-lib concrete-core-ffi $(BUILD_DIR)/configured.stamp +build-initialized: concrete-optimizer-lib $(BUILD_DIR)/configured.stamp doc: build-initialized cmake --build $(BUILD_DIR) --target mlir-doc @@ -330,7 +308,7 @@ $(BENCHMARK_CPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml: tests/end_to_end $(BENCHMARK_CPU_DIR)/%.yaml: tests/end_to_end_fixture/%_gen.py $(Python3_EXECUTABLE) $< > $@ -generate-cpu-benchmarks: $(BENCHMARK_CPU_DIR) $(BENCHMARK_CPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml $(BENCHMARK_CPU_DIR)/end_to_end_apply_lookup_table.yaml $(BENCHMARK_CPU_DIR)/end_to_end_leveled.yaml +generate-cpu-benchmarks: $(BENCHMARK_CPU_DIR) $(BENCHMARK_CPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml $(BENCHMARK_CPU_DIR)/end_to_end_apply_lookup_table.yaml SECURITY_TO_BENCH=128 run-cpu-benchmarks: build-benchmarks generate-cpu-benchmarks @@ -530,7 +508,6 @@ install: concretecompiler concrete-optimizer-lib CAPI install-deps python-format \ check-python-format \ concrete-optimizer-lib \ - concrete-core-ffi \ build-tests \ run-tests \ run-check-tests \ @@ -540,7 +517,6 @@ install: concretecompiler concrete-optimizer-lib CAPI install-deps build-end-to-end-tests \ build-end-to-end-dataflow-tests \ run-end-to-end-dataflow-tests \ - concrete-core-ffi \ opt \ mlir-opt \ mlir-cpu-runner \ diff --git a/compiler/concrete-cpu b/compiler/concrete-cpu new file mode 160000 index 000000000..db262714c --- /dev/null +++ b/compiler/concrete-cpu @@ -0,0 +1 @@ +Subproject commit db262714cde546344d25f0a81e7974fd0277a55f diff --git a/compiler/include/concretelang/ClientLib/ClientLambda.h b/compiler/include/concretelang/ClientLib/ClientLambda.h index abca94f9f..f96c6c4e6 100644 --- a/compiler/include/concretelang/ClientLib/ClientLambda.h +++ b/compiler/include/concretelang/ClientLib/ClientLambda.h @@ -92,8 +92,7 @@ public: OUTCOME_TRY(auto clientArguments, EncryptedArguments::create(keySet, args...)); - return clientArguments->exportPublicArguments(clientParameters, - keySet.runtimeContext()); + return clientArguments->exportPublicArguments(clientParameters); } outcome::checked decryptResult(KeySet &keySet, diff --git a/compiler/include/concretelang/ClientLib/ClientParameters.h b/compiler/include/concretelang/ClientLib/ClientParameters.h index 1275cca0d..c9a144823 100644 --- a/compiler/include/concretelang/ClientLib/ClientParameters.h +++ b/compiler/include/concretelang/ClientLib/ClientParameters.h @@ -35,10 +35,8 @@ namespace clientlib { using concretelang::error::StringError; -const std::string SMALL_KEY = "small"; -const std::string BIG_KEY = "big"; -const std::string BOOTSTRAP_KEY = "bsk_v0"; -const std::string KEYSWITCH_KEY = "ksk_v0"; +const uint64_t SMALL_KEY = 1; +const uint64_t BIG_KEY = 0; const std::string CLIENT_PARAMETERS_EXT = ".concrete.params.json"; @@ -52,7 +50,7 @@ typedef std::vector CRTDecomposition; typedef uint64_t LweDimension; typedef uint64_t GlweDimension; -typedef std::string LweSecretKeyID; +typedef uint64_t LweSecretKeyID; struct LweSecretKeyParam { LweDimension dimension; @@ -66,7 +64,7 @@ static bool operator==(const LweSecretKeyParam &lhs, return lhs.dimension == rhs.dimension; } -typedef std::string BootstrapKeyID; +typedef uint64_t BootstrapKeyID; struct BootstrapKeyParam { LweSecretKeyID inputSecretKeyID; LweSecretKeyID outputSecretKeyID; @@ -74,6 +72,8 @@ struct BootstrapKeyParam { DecompositionBaseLog baseLog; GlweDimension glweDimension; Variance variance; + PolynomialSize polynomialSize; + LweDimension inputLweDimension; void hash(size_t &seed); @@ -90,7 +90,7 @@ static inline bool operator==(const BootstrapKeyParam &lhs, lhs.glweDimension == rhs.glweDimension && lhs.variance == rhs.variance; } -typedef std::string KeyswitchKeyID; +typedef uint64_t KeyswitchKeyID; struct KeyswitchKeyParam { LweSecretKeyID inputSecretKeyID; LweSecretKeyID outputSecretKeyID; @@ -112,22 +112,28 @@ static inline bool operator==(const KeyswitchKeyParam &lhs, lhs.variance == rhs.variance; } -typedef std::string PackingKeySwitchID; -struct PackingKeySwitchParam { +typedef uint64_t PackingKeyswitchKeyID; +struct PackingKeyswitchKeyParam { LweSecretKeyID inputSecretKeyID; LweSecretKeyID outputSecretKeyID; - BootstrapKeyID bootstrapKeyID; DecompositionLevelCount level; DecompositionBaseLog baseLog; + GlweDimension glweDimension; + PolynomialSize polynomialSize; + LweDimension inputLweDimension; Variance variance; void hash(size_t &seed); }; -static inline bool operator==(const PackingKeySwitchParam &lhs, - const PackingKeySwitchParam &rhs) { +static inline bool operator==(const PackingKeyswitchKeyParam &lhs, + const PackingKeyswitchKeyParam &rhs) { return lhs.inputSecretKeyID == rhs.inputSecretKeyID && lhs.outputSecretKeyID == rhs.outputSecretKeyID && - lhs.level == rhs.level && lhs.baseLog == rhs.baseLog; + lhs.level == rhs.level && lhs.baseLog == rhs.baseLog && + lhs.glweDimension == rhs.glweDimension && + lhs.polynomialSize == rhs.polynomialSize && + lhs.variance == lhs.variance && + lhs.inputLweDimension == rhs.inputLweDimension; } struct Encoding { @@ -185,13 +191,13 @@ struct CircuitGate { bool isEncrypted() { return encryption.hasValue(); } /// byteSize returns the size in bytes for this gate. - size_t byteSize(std::map secretKeys) { + size_t byteSize(std::vector secretKeys) { auto width = shape.width; auto numElts = shape.size == 0 ? 1 : shape.size; if (isEncrypted()) { - auto skParam = secretKeys.find(encryption->secretKeyID); - assert(skParam != secretKeys.end()); - return 8 * skParam->second.lweSize() * numElts; + assert(encryption->secretKeyID < secretKeys.size()); + auto skParam = secretKeys[encryption->secretKeyID]; + return 8 * skParam.lweSize() * numElts; } width = bitWidthAsWord(width) / 8; return width * numElts; @@ -203,10 +209,10 @@ static inline bool operator==(const CircuitGate &lhs, const CircuitGate &rhs) { } struct ClientParameters { - std::map secretKeys; - std::map bootstrapKeys; - std::map keyswitchKeys; - std::map packingKeys; + std::vector secretKeys; + std::vector bootstrapKeys; + std::vector keyswitchKeys; + std::vector packingKeyswitchKeys; std::vector inputs; std::vector outputs; std::string functionName; @@ -237,12 +243,9 @@ struct ClientParameters { if (!gate.encryption.hasValue()) { return StringError("gate is not encrypted"); } - auto secretKey = secretKeys.find(gate.encryption->secretKeyID); - if (secretKey == secretKeys.end()) { - return StringError("cannot find ") - << gate.encryption->secretKeyID << " in client parameters"; - } - return secretKey->second; + assert(gate.encryption->secretKeyID < secretKeys.size()); + auto secretKey = secretKeys[gate.encryption->secretKeyID]; + return secretKey; } /// bufferSize returns the size of the whole buffer of a gate. @@ -309,8 +312,8 @@ bool fromJSON(const llvm::json::Value, BootstrapKeyParam &, llvm::json::Path); llvm::json::Value toJSON(const KeyswitchKeyParam &); bool fromJSON(const llvm::json::Value, KeyswitchKeyParam &, llvm::json::Path); -llvm::json::Value toJSON(const PackingKeySwitchParam &); -bool fromJSON(const llvm::json::Value, PackingKeySwitchParam &, +llvm::json::Value toJSON(const PackingKeyswitchKeyParam &); +bool fromJSON(const llvm::json::Value, PackingKeyswitchKeyParam &, llvm::json::Path); llvm::json::Value toJSON(const Encoding &); diff --git a/compiler/include/concretelang/ClientLib/EncryptedArguments.h b/compiler/include/concretelang/ClientLib/EncryptedArguments.h index 44e2f2864..c6ead904d 100644 --- a/compiler/include/concretelang/ClientLib/EncryptedArguments.h +++ b/compiler/include/concretelang/ClientLib/EncryptedArguments.h @@ -61,8 +61,7 @@ public: /// arguments, i.e. move all buffers to the PublicArguments and reset the /// positional counter. outcome::checked, StringError> - exportPublicArguments(ClientParameters clientParameters, - RuntimeContext runtimeContext); + exportPublicArguments(ClientParameters clientParameters); /// Check that all arguments as been pushed. // TODO: Remove public method here diff --git a/compiler/include/concretelang/ClientLib/EvaluationKeys.h b/compiler/include/concretelang/ClientLib/EvaluationKeys.h index 33c273058..092570809 100644 --- a/compiler/include/concretelang/ClientLib/EvaluationKeys.h +++ b/compiler/include/concretelang/ClientLib/EvaluationKeys.h @@ -6,113 +6,143 @@ #ifndef CONCRETELANG_CLIENTLIB_EVALUATION_KEYS_H_ #define CONCRETELANG_CLIENTLIB_EVALUATION_KEYS_H_ +#include #include +#include -#include "concrete-core-ffi.h" +#include "concretelang/ClientLib/ClientParameters.h" #include "concretelang/Common/Error.h" -typedef struct LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 - LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64; - -int destroy_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_u64( - LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *); +struct Csprng; +struct CsprngVtable; namespace concretelang { namespace clientlib { -// ============================================= +class CSPRNG { +public: + struct Csprng *ptr; + const struct CsprngVtable *vtable; -/// Wrapper for `LweKeyswitchKey64` so that it cleans up properly. + CSPRNG() = delete; + CSPRNG(CSPRNG &) = delete; + + CSPRNG(CSPRNG &&other) : ptr(other.ptr), vtable(other.vtable) { + assert(ptr != nullptr); + other.ptr = nullptr; + }; + + CSPRNG(Csprng *ptr, const CsprngVtable *vtable) : ptr(ptr), vtable(vtable){}; +}; + +class ConcreteCSPRNG : public CSPRNG { +public: + ConcreteCSPRNG(__uint128_t seed); + ConcreteCSPRNG() = delete; + ConcreteCSPRNG(ConcreteCSPRNG &) = delete; + ConcreteCSPRNG(ConcreteCSPRNG &&other); + ~ConcreteCSPRNG(); +}; + +/// @brief LweSecretKey implements tools for manipulating lwe secret key on +/// client. +class LweSecretKey { + std::shared_ptr> _buffer; + LweSecretKeyParam _parameters; + +public: + LweSecretKey() = delete; + LweSecretKey(LweSecretKeyParam ¶meters, CSPRNG &csprng); + LweSecretKey(std::shared_ptr> buffer, + LweSecretKeyParam parameters) + : _buffer(buffer), _parameters(parameters){}; + + /// @brief Encrypt the plaintext to the lwe ciphertext buffer. + void encrypt(uint64_t *ciphertext, uint64_t plaintext, double variance, + CSPRNG &csprng) const; + + /// @brief Decrypt the ciphertext to the plaintext + void decrypt(const uint64_t *ciphertext, uint64_t &plaintext) const; + + /// @brief Returns the buffer that hold the keyswitch key. + const uint64_t *buffer() const { return _buffer->data(); } + size_t size() const { return _buffer->size(); } + + /// @brief Returns the parameters of the keyswicth key. + LweSecretKeyParam parameters() const { return this->_parameters; } + + /// @brief Returns the lwe dimension of the secret key. + size_t dimension() const { return parameters().dimension; } +}; + +/// @brief LweKeyswitchKey implements tools for manipulating keyswitch key on +/// client. class LweKeyswitchKey { private: - LweKeyswitchKey64 *ksk; - -protected: - friend std::ostream &operator<<(std::ostream &ostream, - const LweKeyswitchKey &wrappedKsk); - friend std::istream &operator>>(std::istream &istream, - LweKeyswitchKey &wrappedKsk); + std::shared_ptr> _buffer; + KeyswitchKeyParam _parameters; public: - LweKeyswitchKey(LweKeyswitchKey64 *ksk) : ksk{ksk} {} - LweKeyswitchKey(LweKeyswitchKey &other) = delete; - LweKeyswitchKey(LweKeyswitchKey &&other) : ksk{other.ksk} { - other.ksk = nullptr; - } - ~LweKeyswitchKey() { - if (this->ksk != nullptr) { - CAPI_ASSERT_ERROR(destroy_lwe_keyswitch_key_u64(this->ksk)); + LweKeyswitchKey() = delete; + LweKeyswitchKey(KeyswitchKeyParam ¶meters, LweSecretKey &inputKey, + LweSecretKey &outputKey, CSPRNG &csprng); + LweKeyswitchKey(std::shared_ptr> buffer, + KeyswitchKeyParam parameters) + : _buffer(buffer), _parameters(parameters){}; - this->ksk = nullptr; - } - } + /// @brief Returns the buffer that hold the keyswitch key. + const uint64_t *buffer() const { return _buffer->data(); } + size_t size() const { return _buffer->size(); } - LweKeyswitchKey64 *get() { return this->ksk; } + /// @brief Returns the parameters of the keyswicth key. + KeyswitchKeyParam parameters() const { return this->_parameters; } }; -// ============================================= - -/// Wrapper for `LweBootstrapKey64` so that it cleans up properly. +/// @brief LweBootstrapKey implements tools for manipulating bootstrap key on +/// client. class LweBootstrapKey { private: - LweBootstrapKey64 *bsk; - -protected: - friend std::ostream &operator<<(std::ostream &ostream, - const LweBootstrapKey &wrappedBsk); - friend std::istream &operator>>(std::istream &istream, - LweBootstrapKey &wrappedBsk); + std::shared_ptr> _buffer; + BootstrapKeyParam _parameters; public: - LweBootstrapKey(LweBootstrapKey64 *bsk) : bsk{bsk} {} - LweBootstrapKey(LweBootstrapKey &other) = delete; - LweBootstrapKey(LweBootstrapKey &&other) : bsk{other.bsk} { - other.bsk = nullptr; - } - ~LweBootstrapKey() { - if (this->bsk != nullptr) { - CAPI_ASSERT_ERROR(destroy_lwe_bootstrap_key_u64(this->bsk)); - this->bsk = nullptr; - } - } + LweBootstrapKey() = delete; + LweBootstrapKey(std::shared_ptr> buffer, + BootstrapKeyParam ¶meters) + : _buffer(buffer), _parameters(parameters){}; + LweBootstrapKey(BootstrapKeyParam ¶meters, LweSecretKey &inputKey, + LweSecretKey &outputKey, CSPRNG &csprng); - LweBootstrapKey64 *get() { return this->bsk; } + ///// @brief Returns the buffer that hold the bootstrap key. + const uint64_t *buffer() const { return _buffer->data(); } + size_t size() const { return _buffer->size(); } + + /// @brief Returns the parameters of the bootsrap key. + BootstrapKeyParam parameters() const { return this->_parameters; } }; -// ============================================= - -/// Wrapper for `LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64` so -/// that it cleans up properly. +/// @brief PackingKeyswitchKey implements tools for manipulating privat packing +/// keyswitch key on client. class PackingKeyswitchKey { private: - LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *key; - -protected: - friend std::ostream &operator<<(std::ostream &ostream, - const PackingKeyswitchKey &wrappedFpksk); - friend std::istream &operator>>(std::istream &istream, - PackingKeyswitchKey &wrappedFpksk); + std::shared_ptr> _buffer; + PackingKeyswitchKeyParam _parameters; public: - PackingKeyswitchKey( - LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *key) - : key{key} {} - PackingKeyswitchKey(PackingKeyswitchKey &other) = delete; - PackingKeyswitchKey(PackingKeyswitchKey &&other) : key{other.key} { - other.key = nullptr; - } - ~PackingKeyswitchKey() { - if (this->key != nullptr) { - CAPI_ASSERT_ERROR( - destroy_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_u64( - this->key)); - this->key = nullptr; - } - } + PackingKeyswitchKey() = delete; + PackingKeyswitchKey(PackingKeyswitchKeyParam ¶meters, + LweSecretKey &inputKey, LweSecretKey &outputKey, + CSPRNG &csprng); + PackingKeyswitchKey(std::shared_ptr> buffer, + PackingKeyswitchKeyParam parameters) + : _buffer(buffer), _parameters(parameters){}; - LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *get() { - return this->key; - } + /// @brief Returns the buffer that hold the keyswitch key. + const uint64_t *buffer() const { return _buffer->data(); } + size_t size() const { return _buffer->size(); } + + /// @brief Returns the parameters of the keyswicth key. + PackingKeyswitchKeyParam parameters() const { return this->_parameters; } }; // ============================================= @@ -120,31 +150,40 @@ public: /// Evalution keys required for execution. class EvaluationKeys { private: - std::shared_ptr sharedKsk; - std::shared_ptr sharedBsk; - std::shared_ptr sharedFpksk; - -protected: - friend std::ostream &operator<<(std::ostream &ostream, - const EvaluationKeys &evaluationKeys); - friend std::istream &operator>>(std::istream &istream, - EvaluationKeys &evaluationKeys); + std::vector keyswitchKeys; + std::vector bootstrapKeys; + std::vector packingKeyswitchKeys; public: - EvaluationKeys() - : sharedKsk{std::shared_ptr(nullptr)}, - sharedBsk{std::shared_ptr(nullptr)} {} + EvaluationKeys() = delete; - EvaluationKeys(std::shared_ptr sharedKsk, - std::shared_ptr sharedBsk, - std::shared_ptr sharedFpksk) - : sharedKsk{sharedKsk}, sharedBsk{sharedBsk}, sharedFpksk{sharedFpksk} {} + EvaluationKeys(const std::vector keyswitchKeys, + const std::vector bootstrapKeys, + const std::vector packingKeyswitchKeys) + : keyswitchKeys(keyswitchKeys), bootstrapKeys(bootstrapKeys), + packingKeyswitchKeys(packingKeyswitchKeys) {} - LweKeyswitchKey64 *getKsk() { return this->sharedKsk->get(); } - LweBootstrapKey64 *getBsk() { return this->sharedBsk->get(); } - LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *getFpksk() { - return this->sharedFpksk->get(); + const LweKeyswitchKey &getKeyswitchKey(size_t id) const { + return this->keyswitchKeys[id]; + } + const std::vector getKeyswitchKeys() const { + return this->keyswitchKeys; + } + + const LweBootstrapKey &getBootstrapKey(size_t id) const { + return bootstrapKeys[id]; + } + const std::vector getBootstrapKeys() const { + return this->bootstrapKeys; + } + + const PackingKeyswitchKey &getPackingKeyswitchKey(size_t id) const { + return this->packingKeyswitchKeys[id]; }; + + const std::vector getPackingKeyswitchKeys() const { + return this->packingKeyswitchKeys; + } }; // ============================================= diff --git a/compiler/include/concretelang/ClientLib/KeySet.h b/compiler/include/concretelang/ClientLib/KeySet.h index c465c11ef..0e6ce0b52 100644 --- a/compiler/include/concretelang/ClientLib/KeySet.h +++ b/compiler/include/concretelang/ClientLib/KeySet.h @@ -10,30 +10,33 @@ #include "boost/outcome.h" -#include "concrete-core-ffi.h" -#include "concretelang/Runtime/DFRuntime.hpp" -#include "concretelang/Runtime/context.h" - #include "concretelang/ClientLib/ClientParameters.h" #include "concretelang/ClientLib/EvaluationKeys.h" #include "concretelang/ClientLib/KeySetCache.h" #include "concretelang/Common/Error.h" +#include "concretelang/Runtime/DFRuntime.hpp" namespace concretelang { namespace clientlib { using concretelang::error::StringError; -using RuntimeContext = mlir::concretelang::RuntimeContext; class KeySet { public: - KeySet(); - ~KeySet(); + KeySet(ClientParameters clientParameters, CSPRNG &&csprng) + : csprng(std::move(csprng)), _clientParameters(clientParameters){}; KeySet(KeySet &other) = delete; - /// allocate a KeySet according the ClientParameters. + /// Generate a KeySet from a ClientParameters specification. static outcome::checked, StringError> - generate(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb); + generate(ClientParameters clientParameters, CSPRNG &&csprng); + + /// Create a KeySet from a set of given keys + static outcome::checked, StringError> fromKeys( + ClientParameters clientParameters, std::vector secretKeys, + std::vector bootstrapKeys, + std::vector keyswitchKeys, + std::vector packingKeyswitchKeys, CSPRNG &&csprng); /// Returns the ClientParameters associated with the KeySet. ClientParameters clientParameters() { return _clientParameters; } @@ -41,23 +44,6 @@ public: // isInputEncrypted return true if the input at the given pos is encrypted. bool isInputEncrypted(size_t pos); - /// getInputLweSecretKeyParam returns the parameters of the lwe secret key for - /// the input at the given `pos`. - /// The input must be encrupted - LweSecretKeyParam getInputLweSecretKeyParam(size_t pos) { - auto gate = inputGate(pos); - auto inputSk = this->secretKeys.find(gate.encryption->secretKeyID); - return inputSk->second.first; - } - - /// getOutputLweSecretKeyParam returns the parameters of the lwe secret key - /// for the given output. - LweSecretKeyParam getOutputLweSecretKeyParam(size_t pos) { - auto gate = outputGate(pos); - auto outputSk = this->secretKeys.find(gate.encryption->secretKeyID); - return outputSk->second.first; - } - /// allocate a lwe ciphertext buffer for the argument at argPos, set the size /// of the allocated buffer. outcome::checked @@ -80,103 +66,58 @@ public: CircuitGate inputGate(size_t pos) { return std::get<0>(inputs[pos]); } CircuitGate outputGate(size_t pos) { return std::get<0>(outputs[pos]); } - RuntimeContext runtimeContext() { - RuntimeContext context; - context.evaluationKeys = this->evaluationKeys(); - return context; - } + /// @brief evaluationKeys returns the evaluation keys associate to this client + /// keyset. Those evaluations keys can be safely shared publicly + EvaluationKeys evaluationKeys(); - EvaluationKeys evaluationKeys() { - if (this->bootstrapKeys.empty() && this->keyswitchKeys.empty()) { - return EvaluationKeys(); - } - auto kskIt = this->keyswitchKeys.find(clientlib::KEYSWITCH_KEY); - auto bskIt = this->bootstrapKeys.find(clientlib::BOOTSTRAP_KEY); - auto fpkskIt = this->packingKeys.find("fpksk_v0"); - if (kskIt != this->keyswitchKeys.end() && - bskIt != this->bootstrapKeys.end()) { - auto sharedKsk = std::get<1>(kskIt->second); - auto sharedBsk = std::get<1>(bskIt->second); - auto sharedFpksk = fpkskIt == this->packingKeys.end() - ? std::make_shared(nullptr) - : std::get<1>(fpkskIt->second); - return EvaluationKeys(sharedKsk, sharedBsk, sharedFpksk); - } - assert(!mlir::concretelang::dfr::_dfr_is_root_node() && - "Evaluation keys missing in KeySet (on root node)."); - return EvaluationKeys(); - } + const std::vector &getSecretKeys(); - const std::map> - &getSecretKeys(); + const std::vector &getBootstrapKeys(); - const std::map>> - &getBootstrapKeys(); + const std::vector &getKeyswitchKeys(); - const std::map>> - &getKeyswitchKeys(); - - const std::map< - LweSecretKeyID, - std::pair>> & - getPackingKeys(); + const std::vector &getPackingKeyswitchKeys(); protected: outcome::checked - generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param); + generateSecretKey(LweSecretKeyParam param); outcome::checked - generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param); + generateBootstrapKey(BootstrapKeyParam param); outcome::checked - generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param); + generateKeyswitchKey(KeyswitchKeyParam param); outcome::checked - generatePackingKey(PackingKeySwitchID id, PackingKeySwitchParam param); + generatePackingKeyswitchKey(PackingKeyswitchKeyParam param); - outcome::checked - generateKeysFromParams(ClientParameters ¶ms, uint64_t seed_msb, - uint64_t seed_lsb); + outcome::checked generateKeysFromParams(); - outcome::checked - setupEncryptionMaterial(ClientParameters ¶ms, uint64_t seed_msb, - uint64_t seed_lsb); + outcome::checked setupEncryptionMaterial(); friend class KeySetCache; private: - DefaultEngine *engine; - DefaultParallelEngine *par_engine; - std::map> - secretKeys; - std::map>> - bootstrapKeys; - std::map>> - keyswitchKeys; - std::map>> - packingKeys; - std::vector> - inputs; - std::vector> - outputs; + CSPRNG csprng; - void setKeys( - std::map> - secretKeys, - std::map>> - bootstrapKeys, - std::map>> - keyswitchKeys, - std::map>> - packingKeys); + /////////////////////////////////////////////// + // Keys mappings + std::vector secretKeys; + std::vector bootstrapKeys; + std::vector keyswitchKeys; + std::vector packingKeyswitchKeys; + + outcome::checked findLweSecretKey(LweSecretKeyID); + + /////////////////////////////////////////////// + // Convenient positional mapping between positional gate en secret key + typedef std::vector>> + SecretKeyGateMapping; + outcome::checked + mapCircuitGateLweSecretKey(std::vector gates); + + SecretKeyGateMapping inputs; + SecretKeyGateMapping outputs; clientlib::ClientParameters _clientParameters; }; diff --git a/compiler/include/concretelang/ClientLib/PublicArguments.h b/compiler/include/concretelang/ClientLib/PublicArguments.h index 97afa2ca8..5af887a0a 100644 --- a/compiler/include/concretelang/ClientLib/PublicArguments.h +++ b/compiler/include/concretelang/ClientLib/PublicArguments.h @@ -14,7 +14,6 @@ #include "concretelang/ClientLib/EncryptedArguments.h" #include "concretelang/ClientLib/Types.h" #include "concretelang/Common/Error.h" -#include "concretelang/Runtime/context.h" namespace concretelang { namespace serverlib { diff --git a/compiler/include/concretelang/ClientLib/Serializers.h b/compiler/include/concretelang/ClientLib/Serializers.h index 76e006fc8..5e12ff503 100644 --- a/compiler/include/concretelang/ClientLib/Serializers.h +++ b/compiler/include/concretelang/ClientLib/Serializers.h @@ -12,13 +12,10 @@ #include "concretelang/ClientLib/ClientParameters.h" #include "concretelang/ClientLib/EvaluationKeys.h" #include "concretelang/ClientLib/Types.h" -#include "concretelang/Runtime/context.h" namespace concretelang { namespace clientlib { -using RuntimeContext = mlir::concretelang::RuntimeContext; - // integers are not serialized as binary values even on a binary stream // so we cannot rely on << operator directly template @@ -62,10 +59,6 @@ template bool incorrectMode(Stream &stream) { return !binary; } -std::ostream &operator<<(std::ostream &ostream, - const RuntimeContext &runtimeContext); -std::istream &operator>>(std::istream &istream, RuntimeContext &runtimeContext); - std::ostream &serializeScalarData(const ScalarData &sd, std::ostream &ostream); outcome::checked @@ -105,17 +98,24 @@ outcome::checked unserializeScalarOrTensorData(const std::vector &expectedSizes, std::istream &istream); +std::ostream &operator<<(std::ostream &ostream, const LweSecretKey &wrappedKsk); +LweSecretKey readLweSecretKey(std::istream &istream); + std::ostream &operator<<(std::ostream &ostream, const LweKeyswitchKey &wrappedKsk); -std::istream &operator>>(std::istream &istream, LweKeyswitchKey &wrappedKsk); +LweKeyswitchKey readLweKeyswitchKey(std::istream &istream); std::ostream &operator<<(std::ostream &ostream, const LweBootstrapKey &wrappedBsk); -std::istream &operator>>(std::istream &istream, LweBootstrapKey &wrappedBsk); +LweBootstrapKey readLweBootstrapKey(std::istream &istream); + +std::ostream &operator<<(std::ostream &ostream, + const PackingKeyswitchKey &wrappedKsk); +PackingKeyswitchKey readPackingKeyswitchKey(std::istream &istream); std::ostream &operator<<(std::ostream &ostream, const EvaluationKeys &evaluationKeys); -std::istream &operator>>(std::istream &istream, EvaluationKeys &evaluationKeys); +EvaluationKeys readEvaluationKeys(std::istream &istream); } // namespace clientlib } // namespace concretelang diff --git a/compiler/include/concretelang/Common/Error.h b/compiler/include/concretelang/Common/Error.h index 74cdc0dc8..508c175a1 100644 --- a/compiler/include/concretelang/Common/Error.h +++ b/compiler/include/concretelang/Common/Error.h @@ -7,12 +7,6 @@ #include -#define CAPI_ASSERT_ERROR(instr) \ - { \ - int err = instr; \ - assert(err == 0); \ - } - namespace concretelang { namespace error { diff --git a/compiler/include/concretelang/Conversion/FHEToTFHECrt/Pass.h b/compiler/include/concretelang/Conversion/FHEToTFHECrt/Pass.h index 131e18490..b3d173797 100644 --- a/compiler/include/concretelang/Conversion/FHEToTFHECrt/Pass.h +++ b/compiler/include/concretelang/Conversion/FHEToTFHECrt/Pass.h @@ -8,6 +8,7 @@ #include "mlir/Pass/Pass.h" #include "llvm/Support/Casting.h" +#include #include namespace mlir { @@ -19,11 +20,9 @@ struct CrtLoweringParameters { size_t nMods; size_t modsProd; size_t bitsTotal; - size_t polynomialSize; - size_t lutSize; + size_t singleLutSize; - CrtLoweringParameters(mlir::SmallVector mods, size_t polySize) - : mods(mods), polynomialSize(polySize) { + CrtLoweringParameters(mlir::SmallVector mods) : mods(mods) { nMods = mods.size(); modsProd = 1; bitsTotal = 0; @@ -35,9 +34,7 @@ struct CrtLoweringParameters { bits.push_back(nbits); bitsTotal += nbits; } - size_t lutCrtSize = size_t(1) << bitsTotal; - lutCrtSize = std::max(lutCrtSize, polynomialSize); - lutSize = mods.size() * lutCrtSize; + singleLutSize = size_t(1) << bitsTotal; } }; diff --git a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td index 101811186..5d816835c 100644 --- a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td +++ b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td @@ -15,12 +15,14 @@ include "concretelang/Dialect/RT/IR/RTTypes.td" def Concrete_LweTensor : 1DTensorOf<[I64]>; def Concrete_LutTensor : 1DTensorOf<[I64]>; +def Concrete_CrtLutsTensor : 2DTensorOf<[I64]>; def Concrete_CrtPlaintextTensor : 1DTensorOf<[I64]>; def Concrete_LweCRTTensor : 2DTensorOf<[I64]>; def Concrete_BatchLweTensor : 2DTensorOf<[I64]>; def Concrete_LweBuffer : MemRefRankOf<[I64], [1]>; def Concrete_LutBuffer : MemRefRankOf<[I64], [1]>; +def Concrete_CrtLutsBuffer : MemRefRankOf<[I64], [2]>; def Concrete_CrtPlaintextBuffer : MemRefRankOf<[I64], [1]>; def Concrete_LweCRTBuffer : MemRefRankOf<[I64], [2]>; def Concrete_BatchLweBuffer : MemRefRankOf<[I64], [2]>; @@ -126,7 +128,7 @@ def Concrete_EncodeExpandLutForBootstrapBufferOp : Concrete_Op<"encode_expand_lu ); } -def Concrete_EncodeExpandLutForWopPBSTensorOp : Concrete_Op<"encode_expand_lut_for_woppbs_tensor", [NoSideEffect]> { +def Concrete_EncodeLutForCrtWopPBSTensorOp : Concrete_Op<"encode_lut_for_crt_woppbs_tensor", [NoSideEffect]> { let summary = "Encode and expand a lookup table so that it can be used for a wop pbs"; @@ -134,24 +136,22 @@ def Concrete_EncodeExpandLutForWopPBSTensorOp : Concrete_Op<"encode_expand_lut_f Concrete_LutTensor : $input_lookup_table, I64ArrayAttr: $crtDecomposition, I64ArrayAttr: $crtBits, - I32Attr : $polySize, I32Attr : $modulusProduct, BoolAttr: $isSigned ); - let results = (outs Concrete_LutTensor : $result); + let results = (outs Concrete_CrtLutsTensor : $result); } -def Concrete_EncodeExpandLutForWopPBSBufferOp : Concrete_Op<"encode_expand_lut_for_woppbs_buffer"> { +def Concrete_EncodeLutForCrtWopPBSBufferOp : Concrete_Op<"encode_lut_for_crt_woppbs_buffer"> { let summary = - "Encode and expand a lookup table so that it can be used for a wop pbs"; + "Encode and expand a lookup table so that it can be used for a crt wop pbs"; let arguments = (ins - Concrete_LutBuffer : $result, + Concrete_CrtLutsBuffer : $result, Concrete_LutBuffer : $input_lookup_table, I64ArrayAttr: $crtDecomposition, I64ArrayAttr: $crtBits, - I32Attr : $polySize, I32Attr : $modulusProduct, BoolAttr: $isSigned ); @@ -347,7 +347,7 @@ def Concrete_BatchedKeySwitchLweBufferOp : Concrete_Op<"batched_keyswitch_lwe_bu def Concrete_WopPBSCRTLweTensorOp : Concrete_Op<"wop_pbs_crt_lwe_tensor", [NoSideEffect]> { let arguments = (ins Concrete_LweCRTTensor:$ciphertext, - Concrete_LutTensor:$lookupTable, + Concrete_CrtLutsTensor:$lookupTable, // Bootstrap parameters I32Attr : $bootstrapLevel, I32Attr : $bootstrapBaseLog, @@ -370,7 +370,7 @@ def Concrete_WopPBSCRTLweBufferOp : Concrete_Op<"wop_pbs_crt_lwe_buffer"> { let arguments = (ins Concrete_LweCRTBuffer:$result, Concrete_LweCRTBuffer:$ciphertext, - Concrete_LutBuffer:$lookup_table, + Concrete_CrtLutsBuffer:$lookup_table, // Bootstrap parameters I32Attr : $bootstrapLevel, I32Attr : $bootstrapBaseLog, diff --git a/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td b/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td index d7d6b4ce4..9cfc2cbea 100644 --- a/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td +++ b/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td @@ -33,7 +33,7 @@ def TFHE_EncodeExpandLutForBootstrapOp : TFHE_Op<"encode_expand_lut_for_bootstra let results = (outs 1DTensorOf<[I64]> : $result); } -def TFHE_EncodeExpandLutForWopPBSOp : TFHE_Op<"encode_expand_lut_for_woppbs"> { +def TFHE_EncodeLutForCrtWopPBSOp : TFHE_Op<"encode_lut_for_crt_woppbs"> { let summary = "Encode and expand a lookup table so that it can be used for a wop pbs."; @@ -41,12 +41,11 @@ def TFHE_EncodeExpandLutForWopPBSOp : TFHE_Op<"encode_expand_lut_for_woppbs"> { 1DTensorOf<[I64]> : $input_lookup_table, I64ArrayAttr: $crtDecomposition, I64ArrayAttr: $crtBits, - I32Attr : $polySize, I32Attr : $modulusProduct, BoolAttr: $isSigned ); - let results = (outs 1DTensorOf<[I64]> : $result); + let results = (outs 2DTensorOf<[I64]> : $result); } def TFHE_EncodePlaintextWithCrtOp : TFHE_Op<"encode_plaintext_with_crt"> { @@ -158,7 +157,7 @@ def TFHE_WopPBSGLWEOp : TFHE_Op<"wop_pbs_glwe"> { let arguments = (ins Type.predicate, HasStaticShapePred]>>: $ciphertexts, - 1DTensorOf<[I64]> : $lookupTable, + 2DTensorOf<[I64]> : $lookupTable, // Bootstrap parameters I32Attr : $bootstrapLevel, I32Attr : $bootstrapBaseLog, diff --git a/compiler/include/concretelang/Runtime/context.h b/compiler/include/concretelang/Runtime/context.h index ae7aa63dd..3e4294616 100644 --- a/compiler/include/concretelang/Runtime/context.h +++ b/compiler/include/concretelang/Runtime/context.h @@ -12,11 +12,10 @@ #include #include "concretelang/ClientLib/EvaluationKeys.h" -#include "concretelang/Runtime/seeder.h" - -#include "concrete-core-ffi.h" #include "concretelang/Common/Error.h" +#include "concrete-cpu.h" + #ifdef CONCRETELANG_CUDA_SUPPORT #include "bootstrap.h" #include "device.h" @@ -26,27 +25,23 @@ namespace mlir { namespace concretelang { +typedef struct FFT { + FFT() = delete; + FFT(size_t polynomial_size); + FFT(FFT &other) = delete; + FFT(FFT &&other); + ~FFT(); + + struct Fft *fft; + size_t polynomial_size; +} FFT; + typedef struct RuntimeContext { - RuntimeContext() { - CAPI_ASSERT_ERROR(new_default_engine(best_seeder, &default_engine)); -#ifdef CONCRETELANG_CUDA_SUPPORT - bsk_gpu = nullptr; - ksk_gpu = nullptr; -#endif - } - - /// Ensure that the engines map is not copied - RuntimeContext(const RuntimeContext &ctx){}; + RuntimeContext() = delete; + RuntimeContext(::concretelang::clientlib::EvaluationKeys evaluationKeys); ~RuntimeContext() { - CAPI_ASSERT_ERROR(destroy_default_engine(default_engine)); - for (const auto &key : fft_engines) { - CAPI_ASSERT_ERROR(destroy_fft_engine(key.second)); - } - if (fbsk != nullptr) { - CAPI_ASSERT_ERROR(destroy_fft_fourier_lwe_bootstrap_key_u64(fbsk)); - } #ifdef CONCRETELANG_CUDA_SUPPORT if (bsk_gpu != nullptr) { cuda_drop(bsk_gpu, 0); @@ -55,44 +50,33 @@ typedef struct RuntimeContext { cuda_drop(ksk_gpu, 0); } #endif + }; + + const uint64_t *keyswitch_key_buffer(size_t keyId) { + return evaluationKeys.getKeyswitchKey(keyId).buffer(); } - FftEngine *get_fft_engine() { - pthread_t threadId = pthread_self(); - std::lock_guard guard(engines_map_guard); - auto engineIt = fft_engines.find(threadId); - if (engineIt == fft_engines.end()) { - FftEngine *fft_engine = nullptr; - - CAPI_ASSERT_ERROR(new_fft_engine(&fft_engine)); - - engineIt = - fft_engines - .insert(std::pair(threadId, fft_engine)) - .first; - } - assert(engineIt->second && "No engine available in context"); - return engineIt->second; + const double *fourier_bootstrap_key_buffer(size_t keyId) { + return fourier_bootstrap_keys[keyId]->data(); } - DefaultEngine *get_default_engine() { return default_engine; } - - FftFourierLweBootstrapKey64 *get_fft_fourier_bsk() { - - if (fbsk != nullptr) { - return fbsk; - } - - const std::lock_guard guard(fbskMutex); - if (fbsk == nullptr) { - CAPI_ASSERT_ERROR( - fft_engine_convert_lwe_bootstrap_key_to_fft_fourier_lwe_bootstrap_key_u64( - get_fft_engine(), evaluationKeys.getBsk(), &fbsk)); - } - return fbsk; + const uint64_t *fp_keyswitch_key_buffer(size_t keyId) { + return evaluationKeys.getPackingKeyswitchKey(keyId).buffer(); } + const struct Fft *fft(size_t keyId) { return ffts[keyId].fft; } + + const ::concretelang::clientlib::EvaluationKeys getKeys() const { + return evaluationKeys; + } + +private: + ::concretelang::clientlib::EvaluationKeys evaluationKeys; + std::vector>> fourier_bootstrap_keys; + std::vector ffts; + #ifdef CONCRETELANG_CUDA_SUPPORT +public: void *get_bsk_gpu(uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level, uint32_t glwe_dim, uint32_t gpu_idx, void *stream) { @@ -104,25 +88,21 @@ typedef struct RuntimeContext { if (bsk_gpu != nullptr) { return bsk_gpu; } - LweBootstrapKey64 *bsk = get_bsk(); - size_t bsk_buffer_len = - input_lwe_dim * (glwe_dim + 1) * (glwe_dim + 1) * poly_size * level; - size_t bsk_buffer_size = bsk_buffer_len * sizeof(uint64_t); - uint64_t *bsk_buffer = - (uint64_t *)aligned_alloc(U64_ALIGNMENT, bsk_buffer_size); + + auto bsk = evaluationKeys.getBootstrapKey(0); + + size_t bsk_buffer_len = bsk.size(); size_t bsk_gpu_buffer_size = bsk_buffer_len * sizeof(double); + void *bsk_gpu_tmp = cuda_malloc(bsk_gpu_buffer_size, gpu_idx); - CAPI_ASSERT_ERROR( - default_engine_discard_convert_lwe_bootstrap_key_to_lwe_bootstrap_key_mut_view_u64_raw_ptr_buffers( - default_engine, bsk, bsk_buffer)); cuda_initialize_twiddles(poly_size, gpu_idx); - cuda_convert_lwe_bootstrap_key_64(bsk_gpu_tmp, bsk_buffer, stream, gpu_idx, - input_lwe_dim, glwe_dim, level, + cuda_convert_lwe_bootstrap_key_64(bsk_gpu_tmp, (void *)bsk.buffer(), stream, + gpu_idx, input_lwe_dim, glwe_dim, level, poly_size); - // This is currently not 100% async as we have to free CPU memory after + // This is currently not 100% async as + // we have to free CPU memory after // conversion cuda_synchronize_device(gpu_idx); - free(bsk_buffer); bsk_gpu = bsk_gpu_tmp; return bsk_gpu; } @@ -138,49 +118,23 @@ typedef struct RuntimeContext { if (ksk_gpu != nullptr) { return ksk_gpu; } - LweKeyswitchKey64 *ksk = get_ksk(); - size_t ksk_buffer_len = input_lwe_dim * (output_lwe_dim + 1) * level; - size_t ksk_buffer_size = sizeof(uint64_t) * ksk_buffer_len; - uint64_t *ksk_buffer = - (uint64_t *)aligned_alloc(U64_ALIGNMENT, ksk_buffer_size); + auto ksk = evaluationKeys.getKeyswitchKey(0); + + size_t ksk_buffer_size = sizeof(uint64_t) * ksk.size(); + void *ksk_gpu_tmp = cuda_malloc(ksk_buffer_size, gpu_idx); - CAPI_ASSERT_ERROR( - default_engine_discard_convert_lwe_keyswitch_key_to_lwe_keyswitch_key_mut_view_u64_raw_ptr_buffers( - default_engine, ksk, ksk_buffer)); - cuda_memcpy_async_to_gpu(ksk_gpu_tmp, ksk_buffer, ksk_buffer_size, stream, - gpu_idx); - // This is currently not 100% async as we have to free CPU memory after + + cuda_memcpy_async_to_gpu(ksk_gpu_tmp, (void *)ksk.buffer(), ksk_buffer_size, + stream, gpu_idx); + // This is currently not 100% async as + // we have to free CPU memory after // conversion cuda_synchronize_device(gpu_idx); - free(ksk_buffer); ksk_gpu = ksk_gpu_tmp; return ksk_gpu; } -#endif - - LweBootstrapKey64 *get_bsk() { return evaluationKeys.getBsk(); } - - LweKeyswitchKey64 *get_ksk() { return evaluationKeys.getKsk(); } - - LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *get_fpksk() { - return evaluationKeys.getFpksk(); - } - - RuntimeContext &operator=(const RuntimeContext &rhs) { - this->evaluationKeys = rhs.evaluationKeys; - return *this; - } - - ::concretelang::clientlib::EvaluationKeys evaluationKeys; private: - std::mutex fbskMutex; - FftFourierLweBootstrapKey64 *fbsk = nullptr; - DefaultEngine *default_engine; - std::map fft_engines; - std::mutex engines_map_guard; - -#ifdef CONCRETELANG_CUDA_SUPPORT std::mutex bsk_gpu_mutex; void *bsk_gpu; std::mutex ksk_gpu_mutex; @@ -192,18 +146,4 @@ private: } // namespace concretelang } // namespace mlir -extern "C" { -LweKeyswitchKey64 * -get_keyswitch_key_u64(mlir::concretelang::RuntimeContext *context); - -FftFourierLweBootstrapKey64 * -get_fft_fourier_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context); - -LweBootstrapKey64 * -get_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context); - -DefaultEngine *get_engine(mlir::concretelang::RuntimeContext *context); - -FftEngine *get_fft_engine(mlir::concretelang::RuntimeContext *context); -} #endif diff --git a/compiler/include/concretelang/Runtime/key_manager.hpp b/compiler/include/concretelang/Runtime/key_manager.hpp index a5287443f..c19f4f807 100644 --- a/compiler/include/concretelang/Runtime/key_manager.hpp +++ b/compiler/include/concretelang/Runtime/key_manager.hpp @@ -14,17 +14,16 @@ #include #include +#include "concretelang/ClientLib/EvaluationKeys.h" #include "concretelang/Runtime/DFRuntime.hpp" #include "concretelang/Runtime/context.h" -#include "concrete-core-ffi.h" #include "concretelang/Common/Error.h" namespace mlir { namespace concretelang { namespace dfr { -template struct KeyManager; struct RuntimeContextManager; namespace { static void *dl_handle; @@ -32,103 +31,83 @@ static RuntimeContextManager *_dfr_node_level_runtime_context_manager; } // namespace template struct KeyWrapper { - LweKeyType *key; - Buffer buffer; + std::vector keys; - KeyWrapper() : key(nullptr) {} - KeyWrapper(KeyWrapper &&moved) noexcept - : key(moved.key), buffer(moved.buffer) {} - KeyWrapper(LweKeyType *key); - KeyWrapper(const KeyWrapper &kw) : key(kw.key), buffer(kw.buffer) {} + KeyWrapper() {} + KeyWrapper(KeyWrapper &&moved) noexcept : keys(moved.keys) {} + KeyWrapper(const KeyWrapper &kw) : keys(kw.keys) {} KeyWrapper &operator=(const KeyWrapper &rhs) { - this->key = rhs.key; - this->buffer = rhs.buffer; + this->keys = rhs.keys; return *this; } + KeyWrapper(std::vector keyvec) : keys(keyvec) {} friend class hpx::serialization::access; + // template + // void save(Archive &ar, const unsigned int version) const; template - void save(Archive &ar, const unsigned int version) const; - template void load(Archive &ar, const unsigned int version); - HPX_SERIALIZATION_SPLIT_MEMBER() + void serialize(Archive &ar, const unsigned int version) const {} + // template void load(Archive &ar, const unsigned int + // version); HPX_SERIALIZATION_SPLIT_MEMBER() }; -template <> -KeyWrapper::KeyWrapper(LweKeyswitchKey64 *key) : key(key) { - - DefaultSerializationEngine *engine; - - CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); - // No Freeing as it doesn't allocate anything. - CAPI_ASSERT_ERROR( - default_serialization_engine_serialize_lwe_keyswitch_key_u64(engine, key, - &buffer)); -} -template <> -KeyWrapper::KeyWrapper(LweBootstrapKey64 *key) : key(key) { - - DefaultSerializationEngine *engine; - - CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); - - // No Freeing as it doesn't allocate anything. - CAPI_ASSERT_ERROR( - default_serialization_engine_serialize_lwe_bootstrap_key_u64(engine, key, - &buffer)); -} - template bool operator==(const KeyWrapper &lhs, const KeyWrapper &rhs) { - return lhs.key == rhs.key; + if (lhs.keys.size() != rhs.keys.size()) + return false; + for (size_t i = 0; i < lhs.keys.size(); ++i) + if (lhs.keys[i].buffer() != rhs.keys[i].buffer()) + return false; + return true; } -template <> -template -void KeyWrapper::save(Archive &ar, - const unsigned int version) const { - ar << buffer.length; - ar << hpx::serialization::make_array(buffer.pointer, buffer.length); -} -template <> -template -void KeyWrapper::load(Archive &ar, - const unsigned int version) { - DefaultSerializationEngine *engine; +// template <> +// template +// void KeyWrapper::save(Archive &ar, +// const unsigned int version) const { +// ar << buffer.length; +// ar << hpx::serialization::make_array(buffer.pointer, buffer.length); +// } +// template <> +// template +// void KeyWrapper::load(Archive &ar, +// const unsigned int version) { +// DefaultSerializationEngine *engine; - // No Freeing as it doesn't allocate anything. - CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); +// // No Freeing as it doesn't allocate anything. +// CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); - ar >> buffer.length; - buffer.pointer = new uint8_t[buffer.length]; - ar >> hpx::serialization::make_array(buffer.pointer, buffer.length); - CAPI_ASSERT_ERROR( - default_serialization_engine_deserialize_lwe_bootstrap_key_u64( - engine, {buffer.pointer, buffer.length}, &key)); -} +// ar >> buffer.length; +// buffer.pointer = new uint8_t[buffer.length]; +// ar >> hpx::serialization::make_array(buffer.pointer, buffer.length); +// CAPI_ASSERT_ERROR( +// default_serialization_engine_deserialize_lwe_bootstrap_key_u64( +// engine, {buffer.pointer, buffer.length}, &key)); +// } -template <> -template -void KeyWrapper::save(Archive &ar, - const unsigned int version) const { - ar << buffer.length; - ar << hpx::serialization::make_array(buffer.pointer, buffer.length); -} -template <> -template -void KeyWrapper::load(Archive &ar, - const unsigned int version) { - DefaultSerializationEngine *engine; +// template <> +// template +// void KeyWrapper::save(Archive &ar, +// const unsigned int version) const { +// ar << buffer.length; +// ar << hpx::serialization::make_array(buffer.pointer, buffer.length); +// } +// template <> +// template +// void KeyWrapper::load(Archive &ar, +// const unsigned int version) { +// DefaultSerializationEngine *engine; - // No Freeing as it doesn't allocate anything. - CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); +// // No Freeing as it doesn't allocate anything. +// CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); - ar >> buffer.length; - buffer.pointer = new uint8_t[buffer.length]; - ar >> hpx::serialization::make_array(buffer.pointer, buffer.length); - CAPI_ASSERT_ERROR( - default_serialization_engine_deserialize_lwe_keyswitch_key_u64( - engine, {buffer.pointer, buffer.length}, &key)); -} +// ar >> buffer.length; +// buffer.pointer = new uint8_t[buffer.length]; +// ar >> hpx::serialization::make_array(buffer.pointer, buffer.length); +// CAPI_ASSERT_ERROR( +// default_serialization_engine_deserialize_lwe_keyswitch_key_u64( +// engine, {buffer.pointer, buffer.length}, &key)); +// } /************************/ /* Context management. */ @@ -152,34 +131,36 @@ struct RuntimeContextManager { // instantiates a local RuntimeContext. if (_dfr_is_root_node()) { RuntimeContext *context = (RuntimeContext *)ctx; - LweKeyswitchKey64 *ksk = get_keyswitch_key_u64(context); - LweBootstrapKey64 *bsk = get_bootstrap_key_u64(context); - KeyWrapper kskw(ksk); - KeyWrapper bskw(bsk); + KeyWrapper<::concretelang::clientlib::LweKeyswitchKey> kskw( + context->getKeys().getKeyswitchKeys()); + KeyWrapper<::concretelang::clientlib::LweBootstrapKey> bskw( + context->getKeys().getBootstrapKeys()); + KeyWrapper<::concretelang::clientlib::PackingKeyswitchKey> pkskw( + context->getKeys().getPackingKeyswitchKeys()); hpx::collectives::broadcast_to("ksk_keystore", kskw); hpx::collectives::broadcast_to("bsk_keystore", bskw); + hpx::collectives::broadcast_to("pksk_keystore", pkskw); } else { - auto kskFut = - hpx::collectives::broadcast_from>( - "ksk_keystore"); - auto bskFut = - hpx::collectives::broadcast_from>( - "bsk_keystore"); + auto kskFut = hpx::collectives::broadcast_from< + KeyWrapper<::concretelang::clientlib::LweKeyswitchKey>>( + "ksk_keystore"); + auto bskFut = hpx::collectives::broadcast_from< + KeyWrapper<::concretelang::clientlib::LweBootstrapKey>>( + "bsk_keystore"); + auto pkskFut = hpx::collectives::broadcast_from< + KeyWrapper<::concretelang::clientlib::PackingKeyswitchKey>>( + "pksk_keystore"); - KeyWrapper kskw = kskFut.get(); - KeyWrapper bskw = bskFut.get(); - context = new mlir::concretelang::RuntimeContext(); - // TODO - packing keyswitch key for distributed - context->evaluationKeys = ::concretelang::clientlib::EvaluationKeys( - std::shared_ptr<::concretelang::clientlib::LweKeyswitchKey>( - new ::concretelang::clientlib::LweKeyswitchKey(kskw.key)), - std::shared_ptr<::concretelang::clientlib::LweBootstrapKey>( - new ::concretelang::clientlib::LweBootstrapKey(bskw.key)), - std::shared_ptr<::concretelang::clientlib::PackingKeyswitchKey>( - nullptr)); - delete (kskw.buffer.pointer); - delete (bskw.buffer.pointer); + KeyWrapper<::concretelang::clientlib::LweKeyswitchKey> kskw = + kskFut.get(); + KeyWrapper<::concretelang::clientlib::LweBootstrapKey> bskw = + bskFut.get(); + KeyWrapper<::concretelang::clientlib::PackingKeyswitchKey> pkskw = + pkskFut.get(); + context = new mlir::concretelang::RuntimeContext( + ::concretelang::clientlib::EvaluationKeys(kskw.keys, bskw.keys, + pkskw.keys)); } } diff --git a/compiler/include/concretelang/Runtime/seeder.h b/compiler/include/concretelang/Runtime/seeder.h deleted file mode 100644 index 8e8076202..000000000 --- a/compiler/include/concretelang/Runtime/seeder.h +++ /dev/null @@ -1,13 +0,0 @@ -// Part of the Concrete Compiler Project, under the BSD3 License with Zama -// Exceptions. See -// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt -// for license information. - -#ifndef CONCRETELANG_RUNTIME_SEEDER_H -#define CONCRETELANG_RUNTIME_SEEDER_H - -#include "concrete-core-ffi.h" - -extern SeederBuilder *best_seeder; - -#endif diff --git a/compiler/include/concretelang/Runtime/wrappers.h b/compiler/include/concretelang/Runtime/wrappers.h index 387708ba7..f930defb2 100644 --- a/compiler/include/concretelang/Runtime/wrappers.h +++ b/compiler/include/concretelang/Runtime/wrappers.h @@ -29,18 +29,19 @@ void memref_encode_expand_lut_for_bootstrap( uint64_t input_lut_size, uint64_t input_lut_stride, uint32_t poly_size, uint32_t out_MESSAGE_BITS, bool is_signed); -void memref_encode_expand_lut_for_woppbs( +void memref_encode_lut_for_crt_woppbs( uint64_t *output_lut_allocated, uint64_t *output_lut_aligned, - uint64_t output_lut_offset, uint64_t output_lut_size, - uint64_t output_lut_stride, uint64_t *input_lut_allocated, + uint64_t output_lut_offset, uint64_t output_lut_size0, + uint64_t output_lut_size1, uint64_t output_lut_stride0, + uint64_t output_lut_stride1, uint64_t *input_lut_allocated, uint64_t *input_lut_aligned, uint64_t input_lut_offset, uint64_t input_lut_size, uint64_t input_lut_stride, uint64_t *crt_decomposition_allocated, uint64_t *crt_decomposition_aligned, uint64_t crt_decomposition_offset, uint64_t crt_decomposition_size, uint64_t crt_decomposition_stride, uint64_t *crt_bits_allocated, uint64_t *crt_bits_aligned, uint64_t crt_bits_offset, - uint64_t crt_bits_size, uint64_t crt_bits_stride, uint32_t poly_size, - uint32_t modulus_product, bool is_signed); + uint64_t crt_bits_size, uint64_t crt_bits_stride, uint32_t modulus_product, + bool is_signed); void memref_encode_plaintext_with_crt( uint64_t *output_allocated, uint64_t *output_aligned, @@ -149,13 +150,16 @@ void memref_wop_pbs_crt_buffer( uint64_t in_stride_1, // clear text lut uint64_t *lut_ct_allocated, uint64_t *lut_ct_aligned, - uint64_t lut_ct_offset, uint64_t lut_ct_size, uint64_t lut_ct_stride, + uint64_t lut_ct_offset, uint64_t lut_ct_size0, uint64_t lut_ct_size1, + uint64_t lut_ct_stride0, uint64_t lut_ct_stride1, // CRT decomposition uint64_t *crt_decomp_allocated, uint64_t *crt_decomp_aligned, uint64_t crt_decomp_offset, uint64_t crt_decomp_size, uint64_t crt_decomp_stride, // Additional crypto parameters uint32_t lwe_small_size, uint32_t cbs_level_count, uint32_t cbs_base_log, + uint32_t ksk_level_count, uint32_t ksk_base_log, uint32_t bsk_level_count, + uint32_t bsk_base_log, uint32_t fpksk_level_count, uint32_t fpksk_base_log, uint32_t polynomial_size, // runtime context that hold evluation keys mlir::concretelang::RuntimeContext *context); diff --git a/compiler/include/concretelang/Support/LambdaSupport.h b/compiler/include/concretelang/Support/LambdaSupport.h index 68dfb6d5d..11294212d 100644 --- a/compiler/include/concretelang/Support/LambdaSupport.h +++ b/compiler/include/concretelang/Support/LambdaSupport.h @@ -337,8 +337,8 @@ public: if (check.has_error()) { return StreamStringError(check.error().mesg); } - auto publicArguments = encryptedArgs->exportPublicArguments( - clientParameters, keySet.runtimeContext()); + auto publicArguments = + encryptedArgs->exportPublicArguments(clientParameters); if (publicArguments.has_error()) { return StreamStringError(publicArguments.error().mesg); } @@ -485,8 +485,8 @@ public: if (encryptedArgs.has_error()) { return StreamStringError(encryptedArgs.error().mesg); } - auto publicArguments = encryptedArgs.value()->exportPublicArguments( - clientParameters, keySet->runtimeContext()); + auto publicArguments = + encryptedArgs.value()->exportPublicArguments(clientParameters); if (!publicArguments.has_value()) { return StreamStringError(publicArguments.error().mesg); } @@ -506,8 +506,8 @@ public: if (encryptedArgs.has_error()) { return StreamStringError(encryptedArgs.error().mesg); } - auto publicArguments = encryptedArgs.value()->exportPublicArguments( - clientParameters, keySet->runtimeContext()); + auto publicArguments = + encryptedArgs.value()->exportPublicArguments(clientParameters); if (publicArguments.has_error()) { return StreamStringError(publicArguments.error().mesg); } diff --git a/compiler/include/concretelang/TestLib/TestTypedLambda.h b/compiler/include/concretelang/TestLib/TestTypedLambda.h index 49dee3334..de093ae70 100644 --- a/compiler/include/concretelang/TestLib/TestTypedLambda.h +++ b/compiler/include/concretelang/TestLib/TestTypedLambda.h @@ -73,8 +73,7 @@ public: OUTCOME_TRY(auto encryptedArgs, clientlib::EncryptedArguments::create(*keySet, args...)); OUTCOME_TRY(auto publicArgument, - encryptedArgs->exportPublicArguments(this->clientParameters, - keySet->runtimeContext())); + encryptedArgs->exportPublicArguments(this->clientParameters)); // client argument serialization // publicArgument->serialize(clientOuput); // message = clientOuput.str(); diff --git a/compiler/lib/Bindings/Python/CompilerEngine.cpp b/compiler/lib/Bindings/Python/CompilerEngine.cpp index b050b38b0..c0af992a2 100644 --- a/compiler/lib/Bindings/Python/CompilerEngine.cpp +++ b/compiler/lib/Bindings/Python/CompilerEngine.cpp @@ -218,8 +218,8 @@ MLIR_CAPI_EXPORTED concretelang::clientlib::EvaluationKeys evaluationKeysUnserialize(const std::string &buffer) { std::stringstream istream(buffer); - concretelang::clientlib::EvaluationKeys evaluationKeys; - concretelang::clientlib::operator>>(istream, evaluationKeys); + concretelang::clientlib::EvaluationKeys evaluationKeys = + concretelang::clientlib::readEvaluationKeys(istream); if (istream.fail()) { throw std::runtime_error("Cannot read evaluation keys"); diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index e853ce5a2..7432d31f7 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -394,8 +394,15 @@ void encodingDestroy(Encoding encoding){C_STRUCT_CLEANER(encoding)} KeySet keySetGenerate(ClientParameters params, uint64_t seed_msb, uint64_t seed_lsb) { + + __uint128_t seed = seed_msb; + seed <<= 64; + seed += seed_lsb; + + auto csprng = concretelang::clientlib::ConcreteCSPRNG(seed); + auto keySet = mlir::concretelang::clientlib::KeySet::generate( - *unwrap(params), seed_msb, seed_lsb); + *unwrap(params), std::move(csprng)); if (keySet.has_error()) { return wrap((mlir::concretelang::clientlib::KeySet *)NULL, keySet.error().mesg); @@ -445,8 +452,8 @@ BufferRef evaluationKeysSerialize(EvaluationKeys keys) { EvaluationKeys evaluationKeysUnserialize(BufferRef buffer) { std::stringstream istream(std::string(buffer.data, buffer.length)); - concretelang::clientlib::EvaluationKeys evaluationKeys; - concretelang::clientlib::operator>>(istream, evaluationKeys); + concretelang::clientlib::EvaluationKeys evaluationKeys = + concretelang::clientlib::readEvaluationKeys(istream); if (istream.fail()) { return wrap((concretelang::clientlib::EvaluationKeys *)NULL, "input stream failure during evaluation keys unserialization"); diff --git a/compiler/lib/ClientLib/CMakeLists.txt b/compiler/lib/ClientLib/CMakeLists.txt index 6ce6c0463..3e431ca73 100644 --- a/compiler/lib/ClientLib/CMakeLists.txt +++ b/compiler/lib/ClientLib/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library( ConcretelangClientLib ClientLambda.cpp ClientParameters.cpp + EvaluationKeys.cpp CRT.cpp EncryptedArguments.cpp KeySet.cpp @@ -12,4 +13,6 @@ add_mlir_library( ${PROJECT_SOURCE_DIR}/include/concretelang/ClientLib LINK_LIBS PUBLIC - Concrete) + concrete_cpu) + +target_include_directories(ConcretelangClientLib PUBLIC ${CONCRETE_CPU_INCLUDE_DIR}) diff --git a/compiler/lib/ClientLib/ClientParameters.cpp b/compiler/lib/ClientLib/ClientParameters.cpp index 6df51747e..4e147c55a 100644 --- a/compiler/lib/ClientLib/ClientParameters.cpp +++ b/compiler/lib/ClientLib/ClientParameters.cpp @@ -39,28 +39,25 @@ void KeyswitchKeyParam::hash(size_t &seed) { double_to_bits(variance)); } -void PackingKeySwitchParam::hash(size_t &seed) { - hash_(seed, inputSecretKeyID, outputSecretKeyID, bootstrapKeyID, level, - baseLog, double_to_bits(variance)); +void PackingKeyswitchKeyParam::hash(size_t &seed) { + hash_(seed, inputSecretKeyID, outputSecretKeyID, level, baseLog, + glweDimension, polynomialSize, inputLweDimension, + double_to_bits(variance)); } std::size_t ClientParameters::hash() { std::size_t currentHash = 1; for (auto secretKeyParam : secretKeys) { - hash_(currentHash, secretKeyParam.first); - secretKeyParam.second.hash(currentHash); + secretKeyParam.hash(currentHash); } for (auto bootstrapKeyParam : bootstrapKeys) { - hash_(currentHash, bootstrapKeyParam.first); - bootstrapKeyParam.second.hash(currentHash); + bootstrapKeyParam.hash(currentHash); } for (auto keyswitchParam : keyswitchKeys) { - hash_(currentHash, keyswitchParam.first); - keyswitchParam.second.hash(currentHash); + keyswitchParam.hash(currentHash); } - for (auto packingParam : packingKeys) { - hash_(currentHash, packingParam.first); - packingParam.second.hash(currentHash); + for (auto packingKeyswitchKeyParam : packingKeyswitchKeys) { + packingKeyswitchKeyParam.hash(currentHash); } return currentHash; } @@ -74,18 +71,8 @@ llvm::json::Value toJSON(const LweSecretKeyParam &v) { bool fromJSON(const llvm::json::Value j, LweSecretKeyParam &v, llvm::json::Path p) { - auto obj = j.getAsObject(); - if (obj == nullptr) { - p.report("should be an object"); - return false; - } - auto dimension = obj->getInteger("dimension"); - if (!dimension.hasValue()) { - p.report("missing size field"); - return false; - } - v.dimension = *dimension; - return true; + llvm::json::ObjectMapper O(j, p); + return O && O.map("dimension", v.dimension); } llvm::json::Value toJSON(const BootstrapKeyParam &v) { @@ -96,54 +83,22 @@ llvm::json::Value toJSON(const BootstrapKeyParam &v) { {"glweDimension", v.glweDimension}, {"baseLog", v.baseLog}, {"variance", v.variance}, + {"polynomialSize", v.polynomialSize}, + {"inputLweDimension", v.inputLweDimension}, }; return object; } bool fromJSON(const llvm::json::Value j, BootstrapKeyParam &v, llvm::json::Path p) { - auto obj = j.getAsObject(); - if (obj == nullptr) { - p.report("should be an object"); - return false; - } - auto inputSecretKeyID = obj->getString("inputSecretKeyID"); - if (!inputSecretKeyID.hasValue()) { - p.report("missing inputSecretKeyID field"); - return false; - } - auto outputSecretKeyID = obj->getString("outputSecretKeyID"); - if (!outputSecretKeyID.hasValue()) { - p.report("missing outputSecretKeyID field"); - return false; - } - auto level = obj->getInteger("level"); - if (!level.hasValue()) { - p.report("missing level field"); - return false; - } - auto baseLog = obj->getInteger("baseLog"); - if (!baseLog.hasValue()) { - p.report("missing baseLog field"); - return false; - } - auto glweDimension = obj->getInteger("glweDimension"); - if (!glweDimension.hasValue()) { - p.report("missing glweDimension field"); - return false; - } - auto variance = obj->getNumber("variance"); - if (!variance.hasValue()) { - p.report("missing variance field"); - return false; - } - v.inputSecretKeyID = (std::string)inputSecretKeyID.getValue(); - v.outputSecretKeyID = (std::string)outputSecretKeyID.getValue(); - v.level = level.getValue(); - v.baseLog = baseLog.getValue(); - v.glweDimension = glweDimension.getValue(); - v.variance = variance.getValue(); - return true; + llvm::json::ObjectMapper O(j, p); + return O && O.map("inputSecretKeyID", v.inputSecretKeyID) && + O.map("outputSecretKeyID", v.outputSecretKeyID) && + O.map("level", v.level) && O.map("baseLog", v.baseLog) && + O.map("glweDimension", v.glweDimension) && + O.map("variance", v.variance) && + O.map("polynomialSize", v.polynomialSize) && + O.map("inputLweDimension", v.inputLweDimension); } llvm::json::Value toJSON(const KeyswitchKeyParam &v) { @@ -158,99 +113,37 @@ llvm::json::Value toJSON(const KeyswitchKeyParam &v) { } bool fromJSON(const llvm::json::Value j, KeyswitchKeyParam &v, llvm::json::Path p) { - auto obj = j.getAsObject(); - if (obj == nullptr) { - p.report("should be an object"); - return false; - } - auto inputSecretKeyID = obj->getString("inputSecretKeyID"); - if (!inputSecretKeyID.hasValue()) { - p.report("missing inputSecretKeyID field"); - return false; - } - auto outputSecretKeyID = obj->getString("outputSecretKeyID"); - if (!outputSecretKeyID.hasValue()) { - p.report("missing outputSecretKeyID field"); - return false; - } - auto level = obj->getInteger("level"); - if (!level.hasValue()) { - p.report("missing level field"); - return false; - } - auto baseLog = obj->getInteger("baseLog"); - if (!baseLog.hasValue()) { - p.report("missing baseLog field"); - return false; - } - auto variance = obj->getNumber("variance"); - if (!variance.hasValue()) { - p.report("missing variance field"); - return false; - } - v.inputSecretKeyID = (std::string)inputSecretKeyID.getValue(); - v.outputSecretKeyID = (std::string)outputSecretKeyID.getValue(); - v.level = level.getValue(); - v.baseLog = baseLog.getValue(); - v.variance = variance.getValue(); - return true; + llvm::json::ObjectMapper O(j, p); + return O && O.map("inputSecretKeyID", v.inputSecretKeyID) && + O.map("outputSecretKeyID", v.outputSecretKeyID) && + O.map("level", v.level) && O.map("baseLog", v.baseLog) && + O.map("variance", v.variance); } -llvm::json::Value toJSON(const PackingKeySwitchParam &v) { +llvm::json::Value toJSON(const PackingKeyswitchKeyParam &v) { llvm::json::Object object{ {"inputSecretKeyID", v.inputSecretKeyID}, {"outputSecretKeyID", v.outputSecretKeyID}, - {"bootstrapKeyID", v.bootstrapKeyID}, {"level", v.level}, {"baseLog", v.baseLog}, + {"glweDimension", v.glweDimension}, + {"polynomialSize", v.polynomialSize}, + {"inputLweDimension", v.inputLweDimension}, {"variance", v.variance}, }; return object; } -bool fromJSON(const llvm::json::Value j, PackingKeySwitchParam &v, +bool fromJSON(const llvm::json::Value j, PackingKeyswitchKeyParam &v, llvm::json::Path p) { - auto obj = j.getAsObject(); - if (obj == nullptr) { - p.report("should be an object"); - return false; - } - auto inputSecretKeyID = obj->getString("inputSecretKeyID"); - if (!inputSecretKeyID.hasValue()) { - p.report("missing inputSecretKeyID field"); - return false; - } - auto outputSecretKeyID = obj->getString("outputSecretKeyID"); - if (!outputSecretKeyID.hasValue()) { - p.report("missing outputSecretKeyID field"); - return false; - } - auto bootstrapKeyID = obj->getString("bootstrapKeyID"); - if (!bootstrapKeyID.hasValue()) { - p.report("missing bootstrapKeyID field"); - return false; - } - auto level = obj->getInteger("level"); - if (!level.hasValue()) { - p.report("missing level field"); - return false; - } - auto baseLog = obj->getInteger("baseLog"); - if (!baseLog.hasValue()) { - p.report("missing baseLog field"); - return false; - } - auto variance = obj->getNumber("variance"); - if (!variance.hasValue()) { - p.report("missing variance field"); - return false; - } - v.inputSecretKeyID = (std::string)inputSecretKeyID.getValue(); - v.outputSecretKeyID = (std::string)outputSecretKeyID.getValue(); - v.bootstrapKeyID = (std::string)bootstrapKeyID.getValue(); - v.level = level.getValue(); - v.baseLog = baseLog.getValue(); - v.variance = variance.getValue(); - return true; + + llvm::json::ObjectMapper O(j, p); + return O && O.map("inputSecretKeyID", v.inputSecretKeyID) && + O.map("outputSecretKeyID", v.outputSecretKeyID) && + O.map("level", v.level) && O.map("baseLog", v.baseLog) && + O.map("glweDimension", v.glweDimension) && + O.map("polynomialSize", v.polynomialSize) && + O.map("inputLweDimension", v.inputLweDimension) && + O.map("variance", v.variance); } llvm::json::Value toJSON(const CircuitGateShape &v) { @@ -264,43 +157,10 @@ llvm::json::Value toJSON(const CircuitGateShape &v) { } bool fromJSON(const llvm::json::Value j, CircuitGateShape &v, llvm::json::Path p) { - auto obj = j.getAsObject(); - if (obj == nullptr) { - p.report("should be an object"); - return false; - } - auto width = obj->getInteger("width"); - if (!width.hasValue()) { - p.report("missing width field"); - return false; - } - auto dimensions = obj->getArray("dimensions"); - if (dimensions == nullptr) { - p.report("missing dimensions field"); - return false; - } - for (auto dim : *dimensions) { - auto iDim = dim.getAsInteger(); - if (!iDim.hasValue()) { - p.report("dimensions must be integer"); - return false; - } - v.dimensions.push_back(iDim.getValue()); - } - auto size = obj->getInteger("size"); - if (!size.hasValue()) { - p.report("missing size field"); - return false; - } - auto sign = obj->getBoolean("sign"); - if (!sign.hasValue()) { - p.report("missing sign field"); - return false; - } - v.width = width.getValue(); - v.size = size.getValue(); - v.sign = sign.getValue(); - return true; + + llvm::json::ObjectMapper O(j, p); + return O && O.map("width", v.width) && O.map("size", v.size) && + O.map("dimensions", v.dimensions) && O.map("sign", v.sign); } llvm::json::Value toJSON(const Encoding &v) { @@ -314,35 +174,13 @@ llvm::json::Value toJSON(const Encoding &v) { return object; } bool fromJSON(const llvm::json::Value j, Encoding &v, llvm::json::Path p) { - auto obj = j.getAsObject(); - if (obj == nullptr) { - p.report("should be an object"); + llvm::json::ObjectMapper O(j, p); + if (!(O && O.map("precision", v.precision) && + O.map("isSigned", v.isSigned))) { return false; } - auto precision = obj->getInteger("precision"); - if (!precision.hasValue()) { - p.report("missing precision field"); - return false; - } - v.precision = precision.getValue(); - auto isSigned = obj->getBoolean("isSigned"); - if (!isSigned.hasValue()) { - p.report("missing isSigned field"); - return false; - } - v.isSigned = isSigned.getValue(); - auto crt = obj->getArray("crt"); - if (crt != nullptr) { - for (auto dim : *crt) { - auto iDim = dim.getAsInteger(); - if (!iDim.hasValue()) { - p.report("dimensions must be integer"); - return false; - } - v.crt.push_back(iDim.getValue()); - } - } - + // TODO: check this is correct for an optional field + O.map("crt", v.crt); return true; } @@ -356,32 +194,9 @@ llvm::json::Value toJSON(const EncryptionGate &v) { } bool fromJSON(const llvm::json::Value j, EncryptionGate &v, llvm::json::Path p) { - auto obj = j.getAsObject(); - if (obj == nullptr) { - p.report("should be an object"); - return false; - } - auto secretKeyID = obj->getString("secretKeyID"); - if (!secretKeyID.hasValue()) { - p.report("missing secretKeyID field"); - return false; - } - v.secretKeyID = (std::string)secretKeyID.getValue(); - auto variance = obj->getNumber("variance"); - if (!variance.hasValue()) { - p.report("missing variance field"); - return false; - } - v.variance = variance.getValue(); - auto encoding = obj->get("encoding"); - if (encoding == nullptr) { - p.report("missing encoding field"); - return false; - } - if (!fromJSON(*encoding, v.encoding, p.field("encoding"))) { - return false; - } - return true; + llvm::json::ObjectMapper O(j, p); + return O && O.map("secretKeyID", v.secretKeyID) && + O.map("variance", v.variance) && O.map("encoding", v.encoding); } llvm::json::Value toJSON(const CircuitGate &v) { @@ -392,40 +207,16 @@ llvm::json::Value toJSON(const CircuitGate &v) { return object; } bool fromJSON(const llvm::json::Value j, CircuitGate &v, llvm::json::Path p) { - auto obj = j.getAsObject(); - auto encryption = obj->get("encryption"); - if (encryption == nullptr) { - p.report("missing encryption field"); - return false; - } - if (!fromJSON(*encryption, v.encryption, p.field("encryption"))) { - return false; - } - auto shape = obj->get("shape"); - if (shape == nullptr) { - p.report("missing shape field"); - return false; - } - if (!fromJSON(*shape, v.shape, p.field("shape"))) { - return false; - } - return true; -} - -template llvm::json::Value toJson(std::map map) { - llvm::json::Object obj; - for (auto entry : map) { - obj[entry.first] = entry.second; - } - return obj; + llvm::json::ObjectMapper O(j, p); + return O && O.map("encryption", v.encryption) && O.map("shape", v.shape); } llvm::json::Value toJSON(const ClientParameters &v) { llvm::json::Object object{ - {"secretKeys", toJson(v.secretKeys)}, - {"bootstrapKeys", toJson(v.bootstrapKeys)}, - {"keyswitchKeys", toJson(v.keyswitchKeys)}, - {"packingKeys", toJson(v.packingKeys)}, + {"secretKeys", v.secretKeys}, + {"bootstrapKeys", v.bootstrapKeys}, + {"keyswitchKeys", v.keyswitchKeys}, + {"packingKeyswitchKeys", v.packingKeyswitchKeys}, {"inputs", v.inputs}, {"outputs", v.outputs}, {"functionName", v.functionName}, @@ -434,64 +225,13 @@ llvm::json::Value toJSON(const ClientParameters &v) { } bool fromJSON(const llvm::json::Value j, ClientParameters &v, llvm::json::Path p) { - - auto obj = j.getAsObject(); - auto secretkeys = obj->get("secretKeys"); - if (secretkeys == nullptr) { - p.report("missing secretKeys field"); - return false; - } - if (!fromJSON(*secretkeys, v.secretKeys, p.field("secretKeys"))) { - return false; - } - auto bootstrapKeys = obj->get("bootstrapKeys"); - if (bootstrapKeys == nullptr) { - p.report("missing bootstrapKeys field"); - return false; - } - if (!fromJSON(*bootstrapKeys, v.bootstrapKeys, p.field("bootstrapKeys"))) { - return false; - } - auto keyswitchKeys = obj->get("keyswitchKeys"); - if (keyswitchKeys == nullptr) { - p.report("missing keyswitchKeys field"); - return false; - } - if (!fromJSON(*keyswitchKeys, v.keyswitchKeys, p.field("keyswitchKeys"))) { - return false; - } - auto packingKeys = obj->get("packingKeys"); - if (packingKeys == nullptr) { - p.report("missing packingKeys field"); - return false; - } - if (!fromJSON(*packingKeys, v.packingKeys, p.field("packingKeys"))) { - return false; - } - auto inputs = obj->get("inputs"); - if (inputs == nullptr) { - p.report("missing inputs field"); - return false; - } - if (!fromJSON(*inputs, v.inputs, p.field("inputs"))) { - return false; - } - auto outputs = obj->get("outputs"); - if (outputs == nullptr) { - p.report("missing outputs field"); - return false; - } - if (!fromJSON(*outputs, v.outputs, p.field("outputs"))) { - return false; - } - auto functionName = obj->getString("functionName"); - if (!functionName.hasValue()) { - p.report("missing functionName field"); - return false; - } - v.functionName = (std::string)functionName.getValue(); - - return true; + llvm::json::ObjectMapper O(j, p); + return O && O.map("secretKeys", v.secretKeys) && + O.map("bootstrapKeys", v.bootstrapKeys) && + O.map("keyswitchKeys", v.keyswitchKeys) && + O.map("packingKeyswitchKeys", v.packingKeyswitchKeys) && + O.map("inputs", v.inputs) && O.map("outputs", v.outputs) && + O.map("functionName", v.functionName); } std::string ClientParameters::getClientParametersPath(std::string path) { diff --git a/compiler/lib/ClientLib/EncryptedArguments.cpp b/compiler/lib/ClientLib/EncryptedArguments.cpp index 999a9cadf..9a5bcdabd 100644 --- a/compiler/lib/ClientLib/EncryptedArguments.cpp +++ b/compiler/lib/ClientLib/EncryptedArguments.cpp @@ -12,8 +12,7 @@ namespace clientlib { using StringError = concretelang::error::StringError; outcome::checked, StringError> -EncryptedArguments::exportPublicArguments(ClientParameters clientParameters, - RuntimeContext runtimeContext) { +EncryptedArguments::exportPublicArguments(ClientParameters clientParameters) { return std::make_unique( clientParameters, std::move(preparedArgs), std::move(ciphertextBuffers)); } diff --git a/compiler/lib/ClientLib/EvaluationKeys.cpp b/compiler/lib/ClientLib/EvaluationKeys.cpp new file mode 100644 index 000000000..1e53c2ee2 --- /dev/null +++ b/compiler/lib/ClientLib/EvaluationKeys.cpp @@ -0,0 +1,137 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "concretelang/ClientLib/EvaluationKeys.h" +#include "concrete-cpu.h" +#include "concretelang/ClientLib/ClientParameters.h" + +namespace concretelang { +namespace clientlib { + +ConcreteCSPRNG::ConcreteCSPRNG(__uint128_t seed) + : CSPRNG(nullptr, &CONCRETE_CSPRNG_VTABLE) { + ptr = (Csprng *)aligned_alloc(CONCRETE_CSPRNG_ALIGN, CONCRETE_CSPRNG_SIZE); + struct Uint128 u128; + if (seed == 0) { + switch (concrete_cpu_crypto_secure_random_128(&u128)) { + case 1: + break; + case -1: + llvm::errs() + << "WARNING: The generated random seed is not crypto secure\n"; + break; + default: + assert(false && "Cannot instantiate a random seed"); + } + + } else { + for (int i = 0; i < 16; i++) { + u128.little_endian_bytes[i] = seed >> (8 * i); + } + } + concrete_cpu_construct_concrete_csprng(ptr, u128); +} + +ConcreteCSPRNG::ConcreteCSPRNG(ConcreteCSPRNG &&other) + : CSPRNG(other.ptr, &CONCRETE_CSPRNG_VTABLE) { + assert(ptr != nullptr); + other.ptr = nullptr; +} + +ConcreteCSPRNG::~ConcreteCSPRNG() { + if (ptr != nullptr) { + concrete_cpu_destroy_concrete_csprng(ptr); + free(ptr); + } +} + +LweSecretKey::LweSecretKey(LweSecretKeyParam ¶meters, CSPRNG &csprng) + : _parameters(parameters) { + // Allocate the buffer + _buffer = std::make_shared>(); + _buffer->resize(parameters.dimension); + // Initialize the lwe secret key buffer + concrete_cpu_init_lwe_secret_key_u64(_buffer->data(), parameters.dimension, + csprng.ptr, csprng.vtable); +} + +void LweSecretKey::encrypt(uint64_t *ciphertext, uint64_t plaintext, + double variance, CSPRNG &csprng) const { + concrete_cpu_encrypt_lwe_ciphertext_u64(_buffer->data(), ciphertext, + plaintext, parameters().dimension, + variance, csprng.ptr, csprng.vtable); +} + +void LweSecretKey::decrypt(const uint64_t *ciphertext, + uint64_t &plaintext) const { + concrete_cpu_decrypt_lwe_ciphertext_u64(_buffer->data(), ciphertext, + parameters().dimension, &plaintext); +} + +LweKeyswitchKey::LweKeyswitchKey(KeyswitchKeyParam ¶meters, + LweSecretKey &inputKey, + LweSecretKey &outputKey, CSPRNG &csprng) + : _parameters(parameters) { + // Allocate the buffer + auto size = concrete_cpu_keyswitch_key_size_u64( + _parameters.level, _parameters.baseLog, inputKey.dimension(), + outputKey.dimension()); + _buffer = std::make_shared>(); + _buffer->resize(size); + + // Initialize the keyswitch key buffer + concrete_cpu_init_lwe_keyswitch_key_u64( + _buffer->data(), inputKey.buffer(), outputKey.buffer(), + inputKey.dimension(), outputKey.dimension(), _parameters.level, + _parameters.baseLog, _parameters.variance, csprng.ptr, csprng.vtable); +} + +LweBootstrapKey::LweBootstrapKey(BootstrapKeyParam ¶meters, + LweSecretKey &inputKey, + LweSecretKey &outputKey, CSPRNG &csprng) + : _parameters(parameters) { + // TODO + size_t polynomial_size = outputKey.dimension() / _parameters.glweDimension; + // Allocate the buffer + auto size = concrete_cpu_bootstrap_key_size_u64( + _parameters.level, _parameters.glweDimension, polynomial_size, + inputKey.dimension()); + _buffer = std::make_shared>(); + _buffer->resize(size); + + // Initialize the keyswitch key buffer + concrete_cpu_init_lwe_bootstrap_key_u64( + _buffer->data(), inputKey.buffer(), outputKey.buffer(), + inputKey.dimension(), polynomial_size, _parameters.glweDimension, + _parameters.level, _parameters.baseLog, _parameters.variance, + Parallelism::Rayon, csprng.ptr, csprng.vtable); +} + +PackingKeyswitchKey::PackingKeyswitchKey(PackingKeyswitchKeyParam ¶ms, + LweSecretKey &inputKey, + LweSecretKey &outputKey, + CSPRNG &csprng) + : _parameters(params) { + assert(_parameters.inputLweDimension == inputKey.dimension()); + assert(_parameters.glweDimension * _parameters.polynomialSize == + outputKey.dimension()); + + // Allocate the buffer + auto size = concrete_cpu_lwe_packing_keyswitch_key_size( + _parameters.glweDimension, _parameters.polynomialSize, _parameters.level, + _parameters.inputLweDimension); + _buffer = std::make_shared>(); + _buffer->resize(size * (_parameters.glweDimension + 1)); + + // Initialize the keyswitch key buffer + concrete_cpu_init_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_u64( + _buffer->data(), inputKey.buffer(), outputKey.buffer(), + _parameters.inputLweDimension, _parameters.polynomialSize, + _parameters.glweDimension, _parameters.level, _parameters.baseLog, + _parameters.variance, Parallelism::Rayon, csprng.ptr, csprng.vtable); +} + +} // namespace clientlib +} // namespace concretelang diff --git a/compiler/lib/ClientLib/KeySet.cpp b/compiler/lib/ClientLib/KeySet.cpp index f4bb8fcab..c0f3e3150 100644 --- a/compiler/lib/ClientLib/KeySet.cpp +++ b/compiler/lib/ClientLib/KeySet.cpp @@ -6,269 +6,142 @@ #include "concretelang/ClientLib/KeySet.h" #include "concretelang/ClientLib/CRT.h" #include "concretelang/Common/Error.h" -#include "concretelang/Runtime/seeder.h" #include "concretelang/Support/Error.h" - -#define CAPI_ERR_TO_STRINGERROR(instr, msg) \ - { \ - int err; \ - instr; \ - if (err != 0) { \ - return concretelang::error::StringError(msg); \ - } \ - } - -int clone_transform_lwe_secret_key_to_glwe_secret_key_u64( - DefaultEngine *default_engine, LweSecretKey64 *output_lwe_sk, - size_t poly_size, GlweSecretKey64 **output_glwe_sk) { - LweSecretKey64 *output_lwe_sk_clone = NULL; - int lwe_out_sk_clone_ok = - clone_lwe_secret_key_u64(output_lwe_sk, &output_lwe_sk_clone); - if (lwe_out_sk_clone_ok != 0) { - return 1; - } - - int glwe_sk_ok = - default_engine_transform_lwe_secret_key_to_glwe_secret_key_u64( - default_engine, &output_lwe_sk_clone, poly_size, output_glwe_sk); - if (glwe_sk_ok != 0) { - return 1; - } - - if (output_lwe_sk_clone != NULL) { - return 1; - } - - return 0; -} +#include +#include +#include namespace concretelang { namespace clientlib { -KeySet::KeySet() { - - CAPI_ASSERT_ERROR(new_default_engine(best_seeder, &engine)); - - CAPI_ASSERT_ERROR(new_default_parallel_engine(best_seeder, &par_engine)); -} - -KeySet::~KeySet() { - for (auto it : secretKeys) { - CAPI_ASSERT_ERROR(destroy_lwe_secret_key_u64(it.second.second)); - } - - CAPI_ASSERT_ERROR(destroy_default_engine(engine)); - CAPI_ASSERT_ERROR(destroy_default_parallel_engine(par_engine)); -} - outcome::checked, StringError> -KeySet::generate(ClientParameters ¶ms, uint64_t seed_msb, - uint64_t seed_lsb) { - auto keySet = std::make_unique(); - OUTCOME_TRYV(keySet->generateKeysFromParams(params, seed_msb, seed_lsb)); - OUTCOME_TRYV(keySet->setupEncryptionMaterial(params, seed_msb, seed_lsb)); +KeySet::generate(ClientParameters clientParameters, CSPRNG &&csprng) { + auto keySet = std::make_unique(clientParameters, std::move(csprng)); + OUTCOME_TRYV(keySet->generateKeysFromParams()); + OUTCOME_TRYV(keySet->setupEncryptionMaterial()); return std::move(keySet); } -outcome::checked -KeySet::setupEncryptionMaterial(ClientParameters ¶ms, uint64_t seed_msb, - uint64_t seed_lsb) { - _clientParameters = params; +outcome::checked, StringError> KeySet::fromKeys( + ClientParameters clientParameters, std::vector secretKeys, + std::vector bootstrapKeys, + std::vector keyswitchKeys, + std::vector packingKeyswitchKeys, CSPRNG &&csprng) { - // Set inputs and outputs LWE secret keys - { - for (auto param : params.inputs) { - LweSecretKeyParam secretKeyParam = {0}; - LweSecretKey64 *secretKey = nullptr; - if (param.encryption.hasValue()) { - auto inputSk = this->secretKeys.find(param.encryption->secretKeyID); - if (inputSk == this->secretKeys.end()) { - return StringError("input encryption secret key (") - << param.encryption->secretKeyID << ") does not exist "; - } - secretKeyParam = inputSk->second.first; - secretKey = inputSk->second.second; - } - std::tuple input = { - param, secretKeyParam, secretKey}; - this->inputs.push_back(input); - } - for (auto param : params.outputs) { - LweSecretKeyParam secretKeyParam = {0}; - LweSecretKey64 *secretKey = nullptr; - if (param.encryption.hasValue()) { - auto outputSk = this->secretKeys.find(param.encryption->secretKeyID); - if (outputSk == this->secretKeys.end()) { - return StringError( - "cannot find output key to generate bootstrap key"); - } - secretKeyParam = outputSk->second.first; - secretKey = outputSk->second.second; - } - std::tuple output = { - param, secretKeyParam, secretKey}; - this->outputs.push_back(output); + auto keySet = std::make_unique(clientParameters, std::move(csprng)); + keySet->secretKeys = secretKeys; + keySet->bootstrapKeys = bootstrapKeys; + keySet->keyswitchKeys = keyswitchKeys; + keySet->packingKeyswitchKeys = packingKeyswitchKeys; + OUTCOME_TRYV(keySet->setupEncryptionMaterial()); + return std::move(keySet); +} + +EvaluationKeys KeySet::evaluationKeys() { + return EvaluationKeys(keyswitchKeys, bootstrapKeys, packingKeyswitchKeys); +} + +outcome::checked +KeySet::mapCircuitGateLweSecretKey(std::vector gates) { + SecretKeyGateMapping mapping; + for (auto gate : gates) { + if (gate.encryption.hasValue()) { + assert(gate.encryption->secretKeyID < this->secretKeys.size()); + auto skIt = this->secretKeys[gate.encryption->secretKeyID]; + + std::pair> input = {gate, skIt}; + mapping.push_back(input); + } else { + std::pair> input = {gate, + llvm::None}; + mapping.push_back(input); } } + return mapping; +} +outcome::checked KeySet::setupEncryptionMaterial() { + OUTCOME_TRY(this->inputs, + mapCircuitGateLweSecretKey(_clientParameters.inputs)); + OUTCOME_TRY(this->outputs, + mapCircuitGateLweSecretKey(_clientParameters.outputs)); return outcome::success(); } -outcome::checked -KeySet::generateKeysFromParams(ClientParameters ¶ms, uint64_t seed_msb, - uint64_t seed_lsb) { +outcome::checked KeySet::generateKeysFromParams() { - { - // Generate LWE secret keys - for (auto secretKeyParam : params.secretKeys) { - OUTCOME_TRYV( - this->generateSecretKey(secretKeyParam.first, secretKeyParam.second)); - } + // Generate LWE secret keys + for (auto secretKeyParam : _clientParameters.secretKeys) { + OUTCOME_TRYV(this->generateSecretKey(secretKeyParam)); } - // Generate bootstrap, keyswitch and packing keyswitch keys - { - for (auto bootstrapKeyParam : params.bootstrapKeys) { - OUTCOME_TRYV(this->generateBootstrapKey(bootstrapKeyParam.first, - bootstrapKeyParam.second)); - } - for (auto keyswitchParam : params.keyswitchKeys) { - OUTCOME_TRYV(this->generateKeyswitchKey(keyswitchParam.first, - keyswitchParam.second)); - } - for (auto packingParam : params.packingKeys) { - OUTCOME_TRYV( - this->generatePackingKey(packingParam.first, packingParam.second)); - } + // Generate bootstrap keys + for (auto bootstrapKeyParam : _clientParameters.bootstrapKeys) { + OUTCOME_TRYV(this->generateBootstrapKey(bootstrapKeyParam)); + } + // Generate keyswitch key + for (auto keyswitchParam : _clientParameters.keyswitchKeys) { + OUTCOME_TRYV(this->generateKeyswitchKey(keyswitchParam)); + } + // Generate packing keyswitch key + for (auto packingKeyswitchKeyParam : _clientParameters.packingKeyswitchKeys) { + OUTCOME_TRYV(this->generatePackingKeyswitchKey(packingKeyswitchKeyParam)); } return outcome::success(); } -void KeySet::setKeys( - std::map> - secretKeys, - std::map>> - bootstrapKeys, - std::map>> - keyswitchKeys, - std::map>> - packingKeys) { - this->secretKeys = secretKeys; - this->bootstrapKeys = bootstrapKeys; - this->keyswitchKeys = keyswitchKeys; - this->packingKeys = packingKeys; -} - outcome::checked -KeySet::generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param) { - LweSecretKey64 *sk; - CAPI_ASSERT_ERROR(default_engine_generate_new_lwe_secret_key_u64( - engine, param.dimension, &sk)); - - secretKeys[id] = {param, sk}; - +KeySet::generateSecretKey(LweSecretKeyParam param) { + // Init the lwe secret key + LweSecretKey sk(param, csprng); + // Store the lwe secret key + secretKeys.push_back(sk); return outcome::success(); } +outcome::checked +KeySet::findLweSecretKey(LweSecretKeyID keyID) { + assert(keyID < secretKeys.size()); + auto secretKey = secretKeys[keyID]; + + return secretKey; +} + outcome::checked -KeySet::generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param) { +KeySet::generateBootstrapKey(BootstrapKeyParam param) { // Finding input and output secretKeys - auto inputSk = secretKeys.find(param.inputSecretKeyID); - if (inputSk == secretKeys.end()) { - return StringError("cannot find input key to generate bootstrap key"); - } - auto outputSk = secretKeys.find(param.outputSecretKeyID); - if (outputSk == secretKeys.end()) { - return StringError("cannot find output key to generate bootstrap key"); - } - // Allocate the bootstrap key - LweBootstrapKey64 *bsk; - - uint64_t total_dimension = outputSk->second.first.dimension; - - assert(total_dimension % param.glweDimension == 0); - - uint64_t polynomialSize = total_dimension / param.glweDimension; - - GlweSecretKey64 *output_glwe_sk = nullptr; - - // This is not part of the C FFI but rather is a C util exposed for - // convenience in tests. - CAPI_ASSERT_ERROR(clone_transform_lwe_secret_key_to_glwe_secret_key_u64( - engine, outputSk->second.second, polynomialSize, &output_glwe_sk)); - - CAPI_ASSERT_ERROR(default_parallel_engine_generate_new_lwe_bootstrap_key_u64( - par_engine, inputSk->second.second, output_glwe_sk, param.baseLog, - param.level, param.variance, &bsk)); - - CAPI_ASSERT_ERROR(destroy_glwe_secret_key_u64(output_glwe_sk)); - + OUTCOME_TRY(auto inputKey, findLweSecretKey(param.inputSecretKeyID)); + OUTCOME_TRY(auto outputKey, findLweSecretKey(param.outputSecretKeyID)); + // Initialize the bootstrap key + LweBootstrapKey bootstrapKey(param, inputKey, outputKey, csprng); // Store the bootstrap key - bootstrapKeys[id] = {param, std::make_shared(bsk)}; - + bootstrapKeys.push_back(std::move(bootstrapKey)); return outcome::success(); } outcome::checked -KeySet::generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param) { +KeySet::generateKeyswitchKey(KeyswitchKeyParam param) { // Finding input and output secretKeys - auto inputSk = secretKeys.find(param.inputSecretKeyID); - if (inputSk == secretKeys.end()) { - return StringError("cannot find input key to generate keyswitch key"); - } - auto outputSk = secretKeys.find(param.outputSecretKeyID); - if (outputSk == secretKeys.end()) { - return StringError("cannot find output key to generate keyswitch key"); - } - // Allocate the keyswitch key - LweKeyswitchKey64 *ksk; - - CAPI_ASSERT_ERROR(default_engine_generate_new_lwe_keyswitch_key_u64( - engine, inputSk->second.second, outputSk->second.second, param.level, - param.baseLog, param.variance, &ksk)); - + OUTCOME_TRY(auto inputKey, findLweSecretKey(param.inputSecretKeyID)); + OUTCOME_TRY(auto outputKey, findLweSecretKey(param.outputSecretKeyID)); + // Initialize the bootstrap key + LweKeyswitchKey keyswitchKey(param, inputKey, outputKey, csprng); // Store the keyswitch key - keyswitchKeys[id] = {param, std::make_shared(ksk)}; - + keyswitchKeys.push_back(keyswitchKey); return outcome::success(); } outcome::checked -KeySet::generatePackingKey(PackingKeySwitchID id, PackingKeySwitchParam param) { +KeySet::generatePackingKeyswitchKey(PackingKeyswitchKeyParam param) { // Finding input secretKeys - auto inputSk = secretKeys.find(param.inputSecretKeyID); - if (inputSk == secretKeys.end()) { - return StringError( - "cannot find input key to generate packing keyswitch key"); - } - auto bsk = bootstrapKeys.find(param.bootstrapKeyID); - if (bsk == bootstrapKeys.end()) { - return StringError( - "cannot find input key to generate packing keyswitch key"); - } + assert(param.inputSecretKeyID < secretKeys.size()); + auto inputSk = secretKeys[param.inputSecretKeyID]; - // This is not part of the C FFI but rather is a C util exposed for - // convenience in tests. - GlweSecretKey64 *output_glwe_sk = nullptr; - - auto lweDimension = - inputSk->second.first.lweDimension() / bsk->second.first.glweDimension; - - CAPI_ASSERT_ERROR(clone_transform_lwe_secret_key_to_glwe_secret_key_u64( - engine, inputSk->second.second, lweDimension, &output_glwe_sk)); - - // Allocate the packing keyswitch key - LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *fpksk; - CAPI_ASSERT_ERROR( - default_parallel_engine_generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_unchecked_u64( - par_engine, inputSk->second.second, output_glwe_sk, param.baseLog, - param.level, param.variance, &fpksk)); + assert(param.outputSecretKeyID < secretKeys.size()); + auto outputSk = secretKeys[param.outputSecretKeyID]; + PackingKeyswitchKey packingKeyswitchKey(param, inputSk, outputSk, csprng); // Store the keyswitch key - packingKeys[id] = {param, std::make_shared(fpksk)}; - + packingKeyswitchKeys.push_back(packingKeyswitchKey); return outcome::success(); } @@ -285,8 +158,9 @@ KeySet::allocate_lwe(size_t argPos, uint64_t **ciphertext, uint64_t &size) { } auto numBlocks = encryption->encoding.crt.empty() ? 1 : encryption->encoding.crt.size(); + assert(inputSk.second.has_value()); - size = std::get<1>(inputSk).lweSize(); + size = inputSk.second->parameters().lweSize(); *ciphertext = (uint64_t *)malloc(sizeof(uint64_t) * size * numBlocks); return outcome::success(); } @@ -315,8 +189,9 @@ KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) { return StringError("encrypt_lwe the positional argument is not encrypted"); } auto encoding = encryption->encoding; - auto lweSecretKeyParam = std::get<1>(inputSk); - auto lweSecretKey = std::get<2>(inputSk); + assert(inputSk.second.has_value()); + auto lweSecretKey = *inputSk.second; + auto lweSecretKeyParam = lweSecretKey.parameters(); // CRT encoding - N blocks with crt encoding auto crt = encryption->encoding.crt; if (!crt.empty()) { @@ -324,11 +199,7 @@ KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) { auto product = crt::productOfModuli(crt); for (auto modulus : crt) { auto plaintext = crt::encode(input, modulus, product); - CAPI_ASSERT_ERROR( - default_engine_discard_encrypt_lwe_ciphertext_u64_raw_ptr_buffers( - engine, lweSecretKey, ciphertext, plaintext, - encryption->variance)); - + lweSecretKey.encrypt(ciphertext, plaintext, encryption->variance, csprng); ciphertext = ciphertext + lweSecretKeyParam.lweSize(); } return outcome::success(); @@ -336,10 +207,7 @@ KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) { // Simple TFHE integers - 1 blocks with one padding bits // TODO we could check if the input value is in the right range uint64_t plaintext = input << (64 - (encryption->encoding.precision + 1)); - CAPI_ASSERT_ERROR( - default_engine_discard_encrypt_lwe_ciphertext_u64_raw_ptr_buffers( - engine, lweSecretKey, ciphertext, plaintext, encryption->variance)); - + lweSecretKey.encrypt(ciphertext, plaintext, encryption->variance, csprng); return outcome::success(); } @@ -349,8 +217,9 @@ KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) { return StringError("decrypt_lwe: position of argument is too high"); } auto outputSk = outputs[argPos]; - auto lweSecretKey = std::get<2>(outputSk); - auto lweSecretKeyParam = std::get<1>(outputSk); + assert(outputSk.second.has_value()); + auto lweSecretKey = *outputSk.second; + auto lweSecretKeyParam = lweSecretKey.parameters(); auto encryption = std::get<0>(outputSk).encryption; if (!encryption.hasValue()) { return StringError("decrypt_lwe: the positional argument is not encrypted"); @@ -358,15 +227,15 @@ KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) { auto crt = encryption->encoding.crt; - if (!crt.empty()) { // The ciphertext used the crt strategy. + if (!crt.empty()) { + // CRT encoded TFHE integers // Decrypt and decode remainders std::vector remainders; for (auto modulus : crt) { - uint64_t decrypted; - CAPI_ASSERT_ERROR( - default_engine_decrypt_lwe_ciphertext_u64_raw_ptr_buffers( - engine, lweSecretKey, ciphertext, &decrypted)); + uint64_t decrypted = 0; + lweSecretKey.decrypt(ciphertext, decrypted); + auto plaintext = crt::decode(decrypted, modulus); remainders.push_back(plaintext); ciphertext = ciphertext + lweSecretKeyParam.lweSize(); @@ -386,12 +255,10 @@ KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) { output -= maxPos * 2; } } - } else { // The ciphertext used the scalar strategy - - // Decrypt - uint64_t plaintext; - CAPI_ASSERT_ERROR(default_engine_decrypt_lwe_ciphertext_u64_raw_ptr_buffers( - engine, lweSecretKey, ciphertext, &plaintext)); + } else { + // Native encoded TFHE integers - 1 blocks with one padding bits + uint64_t plaintext = 0; + lweSecretKey.decrypt(ciphertext, plaintext); // Decode unsigned integer uint64_t precision = encryption->encoding.precision; @@ -415,27 +282,18 @@ KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) { return outcome::success(); } -const std::map> & -KeySet::getSecretKeys() { - return secretKeys; -} +const std::vector &KeySet::getSecretKeys() { return secretKeys; } -const std::map>> & -KeySet::getBootstrapKeys() { +const std::vector &KeySet::getBootstrapKeys() { return bootstrapKeys; } -const std::map>> & -KeySet::getKeyswitchKeys() { +const std::vector &KeySet::getKeyswitchKeys() { return keyswitchKeys; } -const std::map>> - &KeySet::getPackingKeys() { - return packingKeys; +const std::vector &KeySet::getPackingKeyswitchKeys() { + return packingKeyswitchKeys; } } // namespace clientlib diff --git a/compiler/lib/ClientLib/KeySetCache.cpp b/compiler/lib/ClientLib/KeySetCache.cpp index b1dd36317..211c61090 100644 --- a/compiler/lib/ClientLib/KeySetCache.cpp +++ b/compiler/lib/ClientLib/KeySetCache.cpp @@ -7,6 +7,8 @@ #include "concretelang/ClientLib/EvaluationKeys.h" #include "concretelang/ClientLib/KeySetCache.h" +#include "concretelang/ClientLib/Serializers.h" + #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/Path.h" @@ -16,216 +18,101 @@ #include #include -#include "concrete-core-ffi.h" - namespace concretelang { namespace clientlib { using StringError = concretelang::error::StringError; -template -outcome::checked -load(llvm::SmallString<0> &path, - int (*deser)(Engine *, BufferView buffer, Key **), Engine *engine) { +template +outcome::checked loadKey(llvm::SmallString<0> &path, + Key(deser)(std::istream &istream)) { std::ifstream in((std::string)path, std::ofstream::binary); if (in.fail()) { return StringError("Cannot access " + (std::string)path); } - std::stringstream sbuffer; - sbuffer << in.rdbuf(); - if (in.fail()) { - return StringError("Cannot read " + (std::string)path); + auto key = deser(in); + if (in.bad()) { + return StringError("Cannot load key at path(") << (std::string)path << ")"; } - auto content = sbuffer.str(); - BufferView buffer = {(const uint8_t *)content.c_str(), content.length()}; - Key *result = nullptr; - - int error_code = deser(engine, buffer, &result); - - if (result == nullptr || error_code != 0) { - return StringError("Cannot deserialize " + (std::string)path); - } - return result; + return key; } -static void writeFile(llvm::SmallString<0> &path, Buffer content) { +template +outcome::checked saveKey(llvm::SmallString<0> &path, + Key key) { std::ofstream out((std::string)path, std::ofstream::binary); - out.write((const char *)content.pointer, content.length); + if (out.fail()) { + return StringError("Cannot access " + (std::string)path); + } + out << key; + if (out.bad()) { + return StringError("Cannot save key at path(") << (std::string)path << ")"; + } out.close(); -} - -outcome::checked -loadSecretKey(llvm::SmallString<0> &path) { - DefaultSerializationEngine *engine; - - CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); - - return load(path, default_serialization_engine_deserialize_lwe_secret_key_u64, - engine); -} - -outcome::checked -loadKeyswitchKey(llvm::SmallString<0> &path) { - DefaultSerializationEngine *engine; - - CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); - - return load(path, - default_serialization_engine_deserialize_lwe_keyswitch_key_u64, - engine); -} - -outcome::checked -loadBootstrapKey(llvm::SmallString<0> &path) { - - DefaultSerializationEngine *engine; - - CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); - - return load(path, - default_serialization_engine_deserialize_lwe_bootstrap_key_u64, - engine); -} - -outcome::checked -loadPackingKey(llvm::SmallString<0> &path) { - - DefaultSerializationEngine *engine; - - CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); - - return load( - path, - default_serialization_engine_deserialize_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_u64, - engine); -} - -void saveSecretKey(llvm::SmallString<0> &path, LweSecretKey64 *key) { - - DefaultSerializationEngine *engine; - - CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); - - Buffer buffer; - - CAPI_ASSERT_ERROR(default_serialization_engine_serialize_lwe_secret_key_u64( - engine, key, &buffer)); - - writeFile(path, buffer); - free(buffer.pointer); -} - -void saveBootstrapKey(llvm::SmallString<0> &path, LweBootstrapKey64 *key) { - DefaultSerializationEngine *engine; - - CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); - - Buffer buffer; - - CAPI_ASSERT_ERROR( - default_serialization_engine_serialize_lwe_bootstrap_key_u64(engine, key, - &buffer)); - writeFile(path, buffer); - free(buffer.pointer); -} - -void saveKeyswitchKey(llvm::SmallString<0> &path, LweKeyswitchKey64 *key) { - - DefaultSerializationEngine *engine; - - CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); - - Buffer buffer; - CAPI_ASSERT_ERROR( - default_serialization_engine_serialize_lwe_keyswitch_key_u64(engine, key, - &buffer)); - writeFile(path, buffer); - free(buffer.pointer); -} - -void savePackingKey( - llvm::SmallString<0> &path, - LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *key) { - - DefaultSerializationEngine *engine; - - CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); - - Buffer buffer; - CAPI_ASSERT_ERROR( - default_serialization_engine_serialize_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_u64( - engine, key, &buffer)); - - writeFile(path, buffer); - free(buffer.pointer); + return outcome::success(); } outcome::checked, StringError> KeySetCache::loadKeys(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb, std::string folderPath) { - // TODO: text dump of all parameter in /hash - auto key_set = std::make_unique(); + // Mark the folder as recently use. // e.g. so the CI can do some cleanup of unused keys. utime(folderPath.c_str(), nullptr); - std::map> - secretKeys; - std::map>> - bootstrapKeys; - std::map>> - keyswitchKeys; - std::map>> - packingKeys; + std::vector secretKeys; + std::vector bootstrapKeys; + std::vector keyswitchKeys; + std::vector packingKeyswitchKeys; - // Load LWE secret keys - for (auto secretKeyParam : params.secretKeys) { - auto id = secretKeyParam.first; - auto param = secretKeyParam.second; + // Load secret keys + for (auto p : llvm::enumerate(params.secretKeys)) { + // TODO - Check parameters? + // auto param = secretKeyParam.second; llvm::SmallString<0> path(folderPath); - llvm::sys::path::append(path, "secretKey_" + id); - OUTCOME_TRY(LweSecretKey64 * sk, loadSecretKey(path)); - secretKeys[id] = {param, sk}; + llvm::sys::path::append(path, "secretKey_" + std::to_string(p.index())); + OUTCOME_TRY(auto key, loadKey(path, readLweSecretKey)); + secretKeys.push_back(key); } // Load bootstrap keys - for (auto bootstrapKeyParam : params.bootstrapKeys) { - auto id = bootstrapKeyParam.first; - auto param = bootstrapKeyParam.second; + for (auto p : llvm::enumerate(params.bootstrapKeys)) { + // TODO - Check parameters? + // auto param = p.value(); llvm::SmallString<0> path(folderPath); - llvm::sys::path::append(path, "pbsKey_" + id); - OUTCOME_TRY(LweBootstrapKey64 * bsk, loadBootstrapKey(path)); - bootstrapKeys[id] = {param, std::make_shared(bsk)}; + llvm::sys::path::append(path, "pbsKey_" + std::to_string(p.index())); + OUTCOME_TRY(auto key, loadKey(path, readLweBootstrapKey)); + bootstrapKeys.push_back(key); } // Load keyswitch keys - for (auto keyswitchParam : params.keyswitchKeys) { - auto id = keyswitchParam.first; - auto param = keyswitchParam.second; + for (auto p : llvm::enumerate(params.keyswitchKeys)) { + // TODO - Check parameters? + // auto param = p.value(); llvm::SmallString<0> path(folderPath); - llvm::sys::path::append(path, "ksKey_" + id); - OUTCOME_TRY(LweKeyswitchKey64 * ksk, loadKeyswitchKey(path)); - keyswitchKeys[id] = {param, std::make_shared(ksk)}; - } - // Load packing keys - for (auto packingParam : params.packingKeys) { - auto id = packingParam.first; - auto param = packingParam.second; - llvm::SmallString<0> path(folderPath); - llvm::sys::path::append(path, "fpksKey_" + id); - OUTCOME_TRY(LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 * - ksk, - loadPackingKey(path)); - packingKeys[id] = {param, std::make_shared(ksk)}; + llvm::sys::path::append(path, "ksKey_" + std::to_string(p.index())); + OUTCOME_TRY(auto key, loadKey(path, readLweKeyswitchKey)); + keyswitchKeys.push_back(key); } - key_set->setKeys(secretKeys, bootstrapKeys, keyswitchKeys, packingKeys); + for (auto p : llvm::enumerate(params.packingKeyswitchKeys)) { + // TODO - Check parameters? + // auto param = p.value(); + llvm::SmallString<0> path(folderPath); + llvm::sys::path::append(path, "pksKey_" + std::to_string(p.index())); + OUTCOME_TRY(auto key, loadKey(path, readPackingKeyswitchKey)); + packingKeyswitchKeys.push_back(key); + } - OUTCOME_TRYV(key_set->setupEncryptionMaterial(params, seed_msb, seed_lsb)); + __uint128_t seed = seed_msb; + seed <<= 64; + seed += seed_lsb; - return std::move(key_set); + auto csprng = ConcreteCSPRNG(seed); + + OUTCOME_TRY(auto keySet, + KeySet::fromKeys(params, secretKeys, bootstrapKeys, keyswitchKeys, + packingKeyswitchKeys, std::move(csprng))); + + return std::move(keySet); } outcome::checked saveKeys(KeySet &key_set, @@ -239,38 +126,30 @@ outcome::checked saveKeys(KeySet &key_set, return StringError("Cannot create directory \"") << std::string(folderIncompletePath) << "\": " << err.message(); } - // Save LWE secret keys - for (auto secretKeyParam : key_set.getSecretKeys()) { - auto id = secretKeyParam.first; - auto key = secretKeyParam.second.second; + for (auto p : llvm::enumerate(key_set.getSecretKeys())) { + llvm::SmallString<0> path = folderIncompletePath; - llvm::sys::path::append(path, "secretKey_" + id); - saveSecretKey(path, key); + llvm::sys::path::append(path, "secretKey_" + std::to_string(p.index())); + OUTCOME_TRYV(saveKey(path, p.value())); } // Save bootstrap keys - for (auto bootstrapKeyParam : key_set.getBootstrapKeys()) { - auto id = bootstrapKeyParam.first; - auto key = bootstrapKeyParam.second.second; + for (auto p : llvm::enumerate(key_set.getBootstrapKeys())) { llvm::SmallString<0> path = folderIncompletePath; - llvm::sys::path::append(path, "pbsKey_" + id); - saveBootstrapKey(path, key->get()); + llvm::sys::path::append(path, "pbsKey_" + std::to_string(p.index())); + OUTCOME_TRYV(saveKey(path, p.value())); } // Save keyswitch keys - for (auto keyswitchParam : key_set.getKeyswitchKeys()) { - auto id = keyswitchParam.first; - auto key = keyswitchParam.second.second; + for (auto p : llvm::enumerate(key_set.getKeyswitchKeys())) { llvm::SmallString<0> path = folderIncompletePath; - llvm::sys::path::append(path, "ksKey_" + id); - saveKeyswitchKey(path, key->get()); + llvm::sys::path::append(path, "ksKey_" + std::to_string(p.index())); + OUTCOME_TRYV(saveKey(path, p.value())); } - // Save packing keys - for (auto keyswitchParam : key_set.getPackingKeys()) { - auto id = keyswitchParam.first; - auto key = keyswitchParam.second.second; + // Save packing keyswitch keys + for (auto p : llvm::enumerate(key_set.getPackingKeyswitchKeys())) { llvm::SmallString<0> path = folderIncompletePath; - llvm::sys::path::append(path, "fpksKey_" + id); - savePackingKey(path, key->get()); + llvm::sys::path::append(path, "pksKey_" + std::to_string(p.index())); + OUTCOME_TRYV(saveKey(path, p.value())); } err = llvm::sys::fs::rename(folderIncompletePath, folderPath); @@ -338,7 +217,14 @@ KeySetCache::loadOrGenerateSave(ClientParameters ¶ms, uint64_t seed_msb, std::cerr << "KeySetCache: miss, regenerating " << std::string(folderPath) << "\n"; - OUTCOME_TRY(auto key_set, KeySet::generate(params, seed_msb, seed_lsb)); + + __uint128_t seed = seed_msb; + seed <<= 64; + seed += seed_lsb; + + auto csprng = ConcreteCSPRNG(seed); + + OUTCOME_TRY(auto key_set, KeySet::generate(params, std::move(csprng))); OUTCOME_TRYV(saveKeys(*key_set, folderPath)); @@ -349,8 +235,13 @@ outcome::checked, StringError> KeySetCache::generate(std::shared_ptr cache, ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb) { + __uint128_t seed = seed_msb; + seed <<= 64; + seed += seed_lsb; + + auto csprng = ConcreteCSPRNG(seed); return cache ? cache->loadOrGenerateSave(params, seed_msb, seed_lsb) - : KeySet::generate(params, seed_msb, seed_lsb); + : KeySet::generate(params, std::move(csprng)); } outcome::checked, StringError> diff --git a/compiler/lib/ClientLib/Serializers.cpp b/compiler/lib/ClientLib/Serializers.cpp index 0a19839d4..2f0946140 100644 --- a/compiler/lib/ClientLib/Serializers.cpp +++ b/compiler/lib/ClientLib/Serializers.cpp @@ -7,8 +7,6 @@ #include #include -#include "concrete-core-ffi.h" - #include "concretelang/ClientLib/PublicArguments.h" #include "concretelang/ClientLib/Serializers.h" #include "concretelang/Common/Error.h" @@ -16,139 +14,226 @@ namespace concretelang { namespace clientlib { -template -Result read_deser(std::istream &istream, - int (*deser)(Engine *, BufferView, Result *), - Engine *engine) { - size_t length; - readSize(istream, length); - // buffer is too big to be allocated on stack - // vector ensures everything is deallocated w.r.t. new - std::vector buffer(length); - istream.read((char *)buffer.data(), length); - assert(istream.good()); - Result result; - - CAPI_ASSERT_ERROR(deser(engine, {buffer.data(), length}, &result)); - - return result; -} - -template -std::ostream &writeBufferLike(std::ostream &ostream, BufferLike &buffer) { - writeSize(ostream, buffer.length); - ostream.write((const char *)buffer.pointer, buffer.length); +template +std::ostream &writeUInt64KeyBuffer(std::ostream &ostream, Key &buffer) { + writeSize(ostream, (uint64_t)buffer.size()); + ostream.write((const char *)buffer.buffer(), + buffer.size() * sizeof(uint64_t)); assert(ostream.good()); return ostream; } -std::ostream &operator<<(std::ostream &ostream, const LweKeyswitchKey64 *key) { - DefaultSerializationEngine *engine; - - // No Freeing as it doesn't allocate anything. - CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); - - Buffer b; - - CAPI_ASSERT_ERROR( - default_serialization_engine_serialize_lwe_keyswitch_key_u64(engine, key, - &b)); - - writeBufferLike(ostream, b); - free((void *)b.pointer); - b.pointer = nullptr; - return ostream; -} - -std::ostream &operator<<(std::ostream &ostream, const LweBootstrapKey64 *key) { - DefaultSerializationEngine *engine; - - // No Freeing as it doesn't allocate anything. - CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); - - Buffer b; - - CAPI_ASSERT_ERROR( - default_serialization_engine_serialize_lwe_bootstrap_key_u64(engine, key, - &b)) - - writeBufferLike(ostream, b); - free((void *)b.pointer); - b.pointer = nullptr; - return ostream; -} - -std::ostream &operator<<(std::ostream &ostream, - const FftFourierLweBootstrapKey64 *key) { - FftSerializationEngine *engine; - - // No Freeing as it doesn't allocate anything. - CAPI_ASSERT_ERROR(new_fft_serialization_engine(&engine)); - - Buffer b; - - CAPI_ASSERT_ERROR( - fft_serialization_engine_serialize_fft_fourier_lwe_bootstrap_key_u64( - engine, key, &b)) - - writeBufferLike(ostream, b); - free((void *)b.pointer); - b.pointer = nullptr; - return ostream; -} - -std::istream &operator>>(std::istream &istream, LweKeyswitchKey64 *&key) { - DefaultSerializationEngine *engine; - - // No Freeing as it doesn't allocate anything. - CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); - - key = read_deser( - istream, default_serialization_engine_deserialize_lwe_keyswitch_key_u64, - engine); - return istream; -} - -std::istream &operator>>(std::istream &istream, LweBootstrapKey64 *&key) { - DefaultSerializationEngine *engine; - - // No Freeing as it doesn't allocate anything. - CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine)); - - key = read_deser( - istream, default_serialization_engine_deserialize_lwe_bootstrap_key_u64, - engine); - return istream; -} - std::istream &operator>>(std::istream &istream, - FftFourierLweBootstrapKey64 *&key) { - FftSerializationEngine *engine; - - // No Freeing as it doesn't allocate anything. - CAPI_ASSERT_ERROR(new_fft_serialization_engine(&engine)); - - key = read_deser( - istream, - fft_serialization_engine_deserialize_fft_fourier_lwe_bootstrap_key_u64, - engine); - return istream; -} - -std::istream &operator>>(std::istream &istream, - RuntimeContext &runtimeContext) { - istream >> runtimeContext.evaluationKeys; + std::shared_ptr> &vec) { + // TODO assertion on size? + uint64_t size; + readSize(istream, size); + vec->resize(size); + istream.read((char *)vec->data(), size * sizeof(uint64_t)); assert(istream.good()); return istream; } +// LweSecretKey //////////////////////////// + +std::ostream &operator<<(std::ostream &ostream, const LweSecretKeyParam param) { + writeWord(ostream, param.dimension); + return ostream; +} + +std::istream &operator>>(std::istream &istream, LweSecretKeyParam ¶m) { + readWord(istream, param.dimension); + return istream; +} + +// LweSecretKey ///////////////////////////////// + +std::ostream &operator<<(std::ostream &ostream, const LweSecretKey &key) { + ostream << key.parameters(); + writeUInt64KeyBuffer(ostream, key); + return ostream; +} + +LweSecretKey readLweSecretKey(std::istream &istream) { + LweSecretKeyParam param; + istream >> param; + auto buffer = std::make_shared>(); + istream >> buffer; + return LweSecretKey(buffer, param); +} + +// KeyswitchKeyParam //////////////////////////// + +std::ostream &operator<<(std::ostream &ostream, const KeyswitchKeyParam param) { + // TODO keys id + writeWord(ostream, param.level); + writeWord(ostream, param.baseLog); + writeWord(ostream, param.variance); + return ostream; +} + +std::istream &operator>>(std::istream &istream, KeyswitchKeyParam ¶m) { + // TODO keys id + param.outputSecretKeyID = 1234; + param.inputSecretKeyID = 1234; + readWord(istream, param.level); + readWord(istream, param.baseLog); + readWord(istream, param.variance); + return istream; +} + +// LweKeyswitchKey ////////////////////////////// + +std::ostream &operator<<(std::ostream &ostream, const LweKeyswitchKey &key) { + ostream << key.parameters(); + writeUInt64KeyBuffer(ostream, key); + return ostream; +} + +LweKeyswitchKey readLweKeyswitchKey(std::istream &istream) { + KeyswitchKeyParam param; + istream >> param; + auto buffer = std::make_shared>(); + istream >> buffer; + return LweKeyswitchKey(buffer, param); +} + +// BootstrapKeyParam //////////////////////////// + +std::ostream &operator<<(std::ostream &ostream, const BootstrapKeyParam param) { + // TODO keys id + writeWord(ostream, param.level); + writeWord(ostream, param.baseLog); + writeWord(ostream, param.glweDimension); + writeWord(ostream, param.variance); + writeWord(ostream, param.polynomialSize); + writeWord(ostream, param.inputLweDimension); + return ostream; +} + +std::istream &operator>>(std::istream &istream, BootstrapKeyParam ¶m) { + // TODO keys id + readWord(istream, param.level); + readWord(istream, param.baseLog); + readWord(istream, param.glweDimension); + readWord(istream, param.variance); + readWord(istream, param.polynomialSize); + readWord(istream, param.inputLweDimension); + return istream; +} + +// LweBootstrapKey ////////////////////////////// + +std::ostream &operator<<(std::ostream &ostream, const LweBootstrapKey &key) { + ostream << key.parameters(); + writeUInt64KeyBuffer(ostream, key); + return ostream; +} + +LweBootstrapKey readLweBootstrapKey(std::istream &istream) { + BootstrapKeyParam param; + istream >> param; + auto buffer = std::make_shared>(); + istream >> buffer; + return LweBootstrapKey(buffer, param); +} + +// PackingKeyswitchKeyParam //////////////////////////// + std::ostream &operator<<(std::ostream &ostream, - const RuntimeContext &runtimeContext) { - ostream << runtimeContext.evaluationKeys; + const PackingKeyswitchKeyParam param) { + + // TODO keys id + writeWord(ostream, param.level); + writeWord(ostream, param.baseLog); + writeWord(ostream, param.glweDimension); + writeWord(ostream, param.polynomialSize); + writeWord(ostream, param.inputLweDimension); + writeWord(ostream, param.variance); + + return ostream; +} + +std::istream &operator>>(std::istream &istream, + PackingKeyswitchKeyParam ¶m) { + + // TODO keys id + param.outputSecretKeyID = 1234; + param.inputSecretKeyID = 1234; + readWord(istream, param.level); + readWord(istream, param.baseLog); + readWord(istream, param.glweDimension); + readWord(istream, param.polynomialSize); + readWord(istream, param.inputLweDimension); + readWord(istream, param.variance); + + return istream; +} + +// PackingKeyswitchKey ////////////////////////////// + +std::ostream &operator<<(std::ostream &ostream, + const PackingKeyswitchKey &key) { + ostream << key.parameters(); + writeUInt64KeyBuffer(ostream, key); + return ostream; +} + +PackingKeyswitchKey readPackingKeyswitchKey(std::istream &istream) { + PackingKeyswitchKeyParam param; + istream >> param; + auto buffer = std::make_shared>(); + istream >> buffer; + auto b = PackingKeyswitchKey(buffer, param); + + return b; +} + +// EvaluationKey //////////////////////////////// + +EvaluationKeys readEvaluationKeys(std::istream &istream) { + uint64_t nbKey; + readSize(istream, nbKey); + std::vector bootstrapKeys; + for (uint64_t i = 0; i < nbKey; i++) { + bootstrapKeys.push_back(readLweBootstrapKey(istream)); + } + readSize(istream, nbKey); + std::vector keyswitchKeys; + for (uint64_t i = 0; i < nbKey; i++) { + keyswitchKeys.push_back(readLweKeyswitchKey(istream)); + } + std::vector packingKeyswitchKeys; + readSize(istream, nbKey); + for (uint64_t i = 0; i < nbKey; i++) { + packingKeyswitchKeys.push_back(readPackingKeyswitchKey(istream)); + } + return EvaluationKeys(keyswitchKeys, bootstrapKeys, packingKeyswitchKeys); +} + +std::ostream &operator<<(std::ostream &ostream, + const EvaluationKeys &evaluationKeys) { + auto bootstrapKeys = evaluationKeys.getBootstrapKeys(); + writeSize(ostream, bootstrapKeys.size()); + for (auto bsk : bootstrapKeys) { + ostream << bsk; + } + auto keyswitchKeys = evaluationKeys.getKeyswitchKeys(); + writeSize(ostream, keyswitchKeys.size()); + for (auto ksk : keyswitchKeys) { + ostream << ksk; + } + auto packingKeyswitchKeys = evaluationKeys.getPackingKeyswitchKeys(); + writeSize(ostream, packingKeyswitchKeys.size()); + for (auto pksk : packingKeyswitchKeys) { + ostream << pksk; + } assert(ostream.good()); return ostream; } +// TensorData /////////////////////////////////// + template std::ostream &serializeScalarDataRaw(T value, std::ostream &ostream) { writeWord(ostream, sizeof(T) * 8); @@ -399,70 +484,5 @@ unserializeScalarOrTensorData(const std::vector &expectedSizes, } } -std::ostream &operator<<(std::ostream &ostream, - const LweKeyswitchKey &wrappedKsk) { - ostream << wrappedKsk.ksk; - assert(ostream.good()); - return ostream; -} -std::istream &operator>>(std::istream &istream, LweKeyswitchKey &wrappedKsk) { - istream >> wrappedKsk.ksk; - assert(istream.good()); - return istream; -} - -std::ostream &operator<<(std::ostream &ostream, - const LweBootstrapKey &wrappedBsk) { - ostream << wrappedBsk.bsk; - assert(ostream.good()); - return ostream; -} -std::istream &operator>>(std::istream &istream, LweBootstrapKey &wrappedBsk) { - istream >> wrappedBsk.bsk; - assert(istream.good()); - return istream; -} - -std::ostream &operator<<(std::ostream &ostream, - const EvaluationKeys &evaluationKeys) { - bool has_ksk = (bool)evaluationKeys.sharedKsk; - writeWord(ostream, has_ksk); - if (has_ksk) { - ostream << *evaluationKeys.sharedKsk; - } - - bool has_bsk = (bool)evaluationKeys.sharedBsk; - writeWord(ostream, has_bsk); - if (has_bsk) { - ostream << *evaluationKeys.sharedBsk; - } - assert(ostream.good()); - return ostream; -} - -std::istream &operator>>(std::istream &istream, - EvaluationKeys &evaluationKeys) { - bool has_ksk; - readWord(istream, has_ksk); - if (has_ksk) { - auto sharedKsk = LweKeyswitchKey(nullptr); - istream >> sharedKsk; - evaluationKeys.sharedKsk = - std::make_shared(std::move(sharedKsk)); - } - - bool has_bsk; - readWord(istream, has_bsk); - if (has_bsk) { - auto sharedBsk = LweBootstrapKey(nullptr); - istream >> sharedBsk; - evaluationKeys.sharedBsk = - std::make_shared(std::move(sharedBsk)); - } - - assert(istream.good()); - return istream; -} - } // namespace clientlib } // namespace concretelang diff --git a/compiler/lib/Conversion/ConcreteToCAPI/ConcreteToCAPI.cpp b/compiler/lib/Conversion/ConcreteToCAPI/ConcreteToCAPI.cpp index be9fc9aaf..4081b2caa 100644 --- a/compiler/lib/Conversion/ConcreteToCAPI/ConcreteToCAPI.cpp +++ b/compiler/lib/Conversion/ConcreteToCAPI/ConcreteToCAPI.cpp @@ -47,8 +47,7 @@ char memref_wop_pbs_crt_buffer[] = "memref_wop_pbs_crt_buffer"; char memref_encode_plaintext_with_crt[] = "memref_encode_plaintext_with_crt"; char memref_encode_expand_lut_for_bootstrap[] = "memref_encode_expand_lut_for_bootstrap"; -char memref_encode_expand_lut_for_woppbs[] = - "memref_encode_expand_lut_for_woppbs"; +char memref_encode_lut_for_crt_woppbs[] = "memref_encode_lut_for_crt_woppbs"; char memref_trace[] = "memref_trace"; mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter, @@ -158,10 +157,16 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI( } else if (funcName == memref_wop_pbs_crt_buffer) { funcType = mlir::FunctionType::get(rewriter.getContext(), { + memref2DType, memref2DType, memref2DType, memref1DType, - memref1DType, + rewriter.getI32Type(), + rewriter.getI32Type(), + rewriter.getI32Type(), + rewriter.getI32Type(), + rewriter.getI32Type(), + rewriter.getI32Type(), rewriter.getI32Type(), rewriter.getI32Type(), rewriter.getI32Type(), @@ -181,11 +186,11 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI( {memref1DType, memref1DType, rewriter.getI32Type(), rewriter.getI32Type(), rewriter.getI1Type()}, {}); - } else if (funcName == memref_encode_expand_lut_for_woppbs) { + } else if (funcName == memref_encode_lut_for_crt_woppbs) { funcType = mlir::FunctionType::get( rewriter.getContext(), - {memref1DType, memref1DType, memref1DType, memref1DType, - rewriter.getI32Type(), rewriter.getI32Type(), rewriter.getI1Type()}, + {memref2DType, memref1DType, memref1DType, memref1DType, + rewriter.getI32Type(), rewriter.getI1Type()}, {}); } else if (funcName == memref_trace) { funcType = mlir::FunctionType::get( @@ -326,9 +331,32 @@ void wopPBSAddOperands(Concrete::WopPBSCRTLweBufferOp op, // cbs_base_log operands.push_back(rewriter.create( op.getLoc(), op.circuitBootstrapBaseLogAttr())); + + // ksk_level_count + operands.push_back(rewriter.create( + op.getLoc(), op.keyswitchLevelAttr())); + // ksk_base_log + operands.push_back(rewriter.create( + op.getLoc(), op.keyswitchBaseLogAttr())); + + // bsk_level_count + operands.push_back(rewriter.create( + op.getLoc(), op.bootstrapLevelAttr())); + // bsk_base_log + operands.push_back(rewriter.create( + op.getLoc(), op.bootstrapBaseLogAttr())); + + // fpksk_level_count + operands.push_back(rewriter.create( + op.getLoc(), op.packingKeySwitchLevelAttr())); + // fpksk_base_log + operands.push_back(rewriter.create( + op.getLoc(), op.packingKeySwitchBaseLogAttr())); + // polynomial_size operands.push_back(rewriter.create( op.getLoc(), op.packingKeySwitchoutputPolynomialSizeAttr())); + // context operands.push_back(getContextArgument(op)); } @@ -372,9 +400,9 @@ void encodeExpandLutForBootstrapAddOperands( rewriter.create(op.getLoc(), op.isSignedAttr())); } -void encodeExpandLutForWopPBSAddOperands( - Concrete::EncodeExpandLutForWopPBSBufferOp op, - mlir::SmallVector &operands, mlir::RewriterBase &rewriter) { +void encodeLutForWopPBSAddOperands(Concrete::EncodeLutForCrtWopPBSBufferOp op, + mlir::SmallVector &operands, + mlir::RewriterBase &rewriter) { // crt_decomposition mlir::Type crtDecompositionType = mlir::RankedTensorType::get( @@ -414,9 +442,6 @@ void encodeExpandLutForWopPBSAddOperands( op.getLoc(), (*crtBitsGlobalMemref).type(), (*crtBitsGlobalMemref).getName()); operands.push_back(getCastedMemRef(rewriter, crtBitsGlobalRef)); - // poly_size - operands.push_back( - rewriter.create(op.getLoc(), op.polySizeAttr())); // modulus_product operands.push_back(rewriter.create( op.getLoc(), op.modulusProductAttr())); @@ -467,10 +492,10 @@ struct ConcreteToCAPIPass : public ConcreteToCAPIBase { ConcreteToCAPICallPattern>( &getContext(), encodeExpandLutForBootstrapAddOperands); - patterns.add< - ConcreteToCAPICallPattern>( - &getContext(), encodeExpandLutForWopPBSAddOperands); + patterns + .add>( + &getContext(), encodeLutForWopPBSAddOperands); if (gpu) { patterns.add>( diff --git a/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp b/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp index f27dccfa6..a79ea9796 100644 --- a/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp +++ b/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp @@ -560,17 +560,18 @@ struct ApplyLookupTableEintOpPattern mlir::Value newLut = rewriter - .create( + .create( op.getLoc(), mlir::RankedTensorType::get( - mlir::ArrayRef(loweringParameters.lutSize), + mlir::ArrayRef{ + (int64_t)loweringParameters.nMods, + (int64_t)loweringParameters.singleLutSize}, rewriter.getI64Type()), adaptor.lut(), rewriter.getI64ArrayAttr( mlir::ArrayRef(loweringParameters.mods)), rewriter.getI64ArrayAttr( mlir::ArrayRef(loweringParameters.bits)), - rewriter.getI32IntegerAttr(loweringParameters.polynomialSize), rewriter.getI32IntegerAttr(loweringParameters.modsProd), rewriter.getBoolAttr(originalInputType.isSigned())) .getResult(); diff --git a/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp b/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp index b95b725e0..e9d957206 100644 --- a/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp +++ b/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp @@ -664,8 +664,8 @@ void TFHEToConcretePass::runOnOperation() { mlir::concretelang::Concrete::EncodeExpandLutForBootstrapTensorOp, true>, mlir::concretelang::GenericOneToOneOpConversionPattern< - mlir::concretelang::TFHE::EncodeExpandLutForWopPBSOp, - mlir::concretelang::Concrete::EncodeExpandLutForWopPBSTensorOp, true>, + mlir::concretelang::TFHE::EncodeLutForCrtWopPBSOp, + mlir::concretelang::Concrete::EncodeLutForCrtWopPBSTensorOp, true>, mlir::concretelang::GenericOneToOneOpConversionPattern< mlir::concretelang::TFHE::EncodePlaintextWithCrtOp, mlir::concretelang::Concrete::EncodePlaintextWithCrtTensorOp, true>, diff --git a/compiler/lib/Dialect/Concrete/Transforms/BufferizableOpInterfaceImpl.cpp b/compiler/lib/Dialect/Concrete/Transforms/BufferizableOpInterfaceImpl.cpp index cc8e84beb..c00cf26fc 100644 --- a/compiler/lib/Dialect/Concrete/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/compiler/lib/Dialect/Concrete/Transforms/BufferizableOpInterfaceImpl.cpp @@ -143,10 +143,10 @@ void mlir::concretelang::Concrete:: Concrete::EncodeExpandLutForBootstrapTensorOp::attachInterface< TensorToMemrefOp>(*ctx); - // encode_expand_lut_for_woppbs_tensor => - // encode_expand_lut_for_woppbs_buffer - Concrete::EncodeExpandLutForWopPBSTensorOp::attachInterface< - TensorToMemrefOp>(*ctx); + // encode_lut_for_crt_woppbs_tensor => + // encode_lut_for_crt_woppbs_buffer + Concrete::EncodeLutForCrtWopPBSTensorOp::attachInterface< + TensorToMemrefOp>(*ctx); }); } diff --git a/compiler/lib/Runtime/CMakeLists.txt b/compiler/lib/Runtime/CMakeLists.txt index bb5f5492e..7c4d2f2bf 100644 --- a/compiler/lib/Runtime/CMakeLists.txt +++ b/compiler/lib/Runtime/CMakeLists.txt @@ -1,4 +1,6 @@ -add_library(ConcretelangRuntime SHARED context.cpp wrappers.cpp DFRuntime.cpp StreamEmulator.cpp seeder.cpp) +add_library(ConcretelangRuntime SHARED context.cpp wrappers.cpp DFRuntime.cpp StreamEmulator.cpp) + +add_dependencies(ConcretelangRuntime concrete_cpu) if(CONCRETELANG_DATAFLOW_EXECUTION_ENABLED) target_link_libraries(ConcretelangRuntime PRIVATE HPX::hpx HPX::iostreams_component) @@ -25,7 +27,8 @@ if(APPLE) target_link_libraries(ConcretelangRuntime LINK_PUBLIC ${SECURITY_FRAMEWORK}) endif() -target_link_libraries(ConcretelangRuntime PUBLIC Concrete ConcretelangClientLib pthread m dl +target_include_directories(ConcretelangRuntime PUBLIC ${CONCRETE_CPU_INCLUDE_DIR}) +target_link_libraries(ConcretelangRuntime PUBLIC concrete_cpu ConcretelangClientLib pthread m dl $) install(TARGETS ConcretelangRuntime omp EXPORT ConcretelangRuntime) diff --git a/compiler/lib/Runtime/context.cpp b/compiler/lib/Runtime/context.cpp index c5a7bdd98..18feb159b 100644 --- a/compiler/lib/Runtime/context.cpp +++ b/compiler/lib/Runtime/context.cpp @@ -5,29 +5,77 @@ #include "concretelang/Runtime/context.h" #include "concretelang/Common/Error.h" -#include "concretelang/Runtime/seeder.h" #include #include -LweKeyswitchKey64 * -get_keyswitch_key_u64(mlir::concretelang::RuntimeContext *context) { - return context->get_ksk(); +namespace clientlib = ::concretelang::clientlib; +namespace mlir { +namespace concretelang { + +FFT::FFT(size_t polynomial_size) + : fft(nullptr), polynomial_size(polynomial_size) { + fft = (struct Fft *)aligned_alloc(CONCRETE_FFT_ALIGN, CONCRETE_FFT_SIZE); + concrete_cpu_construct_concrete_fft(fft, polynomial_size); } -LweBootstrapKey64 * -get_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context) { - return context->get_bsk(); +FFT::FFT(FFT &&other) : fft(other.fft), polynomial_size(other.polynomial_size) { + other.fft = nullptr; } -FftFourierLweBootstrapKey64 * -get_fft_fourier_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context) { - return context->get_fft_fourier_bsk(); +FFT::~FFT() { + if (fft != nullptr) { + concrete_cpu_destroy_concrete_fft(fft); + free(fft); + } } -DefaultEngine *get_engine(mlir::concretelang::RuntimeContext *context) { - return context->get_default_engine(); +RuntimeContext::RuntimeContext(clientlib::EvaluationKeys evaluationKeys) + : evaluationKeys(evaluationKeys) { + { + + // Initialize for each bootstrap key the fourier one + for (auto bsk : evaluationKeys.getBootstrapKeys()) { + auto param = bsk.parameters(); + + size_t decomposition_level_count = param.level; + size_t decomposition_base_log = param.baseLog; + size_t glwe_dimension = param.glweDimension; + size_t polynomial_size = param.polynomialSize; + size_t input_lwe_dimension = param.inputLweDimension; + + // Create the FFT + FFT fft(polynomial_size); + + // Allocate scratch for key conversion + size_t scratch_size; + size_t scratch_align; + concrete_cpu_bootstrap_key_convert_u64_to_fourier_scratch( + &scratch_size, &scratch_align, fft.fft); + auto scratch = (uint8_t *)aligned_alloc(scratch_align, scratch_size); + + // Allocate the fourier_bootstrap_key + auto fourier_data = std::make_shared>(); + fourier_data->resize(bsk.size()); + auto bsk_data = bsk.buffer(); + + // Convert bootstrap_key to the fourier domain + concrete_cpu_bootstrap_key_convert_u64_to_fourier( + bsk_data, fourier_data->data(), decomposition_level_count, + decomposition_base_log, glwe_dimension, polynomial_size, + input_lwe_dimension, fft.fft, scratch, scratch_size); + + // Store the fourier_bootstrap_key in the context + fourier_bootstrap_keys.push_back(fourier_data); + ffts.push_back(std::move(fft)); + free(scratch); + } + +#ifdef CONCRETELANG_CUDA_SUPPORT + bsk_gpu = nullptr; + ksk_gpu = nullptr; +#endif + } } -FftEngine *get_fft_engine(mlir::concretelang::RuntimeContext *context) { - return context->get_fft_engine(); -} +} // namespace concretelang +} // namespace mlir diff --git a/compiler/lib/Runtime/seeder.cpp b/compiler/lib/Runtime/seeder.cpp deleted file mode 100644 index d67441e10..000000000 --- a/compiler/lib/Runtime/seeder.cpp +++ /dev/null @@ -1,52 +0,0 @@ -// Part of the Concrete Compiler Project, under the BSD3 License with Zama -// Exceptions. See -// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt -// for license information. - -#include "concretelang/Runtime/seeder.h" -#include -#include - -#include "concrete-core-ffi.h" -#include "concretelang/Common/Error.h" - -SeederBuilder *get_best_seeder() { - SeederBuilder *builder = NULL; - -#if defined(__x86_64__) || defined(_M_X64) - bool rdseed_seeder_available = false; - CAPI_ASSERT_ERROR(rdseed_seeder_is_available(&rdseed_seeder_available)); - if (rdseed_seeder_available) { - CAPI_ASSERT_ERROR(get_rdseed_seeder_builder(&builder)); - return builder; - } -#endif - -#if __APPLE__ - bool apple_seeder_available = false; - CAPI_ASSERT_ERROR( - apple_secure_enclave_seeder_is_available(&apple_seeder_available)); - if (apple_seeder_available) { - CAPI_ASSERT_ERROR(get_apple_secure_enclave_seeder_builder(&builder)); - return builder; - } -#endif - - bool unix_seeder_available = false; - CAPI_ASSERT_ERROR(unix_seeder_is_available(&unix_seeder_available)); - - if (unix_seeder_available) { - // Security depends on /dev/random security - uint64_t secret_high_64 = 0; - uint64_t secret_low_64 = 0; - CAPI_ASSERT_ERROR( - get_unix_seeder_builder(secret_high_64, secret_low_64, &builder)); - - return builder; - } - - std::cout << "No available seeder." << std::endl; - return builder; -} - -SeederBuilder *best_seeder = get_best_seeder(); diff --git a/compiler/lib/Runtime/wrappers.cpp b/compiler/lib/Runtime/wrappers.cpp index 3fc914557..7ac85ac01 100644 --- a/compiler/lib/Runtime/wrappers.cpp +++ b/compiler/lib/Runtime/wrappers.cpp @@ -4,8 +4,8 @@ // for license information. #include "concretelang/Runtime/wrappers.h" +#include "concrete-cpu.h" #include "concretelang/Common/Error.h" -#include "concretelang/Runtime/seeder.h" #include #include #include @@ -16,15 +16,6 @@ #include #include -static DefaultEngine *levelled_engine = nullptr; - -DefaultEngine *get_levelled_engine() { - if (levelled_engine == nullptr) { - CAPI_ASSERT_ERROR(new_default_engine(best_seeder, &levelled_engine)); - } - return levelled_engine; -} - #include "concretelang/ClientLib/CRT.h" #include "concretelang/Runtime/wrappers.h" @@ -184,18 +175,21 @@ void memref_batched_bootstrap_lwe_cuda_u64( // Construct the glwe accumulator (on CPU) // TODO: Should be done outside of the bootstrap call, compile time if // possible. Refactor in progress - uint64_t glwe_ct_len = poly_size * (glwe_dim + 1); - uint64_t glwe_ct_size = glwe_ct_len * sizeof(uint64_t); - uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size); + uint64_t glwe_ct_size = poly_size * (glwe_dim + 1); + uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size * sizeof(uint64_t)); + auto tlu = tlu_aligned + tlu_offset; - CAPI_ASSERT_ERROR( - default_engine_discard_trivially_encrypt_glwe_ciphertext_u64_raw_ptr_buffers( - get_levelled_engine(), glwe_ct, glwe_ct_len, tlu_aligned + tlu_offset, - poly_size)); + // Glwe trivial encryption + for (size_t i = 0; i < poly_size * glwe_dim; i++) { + glwe_ct[i] = 0; + } + for (size_t i = 0; i < poly_size; i++) { + glwe_ct[poly_size * glwe_dim + i] = tlu[i]; + } // Move the glwe accumulator to the GPU void *glwe_ct_gpu = - alloc_and_memcpy_async_to_gpu(glwe_ct, 0, glwe_ct_len, gpu_idx, stream); + alloc_and_memcpy_async_to_gpu(glwe_ct, 0, glwe_ct_size, gpu_idx, stream); // Move test vector indexes to the GPU, the test vector indexes is set of 0 uint32_t num_test_vectors = 1, lwe_idx = 0, @@ -313,11 +307,12 @@ void memref_encode_expand_lut_for_bootstrap( return; } -void memref_encode_expand_lut_for_woppbs( +void memref_encode_lut_for_crt_woppbs( // Output encoded/expanded lut uint64_t *output_lut_allocated, uint64_t *output_lut_aligned, - uint64_t output_lut_offset, uint64_t output_lut_size, - uint64_t output_lut_stride, + uint64_t output_lut_offset, uint64_t output_lut_size0, + uint64_t output_lut_size1, uint64_t output_lut_stride0, + uint64_t output_lut_stride1, // Input lut uint64_t *input_lut_allocated, uint64_t *input_lut_aligned, uint64_t input_lut_offset, uint64_t input_lut_size, @@ -330,13 +325,22 @@ void memref_encode_expand_lut_for_woppbs( uint64_t *crt_bits_allocated, uint64_t *crt_bits_aligned, uint64_t crt_bits_offset, uint64_t crt_bits_size, uint64_t crt_bits_stride, // Crypto parameters - uint32_t poly_size, uint32_t modulus_product, bool is_signed) { + uint32_t modulus_product, bool is_signed) { assert(input_lut_stride == 1 && "Runtime: stride not equal to 1, check " - "memref_encode_expand_lut_woppbs"); - assert(output_lut_stride == 1 && "Runtime: stride not equal to 1, check " - "memref_encode_expand_lut_woppbs"); - assert(modulus_product > input_lut_size); + "memref_encode_lut_woppbs"); + assert(output_lut_stride0 == output_lut_size1 && + "Runtime: out dim stride not equal to in_dim size, check " + "memref_encode_lut_woppbs"); + assert(output_lut_stride1 == 1 && "Runtime: stride not equal to 1, check " + "memref_encode_lut_woppbs"); + + assert(modulus_product >= input_lut_size); + + // Initialize lut cases not supposed to be reached + for (uint64_t i = 0; i < output_lut_size0 * output_lut_size1; i++) { + output_lut_aligned[output_lut_offset + i] = 0; + } // When the woppbs is executed on encrypted signed integers, the index of the // lut elements must be adapted to fit the way signed are encrypted in CRT @@ -393,25 +397,40 @@ void memref_encode_expand_lut_for_woppbs( }; } - uint64_t lut_crt_size = output_lut_size / crt_decomposition_size; + uint64_t log_lut_crt_size = 0; - for (uint64_t index = 0; index < input_lut_size; index++) { - uint64_t index_lut = 0; - uint64_t tmp = 1; + for (size_t in_block = 0; in_block < crt_decomposition_size; in_block++) { + auto bits_count = crt_bits_aligned[crt_bits_offset + in_block]; + log_lut_crt_size += bits_count; + } - for (size_t block = 0; block < crt_decomposition_size; block++) { - auto base = crt_decomposition_aligned[crt_decomposition_offset + block]; - auto bits = crt_bits_aligned[crt_bits_offset + block]; - index_lut += (((indexMap(index) % base) << bits) / base) * tmp; - tmp <<= bits; + uint64_t lut_crt_size = 1 << log_lut_crt_size; + assert(lut_crt_size == output_lut_size1); + assert(crt_decomposition_size == output_lut_size0); + + for (uint64_t in_index = 0; in_index < input_lut_size; in_index++) { + uint64_t out_index = 0; + + { + uint64_t total_bit_count = 0; + for (size_t in_block = 0; in_block < crt_decomposition_size; in_block++) { + auto in_base = + crt_decomposition_aligned[crt_decomposition_offset + in_block]; + auto bits_count = crt_bits_aligned[crt_bits_offset + in_block]; + out_index += (((indexMap(in_index) % in_base) << bits_count) / in_base) + << total_bit_count; + total_bit_count += bits_count; + } } - for (size_t block = 0; block < crt_decomposition_size; block++) { - auto base = crt_decomposition_aligned[crt_decomposition_offset + block]; - auto v = encode_crt(input_lut_aligned[input_lut_offset + index], base, - modulus_product); - output_lut_aligned[output_lut_offset + block * lut_crt_size + index_lut] = - v; + for (size_t out_block = 0; out_block < crt_decomposition_size; + out_block++) { + auto out_base = + crt_decomposition_aligned[crt_decomposition_offset + out_block]; + auto v = encode_crt(input_lut_aligned[input_lut_offset + in_index], + out_base, modulus_product); + output_lut_aligned[output_lut_offset + out_block * lut_crt_size + + out_index] = v; } } } @@ -424,11 +443,10 @@ void memref_add_lwe_ciphertexts_u64( uint64_t ct1_offset, uint64_t ct1_size, uint64_t ct1_stride) { assert(out_size == ct0_size && out_size == ct1_size && "size of lwe buffer are incompatible"); - size_t lwe_dimension = {out_size - 1}; - CAPI_ASSERT_ERROR( - default_engine_discard_add_lwe_ciphertext_u64_raw_ptr_buffers( - get_levelled_engine(), out_aligned + out_offset, - ct0_aligned + ct0_offset, ct1_aligned + ct1_offset, lwe_dimension)); + size_t lwe_dimension = out_size - 1; + concrete_cpu_add_lwe_ciphertext_u64(out_aligned + out_offset, + ct0_aligned + ct0_offset, + ct1_aligned + ct1_offset, lwe_dimension); } void memref_add_plaintext_lwe_ciphertext_u64( @@ -437,11 +455,10 @@ void memref_add_plaintext_lwe_ciphertext_u64( uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, uint64_t ct0_stride, uint64_t plaintext) { assert(out_size == ct0_size && "size of lwe buffer are incompatible"); - size_t lwe_dimension = {out_size - 1}; - CAPI_ASSERT_ERROR( - default_engine_discard_add_lwe_ciphertext_plaintext_u64_raw_ptr_buffers( - get_levelled_engine(), out_aligned + out_offset, - ct0_aligned + ct0_offset, lwe_dimension, plaintext)); + size_t lwe_dimension = out_size - 1; + concrete_cpu_add_plaintext_lwe_ciphertext_u64(out_aligned + out_offset, + ct0_aligned + ct0_offset, + plaintext, lwe_dimension); } void memref_mul_cleartext_lwe_ciphertext_u64( @@ -450,11 +467,10 @@ void memref_mul_cleartext_lwe_ciphertext_u64( uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, uint64_t ct0_stride, uint64_t cleartext) { assert(out_size == ct0_size && "size of lwe buffer are incompatible"); - size_t lwe_dimension = {out_size - 1}; - CAPI_ASSERT_ERROR( - default_engine_discard_mul_lwe_ciphertext_cleartext_u64_raw_ptr_buffers( - get_levelled_engine(), out_aligned + out_offset, - ct0_aligned + ct0_offset, lwe_dimension, cleartext)); + size_t lwe_dimension = out_size - 1; + concrete_cpu_mul_cleartext_lwe_ciphertext_u64(out_aligned + out_offset, + ct0_aligned + ct0_offset, + cleartext, lwe_dimension); } void memref_negate_lwe_ciphertext_u64( @@ -464,24 +480,25 @@ void memref_negate_lwe_ciphertext_u64( uint64_t ct0_stride) { assert(out_size == ct0_size && "size of lwe buffer are incompatible"); size_t lwe_dimension = {out_size - 1}; - CAPI_ASSERT_ERROR( - default_engine_discard_opp_lwe_ciphertext_u64_raw_ptr_buffers( - get_levelled_engine(), out_aligned + out_offset, - ct0_aligned + ct0_offset, lwe_dimension)); + concrete_cpu_negate_lwe_ciphertext_u64( + out_aligned + out_offset, ct0_aligned + ct0_offset, lwe_dimension); } -void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned, - uint64_t out_offset, uint64_t out_size, - uint64_t out_stride, uint64_t *ct0_allocated, - uint64_t *ct0_aligned, uint64_t ct0_offset, - uint64_t ct0_size, uint64_t ct0_stride, - uint32_t level, uint32_t base_log, - uint32_t input_lwe_dim, uint32_t output_lwe_dim, - mlir::concretelang::RuntimeContext *context) { - CAPI_ASSERT_ERROR( - default_engine_discard_keyswitch_lwe_ciphertext_u64_raw_ptr_buffers( - get_engine(context), get_keyswitch_key_u64(context), - out_aligned + out_offset, ct0_aligned + ct0_offset)); +void memref_keyswitch_lwe_u64( + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset, + uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated, + uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, + uint64_t ct0_stride, uint32_t decomposition_level_count, + uint32_t decomposition_base_log, uint32_t input_dimension, + uint32_t output_dimension, mlir::concretelang::RuntimeContext *context) { + assert(out_stride == 1 && ct0_stride == 1); + // Get keyswitch key - TODO Give a non hardcoded keyID + const uint64_t *keyswitch_key = context->keyswitch_key_buffer(0); + // Get stack parameter + concrete_cpu_keyswitch_lwe_ciphertext_u64( + out_aligned + out_offset, ct0_aligned + ct0_offset, keyswitch_key, + decomposition_level_count, decomposition_base_log, input_dimension, + output_dimension); } void memref_batched_keyswitch_lwe_u64( @@ -507,24 +524,44 @@ void memref_bootstrap_lwe_u64( uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size, uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride, - uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level, - uint32_t base_log, uint32_t glwe_dim, uint32_t precision, + uint32_t input_lwe_dimension, uint32_t polynomial_size, + uint32_t decomposition_level_count, uint32_t decomposition_base_log, + uint32_t glwe_dimension, uint32_t precision, mlir::concretelang::RuntimeContext *context) { - uint64_t glwe_ct_size = poly_size * (glwe_dim + 1); + uint64_t glwe_ct_size = polynomial_size * (glwe_dimension + 1); uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size * sizeof(uint64_t)); + auto tlu = tlu_aligned + tlu_offset; - CAPI_ASSERT_ERROR( - default_engine_discard_trivially_encrypt_glwe_ciphertext_u64_raw_ptr_buffers( - get_levelled_engine(), glwe_ct, glwe_ct_size, - tlu_aligned + tlu_offset, poly_size)); + // Glwe trivial encryption + for (size_t i = 0; i < polynomial_size * glwe_dimension; i++) { + glwe_ct[i] = 0; + } + for (size_t i = 0; i < polynomial_size; i++) { + glwe_ct[polynomial_size * glwe_dimension + i] = tlu[i]; + } + + // Get fourrier bootstrap key - TODO Give a non hardcoded keyID + size_t keyId = 0; + const auto &fft = context->fft(keyId); + auto bootstrap_key = context->fourier_bootstrap_key_buffer(keyId); + // Get stack parameter + size_t scratch_size; + size_t scratch_align; + concrete_cpu_bootstrap_lwe_ciphertext_u64_scratch( + &scratch_size, &scratch_align, glwe_dimension, polynomial_size, fft); + // Allocate scratch + auto scratch = (uint8_t *)aligned_alloc(scratch_align, scratch_size); + + // Bootstrap + concrete_cpu_bootstrap_lwe_ciphertext_u64( + out_aligned + out_offset, ct0_aligned + ct0_offset, glwe_ct, + bootstrap_key, decomposition_level_count, decomposition_base_log, + glwe_dimension, polynomial_size, input_lwe_dimension, fft, scratch, + scratch_size); - CAPI_ASSERT_ERROR( - fft_engine_lwe_ciphertext_discarding_bootstrap_u64_raw_ptr_buffers( - get_fft_engine(context), get_engine(context), - get_fft_fourier_bootstrap_key_u64(context), out_aligned + out_offset, - ct0_aligned + ct0_offset, glwe_ct)); free(glwe_ct); + free(scratch); } void memref_batched_bootstrap_lwe_u64( @@ -563,13 +600,16 @@ void memref_wop_pbs_crt_buffer( uint64_t in_stride_1, // clear text lut 1D memref uint64_t *lut_ct_allocated, uint64_t *lut_ct_aligned, - uint64_t lut_ct_offset, uint64_t lut_ct_size, uint64_t lut_ct_stride, + uint64_t lut_ct_offset, uint64_t lut_ct_size0, uint64_t lut_ct_size1, + uint64_t lut_ct_stride0, uint64_t lut_ct_stride1, // CRT decomposition 1D memref uint64_t *crt_decomp_allocated, uint64_t *crt_decomp_aligned, uint64_t crt_decomp_offset, uint64_t crt_decomp_size, uint64_t crt_decomp_stride, // Additional crypto parameters uint32_t lwe_small_size, uint32_t cbs_level_count, uint32_t cbs_base_log, + uint32_t ksk_level_count, uint32_t ksk_base_log, uint32_t bsk_level_count, + uint32_t bsk_base_log, uint32_t fpksk_level_count, uint32_t fpksk_base_log, uint32_t polynomial_size, // runtime context that hold evluation keys mlir::concretelang::RuntimeContext *context) { @@ -585,7 +625,15 @@ void memref_wop_pbs_crt_buffer( // Check for the size S assert(out_size_1 == in_size_1); + uint64_t lwe_small_dim = lwe_small_size - 1; + + assert(out_size_1 == in_size_1); uint64_t lwe_big_size = in_size_1; + uint64_t lwe_big_dim = lwe_big_size - 1; + assert(lwe_big_dim % polynomial_size == 0); + + assert(lwe_big_dim % polynomial_size == 0); + uint64_t glwe_dim = lwe_big_dim / polynomial_size; // Compute the numbers of bits to extract for each block and the total one. uint64_t total_number_of_bits_per_block = 0; @@ -609,38 +657,83 @@ void memref_wop_pbs_crt_buffer( new uint64_t[lwe_small_size * total_number_of_bits_per_block]{0}; // We make a private copy to apply a subtraction on the body - auto first_cyphertext = in_aligned + in_offset; + auto first_ciphertext = in_aligned + in_offset; auto copy_size = crt_decomp_size * lwe_big_size; - std::vector in_copy(first_cyphertext, first_cyphertext + copy_size); + std::vector in_copy(first_ciphertext, first_ciphertext + copy_size); // Extraction of each bit for each block + + size_t fftKeyId = 0; + const auto &fft = context->fft(fftKeyId); + size_t bskKeyId = 0; + auto bootstrap_key = context->fourier_bootstrap_key_buffer(bskKeyId); + + size_t kskKeyId = 0; + auto keyswicth_key = context->keyswitch_key_buffer(kskKeyId); + for (int64_t i = crt_decomp_size - 1, extract_bits_output_offset = 0; i >= 0; extract_bits_output_offset += number_of_bits_per_block[i--]) { auto nb_bits_to_extract = number_of_bits_per_block[i]; - auto delta_log = 64 - nb_bits_to_extract; + size_t delta_log = 64 - nb_bits_to_extract; + auto in_block = &in_copy[lwe_big_size * i]; // trick ( ct - delta/2 + delta/2^4 ) uint64_t sub = (uint64_t(1) << (uint64_t(64) - nb_bits_to_extract - 1)) - (uint64_t(1) << (uint64_t(64) - nb_bits_to_extract - 5)); in_block[lwe_big_size - 1] -= sub; - CAPI_ASSERT_ERROR( - fft_engine_lwe_ciphertext_discarding_bit_extraction_unchecked_u64_raw_ptr_buffers( - context->get_fft_engine(), context->get_default_engine(), - context->get_fft_fourier_bsk(), context->get_ksk(), - &extract_bits_output_buffer[lwe_small_size * - extract_bits_output_offset], - in_block, nb_bits_to_extract, delta_log)); + + size_t scratch_size; + size_t scratch_align; + concrete_cpu_extract_bit_lwe_ciphertext_u64_scratch( + &scratch_size, &scratch_align, lwe_small_dim, lwe_big_dim, glwe_dim, + polynomial_size, fft); + // Allocate scratch + auto *scratch = (uint8_t *)aligned_alloc(scratch_align, scratch_size); + + concrete_cpu_extract_bit_lwe_ciphertext_u64( + &extract_bits_output_buffer[lwe_small_size * + extract_bits_output_offset], + in_block, bootstrap_key, keyswicth_key, lwe_small_dim, + nb_bits_to_extract, lwe_big_dim, nb_bits_to_extract, delta_log, + bsk_level_count, bsk_base_log, glwe_dim, polynomial_size, lwe_small_dim, + ksk_level_count, ksk_base_log, lwe_big_dim, lwe_small_dim, fft, scratch, + scratch_size); + + free(scratch); } + size_t ct_in_count = total_number_of_bits_per_block; + size_t lut_size = 1 << ct_in_count; + size_t ct_out_count = out_size_0; + size_t lut_count = ct_out_count; + + assert(lut_ct_size0 == lut_count); + assert(lut_ct_size1 == lut_size); + // Vertical packing - CAPI_ASSERT_ERROR( - fft_engine_lwe_ciphertext_vector_discarding_circuit_bootstrap_boolean_vertical_packing_u64_raw_ptr_buffers( - context->get_fft_engine(), context->get_default_engine(), - context->get_fft_fourier_bsk(), out_aligned, lwe_big_size, - crt_decomp_size, extract_bits_output_buffer, lwe_small_size, - total_number_of_bits_per_block, lut_ct_aligned + lut_ct_offset, - lut_ct_size, cbs_level_count, cbs_base_log, context->get_fpksk())); + size_t scratch_size; + size_t scratch_align; + concrete_cpu_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_u64_scratch( + &scratch_size, &scratch_align, ct_out_count, lwe_small_dim, ct_in_count, + lut_size, lut_count, glwe_dim, polynomial_size, polynomial_size, + cbs_level_count, fft); + + auto *scratch = (uint8_t *)aligned_alloc(scratch_align, scratch_size); + + size_t fpkskKeyId = 0; + auto fp_keyswicth_key = context->fp_keyswitch_key_buffer(fpkskKeyId); + + concrete_cpu_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_u64( + out_aligned + out_offset, extract_bits_output_buffer, + lut_ct_aligned + lut_ct_offset, bootstrap_key, fp_keyswicth_key, + lwe_big_dim, ct_out_count, lwe_small_dim, ct_in_count, lut_size, + lut_count, bsk_level_count, bsk_base_log, glwe_dim, polynomial_size, + lwe_small_dim, fpksk_level_count, fpksk_base_log, lwe_big_dim, glwe_dim, + polynomial_size, glwe_dim + 1, cbs_level_count, cbs_base_log, fft, + scratch, scratch_size); + + free(scratch); } void memref_copy_one_rank(uint64_t *src_allocated, uint64_t *src_aligned, diff --git a/compiler/lib/ServerLib/ServerLambda.cpp b/compiler/lib/ServerLib/ServerLambda.cpp index 302d8cee2..af0a819e6 100644 --- a/compiler/lib/ServerLib/ServerLambda.cpp +++ b/compiler/lib/ServerLib/ServerLambda.cpp @@ -9,6 +9,7 @@ #include "concretelang/ClientLib/Serializers.h" #include "concretelang/Common/Error.h" +#include "concretelang/Runtime/context.h" #include "concretelang/ServerLib/DynamicArityCall.h" #include "concretelang/ServerLib/DynamicModule.h" #include "concretelang/ServerLib/DynamicRankCall.h" @@ -23,8 +24,8 @@ using concretelang::clientlib::CircuitGate; using concretelang::clientlib::CircuitGateShape; using concretelang::clientlib::EvaluationKeys; using concretelang::clientlib::PublicArguments; -using concretelang::clientlib::RuntimeContext; using concretelang::error::StringError; +using mlir::concretelang::RuntimeContext; outcome::checked ServerLambda::loadFromModule(std::shared_ptr module, @@ -70,8 +71,7 @@ ServerLambda::call(PublicArguments &args, EvaluationKeys &evaluationKeys) { std::vector preparedArgs(args.preparedArgs.begin(), args.preparedArgs.end()); - RuntimeContext runtimeContext; - runtimeContext.evaluationKeys = evaluationKeys; + RuntimeContext runtimeContext(evaluationKeys); preparedArgs.push_back((void *)&runtimeContext); assert(clientParameters.outputs.size() == 1 && diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index 7a4cd81ba..66e647840 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -37,5 +37,6 @@ add_mlir_library( ${LLVM_PTHREAD_LIB} ConcretelangRuntime ConcretelangClientLib - ConcretelangServerLib - Concrete) + ConcretelangServerLib) + +target_include_directories(ConcretelangSupport PUBLIC ${CONCRETE_CPU_INCLUDE_DIR}) diff --git a/compiler/lib/Support/CompilationFeedback.cpp b/compiler/lib/Support/CompilationFeedback.cpp index fb9a7b8f5..d5aca8cd3 100644 --- a/compiler/lib/Support/CompilationFeedback.cpp +++ b/compiler/lib/Support/CompilationFeedback.cpp @@ -3,6 +3,7 @@ // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. +#include #include #include "boost/outcome.h" @@ -18,30 +19,29 @@ void CompilationFeedback::fillFromClientParameters( // Compute the size of secret keys totalSecretKeysSize = 0; for (auto sk : params.secretKeys) { - totalSecretKeysSize += sk.second.byteSize(); + totalSecretKeysSize += sk.byteSize(); } // Compute the boostrap keys size totalBootstrapKeysSize = 0; - for (auto bsk : params.bootstrapKeys) { - auto bskParam = bsk.second; - auto inputKey = params.secretKeys.find(bskParam.inputSecretKeyID); - assert(inputKey != params.secretKeys.end()); - auto outputKey = params.secretKeys.find(bskParam.outputSecretKeyID); - assert(outputKey != params.secretKeys.end()); + for (auto bskParam : params.bootstrapKeys) { + assert(bskParam.inputSecretKeyID < params.secretKeys.size()); + auto inputKey = params.secretKeys[bskParam.inputSecretKeyID]; - totalBootstrapKeysSize += bskParam.byteSize(inputKey->second.lweSize(), - outputKey->second.lweSize()); + assert(bskParam.outputSecretKeyID < params.secretKeys.size()); + auto outputKey = params.secretKeys[bskParam.outputSecretKeyID]; + + totalBootstrapKeysSize += + bskParam.byteSize(inputKey.lweSize(), outputKey.lweSize()); } // Compute the keyswitch keys size totalKeyswitchKeysSize = 0; - for (auto ksk : params.keyswitchKeys) { - auto kskParam = ksk.second; - auto inputKey = params.secretKeys.find(kskParam.inputSecretKeyID); - assert(inputKey != params.secretKeys.end()); - auto outputKey = params.secretKeys.find(kskParam.outputSecretKeyID); - assert(outputKey != params.secretKeys.end()); - totalKeyswitchKeysSize += kskParam.byteSize(inputKey->second.lweSize(), - outputKey->second.lweSize()); + for (auto kskParam : params.keyswitchKeys) { + assert(kskParam.inputSecretKeyID < params.secretKeys.size()); + auto inputKey = params.secretKeys[kskParam.inputSecretKeyID]; + assert(kskParam.outputSecretKeyID < params.secretKeys.size()); + auto outputKey = params.secretKeys[kskParam.outputSecretKeyID]; + totalKeyswitchKeysSize += + kskParam.byteSize(inputKey.lweSize(), outputKey.lweSize()); } // Compute the size of inputs totalInputsSize = 0; diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index d4bffa1fa..aefa96bf0 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -13,10 +13,11 @@ #include #include "concretelang/Common/BitsSize.h" -#include -#include -#include -#include +#include "concretelang/Runtime/DFRuntime.hpp" +#include "concretelang/Runtime/context.h" +#include "concretelang/Support/Error.h" +#include "concretelang/Support/Jit.h" +#include "concretelang/Support/logging.h" namespace mlir { namespace concretelang { @@ -133,8 +134,7 @@ JITLambda::call(clientlib::PublicArguments &args, rawArgs[i++] = &arg; } - RuntimeContext runtimeContext; - runtimeContext.evaluationKeys = evaluationKeys; + mlir::concretelang::RuntimeContext runtimeContext(evaluationKeys); // Pointer on runtime context, the rawArgs take pointer on actual value that // is passed to the compiled function. auto rtCtxPtr = &runtimeContext; diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 99afc590a..56a909710 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -250,11 +250,10 @@ lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, auto dec = fheContext.value().parameter.largeInteger.value().crtDecomposition; auto mods = mlir::SmallVector(dec.begin(), dec.end()); - auto polySize = fheContext.value().parameter.getPolynomialSize(); addPotentiallyNestedPass( pm, mlir::concretelang::createConvertFHEToTFHECrtPass( - mlir::concretelang::CrtLoweringParameters(mods, polySize)), + mlir::concretelang::CrtLoweringParameters(mods)), enablePass); } else if (fheContext.hasValue()) { pipelinePrinting("FHEToTFHEScalar", pm, context); diff --git a/compiler/lib/Support/V0ClientParameters.cpp b/compiler/lib/Support/V0ClientParameters.cpp index 00125da8a..1871e5a2f 100644 --- a/compiler/lib/Support/V0ClientParameters.cpp +++ b/compiler/lib/Support/V0ClientParameters.cpp @@ -2,6 +2,7 @@ // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. +#include #include #include @@ -161,59 +162,56 @@ createClientParametersForV0(V0FHEContext fheContext, Variance keyswitchKeyVariance = v0Curve->getVariance(1, v0Param.nSmall, 64); // Static client parameters from global parameters for v0 ClientParameters c; - c.secretKeys = { - {clientlib::BIG_KEY, {/*.size = */ v0Param.getNBigLweDimension()}}, - }; + + assert(c.secretKeys.size() == clientlib::BIG_KEY); + clientlib::LweSecretKeyParam skParam; + skParam.dimension = v0Param.getNBigLweDimension(); + c.secretKeys.push_back(skParam); + bool has_small_key = v0Param.nSmall != 0; bool has_bootstrap = v0Param.brLevel != 0; if (has_small_key) { - c.secretKeys.insert({clientlib::SMALL_KEY, {/*.size = */ v0Param.nSmall}}); + assert(c.secretKeys.size() == clientlib::SMALL_KEY); + clientlib::LweSecretKeyParam skParam2; + skParam2.dimension = v0Param.nSmall; + c.secretKeys.push_back(skParam2); } if (has_bootstrap) { auto inputKey = (has_small_key) ? clientlib::SMALL_KEY : clientlib::BIG_KEY; - c.bootstrapKeys = { - { - clientlib::BOOTSTRAP_KEY, - { - /*.inputSecretKeyID = */ inputKey, - /*.outputSecretKeyID = */ clientlib::BIG_KEY, - /*.level = */ v0Param.brLevel, - /*.baseLog = */ v0Param.brLogBase, - /*.glweDimension = */ v0Param.glweDimension, - /*.variance = */ bootstrapKeyVariance, - }, - }, - }; + clientlib::BootstrapKeyParam bskParam; + bskParam.inputSecretKeyID = inputKey; + bskParam.outputSecretKeyID = clientlib::BIG_KEY; + bskParam.level = v0Param.brLevel; + bskParam.baseLog = v0Param.brLogBase; + bskParam.glweDimension = v0Param.glweDimension; + bskParam.variance = bootstrapKeyVariance; + bskParam.polynomialSize = v0Param.getPolynomialSize(); + bskParam.inputLweDimension = v0Param.nSmall; + c.bootstrapKeys.push_back(bskParam); } if (v0Param.largeInteger.hasValue()) { - clientlib::PackingKeySwitchParam param; + clientlib::PackingKeyswitchKeyParam param; param.inputSecretKeyID = clientlib::BIG_KEY; param.outputSecretKeyID = clientlib::BIG_KEY; param.level = v0Param.largeInteger->wopPBS.packingKeySwitch.level; param.baseLog = v0Param.largeInteger->wopPBS.packingKeySwitch.baseLog; - param.bootstrapKeyID = clientlib::BOOTSTRAP_KEY; + + param.glweDimension = v0Param.glweDimension; + param.polynomialSize = v0Param.getPolynomialSize(); + param.inputLweDimension = v0Param.getNBigLweDimension(); param.variance = v0Curve->getVariance(v0Param.glweDimension, v0Param.getPolynomialSize(), 64); - c.packingKeys = { - { - "fpksk_v0", - param, - }, - }; + + c.packingKeyswitchKeys.push_back(param); } if (has_small_key) { - c.keyswitchKeys = { - { - clientlib::KEYSWITCH_KEY, - { - /*.inputSecretKeyID = */ clientlib::BIG_KEY, - /*.outputSecretKeyID = */ clientlib::SMALL_KEY, - /*.level = */ v0Param.ksLevel, - /*.baseLog = */ v0Param.ksLogBase, - /*.variance = */ keyswitchKeyVariance, - }, - }, - }; + clientlib::KeyswitchKeyParam kskParam; + kskParam.inputSecretKeyID = clientlib::BIG_KEY; + kskParam.outputSecretKeyID = clientlib::SMALL_KEY; + kskParam.level = v0Param.ksLevel; + kskParam.baseLog = v0Param.ksLogBase; + kskParam.variance = keyswitchKeyVariance; + c.keyswitchKeys.push_back(kskParam); } c.functionName = (std::string)functionName; diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/apply_univariate.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/apply_univariate.mlir index 9e9c17899..c229754da 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/apply_univariate.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/apply_univariate.mlir @@ -1,8 +1,8 @@ // RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s // CHECK: func.func @apply_lookup_table(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{3}>> -// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64> -// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bootstrapBaseLog = -1 : i32, bootstrapLevel = -1 : i32, circuitBootstrapBaseLog = -1 : i32, circuitBootstrapLevel = -1 : i32, crtDecomposition = [], keyswitchBaseLog = -1 : i32, keyswitchLevel = -1 : i32, packingKeySwitchBaseLog = -1 : i32, packingKeySwitchInputLweDimension = -1 : i32, packingKeySwitchLevel = -1 : i32, packingKeySwitchoutputPolynomialSize = -1 : i32} : (tensor<5x!TFHE.glwe<{_,_,_}{2}>>, tensor<40960xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{3}>> +// CHECK-NEXT: %0 = "TFHE.encode_lut_for_crt_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32} : (tensor<4xi64>) -> tensor<5x8192xi64> +// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bootstrapBaseLog = -1 : i32, bootstrapLevel = -1 : i32, circuitBootstrapBaseLog = -1 : i32, circuitBootstrapLevel = -1 : i32, crtDecomposition = [], keyswitchBaseLog = -1 : i32, keyswitchLevel = -1 : i32, packingKeySwitchBaseLog = -1 : i32, packingKeySwitchInputLweDimension = -1 : i32, packingKeySwitchLevel = -1 : i32, packingKeySwitchoutputPolynomialSize = -1 : i32} : (tensor<5x!TFHE.glwe<{_,_,_}{2}>>, tensor<5x8192xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{3}>> // CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<{_,_,_}{3}>> func.func @apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE.eint<3> { %1 = "FHE.apply_lookup_table"(%arg0, %arg1): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<3>) diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/apply_univariate_cst.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/apply_univariate_cst.mlir index 1ad7e30f6..a7358906b 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/apply_univariate_cst.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/apply_univariate_cst.mlir @@ -2,8 +2,8 @@ // CHECK: func.func @apply_lookup_table_cst(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> { // CHECK-NEXT: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64> -// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_woppbs"(%cst) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<128xi64>) -> tensor<40960xi64> -// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bootstrapBaseLog = -1 : i32, bootstrapLevel = -1 : i32, circuitBootstrapBaseLog = -1 : i32, circuitBootstrapLevel = -1 : i32, crtDecomposition = [], keyswitchBaseLog = -1 : i32, keyswitchLevel = -1 : i32, packingKeySwitchBaseLog = -1 : i32, packingKeySwitchInputLweDimension = -1 : i32, packingKeySwitchLevel = -1 : i32, packingKeySwitchoutputPolynomialSize = -1 : i32} : (tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<40960xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> +// CHECK-NEXT: %0 = "TFHE.encode_lut_for_crt_woppbs"(%cst) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32} : (tensor<128xi64>) -> tensor<5x8192xi64> +// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bootstrapBaseLog = -1 : i32, bootstrapLevel = -1 : i32, circuitBootstrapBaseLog = -1 : i32, circuitBootstrapLevel = -1 : i32, crtDecomposition = [], keyswitchBaseLog = -1 : i32, keyswitchLevel = -1 : i32, packingKeySwitchBaseLog = -1 : i32, packingKeySwitchInputLweDimension = -1 : i32, packingKeySwitchLevel = -1 : i32, packingKeySwitchoutputPolynomialSize = -1 : i32} : (tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<5x8192xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> // CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<{_,_,_}{7}>> func.func @apply_lookup_table_cst(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { %tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64> diff --git a/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_expand_lut_for_woppbs.mlir b/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_expand_lut_for_woppbs.mlir deleted file mode 100644 index b37612ac3..000000000 --- a/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_expand_lut_for_woppbs.mlir +++ /dev/null @@ -1,10 +0,0 @@ -// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s - -// CHECK: func.func @main(%arg0: tensor<4xi64>) -> tensor<40960xi64> { -// CHECK-NEXT: %0 = "Concrete.encode_expand_lut_for_woppbs_tensor"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64> -// CHECK-NEXT: return %0 : tensor<40960xi64> -// CHECK-NEXT: } -func.func @main(%arg1: tensor<4xi64>) -> tensor<40960xi64> { - %0 = "TFHE.encode_expand_lut_for_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64> - return %0: tensor<40960xi64> -} diff --git a/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_lut_for_woppbs.mlir b/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_lut_for_woppbs.mlir new file mode 100644 index 000000000..d647061e0 --- /dev/null +++ b/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_lut_for_woppbs.mlir @@ -0,0 +1,10 @@ +// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s + +// CHECK: func.func @main(%arg0: tensor<4xi64>) -> tensor<5x8192xi64> { +// CHECK-NEXT: %0 = "Concrete.encode_lut_for_crt_woppbs_tensor"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32} : (tensor<4xi64>) -> tensor<5x8192xi64> +// CHECK-NEXT: return %0 : tensor<5x8192xi64> +// CHECK-NEXT: } +func.func @main(%arg1: tensor<4xi64>) -> tensor<5x8192xi64> { + %0 = "TFHE.encode_lut_for_crt_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32} : (tensor<4xi64>) -> tensor<5x8192xi64> + return %0: tensor<5x8192xi64> +} diff --git a/compiler/tests/end_to_end_tests/end_to_end_test.cc b/compiler/tests/end_to_end_tests/end_to_end_test.cc index 790d43bfa..87f1f0b47 100644 --- a/compiler/tests/end_to_end_tests/end_to_end_test.cc +++ b/compiler/tests/end_to_end_tests/end_to_end_test.cc @@ -3,6 +3,7 @@ #include #include +#include "concretelang/ClientLib/Serializers.h" #include "concretelang/Support/CompilationFeedback.h" #include "concretelang/Support/JITSupport.h" #include "concretelang/Support/LibrarySupport.h" @@ -72,6 +73,13 @@ public: void testOnce() { auto evaluationKeys = keySet->evaluationKeys(); + + /* Serialize and unserialize evaluation keys */ + std::stringstream stream; + stream << evaluationKeys; + stream.seekg(0, std::ios::beg); + evaluationKeys = concretelang::clientlib::readEvaluationKeys(stream); + /* Call the server lambda */ auto publicResult = support.serverCall(serverLambda, *publicArguments, evaluationKeys); diff --git a/compiler/tests/unit_tests/concretelang/ClientLib/ClientParameters.cpp b/compiler/tests/unit_tests/concretelang/ClientLib/ClientParameters.cpp index bd170a5b6..c1fbcced0 100644 --- a/compiler/tests/unit_tests/concretelang/ClientLib/ClientParameters.cpp +++ b/compiler/tests/unit_tests/concretelang/ClientLib/ClientParameters.cpp @@ -1,3 +1,4 @@ +#include #include #include "concretelang/ClientLib/ClientParameters.h" @@ -8,38 +9,42 @@ namespace clientlib = concretelang::clientlib; TEST(Support, client_parameters_json_serde) { clientlib::ClientParameters params0; - params0.secretKeys = { - {clientlib::SMALL_KEY, {/*.size = */ 12}}, - {clientlib::BIG_KEY, {/*.size = */ 14}}, - }; - params0.bootstrapKeys = { - {clientlib::BOOTSTRAP_KEY, - {/*.inputSecretKeyID = */ clientlib::SMALL_KEY, - /*.outputSecretKeyID = */ clientlib::BIG_KEY, - /*.level = */ 1, - /*.baseLog = */ 2, - /*.glweDimension = */ 3, - /*.variance = */ 0.001}}, - {"wtf_bsk_v0", - { - /*.inputSecretKeyID = */ clientlib::BIG_KEY, - /*.outputSecretKeyID = */ clientlib::SMALL_KEY, - /*.level = */ 3, - /*.baseLog = */ 2, - /*.glweDimension = */ 1, - /*.variance = */ 0.0001, - }}, - }; - params0.keyswitchKeys = {{clientlib::KEYSWITCH_KEY, - { - /*.inputSecretKeyID = */ - clientlib::BIG_KEY, - /*.outputSecretKeyID = */ - clientlib::SMALL_KEY, - /*.level = */ 1, - /*.baseLog = */ 2, - /*.variance = */ 3, - }}}; + assert(params0.secretKeys.size() == clientlib::BIG_KEY); + params0.secretKeys.push_back({14}); + + assert(params0.secretKeys.size() == clientlib::SMALL_KEY); + params0.secretKeys.push_back({12}); + + params0.bootstrapKeys.push_back({ + /*.inputSecretKeyID = */ clientlib::SMALL_KEY, + /*.outputSecretKeyID = */ clientlib::BIG_KEY, + /*.level = */ 1, + /*.baseLog = */ 2, + /*.glweDimension = */ 3, + /*.variance = */ 0.001, + /*.polynomialSize = */ 1024, + /*.inputLweDimension = */ 600, + }); + + params0.bootstrapKeys.push_back({ + /*.inputSecretKeyID = */ clientlib::BIG_KEY, + /*.outputSecretKeyID = */ clientlib::SMALL_KEY, + /*.level = */ 3, + /*.baseLog = */ 2, + /*.glweDimension = */ 1, + /*.variance = */ 0.0001, + /*.polynomialSize = */ 1024, + /*.inputLweDimension = */ 600, + }); + params0.keyswitchKeys.push_back({ + /*.inputSecretKeyID = */ + clientlib::BIG_KEY, + /*.outputSecretKeyID = */ + clientlib::SMALL_KEY, + /*.level = */ 1, + /*.baseLog = */ 2, + /*.variance = */ 3, + }); params0.inputs = { { /*.encryption = */ { diff --git a/compiler/tests/unit_tests/concretelang/ClientLib/KeySet.cpp b/compiler/tests/unit_tests/concretelang/ClientLib/KeySet.cpp index ab9ffd027..00d66eb70 100644 --- a/compiler/tests/unit_tests/concretelang/ClientLib/KeySet.cpp +++ b/compiler/tests/unit_tests/concretelang/ClientLib/KeySet.cpp @@ -3,6 +3,7 @@ #include "concrete/curves.h" #include "concretelang/ClientLib/ClientParameters.h" #include "concretelang/ClientLib/EncryptedArguments.h" +#include "concretelang/ClientLib/EvaluationKeys.h" #include "tests_tools/assert.h" namespace clientlib = concretelang::clientlib; @@ -19,9 +20,13 @@ TEST_P(KeySetTest, encrypt_decrypt) { auto clientParameters = GetParam(); + __uint128_t seed = 0; + // Generate the client keySet ASSERT_ASSIGN_OUTCOME_VALUE( - keySet, clientlib::KeySet::generate(clientParameters, 0, 0)); + keySet, + clientlib::KeySet::generate( + clientParameters, concretelang::clientlib::ConcreteCSPRNG(seed))); // Allocate the ciphertext uint64_t *ciphertext = nullptr; @@ -50,13 +55,13 @@ clientlib::ClientParameters generateClientParameterOneScalarOneScalar( clientlib::CRTDecomposition crtDecomposition) { // One secret key with the given dimension clientlib::ClientParameters params; - params.secretKeys.insert({clientlib::SMALL_KEY, {/*.dimension =*/dimension}}); + params.secretKeys.push_back({/*.dimension =*/dimension}); // One input and output encryption gate on the same secret key and encoded // with the same precision const auto v0Curve = concrete::getSecurityCurve(128, concrete::BINARY); clientlib::EncryptionGate encryption; - encryption.secretKeyID = clientlib::SMALL_KEY; + encryption.secretKeyID = clientlib::BIG_KEY; encryption.encoding.precision = precision; encryption.encoding.crt = crtDecomposition; encryption.variance = v0Curve->getVariance(1, dimension, 64);