mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
refactor: Integrate concrete-cpu and remove concrete-core
Co-authored-by: Mayeul@Zama <mayeul.debellabre@zama.ai>
This commit is contained in:
33
.github/workflows/aws_build_gpu.yml
vendored
33
.github/workflows/aws_build_gpu.yml
vendored
@@ -73,7 +73,7 @@ jobs:
|
||||
- name: Create build dir
|
||||
run: mkdir build
|
||||
|
||||
- name: Build compiler
|
||||
- name: Build and test compiler
|
||||
uses: addnab/docker-run-action@v3
|
||||
id: build-compiler
|
||||
with:
|
||||
@@ -94,34 +94,7 @@ jobs:
|
||||
set -e
|
||||
cd /compiler
|
||||
rm -rf /build/*
|
||||
make DATAFLOW_EXECUTION_ENABLED=ON CCACHE=ON Python3_EXECUTABLE=$PYTHON_EXEC BUILD_DIR=/build CUDA_SUPPORT=ON CUDA_PATH=${{ env.CUDA_PATH }} all python-package build-end-to-end-dataflow-tests
|
||||
mkdir -p /tmp/concrete_compiler/gpu_tests/
|
||||
make BINDINGS_PYTHON_ENABLED=OFF CCACHE=ON Python3_EXECUTABLE=$PYTHON_EXEC CUDA_SUPPORT=ON CUDA_PATH=${{ env.CUDA_PATH }} run-end-to-end-tests-gpu
|
||||
echo "Debug: ccache statistics (after the build):"
|
||||
ccache -s
|
||||
|
||||
- name: Archive python package
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: concrete-compiler-gpu
|
||||
path: build/wheels/*.whl
|
||||
|
||||
- name: Test compiler
|
||||
uses: addnab/docker-run-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
image: ${{ env.DOCKER_IMAGE_TEST }}
|
||||
username: ${{ secrets.GHCR_LOGIN }}
|
||||
password: ${{ secrets.GHCR_PASSWORD }}
|
||||
options: >-
|
||||
-v ${{ github.workspace }}/llvm-project:/llvm-project
|
||||
-v ${{ github.workspace }}/compiler:/compiler
|
||||
-v ${{ github.workspace }}/build:/build
|
||||
--gpus all
|
||||
shell: bash
|
||||
run: |
|
||||
set -e
|
||||
cd /compiler
|
||||
pip install pytest
|
||||
sed "s/pytest/python -m pytest/g" -i Makefile
|
||||
mkdir -p /tmp/concrete_compiler/gpu_tests/
|
||||
make DATAFLOW_EXECUTION_ENABLED=ON CCACHE=ON Python3_EXECUTABLE=$PYTHON_EXEC BUILD_DIR=/build run-end-to-end-tests-gpu
|
||||
chmod -R ugo+rwx /tmp/KeySetCache
|
||||
|
||||
5
.gitmodules
vendored
5
.gitmodules
vendored
@@ -13,3 +13,8 @@
|
||||
[submodule "compiler/parameter-curves"]
|
||||
path = compiler/parameter-curves
|
||||
url = git@github.com:zama-ai/parameter-curves.git
|
||||
shallow = true
|
||||
[submodule "compiler/concrete-cpu"]
|
||||
path = compiler/concrete-cpu
|
||||
url = git@github.com:zama-ai/concrete-cpu.git
|
||||
shallow = true
|
||||
|
||||
3
compiler/.gitignore
vendored
3
compiler/.gitignore
vendored
@@ -4,9 +4,6 @@ build*/
|
||||
*.mlir.script
|
||||
*.lit_test_times.txt
|
||||
|
||||
# core ffi artifacts
|
||||
concrete-core-ffi*
|
||||
|
||||
# Test-generated artifacts
|
||||
concrete-compiler_compilation_artifacts/
|
||||
py_test_lib_compile_and_run_custom_perror/
|
||||
|
||||
@@ -61,11 +61,23 @@ set(CONCRETELANG_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
|
||||
include_directories(${PROJECT_SOURCE_DIR}/parameter-curves/concrete-security-curves-cpp/include)
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
# Concrete FFI Configuration
|
||||
# Concrete CPU Configuration
|
||||
# -------------------------------------------------------------------------------
|
||||
include_directories(${CONCRETE_FFI_RELEASE})
|
||||
add_library(Concrete STATIC IMPORTED)
|
||||
set_target_properties(Concrete PROPERTIES IMPORTED_LOCATION ${CONCRETE_FFI_RELEASE}/libconcrete_core_ffi.a)
|
||||
set(CONCRETE_CPU_STATIC_LIB "${PROJECT_SOURCE_DIR}/concrete-cpu/target/release/libconcrete_cpu.a")
|
||||
ExternalProject_Add(
|
||||
concrete_cpu_rust
|
||||
DOWNLOAD_COMMAND ""
|
||||
CONFIGURE_COMMAND "" OUTPUT "${CONCRETE_CPU_STATIC_LIB}"
|
||||
BUILD_COMMAND cargo build
|
||||
COMMAND cargo build --release
|
||||
BINARY_DIR "${PROJECT_SOURCE_DIR}/concrete-cpu"
|
||||
INSTALL_COMMAND ""
|
||||
LOG_BUILD ON)
|
||||
add_library(concrete_cpu STATIC IMPORTED)
|
||||
# TODO - Change that to a location in the release dir
|
||||
set(CONCRETE_CPU_INCLUDE_DIR "${PROJECT_SOURCE_DIR}/concrete-cpu/concrete-cpu")
|
||||
set_target_properties(concrete_cpu PROPERTIES IMPORTED_LOCATION "${CONCRETE_CPU_STATIC_LIB}")
|
||||
add_dependencies(concrete_cpu concrete_cpu_rust)
|
||||
|
||||
# --------------------------------------------------------------------------------
|
||||
# Concrete Cuda Configuration
|
||||
|
||||
@@ -24,8 +24,6 @@ HPX_TARBALL=$(shell pwd)/hpx-$(HPX_VERSION).tar.gz
|
||||
HPX_LOCAL_DIR=$(shell pwd)/hpx-$(HPX_VERSION)
|
||||
HPX_INSTALL_DIR?=$(HPX_LOCAL_DIR)/build
|
||||
|
||||
CONCRETE_CORE_FFI_VERSION?=0.2.0
|
||||
|
||||
ML_BENCH_SUBSET_ID=
|
||||
|
||||
# Find OS
|
||||
@@ -50,8 +48,6 @@ else
|
||||
ARCHITECTURE=amd64
|
||||
endif
|
||||
|
||||
CONCRETE_CORE_FFI_TARBALL=concrete-core-ffi_$(CONCRETE_CORE_FFI_VERSION)_$(OS)_$(ARCHITECTURE).tar.gz
|
||||
|
||||
export PATH := $(abspath $(BUILD_DIR))/bin:$(PATH)
|
||||
|
||||
ifeq ($(shell which ccache),)
|
||||
@@ -87,23 +83,6 @@ endif
|
||||
|
||||
all: concretecompiler python-bindings build-tests build-benchmarks build-mlbench doc rust-bindings
|
||||
|
||||
# concrete-core-ffi #######################################
|
||||
|
||||
CONCRETE_CORE_FFI_FOLDER = $(shell pwd)/concrete-core-ffi-$(CONCRETE_CORE_FFI_VERSION)
|
||||
|
||||
CONCRETE_CORE_FFI_ARTIFACTS=$(CONCRETE_CORE_FFI_FOLDER)/libconcrete_core_ffi.a $(CONCRETE_CORE_FFI_FOLDER)/concrete-core-ffi.h
|
||||
|
||||
concrete-core-ffi: $(CONCRETE_CORE_FFI_ARTIFACTS)
|
||||
|
||||
$(CONCRETE_CORE_FFI_ARTIFACTS): $(CONCRETE_CORE_FFI_FOLDER)
|
||||
|
||||
$(CONCRETE_CORE_FFI_FOLDER): $(CONCRETE_CORE_FFI_TARBALL)
|
||||
mkdir -p $(CONCRETE_CORE_FFI_FOLDER)
|
||||
tar -xvzf $(CONCRETE_CORE_FFI_TARBALL) --directory $(CONCRETE_CORE_FFI_FOLDER)
|
||||
|
||||
$(CONCRETE_CORE_FFI_TARBALL):
|
||||
curl -L https://github.com/zama-ai/concrete-core/releases/download/concrete-core-ffi-$(CONCRETE_CORE_FFI_VERSION)/$(CONCRETE_CORE_FFI_TARBALL) -o $(CONCRETE_CORE_FFI_TARBALL)
|
||||
|
||||
# concrete-optimizer ######################################
|
||||
|
||||
LIB_CONCRETE_OPTIMIZER_CPP = $(CONCRETE_OPTIMIZER_DIR)/target/libconcrete_optimizer_cpp.a
|
||||
@@ -143,7 +122,6 @@ $(BUILD_DIR)/configured.stamp:
|
||||
-DCONCRETELANG_BINDINGS_PYTHON_ENABLED=$(BINDINGS_PYTHON_ENABLED) \
|
||||
-DCONCRETELANG_DATAFLOW_EXECUTION_ENABLED=$(DATAFLOW_EXECUTION_ENABLED) \
|
||||
-DCONCRETELANG_TIMING_ENABLED=$(TIMING_ENABLED) \
|
||||
-DCONCRETE_FFI_RELEASE=$(CONCRETE_CORE_FFI_FOLDER) \
|
||||
-DHPX_DIR=${HPX_INSTALL_DIR}/lib/cmake/HPX \
|
||||
-DLLVM_EXTERNAL_PROJECTS=concretelang \
|
||||
-DLLVM_EXTERNAL_CONCRETELANG_SOURCE_DIR=. \
|
||||
@@ -154,7 +132,7 @@ $(BUILD_DIR)/configured.stamp:
|
||||
-DCUDAToolkit_ROOT=$(CUDA_PATH)
|
||||
touch $@
|
||||
|
||||
build-initialized: concrete-optimizer-lib concrete-core-ffi $(BUILD_DIR)/configured.stamp
|
||||
build-initialized: concrete-optimizer-lib $(BUILD_DIR)/configured.stamp
|
||||
|
||||
doc: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target mlir-doc
|
||||
@@ -330,7 +308,7 @@ $(BENCHMARK_CPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml: tests/end_to_end
|
||||
$(BENCHMARK_CPU_DIR)/%.yaml: tests/end_to_end_fixture/%_gen.py
|
||||
$(Python3_EXECUTABLE) $< > $@
|
||||
|
||||
generate-cpu-benchmarks: $(BENCHMARK_CPU_DIR) $(BENCHMARK_CPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml $(BENCHMARK_CPU_DIR)/end_to_end_apply_lookup_table.yaml $(BENCHMARK_CPU_DIR)/end_to_end_leveled.yaml
|
||||
generate-cpu-benchmarks: $(BENCHMARK_CPU_DIR) $(BENCHMARK_CPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml $(BENCHMARK_CPU_DIR)/end_to_end_apply_lookup_table.yaml
|
||||
|
||||
SECURITY_TO_BENCH=128
|
||||
run-cpu-benchmarks: build-benchmarks generate-cpu-benchmarks
|
||||
@@ -530,7 +508,6 @@ install: concretecompiler concrete-optimizer-lib CAPI install-deps
|
||||
python-format \
|
||||
check-python-format \
|
||||
concrete-optimizer-lib \
|
||||
concrete-core-ffi \
|
||||
build-tests \
|
||||
run-tests \
|
||||
run-check-tests \
|
||||
@@ -540,7 +517,6 @@ install: concretecompiler concrete-optimizer-lib CAPI install-deps
|
||||
build-end-to-end-tests \
|
||||
build-end-to-end-dataflow-tests \
|
||||
run-end-to-end-dataflow-tests \
|
||||
concrete-core-ffi \
|
||||
opt \
|
||||
mlir-opt \
|
||||
mlir-cpu-runner \
|
||||
|
||||
1
compiler/concrete-cpu
Submodule
1
compiler/concrete-cpu
Submodule
Submodule compiler/concrete-cpu added at db262714cd
@@ -92,8 +92,7 @@ public:
|
||||
OUTCOME_TRY(auto clientArguments,
|
||||
EncryptedArguments::create(keySet, args...));
|
||||
|
||||
return clientArguments->exportPublicArguments(clientParameters,
|
||||
keySet.runtimeContext());
|
||||
return clientArguments->exportPublicArguments(clientParameters);
|
||||
}
|
||||
|
||||
outcome::checked<Result, StringError> decryptResult(KeySet &keySet,
|
||||
|
||||
@@ -35,10 +35,8 @@ namespace clientlib {
|
||||
|
||||
using concretelang::error::StringError;
|
||||
|
||||
const std::string SMALL_KEY = "small";
|
||||
const std::string BIG_KEY = "big";
|
||||
const std::string BOOTSTRAP_KEY = "bsk_v0";
|
||||
const std::string KEYSWITCH_KEY = "ksk_v0";
|
||||
const uint64_t SMALL_KEY = 1;
|
||||
const uint64_t BIG_KEY = 0;
|
||||
|
||||
const std::string CLIENT_PARAMETERS_EXT = ".concrete.params.json";
|
||||
|
||||
@@ -52,7 +50,7 @@ typedef std::vector<int64_t> CRTDecomposition;
|
||||
typedef uint64_t LweDimension;
|
||||
typedef uint64_t GlweDimension;
|
||||
|
||||
typedef std::string LweSecretKeyID;
|
||||
typedef uint64_t LweSecretKeyID;
|
||||
struct LweSecretKeyParam {
|
||||
LweDimension dimension;
|
||||
|
||||
@@ -66,7 +64,7 @@ static bool operator==(const LweSecretKeyParam &lhs,
|
||||
return lhs.dimension == rhs.dimension;
|
||||
}
|
||||
|
||||
typedef std::string BootstrapKeyID;
|
||||
typedef uint64_t BootstrapKeyID;
|
||||
struct BootstrapKeyParam {
|
||||
LweSecretKeyID inputSecretKeyID;
|
||||
LweSecretKeyID outputSecretKeyID;
|
||||
@@ -74,6 +72,8 @@ struct BootstrapKeyParam {
|
||||
DecompositionBaseLog baseLog;
|
||||
GlweDimension glweDimension;
|
||||
Variance variance;
|
||||
PolynomialSize polynomialSize;
|
||||
LweDimension inputLweDimension;
|
||||
|
||||
void hash(size_t &seed);
|
||||
|
||||
@@ -90,7 +90,7 @@ static inline bool operator==(const BootstrapKeyParam &lhs,
|
||||
lhs.glweDimension == rhs.glweDimension && lhs.variance == rhs.variance;
|
||||
}
|
||||
|
||||
typedef std::string KeyswitchKeyID;
|
||||
typedef uint64_t KeyswitchKeyID;
|
||||
struct KeyswitchKeyParam {
|
||||
LweSecretKeyID inputSecretKeyID;
|
||||
LweSecretKeyID outputSecretKeyID;
|
||||
@@ -112,22 +112,28 @@ static inline bool operator==(const KeyswitchKeyParam &lhs,
|
||||
lhs.variance == rhs.variance;
|
||||
}
|
||||
|
||||
typedef std::string PackingKeySwitchID;
|
||||
struct PackingKeySwitchParam {
|
||||
typedef uint64_t PackingKeyswitchKeyID;
|
||||
struct PackingKeyswitchKeyParam {
|
||||
LweSecretKeyID inputSecretKeyID;
|
||||
LweSecretKeyID outputSecretKeyID;
|
||||
BootstrapKeyID bootstrapKeyID;
|
||||
DecompositionLevelCount level;
|
||||
DecompositionBaseLog baseLog;
|
||||
GlweDimension glweDimension;
|
||||
PolynomialSize polynomialSize;
|
||||
LweDimension inputLweDimension;
|
||||
Variance variance;
|
||||
|
||||
void hash(size_t &seed);
|
||||
};
|
||||
static inline bool operator==(const PackingKeySwitchParam &lhs,
|
||||
const PackingKeySwitchParam &rhs) {
|
||||
static inline bool operator==(const PackingKeyswitchKeyParam &lhs,
|
||||
const PackingKeyswitchKeyParam &rhs) {
|
||||
return lhs.inputSecretKeyID == rhs.inputSecretKeyID &&
|
||||
lhs.outputSecretKeyID == rhs.outputSecretKeyID &&
|
||||
lhs.level == rhs.level && lhs.baseLog == rhs.baseLog;
|
||||
lhs.level == rhs.level && lhs.baseLog == rhs.baseLog &&
|
||||
lhs.glweDimension == rhs.glweDimension &&
|
||||
lhs.polynomialSize == rhs.polynomialSize &&
|
||||
lhs.variance == lhs.variance &&
|
||||
lhs.inputLweDimension == rhs.inputLweDimension;
|
||||
}
|
||||
|
||||
struct Encoding {
|
||||
@@ -185,13 +191,13 @@ struct CircuitGate {
|
||||
bool isEncrypted() { return encryption.hasValue(); }
|
||||
|
||||
/// byteSize returns the size in bytes for this gate.
|
||||
size_t byteSize(std::map<LweSecretKeyID, LweSecretKeyParam> secretKeys) {
|
||||
size_t byteSize(std::vector<LweSecretKeyParam> secretKeys) {
|
||||
auto width = shape.width;
|
||||
auto numElts = shape.size == 0 ? 1 : shape.size;
|
||||
if (isEncrypted()) {
|
||||
auto skParam = secretKeys.find(encryption->secretKeyID);
|
||||
assert(skParam != secretKeys.end());
|
||||
return 8 * skParam->second.lweSize() * numElts;
|
||||
assert(encryption->secretKeyID < secretKeys.size());
|
||||
auto skParam = secretKeys[encryption->secretKeyID];
|
||||
return 8 * skParam.lweSize() * numElts;
|
||||
}
|
||||
width = bitWidthAsWord(width) / 8;
|
||||
return width * numElts;
|
||||
@@ -203,10 +209,10 @@ static inline bool operator==(const CircuitGate &lhs, const CircuitGate &rhs) {
|
||||
}
|
||||
|
||||
struct ClientParameters {
|
||||
std::map<LweSecretKeyID, LweSecretKeyParam> secretKeys;
|
||||
std::map<BootstrapKeyID, BootstrapKeyParam> bootstrapKeys;
|
||||
std::map<KeyswitchKeyID, KeyswitchKeyParam> keyswitchKeys;
|
||||
std::map<PackingKeySwitchID, PackingKeySwitchParam> packingKeys;
|
||||
std::vector<LweSecretKeyParam> secretKeys;
|
||||
std::vector<BootstrapKeyParam> bootstrapKeys;
|
||||
std::vector<KeyswitchKeyParam> keyswitchKeys;
|
||||
std::vector<PackingKeyswitchKeyParam> packingKeyswitchKeys;
|
||||
std::vector<CircuitGate> inputs;
|
||||
std::vector<CircuitGate> outputs;
|
||||
std::string functionName;
|
||||
@@ -237,12 +243,9 @@ struct ClientParameters {
|
||||
if (!gate.encryption.hasValue()) {
|
||||
return StringError("gate is not encrypted");
|
||||
}
|
||||
auto secretKey = secretKeys.find(gate.encryption->secretKeyID);
|
||||
if (secretKey == secretKeys.end()) {
|
||||
return StringError("cannot find ")
|
||||
<< gate.encryption->secretKeyID << " in client parameters";
|
||||
}
|
||||
return secretKey->second;
|
||||
assert(gate.encryption->secretKeyID < secretKeys.size());
|
||||
auto secretKey = secretKeys[gate.encryption->secretKeyID];
|
||||
return secretKey;
|
||||
}
|
||||
|
||||
/// bufferSize returns the size of the whole buffer of a gate.
|
||||
@@ -309,8 +312,8 @@ bool fromJSON(const llvm::json::Value, BootstrapKeyParam &, llvm::json::Path);
|
||||
llvm::json::Value toJSON(const KeyswitchKeyParam &);
|
||||
bool fromJSON(const llvm::json::Value, KeyswitchKeyParam &, llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const PackingKeySwitchParam &);
|
||||
bool fromJSON(const llvm::json::Value, PackingKeySwitchParam &,
|
||||
llvm::json::Value toJSON(const PackingKeyswitchKeyParam &);
|
||||
bool fromJSON(const llvm::json::Value, PackingKeyswitchKeyParam &,
|
||||
llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const Encoding &);
|
||||
|
||||
@@ -61,8 +61,7 @@ public:
|
||||
/// arguments, i.e. move all buffers to the PublicArguments and reset the
|
||||
/// positional counter.
|
||||
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
|
||||
exportPublicArguments(ClientParameters clientParameters,
|
||||
RuntimeContext runtimeContext);
|
||||
exportPublicArguments(ClientParameters clientParameters);
|
||||
|
||||
/// Check that all arguments as been pushed.
|
||||
// TODO: Remove public method here
|
||||
|
||||
@@ -6,113 +6,143 @@
|
||||
#ifndef CONCRETELANG_CLIENTLIB_EVALUATION_KEYS_H_
|
||||
#define CONCRETELANG_CLIENTLIB_EVALUATION_KEYS_H_
|
||||
|
||||
#include <cassert>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "concrete-core-ffi.h"
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
|
||||
typedef struct LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64
|
||||
LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64;
|
||||
|
||||
int destroy_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_u64(
|
||||
LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *);
|
||||
struct Csprng;
|
||||
struct CsprngVtable;
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
// =============================================
|
||||
class CSPRNG {
|
||||
public:
|
||||
struct Csprng *ptr;
|
||||
const struct CsprngVtable *vtable;
|
||||
|
||||
/// Wrapper for `LweKeyswitchKey64` so that it cleans up properly.
|
||||
CSPRNG() = delete;
|
||||
CSPRNG(CSPRNG &) = delete;
|
||||
|
||||
CSPRNG(CSPRNG &&other) : ptr(other.ptr), vtable(other.vtable) {
|
||||
assert(ptr != nullptr);
|
||||
other.ptr = nullptr;
|
||||
};
|
||||
|
||||
CSPRNG(Csprng *ptr, const CsprngVtable *vtable) : ptr(ptr), vtable(vtable){};
|
||||
};
|
||||
|
||||
class ConcreteCSPRNG : public CSPRNG {
|
||||
public:
|
||||
ConcreteCSPRNG(__uint128_t seed);
|
||||
ConcreteCSPRNG() = delete;
|
||||
ConcreteCSPRNG(ConcreteCSPRNG &) = delete;
|
||||
ConcreteCSPRNG(ConcreteCSPRNG &&other);
|
||||
~ConcreteCSPRNG();
|
||||
};
|
||||
|
||||
/// @brief LweSecretKey implements tools for manipulating lwe secret key on
|
||||
/// client.
|
||||
class LweSecretKey {
|
||||
std::shared_ptr<std::vector<uint64_t>> _buffer;
|
||||
LweSecretKeyParam _parameters;
|
||||
|
||||
public:
|
||||
LweSecretKey() = delete;
|
||||
LweSecretKey(LweSecretKeyParam ¶meters, CSPRNG &csprng);
|
||||
LweSecretKey(std::shared_ptr<std::vector<uint64_t>> buffer,
|
||||
LweSecretKeyParam parameters)
|
||||
: _buffer(buffer), _parameters(parameters){};
|
||||
|
||||
/// @brief Encrypt the plaintext to the lwe ciphertext buffer.
|
||||
void encrypt(uint64_t *ciphertext, uint64_t plaintext, double variance,
|
||||
CSPRNG &csprng) const;
|
||||
|
||||
/// @brief Decrypt the ciphertext to the plaintext
|
||||
void decrypt(const uint64_t *ciphertext, uint64_t &plaintext) const;
|
||||
|
||||
/// @brief Returns the buffer that hold the keyswitch key.
|
||||
const uint64_t *buffer() const { return _buffer->data(); }
|
||||
size_t size() const { return _buffer->size(); }
|
||||
|
||||
/// @brief Returns the parameters of the keyswicth key.
|
||||
LweSecretKeyParam parameters() const { return this->_parameters; }
|
||||
|
||||
/// @brief Returns the lwe dimension of the secret key.
|
||||
size_t dimension() const { return parameters().dimension; }
|
||||
};
|
||||
|
||||
/// @brief LweKeyswitchKey implements tools for manipulating keyswitch key on
|
||||
/// client.
|
||||
class LweKeyswitchKey {
|
||||
private:
|
||||
LweKeyswitchKey64 *ksk;
|
||||
|
||||
protected:
|
||||
friend std::ostream &operator<<(std::ostream &ostream,
|
||||
const LweKeyswitchKey &wrappedKsk);
|
||||
friend std::istream &operator>>(std::istream &istream,
|
||||
LweKeyswitchKey &wrappedKsk);
|
||||
std::shared_ptr<std::vector<uint64_t>> _buffer;
|
||||
KeyswitchKeyParam _parameters;
|
||||
|
||||
public:
|
||||
LweKeyswitchKey(LweKeyswitchKey64 *ksk) : ksk{ksk} {}
|
||||
LweKeyswitchKey(LweKeyswitchKey &other) = delete;
|
||||
LweKeyswitchKey(LweKeyswitchKey &&other) : ksk{other.ksk} {
|
||||
other.ksk = nullptr;
|
||||
}
|
||||
~LweKeyswitchKey() {
|
||||
if (this->ksk != nullptr) {
|
||||
CAPI_ASSERT_ERROR(destroy_lwe_keyswitch_key_u64(this->ksk));
|
||||
LweKeyswitchKey() = delete;
|
||||
LweKeyswitchKey(KeyswitchKeyParam ¶meters, LweSecretKey &inputKey,
|
||||
LweSecretKey &outputKey, CSPRNG &csprng);
|
||||
LweKeyswitchKey(std::shared_ptr<std::vector<uint64_t>> buffer,
|
||||
KeyswitchKeyParam parameters)
|
||||
: _buffer(buffer), _parameters(parameters){};
|
||||
|
||||
this->ksk = nullptr;
|
||||
}
|
||||
}
|
||||
/// @brief Returns the buffer that hold the keyswitch key.
|
||||
const uint64_t *buffer() const { return _buffer->data(); }
|
||||
size_t size() const { return _buffer->size(); }
|
||||
|
||||
LweKeyswitchKey64 *get() { return this->ksk; }
|
||||
/// @brief Returns the parameters of the keyswicth key.
|
||||
KeyswitchKeyParam parameters() const { return this->_parameters; }
|
||||
};
|
||||
|
||||
// =============================================
|
||||
|
||||
/// Wrapper for `LweBootstrapKey64` so that it cleans up properly.
|
||||
/// @brief LweBootstrapKey implements tools for manipulating bootstrap key on
|
||||
/// client.
|
||||
class LweBootstrapKey {
|
||||
private:
|
||||
LweBootstrapKey64 *bsk;
|
||||
|
||||
protected:
|
||||
friend std::ostream &operator<<(std::ostream &ostream,
|
||||
const LweBootstrapKey &wrappedBsk);
|
||||
friend std::istream &operator>>(std::istream &istream,
|
||||
LweBootstrapKey &wrappedBsk);
|
||||
std::shared_ptr<std::vector<uint64_t>> _buffer;
|
||||
BootstrapKeyParam _parameters;
|
||||
|
||||
public:
|
||||
LweBootstrapKey(LweBootstrapKey64 *bsk) : bsk{bsk} {}
|
||||
LweBootstrapKey(LweBootstrapKey &other) = delete;
|
||||
LweBootstrapKey(LweBootstrapKey &&other) : bsk{other.bsk} {
|
||||
other.bsk = nullptr;
|
||||
}
|
||||
~LweBootstrapKey() {
|
||||
if (this->bsk != nullptr) {
|
||||
CAPI_ASSERT_ERROR(destroy_lwe_bootstrap_key_u64(this->bsk));
|
||||
this->bsk = nullptr;
|
||||
}
|
||||
}
|
||||
LweBootstrapKey() = delete;
|
||||
LweBootstrapKey(std::shared_ptr<std::vector<uint64_t>> buffer,
|
||||
BootstrapKeyParam ¶meters)
|
||||
: _buffer(buffer), _parameters(parameters){};
|
||||
LweBootstrapKey(BootstrapKeyParam ¶meters, LweSecretKey &inputKey,
|
||||
LweSecretKey &outputKey, CSPRNG &csprng);
|
||||
|
||||
LweBootstrapKey64 *get() { return this->bsk; }
|
||||
///// @brief Returns the buffer that hold the bootstrap key.
|
||||
const uint64_t *buffer() const { return _buffer->data(); }
|
||||
size_t size() const { return _buffer->size(); }
|
||||
|
||||
/// @brief Returns the parameters of the bootsrap key.
|
||||
BootstrapKeyParam parameters() const { return this->_parameters; }
|
||||
};
|
||||
|
||||
// =============================================
|
||||
|
||||
/// Wrapper for `LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64` so
|
||||
/// that it cleans up properly.
|
||||
/// @brief PackingKeyswitchKey implements tools for manipulating privat packing
|
||||
/// keyswitch key on client.
|
||||
class PackingKeyswitchKey {
|
||||
private:
|
||||
LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *key;
|
||||
|
||||
protected:
|
||||
friend std::ostream &operator<<(std::ostream &ostream,
|
||||
const PackingKeyswitchKey &wrappedFpksk);
|
||||
friend std::istream &operator>>(std::istream &istream,
|
||||
PackingKeyswitchKey &wrappedFpksk);
|
||||
std::shared_ptr<std::vector<uint64_t>> _buffer;
|
||||
PackingKeyswitchKeyParam _parameters;
|
||||
|
||||
public:
|
||||
PackingKeyswitchKey(
|
||||
LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *key)
|
||||
: key{key} {}
|
||||
PackingKeyswitchKey(PackingKeyswitchKey &other) = delete;
|
||||
PackingKeyswitchKey(PackingKeyswitchKey &&other) : key{other.key} {
|
||||
other.key = nullptr;
|
||||
}
|
||||
~PackingKeyswitchKey() {
|
||||
if (this->key != nullptr) {
|
||||
CAPI_ASSERT_ERROR(
|
||||
destroy_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_u64(
|
||||
this->key));
|
||||
this->key = nullptr;
|
||||
}
|
||||
}
|
||||
PackingKeyswitchKey() = delete;
|
||||
PackingKeyswitchKey(PackingKeyswitchKeyParam ¶meters,
|
||||
LweSecretKey &inputKey, LweSecretKey &outputKey,
|
||||
CSPRNG &csprng);
|
||||
PackingKeyswitchKey(std::shared_ptr<std::vector<uint64_t>> buffer,
|
||||
PackingKeyswitchKeyParam parameters)
|
||||
: _buffer(buffer), _parameters(parameters){};
|
||||
|
||||
LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *get() {
|
||||
return this->key;
|
||||
}
|
||||
/// @brief Returns the buffer that hold the keyswitch key.
|
||||
const uint64_t *buffer() const { return _buffer->data(); }
|
||||
size_t size() const { return _buffer->size(); }
|
||||
|
||||
/// @brief Returns the parameters of the keyswicth key.
|
||||
PackingKeyswitchKeyParam parameters() const { return this->_parameters; }
|
||||
};
|
||||
|
||||
// =============================================
|
||||
@@ -120,31 +150,40 @@ public:
|
||||
/// Evalution keys required for execution.
|
||||
class EvaluationKeys {
|
||||
private:
|
||||
std::shared_ptr<LweKeyswitchKey> sharedKsk;
|
||||
std::shared_ptr<LweBootstrapKey> sharedBsk;
|
||||
std::shared_ptr<PackingKeyswitchKey> sharedFpksk;
|
||||
|
||||
protected:
|
||||
friend std::ostream &operator<<(std::ostream &ostream,
|
||||
const EvaluationKeys &evaluationKeys);
|
||||
friend std::istream &operator>>(std::istream &istream,
|
||||
EvaluationKeys &evaluationKeys);
|
||||
std::vector<LweKeyswitchKey> keyswitchKeys;
|
||||
std::vector<LweBootstrapKey> bootstrapKeys;
|
||||
std::vector<PackingKeyswitchKey> packingKeyswitchKeys;
|
||||
|
||||
public:
|
||||
EvaluationKeys()
|
||||
: sharedKsk{std::shared_ptr<LweKeyswitchKey>(nullptr)},
|
||||
sharedBsk{std::shared_ptr<LweBootstrapKey>(nullptr)} {}
|
||||
EvaluationKeys() = delete;
|
||||
|
||||
EvaluationKeys(std::shared_ptr<LweKeyswitchKey> sharedKsk,
|
||||
std::shared_ptr<LweBootstrapKey> sharedBsk,
|
||||
std::shared_ptr<PackingKeyswitchKey> sharedFpksk)
|
||||
: sharedKsk{sharedKsk}, sharedBsk{sharedBsk}, sharedFpksk{sharedFpksk} {}
|
||||
EvaluationKeys(const std::vector<LweKeyswitchKey> keyswitchKeys,
|
||||
const std::vector<LweBootstrapKey> bootstrapKeys,
|
||||
const std::vector<PackingKeyswitchKey> packingKeyswitchKeys)
|
||||
: keyswitchKeys(keyswitchKeys), bootstrapKeys(bootstrapKeys),
|
||||
packingKeyswitchKeys(packingKeyswitchKeys) {}
|
||||
|
||||
LweKeyswitchKey64 *getKsk() { return this->sharedKsk->get(); }
|
||||
LweBootstrapKey64 *getBsk() { return this->sharedBsk->get(); }
|
||||
LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *getFpksk() {
|
||||
return this->sharedFpksk->get();
|
||||
const LweKeyswitchKey &getKeyswitchKey(size_t id) const {
|
||||
return this->keyswitchKeys[id];
|
||||
}
|
||||
const std::vector<LweKeyswitchKey> getKeyswitchKeys() const {
|
||||
return this->keyswitchKeys;
|
||||
}
|
||||
|
||||
const LweBootstrapKey &getBootstrapKey(size_t id) const {
|
||||
return bootstrapKeys[id];
|
||||
}
|
||||
const std::vector<LweBootstrapKey> getBootstrapKeys() const {
|
||||
return this->bootstrapKeys;
|
||||
}
|
||||
|
||||
const PackingKeyswitchKey &getPackingKeyswitchKey(size_t id) const {
|
||||
return this->packingKeyswitchKeys[id];
|
||||
};
|
||||
|
||||
const std::vector<PackingKeyswitchKey> getPackingKeyswitchKeys() const {
|
||||
return this->packingKeyswitchKeys;
|
||||
}
|
||||
};
|
||||
|
||||
// =============================================
|
||||
|
||||
@@ -10,30 +10,33 @@
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concrete-core-ffi.h"
|
||||
#include "concretelang/Runtime/DFRuntime.hpp"
|
||||
#include "concretelang/Runtime/context.h"
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Runtime/DFRuntime.hpp"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using concretelang::error::StringError;
|
||||
using RuntimeContext = mlir::concretelang::RuntimeContext;
|
||||
|
||||
class KeySet {
|
||||
public:
|
||||
KeySet();
|
||||
~KeySet();
|
||||
KeySet(ClientParameters clientParameters, CSPRNG &&csprng)
|
||||
: csprng(std::move(csprng)), _clientParameters(clientParameters){};
|
||||
KeySet(KeySet &other) = delete;
|
||||
|
||||
/// allocate a KeySet according the ClientParameters.
|
||||
/// Generate a KeySet from a ClientParameters specification.
|
||||
static outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
generate(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb);
|
||||
generate(ClientParameters clientParameters, CSPRNG &&csprng);
|
||||
|
||||
/// Create a KeySet from a set of given keys
|
||||
static outcome::checked<std::unique_ptr<KeySet>, StringError> fromKeys(
|
||||
ClientParameters clientParameters, std::vector<LweSecretKey> secretKeys,
|
||||
std::vector<LweBootstrapKey> bootstrapKeys,
|
||||
std::vector<LweKeyswitchKey> keyswitchKeys,
|
||||
std::vector<PackingKeyswitchKey> packingKeyswitchKeys, CSPRNG &&csprng);
|
||||
|
||||
/// Returns the ClientParameters associated with the KeySet.
|
||||
ClientParameters clientParameters() { return _clientParameters; }
|
||||
@@ -41,23 +44,6 @@ public:
|
||||
// isInputEncrypted return true if the input at the given pos is encrypted.
|
||||
bool isInputEncrypted(size_t pos);
|
||||
|
||||
/// getInputLweSecretKeyParam returns the parameters of the lwe secret key for
|
||||
/// the input at the given `pos`.
|
||||
/// The input must be encrupted
|
||||
LweSecretKeyParam getInputLweSecretKeyParam(size_t pos) {
|
||||
auto gate = inputGate(pos);
|
||||
auto inputSk = this->secretKeys.find(gate.encryption->secretKeyID);
|
||||
return inputSk->second.first;
|
||||
}
|
||||
|
||||
/// getOutputLweSecretKeyParam returns the parameters of the lwe secret key
|
||||
/// for the given output.
|
||||
LweSecretKeyParam getOutputLweSecretKeyParam(size_t pos) {
|
||||
auto gate = outputGate(pos);
|
||||
auto outputSk = this->secretKeys.find(gate.encryption->secretKeyID);
|
||||
return outputSk->second.first;
|
||||
}
|
||||
|
||||
/// allocate a lwe ciphertext buffer for the argument at argPos, set the size
|
||||
/// of the allocated buffer.
|
||||
outcome::checked<void, StringError>
|
||||
@@ -80,103 +66,58 @@ public:
|
||||
CircuitGate inputGate(size_t pos) { return std::get<0>(inputs[pos]); }
|
||||
CircuitGate outputGate(size_t pos) { return std::get<0>(outputs[pos]); }
|
||||
|
||||
RuntimeContext runtimeContext() {
|
||||
RuntimeContext context;
|
||||
context.evaluationKeys = this->evaluationKeys();
|
||||
return context;
|
||||
}
|
||||
/// @brief evaluationKeys returns the evaluation keys associate to this client
|
||||
/// keyset. Those evaluations keys can be safely shared publicly
|
||||
EvaluationKeys evaluationKeys();
|
||||
|
||||
EvaluationKeys evaluationKeys() {
|
||||
if (this->bootstrapKeys.empty() && this->keyswitchKeys.empty()) {
|
||||
return EvaluationKeys();
|
||||
}
|
||||
auto kskIt = this->keyswitchKeys.find(clientlib::KEYSWITCH_KEY);
|
||||
auto bskIt = this->bootstrapKeys.find(clientlib::BOOTSTRAP_KEY);
|
||||
auto fpkskIt = this->packingKeys.find("fpksk_v0");
|
||||
if (kskIt != this->keyswitchKeys.end() &&
|
||||
bskIt != this->bootstrapKeys.end()) {
|
||||
auto sharedKsk = std::get<1>(kskIt->second);
|
||||
auto sharedBsk = std::get<1>(bskIt->second);
|
||||
auto sharedFpksk = fpkskIt == this->packingKeys.end()
|
||||
? std::make_shared<PackingKeyswitchKey>(nullptr)
|
||||
: std::get<1>(fpkskIt->second);
|
||||
return EvaluationKeys(sharedKsk, sharedBsk, sharedFpksk);
|
||||
}
|
||||
assert(!mlir::concretelang::dfr::_dfr_is_root_node() &&
|
||||
"Evaluation keys missing in KeySet (on root node).");
|
||||
return EvaluationKeys();
|
||||
}
|
||||
const std::vector<LweSecretKey> &getSecretKeys();
|
||||
|
||||
const std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey64 *>>
|
||||
&getSecretKeys();
|
||||
const std::vector<LweBootstrapKey> &getBootstrapKeys();
|
||||
|
||||
const std::map<LweSecretKeyID,
|
||||
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>>
|
||||
&getBootstrapKeys();
|
||||
const std::vector<LweKeyswitchKey> &getKeyswitchKeys();
|
||||
|
||||
const std::map<LweSecretKeyID,
|
||||
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>>
|
||||
&getKeyswitchKeys();
|
||||
|
||||
const std::map<
|
||||
LweSecretKeyID,
|
||||
std::pair<PackingKeySwitchParam, std::shared_ptr<PackingKeyswitchKey>>> &
|
||||
getPackingKeys();
|
||||
const std::vector<PackingKeyswitchKey> &getPackingKeyswitchKeys();
|
||||
|
||||
protected:
|
||||
outcome::checked<void, StringError>
|
||||
generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param);
|
||||
generateSecretKey(LweSecretKeyParam param);
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param);
|
||||
generateBootstrapKey(BootstrapKeyParam param);
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param);
|
||||
generateKeyswitchKey(KeyswitchKeyParam param);
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
generatePackingKey(PackingKeySwitchID id, PackingKeySwitchParam param);
|
||||
generatePackingKeyswitchKey(PackingKeyswitchKeyParam param);
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
generateKeysFromParams(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb);
|
||||
outcome::checked<void, StringError> generateKeysFromParams();
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
setupEncryptionMaterial(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb);
|
||||
outcome::checked<void, StringError> setupEncryptionMaterial();
|
||||
|
||||
friend class KeySetCache;
|
||||
|
||||
private:
|
||||
DefaultEngine *engine;
|
||||
DefaultParallelEngine *par_engine;
|
||||
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey64 *>>
|
||||
secretKeys;
|
||||
std::map<LweSecretKeyID,
|
||||
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>>
|
||||
bootstrapKeys;
|
||||
std::map<LweSecretKeyID,
|
||||
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>>
|
||||
keyswitchKeys;
|
||||
std::map<LweSecretKeyID, std::pair<PackingKeySwitchParam,
|
||||
std::shared_ptr<PackingKeyswitchKey>>>
|
||||
packingKeys;
|
||||
std::vector<std::tuple<CircuitGate, LweSecretKeyParam, LweSecretKey64 *>>
|
||||
inputs;
|
||||
std::vector<std::tuple<CircuitGate, LweSecretKeyParam, LweSecretKey64 *>>
|
||||
outputs;
|
||||
CSPRNG csprng;
|
||||
|
||||
void setKeys(
|
||||
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey64 *>>
|
||||
secretKeys,
|
||||
std::map<LweSecretKeyID,
|
||||
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>>
|
||||
bootstrapKeys,
|
||||
std::map<LweSecretKeyID,
|
||||
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>>
|
||||
keyswitchKeys,
|
||||
std::map<LweSecretKeyID, std::pair<PackingKeySwitchParam,
|
||||
std::shared_ptr<PackingKeyswitchKey>>>
|
||||
packingKeys);
|
||||
///////////////////////////////////////////////
|
||||
// Keys mappings
|
||||
std::vector<LweSecretKey> secretKeys;
|
||||
std::vector<LweBootstrapKey> bootstrapKeys;
|
||||
std::vector<LweKeyswitchKey> keyswitchKeys;
|
||||
std::vector<PackingKeyswitchKey> packingKeyswitchKeys;
|
||||
|
||||
outcome::checked<LweSecretKey, StringError> findLweSecretKey(LweSecretKeyID);
|
||||
|
||||
///////////////////////////////////////////////
|
||||
// Convenient positional mapping between positional gate en secret key
|
||||
typedef std::vector<std::pair<CircuitGate, llvm::Optional<LweSecretKey>>>
|
||||
SecretKeyGateMapping;
|
||||
outcome::checked<SecretKeyGateMapping, StringError>
|
||||
mapCircuitGateLweSecretKey(std::vector<CircuitGate> gates);
|
||||
|
||||
SecretKeyGateMapping inputs;
|
||||
SecretKeyGateMapping outputs;
|
||||
|
||||
clientlib::ClientParameters _clientParameters;
|
||||
};
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
#include "concretelang/ClientLib/EncryptedArguments.h"
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Runtime/context.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace serverlib {
|
||||
|
||||
@@ -12,13 +12,10 @@
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
#include "concretelang/Runtime/context.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using RuntimeContext = mlir::concretelang::RuntimeContext;
|
||||
|
||||
// integers are not serialized as binary values even on a binary stream
|
||||
// so we cannot rely on << operator directly
|
||||
template <typename Word>
|
||||
@@ -62,10 +59,6 @@ template <typename Stream> bool incorrectMode(Stream &stream) {
|
||||
return !binary;
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const RuntimeContext &runtimeContext);
|
||||
std::istream &operator>>(std::istream &istream, RuntimeContext &runtimeContext);
|
||||
|
||||
std::ostream &serializeScalarData(const ScalarData &sd, std::ostream &ostream);
|
||||
|
||||
outcome::checked<ScalarData, StringError>
|
||||
@@ -105,17 +98,24 @@ outcome::checked<ScalarOrTensorData, StringError>
|
||||
unserializeScalarOrTensorData(const std::vector<int64_t> &expectedSizes,
|
||||
std::istream &istream);
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream, const LweSecretKey &wrappedKsk);
|
||||
LweSecretKey readLweSecretKey(std::istream &istream);
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const LweKeyswitchKey &wrappedKsk);
|
||||
std::istream &operator>>(std::istream &istream, LweKeyswitchKey &wrappedKsk);
|
||||
LweKeyswitchKey readLweKeyswitchKey(std::istream &istream);
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const LweBootstrapKey &wrappedBsk);
|
||||
std::istream &operator>>(std::istream &istream, LweBootstrapKey &wrappedBsk);
|
||||
LweBootstrapKey readLweBootstrapKey(std::istream &istream);
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const PackingKeyswitchKey &wrappedKsk);
|
||||
PackingKeyswitchKey readPackingKeyswitchKey(std::istream &istream);
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const EvaluationKeys &evaluationKeys);
|
||||
std::istream &operator>>(std::istream &istream, EvaluationKeys &evaluationKeys);
|
||||
EvaluationKeys readEvaluationKeys(std::istream &istream);
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
@@ -7,12 +7,6 @@
|
||||
|
||||
#include <string>
|
||||
|
||||
#define CAPI_ASSERT_ERROR(instr) \
|
||||
{ \
|
||||
int err = instr; \
|
||||
assert(err == 0); \
|
||||
}
|
||||
|
||||
namespace concretelang {
|
||||
namespace error {
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include <cstddef>
|
||||
#include <list>
|
||||
|
||||
namespace mlir {
|
||||
@@ -19,11 +20,9 @@ struct CrtLoweringParameters {
|
||||
size_t nMods;
|
||||
size_t modsProd;
|
||||
size_t bitsTotal;
|
||||
size_t polynomialSize;
|
||||
size_t lutSize;
|
||||
size_t singleLutSize;
|
||||
|
||||
CrtLoweringParameters(mlir::SmallVector<int64_t> mods, size_t polySize)
|
||||
: mods(mods), polynomialSize(polySize) {
|
||||
CrtLoweringParameters(mlir::SmallVector<int64_t> mods) : mods(mods) {
|
||||
nMods = mods.size();
|
||||
modsProd = 1;
|
||||
bitsTotal = 0;
|
||||
@@ -35,9 +34,7 @@ struct CrtLoweringParameters {
|
||||
bits.push_back(nbits);
|
||||
bitsTotal += nbits;
|
||||
}
|
||||
size_t lutCrtSize = size_t(1) << bitsTotal;
|
||||
lutCrtSize = std::max(lutCrtSize, polynomialSize);
|
||||
lutSize = mods.size() * lutCrtSize;
|
||||
singleLutSize = size_t(1) << bitsTotal;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -15,12 +15,14 @@ include "concretelang/Dialect/RT/IR/RTTypes.td"
|
||||
|
||||
def Concrete_LweTensor : 1DTensorOf<[I64]>;
|
||||
def Concrete_LutTensor : 1DTensorOf<[I64]>;
|
||||
def Concrete_CrtLutsTensor : 2DTensorOf<[I64]>;
|
||||
def Concrete_CrtPlaintextTensor : 1DTensorOf<[I64]>;
|
||||
def Concrete_LweCRTTensor : 2DTensorOf<[I64]>;
|
||||
def Concrete_BatchLweTensor : 2DTensorOf<[I64]>;
|
||||
|
||||
def Concrete_LweBuffer : MemRefRankOf<[I64], [1]>;
|
||||
def Concrete_LutBuffer : MemRefRankOf<[I64], [1]>;
|
||||
def Concrete_CrtLutsBuffer : MemRefRankOf<[I64], [2]>;
|
||||
def Concrete_CrtPlaintextBuffer : MemRefRankOf<[I64], [1]>;
|
||||
def Concrete_LweCRTBuffer : MemRefRankOf<[I64], [2]>;
|
||||
def Concrete_BatchLweBuffer : MemRefRankOf<[I64], [2]>;
|
||||
@@ -126,7 +128,7 @@ def Concrete_EncodeExpandLutForBootstrapBufferOp : Concrete_Op<"encode_expand_lu
|
||||
);
|
||||
}
|
||||
|
||||
def Concrete_EncodeExpandLutForWopPBSTensorOp : Concrete_Op<"encode_expand_lut_for_woppbs_tensor", [NoSideEffect]> {
|
||||
def Concrete_EncodeLutForCrtWopPBSTensorOp : Concrete_Op<"encode_lut_for_crt_woppbs_tensor", [NoSideEffect]> {
|
||||
let summary =
|
||||
"Encode and expand a lookup table so that it can be used for a wop pbs";
|
||||
|
||||
@@ -134,24 +136,22 @@ def Concrete_EncodeExpandLutForWopPBSTensorOp : Concrete_Op<"encode_expand_lut_f
|
||||
Concrete_LutTensor : $input_lookup_table,
|
||||
I64ArrayAttr: $crtDecomposition,
|
||||
I64ArrayAttr: $crtBits,
|
||||
I32Attr : $polySize,
|
||||
I32Attr : $modulusProduct,
|
||||
BoolAttr: $isSigned
|
||||
);
|
||||
|
||||
let results = (outs Concrete_LutTensor : $result);
|
||||
let results = (outs Concrete_CrtLutsTensor : $result);
|
||||
}
|
||||
|
||||
def Concrete_EncodeExpandLutForWopPBSBufferOp : Concrete_Op<"encode_expand_lut_for_woppbs_buffer"> {
|
||||
def Concrete_EncodeLutForCrtWopPBSBufferOp : Concrete_Op<"encode_lut_for_crt_woppbs_buffer"> {
|
||||
let summary =
|
||||
"Encode and expand a lookup table so that it can be used for a wop pbs";
|
||||
"Encode and expand a lookup table so that it can be used for a crt wop pbs";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_LutBuffer : $result,
|
||||
Concrete_CrtLutsBuffer : $result,
|
||||
Concrete_LutBuffer : $input_lookup_table,
|
||||
I64ArrayAttr: $crtDecomposition,
|
||||
I64ArrayAttr: $crtBits,
|
||||
I32Attr : $polySize,
|
||||
I32Attr : $modulusProduct,
|
||||
BoolAttr: $isSigned
|
||||
);
|
||||
@@ -347,7 +347,7 @@ def Concrete_BatchedKeySwitchLweBufferOp : Concrete_Op<"batched_keyswitch_lwe_bu
|
||||
def Concrete_WopPBSCRTLweTensorOp : Concrete_Op<"wop_pbs_crt_lwe_tensor", [NoSideEffect]> {
|
||||
let arguments = (ins
|
||||
Concrete_LweCRTTensor:$ciphertext,
|
||||
Concrete_LutTensor:$lookupTable,
|
||||
Concrete_CrtLutsTensor:$lookupTable,
|
||||
// Bootstrap parameters
|
||||
I32Attr : $bootstrapLevel,
|
||||
I32Attr : $bootstrapBaseLog,
|
||||
@@ -370,7 +370,7 @@ def Concrete_WopPBSCRTLweBufferOp : Concrete_Op<"wop_pbs_crt_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
Concrete_LweCRTBuffer:$result,
|
||||
Concrete_LweCRTBuffer:$ciphertext,
|
||||
Concrete_LutBuffer:$lookup_table,
|
||||
Concrete_CrtLutsBuffer:$lookup_table,
|
||||
// Bootstrap parameters
|
||||
I32Attr : $bootstrapLevel,
|
||||
I32Attr : $bootstrapBaseLog,
|
||||
|
||||
@@ -33,7 +33,7 @@ def TFHE_EncodeExpandLutForBootstrapOp : TFHE_Op<"encode_expand_lut_for_bootstra
|
||||
let results = (outs 1DTensorOf<[I64]> : $result);
|
||||
}
|
||||
|
||||
def TFHE_EncodeExpandLutForWopPBSOp : TFHE_Op<"encode_expand_lut_for_woppbs"> {
|
||||
def TFHE_EncodeLutForCrtWopPBSOp : TFHE_Op<"encode_lut_for_crt_woppbs"> {
|
||||
let summary =
|
||||
"Encode and expand a lookup table so that it can be used for a wop pbs.";
|
||||
|
||||
@@ -41,12 +41,11 @@ def TFHE_EncodeExpandLutForWopPBSOp : TFHE_Op<"encode_expand_lut_for_woppbs"> {
|
||||
1DTensorOf<[I64]> : $input_lookup_table,
|
||||
I64ArrayAttr: $crtDecomposition,
|
||||
I64ArrayAttr: $crtBits,
|
||||
I32Attr : $polySize,
|
||||
I32Attr : $modulusProduct,
|
||||
BoolAttr: $isSigned
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[I64]> : $result);
|
||||
let results = (outs 2DTensorOf<[I64]> : $result);
|
||||
}
|
||||
|
||||
def TFHE_EncodePlaintextWithCrtOp : TFHE_Op<"encode_plaintext_with_crt"> {
|
||||
@@ -158,7 +157,7 @@ def TFHE_WopPBSGLWEOp : TFHE_Op<"wop_pbs_glwe"> {
|
||||
|
||||
let arguments = (ins
|
||||
Type<And<[TensorOf<[TFHE_GLWECipherTextType]>.predicate, HasStaticShapePred]>>: $ciphertexts,
|
||||
1DTensorOf<[I64]> : $lookupTable,
|
||||
2DTensorOf<[I64]> : $lookupTable,
|
||||
// Bootstrap parameters
|
||||
I32Attr : $bootstrapLevel,
|
||||
I32Attr : $bootstrapBaseLog,
|
||||
|
||||
@@ -12,11 +12,10 @@
|
||||
#include <pthread.h>
|
||||
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "concretelang/Runtime/seeder.h"
|
||||
|
||||
#include "concrete-core-ffi.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
|
||||
#include "concrete-cpu.h"
|
||||
|
||||
#ifdef CONCRETELANG_CUDA_SUPPORT
|
||||
#include "bootstrap.h"
|
||||
#include "device.h"
|
||||
@@ -26,27 +25,23 @@
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
typedef struct FFT {
|
||||
FFT() = delete;
|
||||
FFT(size_t polynomial_size);
|
||||
FFT(FFT &other) = delete;
|
||||
FFT(FFT &&other);
|
||||
~FFT();
|
||||
|
||||
struct Fft *fft;
|
||||
size_t polynomial_size;
|
||||
} FFT;
|
||||
|
||||
typedef struct RuntimeContext {
|
||||
|
||||
RuntimeContext() {
|
||||
CAPI_ASSERT_ERROR(new_default_engine(best_seeder, &default_engine));
|
||||
#ifdef CONCRETELANG_CUDA_SUPPORT
|
||||
bsk_gpu = nullptr;
|
||||
ksk_gpu = nullptr;
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Ensure that the engines map is not copied
|
||||
RuntimeContext(const RuntimeContext &ctx){};
|
||||
RuntimeContext() = delete;
|
||||
|
||||
RuntimeContext(::concretelang::clientlib::EvaluationKeys evaluationKeys);
|
||||
~RuntimeContext() {
|
||||
CAPI_ASSERT_ERROR(destroy_default_engine(default_engine));
|
||||
for (const auto &key : fft_engines) {
|
||||
CAPI_ASSERT_ERROR(destroy_fft_engine(key.second));
|
||||
}
|
||||
if (fbsk != nullptr) {
|
||||
CAPI_ASSERT_ERROR(destroy_fft_fourier_lwe_bootstrap_key_u64(fbsk));
|
||||
}
|
||||
#ifdef CONCRETELANG_CUDA_SUPPORT
|
||||
if (bsk_gpu != nullptr) {
|
||||
cuda_drop(bsk_gpu, 0);
|
||||
@@ -55,44 +50,33 @@ typedef struct RuntimeContext {
|
||||
cuda_drop(ksk_gpu, 0);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
const uint64_t *keyswitch_key_buffer(size_t keyId) {
|
||||
return evaluationKeys.getKeyswitchKey(keyId).buffer();
|
||||
}
|
||||
|
||||
FftEngine *get_fft_engine() {
|
||||
pthread_t threadId = pthread_self();
|
||||
std::lock_guard<std::mutex> guard(engines_map_guard);
|
||||
auto engineIt = fft_engines.find(threadId);
|
||||
if (engineIt == fft_engines.end()) {
|
||||
FftEngine *fft_engine = nullptr;
|
||||
|
||||
CAPI_ASSERT_ERROR(new_fft_engine(&fft_engine));
|
||||
|
||||
engineIt =
|
||||
fft_engines
|
||||
.insert(std::pair<pthread_t, FftEngine *>(threadId, fft_engine))
|
||||
.first;
|
||||
}
|
||||
assert(engineIt->second && "No engine available in context");
|
||||
return engineIt->second;
|
||||
const double *fourier_bootstrap_key_buffer(size_t keyId) {
|
||||
return fourier_bootstrap_keys[keyId]->data();
|
||||
}
|
||||
|
||||
DefaultEngine *get_default_engine() { return default_engine; }
|
||||
|
||||
FftFourierLweBootstrapKey64 *get_fft_fourier_bsk() {
|
||||
|
||||
if (fbsk != nullptr) {
|
||||
return fbsk;
|
||||
}
|
||||
|
||||
const std::lock_guard<std::mutex> guard(fbskMutex);
|
||||
if (fbsk == nullptr) {
|
||||
CAPI_ASSERT_ERROR(
|
||||
fft_engine_convert_lwe_bootstrap_key_to_fft_fourier_lwe_bootstrap_key_u64(
|
||||
get_fft_engine(), evaluationKeys.getBsk(), &fbsk));
|
||||
}
|
||||
return fbsk;
|
||||
const uint64_t *fp_keyswitch_key_buffer(size_t keyId) {
|
||||
return evaluationKeys.getPackingKeyswitchKey(keyId).buffer();
|
||||
}
|
||||
|
||||
const struct Fft *fft(size_t keyId) { return ffts[keyId].fft; }
|
||||
|
||||
const ::concretelang::clientlib::EvaluationKeys getKeys() const {
|
||||
return evaluationKeys;
|
||||
}
|
||||
|
||||
private:
|
||||
::concretelang::clientlib::EvaluationKeys evaluationKeys;
|
||||
std::vector<std::shared_ptr<std::vector<double>>> fourier_bootstrap_keys;
|
||||
std::vector<FFT> ffts;
|
||||
|
||||
#ifdef CONCRETELANG_CUDA_SUPPORT
|
||||
public:
|
||||
void *get_bsk_gpu(uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
|
||||
uint32_t glwe_dim, uint32_t gpu_idx, void *stream) {
|
||||
|
||||
@@ -104,25 +88,21 @@ typedef struct RuntimeContext {
|
||||
if (bsk_gpu != nullptr) {
|
||||
return bsk_gpu;
|
||||
}
|
||||
LweBootstrapKey64 *bsk = get_bsk();
|
||||
size_t bsk_buffer_len =
|
||||
input_lwe_dim * (glwe_dim + 1) * (glwe_dim + 1) * poly_size * level;
|
||||
size_t bsk_buffer_size = bsk_buffer_len * sizeof(uint64_t);
|
||||
uint64_t *bsk_buffer =
|
||||
(uint64_t *)aligned_alloc(U64_ALIGNMENT, bsk_buffer_size);
|
||||
|
||||
auto bsk = evaluationKeys.getBootstrapKey(0);
|
||||
|
||||
size_t bsk_buffer_len = bsk.size();
|
||||
size_t bsk_gpu_buffer_size = bsk_buffer_len * sizeof(double);
|
||||
|
||||
void *bsk_gpu_tmp = cuda_malloc(bsk_gpu_buffer_size, gpu_idx);
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_engine_discard_convert_lwe_bootstrap_key_to_lwe_bootstrap_key_mut_view_u64_raw_ptr_buffers(
|
||||
default_engine, bsk, bsk_buffer));
|
||||
cuda_initialize_twiddles(poly_size, gpu_idx);
|
||||
cuda_convert_lwe_bootstrap_key_64(bsk_gpu_tmp, bsk_buffer, stream, gpu_idx,
|
||||
input_lwe_dim, glwe_dim, level,
|
||||
cuda_convert_lwe_bootstrap_key_64(bsk_gpu_tmp, (void *)bsk.buffer(), stream,
|
||||
gpu_idx, input_lwe_dim, glwe_dim, level,
|
||||
poly_size);
|
||||
// This is currently not 100% async as we have to free CPU memory after
|
||||
// This is currently not 100% async as
|
||||
// we have to free CPU memory after
|
||||
// conversion
|
||||
cuda_synchronize_device(gpu_idx);
|
||||
free(bsk_buffer);
|
||||
bsk_gpu = bsk_gpu_tmp;
|
||||
return bsk_gpu;
|
||||
}
|
||||
@@ -138,49 +118,23 @@ typedef struct RuntimeContext {
|
||||
if (ksk_gpu != nullptr) {
|
||||
return ksk_gpu;
|
||||
}
|
||||
LweKeyswitchKey64 *ksk = get_ksk();
|
||||
size_t ksk_buffer_len = input_lwe_dim * (output_lwe_dim + 1) * level;
|
||||
size_t ksk_buffer_size = sizeof(uint64_t) * ksk_buffer_len;
|
||||
uint64_t *ksk_buffer =
|
||||
(uint64_t *)aligned_alloc(U64_ALIGNMENT, ksk_buffer_size);
|
||||
auto ksk = evaluationKeys.getKeyswitchKey(0);
|
||||
|
||||
size_t ksk_buffer_size = sizeof(uint64_t) * ksk.size();
|
||||
|
||||
void *ksk_gpu_tmp = cuda_malloc(ksk_buffer_size, gpu_idx);
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_engine_discard_convert_lwe_keyswitch_key_to_lwe_keyswitch_key_mut_view_u64_raw_ptr_buffers(
|
||||
default_engine, ksk, ksk_buffer));
|
||||
cuda_memcpy_async_to_gpu(ksk_gpu_tmp, ksk_buffer, ksk_buffer_size, stream,
|
||||
gpu_idx);
|
||||
// This is currently not 100% async as we have to free CPU memory after
|
||||
|
||||
cuda_memcpy_async_to_gpu(ksk_gpu_tmp, (void *)ksk.buffer(), ksk_buffer_size,
|
||||
stream, gpu_idx);
|
||||
// This is currently not 100% async as
|
||||
// we have to free CPU memory after
|
||||
// conversion
|
||||
cuda_synchronize_device(gpu_idx);
|
||||
free(ksk_buffer);
|
||||
ksk_gpu = ksk_gpu_tmp;
|
||||
return ksk_gpu;
|
||||
}
|
||||
#endif
|
||||
|
||||
LweBootstrapKey64 *get_bsk() { return evaluationKeys.getBsk(); }
|
||||
|
||||
LweKeyswitchKey64 *get_ksk() { return evaluationKeys.getKsk(); }
|
||||
|
||||
LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *get_fpksk() {
|
||||
return evaluationKeys.getFpksk();
|
||||
}
|
||||
|
||||
RuntimeContext &operator=(const RuntimeContext &rhs) {
|
||||
this->evaluationKeys = rhs.evaluationKeys;
|
||||
return *this;
|
||||
}
|
||||
|
||||
::concretelang::clientlib::EvaluationKeys evaluationKeys;
|
||||
|
||||
private:
|
||||
std::mutex fbskMutex;
|
||||
FftFourierLweBootstrapKey64 *fbsk = nullptr;
|
||||
DefaultEngine *default_engine;
|
||||
std::map<pthread_t, FftEngine *> fft_engines;
|
||||
std::mutex engines_map_guard;
|
||||
|
||||
#ifdef CONCRETELANG_CUDA_SUPPORT
|
||||
std::mutex bsk_gpu_mutex;
|
||||
void *bsk_gpu;
|
||||
std::mutex ksk_gpu_mutex;
|
||||
@@ -192,18 +146,4 @@ private:
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
extern "C" {
|
||||
LweKeyswitchKey64 *
|
||||
get_keyswitch_key_u64(mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
FftFourierLweBootstrapKey64 *
|
||||
get_fft_fourier_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
LweBootstrapKey64 *
|
||||
get_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
DefaultEngine *get_engine(mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
FftEngine *get_fft_engine(mlir::concretelang::RuntimeContext *context);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -14,17 +14,16 @@
|
||||
#include <hpx/modules/collectives.hpp>
|
||||
#include <hpx/modules/serialization.hpp>
|
||||
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "concretelang/Runtime/DFRuntime.hpp"
|
||||
#include "concretelang/Runtime/context.h"
|
||||
|
||||
#include "concrete-core-ffi.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace dfr {
|
||||
|
||||
template <typename T> struct KeyManager;
|
||||
struct RuntimeContextManager;
|
||||
namespace {
|
||||
static void *dl_handle;
|
||||
@@ -32,103 +31,83 @@ static RuntimeContextManager *_dfr_node_level_runtime_context_manager;
|
||||
} // namespace
|
||||
|
||||
template <typename LweKeyType> struct KeyWrapper {
|
||||
LweKeyType *key;
|
||||
Buffer buffer;
|
||||
std::vector<LweKeyType> keys;
|
||||
|
||||
KeyWrapper() : key(nullptr) {}
|
||||
KeyWrapper(KeyWrapper &&moved) noexcept
|
||||
: key(moved.key), buffer(moved.buffer) {}
|
||||
KeyWrapper(LweKeyType *key);
|
||||
KeyWrapper(const KeyWrapper &kw) : key(kw.key), buffer(kw.buffer) {}
|
||||
KeyWrapper() {}
|
||||
KeyWrapper(KeyWrapper &&moved) noexcept : keys(moved.keys) {}
|
||||
KeyWrapper(const KeyWrapper &kw) : keys(kw.keys) {}
|
||||
KeyWrapper &operator=(const KeyWrapper &rhs) {
|
||||
this->key = rhs.key;
|
||||
this->buffer = rhs.buffer;
|
||||
this->keys = rhs.keys;
|
||||
return *this;
|
||||
}
|
||||
KeyWrapper(std::vector<LweKeyType> keyvec) : keys(keyvec) {}
|
||||
friend class hpx::serialization::access;
|
||||
// template <class Archive>
|
||||
// void save(Archive &ar, const unsigned int version) const;
|
||||
template <class Archive>
|
||||
void save(Archive &ar, const unsigned int version) const;
|
||||
template <class Archive> void load(Archive &ar, const unsigned int version);
|
||||
HPX_SERIALIZATION_SPLIT_MEMBER()
|
||||
void serialize(Archive &ar, const unsigned int version) const {}
|
||||
// template <class Archive> void load(Archive &ar, const unsigned int
|
||||
// version); HPX_SERIALIZATION_SPLIT_MEMBER()
|
||||
};
|
||||
|
||||
template <>
|
||||
KeyWrapper<LweKeyswitchKey64>::KeyWrapper(LweKeyswitchKey64 *key) : key(key) {
|
||||
|
||||
DefaultSerializationEngine *engine;
|
||||
|
||||
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
|
||||
// No Freeing as it doesn't allocate anything.
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_serialization_engine_serialize_lwe_keyswitch_key_u64(engine, key,
|
||||
&buffer));
|
||||
}
|
||||
template <>
|
||||
KeyWrapper<LweBootstrapKey64>::KeyWrapper(LweBootstrapKey64 *key) : key(key) {
|
||||
|
||||
DefaultSerializationEngine *engine;
|
||||
|
||||
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
|
||||
|
||||
// No Freeing as it doesn't allocate anything.
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_serialization_engine_serialize_lwe_bootstrap_key_u64(engine, key,
|
||||
&buffer));
|
||||
}
|
||||
|
||||
template <typename LweKeyType>
|
||||
bool operator==(const KeyWrapper<LweKeyType> &lhs,
|
||||
const KeyWrapper<LweKeyType> &rhs) {
|
||||
return lhs.key == rhs.key;
|
||||
if (lhs.keys.size() != rhs.keys.size())
|
||||
return false;
|
||||
for (size_t i = 0; i < lhs.keys.size(); ++i)
|
||||
if (lhs.keys[i].buffer() != rhs.keys[i].buffer())
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
template <class Archive>
|
||||
void KeyWrapper<LweBootstrapKey64>::save(Archive &ar,
|
||||
const unsigned int version) const {
|
||||
ar << buffer.length;
|
||||
ar << hpx::serialization::make_array(buffer.pointer, buffer.length);
|
||||
}
|
||||
template <>
|
||||
template <class Archive>
|
||||
void KeyWrapper<LweBootstrapKey64>::load(Archive &ar,
|
||||
const unsigned int version) {
|
||||
DefaultSerializationEngine *engine;
|
||||
// template <>
|
||||
// template <class Archive>
|
||||
// void KeyWrapper<LweBootstrapKey>::save(Archive &ar,
|
||||
// const unsigned int version) const {
|
||||
// ar << buffer.length;
|
||||
// ar << hpx::serialization::make_array(buffer.pointer, buffer.length);
|
||||
// }
|
||||
// template <>
|
||||
// template <class Archive>
|
||||
// void KeyWrapper<LweBootstrapKey>::load(Archive &ar,
|
||||
// const unsigned int version) {
|
||||
// DefaultSerializationEngine *engine;
|
||||
|
||||
// No Freeing as it doesn't allocate anything.
|
||||
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
|
||||
// // No Freeing as it doesn't allocate anything.
|
||||
// CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
|
||||
|
||||
ar >> buffer.length;
|
||||
buffer.pointer = new uint8_t[buffer.length];
|
||||
ar >> hpx::serialization::make_array(buffer.pointer, buffer.length);
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_serialization_engine_deserialize_lwe_bootstrap_key_u64(
|
||||
engine, {buffer.pointer, buffer.length}, &key));
|
||||
}
|
||||
// ar >> buffer.length;
|
||||
// buffer.pointer = new uint8_t[buffer.length];
|
||||
// ar >> hpx::serialization::make_array(buffer.pointer, buffer.length);
|
||||
// CAPI_ASSERT_ERROR(
|
||||
// default_serialization_engine_deserialize_lwe_bootstrap_key_u64(
|
||||
// engine, {buffer.pointer, buffer.length}, &key));
|
||||
// }
|
||||
|
||||
template <>
|
||||
template <class Archive>
|
||||
void KeyWrapper<LweKeyswitchKey64>::save(Archive &ar,
|
||||
const unsigned int version) const {
|
||||
ar << buffer.length;
|
||||
ar << hpx::serialization::make_array(buffer.pointer, buffer.length);
|
||||
}
|
||||
template <>
|
||||
template <class Archive>
|
||||
void KeyWrapper<LweKeyswitchKey64>::load(Archive &ar,
|
||||
const unsigned int version) {
|
||||
DefaultSerializationEngine *engine;
|
||||
// template <>
|
||||
// template <class Archive>
|
||||
// void KeyWrapper<LweKeyswitchKey>::save(Archive &ar,
|
||||
// const unsigned int version) const {
|
||||
// ar << buffer.length;
|
||||
// ar << hpx::serialization::make_array(buffer.pointer, buffer.length);
|
||||
// }
|
||||
// template <>
|
||||
// template <class Archive>
|
||||
// void KeyWrapper<LweKeyswitchKey>::load(Archive &ar,
|
||||
// const unsigned int version) {
|
||||
// DefaultSerializationEngine *engine;
|
||||
|
||||
// No Freeing as it doesn't allocate anything.
|
||||
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
|
||||
// // No Freeing as it doesn't allocate anything.
|
||||
// CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
|
||||
|
||||
ar >> buffer.length;
|
||||
buffer.pointer = new uint8_t[buffer.length];
|
||||
ar >> hpx::serialization::make_array(buffer.pointer, buffer.length);
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_serialization_engine_deserialize_lwe_keyswitch_key_u64(
|
||||
engine, {buffer.pointer, buffer.length}, &key));
|
||||
}
|
||||
// ar >> buffer.length;
|
||||
// buffer.pointer = new uint8_t[buffer.length];
|
||||
// ar >> hpx::serialization::make_array(buffer.pointer, buffer.length);
|
||||
// CAPI_ASSERT_ERROR(
|
||||
// default_serialization_engine_deserialize_lwe_keyswitch_key_u64(
|
||||
// engine, {buffer.pointer, buffer.length}, &key));
|
||||
// }
|
||||
|
||||
/************************/
|
||||
/* Context management. */
|
||||
@@ -152,34 +131,36 @@ struct RuntimeContextManager {
|
||||
// instantiates a local RuntimeContext.
|
||||
if (_dfr_is_root_node()) {
|
||||
RuntimeContext *context = (RuntimeContext *)ctx;
|
||||
LweKeyswitchKey64 *ksk = get_keyswitch_key_u64(context);
|
||||
LweBootstrapKey64 *bsk = get_bootstrap_key_u64(context);
|
||||
|
||||
KeyWrapper<LweKeyswitchKey64> kskw(ksk);
|
||||
KeyWrapper<LweBootstrapKey64> bskw(bsk);
|
||||
KeyWrapper<::concretelang::clientlib::LweKeyswitchKey> kskw(
|
||||
context->getKeys().getKeyswitchKeys());
|
||||
KeyWrapper<::concretelang::clientlib::LweBootstrapKey> bskw(
|
||||
context->getKeys().getBootstrapKeys());
|
||||
KeyWrapper<::concretelang::clientlib::PackingKeyswitchKey> pkskw(
|
||||
context->getKeys().getPackingKeyswitchKeys());
|
||||
hpx::collectives::broadcast_to("ksk_keystore", kskw);
|
||||
hpx::collectives::broadcast_to("bsk_keystore", bskw);
|
||||
hpx::collectives::broadcast_to("pksk_keystore", pkskw);
|
||||
} else {
|
||||
auto kskFut =
|
||||
hpx::collectives::broadcast_from<KeyWrapper<LweKeyswitchKey64>>(
|
||||
"ksk_keystore");
|
||||
auto bskFut =
|
||||
hpx::collectives::broadcast_from<KeyWrapper<LweBootstrapKey64>>(
|
||||
"bsk_keystore");
|
||||
auto kskFut = hpx::collectives::broadcast_from<
|
||||
KeyWrapper<::concretelang::clientlib::LweKeyswitchKey>>(
|
||||
"ksk_keystore");
|
||||
auto bskFut = hpx::collectives::broadcast_from<
|
||||
KeyWrapper<::concretelang::clientlib::LweBootstrapKey>>(
|
||||
"bsk_keystore");
|
||||
auto pkskFut = hpx::collectives::broadcast_from<
|
||||
KeyWrapper<::concretelang::clientlib::PackingKeyswitchKey>>(
|
||||
"pksk_keystore");
|
||||
|
||||
KeyWrapper<LweKeyswitchKey64> kskw = kskFut.get();
|
||||
KeyWrapper<LweBootstrapKey64> bskw = bskFut.get();
|
||||
context = new mlir::concretelang::RuntimeContext();
|
||||
// TODO - packing keyswitch key for distributed
|
||||
context->evaluationKeys = ::concretelang::clientlib::EvaluationKeys(
|
||||
std::shared_ptr<::concretelang::clientlib::LweKeyswitchKey>(
|
||||
new ::concretelang::clientlib::LweKeyswitchKey(kskw.key)),
|
||||
std::shared_ptr<::concretelang::clientlib::LweBootstrapKey>(
|
||||
new ::concretelang::clientlib::LweBootstrapKey(bskw.key)),
|
||||
std::shared_ptr<::concretelang::clientlib::PackingKeyswitchKey>(
|
||||
nullptr));
|
||||
delete (kskw.buffer.pointer);
|
||||
delete (bskw.buffer.pointer);
|
||||
KeyWrapper<::concretelang::clientlib::LweKeyswitchKey> kskw =
|
||||
kskFut.get();
|
||||
KeyWrapper<::concretelang::clientlib::LweBootstrapKey> bskw =
|
||||
bskFut.get();
|
||||
KeyWrapper<::concretelang::clientlib::PackingKeyswitchKey> pkskw =
|
||||
pkskFut.get();
|
||||
context = new mlir::concretelang::RuntimeContext(
|
||||
::concretelang::clientlib::EvaluationKeys(kskw.keys, bskw.keys,
|
||||
pkskw.keys));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_RUNTIME_SEEDER_H
|
||||
#define CONCRETELANG_RUNTIME_SEEDER_H
|
||||
|
||||
#include "concrete-core-ffi.h"
|
||||
|
||||
extern SeederBuilder *best_seeder;
|
||||
|
||||
#endif
|
||||
@@ -29,18 +29,19 @@ void memref_encode_expand_lut_for_bootstrap(
|
||||
uint64_t input_lut_size, uint64_t input_lut_stride, uint32_t poly_size,
|
||||
uint32_t out_MESSAGE_BITS, bool is_signed);
|
||||
|
||||
void memref_encode_expand_lut_for_woppbs(
|
||||
void memref_encode_lut_for_crt_woppbs(
|
||||
uint64_t *output_lut_allocated, uint64_t *output_lut_aligned,
|
||||
uint64_t output_lut_offset, uint64_t output_lut_size,
|
||||
uint64_t output_lut_stride, uint64_t *input_lut_allocated,
|
||||
uint64_t output_lut_offset, uint64_t output_lut_size0,
|
||||
uint64_t output_lut_size1, uint64_t output_lut_stride0,
|
||||
uint64_t output_lut_stride1, uint64_t *input_lut_allocated,
|
||||
uint64_t *input_lut_aligned, uint64_t input_lut_offset,
|
||||
uint64_t input_lut_size, uint64_t input_lut_stride,
|
||||
uint64_t *crt_decomposition_allocated, uint64_t *crt_decomposition_aligned,
|
||||
uint64_t crt_decomposition_offset, uint64_t crt_decomposition_size,
|
||||
uint64_t crt_decomposition_stride, uint64_t *crt_bits_allocated,
|
||||
uint64_t *crt_bits_aligned, uint64_t crt_bits_offset,
|
||||
uint64_t crt_bits_size, uint64_t crt_bits_stride, uint32_t poly_size,
|
||||
uint32_t modulus_product, bool is_signed);
|
||||
uint64_t crt_bits_size, uint64_t crt_bits_stride, uint32_t modulus_product,
|
||||
bool is_signed);
|
||||
|
||||
void memref_encode_plaintext_with_crt(
|
||||
uint64_t *output_allocated, uint64_t *output_aligned,
|
||||
@@ -149,13 +150,16 @@ void memref_wop_pbs_crt_buffer(
|
||||
uint64_t in_stride_1,
|
||||
// clear text lut
|
||||
uint64_t *lut_ct_allocated, uint64_t *lut_ct_aligned,
|
||||
uint64_t lut_ct_offset, uint64_t lut_ct_size, uint64_t lut_ct_stride,
|
||||
uint64_t lut_ct_offset, uint64_t lut_ct_size0, uint64_t lut_ct_size1,
|
||||
uint64_t lut_ct_stride0, uint64_t lut_ct_stride1,
|
||||
// CRT decomposition
|
||||
uint64_t *crt_decomp_allocated, uint64_t *crt_decomp_aligned,
|
||||
uint64_t crt_decomp_offset, uint64_t crt_decomp_size,
|
||||
uint64_t crt_decomp_stride,
|
||||
// Additional crypto parameters
|
||||
uint32_t lwe_small_size, uint32_t cbs_level_count, uint32_t cbs_base_log,
|
||||
uint32_t ksk_level_count, uint32_t ksk_base_log, uint32_t bsk_level_count,
|
||||
uint32_t bsk_base_log, uint32_t fpksk_level_count, uint32_t fpksk_base_log,
|
||||
uint32_t polynomial_size,
|
||||
// runtime context that hold evluation keys
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
@@ -337,8 +337,8 @@ public:
|
||||
if (check.has_error()) {
|
||||
return StreamStringError(check.error().mesg);
|
||||
}
|
||||
auto publicArguments = encryptedArgs->exportPublicArguments(
|
||||
clientParameters, keySet.runtimeContext());
|
||||
auto publicArguments =
|
||||
encryptedArgs->exportPublicArguments(clientParameters);
|
||||
if (publicArguments.has_error()) {
|
||||
return StreamStringError(publicArguments.error().mesg);
|
||||
}
|
||||
@@ -485,8 +485,8 @@ public:
|
||||
if (encryptedArgs.has_error()) {
|
||||
return StreamStringError(encryptedArgs.error().mesg);
|
||||
}
|
||||
auto publicArguments = encryptedArgs.value()->exportPublicArguments(
|
||||
clientParameters, keySet->runtimeContext());
|
||||
auto publicArguments =
|
||||
encryptedArgs.value()->exportPublicArguments(clientParameters);
|
||||
if (!publicArguments.has_value()) {
|
||||
return StreamStringError(publicArguments.error().mesg);
|
||||
}
|
||||
@@ -506,8 +506,8 @@ public:
|
||||
if (encryptedArgs.has_error()) {
|
||||
return StreamStringError(encryptedArgs.error().mesg);
|
||||
}
|
||||
auto publicArguments = encryptedArgs.value()->exportPublicArguments(
|
||||
clientParameters, keySet->runtimeContext());
|
||||
auto publicArguments =
|
||||
encryptedArgs.value()->exportPublicArguments(clientParameters);
|
||||
if (publicArguments.has_error()) {
|
||||
return StreamStringError(publicArguments.error().mesg);
|
||||
}
|
||||
|
||||
@@ -73,8 +73,7 @@ public:
|
||||
OUTCOME_TRY(auto encryptedArgs,
|
||||
clientlib::EncryptedArguments::create(*keySet, args...));
|
||||
OUTCOME_TRY(auto publicArgument,
|
||||
encryptedArgs->exportPublicArguments(this->clientParameters,
|
||||
keySet->runtimeContext()));
|
||||
encryptedArgs->exportPublicArguments(this->clientParameters));
|
||||
// client argument serialization
|
||||
// publicArgument->serialize(clientOuput);
|
||||
// message = clientOuput.str();
|
||||
|
||||
@@ -218,8 +218,8 @@ MLIR_CAPI_EXPORTED concretelang::clientlib::EvaluationKeys
|
||||
evaluationKeysUnserialize(const std::string &buffer) {
|
||||
std::stringstream istream(buffer);
|
||||
|
||||
concretelang::clientlib::EvaluationKeys evaluationKeys;
|
||||
concretelang::clientlib::operator>>(istream, evaluationKeys);
|
||||
concretelang::clientlib::EvaluationKeys evaluationKeys =
|
||||
concretelang::clientlib::readEvaluationKeys(istream);
|
||||
|
||||
if (istream.fail()) {
|
||||
throw std::runtime_error("Cannot read evaluation keys");
|
||||
|
||||
@@ -394,8 +394,15 @@ void encodingDestroy(Encoding encoding){C_STRUCT_CLEANER(encoding)}
|
||||
|
||||
KeySet keySetGenerate(ClientParameters params, uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
|
||||
__uint128_t seed = seed_msb;
|
||||
seed <<= 64;
|
||||
seed += seed_lsb;
|
||||
|
||||
auto csprng = concretelang::clientlib::ConcreteCSPRNG(seed);
|
||||
|
||||
auto keySet = mlir::concretelang::clientlib::KeySet::generate(
|
||||
*unwrap(params), seed_msb, seed_lsb);
|
||||
*unwrap(params), std::move(csprng));
|
||||
if (keySet.has_error()) {
|
||||
return wrap((mlir::concretelang::clientlib::KeySet *)NULL,
|
||||
keySet.error().mesg);
|
||||
@@ -445,8 +452,8 @@ BufferRef evaluationKeysSerialize(EvaluationKeys keys) {
|
||||
|
||||
EvaluationKeys evaluationKeysUnserialize(BufferRef buffer) {
|
||||
std::stringstream istream(std::string(buffer.data, buffer.length));
|
||||
concretelang::clientlib::EvaluationKeys evaluationKeys;
|
||||
concretelang::clientlib::operator>>(istream, evaluationKeys);
|
||||
concretelang::clientlib::EvaluationKeys evaluationKeys =
|
||||
concretelang::clientlib::readEvaluationKeys(istream);
|
||||
if (istream.fail()) {
|
||||
return wrap((concretelang::clientlib::EvaluationKeys *)NULL,
|
||||
"input stream failure during evaluation keys unserialization");
|
||||
|
||||
@@ -2,6 +2,7 @@ add_mlir_library(
|
||||
ConcretelangClientLib
|
||||
ClientLambda.cpp
|
||||
ClientParameters.cpp
|
||||
EvaluationKeys.cpp
|
||||
CRT.cpp
|
||||
EncryptedArguments.cpp
|
||||
KeySet.cpp
|
||||
@@ -12,4 +13,6 @@ add_mlir_library(
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/ClientLib
|
||||
LINK_LIBS
|
||||
PUBLIC
|
||||
Concrete)
|
||||
concrete_cpu)
|
||||
|
||||
target_include_directories(ConcretelangClientLib PUBLIC ${CONCRETE_CPU_INCLUDE_DIR})
|
||||
|
||||
@@ -39,28 +39,25 @@ void KeyswitchKeyParam::hash(size_t &seed) {
|
||||
double_to_bits(variance));
|
||||
}
|
||||
|
||||
void PackingKeySwitchParam::hash(size_t &seed) {
|
||||
hash_(seed, inputSecretKeyID, outputSecretKeyID, bootstrapKeyID, level,
|
||||
baseLog, double_to_bits(variance));
|
||||
void PackingKeyswitchKeyParam::hash(size_t &seed) {
|
||||
hash_(seed, inputSecretKeyID, outputSecretKeyID, level, baseLog,
|
||||
glweDimension, polynomialSize, inputLweDimension,
|
||||
double_to_bits(variance));
|
||||
}
|
||||
|
||||
std::size_t ClientParameters::hash() {
|
||||
std::size_t currentHash = 1;
|
||||
for (auto secretKeyParam : secretKeys) {
|
||||
hash_(currentHash, secretKeyParam.first);
|
||||
secretKeyParam.second.hash(currentHash);
|
||||
secretKeyParam.hash(currentHash);
|
||||
}
|
||||
for (auto bootstrapKeyParam : bootstrapKeys) {
|
||||
hash_(currentHash, bootstrapKeyParam.first);
|
||||
bootstrapKeyParam.second.hash(currentHash);
|
||||
bootstrapKeyParam.hash(currentHash);
|
||||
}
|
||||
for (auto keyswitchParam : keyswitchKeys) {
|
||||
hash_(currentHash, keyswitchParam.first);
|
||||
keyswitchParam.second.hash(currentHash);
|
||||
keyswitchParam.hash(currentHash);
|
||||
}
|
||||
for (auto packingParam : packingKeys) {
|
||||
hash_(currentHash, packingParam.first);
|
||||
packingParam.second.hash(currentHash);
|
||||
for (auto packingKeyswitchKeyParam : packingKeyswitchKeys) {
|
||||
packingKeyswitchKeyParam.hash(currentHash);
|
||||
}
|
||||
return currentHash;
|
||||
}
|
||||
@@ -74,18 +71,8 @@ llvm::json::Value toJSON(const LweSecretKeyParam &v) {
|
||||
|
||||
bool fromJSON(const llvm::json::Value j, LweSecretKeyParam &v,
|
||||
llvm::json::Path p) {
|
||||
auto obj = j.getAsObject();
|
||||
if (obj == nullptr) {
|
||||
p.report("should be an object");
|
||||
return false;
|
||||
}
|
||||
auto dimension = obj->getInteger("dimension");
|
||||
if (!dimension.hasValue()) {
|
||||
p.report("missing size field");
|
||||
return false;
|
||||
}
|
||||
v.dimension = *dimension;
|
||||
return true;
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
return O && O.map("dimension", v.dimension);
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const BootstrapKeyParam &v) {
|
||||
@@ -96,54 +83,22 @@ llvm::json::Value toJSON(const BootstrapKeyParam &v) {
|
||||
{"glweDimension", v.glweDimension},
|
||||
{"baseLog", v.baseLog},
|
||||
{"variance", v.variance},
|
||||
{"polynomialSize", v.polynomialSize},
|
||||
{"inputLweDimension", v.inputLweDimension},
|
||||
};
|
||||
return object;
|
||||
}
|
||||
|
||||
bool fromJSON(const llvm::json::Value j, BootstrapKeyParam &v,
|
||||
llvm::json::Path p) {
|
||||
auto obj = j.getAsObject();
|
||||
if (obj == nullptr) {
|
||||
p.report("should be an object");
|
||||
return false;
|
||||
}
|
||||
auto inputSecretKeyID = obj->getString("inputSecretKeyID");
|
||||
if (!inputSecretKeyID.hasValue()) {
|
||||
p.report("missing inputSecretKeyID field");
|
||||
return false;
|
||||
}
|
||||
auto outputSecretKeyID = obj->getString("outputSecretKeyID");
|
||||
if (!outputSecretKeyID.hasValue()) {
|
||||
p.report("missing outputSecretKeyID field");
|
||||
return false;
|
||||
}
|
||||
auto level = obj->getInteger("level");
|
||||
if (!level.hasValue()) {
|
||||
p.report("missing level field");
|
||||
return false;
|
||||
}
|
||||
auto baseLog = obj->getInteger("baseLog");
|
||||
if (!baseLog.hasValue()) {
|
||||
p.report("missing baseLog field");
|
||||
return false;
|
||||
}
|
||||
auto glweDimension = obj->getInteger("glweDimension");
|
||||
if (!glweDimension.hasValue()) {
|
||||
p.report("missing glweDimension field");
|
||||
return false;
|
||||
}
|
||||
auto variance = obj->getNumber("variance");
|
||||
if (!variance.hasValue()) {
|
||||
p.report("missing variance field");
|
||||
return false;
|
||||
}
|
||||
v.inputSecretKeyID = (std::string)inputSecretKeyID.getValue();
|
||||
v.outputSecretKeyID = (std::string)outputSecretKeyID.getValue();
|
||||
v.level = level.getValue();
|
||||
v.baseLog = baseLog.getValue();
|
||||
v.glweDimension = glweDimension.getValue();
|
||||
v.variance = variance.getValue();
|
||||
return true;
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
return O && O.map("inputSecretKeyID", v.inputSecretKeyID) &&
|
||||
O.map("outputSecretKeyID", v.outputSecretKeyID) &&
|
||||
O.map("level", v.level) && O.map("baseLog", v.baseLog) &&
|
||||
O.map("glweDimension", v.glweDimension) &&
|
||||
O.map("variance", v.variance) &&
|
||||
O.map("polynomialSize", v.polynomialSize) &&
|
||||
O.map("inputLweDimension", v.inputLweDimension);
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const KeyswitchKeyParam &v) {
|
||||
@@ -158,99 +113,37 @@ llvm::json::Value toJSON(const KeyswitchKeyParam &v) {
|
||||
}
|
||||
bool fromJSON(const llvm::json::Value j, KeyswitchKeyParam &v,
|
||||
llvm::json::Path p) {
|
||||
auto obj = j.getAsObject();
|
||||
if (obj == nullptr) {
|
||||
p.report("should be an object");
|
||||
return false;
|
||||
}
|
||||
auto inputSecretKeyID = obj->getString("inputSecretKeyID");
|
||||
if (!inputSecretKeyID.hasValue()) {
|
||||
p.report("missing inputSecretKeyID field");
|
||||
return false;
|
||||
}
|
||||
auto outputSecretKeyID = obj->getString("outputSecretKeyID");
|
||||
if (!outputSecretKeyID.hasValue()) {
|
||||
p.report("missing outputSecretKeyID field");
|
||||
return false;
|
||||
}
|
||||
auto level = obj->getInteger("level");
|
||||
if (!level.hasValue()) {
|
||||
p.report("missing level field");
|
||||
return false;
|
||||
}
|
||||
auto baseLog = obj->getInteger("baseLog");
|
||||
if (!baseLog.hasValue()) {
|
||||
p.report("missing baseLog field");
|
||||
return false;
|
||||
}
|
||||
auto variance = obj->getNumber("variance");
|
||||
if (!variance.hasValue()) {
|
||||
p.report("missing variance field");
|
||||
return false;
|
||||
}
|
||||
v.inputSecretKeyID = (std::string)inputSecretKeyID.getValue();
|
||||
v.outputSecretKeyID = (std::string)outputSecretKeyID.getValue();
|
||||
v.level = level.getValue();
|
||||
v.baseLog = baseLog.getValue();
|
||||
v.variance = variance.getValue();
|
||||
return true;
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
return O && O.map("inputSecretKeyID", v.inputSecretKeyID) &&
|
||||
O.map("outputSecretKeyID", v.outputSecretKeyID) &&
|
||||
O.map("level", v.level) && O.map("baseLog", v.baseLog) &&
|
||||
O.map("variance", v.variance);
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const PackingKeySwitchParam &v) {
|
||||
llvm::json::Value toJSON(const PackingKeyswitchKeyParam &v) {
|
||||
llvm::json::Object object{
|
||||
{"inputSecretKeyID", v.inputSecretKeyID},
|
||||
{"outputSecretKeyID", v.outputSecretKeyID},
|
||||
{"bootstrapKeyID", v.bootstrapKeyID},
|
||||
{"level", v.level},
|
||||
{"baseLog", v.baseLog},
|
||||
{"glweDimension", v.glweDimension},
|
||||
{"polynomialSize", v.polynomialSize},
|
||||
{"inputLweDimension", v.inputLweDimension},
|
||||
{"variance", v.variance},
|
||||
};
|
||||
return object;
|
||||
}
|
||||
bool fromJSON(const llvm::json::Value j, PackingKeySwitchParam &v,
|
||||
bool fromJSON(const llvm::json::Value j, PackingKeyswitchKeyParam &v,
|
||||
llvm::json::Path p) {
|
||||
auto obj = j.getAsObject();
|
||||
if (obj == nullptr) {
|
||||
p.report("should be an object");
|
||||
return false;
|
||||
}
|
||||
auto inputSecretKeyID = obj->getString("inputSecretKeyID");
|
||||
if (!inputSecretKeyID.hasValue()) {
|
||||
p.report("missing inputSecretKeyID field");
|
||||
return false;
|
||||
}
|
||||
auto outputSecretKeyID = obj->getString("outputSecretKeyID");
|
||||
if (!outputSecretKeyID.hasValue()) {
|
||||
p.report("missing outputSecretKeyID field");
|
||||
return false;
|
||||
}
|
||||
auto bootstrapKeyID = obj->getString("bootstrapKeyID");
|
||||
if (!bootstrapKeyID.hasValue()) {
|
||||
p.report("missing bootstrapKeyID field");
|
||||
return false;
|
||||
}
|
||||
auto level = obj->getInteger("level");
|
||||
if (!level.hasValue()) {
|
||||
p.report("missing level field");
|
||||
return false;
|
||||
}
|
||||
auto baseLog = obj->getInteger("baseLog");
|
||||
if (!baseLog.hasValue()) {
|
||||
p.report("missing baseLog field");
|
||||
return false;
|
||||
}
|
||||
auto variance = obj->getNumber("variance");
|
||||
if (!variance.hasValue()) {
|
||||
p.report("missing variance field");
|
||||
return false;
|
||||
}
|
||||
v.inputSecretKeyID = (std::string)inputSecretKeyID.getValue();
|
||||
v.outputSecretKeyID = (std::string)outputSecretKeyID.getValue();
|
||||
v.bootstrapKeyID = (std::string)bootstrapKeyID.getValue();
|
||||
v.level = level.getValue();
|
||||
v.baseLog = baseLog.getValue();
|
||||
v.variance = variance.getValue();
|
||||
return true;
|
||||
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
return O && O.map("inputSecretKeyID", v.inputSecretKeyID) &&
|
||||
O.map("outputSecretKeyID", v.outputSecretKeyID) &&
|
||||
O.map("level", v.level) && O.map("baseLog", v.baseLog) &&
|
||||
O.map("glweDimension", v.glweDimension) &&
|
||||
O.map("polynomialSize", v.polynomialSize) &&
|
||||
O.map("inputLweDimension", v.inputLweDimension) &&
|
||||
O.map("variance", v.variance);
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const CircuitGateShape &v) {
|
||||
@@ -264,43 +157,10 @@ llvm::json::Value toJSON(const CircuitGateShape &v) {
|
||||
}
|
||||
bool fromJSON(const llvm::json::Value j, CircuitGateShape &v,
|
||||
llvm::json::Path p) {
|
||||
auto obj = j.getAsObject();
|
||||
if (obj == nullptr) {
|
||||
p.report("should be an object");
|
||||
return false;
|
||||
}
|
||||
auto width = obj->getInteger("width");
|
||||
if (!width.hasValue()) {
|
||||
p.report("missing width field");
|
||||
return false;
|
||||
}
|
||||
auto dimensions = obj->getArray("dimensions");
|
||||
if (dimensions == nullptr) {
|
||||
p.report("missing dimensions field");
|
||||
return false;
|
||||
}
|
||||
for (auto dim : *dimensions) {
|
||||
auto iDim = dim.getAsInteger();
|
||||
if (!iDim.hasValue()) {
|
||||
p.report("dimensions must be integer");
|
||||
return false;
|
||||
}
|
||||
v.dimensions.push_back(iDim.getValue());
|
||||
}
|
||||
auto size = obj->getInteger("size");
|
||||
if (!size.hasValue()) {
|
||||
p.report("missing size field");
|
||||
return false;
|
||||
}
|
||||
auto sign = obj->getBoolean("sign");
|
||||
if (!sign.hasValue()) {
|
||||
p.report("missing sign field");
|
||||
return false;
|
||||
}
|
||||
v.width = width.getValue();
|
||||
v.size = size.getValue();
|
||||
v.sign = sign.getValue();
|
||||
return true;
|
||||
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
return O && O.map("width", v.width) && O.map("size", v.size) &&
|
||||
O.map("dimensions", v.dimensions) && O.map("sign", v.sign);
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const Encoding &v) {
|
||||
@@ -314,35 +174,13 @@ llvm::json::Value toJSON(const Encoding &v) {
|
||||
return object;
|
||||
}
|
||||
bool fromJSON(const llvm::json::Value j, Encoding &v, llvm::json::Path p) {
|
||||
auto obj = j.getAsObject();
|
||||
if (obj == nullptr) {
|
||||
p.report("should be an object");
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
if (!(O && O.map("precision", v.precision) &&
|
||||
O.map("isSigned", v.isSigned))) {
|
||||
return false;
|
||||
}
|
||||
auto precision = obj->getInteger("precision");
|
||||
if (!precision.hasValue()) {
|
||||
p.report("missing precision field");
|
||||
return false;
|
||||
}
|
||||
v.precision = precision.getValue();
|
||||
auto isSigned = obj->getBoolean("isSigned");
|
||||
if (!isSigned.hasValue()) {
|
||||
p.report("missing isSigned field");
|
||||
return false;
|
||||
}
|
||||
v.isSigned = isSigned.getValue();
|
||||
auto crt = obj->getArray("crt");
|
||||
if (crt != nullptr) {
|
||||
for (auto dim : *crt) {
|
||||
auto iDim = dim.getAsInteger();
|
||||
if (!iDim.hasValue()) {
|
||||
p.report("dimensions must be integer");
|
||||
return false;
|
||||
}
|
||||
v.crt.push_back(iDim.getValue());
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: check this is correct for an optional field
|
||||
O.map("crt", v.crt);
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -356,32 +194,9 @@ llvm::json::Value toJSON(const EncryptionGate &v) {
|
||||
}
|
||||
bool fromJSON(const llvm::json::Value j, EncryptionGate &v,
|
||||
llvm::json::Path p) {
|
||||
auto obj = j.getAsObject();
|
||||
if (obj == nullptr) {
|
||||
p.report("should be an object");
|
||||
return false;
|
||||
}
|
||||
auto secretKeyID = obj->getString("secretKeyID");
|
||||
if (!secretKeyID.hasValue()) {
|
||||
p.report("missing secretKeyID field");
|
||||
return false;
|
||||
}
|
||||
v.secretKeyID = (std::string)secretKeyID.getValue();
|
||||
auto variance = obj->getNumber("variance");
|
||||
if (!variance.hasValue()) {
|
||||
p.report("missing variance field");
|
||||
return false;
|
||||
}
|
||||
v.variance = variance.getValue();
|
||||
auto encoding = obj->get("encoding");
|
||||
if (encoding == nullptr) {
|
||||
p.report("missing encoding field");
|
||||
return false;
|
||||
}
|
||||
if (!fromJSON(*encoding, v.encoding, p.field("encoding"))) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
return O && O.map("secretKeyID", v.secretKeyID) &&
|
||||
O.map("variance", v.variance) && O.map("encoding", v.encoding);
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const CircuitGate &v) {
|
||||
@@ -392,40 +207,16 @@ llvm::json::Value toJSON(const CircuitGate &v) {
|
||||
return object;
|
||||
}
|
||||
bool fromJSON(const llvm::json::Value j, CircuitGate &v, llvm::json::Path p) {
|
||||
auto obj = j.getAsObject();
|
||||
auto encryption = obj->get("encryption");
|
||||
if (encryption == nullptr) {
|
||||
p.report("missing encryption field");
|
||||
return false;
|
||||
}
|
||||
if (!fromJSON(*encryption, v.encryption, p.field("encryption"))) {
|
||||
return false;
|
||||
}
|
||||
auto shape = obj->get("shape");
|
||||
if (shape == nullptr) {
|
||||
p.report("missing shape field");
|
||||
return false;
|
||||
}
|
||||
if (!fromJSON(*shape, v.shape, p.field("shape"))) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T> llvm::json::Value toJson(std::map<std::string, T> map) {
|
||||
llvm::json::Object obj;
|
||||
for (auto entry : map) {
|
||||
obj[entry.first] = entry.second;
|
||||
}
|
||||
return obj;
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
return O && O.map("encryption", v.encryption) && O.map("shape", v.shape);
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const ClientParameters &v) {
|
||||
llvm::json::Object object{
|
||||
{"secretKeys", toJson(v.secretKeys)},
|
||||
{"bootstrapKeys", toJson(v.bootstrapKeys)},
|
||||
{"keyswitchKeys", toJson(v.keyswitchKeys)},
|
||||
{"packingKeys", toJson(v.packingKeys)},
|
||||
{"secretKeys", v.secretKeys},
|
||||
{"bootstrapKeys", v.bootstrapKeys},
|
||||
{"keyswitchKeys", v.keyswitchKeys},
|
||||
{"packingKeyswitchKeys", v.packingKeyswitchKeys},
|
||||
{"inputs", v.inputs},
|
||||
{"outputs", v.outputs},
|
||||
{"functionName", v.functionName},
|
||||
@@ -434,64 +225,13 @@ llvm::json::Value toJSON(const ClientParameters &v) {
|
||||
}
|
||||
bool fromJSON(const llvm::json::Value j, ClientParameters &v,
|
||||
llvm::json::Path p) {
|
||||
|
||||
auto obj = j.getAsObject();
|
||||
auto secretkeys = obj->get("secretKeys");
|
||||
if (secretkeys == nullptr) {
|
||||
p.report("missing secretKeys field");
|
||||
return false;
|
||||
}
|
||||
if (!fromJSON(*secretkeys, v.secretKeys, p.field("secretKeys"))) {
|
||||
return false;
|
||||
}
|
||||
auto bootstrapKeys = obj->get("bootstrapKeys");
|
||||
if (bootstrapKeys == nullptr) {
|
||||
p.report("missing bootstrapKeys field");
|
||||
return false;
|
||||
}
|
||||
if (!fromJSON(*bootstrapKeys, v.bootstrapKeys, p.field("bootstrapKeys"))) {
|
||||
return false;
|
||||
}
|
||||
auto keyswitchKeys = obj->get("keyswitchKeys");
|
||||
if (keyswitchKeys == nullptr) {
|
||||
p.report("missing keyswitchKeys field");
|
||||
return false;
|
||||
}
|
||||
if (!fromJSON(*keyswitchKeys, v.keyswitchKeys, p.field("keyswitchKeys"))) {
|
||||
return false;
|
||||
}
|
||||
auto packingKeys = obj->get("packingKeys");
|
||||
if (packingKeys == nullptr) {
|
||||
p.report("missing packingKeys field");
|
||||
return false;
|
||||
}
|
||||
if (!fromJSON(*packingKeys, v.packingKeys, p.field("packingKeys"))) {
|
||||
return false;
|
||||
}
|
||||
auto inputs = obj->get("inputs");
|
||||
if (inputs == nullptr) {
|
||||
p.report("missing inputs field");
|
||||
return false;
|
||||
}
|
||||
if (!fromJSON(*inputs, v.inputs, p.field("inputs"))) {
|
||||
return false;
|
||||
}
|
||||
auto outputs = obj->get("outputs");
|
||||
if (outputs == nullptr) {
|
||||
p.report("missing outputs field");
|
||||
return false;
|
||||
}
|
||||
if (!fromJSON(*outputs, v.outputs, p.field("outputs"))) {
|
||||
return false;
|
||||
}
|
||||
auto functionName = obj->getString("functionName");
|
||||
if (!functionName.hasValue()) {
|
||||
p.report("missing functionName field");
|
||||
return false;
|
||||
}
|
||||
v.functionName = (std::string)functionName.getValue();
|
||||
|
||||
return true;
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
return O && O.map("secretKeys", v.secretKeys) &&
|
||||
O.map("bootstrapKeys", v.bootstrapKeys) &&
|
||||
O.map("keyswitchKeys", v.keyswitchKeys) &&
|
||||
O.map("packingKeyswitchKeys", v.packingKeyswitchKeys) &&
|
||||
O.map("inputs", v.inputs) && O.map("outputs", v.outputs) &&
|
||||
O.map("functionName", v.functionName);
|
||||
}
|
||||
|
||||
std::string ClientParameters::getClientParametersPath(std::string path) {
|
||||
|
||||
@@ -12,8 +12,7 @@ namespace clientlib {
|
||||
using StringError = concretelang::error::StringError;
|
||||
|
||||
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
|
||||
EncryptedArguments::exportPublicArguments(ClientParameters clientParameters,
|
||||
RuntimeContext runtimeContext) {
|
||||
EncryptedArguments::exportPublicArguments(ClientParameters clientParameters) {
|
||||
return std::make_unique<PublicArguments>(
|
||||
clientParameters, std::move(preparedArgs), std::move(ciphertextBuffers));
|
||||
}
|
||||
|
||||
137
compiler/lib/ClientLib/EvaluationKeys.cpp
Normal file
137
compiler/lib/ClientLib/EvaluationKeys.cpp
Normal file
@@ -0,0 +1,137 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "concrete-cpu.h"
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
ConcreteCSPRNG::ConcreteCSPRNG(__uint128_t seed)
|
||||
: CSPRNG(nullptr, &CONCRETE_CSPRNG_VTABLE) {
|
||||
ptr = (Csprng *)aligned_alloc(CONCRETE_CSPRNG_ALIGN, CONCRETE_CSPRNG_SIZE);
|
||||
struct Uint128 u128;
|
||||
if (seed == 0) {
|
||||
switch (concrete_cpu_crypto_secure_random_128(&u128)) {
|
||||
case 1:
|
||||
break;
|
||||
case -1:
|
||||
llvm::errs()
|
||||
<< "WARNING: The generated random seed is not crypto secure\n";
|
||||
break;
|
||||
default:
|
||||
assert(false && "Cannot instantiate a random seed");
|
||||
}
|
||||
|
||||
} else {
|
||||
for (int i = 0; i < 16; i++) {
|
||||
u128.little_endian_bytes[i] = seed >> (8 * i);
|
||||
}
|
||||
}
|
||||
concrete_cpu_construct_concrete_csprng(ptr, u128);
|
||||
}
|
||||
|
||||
ConcreteCSPRNG::ConcreteCSPRNG(ConcreteCSPRNG &&other)
|
||||
: CSPRNG(other.ptr, &CONCRETE_CSPRNG_VTABLE) {
|
||||
assert(ptr != nullptr);
|
||||
other.ptr = nullptr;
|
||||
}
|
||||
|
||||
ConcreteCSPRNG::~ConcreteCSPRNG() {
|
||||
if (ptr != nullptr) {
|
||||
concrete_cpu_destroy_concrete_csprng(ptr);
|
||||
free(ptr);
|
||||
}
|
||||
}
|
||||
|
||||
LweSecretKey::LweSecretKey(LweSecretKeyParam ¶meters, CSPRNG &csprng)
|
||||
: _parameters(parameters) {
|
||||
// Allocate the buffer
|
||||
_buffer = std::make_shared<std::vector<uint64_t>>();
|
||||
_buffer->resize(parameters.dimension);
|
||||
// Initialize the lwe secret key buffer
|
||||
concrete_cpu_init_lwe_secret_key_u64(_buffer->data(), parameters.dimension,
|
||||
csprng.ptr, csprng.vtable);
|
||||
}
|
||||
|
||||
void LweSecretKey::encrypt(uint64_t *ciphertext, uint64_t plaintext,
|
||||
double variance, CSPRNG &csprng) const {
|
||||
concrete_cpu_encrypt_lwe_ciphertext_u64(_buffer->data(), ciphertext,
|
||||
plaintext, parameters().dimension,
|
||||
variance, csprng.ptr, csprng.vtable);
|
||||
}
|
||||
|
||||
void LweSecretKey::decrypt(const uint64_t *ciphertext,
|
||||
uint64_t &plaintext) const {
|
||||
concrete_cpu_decrypt_lwe_ciphertext_u64(_buffer->data(), ciphertext,
|
||||
parameters().dimension, &plaintext);
|
||||
}
|
||||
|
||||
LweKeyswitchKey::LweKeyswitchKey(KeyswitchKeyParam ¶meters,
|
||||
LweSecretKey &inputKey,
|
||||
LweSecretKey &outputKey, CSPRNG &csprng)
|
||||
: _parameters(parameters) {
|
||||
// Allocate the buffer
|
||||
auto size = concrete_cpu_keyswitch_key_size_u64(
|
||||
_parameters.level, _parameters.baseLog, inputKey.dimension(),
|
||||
outputKey.dimension());
|
||||
_buffer = std::make_shared<std::vector<uint64_t>>();
|
||||
_buffer->resize(size);
|
||||
|
||||
// Initialize the keyswitch key buffer
|
||||
concrete_cpu_init_lwe_keyswitch_key_u64(
|
||||
_buffer->data(), inputKey.buffer(), outputKey.buffer(),
|
||||
inputKey.dimension(), outputKey.dimension(), _parameters.level,
|
||||
_parameters.baseLog, _parameters.variance, csprng.ptr, csprng.vtable);
|
||||
}
|
||||
|
||||
LweBootstrapKey::LweBootstrapKey(BootstrapKeyParam ¶meters,
|
||||
LweSecretKey &inputKey,
|
||||
LweSecretKey &outputKey, CSPRNG &csprng)
|
||||
: _parameters(parameters) {
|
||||
// TODO
|
||||
size_t polynomial_size = outputKey.dimension() / _parameters.glweDimension;
|
||||
// Allocate the buffer
|
||||
auto size = concrete_cpu_bootstrap_key_size_u64(
|
||||
_parameters.level, _parameters.glweDimension, polynomial_size,
|
||||
inputKey.dimension());
|
||||
_buffer = std::make_shared<std::vector<uint64_t>>();
|
||||
_buffer->resize(size);
|
||||
|
||||
// Initialize the keyswitch key buffer
|
||||
concrete_cpu_init_lwe_bootstrap_key_u64(
|
||||
_buffer->data(), inputKey.buffer(), outputKey.buffer(),
|
||||
inputKey.dimension(), polynomial_size, _parameters.glweDimension,
|
||||
_parameters.level, _parameters.baseLog, _parameters.variance,
|
||||
Parallelism::Rayon, csprng.ptr, csprng.vtable);
|
||||
}
|
||||
|
||||
PackingKeyswitchKey::PackingKeyswitchKey(PackingKeyswitchKeyParam ¶ms,
|
||||
LweSecretKey &inputKey,
|
||||
LweSecretKey &outputKey,
|
||||
CSPRNG &csprng)
|
||||
: _parameters(params) {
|
||||
assert(_parameters.inputLweDimension == inputKey.dimension());
|
||||
assert(_parameters.glweDimension * _parameters.polynomialSize ==
|
||||
outputKey.dimension());
|
||||
|
||||
// Allocate the buffer
|
||||
auto size = concrete_cpu_lwe_packing_keyswitch_key_size(
|
||||
_parameters.glweDimension, _parameters.polynomialSize, _parameters.level,
|
||||
_parameters.inputLweDimension);
|
||||
_buffer = std::make_shared<std::vector<uint64_t>>();
|
||||
_buffer->resize(size * (_parameters.glweDimension + 1));
|
||||
|
||||
// Initialize the keyswitch key buffer
|
||||
concrete_cpu_init_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_u64(
|
||||
_buffer->data(), inputKey.buffer(), outputKey.buffer(),
|
||||
_parameters.inputLweDimension, _parameters.polynomialSize,
|
||||
_parameters.glweDimension, _parameters.level, _parameters.baseLog,
|
||||
_parameters.variance, Parallelism::Rayon, csprng.ptr, csprng.vtable);
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
@@ -6,269 +6,142 @@
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
#include "concretelang/ClientLib/CRT.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Runtime/seeder.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
|
||||
#define CAPI_ERR_TO_STRINGERROR(instr, msg) \
|
||||
{ \
|
||||
int err; \
|
||||
instr; \
|
||||
if (err != 0) { \
|
||||
return concretelang::error::StringError(msg); \
|
||||
} \
|
||||
}
|
||||
|
||||
int clone_transform_lwe_secret_key_to_glwe_secret_key_u64(
|
||||
DefaultEngine *default_engine, LweSecretKey64 *output_lwe_sk,
|
||||
size_t poly_size, GlweSecretKey64 **output_glwe_sk) {
|
||||
LweSecretKey64 *output_lwe_sk_clone = NULL;
|
||||
int lwe_out_sk_clone_ok =
|
||||
clone_lwe_secret_key_u64(output_lwe_sk, &output_lwe_sk_clone);
|
||||
if (lwe_out_sk_clone_ok != 0) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int glwe_sk_ok =
|
||||
default_engine_transform_lwe_secret_key_to_glwe_secret_key_u64(
|
||||
default_engine, &output_lwe_sk_clone, poly_size, output_glwe_sk);
|
||||
if (glwe_sk_ok != 0) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (output_lwe_sk_clone != NULL) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
KeySet::KeySet() {
|
||||
|
||||
CAPI_ASSERT_ERROR(new_default_engine(best_seeder, &engine));
|
||||
|
||||
CAPI_ASSERT_ERROR(new_default_parallel_engine(best_seeder, &par_engine));
|
||||
}
|
||||
|
||||
KeySet::~KeySet() {
|
||||
for (auto it : secretKeys) {
|
||||
CAPI_ASSERT_ERROR(destroy_lwe_secret_key_u64(it.second.second));
|
||||
}
|
||||
|
||||
CAPI_ASSERT_ERROR(destroy_default_engine(engine));
|
||||
CAPI_ASSERT_ERROR(destroy_default_parallel_engine(par_engine));
|
||||
}
|
||||
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
KeySet::generate(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
auto keySet = std::make_unique<KeySet>();
|
||||
OUTCOME_TRYV(keySet->generateKeysFromParams(params, seed_msb, seed_lsb));
|
||||
OUTCOME_TRYV(keySet->setupEncryptionMaterial(params, seed_msb, seed_lsb));
|
||||
KeySet::generate(ClientParameters clientParameters, CSPRNG &&csprng) {
|
||||
auto keySet = std::make_unique<KeySet>(clientParameters, std::move(csprng));
|
||||
OUTCOME_TRYV(keySet->generateKeysFromParams());
|
||||
OUTCOME_TRYV(keySet->setupEncryptionMaterial());
|
||||
return std::move(keySet);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::setupEncryptionMaterial(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
_clientParameters = params;
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError> KeySet::fromKeys(
|
||||
ClientParameters clientParameters, std::vector<LweSecretKey> secretKeys,
|
||||
std::vector<LweBootstrapKey> bootstrapKeys,
|
||||
std::vector<LweKeyswitchKey> keyswitchKeys,
|
||||
std::vector<PackingKeyswitchKey> packingKeyswitchKeys, CSPRNG &&csprng) {
|
||||
|
||||
// Set inputs and outputs LWE secret keys
|
||||
{
|
||||
for (auto param : params.inputs) {
|
||||
LweSecretKeyParam secretKeyParam = {0};
|
||||
LweSecretKey64 *secretKey = nullptr;
|
||||
if (param.encryption.hasValue()) {
|
||||
auto inputSk = this->secretKeys.find(param.encryption->secretKeyID);
|
||||
if (inputSk == this->secretKeys.end()) {
|
||||
return StringError("input encryption secret key (")
|
||||
<< param.encryption->secretKeyID << ") does not exist ";
|
||||
}
|
||||
secretKeyParam = inputSk->second.first;
|
||||
secretKey = inputSk->second.second;
|
||||
}
|
||||
std::tuple<CircuitGate, LweSecretKeyParam, LweSecretKey64 *> input = {
|
||||
param, secretKeyParam, secretKey};
|
||||
this->inputs.push_back(input);
|
||||
}
|
||||
for (auto param : params.outputs) {
|
||||
LweSecretKeyParam secretKeyParam = {0};
|
||||
LweSecretKey64 *secretKey = nullptr;
|
||||
if (param.encryption.hasValue()) {
|
||||
auto outputSk = this->secretKeys.find(param.encryption->secretKeyID);
|
||||
if (outputSk == this->secretKeys.end()) {
|
||||
return StringError(
|
||||
"cannot find output key to generate bootstrap key");
|
||||
}
|
||||
secretKeyParam = outputSk->second.first;
|
||||
secretKey = outputSk->second.second;
|
||||
}
|
||||
std::tuple<CircuitGate, LweSecretKeyParam, LweSecretKey64 *> output = {
|
||||
param, secretKeyParam, secretKey};
|
||||
this->outputs.push_back(output);
|
||||
auto keySet = std::make_unique<KeySet>(clientParameters, std::move(csprng));
|
||||
keySet->secretKeys = secretKeys;
|
||||
keySet->bootstrapKeys = bootstrapKeys;
|
||||
keySet->keyswitchKeys = keyswitchKeys;
|
||||
keySet->packingKeyswitchKeys = packingKeyswitchKeys;
|
||||
OUTCOME_TRYV(keySet->setupEncryptionMaterial());
|
||||
return std::move(keySet);
|
||||
}
|
||||
|
||||
EvaluationKeys KeySet::evaluationKeys() {
|
||||
return EvaluationKeys(keyswitchKeys, bootstrapKeys, packingKeyswitchKeys);
|
||||
}
|
||||
|
||||
outcome::checked<KeySet::SecretKeyGateMapping, StringError>
|
||||
KeySet::mapCircuitGateLweSecretKey(std::vector<CircuitGate> gates) {
|
||||
SecretKeyGateMapping mapping;
|
||||
for (auto gate : gates) {
|
||||
if (gate.encryption.hasValue()) {
|
||||
assert(gate.encryption->secretKeyID < this->secretKeys.size());
|
||||
auto skIt = this->secretKeys[gate.encryption->secretKeyID];
|
||||
|
||||
std::pair<CircuitGate, llvm::Optional<LweSecretKey>> input = {gate, skIt};
|
||||
mapping.push_back(input);
|
||||
} else {
|
||||
std::pair<CircuitGate, llvm::Optional<LweSecretKey>> input = {gate,
|
||||
llvm::None};
|
||||
mapping.push_back(input);
|
||||
}
|
||||
}
|
||||
return mapping;
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError> KeySet::setupEncryptionMaterial() {
|
||||
OUTCOME_TRY(this->inputs,
|
||||
mapCircuitGateLweSecretKey(_clientParameters.inputs));
|
||||
OUTCOME_TRY(this->outputs,
|
||||
mapCircuitGateLweSecretKey(_clientParameters.outputs));
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::generateKeysFromParams(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
outcome::checked<void, StringError> KeySet::generateKeysFromParams() {
|
||||
|
||||
{
|
||||
// Generate LWE secret keys
|
||||
for (auto secretKeyParam : params.secretKeys) {
|
||||
OUTCOME_TRYV(
|
||||
this->generateSecretKey(secretKeyParam.first, secretKeyParam.second));
|
||||
}
|
||||
// Generate LWE secret keys
|
||||
for (auto secretKeyParam : _clientParameters.secretKeys) {
|
||||
OUTCOME_TRYV(this->generateSecretKey(secretKeyParam));
|
||||
}
|
||||
// Generate bootstrap, keyswitch and packing keyswitch keys
|
||||
{
|
||||
for (auto bootstrapKeyParam : params.bootstrapKeys) {
|
||||
OUTCOME_TRYV(this->generateBootstrapKey(bootstrapKeyParam.first,
|
||||
bootstrapKeyParam.second));
|
||||
}
|
||||
for (auto keyswitchParam : params.keyswitchKeys) {
|
||||
OUTCOME_TRYV(this->generateKeyswitchKey(keyswitchParam.first,
|
||||
keyswitchParam.second));
|
||||
}
|
||||
for (auto packingParam : params.packingKeys) {
|
||||
OUTCOME_TRYV(
|
||||
this->generatePackingKey(packingParam.first, packingParam.second));
|
||||
}
|
||||
// Generate bootstrap keys
|
||||
for (auto bootstrapKeyParam : _clientParameters.bootstrapKeys) {
|
||||
OUTCOME_TRYV(this->generateBootstrapKey(bootstrapKeyParam));
|
||||
}
|
||||
// Generate keyswitch key
|
||||
for (auto keyswitchParam : _clientParameters.keyswitchKeys) {
|
||||
OUTCOME_TRYV(this->generateKeyswitchKey(keyswitchParam));
|
||||
}
|
||||
// Generate packing keyswitch key
|
||||
for (auto packingKeyswitchKeyParam : _clientParameters.packingKeyswitchKeys) {
|
||||
OUTCOME_TRYV(this->generatePackingKeyswitchKey(packingKeyswitchKeyParam));
|
||||
}
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
void KeySet::setKeys(
|
||||
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey64 *>>
|
||||
secretKeys,
|
||||
std::map<LweSecretKeyID,
|
||||
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>>
|
||||
bootstrapKeys,
|
||||
std::map<LweSecretKeyID,
|
||||
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>>
|
||||
keyswitchKeys,
|
||||
std::map<LweSecretKeyID, std::pair<PackingKeySwitchParam,
|
||||
std::shared_ptr<PackingKeyswitchKey>>>
|
||||
packingKeys) {
|
||||
this->secretKeys = secretKeys;
|
||||
this->bootstrapKeys = bootstrapKeys;
|
||||
this->keyswitchKeys = keyswitchKeys;
|
||||
this->packingKeys = packingKeys;
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param) {
|
||||
LweSecretKey64 *sk;
|
||||
CAPI_ASSERT_ERROR(default_engine_generate_new_lwe_secret_key_u64(
|
||||
engine, param.dimension, &sk));
|
||||
|
||||
secretKeys[id] = {param, sk};
|
||||
|
||||
KeySet::generateSecretKey(LweSecretKeyParam param) {
|
||||
// Init the lwe secret key
|
||||
LweSecretKey sk(param, csprng);
|
||||
// Store the lwe secret key
|
||||
secretKeys.push_back(sk);
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<LweSecretKey, StringError>
|
||||
KeySet::findLweSecretKey(LweSecretKeyID keyID) {
|
||||
assert(keyID < secretKeys.size());
|
||||
auto secretKey = secretKeys[keyID];
|
||||
|
||||
return secretKey;
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param) {
|
||||
KeySet::generateBootstrapKey(BootstrapKeyParam param) {
|
||||
// Finding input and output secretKeys
|
||||
auto inputSk = secretKeys.find(param.inputSecretKeyID);
|
||||
if (inputSk == secretKeys.end()) {
|
||||
return StringError("cannot find input key to generate bootstrap key");
|
||||
}
|
||||
auto outputSk = secretKeys.find(param.outputSecretKeyID);
|
||||
if (outputSk == secretKeys.end()) {
|
||||
return StringError("cannot find output key to generate bootstrap key");
|
||||
}
|
||||
// Allocate the bootstrap key
|
||||
LweBootstrapKey64 *bsk;
|
||||
|
||||
uint64_t total_dimension = outputSk->second.first.dimension;
|
||||
|
||||
assert(total_dimension % param.glweDimension == 0);
|
||||
|
||||
uint64_t polynomialSize = total_dimension / param.glweDimension;
|
||||
|
||||
GlweSecretKey64 *output_glwe_sk = nullptr;
|
||||
|
||||
// This is not part of the C FFI but rather is a C util exposed for
|
||||
// convenience in tests.
|
||||
CAPI_ASSERT_ERROR(clone_transform_lwe_secret_key_to_glwe_secret_key_u64(
|
||||
engine, outputSk->second.second, polynomialSize, &output_glwe_sk));
|
||||
|
||||
CAPI_ASSERT_ERROR(default_parallel_engine_generate_new_lwe_bootstrap_key_u64(
|
||||
par_engine, inputSk->second.second, output_glwe_sk, param.baseLog,
|
||||
param.level, param.variance, &bsk));
|
||||
|
||||
CAPI_ASSERT_ERROR(destroy_glwe_secret_key_u64(output_glwe_sk));
|
||||
|
||||
OUTCOME_TRY(auto inputKey, findLweSecretKey(param.inputSecretKeyID));
|
||||
OUTCOME_TRY(auto outputKey, findLweSecretKey(param.outputSecretKeyID));
|
||||
// Initialize the bootstrap key
|
||||
LweBootstrapKey bootstrapKey(param, inputKey, outputKey, csprng);
|
||||
// Store the bootstrap key
|
||||
bootstrapKeys[id] = {param, std::make_shared<LweBootstrapKey>(bsk)};
|
||||
|
||||
bootstrapKeys.push_back(std::move(bootstrapKey));
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param) {
|
||||
KeySet::generateKeyswitchKey(KeyswitchKeyParam param) {
|
||||
// Finding input and output secretKeys
|
||||
auto inputSk = secretKeys.find(param.inputSecretKeyID);
|
||||
if (inputSk == secretKeys.end()) {
|
||||
return StringError("cannot find input key to generate keyswitch key");
|
||||
}
|
||||
auto outputSk = secretKeys.find(param.outputSecretKeyID);
|
||||
if (outputSk == secretKeys.end()) {
|
||||
return StringError("cannot find output key to generate keyswitch key");
|
||||
}
|
||||
// Allocate the keyswitch key
|
||||
LweKeyswitchKey64 *ksk;
|
||||
|
||||
CAPI_ASSERT_ERROR(default_engine_generate_new_lwe_keyswitch_key_u64(
|
||||
engine, inputSk->second.second, outputSk->second.second, param.level,
|
||||
param.baseLog, param.variance, &ksk));
|
||||
|
||||
OUTCOME_TRY(auto inputKey, findLweSecretKey(param.inputSecretKeyID));
|
||||
OUTCOME_TRY(auto outputKey, findLweSecretKey(param.outputSecretKeyID));
|
||||
// Initialize the bootstrap key
|
||||
LweKeyswitchKey keyswitchKey(param, inputKey, outputKey, csprng);
|
||||
// Store the keyswitch key
|
||||
keyswitchKeys[id] = {param, std::make_shared<LweKeyswitchKey>(ksk)};
|
||||
|
||||
keyswitchKeys.push_back(keyswitchKey);
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::generatePackingKey(PackingKeySwitchID id, PackingKeySwitchParam param) {
|
||||
KeySet::generatePackingKeyswitchKey(PackingKeyswitchKeyParam param) {
|
||||
// Finding input secretKeys
|
||||
auto inputSk = secretKeys.find(param.inputSecretKeyID);
|
||||
if (inputSk == secretKeys.end()) {
|
||||
return StringError(
|
||||
"cannot find input key to generate packing keyswitch key");
|
||||
}
|
||||
auto bsk = bootstrapKeys.find(param.bootstrapKeyID);
|
||||
if (bsk == bootstrapKeys.end()) {
|
||||
return StringError(
|
||||
"cannot find input key to generate packing keyswitch key");
|
||||
}
|
||||
assert(param.inputSecretKeyID < secretKeys.size());
|
||||
auto inputSk = secretKeys[param.inputSecretKeyID];
|
||||
|
||||
// This is not part of the C FFI but rather is a C util exposed for
|
||||
// convenience in tests.
|
||||
GlweSecretKey64 *output_glwe_sk = nullptr;
|
||||
|
||||
auto lweDimension =
|
||||
inputSk->second.first.lweDimension() / bsk->second.first.glweDimension;
|
||||
|
||||
CAPI_ASSERT_ERROR(clone_transform_lwe_secret_key_to_glwe_secret_key_u64(
|
||||
engine, inputSk->second.second, lweDimension, &output_glwe_sk));
|
||||
|
||||
// Allocate the packing keyswitch key
|
||||
LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *fpksk;
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_parallel_engine_generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_unchecked_u64(
|
||||
par_engine, inputSk->second.second, output_glwe_sk, param.baseLog,
|
||||
param.level, param.variance, &fpksk));
|
||||
assert(param.outputSecretKeyID < secretKeys.size());
|
||||
auto outputSk = secretKeys[param.outputSecretKeyID];
|
||||
|
||||
PackingKeyswitchKey packingKeyswitchKey(param, inputSk, outputSk, csprng);
|
||||
// Store the keyswitch key
|
||||
packingKeys[id] = {param, std::make_shared<PackingKeyswitchKey>(fpksk)};
|
||||
|
||||
packingKeyswitchKeys.push_back(packingKeyswitchKey);
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
@@ -285,8 +158,9 @@ KeySet::allocate_lwe(size_t argPos, uint64_t **ciphertext, uint64_t &size) {
|
||||
}
|
||||
auto numBlocks =
|
||||
encryption->encoding.crt.empty() ? 1 : encryption->encoding.crt.size();
|
||||
assert(inputSk.second.has_value());
|
||||
|
||||
size = std::get<1>(inputSk).lweSize();
|
||||
size = inputSk.second->parameters().lweSize();
|
||||
*ciphertext = (uint64_t *)malloc(sizeof(uint64_t) * size * numBlocks);
|
||||
return outcome::success();
|
||||
}
|
||||
@@ -315,8 +189,9 @@ KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) {
|
||||
return StringError("encrypt_lwe the positional argument is not encrypted");
|
||||
}
|
||||
auto encoding = encryption->encoding;
|
||||
auto lweSecretKeyParam = std::get<1>(inputSk);
|
||||
auto lweSecretKey = std::get<2>(inputSk);
|
||||
assert(inputSk.second.has_value());
|
||||
auto lweSecretKey = *inputSk.second;
|
||||
auto lweSecretKeyParam = lweSecretKey.parameters();
|
||||
// CRT encoding - N blocks with crt encoding
|
||||
auto crt = encryption->encoding.crt;
|
||||
if (!crt.empty()) {
|
||||
@@ -324,11 +199,7 @@ KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) {
|
||||
auto product = crt::productOfModuli(crt);
|
||||
for (auto modulus : crt) {
|
||||
auto plaintext = crt::encode(input, modulus, product);
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_engine_discard_encrypt_lwe_ciphertext_u64_raw_ptr_buffers(
|
||||
engine, lweSecretKey, ciphertext, plaintext,
|
||||
encryption->variance));
|
||||
|
||||
lweSecretKey.encrypt(ciphertext, plaintext, encryption->variance, csprng);
|
||||
ciphertext = ciphertext + lweSecretKeyParam.lweSize();
|
||||
}
|
||||
return outcome::success();
|
||||
@@ -336,10 +207,7 @@ KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) {
|
||||
// Simple TFHE integers - 1 blocks with one padding bits
|
||||
// TODO we could check if the input value is in the right range
|
||||
uint64_t plaintext = input << (64 - (encryption->encoding.precision + 1));
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_engine_discard_encrypt_lwe_ciphertext_u64_raw_ptr_buffers(
|
||||
engine, lweSecretKey, ciphertext, plaintext, encryption->variance));
|
||||
|
||||
lweSecretKey.encrypt(ciphertext, plaintext, encryption->variance, csprng);
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
@@ -349,8 +217,9 @@ KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) {
|
||||
return StringError("decrypt_lwe: position of argument is too high");
|
||||
}
|
||||
auto outputSk = outputs[argPos];
|
||||
auto lweSecretKey = std::get<2>(outputSk);
|
||||
auto lweSecretKeyParam = std::get<1>(outputSk);
|
||||
assert(outputSk.second.has_value());
|
||||
auto lweSecretKey = *outputSk.second;
|
||||
auto lweSecretKeyParam = lweSecretKey.parameters();
|
||||
auto encryption = std::get<0>(outputSk).encryption;
|
||||
if (!encryption.hasValue()) {
|
||||
return StringError("decrypt_lwe: the positional argument is not encrypted");
|
||||
@@ -358,15 +227,15 @@ KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) {
|
||||
|
||||
auto crt = encryption->encoding.crt;
|
||||
|
||||
if (!crt.empty()) { // The ciphertext used the crt strategy.
|
||||
if (!crt.empty()) {
|
||||
// CRT encoded TFHE integers
|
||||
|
||||
// Decrypt and decode remainders
|
||||
std::vector<int64_t> remainders;
|
||||
for (auto modulus : crt) {
|
||||
uint64_t decrypted;
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_engine_decrypt_lwe_ciphertext_u64_raw_ptr_buffers(
|
||||
engine, lweSecretKey, ciphertext, &decrypted));
|
||||
uint64_t decrypted = 0;
|
||||
lweSecretKey.decrypt(ciphertext, decrypted);
|
||||
|
||||
auto plaintext = crt::decode(decrypted, modulus);
|
||||
remainders.push_back(plaintext);
|
||||
ciphertext = ciphertext + lweSecretKeyParam.lweSize();
|
||||
@@ -386,12 +255,10 @@ KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) {
|
||||
output -= maxPos * 2;
|
||||
}
|
||||
}
|
||||
} else { // The ciphertext used the scalar strategy
|
||||
|
||||
// Decrypt
|
||||
uint64_t plaintext;
|
||||
CAPI_ASSERT_ERROR(default_engine_decrypt_lwe_ciphertext_u64_raw_ptr_buffers(
|
||||
engine, lweSecretKey, ciphertext, &plaintext));
|
||||
} else {
|
||||
// Native encoded TFHE integers - 1 blocks with one padding bits
|
||||
uint64_t plaintext = 0;
|
||||
lweSecretKey.decrypt(ciphertext, plaintext);
|
||||
|
||||
// Decode unsigned integer
|
||||
uint64_t precision = encryption->encoding.precision;
|
||||
@@ -415,27 +282,18 @@ KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) {
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
const std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey64 *>> &
|
||||
KeySet::getSecretKeys() {
|
||||
return secretKeys;
|
||||
}
|
||||
const std::vector<LweSecretKey> &KeySet::getSecretKeys() { return secretKeys; }
|
||||
|
||||
const std::map<LweSecretKeyID,
|
||||
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>> &
|
||||
KeySet::getBootstrapKeys() {
|
||||
const std::vector<LweBootstrapKey> &KeySet::getBootstrapKeys() {
|
||||
return bootstrapKeys;
|
||||
}
|
||||
|
||||
const std::map<LweSecretKeyID,
|
||||
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>> &
|
||||
KeySet::getKeyswitchKeys() {
|
||||
const std::vector<LweKeyswitchKey> &KeySet::getKeyswitchKeys() {
|
||||
return keyswitchKeys;
|
||||
}
|
||||
|
||||
const std::map<LweSecretKeyID, std::pair<PackingKeySwitchParam,
|
||||
std::shared_ptr<PackingKeyswitchKey>>>
|
||||
&KeySet::getPackingKeys() {
|
||||
return packingKeys;
|
||||
const std::vector<PackingKeyswitchKey> &KeySet::getPackingKeyswitchKeys() {
|
||||
return packingKeyswitchKeys;
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/ClientLib/Serializers.h"
|
||||
|
||||
#include "llvm/ADT/ScopeExit.h"
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
#include "llvm/Support/Path.h"
|
||||
@@ -16,216 +18,101 @@
|
||||
#include <string>
|
||||
#include <utime.h>
|
||||
|
||||
#include "concrete-core-ffi.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using StringError = concretelang::error::StringError;
|
||||
|
||||
template <class Engine, class Key>
|
||||
outcome::checked<Key *, StringError>
|
||||
load(llvm::SmallString<0> &path,
|
||||
int (*deser)(Engine *, BufferView buffer, Key **), Engine *engine) {
|
||||
template <class Key>
|
||||
outcome::checked<Key, StringError> loadKey(llvm::SmallString<0> &path,
|
||||
Key(deser)(std::istream &istream)) {
|
||||
std::ifstream in((std::string)path, std::ofstream::binary);
|
||||
if (in.fail()) {
|
||||
return StringError("Cannot access " + (std::string)path);
|
||||
}
|
||||
std::stringstream sbuffer;
|
||||
sbuffer << in.rdbuf();
|
||||
if (in.fail()) {
|
||||
return StringError("Cannot read " + (std::string)path);
|
||||
auto key = deser(in);
|
||||
if (in.bad()) {
|
||||
return StringError("Cannot load key at path(") << (std::string)path << ")";
|
||||
}
|
||||
auto content = sbuffer.str();
|
||||
BufferView buffer = {(const uint8_t *)content.c_str(), content.length()};
|
||||
Key *result = nullptr;
|
||||
|
||||
int error_code = deser(engine, buffer, &result);
|
||||
|
||||
if (result == nullptr || error_code != 0) {
|
||||
return StringError("Cannot deserialize " + (std::string)path);
|
||||
}
|
||||
return result;
|
||||
return key;
|
||||
}
|
||||
|
||||
static void writeFile(llvm::SmallString<0> &path, Buffer content) {
|
||||
template <class Key>
|
||||
outcome::checked<void, StringError> saveKey(llvm::SmallString<0> &path,
|
||||
Key key) {
|
||||
std::ofstream out((std::string)path, std::ofstream::binary);
|
||||
out.write((const char *)content.pointer, content.length);
|
||||
if (out.fail()) {
|
||||
return StringError("Cannot access " + (std::string)path);
|
||||
}
|
||||
out << key;
|
||||
if (out.bad()) {
|
||||
return StringError("Cannot save key at path(") << (std::string)path << ")";
|
||||
}
|
||||
out.close();
|
||||
}
|
||||
|
||||
outcome::checked<LweSecretKey64 *, StringError>
|
||||
loadSecretKey(llvm::SmallString<0> &path) {
|
||||
DefaultSerializationEngine *engine;
|
||||
|
||||
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
|
||||
|
||||
return load(path, default_serialization_engine_deserialize_lwe_secret_key_u64,
|
||||
engine);
|
||||
}
|
||||
|
||||
outcome::checked<LweKeyswitchKey64 *, StringError>
|
||||
loadKeyswitchKey(llvm::SmallString<0> &path) {
|
||||
DefaultSerializationEngine *engine;
|
||||
|
||||
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
|
||||
|
||||
return load(path,
|
||||
default_serialization_engine_deserialize_lwe_keyswitch_key_u64,
|
||||
engine);
|
||||
}
|
||||
|
||||
outcome::checked<LweBootstrapKey64 *, StringError>
|
||||
loadBootstrapKey(llvm::SmallString<0> &path) {
|
||||
|
||||
DefaultSerializationEngine *engine;
|
||||
|
||||
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
|
||||
|
||||
return load(path,
|
||||
default_serialization_engine_deserialize_lwe_bootstrap_key_u64,
|
||||
engine);
|
||||
}
|
||||
|
||||
outcome::checked<LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *,
|
||||
StringError>
|
||||
loadPackingKey(llvm::SmallString<0> &path) {
|
||||
|
||||
DefaultSerializationEngine *engine;
|
||||
|
||||
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
|
||||
|
||||
return load(
|
||||
path,
|
||||
default_serialization_engine_deserialize_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_u64,
|
||||
engine);
|
||||
}
|
||||
|
||||
void saveSecretKey(llvm::SmallString<0> &path, LweSecretKey64 *key) {
|
||||
|
||||
DefaultSerializationEngine *engine;
|
||||
|
||||
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
|
||||
|
||||
Buffer buffer;
|
||||
|
||||
CAPI_ASSERT_ERROR(default_serialization_engine_serialize_lwe_secret_key_u64(
|
||||
engine, key, &buffer));
|
||||
|
||||
writeFile(path, buffer);
|
||||
free(buffer.pointer);
|
||||
}
|
||||
|
||||
void saveBootstrapKey(llvm::SmallString<0> &path, LweBootstrapKey64 *key) {
|
||||
DefaultSerializationEngine *engine;
|
||||
|
||||
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
|
||||
|
||||
Buffer buffer;
|
||||
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_serialization_engine_serialize_lwe_bootstrap_key_u64(engine, key,
|
||||
&buffer));
|
||||
writeFile(path, buffer);
|
||||
free(buffer.pointer);
|
||||
}
|
||||
|
||||
void saveKeyswitchKey(llvm::SmallString<0> &path, LweKeyswitchKey64 *key) {
|
||||
|
||||
DefaultSerializationEngine *engine;
|
||||
|
||||
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
|
||||
|
||||
Buffer buffer;
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_serialization_engine_serialize_lwe_keyswitch_key_u64(engine, key,
|
||||
&buffer));
|
||||
writeFile(path, buffer);
|
||||
free(buffer.pointer);
|
||||
}
|
||||
|
||||
void savePackingKey(
|
||||
llvm::SmallString<0> &path,
|
||||
LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *key) {
|
||||
|
||||
DefaultSerializationEngine *engine;
|
||||
|
||||
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
|
||||
|
||||
Buffer buffer;
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_serialization_engine_serialize_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_u64(
|
||||
engine, key, &buffer));
|
||||
|
||||
writeFile(path, buffer);
|
||||
free(buffer.pointer);
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
KeySetCache::loadKeys(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb, std::string folderPath) {
|
||||
// TODO: text dump of all parameter in /hash
|
||||
auto key_set = std::make_unique<KeySet>();
|
||||
|
||||
// Mark the folder as recently use.
|
||||
// e.g. so the CI can do some cleanup of unused keys.
|
||||
utime(folderPath.c_str(), nullptr);
|
||||
|
||||
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey64 *>>
|
||||
secretKeys;
|
||||
std::map<LweSecretKeyID,
|
||||
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>>
|
||||
bootstrapKeys;
|
||||
std::map<LweSecretKeyID,
|
||||
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>>
|
||||
keyswitchKeys;
|
||||
std::map<LweSecretKeyID, std::pair<PackingKeySwitchParam,
|
||||
std::shared_ptr<PackingKeyswitchKey>>>
|
||||
packingKeys;
|
||||
std::vector<LweSecretKey> secretKeys;
|
||||
std::vector<LweBootstrapKey> bootstrapKeys;
|
||||
std::vector<LweKeyswitchKey> keyswitchKeys;
|
||||
std::vector<PackingKeyswitchKey> packingKeyswitchKeys;
|
||||
|
||||
// Load LWE secret keys
|
||||
for (auto secretKeyParam : params.secretKeys) {
|
||||
auto id = secretKeyParam.first;
|
||||
auto param = secretKeyParam.second;
|
||||
// Load secret keys
|
||||
for (auto p : llvm::enumerate(params.secretKeys)) {
|
||||
// TODO - Check parameters?
|
||||
// auto param = secretKeyParam.second;
|
||||
llvm::SmallString<0> path(folderPath);
|
||||
llvm::sys::path::append(path, "secretKey_" + id);
|
||||
OUTCOME_TRY(LweSecretKey64 * sk, loadSecretKey(path));
|
||||
secretKeys[id] = {param, sk};
|
||||
llvm::sys::path::append(path, "secretKey_" + std::to_string(p.index()));
|
||||
OUTCOME_TRY(auto key, loadKey(path, readLweSecretKey));
|
||||
secretKeys.push_back(key);
|
||||
}
|
||||
// Load bootstrap keys
|
||||
for (auto bootstrapKeyParam : params.bootstrapKeys) {
|
||||
auto id = bootstrapKeyParam.first;
|
||||
auto param = bootstrapKeyParam.second;
|
||||
for (auto p : llvm::enumerate(params.bootstrapKeys)) {
|
||||
// TODO - Check parameters?
|
||||
// auto param = p.value();
|
||||
llvm::SmallString<0> path(folderPath);
|
||||
llvm::sys::path::append(path, "pbsKey_" + id);
|
||||
OUTCOME_TRY(LweBootstrapKey64 * bsk, loadBootstrapKey(path));
|
||||
bootstrapKeys[id] = {param, std::make_shared<LweBootstrapKey>(bsk)};
|
||||
llvm::sys::path::append(path, "pbsKey_" + std::to_string(p.index()));
|
||||
OUTCOME_TRY(auto key, loadKey(path, readLweBootstrapKey));
|
||||
bootstrapKeys.push_back(key);
|
||||
}
|
||||
// Load keyswitch keys
|
||||
for (auto keyswitchParam : params.keyswitchKeys) {
|
||||
auto id = keyswitchParam.first;
|
||||
auto param = keyswitchParam.second;
|
||||
for (auto p : llvm::enumerate(params.keyswitchKeys)) {
|
||||
// TODO - Check parameters?
|
||||
// auto param = p.value();
|
||||
llvm::SmallString<0> path(folderPath);
|
||||
llvm::sys::path::append(path, "ksKey_" + id);
|
||||
OUTCOME_TRY(LweKeyswitchKey64 * ksk, loadKeyswitchKey(path));
|
||||
keyswitchKeys[id] = {param, std::make_shared<LweKeyswitchKey>(ksk)};
|
||||
}
|
||||
// Load packing keys
|
||||
for (auto packingParam : params.packingKeys) {
|
||||
auto id = packingParam.first;
|
||||
auto param = packingParam.second;
|
||||
llvm::SmallString<0> path(folderPath);
|
||||
llvm::sys::path::append(path, "fpksKey_" + id);
|
||||
OUTCOME_TRY(LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *
|
||||
ksk,
|
||||
loadPackingKey(path));
|
||||
packingKeys[id] = {param, std::make_shared<PackingKeyswitchKey>(ksk)};
|
||||
llvm::sys::path::append(path, "ksKey_" + std::to_string(p.index()));
|
||||
OUTCOME_TRY(auto key, loadKey(path, readLweKeyswitchKey));
|
||||
keyswitchKeys.push_back(key);
|
||||
}
|
||||
|
||||
key_set->setKeys(secretKeys, bootstrapKeys, keyswitchKeys, packingKeys);
|
||||
for (auto p : llvm::enumerate(params.packingKeyswitchKeys)) {
|
||||
// TODO - Check parameters?
|
||||
// auto param = p.value();
|
||||
llvm::SmallString<0> path(folderPath);
|
||||
llvm::sys::path::append(path, "pksKey_" + std::to_string(p.index()));
|
||||
OUTCOME_TRY(auto key, loadKey(path, readPackingKeyswitchKey));
|
||||
packingKeyswitchKeys.push_back(key);
|
||||
}
|
||||
|
||||
OUTCOME_TRYV(key_set->setupEncryptionMaterial(params, seed_msb, seed_lsb));
|
||||
__uint128_t seed = seed_msb;
|
||||
seed <<= 64;
|
||||
seed += seed_lsb;
|
||||
|
||||
return std::move(key_set);
|
||||
auto csprng = ConcreteCSPRNG(seed);
|
||||
|
||||
OUTCOME_TRY(auto keySet,
|
||||
KeySet::fromKeys(params, secretKeys, bootstrapKeys, keyswitchKeys,
|
||||
packingKeyswitchKeys, std::move(csprng)));
|
||||
|
||||
return std::move(keySet);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError> saveKeys(KeySet &key_set,
|
||||
@@ -239,38 +126,30 @@ outcome::checked<void, StringError> saveKeys(KeySet &key_set,
|
||||
return StringError("Cannot create directory \"")
|
||||
<< std::string(folderIncompletePath) << "\": " << err.message();
|
||||
}
|
||||
|
||||
// Save LWE secret keys
|
||||
for (auto secretKeyParam : key_set.getSecretKeys()) {
|
||||
auto id = secretKeyParam.first;
|
||||
auto key = secretKeyParam.second.second;
|
||||
for (auto p : llvm::enumerate(key_set.getSecretKeys())) {
|
||||
|
||||
llvm::SmallString<0> path = folderIncompletePath;
|
||||
llvm::sys::path::append(path, "secretKey_" + id);
|
||||
saveSecretKey(path, key);
|
||||
llvm::sys::path::append(path, "secretKey_" + std::to_string(p.index()));
|
||||
OUTCOME_TRYV(saveKey(path, p.value()));
|
||||
}
|
||||
// Save bootstrap keys
|
||||
for (auto bootstrapKeyParam : key_set.getBootstrapKeys()) {
|
||||
auto id = bootstrapKeyParam.first;
|
||||
auto key = bootstrapKeyParam.second.second;
|
||||
for (auto p : llvm::enumerate(key_set.getBootstrapKeys())) {
|
||||
llvm::SmallString<0> path = folderIncompletePath;
|
||||
llvm::sys::path::append(path, "pbsKey_" + id);
|
||||
saveBootstrapKey(path, key->get());
|
||||
llvm::sys::path::append(path, "pbsKey_" + std::to_string(p.index()));
|
||||
OUTCOME_TRYV(saveKey(path, p.value()));
|
||||
}
|
||||
// Save keyswitch keys
|
||||
for (auto keyswitchParam : key_set.getKeyswitchKeys()) {
|
||||
auto id = keyswitchParam.first;
|
||||
auto key = keyswitchParam.second.second;
|
||||
for (auto p : llvm::enumerate(key_set.getKeyswitchKeys())) {
|
||||
llvm::SmallString<0> path = folderIncompletePath;
|
||||
llvm::sys::path::append(path, "ksKey_" + id);
|
||||
saveKeyswitchKey(path, key->get());
|
||||
llvm::sys::path::append(path, "ksKey_" + std::to_string(p.index()));
|
||||
OUTCOME_TRYV(saveKey(path, p.value()));
|
||||
}
|
||||
// Save packing keys
|
||||
for (auto keyswitchParam : key_set.getPackingKeys()) {
|
||||
auto id = keyswitchParam.first;
|
||||
auto key = keyswitchParam.second.second;
|
||||
// Save packing keyswitch keys
|
||||
for (auto p : llvm::enumerate(key_set.getPackingKeyswitchKeys())) {
|
||||
llvm::SmallString<0> path = folderIncompletePath;
|
||||
llvm::sys::path::append(path, "fpksKey_" + id);
|
||||
savePackingKey(path, key->get());
|
||||
llvm::sys::path::append(path, "pksKey_" + std::to_string(p.index()));
|
||||
OUTCOME_TRYV(saveKey(path, p.value()));
|
||||
}
|
||||
|
||||
err = llvm::sys::fs::rename(folderIncompletePath, folderPath);
|
||||
@@ -338,7 +217,14 @@ KeySetCache::loadOrGenerateSave(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
|
||||
std::cerr << "KeySetCache: miss, regenerating " << std::string(folderPath)
|
||||
<< "\n";
|
||||
OUTCOME_TRY(auto key_set, KeySet::generate(params, seed_msb, seed_lsb));
|
||||
|
||||
__uint128_t seed = seed_msb;
|
||||
seed <<= 64;
|
||||
seed += seed_lsb;
|
||||
|
||||
auto csprng = ConcreteCSPRNG(seed);
|
||||
|
||||
OUTCOME_TRY(auto key_set, KeySet::generate(params, std::move(csprng)));
|
||||
|
||||
OUTCOME_TRYV(saveKeys(*key_set, folderPath));
|
||||
|
||||
@@ -349,8 +235,13 @@ outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
KeySetCache::generate(std::shared_ptr<KeySetCache> cache,
|
||||
ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
__uint128_t seed = seed_msb;
|
||||
seed <<= 64;
|
||||
seed += seed_lsb;
|
||||
|
||||
auto csprng = ConcreteCSPRNG(seed);
|
||||
return cache ? cache->loadOrGenerateSave(params, seed_msb, seed_lsb)
|
||||
: KeySet::generate(params, seed_msb, seed_lsb);
|
||||
: KeySet::generate(params, std::move(csprng));
|
||||
}
|
||||
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
|
||||
@@ -7,8 +7,6 @@
|
||||
#include <iostream>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "concrete-core-ffi.h"
|
||||
|
||||
#include "concretelang/ClientLib/PublicArguments.h"
|
||||
#include "concretelang/ClientLib/Serializers.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
@@ -16,139 +14,226 @@
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
template <typename Engine, typename Result>
|
||||
Result read_deser(std::istream &istream,
|
||||
int (*deser)(Engine *, BufferView, Result *),
|
||||
Engine *engine) {
|
||||
size_t length;
|
||||
readSize(istream, length);
|
||||
// buffer is too big to be allocated on stack
|
||||
// vector ensures everything is deallocated w.r.t. new
|
||||
std::vector<uint8_t> buffer(length);
|
||||
istream.read((char *)buffer.data(), length);
|
||||
assert(istream.good());
|
||||
Result result;
|
||||
|
||||
CAPI_ASSERT_ERROR(deser(engine, {buffer.data(), length}, &result));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename BufferLike>
|
||||
std::ostream &writeBufferLike(std::ostream &ostream, BufferLike &buffer) {
|
||||
writeSize(ostream, buffer.length);
|
||||
ostream.write((const char *)buffer.pointer, buffer.length);
|
||||
template <typename Key>
|
||||
std::ostream &writeUInt64KeyBuffer(std::ostream &ostream, Key &buffer) {
|
||||
writeSize(ostream, (uint64_t)buffer.size());
|
||||
ostream.write((const char *)buffer.buffer(),
|
||||
buffer.size() * sizeof(uint64_t));
|
||||
assert(ostream.good());
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream, const LweKeyswitchKey64 *key) {
|
||||
DefaultSerializationEngine *engine;
|
||||
|
||||
// No Freeing as it doesn't allocate anything.
|
||||
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
|
||||
|
||||
Buffer b;
|
||||
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_serialization_engine_serialize_lwe_keyswitch_key_u64(engine, key,
|
||||
&b));
|
||||
|
||||
writeBufferLike(ostream, b);
|
||||
free((void *)b.pointer);
|
||||
b.pointer = nullptr;
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream, const LweBootstrapKey64 *key) {
|
||||
DefaultSerializationEngine *engine;
|
||||
|
||||
// No Freeing as it doesn't allocate anything.
|
||||
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
|
||||
|
||||
Buffer b;
|
||||
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_serialization_engine_serialize_lwe_bootstrap_key_u64(engine, key,
|
||||
&b))
|
||||
|
||||
writeBufferLike(ostream, b);
|
||||
free((void *)b.pointer);
|
||||
b.pointer = nullptr;
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const FftFourierLweBootstrapKey64 *key) {
|
||||
FftSerializationEngine *engine;
|
||||
|
||||
// No Freeing as it doesn't allocate anything.
|
||||
CAPI_ASSERT_ERROR(new_fft_serialization_engine(&engine));
|
||||
|
||||
Buffer b;
|
||||
|
||||
CAPI_ASSERT_ERROR(
|
||||
fft_serialization_engine_serialize_fft_fourier_lwe_bootstrap_key_u64(
|
||||
engine, key, &b))
|
||||
|
||||
writeBufferLike(ostream, b);
|
||||
free((void *)b.pointer);
|
||||
b.pointer = nullptr;
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::istream &operator>>(std::istream &istream, LweKeyswitchKey64 *&key) {
|
||||
DefaultSerializationEngine *engine;
|
||||
|
||||
// No Freeing as it doesn't allocate anything.
|
||||
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
|
||||
|
||||
key = read_deser(
|
||||
istream, default_serialization_engine_deserialize_lwe_keyswitch_key_u64,
|
||||
engine);
|
||||
return istream;
|
||||
}
|
||||
|
||||
std::istream &operator>>(std::istream &istream, LweBootstrapKey64 *&key) {
|
||||
DefaultSerializationEngine *engine;
|
||||
|
||||
// No Freeing as it doesn't allocate anything.
|
||||
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
|
||||
|
||||
key = read_deser(
|
||||
istream, default_serialization_engine_deserialize_lwe_bootstrap_key_u64,
|
||||
engine);
|
||||
return istream;
|
||||
}
|
||||
|
||||
std::istream &operator>>(std::istream &istream,
|
||||
FftFourierLweBootstrapKey64 *&key) {
|
||||
FftSerializationEngine *engine;
|
||||
|
||||
// No Freeing as it doesn't allocate anything.
|
||||
CAPI_ASSERT_ERROR(new_fft_serialization_engine(&engine));
|
||||
|
||||
key = read_deser(
|
||||
istream,
|
||||
fft_serialization_engine_deserialize_fft_fourier_lwe_bootstrap_key_u64,
|
||||
engine);
|
||||
return istream;
|
||||
}
|
||||
|
||||
std::istream &operator>>(std::istream &istream,
|
||||
RuntimeContext &runtimeContext) {
|
||||
istream >> runtimeContext.evaluationKeys;
|
||||
std::shared_ptr<std::vector<uint64_t>> &vec) {
|
||||
// TODO assertion on size?
|
||||
uint64_t size;
|
||||
readSize(istream, size);
|
||||
vec->resize(size);
|
||||
istream.read((char *)vec->data(), size * sizeof(uint64_t));
|
||||
assert(istream.good());
|
||||
return istream;
|
||||
}
|
||||
|
||||
// LweSecretKey ////////////////////////////
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream, const LweSecretKeyParam param) {
|
||||
writeWord(ostream, param.dimension);
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::istream &operator>>(std::istream &istream, LweSecretKeyParam ¶m) {
|
||||
readWord(istream, param.dimension);
|
||||
return istream;
|
||||
}
|
||||
|
||||
// LweSecretKey /////////////////////////////////
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream, const LweSecretKey &key) {
|
||||
ostream << key.parameters();
|
||||
writeUInt64KeyBuffer(ostream, key);
|
||||
return ostream;
|
||||
}
|
||||
|
||||
LweSecretKey readLweSecretKey(std::istream &istream) {
|
||||
LweSecretKeyParam param;
|
||||
istream >> param;
|
||||
auto buffer = std::make_shared<std::vector<uint64_t>>();
|
||||
istream >> buffer;
|
||||
return LweSecretKey(buffer, param);
|
||||
}
|
||||
|
||||
// KeyswitchKeyParam ////////////////////////////
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream, const KeyswitchKeyParam param) {
|
||||
// TODO keys id
|
||||
writeWord(ostream, param.level);
|
||||
writeWord(ostream, param.baseLog);
|
||||
writeWord(ostream, param.variance);
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::istream &operator>>(std::istream &istream, KeyswitchKeyParam ¶m) {
|
||||
// TODO keys id
|
||||
param.outputSecretKeyID = 1234;
|
||||
param.inputSecretKeyID = 1234;
|
||||
readWord(istream, param.level);
|
||||
readWord(istream, param.baseLog);
|
||||
readWord(istream, param.variance);
|
||||
return istream;
|
||||
}
|
||||
|
||||
// LweKeyswitchKey //////////////////////////////
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream, const LweKeyswitchKey &key) {
|
||||
ostream << key.parameters();
|
||||
writeUInt64KeyBuffer(ostream, key);
|
||||
return ostream;
|
||||
}
|
||||
|
||||
LweKeyswitchKey readLweKeyswitchKey(std::istream &istream) {
|
||||
KeyswitchKeyParam param;
|
||||
istream >> param;
|
||||
auto buffer = std::make_shared<std::vector<uint64_t>>();
|
||||
istream >> buffer;
|
||||
return LweKeyswitchKey(buffer, param);
|
||||
}
|
||||
|
||||
// BootstrapKeyParam ////////////////////////////
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream, const BootstrapKeyParam param) {
|
||||
// TODO keys id
|
||||
writeWord(ostream, param.level);
|
||||
writeWord(ostream, param.baseLog);
|
||||
writeWord(ostream, param.glweDimension);
|
||||
writeWord(ostream, param.variance);
|
||||
writeWord(ostream, param.polynomialSize);
|
||||
writeWord(ostream, param.inputLweDimension);
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::istream &operator>>(std::istream &istream, BootstrapKeyParam ¶m) {
|
||||
// TODO keys id
|
||||
readWord(istream, param.level);
|
||||
readWord(istream, param.baseLog);
|
||||
readWord(istream, param.glweDimension);
|
||||
readWord(istream, param.variance);
|
||||
readWord(istream, param.polynomialSize);
|
||||
readWord(istream, param.inputLweDimension);
|
||||
return istream;
|
||||
}
|
||||
|
||||
// LweBootstrapKey //////////////////////////////
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream, const LweBootstrapKey &key) {
|
||||
ostream << key.parameters();
|
||||
writeUInt64KeyBuffer(ostream, key);
|
||||
return ostream;
|
||||
}
|
||||
|
||||
LweBootstrapKey readLweBootstrapKey(std::istream &istream) {
|
||||
BootstrapKeyParam param;
|
||||
istream >> param;
|
||||
auto buffer = std::make_shared<std::vector<uint64_t>>();
|
||||
istream >> buffer;
|
||||
return LweBootstrapKey(buffer, param);
|
||||
}
|
||||
|
||||
// PackingKeyswitchKeyParam ////////////////////////////
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const RuntimeContext &runtimeContext) {
|
||||
ostream << runtimeContext.evaluationKeys;
|
||||
const PackingKeyswitchKeyParam param) {
|
||||
|
||||
// TODO keys id
|
||||
writeWord(ostream, param.level);
|
||||
writeWord(ostream, param.baseLog);
|
||||
writeWord(ostream, param.glweDimension);
|
||||
writeWord(ostream, param.polynomialSize);
|
||||
writeWord(ostream, param.inputLweDimension);
|
||||
writeWord(ostream, param.variance);
|
||||
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::istream &operator>>(std::istream &istream,
|
||||
PackingKeyswitchKeyParam ¶m) {
|
||||
|
||||
// TODO keys id
|
||||
param.outputSecretKeyID = 1234;
|
||||
param.inputSecretKeyID = 1234;
|
||||
readWord(istream, param.level);
|
||||
readWord(istream, param.baseLog);
|
||||
readWord(istream, param.glweDimension);
|
||||
readWord(istream, param.polynomialSize);
|
||||
readWord(istream, param.inputLweDimension);
|
||||
readWord(istream, param.variance);
|
||||
|
||||
return istream;
|
||||
}
|
||||
|
||||
// PackingKeyswitchKey //////////////////////////////
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const PackingKeyswitchKey &key) {
|
||||
ostream << key.parameters();
|
||||
writeUInt64KeyBuffer(ostream, key);
|
||||
return ostream;
|
||||
}
|
||||
|
||||
PackingKeyswitchKey readPackingKeyswitchKey(std::istream &istream) {
|
||||
PackingKeyswitchKeyParam param;
|
||||
istream >> param;
|
||||
auto buffer = std::make_shared<std::vector<uint64_t>>();
|
||||
istream >> buffer;
|
||||
auto b = PackingKeyswitchKey(buffer, param);
|
||||
|
||||
return b;
|
||||
}
|
||||
|
||||
// EvaluationKey ////////////////////////////////
|
||||
|
||||
EvaluationKeys readEvaluationKeys(std::istream &istream) {
|
||||
uint64_t nbKey;
|
||||
readSize(istream, nbKey);
|
||||
std::vector<LweBootstrapKey> bootstrapKeys;
|
||||
for (uint64_t i = 0; i < nbKey; i++) {
|
||||
bootstrapKeys.push_back(readLweBootstrapKey(istream));
|
||||
}
|
||||
readSize(istream, nbKey);
|
||||
std::vector<LweKeyswitchKey> keyswitchKeys;
|
||||
for (uint64_t i = 0; i < nbKey; i++) {
|
||||
keyswitchKeys.push_back(readLweKeyswitchKey(istream));
|
||||
}
|
||||
std::vector<PackingKeyswitchKey> packingKeyswitchKeys;
|
||||
readSize(istream, nbKey);
|
||||
for (uint64_t i = 0; i < nbKey; i++) {
|
||||
packingKeyswitchKeys.push_back(readPackingKeyswitchKey(istream));
|
||||
}
|
||||
return EvaluationKeys(keyswitchKeys, bootstrapKeys, packingKeyswitchKeys);
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const EvaluationKeys &evaluationKeys) {
|
||||
auto bootstrapKeys = evaluationKeys.getBootstrapKeys();
|
||||
writeSize(ostream, bootstrapKeys.size());
|
||||
for (auto bsk : bootstrapKeys) {
|
||||
ostream << bsk;
|
||||
}
|
||||
auto keyswitchKeys = evaluationKeys.getKeyswitchKeys();
|
||||
writeSize(ostream, keyswitchKeys.size());
|
||||
for (auto ksk : keyswitchKeys) {
|
||||
ostream << ksk;
|
||||
}
|
||||
auto packingKeyswitchKeys = evaluationKeys.getPackingKeyswitchKeys();
|
||||
writeSize(ostream, packingKeyswitchKeys.size());
|
||||
for (auto pksk : packingKeyswitchKeys) {
|
||||
ostream << pksk;
|
||||
}
|
||||
assert(ostream.good());
|
||||
return ostream;
|
||||
}
|
||||
|
||||
// TensorData ///////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
std::ostream &serializeScalarDataRaw(T value, std::ostream &ostream) {
|
||||
writeWord<uint64_t>(ostream, sizeof(T) * 8);
|
||||
@@ -399,70 +484,5 @@ unserializeScalarOrTensorData(const std::vector<int64_t> &expectedSizes,
|
||||
}
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const LweKeyswitchKey &wrappedKsk) {
|
||||
ostream << wrappedKsk.ksk;
|
||||
assert(ostream.good());
|
||||
return ostream;
|
||||
}
|
||||
std::istream &operator>>(std::istream &istream, LweKeyswitchKey &wrappedKsk) {
|
||||
istream >> wrappedKsk.ksk;
|
||||
assert(istream.good());
|
||||
return istream;
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const LweBootstrapKey &wrappedBsk) {
|
||||
ostream << wrappedBsk.bsk;
|
||||
assert(ostream.good());
|
||||
return ostream;
|
||||
}
|
||||
std::istream &operator>>(std::istream &istream, LweBootstrapKey &wrappedBsk) {
|
||||
istream >> wrappedBsk.bsk;
|
||||
assert(istream.good());
|
||||
return istream;
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const EvaluationKeys &evaluationKeys) {
|
||||
bool has_ksk = (bool)evaluationKeys.sharedKsk;
|
||||
writeWord(ostream, has_ksk);
|
||||
if (has_ksk) {
|
||||
ostream << *evaluationKeys.sharedKsk;
|
||||
}
|
||||
|
||||
bool has_bsk = (bool)evaluationKeys.sharedBsk;
|
||||
writeWord(ostream, has_bsk);
|
||||
if (has_bsk) {
|
||||
ostream << *evaluationKeys.sharedBsk;
|
||||
}
|
||||
assert(ostream.good());
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::istream &operator>>(std::istream &istream,
|
||||
EvaluationKeys &evaluationKeys) {
|
||||
bool has_ksk;
|
||||
readWord(istream, has_ksk);
|
||||
if (has_ksk) {
|
||||
auto sharedKsk = LweKeyswitchKey(nullptr);
|
||||
istream >> sharedKsk;
|
||||
evaluationKeys.sharedKsk =
|
||||
std::make_shared<LweKeyswitchKey>(std::move(sharedKsk));
|
||||
}
|
||||
|
||||
bool has_bsk;
|
||||
readWord(istream, has_bsk);
|
||||
if (has_bsk) {
|
||||
auto sharedBsk = LweBootstrapKey(nullptr);
|
||||
istream >> sharedBsk;
|
||||
evaluationKeys.sharedBsk =
|
||||
std::make_shared<LweBootstrapKey>(std::move(sharedBsk));
|
||||
}
|
||||
|
||||
assert(istream.good());
|
||||
return istream;
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
@@ -47,8 +47,7 @@ char memref_wop_pbs_crt_buffer[] = "memref_wop_pbs_crt_buffer";
|
||||
char memref_encode_plaintext_with_crt[] = "memref_encode_plaintext_with_crt";
|
||||
char memref_encode_expand_lut_for_bootstrap[] =
|
||||
"memref_encode_expand_lut_for_bootstrap";
|
||||
char memref_encode_expand_lut_for_woppbs[] =
|
||||
"memref_encode_expand_lut_for_woppbs";
|
||||
char memref_encode_lut_for_crt_woppbs[] = "memref_encode_lut_for_crt_woppbs";
|
||||
char memref_trace[] = "memref_trace";
|
||||
|
||||
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
|
||||
@@ -158,10 +157,16 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
} else if (funcName == memref_wop_pbs_crt_buffer) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
memref2DType,
|
||||
memref2DType,
|
||||
memref2DType,
|
||||
memref1DType,
|
||||
memref1DType,
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
@@ -181,11 +186,11 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
{memref1DType, memref1DType, rewriter.getI32Type(),
|
||||
rewriter.getI32Type(), rewriter.getI1Type()},
|
||||
{});
|
||||
} else if (funcName == memref_encode_expand_lut_for_woppbs) {
|
||||
} else if (funcName == memref_encode_lut_for_crt_woppbs) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{memref1DType, memref1DType, memref1DType, memref1DType,
|
||||
rewriter.getI32Type(), rewriter.getI32Type(), rewriter.getI1Type()},
|
||||
{memref2DType, memref1DType, memref1DType, memref1DType,
|
||||
rewriter.getI32Type(), rewriter.getI1Type()},
|
||||
{});
|
||||
} else if (funcName == memref_trace) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
@@ -326,9 +331,32 @@ void wopPBSAddOperands(Concrete::WopPBSCRTLweBufferOp op,
|
||||
// cbs_base_log
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.circuitBootstrapBaseLogAttr()));
|
||||
|
||||
// ksk_level_count
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.keyswitchLevelAttr()));
|
||||
// ksk_base_log
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.keyswitchBaseLogAttr()));
|
||||
|
||||
// bsk_level_count
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.bootstrapLevelAttr()));
|
||||
// bsk_base_log
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.bootstrapBaseLogAttr()));
|
||||
|
||||
// fpksk_level_count
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.packingKeySwitchLevelAttr()));
|
||||
// fpksk_base_log
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.packingKeySwitchBaseLogAttr()));
|
||||
|
||||
// polynomial_size
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.packingKeySwitchoutputPolynomialSizeAttr()));
|
||||
|
||||
// context
|
||||
operands.push_back(getContextArgument(op));
|
||||
}
|
||||
@@ -372,9 +400,9 @@ void encodeExpandLutForBootstrapAddOperands(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.isSignedAttr()));
|
||||
}
|
||||
|
||||
void encodeExpandLutForWopPBSAddOperands(
|
||||
Concrete::EncodeExpandLutForWopPBSBufferOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands, mlir::RewriterBase &rewriter) {
|
||||
void encodeLutForWopPBSAddOperands(Concrete::EncodeLutForCrtWopPBSBufferOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
mlir::RewriterBase &rewriter) {
|
||||
|
||||
// crt_decomposition
|
||||
mlir::Type crtDecompositionType = mlir::RankedTensorType::get(
|
||||
@@ -414,9 +442,6 @@ void encodeExpandLutForWopPBSAddOperands(
|
||||
op.getLoc(), (*crtBitsGlobalMemref).type(),
|
||||
(*crtBitsGlobalMemref).getName());
|
||||
operands.push_back(getCastedMemRef(rewriter, crtBitsGlobalRef));
|
||||
// poly_size
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.polySizeAttr()));
|
||||
// modulus_product
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.modulusProductAttr()));
|
||||
@@ -467,10 +492,10 @@ struct ConcreteToCAPIPass : public ConcreteToCAPIBase<ConcreteToCAPIPass> {
|
||||
ConcreteToCAPICallPattern<Concrete::EncodeExpandLutForBootstrapBufferOp,
|
||||
memref_encode_expand_lut_for_bootstrap>>(
|
||||
&getContext(), encodeExpandLutForBootstrapAddOperands);
|
||||
patterns.add<
|
||||
ConcreteToCAPICallPattern<Concrete::EncodeExpandLutForWopPBSBufferOp,
|
||||
memref_encode_expand_lut_for_woppbs>>(
|
||||
&getContext(), encodeExpandLutForWopPBSAddOperands);
|
||||
patterns
|
||||
.add<ConcreteToCAPICallPattern<Concrete::EncodeLutForCrtWopPBSBufferOp,
|
||||
memref_encode_lut_for_crt_woppbs>>(
|
||||
&getContext(), encodeLutForWopPBSAddOperands);
|
||||
if (gpu) {
|
||||
patterns.add<ConcreteToCAPICallPattern<Concrete::KeySwitchLweBufferOp,
|
||||
memref_keyswitch_lwe_cuda_u64>>(
|
||||
|
||||
@@ -560,17 +560,18 @@ struct ApplyLookupTableEintOpPattern
|
||||
|
||||
mlir::Value newLut =
|
||||
rewriter
|
||||
.create<TFHE::EncodeExpandLutForWopPBSOp>(
|
||||
.create<TFHE::EncodeLutForCrtWopPBSOp>(
|
||||
op.getLoc(),
|
||||
mlir::RankedTensorType::get(
|
||||
mlir::ArrayRef<int64_t>(loweringParameters.lutSize),
|
||||
mlir::ArrayRef<int64_t>{
|
||||
(int64_t)loweringParameters.nMods,
|
||||
(int64_t)loweringParameters.singleLutSize},
|
||||
rewriter.getI64Type()),
|
||||
adaptor.lut(),
|
||||
rewriter.getI64ArrayAttr(
|
||||
mlir::ArrayRef<int64_t>(loweringParameters.mods)),
|
||||
rewriter.getI64ArrayAttr(
|
||||
mlir::ArrayRef<int64_t>(loweringParameters.bits)),
|
||||
rewriter.getI32IntegerAttr(loweringParameters.polynomialSize),
|
||||
rewriter.getI32IntegerAttr(loweringParameters.modsProd),
|
||||
rewriter.getBoolAttr(originalInputType.isSigned()))
|
||||
.getResult();
|
||||
|
||||
@@ -664,8 +664,8 @@ void TFHEToConcretePass::runOnOperation() {
|
||||
mlir::concretelang::Concrete::EncodeExpandLutForBootstrapTensorOp,
|
||||
true>,
|
||||
mlir::concretelang::GenericOneToOneOpConversionPattern<
|
||||
mlir::concretelang::TFHE::EncodeExpandLutForWopPBSOp,
|
||||
mlir::concretelang::Concrete::EncodeExpandLutForWopPBSTensorOp, true>,
|
||||
mlir::concretelang::TFHE::EncodeLutForCrtWopPBSOp,
|
||||
mlir::concretelang::Concrete::EncodeLutForCrtWopPBSTensorOp, true>,
|
||||
mlir::concretelang::GenericOneToOneOpConversionPattern<
|
||||
mlir::concretelang::TFHE::EncodePlaintextWithCrtOp,
|
||||
mlir::concretelang::Concrete::EncodePlaintextWithCrtTensorOp, true>,
|
||||
|
||||
@@ -143,10 +143,10 @@ void mlir::concretelang::Concrete::
|
||||
Concrete::EncodeExpandLutForBootstrapTensorOp::attachInterface<
|
||||
TensorToMemrefOp<Concrete::EncodeExpandLutForBootstrapTensorOp,
|
||||
Concrete::EncodeExpandLutForBootstrapBufferOp>>(*ctx);
|
||||
// encode_expand_lut_for_woppbs_tensor =>
|
||||
// encode_expand_lut_for_woppbs_buffer
|
||||
Concrete::EncodeExpandLutForWopPBSTensorOp::attachInterface<
|
||||
TensorToMemrefOp<Concrete::EncodeExpandLutForWopPBSTensorOp,
|
||||
Concrete::EncodeExpandLutForWopPBSBufferOp>>(*ctx);
|
||||
// encode_lut_for_crt_woppbs_tensor =>
|
||||
// encode_lut_for_crt_woppbs_buffer
|
||||
Concrete::EncodeLutForCrtWopPBSTensorOp::attachInterface<
|
||||
TensorToMemrefOp<Concrete::EncodeLutForCrtWopPBSTensorOp,
|
||||
Concrete::EncodeLutForCrtWopPBSBufferOp>>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
add_library(ConcretelangRuntime SHARED context.cpp wrappers.cpp DFRuntime.cpp StreamEmulator.cpp seeder.cpp)
|
||||
add_library(ConcretelangRuntime SHARED context.cpp wrappers.cpp DFRuntime.cpp StreamEmulator.cpp)
|
||||
|
||||
add_dependencies(ConcretelangRuntime concrete_cpu)
|
||||
|
||||
if(CONCRETELANG_DATAFLOW_EXECUTION_ENABLED)
|
||||
target_link_libraries(ConcretelangRuntime PRIVATE HPX::hpx HPX::iostreams_component)
|
||||
@@ -25,7 +27,8 @@ if(APPLE)
|
||||
target_link_libraries(ConcretelangRuntime LINK_PUBLIC ${SECURITY_FRAMEWORK})
|
||||
endif()
|
||||
|
||||
target_link_libraries(ConcretelangRuntime PUBLIC Concrete ConcretelangClientLib pthread m dl
|
||||
target_include_directories(ConcretelangRuntime PUBLIC ${CONCRETE_CPU_INCLUDE_DIR})
|
||||
target_link_libraries(ConcretelangRuntime PUBLIC concrete_cpu ConcretelangClientLib pthread m dl
|
||||
$<TARGET_OBJECTS:mlir_c_runner_utils>)
|
||||
|
||||
install(TARGETS ConcretelangRuntime omp EXPORT ConcretelangRuntime)
|
||||
|
||||
@@ -5,29 +5,77 @@
|
||||
|
||||
#include "concretelang/Runtime/context.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Runtime/seeder.h"
|
||||
#include <assert.h>
|
||||
#include <stdio.h>
|
||||
|
||||
LweKeyswitchKey64 *
|
||||
get_keyswitch_key_u64(mlir::concretelang::RuntimeContext *context) {
|
||||
return context->get_ksk();
|
||||
namespace clientlib = ::concretelang::clientlib;
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
FFT::FFT(size_t polynomial_size)
|
||||
: fft(nullptr), polynomial_size(polynomial_size) {
|
||||
fft = (struct Fft *)aligned_alloc(CONCRETE_FFT_ALIGN, CONCRETE_FFT_SIZE);
|
||||
concrete_cpu_construct_concrete_fft(fft, polynomial_size);
|
||||
}
|
||||
|
||||
LweBootstrapKey64 *
|
||||
get_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context) {
|
||||
return context->get_bsk();
|
||||
FFT::FFT(FFT &&other) : fft(other.fft), polynomial_size(other.polynomial_size) {
|
||||
other.fft = nullptr;
|
||||
}
|
||||
|
||||
FftFourierLweBootstrapKey64 *
|
||||
get_fft_fourier_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context) {
|
||||
return context->get_fft_fourier_bsk();
|
||||
FFT::~FFT() {
|
||||
if (fft != nullptr) {
|
||||
concrete_cpu_destroy_concrete_fft(fft);
|
||||
free(fft);
|
||||
}
|
||||
}
|
||||
|
||||
DefaultEngine *get_engine(mlir::concretelang::RuntimeContext *context) {
|
||||
return context->get_default_engine();
|
||||
RuntimeContext::RuntimeContext(clientlib::EvaluationKeys evaluationKeys)
|
||||
: evaluationKeys(evaluationKeys) {
|
||||
{
|
||||
|
||||
// Initialize for each bootstrap key the fourier one
|
||||
for (auto bsk : evaluationKeys.getBootstrapKeys()) {
|
||||
auto param = bsk.parameters();
|
||||
|
||||
size_t decomposition_level_count = param.level;
|
||||
size_t decomposition_base_log = param.baseLog;
|
||||
size_t glwe_dimension = param.glweDimension;
|
||||
size_t polynomial_size = param.polynomialSize;
|
||||
size_t input_lwe_dimension = param.inputLweDimension;
|
||||
|
||||
// Create the FFT
|
||||
FFT fft(polynomial_size);
|
||||
|
||||
// Allocate scratch for key conversion
|
||||
size_t scratch_size;
|
||||
size_t scratch_align;
|
||||
concrete_cpu_bootstrap_key_convert_u64_to_fourier_scratch(
|
||||
&scratch_size, &scratch_align, fft.fft);
|
||||
auto scratch = (uint8_t *)aligned_alloc(scratch_align, scratch_size);
|
||||
|
||||
// Allocate the fourier_bootstrap_key
|
||||
auto fourier_data = std::make_shared<std::vector<double>>();
|
||||
fourier_data->resize(bsk.size());
|
||||
auto bsk_data = bsk.buffer();
|
||||
|
||||
// Convert bootstrap_key to the fourier domain
|
||||
concrete_cpu_bootstrap_key_convert_u64_to_fourier(
|
||||
bsk_data, fourier_data->data(), decomposition_level_count,
|
||||
decomposition_base_log, glwe_dimension, polynomial_size,
|
||||
input_lwe_dimension, fft.fft, scratch, scratch_size);
|
||||
|
||||
// Store the fourier_bootstrap_key in the context
|
||||
fourier_bootstrap_keys.push_back(fourier_data);
|
||||
ffts.push_back(std::move(fft));
|
||||
free(scratch);
|
||||
}
|
||||
|
||||
#ifdef CONCRETELANG_CUDA_SUPPORT
|
||||
bsk_gpu = nullptr;
|
||||
ksk_gpu = nullptr;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
FftEngine *get_fft_engine(mlir::concretelang::RuntimeContext *context) {
|
||||
return context->get_fft_engine();
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/Runtime/seeder.h"
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
|
||||
#include "concrete-core-ffi.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
|
||||
SeederBuilder *get_best_seeder() {
|
||||
SeederBuilder *builder = NULL;
|
||||
|
||||
#if defined(__x86_64__) || defined(_M_X64)
|
||||
bool rdseed_seeder_available = false;
|
||||
CAPI_ASSERT_ERROR(rdseed_seeder_is_available(&rdseed_seeder_available));
|
||||
if (rdseed_seeder_available) {
|
||||
CAPI_ASSERT_ERROR(get_rdseed_seeder_builder(&builder));
|
||||
return builder;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if __APPLE__
|
||||
bool apple_seeder_available = false;
|
||||
CAPI_ASSERT_ERROR(
|
||||
apple_secure_enclave_seeder_is_available(&apple_seeder_available));
|
||||
if (apple_seeder_available) {
|
||||
CAPI_ASSERT_ERROR(get_apple_secure_enclave_seeder_builder(&builder));
|
||||
return builder;
|
||||
}
|
||||
#endif
|
||||
|
||||
bool unix_seeder_available = false;
|
||||
CAPI_ASSERT_ERROR(unix_seeder_is_available(&unix_seeder_available));
|
||||
|
||||
if (unix_seeder_available) {
|
||||
// Security depends on /dev/random security
|
||||
uint64_t secret_high_64 = 0;
|
||||
uint64_t secret_low_64 = 0;
|
||||
CAPI_ASSERT_ERROR(
|
||||
get_unix_seeder_builder(secret_high_64, secret_low_64, &builder));
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
||||
std::cout << "No available seeder." << std::endl;
|
||||
return builder;
|
||||
}
|
||||
|
||||
SeederBuilder *best_seeder = get_best_seeder();
|
||||
@@ -4,8 +4,8 @@
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/Runtime/wrappers.h"
|
||||
#include "concrete-cpu.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Runtime/seeder.h"
|
||||
#include <assert.h>
|
||||
#include <bitset>
|
||||
#include <cmath>
|
||||
@@ -16,15 +16,6 @@
|
||||
#include <string.h>
|
||||
#include <vector>
|
||||
|
||||
static DefaultEngine *levelled_engine = nullptr;
|
||||
|
||||
DefaultEngine *get_levelled_engine() {
|
||||
if (levelled_engine == nullptr) {
|
||||
CAPI_ASSERT_ERROR(new_default_engine(best_seeder, &levelled_engine));
|
||||
}
|
||||
return levelled_engine;
|
||||
}
|
||||
|
||||
#include "concretelang/ClientLib/CRT.h"
|
||||
#include "concretelang/Runtime/wrappers.h"
|
||||
|
||||
@@ -184,18 +175,21 @@ void memref_batched_bootstrap_lwe_cuda_u64(
|
||||
// Construct the glwe accumulator (on CPU)
|
||||
// TODO: Should be done outside of the bootstrap call, compile time if
|
||||
// possible. Refactor in progress
|
||||
uint64_t glwe_ct_len = poly_size * (glwe_dim + 1);
|
||||
uint64_t glwe_ct_size = glwe_ct_len * sizeof(uint64_t);
|
||||
uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size);
|
||||
uint64_t glwe_ct_size = poly_size * (glwe_dim + 1);
|
||||
uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size * sizeof(uint64_t));
|
||||
auto tlu = tlu_aligned + tlu_offset;
|
||||
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_engine_discard_trivially_encrypt_glwe_ciphertext_u64_raw_ptr_buffers(
|
||||
get_levelled_engine(), glwe_ct, glwe_ct_len, tlu_aligned + tlu_offset,
|
||||
poly_size));
|
||||
// Glwe trivial encryption
|
||||
for (size_t i = 0; i < poly_size * glwe_dim; i++) {
|
||||
glwe_ct[i] = 0;
|
||||
}
|
||||
for (size_t i = 0; i < poly_size; i++) {
|
||||
glwe_ct[poly_size * glwe_dim + i] = tlu[i];
|
||||
}
|
||||
|
||||
// Move the glwe accumulator to the GPU
|
||||
void *glwe_ct_gpu =
|
||||
alloc_and_memcpy_async_to_gpu(glwe_ct, 0, glwe_ct_len, gpu_idx, stream);
|
||||
alloc_and_memcpy_async_to_gpu(glwe_ct, 0, glwe_ct_size, gpu_idx, stream);
|
||||
|
||||
// Move test vector indexes to the GPU, the test vector indexes is set of 0
|
||||
uint32_t num_test_vectors = 1, lwe_idx = 0,
|
||||
@@ -313,11 +307,12 @@ void memref_encode_expand_lut_for_bootstrap(
|
||||
return;
|
||||
}
|
||||
|
||||
void memref_encode_expand_lut_for_woppbs(
|
||||
void memref_encode_lut_for_crt_woppbs(
|
||||
// Output encoded/expanded lut
|
||||
uint64_t *output_lut_allocated, uint64_t *output_lut_aligned,
|
||||
uint64_t output_lut_offset, uint64_t output_lut_size,
|
||||
uint64_t output_lut_stride,
|
||||
uint64_t output_lut_offset, uint64_t output_lut_size0,
|
||||
uint64_t output_lut_size1, uint64_t output_lut_stride0,
|
||||
uint64_t output_lut_stride1,
|
||||
// Input lut
|
||||
uint64_t *input_lut_allocated, uint64_t *input_lut_aligned,
|
||||
uint64_t input_lut_offset, uint64_t input_lut_size,
|
||||
@@ -330,13 +325,22 @@ void memref_encode_expand_lut_for_woppbs(
|
||||
uint64_t *crt_bits_allocated, uint64_t *crt_bits_aligned,
|
||||
uint64_t crt_bits_offset, uint64_t crt_bits_size, uint64_t crt_bits_stride,
|
||||
// Crypto parameters
|
||||
uint32_t poly_size, uint32_t modulus_product, bool is_signed) {
|
||||
uint32_t modulus_product, bool is_signed) {
|
||||
|
||||
assert(input_lut_stride == 1 && "Runtime: stride not equal to 1, check "
|
||||
"memref_encode_expand_lut_woppbs");
|
||||
assert(output_lut_stride == 1 && "Runtime: stride not equal to 1, check "
|
||||
"memref_encode_expand_lut_woppbs");
|
||||
assert(modulus_product > input_lut_size);
|
||||
"memref_encode_lut_woppbs");
|
||||
assert(output_lut_stride0 == output_lut_size1 &&
|
||||
"Runtime: out dim stride not equal to in_dim size, check "
|
||||
"memref_encode_lut_woppbs");
|
||||
assert(output_lut_stride1 == 1 && "Runtime: stride not equal to 1, check "
|
||||
"memref_encode_lut_woppbs");
|
||||
|
||||
assert(modulus_product >= input_lut_size);
|
||||
|
||||
// Initialize lut cases not supposed to be reached
|
||||
for (uint64_t i = 0; i < output_lut_size0 * output_lut_size1; i++) {
|
||||
output_lut_aligned[output_lut_offset + i] = 0;
|
||||
}
|
||||
|
||||
// When the woppbs is executed on encrypted signed integers, the index of the
|
||||
// lut elements must be adapted to fit the way signed are encrypted in CRT
|
||||
@@ -393,25 +397,40 @@ void memref_encode_expand_lut_for_woppbs(
|
||||
};
|
||||
}
|
||||
|
||||
uint64_t lut_crt_size = output_lut_size / crt_decomposition_size;
|
||||
uint64_t log_lut_crt_size = 0;
|
||||
|
||||
for (uint64_t index = 0; index < input_lut_size; index++) {
|
||||
uint64_t index_lut = 0;
|
||||
uint64_t tmp = 1;
|
||||
for (size_t in_block = 0; in_block < crt_decomposition_size; in_block++) {
|
||||
auto bits_count = crt_bits_aligned[crt_bits_offset + in_block];
|
||||
log_lut_crt_size += bits_count;
|
||||
}
|
||||
|
||||
for (size_t block = 0; block < crt_decomposition_size; block++) {
|
||||
auto base = crt_decomposition_aligned[crt_decomposition_offset + block];
|
||||
auto bits = crt_bits_aligned[crt_bits_offset + block];
|
||||
index_lut += (((indexMap(index) % base) << bits) / base) * tmp;
|
||||
tmp <<= bits;
|
||||
uint64_t lut_crt_size = 1 << log_lut_crt_size;
|
||||
assert(lut_crt_size == output_lut_size1);
|
||||
assert(crt_decomposition_size == output_lut_size0);
|
||||
|
||||
for (uint64_t in_index = 0; in_index < input_lut_size; in_index++) {
|
||||
uint64_t out_index = 0;
|
||||
|
||||
{
|
||||
uint64_t total_bit_count = 0;
|
||||
for (size_t in_block = 0; in_block < crt_decomposition_size; in_block++) {
|
||||
auto in_base =
|
||||
crt_decomposition_aligned[crt_decomposition_offset + in_block];
|
||||
auto bits_count = crt_bits_aligned[crt_bits_offset + in_block];
|
||||
out_index += (((indexMap(in_index) % in_base) << bits_count) / in_base)
|
||||
<< total_bit_count;
|
||||
total_bit_count += bits_count;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t block = 0; block < crt_decomposition_size; block++) {
|
||||
auto base = crt_decomposition_aligned[crt_decomposition_offset + block];
|
||||
auto v = encode_crt(input_lut_aligned[input_lut_offset + index], base,
|
||||
modulus_product);
|
||||
output_lut_aligned[output_lut_offset + block * lut_crt_size + index_lut] =
|
||||
v;
|
||||
for (size_t out_block = 0; out_block < crt_decomposition_size;
|
||||
out_block++) {
|
||||
auto out_base =
|
||||
crt_decomposition_aligned[crt_decomposition_offset + out_block];
|
||||
auto v = encode_crt(input_lut_aligned[input_lut_offset + in_index],
|
||||
out_base, modulus_product);
|
||||
output_lut_aligned[output_lut_offset + out_block * lut_crt_size +
|
||||
out_index] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -424,11 +443,10 @@ void memref_add_lwe_ciphertexts_u64(
|
||||
uint64_t ct1_offset, uint64_t ct1_size, uint64_t ct1_stride) {
|
||||
assert(out_size == ct0_size && out_size == ct1_size &&
|
||||
"size of lwe buffer are incompatible");
|
||||
size_t lwe_dimension = {out_size - 1};
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_engine_discard_add_lwe_ciphertext_u64_raw_ptr_buffers(
|
||||
get_levelled_engine(), out_aligned + out_offset,
|
||||
ct0_aligned + ct0_offset, ct1_aligned + ct1_offset, lwe_dimension));
|
||||
size_t lwe_dimension = out_size - 1;
|
||||
concrete_cpu_add_lwe_ciphertext_u64(out_aligned + out_offset,
|
||||
ct0_aligned + ct0_offset,
|
||||
ct1_aligned + ct1_offset, lwe_dimension);
|
||||
}
|
||||
|
||||
void memref_add_plaintext_lwe_ciphertext_u64(
|
||||
@@ -437,11 +455,10 @@ void memref_add_plaintext_lwe_ciphertext_u64(
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t plaintext) {
|
||||
assert(out_size == ct0_size && "size of lwe buffer are incompatible");
|
||||
size_t lwe_dimension = {out_size - 1};
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_engine_discard_add_lwe_ciphertext_plaintext_u64_raw_ptr_buffers(
|
||||
get_levelled_engine(), out_aligned + out_offset,
|
||||
ct0_aligned + ct0_offset, lwe_dimension, plaintext));
|
||||
size_t lwe_dimension = out_size - 1;
|
||||
concrete_cpu_add_plaintext_lwe_ciphertext_u64(out_aligned + out_offset,
|
||||
ct0_aligned + ct0_offset,
|
||||
plaintext, lwe_dimension);
|
||||
}
|
||||
|
||||
void memref_mul_cleartext_lwe_ciphertext_u64(
|
||||
@@ -450,11 +467,10 @@ void memref_mul_cleartext_lwe_ciphertext_u64(
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t cleartext) {
|
||||
assert(out_size == ct0_size && "size of lwe buffer are incompatible");
|
||||
size_t lwe_dimension = {out_size - 1};
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_engine_discard_mul_lwe_ciphertext_cleartext_u64_raw_ptr_buffers(
|
||||
get_levelled_engine(), out_aligned + out_offset,
|
||||
ct0_aligned + ct0_offset, lwe_dimension, cleartext));
|
||||
size_t lwe_dimension = out_size - 1;
|
||||
concrete_cpu_mul_cleartext_lwe_ciphertext_u64(out_aligned + out_offset,
|
||||
ct0_aligned + ct0_offset,
|
||||
cleartext, lwe_dimension);
|
||||
}
|
||||
|
||||
void memref_negate_lwe_ciphertext_u64(
|
||||
@@ -464,24 +480,25 @@ void memref_negate_lwe_ciphertext_u64(
|
||||
uint64_t ct0_stride) {
|
||||
assert(out_size == ct0_size && "size of lwe buffer are incompatible");
|
||||
size_t lwe_dimension = {out_size - 1};
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_engine_discard_opp_lwe_ciphertext_u64_raw_ptr_buffers(
|
||||
get_levelled_engine(), out_aligned + out_offset,
|
||||
ct0_aligned + ct0_offset, lwe_dimension));
|
||||
concrete_cpu_negate_lwe_ciphertext_u64(
|
||||
out_aligned + out_offset, ct0_aligned + ct0_offset, lwe_dimension);
|
||||
}
|
||||
|
||||
void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
uint64_t out_offset, uint64_t out_size,
|
||||
uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset,
|
||||
uint64_t ct0_size, uint64_t ct0_stride,
|
||||
uint32_t level, uint32_t base_log,
|
||||
uint32_t input_lwe_dim, uint32_t output_lwe_dim,
|
||||
mlir::concretelang::RuntimeContext *context) {
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_engine_discard_keyswitch_lwe_ciphertext_u64_raw_ptr_buffers(
|
||||
get_engine(context), get_keyswitch_key_u64(context),
|
||||
out_aligned + out_offset, ct0_aligned + ct0_offset));
|
||||
void memref_keyswitch_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint32_t decomposition_level_count,
|
||||
uint32_t decomposition_base_log, uint32_t input_dimension,
|
||||
uint32_t output_dimension, mlir::concretelang::RuntimeContext *context) {
|
||||
assert(out_stride == 1 && ct0_stride == 1);
|
||||
// Get keyswitch key - TODO Give a non hardcoded keyID
|
||||
const uint64_t *keyswitch_key = context->keyswitch_key_buffer(0);
|
||||
// Get stack parameter
|
||||
concrete_cpu_keyswitch_lwe_ciphertext_u64(
|
||||
out_aligned + out_offset, ct0_aligned + ct0_offset, keyswitch_key,
|
||||
decomposition_level_count, decomposition_base_log, input_dimension,
|
||||
output_dimension);
|
||||
}
|
||||
|
||||
void memref_batched_keyswitch_lwe_u64(
|
||||
@@ -507,24 +524,44 @@ void memref_bootstrap_lwe_u64(
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
|
||||
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
|
||||
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
|
||||
uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
|
||||
uint32_t input_lwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t decomposition_level_count, uint32_t decomposition_base_log,
|
||||
uint32_t glwe_dimension, uint32_t precision,
|
||||
mlir::concretelang::RuntimeContext *context) {
|
||||
|
||||
uint64_t glwe_ct_size = poly_size * (glwe_dim + 1);
|
||||
uint64_t glwe_ct_size = polynomial_size * (glwe_dimension + 1);
|
||||
uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size * sizeof(uint64_t));
|
||||
auto tlu = tlu_aligned + tlu_offset;
|
||||
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_engine_discard_trivially_encrypt_glwe_ciphertext_u64_raw_ptr_buffers(
|
||||
get_levelled_engine(), glwe_ct, glwe_ct_size,
|
||||
tlu_aligned + tlu_offset, poly_size));
|
||||
// Glwe trivial encryption
|
||||
for (size_t i = 0; i < polynomial_size * glwe_dimension; i++) {
|
||||
glwe_ct[i] = 0;
|
||||
}
|
||||
for (size_t i = 0; i < polynomial_size; i++) {
|
||||
glwe_ct[polynomial_size * glwe_dimension + i] = tlu[i];
|
||||
}
|
||||
|
||||
// Get fourrier bootstrap key - TODO Give a non hardcoded keyID
|
||||
size_t keyId = 0;
|
||||
const auto &fft = context->fft(keyId);
|
||||
auto bootstrap_key = context->fourier_bootstrap_key_buffer(keyId);
|
||||
// Get stack parameter
|
||||
size_t scratch_size;
|
||||
size_t scratch_align;
|
||||
concrete_cpu_bootstrap_lwe_ciphertext_u64_scratch(
|
||||
&scratch_size, &scratch_align, glwe_dimension, polynomial_size, fft);
|
||||
// Allocate scratch
|
||||
auto scratch = (uint8_t *)aligned_alloc(scratch_align, scratch_size);
|
||||
|
||||
// Bootstrap
|
||||
concrete_cpu_bootstrap_lwe_ciphertext_u64(
|
||||
out_aligned + out_offset, ct0_aligned + ct0_offset, glwe_ct,
|
||||
bootstrap_key, decomposition_level_count, decomposition_base_log,
|
||||
glwe_dimension, polynomial_size, input_lwe_dimension, fft, scratch,
|
||||
scratch_size);
|
||||
|
||||
CAPI_ASSERT_ERROR(
|
||||
fft_engine_lwe_ciphertext_discarding_bootstrap_u64_raw_ptr_buffers(
|
||||
get_fft_engine(context), get_engine(context),
|
||||
get_fft_fourier_bootstrap_key_u64(context), out_aligned + out_offset,
|
||||
ct0_aligned + ct0_offset, glwe_ct));
|
||||
free(glwe_ct);
|
||||
free(scratch);
|
||||
}
|
||||
|
||||
void memref_batched_bootstrap_lwe_u64(
|
||||
@@ -563,13 +600,16 @@ void memref_wop_pbs_crt_buffer(
|
||||
uint64_t in_stride_1,
|
||||
// clear text lut 1D memref
|
||||
uint64_t *lut_ct_allocated, uint64_t *lut_ct_aligned,
|
||||
uint64_t lut_ct_offset, uint64_t lut_ct_size, uint64_t lut_ct_stride,
|
||||
uint64_t lut_ct_offset, uint64_t lut_ct_size0, uint64_t lut_ct_size1,
|
||||
uint64_t lut_ct_stride0, uint64_t lut_ct_stride1,
|
||||
// CRT decomposition 1D memref
|
||||
uint64_t *crt_decomp_allocated, uint64_t *crt_decomp_aligned,
|
||||
uint64_t crt_decomp_offset, uint64_t crt_decomp_size,
|
||||
uint64_t crt_decomp_stride,
|
||||
// Additional crypto parameters
|
||||
uint32_t lwe_small_size, uint32_t cbs_level_count, uint32_t cbs_base_log,
|
||||
uint32_t ksk_level_count, uint32_t ksk_base_log, uint32_t bsk_level_count,
|
||||
uint32_t bsk_base_log, uint32_t fpksk_level_count, uint32_t fpksk_base_log,
|
||||
uint32_t polynomial_size,
|
||||
// runtime context that hold evluation keys
|
||||
mlir::concretelang::RuntimeContext *context) {
|
||||
@@ -585,7 +625,15 @@ void memref_wop_pbs_crt_buffer(
|
||||
// Check for the size S
|
||||
assert(out_size_1 == in_size_1);
|
||||
|
||||
uint64_t lwe_small_dim = lwe_small_size - 1;
|
||||
|
||||
assert(out_size_1 == in_size_1);
|
||||
uint64_t lwe_big_size = in_size_1;
|
||||
uint64_t lwe_big_dim = lwe_big_size - 1;
|
||||
assert(lwe_big_dim % polynomial_size == 0);
|
||||
|
||||
assert(lwe_big_dim % polynomial_size == 0);
|
||||
uint64_t glwe_dim = lwe_big_dim / polynomial_size;
|
||||
|
||||
// Compute the numbers of bits to extract for each block and the total one.
|
||||
uint64_t total_number_of_bits_per_block = 0;
|
||||
@@ -609,38 +657,83 @@ void memref_wop_pbs_crt_buffer(
|
||||
new uint64_t[lwe_small_size * total_number_of_bits_per_block]{0};
|
||||
|
||||
// We make a private copy to apply a subtraction on the body
|
||||
auto first_cyphertext = in_aligned + in_offset;
|
||||
auto first_ciphertext = in_aligned + in_offset;
|
||||
auto copy_size = crt_decomp_size * lwe_big_size;
|
||||
std::vector<uint64_t> in_copy(first_cyphertext, first_cyphertext + copy_size);
|
||||
std::vector<uint64_t> in_copy(first_ciphertext, first_ciphertext + copy_size);
|
||||
// Extraction of each bit for each block
|
||||
|
||||
size_t fftKeyId = 0;
|
||||
const auto &fft = context->fft(fftKeyId);
|
||||
size_t bskKeyId = 0;
|
||||
auto bootstrap_key = context->fourier_bootstrap_key_buffer(bskKeyId);
|
||||
|
||||
size_t kskKeyId = 0;
|
||||
auto keyswicth_key = context->keyswitch_key_buffer(kskKeyId);
|
||||
|
||||
for (int64_t i = crt_decomp_size - 1, extract_bits_output_offset = 0; i >= 0;
|
||||
extract_bits_output_offset += number_of_bits_per_block[i--]) {
|
||||
auto nb_bits_to_extract = number_of_bits_per_block[i];
|
||||
|
||||
auto delta_log = 64 - nb_bits_to_extract;
|
||||
size_t delta_log = 64 - nb_bits_to_extract;
|
||||
|
||||
auto in_block = &in_copy[lwe_big_size * i];
|
||||
|
||||
// trick ( ct - delta/2 + delta/2^4 )
|
||||
uint64_t sub = (uint64_t(1) << (uint64_t(64) - nb_bits_to_extract - 1)) -
|
||||
(uint64_t(1) << (uint64_t(64) - nb_bits_to_extract - 5));
|
||||
in_block[lwe_big_size - 1] -= sub;
|
||||
CAPI_ASSERT_ERROR(
|
||||
fft_engine_lwe_ciphertext_discarding_bit_extraction_unchecked_u64_raw_ptr_buffers(
|
||||
context->get_fft_engine(), context->get_default_engine(),
|
||||
context->get_fft_fourier_bsk(), context->get_ksk(),
|
||||
&extract_bits_output_buffer[lwe_small_size *
|
||||
extract_bits_output_offset],
|
||||
in_block, nb_bits_to_extract, delta_log));
|
||||
|
||||
size_t scratch_size;
|
||||
size_t scratch_align;
|
||||
concrete_cpu_extract_bit_lwe_ciphertext_u64_scratch(
|
||||
&scratch_size, &scratch_align, lwe_small_dim, lwe_big_dim, glwe_dim,
|
||||
polynomial_size, fft);
|
||||
// Allocate scratch
|
||||
auto *scratch = (uint8_t *)aligned_alloc(scratch_align, scratch_size);
|
||||
|
||||
concrete_cpu_extract_bit_lwe_ciphertext_u64(
|
||||
&extract_bits_output_buffer[lwe_small_size *
|
||||
extract_bits_output_offset],
|
||||
in_block, bootstrap_key, keyswicth_key, lwe_small_dim,
|
||||
nb_bits_to_extract, lwe_big_dim, nb_bits_to_extract, delta_log,
|
||||
bsk_level_count, bsk_base_log, glwe_dim, polynomial_size, lwe_small_dim,
|
||||
ksk_level_count, ksk_base_log, lwe_big_dim, lwe_small_dim, fft, scratch,
|
||||
scratch_size);
|
||||
|
||||
free(scratch);
|
||||
}
|
||||
|
||||
size_t ct_in_count = total_number_of_bits_per_block;
|
||||
size_t lut_size = 1 << ct_in_count;
|
||||
size_t ct_out_count = out_size_0;
|
||||
size_t lut_count = ct_out_count;
|
||||
|
||||
assert(lut_ct_size0 == lut_count);
|
||||
assert(lut_ct_size1 == lut_size);
|
||||
|
||||
// Vertical packing
|
||||
CAPI_ASSERT_ERROR(
|
||||
fft_engine_lwe_ciphertext_vector_discarding_circuit_bootstrap_boolean_vertical_packing_u64_raw_ptr_buffers(
|
||||
context->get_fft_engine(), context->get_default_engine(),
|
||||
context->get_fft_fourier_bsk(), out_aligned, lwe_big_size,
|
||||
crt_decomp_size, extract_bits_output_buffer, lwe_small_size,
|
||||
total_number_of_bits_per_block, lut_ct_aligned + lut_ct_offset,
|
||||
lut_ct_size, cbs_level_count, cbs_base_log, context->get_fpksk()));
|
||||
size_t scratch_size;
|
||||
size_t scratch_align;
|
||||
concrete_cpu_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_u64_scratch(
|
||||
&scratch_size, &scratch_align, ct_out_count, lwe_small_dim, ct_in_count,
|
||||
lut_size, lut_count, glwe_dim, polynomial_size, polynomial_size,
|
||||
cbs_level_count, fft);
|
||||
|
||||
auto *scratch = (uint8_t *)aligned_alloc(scratch_align, scratch_size);
|
||||
|
||||
size_t fpkskKeyId = 0;
|
||||
auto fp_keyswicth_key = context->fp_keyswitch_key_buffer(fpkskKeyId);
|
||||
|
||||
concrete_cpu_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_u64(
|
||||
out_aligned + out_offset, extract_bits_output_buffer,
|
||||
lut_ct_aligned + lut_ct_offset, bootstrap_key, fp_keyswicth_key,
|
||||
lwe_big_dim, ct_out_count, lwe_small_dim, ct_in_count, lut_size,
|
||||
lut_count, bsk_level_count, bsk_base_log, glwe_dim, polynomial_size,
|
||||
lwe_small_dim, fpksk_level_count, fpksk_base_log, lwe_big_dim, glwe_dim,
|
||||
polynomial_size, glwe_dim + 1, cbs_level_count, cbs_base_log, fft,
|
||||
scratch, scratch_size);
|
||||
|
||||
free(scratch);
|
||||
}
|
||||
|
||||
void memref_copy_one_rank(uint64_t *src_allocated, uint64_t *src_aligned,
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
|
||||
#include "concretelang/ClientLib/Serializers.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Runtime/context.h"
|
||||
#include "concretelang/ServerLib/DynamicArityCall.h"
|
||||
#include "concretelang/ServerLib/DynamicModule.h"
|
||||
#include "concretelang/ServerLib/DynamicRankCall.h"
|
||||
@@ -23,8 +24,8 @@ using concretelang::clientlib::CircuitGate;
|
||||
using concretelang::clientlib::CircuitGateShape;
|
||||
using concretelang::clientlib::EvaluationKeys;
|
||||
using concretelang::clientlib::PublicArguments;
|
||||
using concretelang::clientlib::RuntimeContext;
|
||||
using concretelang::error::StringError;
|
||||
using mlir::concretelang::RuntimeContext;
|
||||
|
||||
outcome::checked<ServerLambda, StringError>
|
||||
ServerLambda::loadFromModule(std::shared_ptr<DynamicModule> module,
|
||||
@@ -70,8 +71,7 @@ ServerLambda::call(PublicArguments &args, EvaluationKeys &evaluationKeys) {
|
||||
std::vector<void *> preparedArgs(args.preparedArgs.begin(),
|
||||
args.preparedArgs.end());
|
||||
|
||||
RuntimeContext runtimeContext;
|
||||
runtimeContext.evaluationKeys = evaluationKeys;
|
||||
RuntimeContext runtimeContext(evaluationKeys);
|
||||
preparedArgs.push_back((void *)&runtimeContext);
|
||||
|
||||
assert(clientParameters.outputs.size() == 1 &&
|
||||
|
||||
@@ -37,5 +37,6 @@ add_mlir_library(
|
||||
${LLVM_PTHREAD_LIB}
|
||||
ConcretelangRuntime
|
||||
ConcretelangClientLib
|
||||
ConcretelangServerLib
|
||||
Concrete)
|
||||
ConcretelangServerLib)
|
||||
|
||||
target_include_directories(ConcretelangSupport PUBLIC ${CONCRETE_CPU_INCLUDE_DIR})
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <cassert>
|
||||
#include <fstream>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
@@ -18,30 +19,29 @@ void CompilationFeedback::fillFromClientParameters(
|
||||
// Compute the size of secret keys
|
||||
totalSecretKeysSize = 0;
|
||||
for (auto sk : params.secretKeys) {
|
||||
totalSecretKeysSize += sk.second.byteSize();
|
||||
totalSecretKeysSize += sk.byteSize();
|
||||
}
|
||||
// Compute the boostrap keys size
|
||||
totalBootstrapKeysSize = 0;
|
||||
for (auto bsk : params.bootstrapKeys) {
|
||||
auto bskParam = bsk.second;
|
||||
auto inputKey = params.secretKeys.find(bskParam.inputSecretKeyID);
|
||||
assert(inputKey != params.secretKeys.end());
|
||||
auto outputKey = params.secretKeys.find(bskParam.outputSecretKeyID);
|
||||
assert(outputKey != params.secretKeys.end());
|
||||
for (auto bskParam : params.bootstrapKeys) {
|
||||
assert(bskParam.inputSecretKeyID < params.secretKeys.size());
|
||||
auto inputKey = params.secretKeys[bskParam.inputSecretKeyID];
|
||||
|
||||
totalBootstrapKeysSize += bskParam.byteSize(inputKey->second.lweSize(),
|
||||
outputKey->second.lweSize());
|
||||
assert(bskParam.outputSecretKeyID < params.secretKeys.size());
|
||||
auto outputKey = params.secretKeys[bskParam.outputSecretKeyID];
|
||||
|
||||
totalBootstrapKeysSize +=
|
||||
bskParam.byteSize(inputKey.lweSize(), outputKey.lweSize());
|
||||
}
|
||||
// Compute the keyswitch keys size
|
||||
totalKeyswitchKeysSize = 0;
|
||||
for (auto ksk : params.keyswitchKeys) {
|
||||
auto kskParam = ksk.second;
|
||||
auto inputKey = params.secretKeys.find(kskParam.inputSecretKeyID);
|
||||
assert(inputKey != params.secretKeys.end());
|
||||
auto outputKey = params.secretKeys.find(kskParam.outputSecretKeyID);
|
||||
assert(outputKey != params.secretKeys.end());
|
||||
totalKeyswitchKeysSize += kskParam.byteSize(inputKey->second.lweSize(),
|
||||
outputKey->second.lweSize());
|
||||
for (auto kskParam : params.keyswitchKeys) {
|
||||
assert(kskParam.inputSecretKeyID < params.secretKeys.size());
|
||||
auto inputKey = params.secretKeys[kskParam.inputSecretKeyID];
|
||||
assert(kskParam.outputSecretKeyID < params.secretKeys.size());
|
||||
auto outputKey = params.secretKeys[kskParam.outputSecretKeyID];
|
||||
totalKeyswitchKeysSize +=
|
||||
kskParam.byteSize(inputKey.lweSize(), outputKey.lweSize());
|
||||
}
|
||||
// Compute the size of inputs
|
||||
totalInputsSize = 0;
|
||||
|
||||
@@ -13,10 +13,11 @@
|
||||
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
|
||||
|
||||
#include "concretelang/Common/BitsSize.h"
|
||||
#include <concretelang/Runtime/DFRuntime.hpp>
|
||||
#include <concretelang/Support/Error.h>
|
||||
#include <concretelang/Support/Jit.h>
|
||||
#include <concretelang/Support/logging.h>
|
||||
#include "concretelang/Runtime/DFRuntime.hpp"
|
||||
#include "concretelang/Runtime/context.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include "concretelang/Support/Jit.h"
|
||||
#include "concretelang/Support/logging.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
@@ -133,8 +134,7 @@ JITLambda::call(clientlib::PublicArguments &args,
|
||||
rawArgs[i++] = &arg;
|
||||
}
|
||||
|
||||
RuntimeContext runtimeContext;
|
||||
runtimeContext.evaluationKeys = evaluationKeys;
|
||||
mlir::concretelang::RuntimeContext runtimeContext(evaluationKeys);
|
||||
// Pointer on runtime context, the rawArgs take pointer on actual value that
|
||||
// is passed to the compiled function.
|
||||
auto rtCtxPtr = &runtimeContext;
|
||||
|
||||
@@ -250,11 +250,10 @@ lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
auto dec =
|
||||
fheContext.value().parameter.largeInteger.value().crtDecomposition;
|
||||
auto mods = mlir::SmallVector<int64_t>(dec.begin(), dec.end());
|
||||
auto polySize = fheContext.value().parameter.getPolynomialSize();
|
||||
addPotentiallyNestedPass(
|
||||
pm,
|
||||
mlir::concretelang::createConvertFHEToTFHECrtPass(
|
||||
mlir::concretelang::CrtLoweringParameters(mods, polySize)),
|
||||
mlir::concretelang::CrtLoweringParameters(mods)),
|
||||
enablePass);
|
||||
} else if (fheContext.hasValue()) {
|
||||
pipelinePrinting("FHEToTFHEScalar", pm, context);
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
#include <cassert>
|
||||
#include <map>
|
||||
|
||||
#include <llvm/ADT/Optional.h>
|
||||
@@ -161,59 +162,56 @@ createClientParametersForV0(V0FHEContext fheContext,
|
||||
Variance keyswitchKeyVariance = v0Curve->getVariance(1, v0Param.nSmall, 64);
|
||||
// Static client parameters from global parameters for v0
|
||||
ClientParameters c;
|
||||
c.secretKeys = {
|
||||
{clientlib::BIG_KEY, {/*.size = */ v0Param.getNBigLweDimension()}},
|
||||
};
|
||||
|
||||
assert(c.secretKeys.size() == clientlib::BIG_KEY);
|
||||
clientlib::LweSecretKeyParam skParam;
|
||||
skParam.dimension = v0Param.getNBigLweDimension();
|
||||
c.secretKeys.push_back(skParam);
|
||||
|
||||
bool has_small_key = v0Param.nSmall != 0;
|
||||
bool has_bootstrap = v0Param.brLevel != 0;
|
||||
if (has_small_key) {
|
||||
c.secretKeys.insert({clientlib::SMALL_KEY, {/*.size = */ v0Param.nSmall}});
|
||||
assert(c.secretKeys.size() == clientlib::SMALL_KEY);
|
||||
clientlib::LweSecretKeyParam skParam2;
|
||||
skParam2.dimension = v0Param.nSmall;
|
||||
c.secretKeys.push_back(skParam2);
|
||||
}
|
||||
if (has_bootstrap) {
|
||||
auto inputKey = (has_small_key) ? clientlib::SMALL_KEY : clientlib::BIG_KEY;
|
||||
c.bootstrapKeys = {
|
||||
{
|
||||
clientlib::BOOTSTRAP_KEY,
|
||||
{
|
||||
/*.inputSecretKeyID = */ inputKey,
|
||||
/*.outputSecretKeyID = */ clientlib::BIG_KEY,
|
||||
/*.level = */ v0Param.brLevel,
|
||||
/*.baseLog = */ v0Param.brLogBase,
|
||||
/*.glweDimension = */ v0Param.glweDimension,
|
||||
/*.variance = */ bootstrapKeyVariance,
|
||||
},
|
||||
},
|
||||
};
|
||||
clientlib::BootstrapKeyParam bskParam;
|
||||
bskParam.inputSecretKeyID = inputKey;
|
||||
bskParam.outputSecretKeyID = clientlib::BIG_KEY;
|
||||
bskParam.level = v0Param.brLevel;
|
||||
bskParam.baseLog = v0Param.brLogBase;
|
||||
bskParam.glweDimension = v0Param.glweDimension;
|
||||
bskParam.variance = bootstrapKeyVariance;
|
||||
bskParam.polynomialSize = v0Param.getPolynomialSize();
|
||||
bskParam.inputLweDimension = v0Param.nSmall;
|
||||
c.bootstrapKeys.push_back(bskParam);
|
||||
}
|
||||
if (v0Param.largeInteger.hasValue()) {
|
||||
clientlib::PackingKeySwitchParam param;
|
||||
clientlib::PackingKeyswitchKeyParam param;
|
||||
param.inputSecretKeyID = clientlib::BIG_KEY;
|
||||
param.outputSecretKeyID = clientlib::BIG_KEY;
|
||||
param.level = v0Param.largeInteger->wopPBS.packingKeySwitch.level;
|
||||
param.baseLog = v0Param.largeInteger->wopPBS.packingKeySwitch.baseLog;
|
||||
param.bootstrapKeyID = clientlib::BOOTSTRAP_KEY;
|
||||
|
||||
param.glweDimension = v0Param.glweDimension;
|
||||
param.polynomialSize = v0Param.getPolynomialSize();
|
||||
param.inputLweDimension = v0Param.getNBigLweDimension();
|
||||
param.variance = v0Curve->getVariance(v0Param.glweDimension,
|
||||
v0Param.getPolynomialSize(), 64);
|
||||
c.packingKeys = {
|
||||
{
|
||||
"fpksk_v0",
|
||||
param,
|
||||
},
|
||||
};
|
||||
|
||||
c.packingKeyswitchKeys.push_back(param);
|
||||
}
|
||||
if (has_small_key) {
|
||||
c.keyswitchKeys = {
|
||||
{
|
||||
clientlib::KEYSWITCH_KEY,
|
||||
{
|
||||
/*.inputSecretKeyID = */ clientlib::BIG_KEY,
|
||||
/*.outputSecretKeyID = */ clientlib::SMALL_KEY,
|
||||
/*.level = */ v0Param.ksLevel,
|
||||
/*.baseLog = */ v0Param.ksLogBase,
|
||||
/*.variance = */ keyswitchKeyVariance,
|
||||
},
|
||||
},
|
||||
};
|
||||
clientlib::KeyswitchKeyParam kskParam;
|
||||
kskParam.inputSecretKeyID = clientlib::BIG_KEY;
|
||||
kskParam.outputSecretKeyID = clientlib::SMALL_KEY;
|
||||
kskParam.level = v0Param.ksLevel;
|
||||
kskParam.baseLog = v0Param.ksLogBase;
|
||||
kskParam.variance = keyswitchKeyVariance;
|
||||
c.keyswitchKeys.push_back(kskParam);
|
||||
}
|
||||
|
||||
c.functionName = (std::string)functionName;
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @apply_lookup_table(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{3}>>
|
||||
// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
|
||||
// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bootstrapBaseLog = -1 : i32, bootstrapLevel = -1 : i32, circuitBootstrapBaseLog = -1 : i32, circuitBootstrapLevel = -1 : i32, crtDecomposition = [], keyswitchBaseLog = -1 : i32, keyswitchLevel = -1 : i32, packingKeySwitchBaseLog = -1 : i32, packingKeySwitchInputLweDimension = -1 : i32, packingKeySwitchLevel = -1 : i32, packingKeySwitchoutputPolynomialSize = -1 : i32} : (tensor<5x!TFHE.glwe<{_,_,_}{2}>>, tensor<40960xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{3}>>
|
||||
// CHECK-NEXT: %0 = "TFHE.encode_lut_for_crt_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32} : (tensor<4xi64>) -> tensor<5x8192xi64>
|
||||
// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bootstrapBaseLog = -1 : i32, bootstrapLevel = -1 : i32, circuitBootstrapBaseLog = -1 : i32, circuitBootstrapLevel = -1 : i32, crtDecomposition = [], keyswitchBaseLog = -1 : i32, keyswitchLevel = -1 : i32, packingKeySwitchBaseLog = -1 : i32, packingKeySwitchInputLweDimension = -1 : i32, packingKeySwitchLevel = -1 : i32, packingKeySwitchoutputPolynomialSize = -1 : i32} : (tensor<5x!TFHE.glwe<{_,_,_}{2}>>, tensor<5x8192xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{3}>>
|
||||
// CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<{_,_,_}{3}>>
|
||||
func.func @apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE.eint<3> {
|
||||
%1 = "FHE.apply_lookup_table"(%arg0, %arg1): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<3>)
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
// CHECK: func.func @apply_lookup_table_cst(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> {
|
||||
// CHECK-NEXT: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64>
|
||||
// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_woppbs"(%cst) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<128xi64>) -> tensor<40960xi64>
|
||||
// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bootstrapBaseLog = -1 : i32, bootstrapLevel = -1 : i32, circuitBootstrapBaseLog = -1 : i32, circuitBootstrapLevel = -1 : i32, crtDecomposition = [], keyswitchBaseLog = -1 : i32, keyswitchLevel = -1 : i32, packingKeySwitchBaseLog = -1 : i32, packingKeySwitchInputLweDimension = -1 : i32, packingKeySwitchLevel = -1 : i32, packingKeySwitchoutputPolynomialSize = -1 : i32} : (tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<40960xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: %0 = "TFHE.encode_lut_for_crt_woppbs"(%cst) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32} : (tensor<128xi64>) -> tensor<5x8192xi64>
|
||||
// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bootstrapBaseLog = -1 : i32, bootstrapLevel = -1 : i32, circuitBootstrapBaseLog = -1 : i32, circuitBootstrapLevel = -1 : i32, crtDecomposition = [], keyswitchBaseLog = -1 : i32, keyswitchLevel = -1 : i32, packingKeySwitchBaseLog = -1 : i32, packingKeySwitchInputLweDimension = -1 : i32, packingKeySwitchLevel = -1 : i32, packingKeySwitchoutputPolynomialSize = -1 : i32} : (tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<5x8192xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
func.func @apply_lookup_table_cst(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64>
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @main(%arg0: tensor<4xi64>) -> tensor<40960xi64> {
|
||||
// CHECK-NEXT: %0 = "Concrete.encode_expand_lut_for_woppbs_tensor"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
|
||||
// CHECK-NEXT: return %0 : tensor<40960xi64>
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%arg1: tensor<4xi64>) -> tensor<40960xi64> {
|
||||
%0 = "TFHE.encode_expand_lut_for_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
|
||||
return %0: tensor<40960xi64>
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @main(%arg0: tensor<4xi64>) -> tensor<5x8192xi64> {
|
||||
// CHECK-NEXT: %0 = "Concrete.encode_lut_for_crt_woppbs_tensor"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32} : (tensor<4xi64>) -> tensor<5x8192xi64>
|
||||
// CHECK-NEXT: return %0 : tensor<5x8192xi64>
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%arg1: tensor<4xi64>) -> tensor<5x8192xi64> {
|
||||
%0 = "TFHE.encode_lut_for_crt_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32} : (tensor<4xi64>) -> tensor<5x8192xi64>
|
||||
return %0: tensor<5x8192xi64>
|
||||
}
|
||||
@@ -3,6 +3,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <type_traits>
|
||||
|
||||
#include "concretelang/ClientLib/Serializers.h"
|
||||
#include "concretelang/Support/CompilationFeedback.h"
|
||||
#include "concretelang/Support/JITSupport.h"
|
||||
#include "concretelang/Support/LibrarySupport.h"
|
||||
@@ -72,6 +73,13 @@ public:
|
||||
|
||||
void testOnce() {
|
||||
auto evaluationKeys = keySet->evaluationKeys();
|
||||
|
||||
/* Serialize and unserialize evaluation keys */
|
||||
std::stringstream stream;
|
||||
stream << evaluationKeys;
|
||||
stream.seekg(0, std::ios::beg);
|
||||
evaluationKeys = concretelang::clientlib::readEvaluationKeys(stream);
|
||||
|
||||
/* Call the server lambda */
|
||||
auto publicResult =
|
||||
support.serverCall(serverLambda, *publicArguments, evaluationKeys);
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#include <cassert>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
@@ -8,38 +9,42 @@ namespace clientlib = concretelang::clientlib;
|
||||
|
||||
TEST(Support, client_parameters_json_serde) {
|
||||
clientlib::ClientParameters params0;
|
||||
params0.secretKeys = {
|
||||
{clientlib::SMALL_KEY, {/*.size = */ 12}},
|
||||
{clientlib::BIG_KEY, {/*.size = */ 14}},
|
||||
};
|
||||
params0.bootstrapKeys = {
|
||||
{clientlib::BOOTSTRAP_KEY,
|
||||
{/*.inputSecretKeyID = */ clientlib::SMALL_KEY,
|
||||
/*.outputSecretKeyID = */ clientlib::BIG_KEY,
|
||||
/*.level = */ 1,
|
||||
/*.baseLog = */ 2,
|
||||
/*.glweDimension = */ 3,
|
||||
/*.variance = */ 0.001}},
|
||||
{"wtf_bsk_v0",
|
||||
{
|
||||
/*.inputSecretKeyID = */ clientlib::BIG_KEY,
|
||||
/*.outputSecretKeyID = */ clientlib::SMALL_KEY,
|
||||
/*.level = */ 3,
|
||||
/*.baseLog = */ 2,
|
||||
/*.glweDimension = */ 1,
|
||||
/*.variance = */ 0.0001,
|
||||
}},
|
||||
};
|
||||
params0.keyswitchKeys = {{clientlib::KEYSWITCH_KEY,
|
||||
{
|
||||
/*.inputSecretKeyID = */
|
||||
clientlib::BIG_KEY,
|
||||
/*.outputSecretKeyID = */
|
||||
clientlib::SMALL_KEY,
|
||||
/*.level = */ 1,
|
||||
/*.baseLog = */ 2,
|
||||
/*.variance = */ 3,
|
||||
}}};
|
||||
assert(params0.secretKeys.size() == clientlib::BIG_KEY);
|
||||
params0.secretKeys.push_back({14});
|
||||
|
||||
assert(params0.secretKeys.size() == clientlib::SMALL_KEY);
|
||||
params0.secretKeys.push_back({12});
|
||||
|
||||
params0.bootstrapKeys.push_back({
|
||||
/*.inputSecretKeyID = */ clientlib::SMALL_KEY,
|
||||
/*.outputSecretKeyID = */ clientlib::BIG_KEY,
|
||||
/*.level = */ 1,
|
||||
/*.baseLog = */ 2,
|
||||
/*.glweDimension = */ 3,
|
||||
/*.variance = */ 0.001,
|
||||
/*.polynomialSize = */ 1024,
|
||||
/*.inputLweDimension = */ 600,
|
||||
});
|
||||
|
||||
params0.bootstrapKeys.push_back({
|
||||
/*.inputSecretKeyID = */ clientlib::BIG_KEY,
|
||||
/*.outputSecretKeyID = */ clientlib::SMALL_KEY,
|
||||
/*.level = */ 3,
|
||||
/*.baseLog = */ 2,
|
||||
/*.glweDimension = */ 1,
|
||||
/*.variance = */ 0.0001,
|
||||
/*.polynomialSize = */ 1024,
|
||||
/*.inputLweDimension = */ 600,
|
||||
});
|
||||
params0.keyswitchKeys.push_back({
|
||||
/*.inputSecretKeyID = */
|
||||
clientlib::BIG_KEY,
|
||||
/*.outputSecretKeyID = */
|
||||
clientlib::SMALL_KEY,
|
||||
/*.level = */ 1,
|
||||
/*.baseLog = */ 2,
|
||||
/*.variance = */ 3,
|
||||
});
|
||||
params0.inputs = {
|
||||
{
|
||||
/*.encryption = */ {
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "concrete/curves.h"
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/EncryptedArguments.h"
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "tests_tools/assert.h"
|
||||
|
||||
namespace clientlib = concretelang::clientlib;
|
||||
@@ -19,9 +20,13 @@ TEST_P(KeySetTest, encrypt_decrypt) {
|
||||
|
||||
auto clientParameters = GetParam();
|
||||
|
||||
__uint128_t seed = 0;
|
||||
|
||||
// Generate the client keySet
|
||||
ASSERT_ASSIGN_OUTCOME_VALUE(
|
||||
keySet, clientlib::KeySet::generate(clientParameters, 0, 0));
|
||||
keySet,
|
||||
clientlib::KeySet::generate(
|
||||
clientParameters, concretelang::clientlib::ConcreteCSPRNG(seed)));
|
||||
|
||||
// Allocate the ciphertext
|
||||
uint64_t *ciphertext = nullptr;
|
||||
@@ -50,13 +55,13 @@ clientlib::ClientParameters generateClientParameterOneScalarOneScalar(
|
||||
clientlib::CRTDecomposition crtDecomposition) {
|
||||
// One secret key with the given dimension
|
||||
clientlib::ClientParameters params;
|
||||
params.secretKeys.insert({clientlib::SMALL_KEY, {/*.dimension =*/dimension}});
|
||||
params.secretKeys.push_back({/*.dimension =*/dimension});
|
||||
// One input and output encryption gate on the same secret key and encoded
|
||||
// with the same precision
|
||||
const auto v0Curve = concrete::getSecurityCurve(128, concrete::BINARY);
|
||||
|
||||
clientlib::EncryptionGate encryption;
|
||||
encryption.secretKeyID = clientlib::SMALL_KEY;
|
||||
encryption.secretKeyID = clientlib::BIG_KEY;
|
||||
encryption.encoding.precision = precision;
|
||||
encryption.encoding.crt = crtDecomposition;
|
||||
encryption.variance = v0Curve->getVariance(1, dimension, 64);
|
||||
|
||||
Reference in New Issue
Block a user