refactor: Integrate concrete-cpu and remove concrete-core

Co-authored-by: Mayeul@Zama <mayeul.debellabre@zama.ai>
This commit is contained in:
Quentin Bourgerie
2023-02-03 15:54:51 +01:00
parent 9c784a2243
commit a62b3b1d74
53 changed files with 1502 additions and 1868 deletions

View File

@@ -73,7 +73,7 @@ jobs:
- name: Create build dir
run: mkdir build
- name: Build compiler
- name: Build and test compiler
uses: addnab/docker-run-action@v3
id: build-compiler
with:
@@ -94,34 +94,7 @@ jobs:
set -e
cd /compiler
rm -rf /build/*
make DATAFLOW_EXECUTION_ENABLED=ON CCACHE=ON Python3_EXECUTABLE=$PYTHON_EXEC BUILD_DIR=/build CUDA_SUPPORT=ON CUDA_PATH=${{ env.CUDA_PATH }} all python-package build-end-to-end-dataflow-tests
mkdir -p /tmp/concrete_compiler/gpu_tests/
make BINDINGS_PYTHON_ENABLED=OFF CCACHE=ON Python3_EXECUTABLE=$PYTHON_EXEC CUDA_SUPPORT=ON CUDA_PATH=${{ env.CUDA_PATH }} run-end-to-end-tests-gpu
echo "Debug: ccache statistics (after the build):"
ccache -s
- name: Archive python package
uses: actions/upload-artifact@v3
with:
name: concrete-compiler-gpu
path: build/wheels/*.whl
- name: Test compiler
uses: addnab/docker-run-action@v3
with:
registry: ghcr.io
image: ${{ env.DOCKER_IMAGE_TEST }}
username: ${{ secrets.GHCR_LOGIN }}
password: ${{ secrets.GHCR_PASSWORD }}
options: >-
-v ${{ github.workspace }}/llvm-project:/llvm-project
-v ${{ github.workspace }}/compiler:/compiler
-v ${{ github.workspace }}/build:/build
--gpus all
shell: bash
run: |
set -e
cd /compiler
pip install pytest
sed "s/pytest/python -m pytest/g" -i Makefile
mkdir -p /tmp/concrete_compiler/gpu_tests/
make DATAFLOW_EXECUTION_ENABLED=ON CCACHE=ON Python3_EXECUTABLE=$PYTHON_EXEC BUILD_DIR=/build run-end-to-end-tests-gpu
chmod -R ugo+rwx /tmp/KeySetCache

5
.gitmodules vendored
View File

@@ -13,3 +13,8 @@
[submodule "compiler/parameter-curves"]
path = compiler/parameter-curves
url = git@github.com:zama-ai/parameter-curves.git
shallow = true
[submodule "compiler/concrete-cpu"]
path = compiler/concrete-cpu
url = git@github.com:zama-ai/concrete-cpu.git
shallow = true

3
compiler/.gitignore vendored
View File

@@ -4,9 +4,6 @@ build*/
*.mlir.script
*.lit_test_times.txt
# core ffi artifacts
concrete-core-ffi*
# Test-generated artifacts
concrete-compiler_compilation_artifacts/
py_test_lib_compile_and_run_custom_perror/

View File

@@ -61,11 +61,23 @@ set(CONCRETELANG_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
include_directories(${PROJECT_SOURCE_DIR}/parameter-curves/concrete-security-curves-cpp/include)
# -------------------------------------------------------------------------------
# Concrete FFI Configuration
# Concrete CPU Configuration
# -------------------------------------------------------------------------------
include_directories(${CONCRETE_FFI_RELEASE})
add_library(Concrete STATIC IMPORTED)
set_target_properties(Concrete PROPERTIES IMPORTED_LOCATION ${CONCRETE_FFI_RELEASE}/libconcrete_core_ffi.a)
set(CONCRETE_CPU_STATIC_LIB "${PROJECT_SOURCE_DIR}/concrete-cpu/target/release/libconcrete_cpu.a")
ExternalProject_Add(
concrete_cpu_rust
DOWNLOAD_COMMAND ""
CONFIGURE_COMMAND "" OUTPUT "${CONCRETE_CPU_STATIC_LIB}"
BUILD_COMMAND cargo build
COMMAND cargo build --release
BINARY_DIR "${PROJECT_SOURCE_DIR}/concrete-cpu"
INSTALL_COMMAND ""
LOG_BUILD ON)
add_library(concrete_cpu STATIC IMPORTED)
# TODO - Change that to a location in the release dir
set(CONCRETE_CPU_INCLUDE_DIR "${PROJECT_SOURCE_DIR}/concrete-cpu/concrete-cpu")
set_target_properties(concrete_cpu PROPERTIES IMPORTED_LOCATION "${CONCRETE_CPU_STATIC_LIB}")
add_dependencies(concrete_cpu concrete_cpu_rust)
# --------------------------------------------------------------------------------
# Concrete Cuda Configuration

View File

@@ -24,8 +24,6 @@ HPX_TARBALL=$(shell pwd)/hpx-$(HPX_VERSION).tar.gz
HPX_LOCAL_DIR=$(shell pwd)/hpx-$(HPX_VERSION)
HPX_INSTALL_DIR?=$(HPX_LOCAL_DIR)/build
CONCRETE_CORE_FFI_VERSION?=0.2.0
ML_BENCH_SUBSET_ID=
# Find OS
@@ -50,8 +48,6 @@ else
ARCHITECTURE=amd64
endif
CONCRETE_CORE_FFI_TARBALL=concrete-core-ffi_$(CONCRETE_CORE_FFI_VERSION)_$(OS)_$(ARCHITECTURE).tar.gz
export PATH := $(abspath $(BUILD_DIR))/bin:$(PATH)
ifeq ($(shell which ccache),)
@@ -87,23 +83,6 @@ endif
all: concretecompiler python-bindings build-tests build-benchmarks build-mlbench doc rust-bindings
# concrete-core-ffi #######################################
CONCRETE_CORE_FFI_FOLDER = $(shell pwd)/concrete-core-ffi-$(CONCRETE_CORE_FFI_VERSION)
CONCRETE_CORE_FFI_ARTIFACTS=$(CONCRETE_CORE_FFI_FOLDER)/libconcrete_core_ffi.a $(CONCRETE_CORE_FFI_FOLDER)/concrete-core-ffi.h
concrete-core-ffi: $(CONCRETE_CORE_FFI_ARTIFACTS)
$(CONCRETE_CORE_FFI_ARTIFACTS): $(CONCRETE_CORE_FFI_FOLDER)
$(CONCRETE_CORE_FFI_FOLDER): $(CONCRETE_CORE_FFI_TARBALL)
mkdir -p $(CONCRETE_CORE_FFI_FOLDER)
tar -xvzf $(CONCRETE_CORE_FFI_TARBALL) --directory $(CONCRETE_CORE_FFI_FOLDER)
$(CONCRETE_CORE_FFI_TARBALL):
curl -L https://github.com/zama-ai/concrete-core/releases/download/concrete-core-ffi-$(CONCRETE_CORE_FFI_VERSION)/$(CONCRETE_CORE_FFI_TARBALL) -o $(CONCRETE_CORE_FFI_TARBALL)
# concrete-optimizer ######################################
LIB_CONCRETE_OPTIMIZER_CPP = $(CONCRETE_OPTIMIZER_DIR)/target/libconcrete_optimizer_cpp.a
@@ -143,7 +122,6 @@ $(BUILD_DIR)/configured.stamp:
-DCONCRETELANG_BINDINGS_PYTHON_ENABLED=$(BINDINGS_PYTHON_ENABLED) \
-DCONCRETELANG_DATAFLOW_EXECUTION_ENABLED=$(DATAFLOW_EXECUTION_ENABLED) \
-DCONCRETELANG_TIMING_ENABLED=$(TIMING_ENABLED) \
-DCONCRETE_FFI_RELEASE=$(CONCRETE_CORE_FFI_FOLDER) \
-DHPX_DIR=${HPX_INSTALL_DIR}/lib/cmake/HPX \
-DLLVM_EXTERNAL_PROJECTS=concretelang \
-DLLVM_EXTERNAL_CONCRETELANG_SOURCE_DIR=. \
@@ -154,7 +132,7 @@ $(BUILD_DIR)/configured.stamp:
-DCUDAToolkit_ROOT=$(CUDA_PATH)
touch $@
build-initialized: concrete-optimizer-lib concrete-core-ffi $(BUILD_DIR)/configured.stamp
build-initialized: concrete-optimizer-lib $(BUILD_DIR)/configured.stamp
doc: build-initialized
cmake --build $(BUILD_DIR) --target mlir-doc
@@ -330,7 +308,7 @@ $(BENCHMARK_CPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml: tests/end_to_end
$(BENCHMARK_CPU_DIR)/%.yaml: tests/end_to_end_fixture/%_gen.py
$(Python3_EXECUTABLE) $< > $@
generate-cpu-benchmarks: $(BENCHMARK_CPU_DIR) $(BENCHMARK_CPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml $(BENCHMARK_CPU_DIR)/end_to_end_apply_lookup_table.yaml $(BENCHMARK_CPU_DIR)/end_to_end_leveled.yaml
generate-cpu-benchmarks: $(BENCHMARK_CPU_DIR) $(BENCHMARK_CPU_DIR)/end_to_end_linalg_apply_lookup_table.yaml $(BENCHMARK_CPU_DIR)/end_to_end_apply_lookup_table.yaml
SECURITY_TO_BENCH=128
run-cpu-benchmarks: build-benchmarks generate-cpu-benchmarks
@@ -530,7 +508,6 @@ install: concretecompiler concrete-optimizer-lib CAPI install-deps
python-format \
check-python-format \
concrete-optimizer-lib \
concrete-core-ffi \
build-tests \
run-tests \
run-check-tests \
@@ -540,7 +517,6 @@ install: concretecompiler concrete-optimizer-lib CAPI install-deps
build-end-to-end-tests \
build-end-to-end-dataflow-tests \
run-end-to-end-dataflow-tests \
concrete-core-ffi \
opt \
mlir-opt \
mlir-cpu-runner \

1
compiler/concrete-cpu Submodule

Submodule compiler/concrete-cpu added at db262714cd

View File

@@ -92,8 +92,7 @@ public:
OUTCOME_TRY(auto clientArguments,
EncryptedArguments::create(keySet, args...));
return clientArguments->exportPublicArguments(clientParameters,
keySet.runtimeContext());
return clientArguments->exportPublicArguments(clientParameters);
}
outcome::checked<Result, StringError> decryptResult(KeySet &keySet,

View File

@@ -35,10 +35,8 @@ namespace clientlib {
using concretelang::error::StringError;
const std::string SMALL_KEY = "small";
const std::string BIG_KEY = "big";
const std::string BOOTSTRAP_KEY = "bsk_v0";
const std::string KEYSWITCH_KEY = "ksk_v0";
const uint64_t SMALL_KEY = 1;
const uint64_t BIG_KEY = 0;
const std::string CLIENT_PARAMETERS_EXT = ".concrete.params.json";
@@ -52,7 +50,7 @@ typedef std::vector<int64_t> CRTDecomposition;
typedef uint64_t LweDimension;
typedef uint64_t GlweDimension;
typedef std::string LweSecretKeyID;
typedef uint64_t LweSecretKeyID;
struct LweSecretKeyParam {
LweDimension dimension;
@@ -66,7 +64,7 @@ static bool operator==(const LweSecretKeyParam &lhs,
return lhs.dimension == rhs.dimension;
}
typedef std::string BootstrapKeyID;
typedef uint64_t BootstrapKeyID;
struct BootstrapKeyParam {
LweSecretKeyID inputSecretKeyID;
LweSecretKeyID outputSecretKeyID;
@@ -74,6 +72,8 @@ struct BootstrapKeyParam {
DecompositionBaseLog baseLog;
GlweDimension glweDimension;
Variance variance;
PolynomialSize polynomialSize;
LweDimension inputLweDimension;
void hash(size_t &seed);
@@ -90,7 +90,7 @@ static inline bool operator==(const BootstrapKeyParam &lhs,
lhs.glweDimension == rhs.glweDimension && lhs.variance == rhs.variance;
}
typedef std::string KeyswitchKeyID;
typedef uint64_t KeyswitchKeyID;
struct KeyswitchKeyParam {
LweSecretKeyID inputSecretKeyID;
LweSecretKeyID outputSecretKeyID;
@@ -112,22 +112,28 @@ static inline bool operator==(const KeyswitchKeyParam &lhs,
lhs.variance == rhs.variance;
}
typedef std::string PackingKeySwitchID;
struct PackingKeySwitchParam {
typedef uint64_t PackingKeyswitchKeyID;
struct PackingKeyswitchKeyParam {
LweSecretKeyID inputSecretKeyID;
LweSecretKeyID outputSecretKeyID;
BootstrapKeyID bootstrapKeyID;
DecompositionLevelCount level;
DecompositionBaseLog baseLog;
GlweDimension glweDimension;
PolynomialSize polynomialSize;
LweDimension inputLweDimension;
Variance variance;
void hash(size_t &seed);
};
static inline bool operator==(const PackingKeySwitchParam &lhs,
const PackingKeySwitchParam &rhs) {
static inline bool operator==(const PackingKeyswitchKeyParam &lhs,
const PackingKeyswitchKeyParam &rhs) {
return lhs.inputSecretKeyID == rhs.inputSecretKeyID &&
lhs.outputSecretKeyID == rhs.outputSecretKeyID &&
lhs.level == rhs.level && lhs.baseLog == rhs.baseLog;
lhs.level == rhs.level && lhs.baseLog == rhs.baseLog &&
lhs.glweDimension == rhs.glweDimension &&
lhs.polynomialSize == rhs.polynomialSize &&
lhs.variance == lhs.variance &&
lhs.inputLweDimension == rhs.inputLweDimension;
}
struct Encoding {
@@ -185,13 +191,13 @@ struct CircuitGate {
bool isEncrypted() { return encryption.hasValue(); }
/// byteSize returns the size in bytes for this gate.
size_t byteSize(std::map<LweSecretKeyID, LweSecretKeyParam> secretKeys) {
size_t byteSize(std::vector<LweSecretKeyParam> secretKeys) {
auto width = shape.width;
auto numElts = shape.size == 0 ? 1 : shape.size;
if (isEncrypted()) {
auto skParam = secretKeys.find(encryption->secretKeyID);
assert(skParam != secretKeys.end());
return 8 * skParam->second.lweSize() * numElts;
assert(encryption->secretKeyID < secretKeys.size());
auto skParam = secretKeys[encryption->secretKeyID];
return 8 * skParam.lweSize() * numElts;
}
width = bitWidthAsWord(width) / 8;
return width * numElts;
@@ -203,10 +209,10 @@ static inline bool operator==(const CircuitGate &lhs, const CircuitGate &rhs) {
}
struct ClientParameters {
std::map<LweSecretKeyID, LweSecretKeyParam> secretKeys;
std::map<BootstrapKeyID, BootstrapKeyParam> bootstrapKeys;
std::map<KeyswitchKeyID, KeyswitchKeyParam> keyswitchKeys;
std::map<PackingKeySwitchID, PackingKeySwitchParam> packingKeys;
std::vector<LweSecretKeyParam> secretKeys;
std::vector<BootstrapKeyParam> bootstrapKeys;
std::vector<KeyswitchKeyParam> keyswitchKeys;
std::vector<PackingKeyswitchKeyParam> packingKeyswitchKeys;
std::vector<CircuitGate> inputs;
std::vector<CircuitGate> outputs;
std::string functionName;
@@ -237,12 +243,9 @@ struct ClientParameters {
if (!gate.encryption.hasValue()) {
return StringError("gate is not encrypted");
}
auto secretKey = secretKeys.find(gate.encryption->secretKeyID);
if (secretKey == secretKeys.end()) {
return StringError("cannot find ")
<< gate.encryption->secretKeyID << " in client parameters";
}
return secretKey->second;
assert(gate.encryption->secretKeyID < secretKeys.size());
auto secretKey = secretKeys[gate.encryption->secretKeyID];
return secretKey;
}
/// bufferSize returns the size of the whole buffer of a gate.
@@ -309,8 +312,8 @@ bool fromJSON(const llvm::json::Value, BootstrapKeyParam &, llvm::json::Path);
llvm::json::Value toJSON(const KeyswitchKeyParam &);
bool fromJSON(const llvm::json::Value, KeyswitchKeyParam &, llvm::json::Path);
llvm::json::Value toJSON(const PackingKeySwitchParam &);
bool fromJSON(const llvm::json::Value, PackingKeySwitchParam &,
llvm::json::Value toJSON(const PackingKeyswitchKeyParam &);
bool fromJSON(const llvm::json::Value, PackingKeyswitchKeyParam &,
llvm::json::Path);
llvm::json::Value toJSON(const Encoding &);

View File

@@ -61,8 +61,7 @@ public:
/// arguments, i.e. move all buffers to the PublicArguments and reset the
/// positional counter.
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
exportPublicArguments(ClientParameters clientParameters,
RuntimeContext runtimeContext);
exportPublicArguments(ClientParameters clientParameters);
/// Check that all arguments as been pushed.
// TODO: Remove public method here

View File

@@ -6,113 +6,143 @@
#ifndef CONCRETELANG_CLIENTLIB_EVALUATION_KEYS_H_
#define CONCRETELANG_CLIENTLIB_EVALUATION_KEYS_H_
#include <cassert>
#include <memory>
#include <vector>
#include "concrete-core-ffi.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/Common/Error.h"
typedef struct LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64
LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64;
int destroy_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_u64(
LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *);
struct Csprng;
struct CsprngVtable;
namespace concretelang {
namespace clientlib {
// =============================================
class CSPRNG {
public:
struct Csprng *ptr;
const struct CsprngVtable *vtable;
/// Wrapper for `LweKeyswitchKey64` so that it cleans up properly.
CSPRNG() = delete;
CSPRNG(CSPRNG &) = delete;
CSPRNG(CSPRNG &&other) : ptr(other.ptr), vtable(other.vtable) {
assert(ptr != nullptr);
other.ptr = nullptr;
};
CSPRNG(Csprng *ptr, const CsprngVtable *vtable) : ptr(ptr), vtable(vtable){};
};
class ConcreteCSPRNG : public CSPRNG {
public:
ConcreteCSPRNG(__uint128_t seed);
ConcreteCSPRNG() = delete;
ConcreteCSPRNG(ConcreteCSPRNG &) = delete;
ConcreteCSPRNG(ConcreteCSPRNG &&other);
~ConcreteCSPRNG();
};
/// @brief LweSecretKey implements tools for manipulating lwe secret key on
/// client.
class LweSecretKey {
std::shared_ptr<std::vector<uint64_t>> _buffer;
LweSecretKeyParam _parameters;
public:
LweSecretKey() = delete;
LweSecretKey(LweSecretKeyParam &parameters, CSPRNG &csprng);
LweSecretKey(std::shared_ptr<std::vector<uint64_t>> buffer,
LweSecretKeyParam parameters)
: _buffer(buffer), _parameters(parameters){};
/// @brief Encrypt the plaintext to the lwe ciphertext buffer.
void encrypt(uint64_t *ciphertext, uint64_t plaintext, double variance,
CSPRNG &csprng) const;
/// @brief Decrypt the ciphertext to the plaintext
void decrypt(const uint64_t *ciphertext, uint64_t &plaintext) const;
/// @brief Returns the buffer that hold the keyswitch key.
const uint64_t *buffer() const { return _buffer->data(); }
size_t size() const { return _buffer->size(); }
/// @brief Returns the parameters of the keyswicth key.
LweSecretKeyParam parameters() const { return this->_parameters; }
/// @brief Returns the lwe dimension of the secret key.
size_t dimension() const { return parameters().dimension; }
};
/// @brief LweKeyswitchKey implements tools for manipulating keyswitch key on
/// client.
class LweKeyswitchKey {
private:
LweKeyswitchKey64 *ksk;
protected:
friend std::ostream &operator<<(std::ostream &ostream,
const LweKeyswitchKey &wrappedKsk);
friend std::istream &operator>>(std::istream &istream,
LweKeyswitchKey &wrappedKsk);
std::shared_ptr<std::vector<uint64_t>> _buffer;
KeyswitchKeyParam _parameters;
public:
LweKeyswitchKey(LweKeyswitchKey64 *ksk) : ksk{ksk} {}
LweKeyswitchKey(LweKeyswitchKey &other) = delete;
LweKeyswitchKey(LweKeyswitchKey &&other) : ksk{other.ksk} {
other.ksk = nullptr;
}
~LweKeyswitchKey() {
if (this->ksk != nullptr) {
CAPI_ASSERT_ERROR(destroy_lwe_keyswitch_key_u64(this->ksk));
LweKeyswitchKey() = delete;
LweKeyswitchKey(KeyswitchKeyParam &parameters, LweSecretKey &inputKey,
LweSecretKey &outputKey, CSPRNG &csprng);
LweKeyswitchKey(std::shared_ptr<std::vector<uint64_t>> buffer,
KeyswitchKeyParam parameters)
: _buffer(buffer), _parameters(parameters){};
this->ksk = nullptr;
}
}
/// @brief Returns the buffer that hold the keyswitch key.
const uint64_t *buffer() const { return _buffer->data(); }
size_t size() const { return _buffer->size(); }
LweKeyswitchKey64 *get() { return this->ksk; }
/// @brief Returns the parameters of the keyswicth key.
KeyswitchKeyParam parameters() const { return this->_parameters; }
};
// =============================================
/// Wrapper for `LweBootstrapKey64` so that it cleans up properly.
/// @brief LweBootstrapKey implements tools for manipulating bootstrap key on
/// client.
class LweBootstrapKey {
private:
LweBootstrapKey64 *bsk;
protected:
friend std::ostream &operator<<(std::ostream &ostream,
const LweBootstrapKey &wrappedBsk);
friend std::istream &operator>>(std::istream &istream,
LweBootstrapKey &wrappedBsk);
std::shared_ptr<std::vector<uint64_t>> _buffer;
BootstrapKeyParam _parameters;
public:
LweBootstrapKey(LweBootstrapKey64 *bsk) : bsk{bsk} {}
LweBootstrapKey(LweBootstrapKey &other) = delete;
LweBootstrapKey(LweBootstrapKey &&other) : bsk{other.bsk} {
other.bsk = nullptr;
}
~LweBootstrapKey() {
if (this->bsk != nullptr) {
CAPI_ASSERT_ERROR(destroy_lwe_bootstrap_key_u64(this->bsk));
this->bsk = nullptr;
}
}
LweBootstrapKey() = delete;
LweBootstrapKey(std::shared_ptr<std::vector<uint64_t>> buffer,
BootstrapKeyParam &parameters)
: _buffer(buffer), _parameters(parameters){};
LweBootstrapKey(BootstrapKeyParam &parameters, LweSecretKey &inputKey,
LweSecretKey &outputKey, CSPRNG &csprng);
LweBootstrapKey64 *get() { return this->bsk; }
///// @brief Returns the buffer that hold the bootstrap key.
const uint64_t *buffer() const { return _buffer->data(); }
size_t size() const { return _buffer->size(); }
/// @brief Returns the parameters of the bootsrap key.
BootstrapKeyParam parameters() const { return this->_parameters; }
};
// =============================================
/// Wrapper for `LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64` so
/// that it cleans up properly.
/// @brief PackingKeyswitchKey implements tools for manipulating privat packing
/// keyswitch key on client.
class PackingKeyswitchKey {
private:
LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *key;
protected:
friend std::ostream &operator<<(std::ostream &ostream,
const PackingKeyswitchKey &wrappedFpksk);
friend std::istream &operator>>(std::istream &istream,
PackingKeyswitchKey &wrappedFpksk);
std::shared_ptr<std::vector<uint64_t>> _buffer;
PackingKeyswitchKeyParam _parameters;
public:
PackingKeyswitchKey(
LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *key)
: key{key} {}
PackingKeyswitchKey(PackingKeyswitchKey &other) = delete;
PackingKeyswitchKey(PackingKeyswitchKey &&other) : key{other.key} {
other.key = nullptr;
}
~PackingKeyswitchKey() {
if (this->key != nullptr) {
CAPI_ASSERT_ERROR(
destroy_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_u64(
this->key));
this->key = nullptr;
}
}
PackingKeyswitchKey() = delete;
PackingKeyswitchKey(PackingKeyswitchKeyParam &parameters,
LweSecretKey &inputKey, LweSecretKey &outputKey,
CSPRNG &csprng);
PackingKeyswitchKey(std::shared_ptr<std::vector<uint64_t>> buffer,
PackingKeyswitchKeyParam parameters)
: _buffer(buffer), _parameters(parameters){};
LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *get() {
return this->key;
}
/// @brief Returns the buffer that hold the keyswitch key.
const uint64_t *buffer() const { return _buffer->data(); }
size_t size() const { return _buffer->size(); }
/// @brief Returns the parameters of the keyswicth key.
PackingKeyswitchKeyParam parameters() const { return this->_parameters; }
};
// =============================================
@@ -120,31 +150,40 @@ public:
/// Evalution keys required for execution.
class EvaluationKeys {
private:
std::shared_ptr<LweKeyswitchKey> sharedKsk;
std::shared_ptr<LweBootstrapKey> sharedBsk;
std::shared_ptr<PackingKeyswitchKey> sharedFpksk;
protected:
friend std::ostream &operator<<(std::ostream &ostream,
const EvaluationKeys &evaluationKeys);
friend std::istream &operator>>(std::istream &istream,
EvaluationKeys &evaluationKeys);
std::vector<LweKeyswitchKey> keyswitchKeys;
std::vector<LweBootstrapKey> bootstrapKeys;
std::vector<PackingKeyswitchKey> packingKeyswitchKeys;
public:
EvaluationKeys()
: sharedKsk{std::shared_ptr<LweKeyswitchKey>(nullptr)},
sharedBsk{std::shared_ptr<LweBootstrapKey>(nullptr)} {}
EvaluationKeys() = delete;
EvaluationKeys(std::shared_ptr<LweKeyswitchKey> sharedKsk,
std::shared_ptr<LweBootstrapKey> sharedBsk,
std::shared_ptr<PackingKeyswitchKey> sharedFpksk)
: sharedKsk{sharedKsk}, sharedBsk{sharedBsk}, sharedFpksk{sharedFpksk} {}
EvaluationKeys(const std::vector<LweKeyswitchKey> keyswitchKeys,
const std::vector<LweBootstrapKey> bootstrapKeys,
const std::vector<PackingKeyswitchKey> packingKeyswitchKeys)
: keyswitchKeys(keyswitchKeys), bootstrapKeys(bootstrapKeys),
packingKeyswitchKeys(packingKeyswitchKeys) {}
LweKeyswitchKey64 *getKsk() { return this->sharedKsk->get(); }
LweBootstrapKey64 *getBsk() { return this->sharedBsk->get(); }
LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *getFpksk() {
return this->sharedFpksk->get();
const LweKeyswitchKey &getKeyswitchKey(size_t id) const {
return this->keyswitchKeys[id];
}
const std::vector<LweKeyswitchKey> getKeyswitchKeys() const {
return this->keyswitchKeys;
}
const LweBootstrapKey &getBootstrapKey(size_t id) const {
return bootstrapKeys[id];
}
const std::vector<LweBootstrapKey> getBootstrapKeys() const {
return this->bootstrapKeys;
}
const PackingKeyswitchKey &getPackingKeyswitchKey(size_t id) const {
return this->packingKeyswitchKeys[id];
};
const std::vector<PackingKeyswitchKey> getPackingKeyswitchKeys() const {
return this->packingKeyswitchKeys;
}
};
// =============================================

View File

@@ -10,30 +10,33 @@
#include "boost/outcome.h"
#include "concrete-core-ffi.h"
#include "concretelang/Runtime/DFRuntime.hpp"
#include "concretelang/Runtime/context.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Runtime/DFRuntime.hpp"
namespace concretelang {
namespace clientlib {
using concretelang::error::StringError;
using RuntimeContext = mlir::concretelang::RuntimeContext;
class KeySet {
public:
KeySet();
~KeySet();
KeySet(ClientParameters clientParameters, CSPRNG &&csprng)
: csprng(std::move(csprng)), _clientParameters(clientParameters){};
KeySet(KeySet &other) = delete;
/// allocate a KeySet according the ClientParameters.
/// Generate a KeySet from a ClientParameters specification.
static outcome::checked<std::unique_ptr<KeySet>, StringError>
generate(ClientParameters &params, uint64_t seed_msb, uint64_t seed_lsb);
generate(ClientParameters clientParameters, CSPRNG &&csprng);
/// Create a KeySet from a set of given keys
static outcome::checked<std::unique_ptr<KeySet>, StringError> fromKeys(
ClientParameters clientParameters, std::vector<LweSecretKey> secretKeys,
std::vector<LweBootstrapKey> bootstrapKeys,
std::vector<LweKeyswitchKey> keyswitchKeys,
std::vector<PackingKeyswitchKey> packingKeyswitchKeys, CSPRNG &&csprng);
/// Returns the ClientParameters associated with the KeySet.
ClientParameters clientParameters() { return _clientParameters; }
@@ -41,23 +44,6 @@ public:
// isInputEncrypted return true if the input at the given pos is encrypted.
bool isInputEncrypted(size_t pos);
/// getInputLweSecretKeyParam returns the parameters of the lwe secret key for
/// the input at the given `pos`.
/// The input must be encrupted
LweSecretKeyParam getInputLweSecretKeyParam(size_t pos) {
auto gate = inputGate(pos);
auto inputSk = this->secretKeys.find(gate.encryption->secretKeyID);
return inputSk->second.first;
}
/// getOutputLweSecretKeyParam returns the parameters of the lwe secret key
/// for the given output.
LweSecretKeyParam getOutputLweSecretKeyParam(size_t pos) {
auto gate = outputGate(pos);
auto outputSk = this->secretKeys.find(gate.encryption->secretKeyID);
return outputSk->second.first;
}
/// allocate a lwe ciphertext buffer for the argument at argPos, set the size
/// of the allocated buffer.
outcome::checked<void, StringError>
@@ -80,103 +66,58 @@ public:
CircuitGate inputGate(size_t pos) { return std::get<0>(inputs[pos]); }
CircuitGate outputGate(size_t pos) { return std::get<0>(outputs[pos]); }
RuntimeContext runtimeContext() {
RuntimeContext context;
context.evaluationKeys = this->evaluationKeys();
return context;
}
/// @brief evaluationKeys returns the evaluation keys associate to this client
/// keyset. Those evaluations keys can be safely shared publicly
EvaluationKeys evaluationKeys();
EvaluationKeys evaluationKeys() {
if (this->bootstrapKeys.empty() && this->keyswitchKeys.empty()) {
return EvaluationKeys();
}
auto kskIt = this->keyswitchKeys.find(clientlib::KEYSWITCH_KEY);
auto bskIt = this->bootstrapKeys.find(clientlib::BOOTSTRAP_KEY);
auto fpkskIt = this->packingKeys.find("fpksk_v0");
if (kskIt != this->keyswitchKeys.end() &&
bskIt != this->bootstrapKeys.end()) {
auto sharedKsk = std::get<1>(kskIt->second);
auto sharedBsk = std::get<1>(bskIt->second);
auto sharedFpksk = fpkskIt == this->packingKeys.end()
? std::make_shared<PackingKeyswitchKey>(nullptr)
: std::get<1>(fpkskIt->second);
return EvaluationKeys(sharedKsk, sharedBsk, sharedFpksk);
}
assert(!mlir::concretelang::dfr::_dfr_is_root_node() &&
"Evaluation keys missing in KeySet (on root node).");
return EvaluationKeys();
}
const std::vector<LweSecretKey> &getSecretKeys();
const std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey64 *>>
&getSecretKeys();
const std::vector<LweBootstrapKey> &getBootstrapKeys();
const std::map<LweSecretKeyID,
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>>
&getBootstrapKeys();
const std::vector<LweKeyswitchKey> &getKeyswitchKeys();
const std::map<LweSecretKeyID,
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>>
&getKeyswitchKeys();
const std::map<
LweSecretKeyID,
std::pair<PackingKeySwitchParam, std::shared_ptr<PackingKeyswitchKey>>> &
getPackingKeys();
const std::vector<PackingKeyswitchKey> &getPackingKeyswitchKeys();
protected:
outcome::checked<void, StringError>
generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param);
generateSecretKey(LweSecretKeyParam param);
outcome::checked<void, StringError>
generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param);
generateBootstrapKey(BootstrapKeyParam param);
outcome::checked<void, StringError>
generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param);
generateKeyswitchKey(KeyswitchKeyParam param);
outcome::checked<void, StringError>
generatePackingKey(PackingKeySwitchID id, PackingKeySwitchParam param);
generatePackingKeyswitchKey(PackingKeyswitchKeyParam param);
outcome::checked<void, StringError>
generateKeysFromParams(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb);
outcome::checked<void, StringError> generateKeysFromParams();
outcome::checked<void, StringError>
setupEncryptionMaterial(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb);
outcome::checked<void, StringError> setupEncryptionMaterial();
friend class KeySetCache;
private:
DefaultEngine *engine;
DefaultParallelEngine *par_engine;
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey64 *>>
secretKeys;
std::map<LweSecretKeyID,
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>>
bootstrapKeys;
std::map<LweSecretKeyID,
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>>
keyswitchKeys;
std::map<LweSecretKeyID, std::pair<PackingKeySwitchParam,
std::shared_ptr<PackingKeyswitchKey>>>
packingKeys;
std::vector<std::tuple<CircuitGate, LweSecretKeyParam, LweSecretKey64 *>>
inputs;
std::vector<std::tuple<CircuitGate, LweSecretKeyParam, LweSecretKey64 *>>
outputs;
CSPRNG csprng;
void setKeys(
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey64 *>>
secretKeys,
std::map<LweSecretKeyID,
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>>
bootstrapKeys,
std::map<LweSecretKeyID,
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>>
keyswitchKeys,
std::map<LweSecretKeyID, std::pair<PackingKeySwitchParam,
std::shared_ptr<PackingKeyswitchKey>>>
packingKeys);
///////////////////////////////////////////////
// Keys mappings
std::vector<LweSecretKey> secretKeys;
std::vector<LweBootstrapKey> bootstrapKeys;
std::vector<LweKeyswitchKey> keyswitchKeys;
std::vector<PackingKeyswitchKey> packingKeyswitchKeys;
outcome::checked<LweSecretKey, StringError> findLweSecretKey(LweSecretKeyID);
///////////////////////////////////////////////
// Convenient positional mapping between positional gate en secret key
typedef std::vector<std::pair<CircuitGate, llvm::Optional<LweSecretKey>>>
SecretKeyGateMapping;
outcome::checked<SecretKeyGateMapping, StringError>
mapCircuitGateLweSecretKey(std::vector<CircuitGate> gates);
SecretKeyGateMapping inputs;
SecretKeyGateMapping outputs;
clientlib::ClientParameters _clientParameters;
};

View File

@@ -14,7 +14,6 @@
#include "concretelang/ClientLib/EncryptedArguments.h"
#include "concretelang/ClientLib/Types.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Runtime/context.h"
namespace concretelang {
namespace serverlib {

View File

@@ -12,13 +12,10 @@
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/ClientLib/Types.h"
#include "concretelang/Runtime/context.h"
namespace concretelang {
namespace clientlib {
using RuntimeContext = mlir::concretelang::RuntimeContext;
// integers are not serialized as binary values even on a binary stream
// so we cannot rely on << operator directly
template <typename Word>
@@ -62,10 +59,6 @@ template <typename Stream> bool incorrectMode(Stream &stream) {
return !binary;
}
std::ostream &operator<<(std::ostream &ostream,
const RuntimeContext &runtimeContext);
std::istream &operator>>(std::istream &istream, RuntimeContext &runtimeContext);
std::ostream &serializeScalarData(const ScalarData &sd, std::ostream &ostream);
outcome::checked<ScalarData, StringError>
@@ -105,17 +98,24 @@ outcome::checked<ScalarOrTensorData, StringError>
unserializeScalarOrTensorData(const std::vector<int64_t> &expectedSizes,
std::istream &istream);
std::ostream &operator<<(std::ostream &ostream, const LweSecretKey &wrappedKsk);
LweSecretKey readLweSecretKey(std::istream &istream);
std::ostream &operator<<(std::ostream &ostream,
const LweKeyswitchKey &wrappedKsk);
std::istream &operator>>(std::istream &istream, LweKeyswitchKey &wrappedKsk);
LweKeyswitchKey readLweKeyswitchKey(std::istream &istream);
std::ostream &operator<<(std::ostream &ostream,
const LweBootstrapKey &wrappedBsk);
std::istream &operator>>(std::istream &istream, LweBootstrapKey &wrappedBsk);
LweBootstrapKey readLweBootstrapKey(std::istream &istream);
std::ostream &operator<<(std::ostream &ostream,
const PackingKeyswitchKey &wrappedKsk);
PackingKeyswitchKey readPackingKeyswitchKey(std::istream &istream);
std::ostream &operator<<(std::ostream &ostream,
const EvaluationKeys &evaluationKeys);
std::istream &operator>>(std::istream &istream, EvaluationKeys &evaluationKeys);
EvaluationKeys readEvaluationKeys(std::istream &istream);
} // namespace clientlib
} // namespace concretelang

View File

@@ -7,12 +7,6 @@
#include <string>
#define CAPI_ASSERT_ERROR(instr) \
{ \
int err = instr; \
assert(err == 0); \
}
namespace concretelang {
namespace error {

View File

@@ -8,6 +8,7 @@
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Casting.h"
#include <cstddef>
#include <list>
namespace mlir {
@@ -19,11 +20,9 @@ struct CrtLoweringParameters {
size_t nMods;
size_t modsProd;
size_t bitsTotal;
size_t polynomialSize;
size_t lutSize;
size_t singleLutSize;
CrtLoweringParameters(mlir::SmallVector<int64_t> mods, size_t polySize)
: mods(mods), polynomialSize(polySize) {
CrtLoweringParameters(mlir::SmallVector<int64_t> mods) : mods(mods) {
nMods = mods.size();
modsProd = 1;
bitsTotal = 0;
@@ -35,9 +34,7 @@ struct CrtLoweringParameters {
bits.push_back(nbits);
bitsTotal += nbits;
}
size_t lutCrtSize = size_t(1) << bitsTotal;
lutCrtSize = std::max(lutCrtSize, polynomialSize);
lutSize = mods.size() * lutCrtSize;
singleLutSize = size_t(1) << bitsTotal;
}
};

View File

@@ -15,12 +15,14 @@ include "concretelang/Dialect/RT/IR/RTTypes.td"
def Concrete_LweTensor : 1DTensorOf<[I64]>;
def Concrete_LutTensor : 1DTensorOf<[I64]>;
def Concrete_CrtLutsTensor : 2DTensorOf<[I64]>;
def Concrete_CrtPlaintextTensor : 1DTensorOf<[I64]>;
def Concrete_LweCRTTensor : 2DTensorOf<[I64]>;
def Concrete_BatchLweTensor : 2DTensorOf<[I64]>;
def Concrete_LweBuffer : MemRefRankOf<[I64], [1]>;
def Concrete_LutBuffer : MemRefRankOf<[I64], [1]>;
def Concrete_CrtLutsBuffer : MemRefRankOf<[I64], [2]>;
def Concrete_CrtPlaintextBuffer : MemRefRankOf<[I64], [1]>;
def Concrete_LweCRTBuffer : MemRefRankOf<[I64], [2]>;
def Concrete_BatchLweBuffer : MemRefRankOf<[I64], [2]>;
@@ -126,7 +128,7 @@ def Concrete_EncodeExpandLutForBootstrapBufferOp : Concrete_Op<"encode_expand_lu
);
}
def Concrete_EncodeExpandLutForWopPBSTensorOp : Concrete_Op<"encode_expand_lut_for_woppbs_tensor", [NoSideEffect]> {
def Concrete_EncodeLutForCrtWopPBSTensorOp : Concrete_Op<"encode_lut_for_crt_woppbs_tensor", [NoSideEffect]> {
let summary =
"Encode and expand a lookup table so that it can be used for a wop pbs";
@@ -134,24 +136,22 @@ def Concrete_EncodeExpandLutForWopPBSTensorOp : Concrete_Op<"encode_expand_lut_f
Concrete_LutTensor : $input_lookup_table,
I64ArrayAttr: $crtDecomposition,
I64ArrayAttr: $crtBits,
I32Attr : $polySize,
I32Attr : $modulusProduct,
BoolAttr: $isSigned
);
let results = (outs Concrete_LutTensor : $result);
let results = (outs Concrete_CrtLutsTensor : $result);
}
def Concrete_EncodeExpandLutForWopPBSBufferOp : Concrete_Op<"encode_expand_lut_for_woppbs_buffer"> {
def Concrete_EncodeLutForCrtWopPBSBufferOp : Concrete_Op<"encode_lut_for_crt_woppbs_buffer"> {
let summary =
"Encode and expand a lookup table so that it can be used for a wop pbs";
"Encode and expand a lookup table so that it can be used for a crt wop pbs";
let arguments = (ins
Concrete_LutBuffer : $result,
Concrete_CrtLutsBuffer : $result,
Concrete_LutBuffer : $input_lookup_table,
I64ArrayAttr: $crtDecomposition,
I64ArrayAttr: $crtBits,
I32Attr : $polySize,
I32Attr : $modulusProduct,
BoolAttr: $isSigned
);
@@ -347,7 +347,7 @@ def Concrete_BatchedKeySwitchLweBufferOp : Concrete_Op<"batched_keyswitch_lwe_bu
def Concrete_WopPBSCRTLweTensorOp : Concrete_Op<"wop_pbs_crt_lwe_tensor", [NoSideEffect]> {
let arguments = (ins
Concrete_LweCRTTensor:$ciphertext,
Concrete_LutTensor:$lookupTable,
Concrete_CrtLutsTensor:$lookupTable,
// Bootstrap parameters
I32Attr : $bootstrapLevel,
I32Attr : $bootstrapBaseLog,
@@ -370,7 +370,7 @@ def Concrete_WopPBSCRTLweBufferOp : Concrete_Op<"wop_pbs_crt_lwe_buffer"> {
let arguments = (ins
Concrete_LweCRTBuffer:$result,
Concrete_LweCRTBuffer:$ciphertext,
Concrete_LutBuffer:$lookup_table,
Concrete_CrtLutsBuffer:$lookup_table,
// Bootstrap parameters
I32Attr : $bootstrapLevel,
I32Attr : $bootstrapBaseLog,

View File

@@ -33,7 +33,7 @@ def TFHE_EncodeExpandLutForBootstrapOp : TFHE_Op<"encode_expand_lut_for_bootstra
let results = (outs 1DTensorOf<[I64]> : $result);
}
def TFHE_EncodeExpandLutForWopPBSOp : TFHE_Op<"encode_expand_lut_for_woppbs"> {
def TFHE_EncodeLutForCrtWopPBSOp : TFHE_Op<"encode_lut_for_crt_woppbs"> {
let summary =
"Encode and expand a lookup table so that it can be used for a wop pbs.";
@@ -41,12 +41,11 @@ def TFHE_EncodeExpandLutForWopPBSOp : TFHE_Op<"encode_expand_lut_for_woppbs"> {
1DTensorOf<[I64]> : $input_lookup_table,
I64ArrayAttr: $crtDecomposition,
I64ArrayAttr: $crtBits,
I32Attr : $polySize,
I32Attr : $modulusProduct,
BoolAttr: $isSigned
);
let results = (outs 1DTensorOf<[I64]> : $result);
let results = (outs 2DTensorOf<[I64]> : $result);
}
def TFHE_EncodePlaintextWithCrtOp : TFHE_Op<"encode_plaintext_with_crt"> {
@@ -158,7 +157,7 @@ def TFHE_WopPBSGLWEOp : TFHE_Op<"wop_pbs_glwe"> {
let arguments = (ins
Type<And<[TensorOf<[TFHE_GLWECipherTextType]>.predicate, HasStaticShapePred]>>: $ciphertexts,
1DTensorOf<[I64]> : $lookupTable,
2DTensorOf<[I64]> : $lookupTable,
// Bootstrap parameters
I32Attr : $bootstrapLevel,
I32Attr : $bootstrapBaseLog,

View File

@@ -12,11 +12,10 @@
#include <pthread.h>
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/Runtime/seeder.h"
#include "concrete-core-ffi.h"
#include "concretelang/Common/Error.h"
#include "concrete-cpu.h"
#ifdef CONCRETELANG_CUDA_SUPPORT
#include "bootstrap.h"
#include "device.h"
@@ -26,27 +25,23 @@
namespace mlir {
namespace concretelang {
typedef struct FFT {
FFT() = delete;
FFT(size_t polynomial_size);
FFT(FFT &other) = delete;
FFT(FFT &&other);
~FFT();
struct Fft *fft;
size_t polynomial_size;
} FFT;
typedef struct RuntimeContext {
RuntimeContext() {
CAPI_ASSERT_ERROR(new_default_engine(best_seeder, &default_engine));
#ifdef CONCRETELANG_CUDA_SUPPORT
bsk_gpu = nullptr;
ksk_gpu = nullptr;
#endif
}
/// Ensure that the engines map is not copied
RuntimeContext(const RuntimeContext &ctx){};
RuntimeContext() = delete;
RuntimeContext(::concretelang::clientlib::EvaluationKeys evaluationKeys);
~RuntimeContext() {
CAPI_ASSERT_ERROR(destroy_default_engine(default_engine));
for (const auto &key : fft_engines) {
CAPI_ASSERT_ERROR(destroy_fft_engine(key.second));
}
if (fbsk != nullptr) {
CAPI_ASSERT_ERROR(destroy_fft_fourier_lwe_bootstrap_key_u64(fbsk));
}
#ifdef CONCRETELANG_CUDA_SUPPORT
if (bsk_gpu != nullptr) {
cuda_drop(bsk_gpu, 0);
@@ -55,44 +50,33 @@ typedef struct RuntimeContext {
cuda_drop(ksk_gpu, 0);
}
#endif
};
const uint64_t *keyswitch_key_buffer(size_t keyId) {
return evaluationKeys.getKeyswitchKey(keyId).buffer();
}
FftEngine *get_fft_engine() {
pthread_t threadId = pthread_self();
std::lock_guard<std::mutex> guard(engines_map_guard);
auto engineIt = fft_engines.find(threadId);
if (engineIt == fft_engines.end()) {
FftEngine *fft_engine = nullptr;
CAPI_ASSERT_ERROR(new_fft_engine(&fft_engine));
engineIt =
fft_engines
.insert(std::pair<pthread_t, FftEngine *>(threadId, fft_engine))
.first;
}
assert(engineIt->second && "No engine available in context");
return engineIt->second;
const double *fourier_bootstrap_key_buffer(size_t keyId) {
return fourier_bootstrap_keys[keyId]->data();
}
DefaultEngine *get_default_engine() { return default_engine; }
FftFourierLweBootstrapKey64 *get_fft_fourier_bsk() {
if (fbsk != nullptr) {
return fbsk;
}
const std::lock_guard<std::mutex> guard(fbskMutex);
if (fbsk == nullptr) {
CAPI_ASSERT_ERROR(
fft_engine_convert_lwe_bootstrap_key_to_fft_fourier_lwe_bootstrap_key_u64(
get_fft_engine(), evaluationKeys.getBsk(), &fbsk));
}
return fbsk;
const uint64_t *fp_keyswitch_key_buffer(size_t keyId) {
return evaluationKeys.getPackingKeyswitchKey(keyId).buffer();
}
const struct Fft *fft(size_t keyId) { return ffts[keyId].fft; }
const ::concretelang::clientlib::EvaluationKeys getKeys() const {
return evaluationKeys;
}
private:
::concretelang::clientlib::EvaluationKeys evaluationKeys;
std::vector<std::shared_ptr<std::vector<double>>> fourier_bootstrap_keys;
std::vector<FFT> ffts;
#ifdef CONCRETELANG_CUDA_SUPPORT
public:
void *get_bsk_gpu(uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
uint32_t glwe_dim, uint32_t gpu_idx, void *stream) {
@@ -104,25 +88,21 @@ typedef struct RuntimeContext {
if (bsk_gpu != nullptr) {
return bsk_gpu;
}
LweBootstrapKey64 *bsk = get_bsk();
size_t bsk_buffer_len =
input_lwe_dim * (glwe_dim + 1) * (glwe_dim + 1) * poly_size * level;
size_t bsk_buffer_size = bsk_buffer_len * sizeof(uint64_t);
uint64_t *bsk_buffer =
(uint64_t *)aligned_alloc(U64_ALIGNMENT, bsk_buffer_size);
auto bsk = evaluationKeys.getBootstrapKey(0);
size_t bsk_buffer_len = bsk.size();
size_t bsk_gpu_buffer_size = bsk_buffer_len * sizeof(double);
void *bsk_gpu_tmp = cuda_malloc(bsk_gpu_buffer_size, gpu_idx);
CAPI_ASSERT_ERROR(
default_engine_discard_convert_lwe_bootstrap_key_to_lwe_bootstrap_key_mut_view_u64_raw_ptr_buffers(
default_engine, bsk, bsk_buffer));
cuda_initialize_twiddles(poly_size, gpu_idx);
cuda_convert_lwe_bootstrap_key_64(bsk_gpu_tmp, bsk_buffer, stream, gpu_idx,
input_lwe_dim, glwe_dim, level,
cuda_convert_lwe_bootstrap_key_64(bsk_gpu_tmp, (void *)bsk.buffer(), stream,
gpu_idx, input_lwe_dim, glwe_dim, level,
poly_size);
// This is currently not 100% async as we have to free CPU memory after
// This is currently not 100% async as
// we have to free CPU memory after
// conversion
cuda_synchronize_device(gpu_idx);
free(bsk_buffer);
bsk_gpu = bsk_gpu_tmp;
return bsk_gpu;
}
@@ -138,49 +118,23 @@ typedef struct RuntimeContext {
if (ksk_gpu != nullptr) {
return ksk_gpu;
}
LweKeyswitchKey64 *ksk = get_ksk();
size_t ksk_buffer_len = input_lwe_dim * (output_lwe_dim + 1) * level;
size_t ksk_buffer_size = sizeof(uint64_t) * ksk_buffer_len;
uint64_t *ksk_buffer =
(uint64_t *)aligned_alloc(U64_ALIGNMENT, ksk_buffer_size);
auto ksk = evaluationKeys.getKeyswitchKey(0);
size_t ksk_buffer_size = sizeof(uint64_t) * ksk.size();
void *ksk_gpu_tmp = cuda_malloc(ksk_buffer_size, gpu_idx);
CAPI_ASSERT_ERROR(
default_engine_discard_convert_lwe_keyswitch_key_to_lwe_keyswitch_key_mut_view_u64_raw_ptr_buffers(
default_engine, ksk, ksk_buffer));
cuda_memcpy_async_to_gpu(ksk_gpu_tmp, ksk_buffer, ksk_buffer_size, stream,
gpu_idx);
// This is currently not 100% async as we have to free CPU memory after
cuda_memcpy_async_to_gpu(ksk_gpu_tmp, (void *)ksk.buffer(), ksk_buffer_size,
stream, gpu_idx);
// This is currently not 100% async as
// we have to free CPU memory after
// conversion
cuda_synchronize_device(gpu_idx);
free(ksk_buffer);
ksk_gpu = ksk_gpu_tmp;
return ksk_gpu;
}
#endif
LweBootstrapKey64 *get_bsk() { return evaluationKeys.getBsk(); }
LweKeyswitchKey64 *get_ksk() { return evaluationKeys.getKsk(); }
LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *get_fpksk() {
return evaluationKeys.getFpksk();
}
RuntimeContext &operator=(const RuntimeContext &rhs) {
this->evaluationKeys = rhs.evaluationKeys;
return *this;
}
::concretelang::clientlib::EvaluationKeys evaluationKeys;
private:
std::mutex fbskMutex;
FftFourierLweBootstrapKey64 *fbsk = nullptr;
DefaultEngine *default_engine;
std::map<pthread_t, FftEngine *> fft_engines;
std::mutex engines_map_guard;
#ifdef CONCRETELANG_CUDA_SUPPORT
std::mutex bsk_gpu_mutex;
void *bsk_gpu;
std::mutex ksk_gpu_mutex;
@@ -192,18 +146,4 @@ private:
} // namespace concretelang
} // namespace mlir
extern "C" {
LweKeyswitchKey64 *
get_keyswitch_key_u64(mlir::concretelang::RuntimeContext *context);
FftFourierLweBootstrapKey64 *
get_fft_fourier_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context);
LweBootstrapKey64 *
get_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context);
DefaultEngine *get_engine(mlir::concretelang::RuntimeContext *context);
FftEngine *get_fft_engine(mlir::concretelang::RuntimeContext *context);
}
#endif

View File

@@ -14,17 +14,16 @@
#include <hpx/modules/collectives.hpp>
#include <hpx/modules/serialization.hpp>
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/Runtime/DFRuntime.hpp"
#include "concretelang/Runtime/context.h"
#include "concrete-core-ffi.h"
#include "concretelang/Common/Error.h"
namespace mlir {
namespace concretelang {
namespace dfr {
template <typename T> struct KeyManager;
struct RuntimeContextManager;
namespace {
static void *dl_handle;
@@ -32,103 +31,83 @@ static RuntimeContextManager *_dfr_node_level_runtime_context_manager;
} // namespace
template <typename LweKeyType> struct KeyWrapper {
LweKeyType *key;
Buffer buffer;
std::vector<LweKeyType> keys;
KeyWrapper() : key(nullptr) {}
KeyWrapper(KeyWrapper &&moved) noexcept
: key(moved.key), buffer(moved.buffer) {}
KeyWrapper(LweKeyType *key);
KeyWrapper(const KeyWrapper &kw) : key(kw.key), buffer(kw.buffer) {}
KeyWrapper() {}
KeyWrapper(KeyWrapper &&moved) noexcept : keys(moved.keys) {}
KeyWrapper(const KeyWrapper &kw) : keys(kw.keys) {}
KeyWrapper &operator=(const KeyWrapper &rhs) {
this->key = rhs.key;
this->buffer = rhs.buffer;
this->keys = rhs.keys;
return *this;
}
KeyWrapper(std::vector<LweKeyType> keyvec) : keys(keyvec) {}
friend class hpx::serialization::access;
// template <class Archive>
// void save(Archive &ar, const unsigned int version) const;
template <class Archive>
void save(Archive &ar, const unsigned int version) const;
template <class Archive> void load(Archive &ar, const unsigned int version);
HPX_SERIALIZATION_SPLIT_MEMBER()
void serialize(Archive &ar, const unsigned int version) const {}
// template <class Archive> void load(Archive &ar, const unsigned int
// version); HPX_SERIALIZATION_SPLIT_MEMBER()
};
template <>
KeyWrapper<LweKeyswitchKey64>::KeyWrapper(LweKeyswitchKey64 *key) : key(key) {
DefaultSerializationEngine *engine;
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
// No Freeing as it doesn't allocate anything.
CAPI_ASSERT_ERROR(
default_serialization_engine_serialize_lwe_keyswitch_key_u64(engine, key,
&buffer));
}
template <>
KeyWrapper<LweBootstrapKey64>::KeyWrapper(LweBootstrapKey64 *key) : key(key) {
DefaultSerializationEngine *engine;
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
// No Freeing as it doesn't allocate anything.
CAPI_ASSERT_ERROR(
default_serialization_engine_serialize_lwe_bootstrap_key_u64(engine, key,
&buffer));
}
template <typename LweKeyType>
bool operator==(const KeyWrapper<LweKeyType> &lhs,
const KeyWrapper<LweKeyType> &rhs) {
return lhs.key == rhs.key;
if (lhs.keys.size() != rhs.keys.size())
return false;
for (size_t i = 0; i < lhs.keys.size(); ++i)
if (lhs.keys[i].buffer() != rhs.keys[i].buffer())
return false;
return true;
}
template <>
template <class Archive>
void KeyWrapper<LweBootstrapKey64>::save(Archive &ar,
const unsigned int version) const {
ar << buffer.length;
ar << hpx::serialization::make_array(buffer.pointer, buffer.length);
}
template <>
template <class Archive>
void KeyWrapper<LweBootstrapKey64>::load(Archive &ar,
const unsigned int version) {
DefaultSerializationEngine *engine;
// template <>
// template <class Archive>
// void KeyWrapper<LweBootstrapKey>::save(Archive &ar,
// const unsigned int version) const {
// ar << buffer.length;
// ar << hpx::serialization::make_array(buffer.pointer, buffer.length);
// }
// template <>
// template <class Archive>
// void KeyWrapper<LweBootstrapKey>::load(Archive &ar,
// const unsigned int version) {
// DefaultSerializationEngine *engine;
// No Freeing as it doesn't allocate anything.
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
// // No Freeing as it doesn't allocate anything.
// CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
ar >> buffer.length;
buffer.pointer = new uint8_t[buffer.length];
ar >> hpx::serialization::make_array(buffer.pointer, buffer.length);
CAPI_ASSERT_ERROR(
default_serialization_engine_deserialize_lwe_bootstrap_key_u64(
engine, {buffer.pointer, buffer.length}, &key));
}
// ar >> buffer.length;
// buffer.pointer = new uint8_t[buffer.length];
// ar >> hpx::serialization::make_array(buffer.pointer, buffer.length);
// CAPI_ASSERT_ERROR(
// default_serialization_engine_deserialize_lwe_bootstrap_key_u64(
// engine, {buffer.pointer, buffer.length}, &key));
// }
template <>
template <class Archive>
void KeyWrapper<LweKeyswitchKey64>::save(Archive &ar,
const unsigned int version) const {
ar << buffer.length;
ar << hpx::serialization::make_array(buffer.pointer, buffer.length);
}
template <>
template <class Archive>
void KeyWrapper<LweKeyswitchKey64>::load(Archive &ar,
const unsigned int version) {
DefaultSerializationEngine *engine;
// template <>
// template <class Archive>
// void KeyWrapper<LweKeyswitchKey>::save(Archive &ar,
// const unsigned int version) const {
// ar << buffer.length;
// ar << hpx::serialization::make_array(buffer.pointer, buffer.length);
// }
// template <>
// template <class Archive>
// void KeyWrapper<LweKeyswitchKey>::load(Archive &ar,
// const unsigned int version) {
// DefaultSerializationEngine *engine;
// No Freeing as it doesn't allocate anything.
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
// // No Freeing as it doesn't allocate anything.
// CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
ar >> buffer.length;
buffer.pointer = new uint8_t[buffer.length];
ar >> hpx::serialization::make_array(buffer.pointer, buffer.length);
CAPI_ASSERT_ERROR(
default_serialization_engine_deserialize_lwe_keyswitch_key_u64(
engine, {buffer.pointer, buffer.length}, &key));
}
// ar >> buffer.length;
// buffer.pointer = new uint8_t[buffer.length];
// ar >> hpx::serialization::make_array(buffer.pointer, buffer.length);
// CAPI_ASSERT_ERROR(
// default_serialization_engine_deserialize_lwe_keyswitch_key_u64(
// engine, {buffer.pointer, buffer.length}, &key));
// }
/************************/
/* Context management. */
@@ -152,34 +131,36 @@ struct RuntimeContextManager {
// instantiates a local RuntimeContext.
if (_dfr_is_root_node()) {
RuntimeContext *context = (RuntimeContext *)ctx;
LweKeyswitchKey64 *ksk = get_keyswitch_key_u64(context);
LweBootstrapKey64 *bsk = get_bootstrap_key_u64(context);
KeyWrapper<LweKeyswitchKey64> kskw(ksk);
KeyWrapper<LweBootstrapKey64> bskw(bsk);
KeyWrapper<::concretelang::clientlib::LweKeyswitchKey> kskw(
context->getKeys().getKeyswitchKeys());
KeyWrapper<::concretelang::clientlib::LweBootstrapKey> bskw(
context->getKeys().getBootstrapKeys());
KeyWrapper<::concretelang::clientlib::PackingKeyswitchKey> pkskw(
context->getKeys().getPackingKeyswitchKeys());
hpx::collectives::broadcast_to("ksk_keystore", kskw);
hpx::collectives::broadcast_to("bsk_keystore", bskw);
hpx::collectives::broadcast_to("pksk_keystore", pkskw);
} else {
auto kskFut =
hpx::collectives::broadcast_from<KeyWrapper<LweKeyswitchKey64>>(
"ksk_keystore");
auto bskFut =
hpx::collectives::broadcast_from<KeyWrapper<LweBootstrapKey64>>(
"bsk_keystore");
auto kskFut = hpx::collectives::broadcast_from<
KeyWrapper<::concretelang::clientlib::LweKeyswitchKey>>(
"ksk_keystore");
auto bskFut = hpx::collectives::broadcast_from<
KeyWrapper<::concretelang::clientlib::LweBootstrapKey>>(
"bsk_keystore");
auto pkskFut = hpx::collectives::broadcast_from<
KeyWrapper<::concretelang::clientlib::PackingKeyswitchKey>>(
"pksk_keystore");
KeyWrapper<LweKeyswitchKey64> kskw = kskFut.get();
KeyWrapper<LweBootstrapKey64> bskw = bskFut.get();
context = new mlir::concretelang::RuntimeContext();
// TODO - packing keyswitch key for distributed
context->evaluationKeys = ::concretelang::clientlib::EvaluationKeys(
std::shared_ptr<::concretelang::clientlib::LweKeyswitchKey>(
new ::concretelang::clientlib::LweKeyswitchKey(kskw.key)),
std::shared_ptr<::concretelang::clientlib::LweBootstrapKey>(
new ::concretelang::clientlib::LweBootstrapKey(bskw.key)),
std::shared_ptr<::concretelang::clientlib::PackingKeyswitchKey>(
nullptr));
delete (kskw.buffer.pointer);
delete (bskw.buffer.pointer);
KeyWrapper<::concretelang::clientlib::LweKeyswitchKey> kskw =
kskFut.get();
KeyWrapper<::concretelang::clientlib::LweBootstrapKey> bskw =
bskFut.get();
KeyWrapper<::concretelang::clientlib::PackingKeyswitchKey> pkskw =
pkskFut.get();
context = new mlir::concretelang::RuntimeContext(
::concretelang::clientlib::EvaluationKeys(kskw.keys, bskw.keys,
pkskw.keys));
}
}

View File

@@ -1,13 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_RUNTIME_SEEDER_H
#define CONCRETELANG_RUNTIME_SEEDER_H
#include "concrete-core-ffi.h"
extern SeederBuilder *best_seeder;
#endif

View File

@@ -29,18 +29,19 @@ void memref_encode_expand_lut_for_bootstrap(
uint64_t input_lut_size, uint64_t input_lut_stride, uint32_t poly_size,
uint32_t out_MESSAGE_BITS, bool is_signed);
void memref_encode_expand_lut_for_woppbs(
void memref_encode_lut_for_crt_woppbs(
uint64_t *output_lut_allocated, uint64_t *output_lut_aligned,
uint64_t output_lut_offset, uint64_t output_lut_size,
uint64_t output_lut_stride, uint64_t *input_lut_allocated,
uint64_t output_lut_offset, uint64_t output_lut_size0,
uint64_t output_lut_size1, uint64_t output_lut_stride0,
uint64_t output_lut_stride1, uint64_t *input_lut_allocated,
uint64_t *input_lut_aligned, uint64_t input_lut_offset,
uint64_t input_lut_size, uint64_t input_lut_stride,
uint64_t *crt_decomposition_allocated, uint64_t *crt_decomposition_aligned,
uint64_t crt_decomposition_offset, uint64_t crt_decomposition_size,
uint64_t crt_decomposition_stride, uint64_t *crt_bits_allocated,
uint64_t *crt_bits_aligned, uint64_t crt_bits_offset,
uint64_t crt_bits_size, uint64_t crt_bits_stride, uint32_t poly_size,
uint32_t modulus_product, bool is_signed);
uint64_t crt_bits_size, uint64_t crt_bits_stride, uint32_t modulus_product,
bool is_signed);
void memref_encode_plaintext_with_crt(
uint64_t *output_allocated, uint64_t *output_aligned,
@@ -149,13 +150,16 @@ void memref_wop_pbs_crt_buffer(
uint64_t in_stride_1,
// clear text lut
uint64_t *lut_ct_allocated, uint64_t *lut_ct_aligned,
uint64_t lut_ct_offset, uint64_t lut_ct_size, uint64_t lut_ct_stride,
uint64_t lut_ct_offset, uint64_t lut_ct_size0, uint64_t lut_ct_size1,
uint64_t lut_ct_stride0, uint64_t lut_ct_stride1,
// CRT decomposition
uint64_t *crt_decomp_allocated, uint64_t *crt_decomp_aligned,
uint64_t crt_decomp_offset, uint64_t crt_decomp_size,
uint64_t crt_decomp_stride,
// Additional crypto parameters
uint32_t lwe_small_size, uint32_t cbs_level_count, uint32_t cbs_base_log,
uint32_t ksk_level_count, uint32_t ksk_base_log, uint32_t bsk_level_count,
uint32_t bsk_base_log, uint32_t fpksk_level_count, uint32_t fpksk_base_log,
uint32_t polynomial_size,
// runtime context that hold evluation keys
mlir::concretelang::RuntimeContext *context);

View File

@@ -337,8 +337,8 @@ public:
if (check.has_error()) {
return StreamStringError(check.error().mesg);
}
auto publicArguments = encryptedArgs->exportPublicArguments(
clientParameters, keySet.runtimeContext());
auto publicArguments =
encryptedArgs->exportPublicArguments(clientParameters);
if (publicArguments.has_error()) {
return StreamStringError(publicArguments.error().mesg);
}
@@ -485,8 +485,8 @@ public:
if (encryptedArgs.has_error()) {
return StreamStringError(encryptedArgs.error().mesg);
}
auto publicArguments = encryptedArgs.value()->exportPublicArguments(
clientParameters, keySet->runtimeContext());
auto publicArguments =
encryptedArgs.value()->exportPublicArguments(clientParameters);
if (!publicArguments.has_value()) {
return StreamStringError(publicArguments.error().mesg);
}
@@ -506,8 +506,8 @@ public:
if (encryptedArgs.has_error()) {
return StreamStringError(encryptedArgs.error().mesg);
}
auto publicArguments = encryptedArgs.value()->exportPublicArguments(
clientParameters, keySet->runtimeContext());
auto publicArguments =
encryptedArgs.value()->exportPublicArguments(clientParameters);
if (publicArguments.has_error()) {
return StreamStringError(publicArguments.error().mesg);
}

View File

@@ -73,8 +73,7 @@ public:
OUTCOME_TRY(auto encryptedArgs,
clientlib::EncryptedArguments::create(*keySet, args...));
OUTCOME_TRY(auto publicArgument,
encryptedArgs->exportPublicArguments(this->clientParameters,
keySet->runtimeContext()));
encryptedArgs->exportPublicArguments(this->clientParameters));
// client argument serialization
// publicArgument->serialize(clientOuput);
// message = clientOuput.str();

View File

@@ -218,8 +218,8 @@ MLIR_CAPI_EXPORTED concretelang::clientlib::EvaluationKeys
evaluationKeysUnserialize(const std::string &buffer) {
std::stringstream istream(buffer);
concretelang::clientlib::EvaluationKeys evaluationKeys;
concretelang::clientlib::operator>>(istream, evaluationKeys);
concretelang::clientlib::EvaluationKeys evaluationKeys =
concretelang::clientlib::readEvaluationKeys(istream);
if (istream.fail()) {
throw std::runtime_error("Cannot read evaluation keys");

View File

@@ -394,8 +394,15 @@ void encodingDestroy(Encoding encoding){C_STRUCT_CLEANER(encoding)}
KeySet keySetGenerate(ClientParameters params, uint64_t seed_msb,
uint64_t seed_lsb) {
__uint128_t seed = seed_msb;
seed <<= 64;
seed += seed_lsb;
auto csprng = concretelang::clientlib::ConcreteCSPRNG(seed);
auto keySet = mlir::concretelang::clientlib::KeySet::generate(
*unwrap(params), seed_msb, seed_lsb);
*unwrap(params), std::move(csprng));
if (keySet.has_error()) {
return wrap((mlir::concretelang::clientlib::KeySet *)NULL,
keySet.error().mesg);
@@ -445,8 +452,8 @@ BufferRef evaluationKeysSerialize(EvaluationKeys keys) {
EvaluationKeys evaluationKeysUnserialize(BufferRef buffer) {
std::stringstream istream(std::string(buffer.data, buffer.length));
concretelang::clientlib::EvaluationKeys evaluationKeys;
concretelang::clientlib::operator>>(istream, evaluationKeys);
concretelang::clientlib::EvaluationKeys evaluationKeys =
concretelang::clientlib::readEvaluationKeys(istream);
if (istream.fail()) {
return wrap((concretelang::clientlib::EvaluationKeys *)NULL,
"input stream failure during evaluation keys unserialization");

View File

@@ -2,6 +2,7 @@ add_mlir_library(
ConcretelangClientLib
ClientLambda.cpp
ClientParameters.cpp
EvaluationKeys.cpp
CRT.cpp
EncryptedArguments.cpp
KeySet.cpp
@@ -12,4 +13,6 @@ add_mlir_library(
${PROJECT_SOURCE_DIR}/include/concretelang/ClientLib
LINK_LIBS
PUBLIC
Concrete)
concrete_cpu)
target_include_directories(ConcretelangClientLib PUBLIC ${CONCRETE_CPU_INCLUDE_DIR})

View File

@@ -39,28 +39,25 @@ void KeyswitchKeyParam::hash(size_t &seed) {
double_to_bits(variance));
}
void PackingKeySwitchParam::hash(size_t &seed) {
hash_(seed, inputSecretKeyID, outputSecretKeyID, bootstrapKeyID, level,
baseLog, double_to_bits(variance));
void PackingKeyswitchKeyParam::hash(size_t &seed) {
hash_(seed, inputSecretKeyID, outputSecretKeyID, level, baseLog,
glweDimension, polynomialSize, inputLweDimension,
double_to_bits(variance));
}
std::size_t ClientParameters::hash() {
std::size_t currentHash = 1;
for (auto secretKeyParam : secretKeys) {
hash_(currentHash, secretKeyParam.first);
secretKeyParam.second.hash(currentHash);
secretKeyParam.hash(currentHash);
}
for (auto bootstrapKeyParam : bootstrapKeys) {
hash_(currentHash, bootstrapKeyParam.first);
bootstrapKeyParam.second.hash(currentHash);
bootstrapKeyParam.hash(currentHash);
}
for (auto keyswitchParam : keyswitchKeys) {
hash_(currentHash, keyswitchParam.first);
keyswitchParam.second.hash(currentHash);
keyswitchParam.hash(currentHash);
}
for (auto packingParam : packingKeys) {
hash_(currentHash, packingParam.first);
packingParam.second.hash(currentHash);
for (auto packingKeyswitchKeyParam : packingKeyswitchKeys) {
packingKeyswitchKeyParam.hash(currentHash);
}
return currentHash;
}
@@ -74,18 +71,8 @@ llvm::json::Value toJSON(const LweSecretKeyParam &v) {
bool fromJSON(const llvm::json::Value j, LweSecretKeyParam &v,
llvm::json::Path p) {
auto obj = j.getAsObject();
if (obj == nullptr) {
p.report("should be an object");
return false;
}
auto dimension = obj->getInteger("dimension");
if (!dimension.hasValue()) {
p.report("missing size field");
return false;
}
v.dimension = *dimension;
return true;
llvm::json::ObjectMapper O(j, p);
return O && O.map("dimension", v.dimension);
}
llvm::json::Value toJSON(const BootstrapKeyParam &v) {
@@ -96,54 +83,22 @@ llvm::json::Value toJSON(const BootstrapKeyParam &v) {
{"glweDimension", v.glweDimension},
{"baseLog", v.baseLog},
{"variance", v.variance},
{"polynomialSize", v.polynomialSize},
{"inputLweDimension", v.inputLweDimension},
};
return object;
}
bool fromJSON(const llvm::json::Value j, BootstrapKeyParam &v,
llvm::json::Path p) {
auto obj = j.getAsObject();
if (obj == nullptr) {
p.report("should be an object");
return false;
}
auto inputSecretKeyID = obj->getString("inputSecretKeyID");
if (!inputSecretKeyID.hasValue()) {
p.report("missing inputSecretKeyID field");
return false;
}
auto outputSecretKeyID = obj->getString("outputSecretKeyID");
if (!outputSecretKeyID.hasValue()) {
p.report("missing outputSecretKeyID field");
return false;
}
auto level = obj->getInteger("level");
if (!level.hasValue()) {
p.report("missing level field");
return false;
}
auto baseLog = obj->getInteger("baseLog");
if (!baseLog.hasValue()) {
p.report("missing baseLog field");
return false;
}
auto glweDimension = obj->getInteger("glweDimension");
if (!glweDimension.hasValue()) {
p.report("missing glweDimension field");
return false;
}
auto variance = obj->getNumber("variance");
if (!variance.hasValue()) {
p.report("missing variance field");
return false;
}
v.inputSecretKeyID = (std::string)inputSecretKeyID.getValue();
v.outputSecretKeyID = (std::string)outputSecretKeyID.getValue();
v.level = level.getValue();
v.baseLog = baseLog.getValue();
v.glweDimension = glweDimension.getValue();
v.variance = variance.getValue();
return true;
llvm::json::ObjectMapper O(j, p);
return O && O.map("inputSecretKeyID", v.inputSecretKeyID) &&
O.map("outputSecretKeyID", v.outputSecretKeyID) &&
O.map("level", v.level) && O.map("baseLog", v.baseLog) &&
O.map("glweDimension", v.glweDimension) &&
O.map("variance", v.variance) &&
O.map("polynomialSize", v.polynomialSize) &&
O.map("inputLweDimension", v.inputLweDimension);
}
llvm::json::Value toJSON(const KeyswitchKeyParam &v) {
@@ -158,99 +113,37 @@ llvm::json::Value toJSON(const KeyswitchKeyParam &v) {
}
bool fromJSON(const llvm::json::Value j, KeyswitchKeyParam &v,
llvm::json::Path p) {
auto obj = j.getAsObject();
if (obj == nullptr) {
p.report("should be an object");
return false;
}
auto inputSecretKeyID = obj->getString("inputSecretKeyID");
if (!inputSecretKeyID.hasValue()) {
p.report("missing inputSecretKeyID field");
return false;
}
auto outputSecretKeyID = obj->getString("outputSecretKeyID");
if (!outputSecretKeyID.hasValue()) {
p.report("missing outputSecretKeyID field");
return false;
}
auto level = obj->getInteger("level");
if (!level.hasValue()) {
p.report("missing level field");
return false;
}
auto baseLog = obj->getInteger("baseLog");
if (!baseLog.hasValue()) {
p.report("missing baseLog field");
return false;
}
auto variance = obj->getNumber("variance");
if (!variance.hasValue()) {
p.report("missing variance field");
return false;
}
v.inputSecretKeyID = (std::string)inputSecretKeyID.getValue();
v.outputSecretKeyID = (std::string)outputSecretKeyID.getValue();
v.level = level.getValue();
v.baseLog = baseLog.getValue();
v.variance = variance.getValue();
return true;
llvm::json::ObjectMapper O(j, p);
return O && O.map("inputSecretKeyID", v.inputSecretKeyID) &&
O.map("outputSecretKeyID", v.outputSecretKeyID) &&
O.map("level", v.level) && O.map("baseLog", v.baseLog) &&
O.map("variance", v.variance);
}
llvm::json::Value toJSON(const PackingKeySwitchParam &v) {
llvm::json::Value toJSON(const PackingKeyswitchKeyParam &v) {
llvm::json::Object object{
{"inputSecretKeyID", v.inputSecretKeyID},
{"outputSecretKeyID", v.outputSecretKeyID},
{"bootstrapKeyID", v.bootstrapKeyID},
{"level", v.level},
{"baseLog", v.baseLog},
{"glweDimension", v.glweDimension},
{"polynomialSize", v.polynomialSize},
{"inputLweDimension", v.inputLweDimension},
{"variance", v.variance},
};
return object;
}
bool fromJSON(const llvm::json::Value j, PackingKeySwitchParam &v,
bool fromJSON(const llvm::json::Value j, PackingKeyswitchKeyParam &v,
llvm::json::Path p) {
auto obj = j.getAsObject();
if (obj == nullptr) {
p.report("should be an object");
return false;
}
auto inputSecretKeyID = obj->getString("inputSecretKeyID");
if (!inputSecretKeyID.hasValue()) {
p.report("missing inputSecretKeyID field");
return false;
}
auto outputSecretKeyID = obj->getString("outputSecretKeyID");
if (!outputSecretKeyID.hasValue()) {
p.report("missing outputSecretKeyID field");
return false;
}
auto bootstrapKeyID = obj->getString("bootstrapKeyID");
if (!bootstrapKeyID.hasValue()) {
p.report("missing bootstrapKeyID field");
return false;
}
auto level = obj->getInteger("level");
if (!level.hasValue()) {
p.report("missing level field");
return false;
}
auto baseLog = obj->getInteger("baseLog");
if (!baseLog.hasValue()) {
p.report("missing baseLog field");
return false;
}
auto variance = obj->getNumber("variance");
if (!variance.hasValue()) {
p.report("missing variance field");
return false;
}
v.inputSecretKeyID = (std::string)inputSecretKeyID.getValue();
v.outputSecretKeyID = (std::string)outputSecretKeyID.getValue();
v.bootstrapKeyID = (std::string)bootstrapKeyID.getValue();
v.level = level.getValue();
v.baseLog = baseLog.getValue();
v.variance = variance.getValue();
return true;
llvm::json::ObjectMapper O(j, p);
return O && O.map("inputSecretKeyID", v.inputSecretKeyID) &&
O.map("outputSecretKeyID", v.outputSecretKeyID) &&
O.map("level", v.level) && O.map("baseLog", v.baseLog) &&
O.map("glweDimension", v.glweDimension) &&
O.map("polynomialSize", v.polynomialSize) &&
O.map("inputLweDimension", v.inputLweDimension) &&
O.map("variance", v.variance);
}
llvm::json::Value toJSON(const CircuitGateShape &v) {
@@ -264,43 +157,10 @@ llvm::json::Value toJSON(const CircuitGateShape &v) {
}
bool fromJSON(const llvm::json::Value j, CircuitGateShape &v,
llvm::json::Path p) {
auto obj = j.getAsObject();
if (obj == nullptr) {
p.report("should be an object");
return false;
}
auto width = obj->getInteger("width");
if (!width.hasValue()) {
p.report("missing width field");
return false;
}
auto dimensions = obj->getArray("dimensions");
if (dimensions == nullptr) {
p.report("missing dimensions field");
return false;
}
for (auto dim : *dimensions) {
auto iDim = dim.getAsInteger();
if (!iDim.hasValue()) {
p.report("dimensions must be integer");
return false;
}
v.dimensions.push_back(iDim.getValue());
}
auto size = obj->getInteger("size");
if (!size.hasValue()) {
p.report("missing size field");
return false;
}
auto sign = obj->getBoolean("sign");
if (!sign.hasValue()) {
p.report("missing sign field");
return false;
}
v.width = width.getValue();
v.size = size.getValue();
v.sign = sign.getValue();
return true;
llvm::json::ObjectMapper O(j, p);
return O && O.map("width", v.width) && O.map("size", v.size) &&
O.map("dimensions", v.dimensions) && O.map("sign", v.sign);
}
llvm::json::Value toJSON(const Encoding &v) {
@@ -314,35 +174,13 @@ llvm::json::Value toJSON(const Encoding &v) {
return object;
}
bool fromJSON(const llvm::json::Value j, Encoding &v, llvm::json::Path p) {
auto obj = j.getAsObject();
if (obj == nullptr) {
p.report("should be an object");
llvm::json::ObjectMapper O(j, p);
if (!(O && O.map("precision", v.precision) &&
O.map("isSigned", v.isSigned))) {
return false;
}
auto precision = obj->getInteger("precision");
if (!precision.hasValue()) {
p.report("missing precision field");
return false;
}
v.precision = precision.getValue();
auto isSigned = obj->getBoolean("isSigned");
if (!isSigned.hasValue()) {
p.report("missing isSigned field");
return false;
}
v.isSigned = isSigned.getValue();
auto crt = obj->getArray("crt");
if (crt != nullptr) {
for (auto dim : *crt) {
auto iDim = dim.getAsInteger();
if (!iDim.hasValue()) {
p.report("dimensions must be integer");
return false;
}
v.crt.push_back(iDim.getValue());
}
}
// TODO: check this is correct for an optional field
O.map("crt", v.crt);
return true;
}
@@ -356,32 +194,9 @@ llvm::json::Value toJSON(const EncryptionGate &v) {
}
bool fromJSON(const llvm::json::Value j, EncryptionGate &v,
llvm::json::Path p) {
auto obj = j.getAsObject();
if (obj == nullptr) {
p.report("should be an object");
return false;
}
auto secretKeyID = obj->getString("secretKeyID");
if (!secretKeyID.hasValue()) {
p.report("missing secretKeyID field");
return false;
}
v.secretKeyID = (std::string)secretKeyID.getValue();
auto variance = obj->getNumber("variance");
if (!variance.hasValue()) {
p.report("missing variance field");
return false;
}
v.variance = variance.getValue();
auto encoding = obj->get("encoding");
if (encoding == nullptr) {
p.report("missing encoding field");
return false;
}
if (!fromJSON(*encoding, v.encoding, p.field("encoding"))) {
return false;
}
return true;
llvm::json::ObjectMapper O(j, p);
return O && O.map("secretKeyID", v.secretKeyID) &&
O.map("variance", v.variance) && O.map("encoding", v.encoding);
}
llvm::json::Value toJSON(const CircuitGate &v) {
@@ -392,40 +207,16 @@ llvm::json::Value toJSON(const CircuitGate &v) {
return object;
}
bool fromJSON(const llvm::json::Value j, CircuitGate &v, llvm::json::Path p) {
auto obj = j.getAsObject();
auto encryption = obj->get("encryption");
if (encryption == nullptr) {
p.report("missing encryption field");
return false;
}
if (!fromJSON(*encryption, v.encryption, p.field("encryption"))) {
return false;
}
auto shape = obj->get("shape");
if (shape == nullptr) {
p.report("missing shape field");
return false;
}
if (!fromJSON(*shape, v.shape, p.field("shape"))) {
return false;
}
return true;
}
template <typename T> llvm::json::Value toJson(std::map<std::string, T> map) {
llvm::json::Object obj;
for (auto entry : map) {
obj[entry.first] = entry.second;
}
return obj;
llvm::json::ObjectMapper O(j, p);
return O && O.map("encryption", v.encryption) && O.map("shape", v.shape);
}
llvm::json::Value toJSON(const ClientParameters &v) {
llvm::json::Object object{
{"secretKeys", toJson(v.secretKeys)},
{"bootstrapKeys", toJson(v.bootstrapKeys)},
{"keyswitchKeys", toJson(v.keyswitchKeys)},
{"packingKeys", toJson(v.packingKeys)},
{"secretKeys", v.secretKeys},
{"bootstrapKeys", v.bootstrapKeys},
{"keyswitchKeys", v.keyswitchKeys},
{"packingKeyswitchKeys", v.packingKeyswitchKeys},
{"inputs", v.inputs},
{"outputs", v.outputs},
{"functionName", v.functionName},
@@ -434,64 +225,13 @@ llvm::json::Value toJSON(const ClientParameters &v) {
}
bool fromJSON(const llvm::json::Value j, ClientParameters &v,
llvm::json::Path p) {
auto obj = j.getAsObject();
auto secretkeys = obj->get("secretKeys");
if (secretkeys == nullptr) {
p.report("missing secretKeys field");
return false;
}
if (!fromJSON(*secretkeys, v.secretKeys, p.field("secretKeys"))) {
return false;
}
auto bootstrapKeys = obj->get("bootstrapKeys");
if (bootstrapKeys == nullptr) {
p.report("missing bootstrapKeys field");
return false;
}
if (!fromJSON(*bootstrapKeys, v.bootstrapKeys, p.field("bootstrapKeys"))) {
return false;
}
auto keyswitchKeys = obj->get("keyswitchKeys");
if (keyswitchKeys == nullptr) {
p.report("missing keyswitchKeys field");
return false;
}
if (!fromJSON(*keyswitchKeys, v.keyswitchKeys, p.field("keyswitchKeys"))) {
return false;
}
auto packingKeys = obj->get("packingKeys");
if (packingKeys == nullptr) {
p.report("missing packingKeys field");
return false;
}
if (!fromJSON(*packingKeys, v.packingKeys, p.field("packingKeys"))) {
return false;
}
auto inputs = obj->get("inputs");
if (inputs == nullptr) {
p.report("missing inputs field");
return false;
}
if (!fromJSON(*inputs, v.inputs, p.field("inputs"))) {
return false;
}
auto outputs = obj->get("outputs");
if (outputs == nullptr) {
p.report("missing outputs field");
return false;
}
if (!fromJSON(*outputs, v.outputs, p.field("outputs"))) {
return false;
}
auto functionName = obj->getString("functionName");
if (!functionName.hasValue()) {
p.report("missing functionName field");
return false;
}
v.functionName = (std::string)functionName.getValue();
return true;
llvm::json::ObjectMapper O(j, p);
return O && O.map("secretKeys", v.secretKeys) &&
O.map("bootstrapKeys", v.bootstrapKeys) &&
O.map("keyswitchKeys", v.keyswitchKeys) &&
O.map("packingKeyswitchKeys", v.packingKeyswitchKeys) &&
O.map("inputs", v.inputs) && O.map("outputs", v.outputs) &&
O.map("functionName", v.functionName);
}
std::string ClientParameters::getClientParametersPath(std::string path) {

View File

@@ -12,8 +12,7 @@ namespace clientlib {
using StringError = concretelang::error::StringError;
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
EncryptedArguments::exportPublicArguments(ClientParameters clientParameters,
RuntimeContext runtimeContext) {
EncryptedArguments::exportPublicArguments(ClientParameters clientParameters) {
return std::make_unique<PublicArguments>(
clientParameters, std::move(preparedArgs), std::move(ciphertextBuffers));
}

View File

@@ -0,0 +1,137 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concrete-cpu.h"
#include "concretelang/ClientLib/ClientParameters.h"
namespace concretelang {
namespace clientlib {
ConcreteCSPRNG::ConcreteCSPRNG(__uint128_t seed)
: CSPRNG(nullptr, &CONCRETE_CSPRNG_VTABLE) {
ptr = (Csprng *)aligned_alloc(CONCRETE_CSPRNG_ALIGN, CONCRETE_CSPRNG_SIZE);
struct Uint128 u128;
if (seed == 0) {
switch (concrete_cpu_crypto_secure_random_128(&u128)) {
case 1:
break;
case -1:
llvm::errs()
<< "WARNING: The generated random seed is not crypto secure\n";
break;
default:
assert(false && "Cannot instantiate a random seed");
}
} else {
for (int i = 0; i < 16; i++) {
u128.little_endian_bytes[i] = seed >> (8 * i);
}
}
concrete_cpu_construct_concrete_csprng(ptr, u128);
}
ConcreteCSPRNG::ConcreteCSPRNG(ConcreteCSPRNG &&other)
: CSPRNG(other.ptr, &CONCRETE_CSPRNG_VTABLE) {
assert(ptr != nullptr);
other.ptr = nullptr;
}
ConcreteCSPRNG::~ConcreteCSPRNG() {
if (ptr != nullptr) {
concrete_cpu_destroy_concrete_csprng(ptr);
free(ptr);
}
}
LweSecretKey::LweSecretKey(LweSecretKeyParam &parameters, CSPRNG &csprng)
: _parameters(parameters) {
// Allocate the buffer
_buffer = std::make_shared<std::vector<uint64_t>>();
_buffer->resize(parameters.dimension);
// Initialize the lwe secret key buffer
concrete_cpu_init_lwe_secret_key_u64(_buffer->data(), parameters.dimension,
csprng.ptr, csprng.vtable);
}
void LweSecretKey::encrypt(uint64_t *ciphertext, uint64_t plaintext,
double variance, CSPRNG &csprng) const {
concrete_cpu_encrypt_lwe_ciphertext_u64(_buffer->data(), ciphertext,
plaintext, parameters().dimension,
variance, csprng.ptr, csprng.vtable);
}
void LweSecretKey::decrypt(const uint64_t *ciphertext,
uint64_t &plaintext) const {
concrete_cpu_decrypt_lwe_ciphertext_u64(_buffer->data(), ciphertext,
parameters().dimension, &plaintext);
}
LweKeyswitchKey::LweKeyswitchKey(KeyswitchKeyParam &parameters,
LweSecretKey &inputKey,
LweSecretKey &outputKey, CSPRNG &csprng)
: _parameters(parameters) {
// Allocate the buffer
auto size = concrete_cpu_keyswitch_key_size_u64(
_parameters.level, _parameters.baseLog, inputKey.dimension(),
outputKey.dimension());
_buffer = std::make_shared<std::vector<uint64_t>>();
_buffer->resize(size);
// Initialize the keyswitch key buffer
concrete_cpu_init_lwe_keyswitch_key_u64(
_buffer->data(), inputKey.buffer(), outputKey.buffer(),
inputKey.dimension(), outputKey.dimension(), _parameters.level,
_parameters.baseLog, _parameters.variance, csprng.ptr, csprng.vtable);
}
LweBootstrapKey::LweBootstrapKey(BootstrapKeyParam &parameters,
LweSecretKey &inputKey,
LweSecretKey &outputKey, CSPRNG &csprng)
: _parameters(parameters) {
// TODO
size_t polynomial_size = outputKey.dimension() / _parameters.glweDimension;
// Allocate the buffer
auto size = concrete_cpu_bootstrap_key_size_u64(
_parameters.level, _parameters.glweDimension, polynomial_size,
inputKey.dimension());
_buffer = std::make_shared<std::vector<uint64_t>>();
_buffer->resize(size);
// Initialize the keyswitch key buffer
concrete_cpu_init_lwe_bootstrap_key_u64(
_buffer->data(), inputKey.buffer(), outputKey.buffer(),
inputKey.dimension(), polynomial_size, _parameters.glweDimension,
_parameters.level, _parameters.baseLog, _parameters.variance,
Parallelism::Rayon, csprng.ptr, csprng.vtable);
}
PackingKeyswitchKey::PackingKeyswitchKey(PackingKeyswitchKeyParam &params,
LweSecretKey &inputKey,
LweSecretKey &outputKey,
CSPRNG &csprng)
: _parameters(params) {
assert(_parameters.inputLweDimension == inputKey.dimension());
assert(_parameters.glweDimension * _parameters.polynomialSize ==
outputKey.dimension());
// Allocate the buffer
auto size = concrete_cpu_lwe_packing_keyswitch_key_size(
_parameters.glweDimension, _parameters.polynomialSize, _parameters.level,
_parameters.inputLweDimension);
_buffer = std::make_shared<std::vector<uint64_t>>();
_buffer->resize(size * (_parameters.glweDimension + 1));
// Initialize the keyswitch key buffer
concrete_cpu_init_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_u64(
_buffer->data(), inputKey.buffer(), outputKey.buffer(),
_parameters.inputLweDimension, _parameters.polynomialSize,
_parameters.glweDimension, _parameters.level, _parameters.baseLog,
_parameters.variance, Parallelism::Rayon, csprng.ptr, csprng.vtable);
}
} // namespace clientlib
} // namespace concretelang

View File

@@ -6,269 +6,142 @@
#include "concretelang/ClientLib/KeySet.h"
#include "concretelang/ClientLib/CRT.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Runtime/seeder.h"
#include "concretelang/Support/Error.h"
#define CAPI_ERR_TO_STRINGERROR(instr, msg) \
{ \
int err; \
instr; \
if (err != 0) { \
return concretelang::error::StringError(msg); \
} \
}
int clone_transform_lwe_secret_key_to_glwe_secret_key_u64(
DefaultEngine *default_engine, LweSecretKey64 *output_lwe_sk,
size_t poly_size, GlweSecretKey64 **output_glwe_sk) {
LweSecretKey64 *output_lwe_sk_clone = NULL;
int lwe_out_sk_clone_ok =
clone_lwe_secret_key_u64(output_lwe_sk, &output_lwe_sk_clone);
if (lwe_out_sk_clone_ok != 0) {
return 1;
}
int glwe_sk_ok =
default_engine_transform_lwe_secret_key_to_glwe_secret_key_u64(
default_engine, &output_lwe_sk_clone, poly_size, output_glwe_sk);
if (glwe_sk_ok != 0) {
return 1;
}
if (output_lwe_sk_clone != NULL) {
return 1;
}
return 0;
}
#include <cassert>
#include <cstddef>
#include <cstdint>
namespace concretelang {
namespace clientlib {
KeySet::KeySet() {
CAPI_ASSERT_ERROR(new_default_engine(best_seeder, &engine));
CAPI_ASSERT_ERROR(new_default_parallel_engine(best_seeder, &par_engine));
}
KeySet::~KeySet() {
for (auto it : secretKeys) {
CAPI_ASSERT_ERROR(destroy_lwe_secret_key_u64(it.second.second));
}
CAPI_ASSERT_ERROR(destroy_default_engine(engine));
CAPI_ASSERT_ERROR(destroy_default_parallel_engine(par_engine));
}
outcome::checked<std::unique_ptr<KeySet>, StringError>
KeySet::generate(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb) {
auto keySet = std::make_unique<KeySet>();
OUTCOME_TRYV(keySet->generateKeysFromParams(params, seed_msb, seed_lsb));
OUTCOME_TRYV(keySet->setupEncryptionMaterial(params, seed_msb, seed_lsb));
KeySet::generate(ClientParameters clientParameters, CSPRNG &&csprng) {
auto keySet = std::make_unique<KeySet>(clientParameters, std::move(csprng));
OUTCOME_TRYV(keySet->generateKeysFromParams());
OUTCOME_TRYV(keySet->setupEncryptionMaterial());
return std::move(keySet);
}
outcome::checked<void, StringError>
KeySet::setupEncryptionMaterial(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb) {
_clientParameters = params;
outcome::checked<std::unique_ptr<KeySet>, StringError> KeySet::fromKeys(
ClientParameters clientParameters, std::vector<LweSecretKey> secretKeys,
std::vector<LweBootstrapKey> bootstrapKeys,
std::vector<LweKeyswitchKey> keyswitchKeys,
std::vector<PackingKeyswitchKey> packingKeyswitchKeys, CSPRNG &&csprng) {
// Set inputs and outputs LWE secret keys
{
for (auto param : params.inputs) {
LweSecretKeyParam secretKeyParam = {0};
LweSecretKey64 *secretKey = nullptr;
if (param.encryption.hasValue()) {
auto inputSk = this->secretKeys.find(param.encryption->secretKeyID);
if (inputSk == this->secretKeys.end()) {
return StringError("input encryption secret key (")
<< param.encryption->secretKeyID << ") does not exist ";
}
secretKeyParam = inputSk->second.first;
secretKey = inputSk->second.second;
}
std::tuple<CircuitGate, LweSecretKeyParam, LweSecretKey64 *> input = {
param, secretKeyParam, secretKey};
this->inputs.push_back(input);
}
for (auto param : params.outputs) {
LweSecretKeyParam secretKeyParam = {0};
LweSecretKey64 *secretKey = nullptr;
if (param.encryption.hasValue()) {
auto outputSk = this->secretKeys.find(param.encryption->secretKeyID);
if (outputSk == this->secretKeys.end()) {
return StringError(
"cannot find output key to generate bootstrap key");
}
secretKeyParam = outputSk->second.first;
secretKey = outputSk->second.second;
}
std::tuple<CircuitGate, LweSecretKeyParam, LweSecretKey64 *> output = {
param, secretKeyParam, secretKey};
this->outputs.push_back(output);
auto keySet = std::make_unique<KeySet>(clientParameters, std::move(csprng));
keySet->secretKeys = secretKeys;
keySet->bootstrapKeys = bootstrapKeys;
keySet->keyswitchKeys = keyswitchKeys;
keySet->packingKeyswitchKeys = packingKeyswitchKeys;
OUTCOME_TRYV(keySet->setupEncryptionMaterial());
return std::move(keySet);
}
EvaluationKeys KeySet::evaluationKeys() {
return EvaluationKeys(keyswitchKeys, bootstrapKeys, packingKeyswitchKeys);
}
outcome::checked<KeySet::SecretKeyGateMapping, StringError>
KeySet::mapCircuitGateLweSecretKey(std::vector<CircuitGate> gates) {
SecretKeyGateMapping mapping;
for (auto gate : gates) {
if (gate.encryption.hasValue()) {
assert(gate.encryption->secretKeyID < this->secretKeys.size());
auto skIt = this->secretKeys[gate.encryption->secretKeyID];
std::pair<CircuitGate, llvm::Optional<LweSecretKey>> input = {gate, skIt};
mapping.push_back(input);
} else {
std::pair<CircuitGate, llvm::Optional<LweSecretKey>> input = {gate,
llvm::None};
mapping.push_back(input);
}
}
return mapping;
}
outcome::checked<void, StringError> KeySet::setupEncryptionMaterial() {
OUTCOME_TRY(this->inputs,
mapCircuitGateLweSecretKey(_clientParameters.inputs));
OUTCOME_TRY(this->outputs,
mapCircuitGateLweSecretKey(_clientParameters.outputs));
return outcome::success();
}
outcome::checked<void, StringError>
KeySet::generateKeysFromParams(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb) {
outcome::checked<void, StringError> KeySet::generateKeysFromParams() {
{
// Generate LWE secret keys
for (auto secretKeyParam : params.secretKeys) {
OUTCOME_TRYV(
this->generateSecretKey(secretKeyParam.first, secretKeyParam.second));
}
// Generate LWE secret keys
for (auto secretKeyParam : _clientParameters.secretKeys) {
OUTCOME_TRYV(this->generateSecretKey(secretKeyParam));
}
// Generate bootstrap, keyswitch and packing keyswitch keys
{
for (auto bootstrapKeyParam : params.bootstrapKeys) {
OUTCOME_TRYV(this->generateBootstrapKey(bootstrapKeyParam.first,
bootstrapKeyParam.second));
}
for (auto keyswitchParam : params.keyswitchKeys) {
OUTCOME_TRYV(this->generateKeyswitchKey(keyswitchParam.first,
keyswitchParam.second));
}
for (auto packingParam : params.packingKeys) {
OUTCOME_TRYV(
this->generatePackingKey(packingParam.first, packingParam.second));
}
// Generate bootstrap keys
for (auto bootstrapKeyParam : _clientParameters.bootstrapKeys) {
OUTCOME_TRYV(this->generateBootstrapKey(bootstrapKeyParam));
}
// Generate keyswitch key
for (auto keyswitchParam : _clientParameters.keyswitchKeys) {
OUTCOME_TRYV(this->generateKeyswitchKey(keyswitchParam));
}
// Generate packing keyswitch key
for (auto packingKeyswitchKeyParam : _clientParameters.packingKeyswitchKeys) {
OUTCOME_TRYV(this->generatePackingKeyswitchKey(packingKeyswitchKeyParam));
}
return outcome::success();
}
void KeySet::setKeys(
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey64 *>>
secretKeys,
std::map<LweSecretKeyID,
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>>
bootstrapKeys,
std::map<LweSecretKeyID,
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>>
keyswitchKeys,
std::map<LweSecretKeyID, std::pair<PackingKeySwitchParam,
std::shared_ptr<PackingKeyswitchKey>>>
packingKeys) {
this->secretKeys = secretKeys;
this->bootstrapKeys = bootstrapKeys;
this->keyswitchKeys = keyswitchKeys;
this->packingKeys = packingKeys;
}
outcome::checked<void, StringError>
KeySet::generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param) {
LweSecretKey64 *sk;
CAPI_ASSERT_ERROR(default_engine_generate_new_lwe_secret_key_u64(
engine, param.dimension, &sk));
secretKeys[id] = {param, sk};
KeySet::generateSecretKey(LweSecretKeyParam param) {
// Init the lwe secret key
LweSecretKey sk(param, csprng);
// Store the lwe secret key
secretKeys.push_back(sk);
return outcome::success();
}
outcome::checked<LweSecretKey, StringError>
KeySet::findLweSecretKey(LweSecretKeyID keyID) {
assert(keyID < secretKeys.size());
auto secretKey = secretKeys[keyID];
return secretKey;
}
outcome::checked<void, StringError>
KeySet::generateBootstrapKey(BootstrapKeyID id, BootstrapKeyParam param) {
KeySet::generateBootstrapKey(BootstrapKeyParam param) {
// Finding input and output secretKeys
auto inputSk = secretKeys.find(param.inputSecretKeyID);
if (inputSk == secretKeys.end()) {
return StringError("cannot find input key to generate bootstrap key");
}
auto outputSk = secretKeys.find(param.outputSecretKeyID);
if (outputSk == secretKeys.end()) {
return StringError("cannot find output key to generate bootstrap key");
}
// Allocate the bootstrap key
LweBootstrapKey64 *bsk;
uint64_t total_dimension = outputSk->second.first.dimension;
assert(total_dimension % param.glweDimension == 0);
uint64_t polynomialSize = total_dimension / param.glweDimension;
GlweSecretKey64 *output_glwe_sk = nullptr;
// This is not part of the C FFI but rather is a C util exposed for
// convenience in tests.
CAPI_ASSERT_ERROR(clone_transform_lwe_secret_key_to_glwe_secret_key_u64(
engine, outputSk->second.second, polynomialSize, &output_glwe_sk));
CAPI_ASSERT_ERROR(default_parallel_engine_generate_new_lwe_bootstrap_key_u64(
par_engine, inputSk->second.second, output_glwe_sk, param.baseLog,
param.level, param.variance, &bsk));
CAPI_ASSERT_ERROR(destroy_glwe_secret_key_u64(output_glwe_sk));
OUTCOME_TRY(auto inputKey, findLweSecretKey(param.inputSecretKeyID));
OUTCOME_TRY(auto outputKey, findLweSecretKey(param.outputSecretKeyID));
// Initialize the bootstrap key
LweBootstrapKey bootstrapKey(param, inputKey, outputKey, csprng);
// Store the bootstrap key
bootstrapKeys[id] = {param, std::make_shared<LweBootstrapKey>(bsk)};
bootstrapKeys.push_back(std::move(bootstrapKey));
return outcome::success();
}
outcome::checked<void, StringError>
KeySet::generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param) {
KeySet::generateKeyswitchKey(KeyswitchKeyParam param) {
// Finding input and output secretKeys
auto inputSk = secretKeys.find(param.inputSecretKeyID);
if (inputSk == secretKeys.end()) {
return StringError("cannot find input key to generate keyswitch key");
}
auto outputSk = secretKeys.find(param.outputSecretKeyID);
if (outputSk == secretKeys.end()) {
return StringError("cannot find output key to generate keyswitch key");
}
// Allocate the keyswitch key
LweKeyswitchKey64 *ksk;
CAPI_ASSERT_ERROR(default_engine_generate_new_lwe_keyswitch_key_u64(
engine, inputSk->second.second, outputSk->second.second, param.level,
param.baseLog, param.variance, &ksk));
OUTCOME_TRY(auto inputKey, findLweSecretKey(param.inputSecretKeyID));
OUTCOME_TRY(auto outputKey, findLweSecretKey(param.outputSecretKeyID));
// Initialize the bootstrap key
LweKeyswitchKey keyswitchKey(param, inputKey, outputKey, csprng);
// Store the keyswitch key
keyswitchKeys[id] = {param, std::make_shared<LweKeyswitchKey>(ksk)};
keyswitchKeys.push_back(keyswitchKey);
return outcome::success();
}
outcome::checked<void, StringError>
KeySet::generatePackingKey(PackingKeySwitchID id, PackingKeySwitchParam param) {
KeySet::generatePackingKeyswitchKey(PackingKeyswitchKeyParam param) {
// Finding input secretKeys
auto inputSk = secretKeys.find(param.inputSecretKeyID);
if (inputSk == secretKeys.end()) {
return StringError(
"cannot find input key to generate packing keyswitch key");
}
auto bsk = bootstrapKeys.find(param.bootstrapKeyID);
if (bsk == bootstrapKeys.end()) {
return StringError(
"cannot find input key to generate packing keyswitch key");
}
assert(param.inputSecretKeyID < secretKeys.size());
auto inputSk = secretKeys[param.inputSecretKeyID];
// This is not part of the C FFI but rather is a C util exposed for
// convenience in tests.
GlweSecretKey64 *output_glwe_sk = nullptr;
auto lweDimension =
inputSk->second.first.lweDimension() / bsk->second.first.glweDimension;
CAPI_ASSERT_ERROR(clone_transform_lwe_secret_key_to_glwe_secret_key_u64(
engine, inputSk->second.second, lweDimension, &output_glwe_sk));
// Allocate the packing keyswitch key
LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *fpksk;
CAPI_ASSERT_ERROR(
default_parallel_engine_generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_unchecked_u64(
par_engine, inputSk->second.second, output_glwe_sk, param.baseLog,
param.level, param.variance, &fpksk));
assert(param.outputSecretKeyID < secretKeys.size());
auto outputSk = secretKeys[param.outputSecretKeyID];
PackingKeyswitchKey packingKeyswitchKey(param, inputSk, outputSk, csprng);
// Store the keyswitch key
packingKeys[id] = {param, std::make_shared<PackingKeyswitchKey>(fpksk)};
packingKeyswitchKeys.push_back(packingKeyswitchKey);
return outcome::success();
}
@@ -285,8 +158,9 @@ KeySet::allocate_lwe(size_t argPos, uint64_t **ciphertext, uint64_t &size) {
}
auto numBlocks =
encryption->encoding.crt.empty() ? 1 : encryption->encoding.crt.size();
assert(inputSk.second.has_value());
size = std::get<1>(inputSk).lweSize();
size = inputSk.second->parameters().lweSize();
*ciphertext = (uint64_t *)malloc(sizeof(uint64_t) * size * numBlocks);
return outcome::success();
}
@@ -315,8 +189,9 @@ KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) {
return StringError("encrypt_lwe the positional argument is not encrypted");
}
auto encoding = encryption->encoding;
auto lweSecretKeyParam = std::get<1>(inputSk);
auto lweSecretKey = std::get<2>(inputSk);
assert(inputSk.second.has_value());
auto lweSecretKey = *inputSk.second;
auto lweSecretKeyParam = lweSecretKey.parameters();
// CRT encoding - N blocks with crt encoding
auto crt = encryption->encoding.crt;
if (!crt.empty()) {
@@ -324,11 +199,7 @@ KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) {
auto product = crt::productOfModuli(crt);
for (auto modulus : crt) {
auto plaintext = crt::encode(input, modulus, product);
CAPI_ASSERT_ERROR(
default_engine_discard_encrypt_lwe_ciphertext_u64_raw_ptr_buffers(
engine, lweSecretKey, ciphertext, plaintext,
encryption->variance));
lweSecretKey.encrypt(ciphertext, plaintext, encryption->variance, csprng);
ciphertext = ciphertext + lweSecretKeyParam.lweSize();
}
return outcome::success();
@@ -336,10 +207,7 @@ KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) {
// Simple TFHE integers - 1 blocks with one padding bits
// TODO we could check if the input value is in the right range
uint64_t plaintext = input << (64 - (encryption->encoding.precision + 1));
CAPI_ASSERT_ERROR(
default_engine_discard_encrypt_lwe_ciphertext_u64_raw_ptr_buffers(
engine, lweSecretKey, ciphertext, plaintext, encryption->variance));
lweSecretKey.encrypt(ciphertext, plaintext, encryption->variance, csprng);
return outcome::success();
}
@@ -349,8 +217,9 @@ KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) {
return StringError("decrypt_lwe: position of argument is too high");
}
auto outputSk = outputs[argPos];
auto lweSecretKey = std::get<2>(outputSk);
auto lweSecretKeyParam = std::get<1>(outputSk);
assert(outputSk.second.has_value());
auto lweSecretKey = *outputSk.second;
auto lweSecretKeyParam = lweSecretKey.parameters();
auto encryption = std::get<0>(outputSk).encryption;
if (!encryption.hasValue()) {
return StringError("decrypt_lwe: the positional argument is not encrypted");
@@ -358,15 +227,15 @@ KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) {
auto crt = encryption->encoding.crt;
if (!crt.empty()) { // The ciphertext used the crt strategy.
if (!crt.empty()) {
// CRT encoded TFHE integers
// Decrypt and decode remainders
std::vector<int64_t> remainders;
for (auto modulus : crt) {
uint64_t decrypted;
CAPI_ASSERT_ERROR(
default_engine_decrypt_lwe_ciphertext_u64_raw_ptr_buffers(
engine, lweSecretKey, ciphertext, &decrypted));
uint64_t decrypted = 0;
lweSecretKey.decrypt(ciphertext, decrypted);
auto plaintext = crt::decode(decrypted, modulus);
remainders.push_back(plaintext);
ciphertext = ciphertext + lweSecretKeyParam.lweSize();
@@ -386,12 +255,10 @@ KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) {
output -= maxPos * 2;
}
}
} else { // The ciphertext used the scalar strategy
// Decrypt
uint64_t plaintext;
CAPI_ASSERT_ERROR(default_engine_decrypt_lwe_ciphertext_u64_raw_ptr_buffers(
engine, lweSecretKey, ciphertext, &plaintext));
} else {
// Native encoded TFHE integers - 1 blocks with one padding bits
uint64_t plaintext = 0;
lweSecretKey.decrypt(ciphertext, plaintext);
// Decode unsigned integer
uint64_t precision = encryption->encoding.precision;
@@ -415,27 +282,18 @@ KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) {
return outcome::success();
}
const std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey64 *>> &
KeySet::getSecretKeys() {
return secretKeys;
}
const std::vector<LweSecretKey> &KeySet::getSecretKeys() { return secretKeys; }
const std::map<LweSecretKeyID,
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>> &
KeySet::getBootstrapKeys() {
const std::vector<LweBootstrapKey> &KeySet::getBootstrapKeys() {
return bootstrapKeys;
}
const std::map<LweSecretKeyID,
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>> &
KeySet::getKeyswitchKeys() {
const std::vector<LweKeyswitchKey> &KeySet::getKeyswitchKeys() {
return keyswitchKeys;
}
const std::map<LweSecretKeyID, std::pair<PackingKeySwitchParam,
std::shared_ptr<PackingKeyswitchKey>>>
&KeySet::getPackingKeys() {
return packingKeys;
const std::vector<PackingKeyswitchKey> &KeySet::getPackingKeyswitchKeys() {
return packingKeyswitchKeys;
}
} // namespace clientlib

View File

@@ -7,6 +7,8 @@
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/ClientLib/Serializers.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/Path.h"
@@ -16,216 +18,101 @@
#include <string>
#include <utime.h>
#include "concrete-core-ffi.h"
namespace concretelang {
namespace clientlib {
using StringError = concretelang::error::StringError;
template <class Engine, class Key>
outcome::checked<Key *, StringError>
load(llvm::SmallString<0> &path,
int (*deser)(Engine *, BufferView buffer, Key **), Engine *engine) {
template <class Key>
outcome::checked<Key, StringError> loadKey(llvm::SmallString<0> &path,
Key(deser)(std::istream &istream)) {
std::ifstream in((std::string)path, std::ofstream::binary);
if (in.fail()) {
return StringError("Cannot access " + (std::string)path);
}
std::stringstream sbuffer;
sbuffer << in.rdbuf();
if (in.fail()) {
return StringError("Cannot read " + (std::string)path);
auto key = deser(in);
if (in.bad()) {
return StringError("Cannot load key at path(") << (std::string)path << ")";
}
auto content = sbuffer.str();
BufferView buffer = {(const uint8_t *)content.c_str(), content.length()};
Key *result = nullptr;
int error_code = deser(engine, buffer, &result);
if (result == nullptr || error_code != 0) {
return StringError("Cannot deserialize " + (std::string)path);
}
return result;
return key;
}
static void writeFile(llvm::SmallString<0> &path, Buffer content) {
template <class Key>
outcome::checked<void, StringError> saveKey(llvm::SmallString<0> &path,
Key key) {
std::ofstream out((std::string)path, std::ofstream::binary);
out.write((const char *)content.pointer, content.length);
if (out.fail()) {
return StringError("Cannot access " + (std::string)path);
}
out << key;
if (out.bad()) {
return StringError("Cannot save key at path(") << (std::string)path << ")";
}
out.close();
}
outcome::checked<LweSecretKey64 *, StringError>
loadSecretKey(llvm::SmallString<0> &path) {
DefaultSerializationEngine *engine;
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
return load(path, default_serialization_engine_deserialize_lwe_secret_key_u64,
engine);
}
outcome::checked<LweKeyswitchKey64 *, StringError>
loadKeyswitchKey(llvm::SmallString<0> &path) {
DefaultSerializationEngine *engine;
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
return load(path,
default_serialization_engine_deserialize_lwe_keyswitch_key_u64,
engine);
}
outcome::checked<LweBootstrapKey64 *, StringError>
loadBootstrapKey(llvm::SmallString<0> &path) {
DefaultSerializationEngine *engine;
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
return load(path,
default_serialization_engine_deserialize_lwe_bootstrap_key_u64,
engine);
}
outcome::checked<LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *,
StringError>
loadPackingKey(llvm::SmallString<0> &path) {
DefaultSerializationEngine *engine;
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
return load(
path,
default_serialization_engine_deserialize_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_u64,
engine);
}
void saveSecretKey(llvm::SmallString<0> &path, LweSecretKey64 *key) {
DefaultSerializationEngine *engine;
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
Buffer buffer;
CAPI_ASSERT_ERROR(default_serialization_engine_serialize_lwe_secret_key_u64(
engine, key, &buffer));
writeFile(path, buffer);
free(buffer.pointer);
}
void saveBootstrapKey(llvm::SmallString<0> &path, LweBootstrapKey64 *key) {
DefaultSerializationEngine *engine;
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
Buffer buffer;
CAPI_ASSERT_ERROR(
default_serialization_engine_serialize_lwe_bootstrap_key_u64(engine, key,
&buffer));
writeFile(path, buffer);
free(buffer.pointer);
}
void saveKeyswitchKey(llvm::SmallString<0> &path, LweKeyswitchKey64 *key) {
DefaultSerializationEngine *engine;
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
Buffer buffer;
CAPI_ASSERT_ERROR(
default_serialization_engine_serialize_lwe_keyswitch_key_u64(engine, key,
&buffer));
writeFile(path, buffer);
free(buffer.pointer);
}
void savePackingKey(
llvm::SmallString<0> &path,
LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *key) {
DefaultSerializationEngine *engine;
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
Buffer buffer;
CAPI_ASSERT_ERROR(
default_serialization_engine_serialize_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_u64(
engine, key, &buffer));
writeFile(path, buffer);
free(buffer.pointer);
return outcome::success();
}
outcome::checked<std::unique_ptr<KeySet>, StringError>
KeySetCache::loadKeys(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb, std::string folderPath) {
// TODO: text dump of all parameter in /hash
auto key_set = std::make_unique<KeySet>();
// Mark the folder as recently use.
// e.g. so the CI can do some cleanup of unused keys.
utime(folderPath.c_str(), nullptr);
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey64 *>>
secretKeys;
std::map<LweSecretKeyID,
std::pair<BootstrapKeyParam, std::shared_ptr<LweBootstrapKey>>>
bootstrapKeys;
std::map<LweSecretKeyID,
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>>
keyswitchKeys;
std::map<LweSecretKeyID, std::pair<PackingKeySwitchParam,
std::shared_ptr<PackingKeyswitchKey>>>
packingKeys;
std::vector<LweSecretKey> secretKeys;
std::vector<LweBootstrapKey> bootstrapKeys;
std::vector<LweKeyswitchKey> keyswitchKeys;
std::vector<PackingKeyswitchKey> packingKeyswitchKeys;
// Load LWE secret keys
for (auto secretKeyParam : params.secretKeys) {
auto id = secretKeyParam.first;
auto param = secretKeyParam.second;
// Load secret keys
for (auto p : llvm::enumerate(params.secretKeys)) {
// TODO - Check parameters?
// auto param = secretKeyParam.second;
llvm::SmallString<0> path(folderPath);
llvm::sys::path::append(path, "secretKey_" + id);
OUTCOME_TRY(LweSecretKey64 * sk, loadSecretKey(path));
secretKeys[id] = {param, sk};
llvm::sys::path::append(path, "secretKey_" + std::to_string(p.index()));
OUTCOME_TRY(auto key, loadKey(path, readLweSecretKey));
secretKeys.push_back(key);
}
// Load bootstrap keys
for (auto bootstrapKeyParam : params.bootstrapKeys) {
auto id = bootstrapKeyParam.first;
auto param = bootstrapKeyParam.second;
for (auto p : llvm::enumerate(params.bootstrapKeys)) {
// TODO - Check parameters?
// auto param = p.value();
llvm::SmallString<0> path(folderPath);
llvm::sys::path::append(path, "pbsKey_" + id);
OUTCOME_TRY(LweBootstrapKey64 * bsk, loadBootstrapKey(path));
bootstrapKeys[id] = {param, std::make_shared<LweBootstrapKey>(bsk)};
llvm::sys::path::append(path, "pbsKey_" + std::to_string(p.index()));
OUTCOME_TRY(auto key, loadKey(path, readLweBootstrapKey));
bootstrapKeys.push_back(key);
}
// Load keyswitch keys
for (auto keyswitchParam : params.keyswitchKeys) {
auto id = keyswitchParam.first;
auto param = keyswitchParam.second;
for (auto p : llvm::enumerate(params.keyswitchKeys)) {
// TODO - Check parameters?
// auto param = p.value();
llvm::SmallString<0> path(folderPath);
llvm::sys::path::append(path, "ksKey_" + id);
OUTCOME_TRY(LweKeyswitchKey64 * ksk, loadKeyswitchKey(path));
keyswitchKeys[id] = {param, std::make_shared<LweKeyswitchKey>(ksk)};
}
// Load packing keys
for (auto packingParam : params.packingKeys) {
auto id = packingParam.first;
auto param = packingParam.second;
llvm::SmallString<0> path(folderPath);
llvm::sys::path::append(path, "fpksKey_" + id);
OUTCOME_TRY(LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 *
ksk,
loadPackingKey(path));
packingKeys[id] = {param, std::make_shared<PackingKeyswitchKey>(ksk)};
llvm::sys::path::append(path, "ksKey_" + std::to_string(p.index()));
OUTCOME_TRY(auto key, loadKey(path, readLweKeyswitchKey));
keyswitchKeys.push_back(key);
}
key_set->setKeys(secretKeys, bootstrapKeys, keyswitchKeys, packingKeys);
for (auto p : llvm::enumerate(params.packingKeyswitchKeys)) {
// TODO - Check parameters?
// auto param = p.value();
llvm::SmallString<0> path(folderPath);
llvm::sys::path::append(path, "pksKey_" + std::to_string(p.index()));
OUTCOME_TRY(auto key, loadKey(path, readPackingKeyswitchKey));
packingKeyswitchKeys.push_back(key);
}
OUTCOME_TRYV(key_set->setupEncryptionMaterial(params, seed_msb, seed_lsb));
__uint128_t seed = seed_msb;
seed <<= 64;
seed += seed_lsb;
return std::move(key_set);
auto csprng = ConcreteCSPRNG(seed);
OUTCOME_TRY(auto keySet,
KeySet::fromKeys(params, secretKeys, bootstrapKeys, keyswitchKeys,
packingKeyswitchKeys, std::move(csprng)));
return std::move(keySet);
}
outcome::checked<void, StringError> saveKeys(KeySet &key_set,
@@ -239,38 +126,30 @@ outcome::checked<void, StringError> saveKeys(KeySet &key_set,
return StringError("Cannot create directory \"")
<< std::string(folderIncompletePath) << "\": " << err.message();
}
// Save LWE secret keys
for (auto secretKeyParam : key_set.getSecretKeys()) {
auto id = secretKeyParam.first;
auto key = secretKeyParam.second.second;
for (auto p : llvm::enumerate(key_set.getSecretKeys())) {
llvm::SmallString<0> path = folderIncompletePath;
llvm::sys::path::append(path, "secretKey_" + id);
saveSecretKey(path, key);
llvm::sys::path::append(path, "secretKey_" + std::to_string(p.index()));
OUTCOME_TRYV(saveKey(path, p.value()));
}
// Save bootstrap keys
for (auto bootstrapKeyParam : key_set.getBootstrapKeys()) {
auto id = bootstrapKeyParam.first;
auto key = bootstrapKeyParam.second.second;
for (auto p : llvm::enumerate(key_set.getBootstrapKeys())) {
llvm::SmallString<0> path = folderIncompletePath;
llvm::sys::path::append(path, "pbsKey_" + id);
saveBootstrapKey(path, key->get());
llvm::sys::path::append(path, "pbsKey_" + std::to_string(p.index()));
OUTCOME_TRYV(saveKey(path, p.value()));
}
// Save keyswitch keys
for (auto keyswitchParam : key_set.getKeyswitchKeys()) {
auto id = keyswitchParam.first;
auto key = keyswitchParam.second.second;
for (auto p : llvm::enumerate(key_set.getKeyswitchKeys())) {
llvm::SmallString<0> path = folderIncompletePath;
llvm::sys::path::append(path, "ksKey_" + id);
saveKeyswitchKey(path, key->get());
llvm::sys::path::append(path, "ksKey_" + std::to_string(p.index()));
OUTCOME_TRYV(saveKey(path, p.value()));
}
// Save packing keys
for (auto keyswitchParam : key_set.getPackingKeys()) {
auto id = keyswitchParam.first;
auto key = keyswitchParam.second.second;
// Save packing keyswitch keys
for (auto p : llvm::enumerate(key_set.getPackingKeyswitchKeys())) {
llvm::SmallString<0> path = folderIncompletePath;
llvm::sys::path::append(path, "fpksKey_" + id);
savePackingKey(path, key->get());
llvm::sys::path::append(path, "pksKey_" + std::to_string(p.index()));
OUTCOME_TRYV(saveKey(path, p.value()));
}
err = llvm::sys::fs::rename(folderIncompletePath, folderPath);
@@ -338,7 +217,14 @@ KeySetCache::loadOrGenerateSave(ClientParameters &params, uint64_t seed_msb,
std::cerr << "KeySetCache: miss, regenerating " << std::string(folderPath)
<< "\n";
OUTCOME_TRY(auto key_set, KeySet::generate(params, seed_msb, seed_lsb));
__uint128_t seed = seed_msb;
seed <<= 64;
seed += seed_lsb;
auto csprng = ConcreteCSPRNG(seed);
OUTCOME_TRY(auto key_set, KeySet::generate(params, std::move(csprng)));
OUTCOME_TRYV(saveKeys(*key_set, folderPath));
@@ -349,8 +235,13 @@ outcome::checked<std::unique_ptr<KeySet>, StringError>
KeySetCache::generate(std::shared_ptr<KeySetCache> cache,
ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb) {
__uint128_t seed = seed_msb;
seed <<= 64;
seed += seed_lsb;
auto csprng = ConcreteCSPRNG(seed);
return cache ? cache->loadOrGenerateSave(params, seed_msb, seed_lsb)
: KeySet::generate(params, seed_msb, seed_lsb);
: KeySet::generate(params, std::move(csprng));
}
outcome::checked<std::unique_ptr<KeySet>, StringError>

View File

@@ -7,8 +7,6 @@
#include <iostream>
#include <stdlib.h>
#include "concrete-core-ffi.h"
#include "concretelang/ClientLib/PublicArguments.h"
#include "concretelang/ClientLib/Serializers.h"
#include "concretelang/Common/Error.h"
@@ -16,139 +14,226 @@
namespace concretelang {
namespace clientlib {
template <typename Engine, typename Result>
Result read_deser(std::istream &istream,
int (*deser)(Engine *, BufferView, Result *),
Engine *engine) {
size_t length;
readSize(istream, length);
// buffer is too big to be allocated on stack
// vector ensures everything is deallocated w.r.t. new
std::vector<uint8_t> buffer(length);
istream.read((char *)buffer.data(), length);
assert(istream.good());
Result result;
CAPI_ASSERT_ERROR(deser(engine, {buffer.data(), length}, &result));
return result;
}
template <typename BufferLike>
std::ostream &writeBufferLike(std::ostream &ostream, BufferLike &buffer) {
writeSize(ostream, buffer.length);
ostream.write((const char *)buffer.pointer, buffer.length);
template <typename Key>
std::ostream &writeUInt64KeyBuffer(std::ostream &ostream, Key &buffer) {
writeSize(ostream, (uint64_t)buffer.size());
ostream.write((const char *)buffer.buffer(),
buffer.size() * sizeof(uint64_t));
assert(ostream.good());
return ostream;
}
std::ostream &operator<<(std::ostream &ostream, const LweKeyswitchKey64 *key) {
DefaultSerializationEngine *engine;
// No Freeing as it doesn't allocate anything.
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
Buffer b;
CAPI_ASSERT_ERROR(
default_serialization_engine_serialize_lwe_keyswitch_key_u64(engine, key,
&b));
writeBufferLike(ostream, b);
free((void *)b.pointer);
b.pointer = nullptr;
return ostream;
}
std::ostream &operator<<(std::ostream &ostream, const LweBootstrapKey64 *key) {
DefaultSerializationEngine *engine;
// No Freeing as it doesn't allocate anything.
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
Buffer b;
CAPI_ASSERT_ERROR(
default_serialization_engine_serialize_lwe_bootstrap_key_u64(engine, key,
&b))
writeBufferLike(ostream, b);
free((void *)b.pointer);
b.pointer = nullptr;
return ostream;
}
std::ostream &operator<<(std::ostream &ostream,
const FftFourierLweBootstrapKey64 *key) {
FftSerializationEngine *engine;
// No Freeing as it doesn't allocate anything.
CAPI_ASSERT_ERROR(new_fft_serialization_engine(&engine));
Buffer b;
CAPI_ASSERT_ERROR(
fft_serialization_engine_serialize_fft_fourier_lwe_bootstrap_key_u64(
engine, key, &b))
writeBufferLike(ostream, b);
free((void *)b.pointer);
b.pointer = nullptr;
return ostream;
}
std::istream &operator>>(std::istream &istream, LweKeyswitchKey64 *&key) {
DefaultSerializationEngine *engine;
// No Freeing as it doesn't allocate anything.
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
key = read_deser(
istream, default_serialization_engine_deserialize_lwe_keyswitch_key_u64,
engine);
return istream;
}
std::istream &operator>>(std::istream &istream, LweBootstrapKey64 *&key) {
DefaultSerializationEngine *engine;
// No Freeing as it doesn't allocate anything.
CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
key = read_deser(
istream, default_serialization_engine_deserialize_lwe_bootstrap_key_u64,
engine);
return istream;
}
std::istream &operator>>(std::istream &istream,
FftFourierLweBootstrapKey64 *&key) {
FftSerializationEngine *engine;
// No Freeing as it doesn't allocate anything.
CAPI_ASSERT_ERROR(new_fft_serialization_engine(&engine));
key = read_deser(
istream,
fft_serialization_engine_deserialize_fft_fourier_lwe_bootstrap_key_u64,
engine);
return istream;
}
std::istream &operator>>(std::istream &istream,
RuntimeContext &runtimeContext) {
istream >> runtimeContext.evaluationKeys;
std::shared_ptr<std::vector<uint64_t>> &vec) {
// TODO assertion on size?
uint64_t size;
readSize(istream, size);
vec->resize(size);
istream.read((char *)vec->data(), size * sizeof(uint64_t));
assert(istream.good());
return istream;
}
// LweSecretKey ////////////////////////////
std::ostream &operator<<(std::ostream &ostream, const LweSecretKeyParam param) {
writeWord(ostream, param.dimension);
return ostream;
}
std::istream &operator>>(std::istream &istream, LweSecretKeyParam &param) {
readWord(istream, param.dimension);
return istream;
}
// LweSecretKey /////////////////////////////////
std::ostream &operator<<(std::ostream &ostream, const LweSecretKey &key) {
ostream << key.parameters();
writeUInt64KeyBuffer(ostream, key);
return ostream;
}
LweSecretKey readLweSecretKey(std::istream &istream) {
LweSecretKeyParam param;
istream >> param;
auto buffer = std::make_shared<std::vector<uint64_t>>();
istream >> buffer;
return LweSecretKey(buffer, param);
}
// KeyswitchKeyParam ////////////////////////////
std::ostream &operator<<(std::ostream &ostream, const KeyswitchKeyParam param) {
// TODO keys id
writeWord(ostream, param.level);
writeWord(ostream, param.baseLog);
writeWord(ostream, param.variance);
return ostream;
}
std::istream &operator>>(std::istream &istream, KeyswitchKeyParam &param) {
// TODO keys id
param.outputSecretKeyID = 1234;
param.inputSecretKeyID = 1234;
readWord(istream, param.level);
readWord(istream, param.baseLog);
readWord(istream, param.variance);
return istream;
}
// LweKeyswitchKey //////////////////////////////
std::ostream &operator<<(std::ostream &ostream, const LweKeyswitchKey &key) {
ostream << key.parameters();
writeUInt64KeyBuffer(ostream, key);
return ostream;
}
LweKeyswitchKey readLweKeyswitchKey(std::istream &istream) {
KeyswitchKeyParam param;
istream >> param;
auto buffer = std::make_shared<std::vector<uint64_t>>();
istream >> buffer;
return LweKeyswitchKey(buffer, param);
}
// BootstrapKeyParam ////////////////////////////
std::ostream &operator<<(std::ostream &ostream, const BootstrapKeyParam param) {
// TODO keys id
writeWord(ostream, param.level);
writeWord(ostream, param.baseLog);
writeWord(ostream, param.glweDimension);
writeWord(ostream, param.variance);
writeWord(ostream, param.polynomialSize);
writeWord(ostream, param.inputLweDimension);
return ostream;
}
std::istream &operator>>(std::istream &istream, BootstrapKeyParam &param) {
// TODO keys id
readWord(istream, param.level);
readWord(istream, param.baseLog);
readWord(istream, param.glweDimension);
readWord(istream, param.variance);
readWord(istream, param.polynomialSize);
readWord(istream, param.inputLweDimension);
return istream;
}
// LweBootstrapKey //////////////////////////////
std::ostream &operator<<(std::ostream &ostream, const LweBootstrapKey &key) {
ostream << key.parameters();
writeUInt64KeyBuffer(ostream, key);
return ostream;
}
LweBootstrapKey readLweBootstrapKey(std::istream &istream) {
BootstrapKeyParam param;
istream >> param;
auto buffer = std::make_shared<std::vector<uint64_t>>();
istream >> buffer;
return LweBootstrapKey(buffer, param);
}
// PackingKeyswitchKeyParam ////////////////////////////
std::ostream &operator<<(std::ostream &ostream,
const RuntimeContext &runtimeContext) {
ostream << runtimeContext.evaluationKeys;
const PackingKeyswitchKeyParam param) {
// TODO keys id
writeWord(ostream, param.level);
writeWord(ostream, param.baseLog);
writeWord(ostream, param.glweDimension);
writeWord(ostream, param.polynomialSize);
writeWord(ostream, param.inputLweDimension);
writeWord(ostream, param.variance);
return ostream;
}
std::istream &operator>>(std::istream &istream,
PackingKeyswitchKeyParam &param) {
// TODO keys id
param.outputSecretKeyID = 1234;
param.inputSecretKeyID = 1234;
readWord(istream, param.level);
readWord(istream, param.baseLog);
readWord(istream, param.glweDimension);
readWord(istream, param.polynomialSize);
readWord(istream, param.inputLweDimension);
readWord(istream, param.variance);
return istream;
}
// PackingKeyswitchKey //////////////////////////////
std::ostream &operator<<(std::ostream &ostream,
const PackingKeyswitchKey &key) {
ostream << key.parameters();
writeUInt64KeyBuffer(ostream, key);
return ostream;
}
PackingKeyswitchKey readPackingKeyswitchKey(std::istream &istream) {
PackingKeyswitchKeyParam param;
istream >> param;
auto buffer = std::make_shared<std::vector<uint64_t>>();
istream >> buffer;
auto b = PackingKeyswitchKey(buffer, param);
return b;
}
// EvaluationKey ////////////////////////////////
EvaluationKeys readEvaluationKeys(std::istream &istream) {
uint64_t nbKey;
readSize(istream, nbKey);
std::vector<LweBootstrapKey> bootstrapKeys;
for (uint64_t i = 0; i < nbKey; i++) {
bootstrapKeys.push_back(readLweBootstrapKey(istream));
}
readSize(istream, nbKey);
std::vector<LweKeyswitchKey> keyswitchKeys;
for (uint64_t i = 0; i < nbKey; i++) {
keyswitchKeys.push_back(readLweKeyswitchKey(istream));
}
std::vector<PackingKeyswitchKey> packingKeyswitchKeys;
readSize(istream, nbKey);
for (uint64_t i = 0; i < nbKey; i++) {
packingKeyswitchKeys.push_back(readPackingKeyswitchKey(istream));
}
return EvaluationKeys(keyswitchKeys, bootstrapKeys, packingKeyswitchKeys);
}
std::ostream &operator<<(std::ostream &ostream,
const EvaluationKeys &evaluationKeys) {
auto bootstrapKeys = evaluationKeys.getBootstrapKeys();
writeSize(ostream, bootstrapKeys.size());
for (auto bsk : bootstrapKeys) {
ostream << bsk;
}
auto keyswitchKeys = evaluationKeys.getKeyswitchKeys();
writeSize(ostream, keyswitchKeys.size());
for (auto ksk : keyswitchKeys) {
ostream << ksk;
}
auto packingKeyswitchKeys = evaluationKeys.getPackingKeyswitchKeys();
writeSize(ostream, packingKeyswitchKeys.size());
for (auto pksk : packingKeyswitchKeys) {
ostream << pksk;
}
assert(ostream.good());
return ostream;
}
// TensorData ///////////////////////////////////
template <typename T>
std::ostream &serializeScalarDataRaw(T value, std::ostream &ostream) {
writeWord<uint64_t>(ostream, sizeof(T) * 8);
@@ -399,70 +484,5 @@ unserializeScalarOrTensorData(const std::vector<int64_t> &expectedSizes,
}
}
std::ostream &operator<<(std::ostream &ostream,
const LweKeyswitchKey &wrappedKsk) {
ostream << wrappedKsk.ksk;
assert(ostream.good());
return ostream;
}
std::istream &operator>>(std::istream &istream, LweKeyswitchKey &wrappedKsk) {
istream >> wrappedKsk.ksk;
assert(istream.good());
return istream;
}
std::ostream &operator<<(std::ostream &ostream,
const LweBootstrapKey &wrappedBsk) {
ostream << wrappedBsk.bsk;
assert(ostream.good());
return ostream;
}
std::istream &operator>>(std::istream &istream, LweBootstrapKey &wrappedBsk) {
istream >> wrappedBsk.bsk;
assert(istream.good());
return istream;
}
std::ostream &operator<<(std::ostream &ostream,
const EvaluationKeys &evaluationKeys) {
bool has_ksk = (bool)evaluationKeys.sharedKsk;
writeWord(ostream, has_ksk);
if (has_ksk) {
ostream << *evaluationKeys.sharedKsk;
}
bool has_bsk = (bool)evaluationKeys.sharedBsk;
writeWord(ostream, has_bsk);
if (has_bsk) {
ostream << *evaluationKeys.sharedBsk;
}
assert(ostream.good());
return ostream;
}
std::istream &operator>>(std::istream &istream,
EvaluationKeys &evaluationKeys) {
bool has_ksk;
readWord(istream, has_ksk);
if (has_ksk) {
auto sharedKsk = LweKeyswitchKey(nullptr);
istream >> sharedKsk;
evaluationKeys.sharedKsk =
std::make_shared<LweKeyswitchKey>(std::move(sharedKsk));
}
bool has_bsk;
readWord(istream, has_bsk);
if (has_bsk) {
auto sharedBsk = LweBootstrapKey(nullptr);
istream >> sharedBsk;
evaluationKeys.sharedBsk =
std::make_shared<LweBootstrapKey>(std::move(sharedBsk));
}
assert(istream.good());
return istream;
}
} // namespace clientlib
} // namespace concretelang

View File

@@ -47,8 +47,7 @@ char memref_wop_pbs_crt_buffer[] = "memref_wop_pbs_crt_buffer";
char memref_encode_plaintext_with_crt[] = "memref_encode_plaintext_with_crt";
char memref_encode_expand_lut_for_bootstrap[] =
"memref_encode_expand_lut_for_bootstrap";
char memref_encode_expand_lut_for_woppbs[] =
"memref_encode_expand_lut_for_woppbs";
char memref_encode_lut_for_crt_woppbs[] = "memref_encode_lut_for_crt_woppbs";
char memref_trace[] = "memref_trace";
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
@@ -158,10 +157,16 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
} else if (funcName == memref_wop_pbs_crt_buffer) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{
memref2DType,
memref2DType,
memref2DType,
memref1DType,
memref1DType,
rewriter.getI32Type(),
rewriter.getI32Type(),
rewriter.getI32Type(),
rewriter.getI32Type(),
rewriter.getI32Type(),
rewriter.getI32Type(),
rewriter.getI32Type(),
rewriter.getI32Type(),
rewriter.getI32Type(),
@@ -181,11 +186,11 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
{memref1DType, memref1DType, rewriter.getI32Type(),
rewriter.getI32Type(), rewriter.getI1Type()},
{});
} else if (funcName == memref_encode_expand_lut_for_woppbs) {
} else if (funcName == memref_encode_lut_for_crt_woppbs) {
funcType = mlir::FunctionType::get(
rewriter.getContext(),
{memref1DType, memref1DType, memref1DType, memref1DType,
rewriter.getI32Type(), rewriter.getI32Type(), rewriter.getI1Type()},
{memref2DType, memref1DType, memref1DType, memref1DType,
rewriter.getI32Type(), rewriter.getI1Type()},
{});
} else if (funcName == memref_trace) {
funcType = mlir::FunctionType::get(
@@ -326,9 +331,32 @@ void wopPBSAddOperands(Concrete::WopPBSCRTLweBufferOp op,
// cbs_base_log
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.circuitBootstrapBaseLogAttr()));
// ksk_level_count
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.keyswitchLevelAttr()));
// ksk_base_log
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.keyswitchBaseLogAttr()));
// bsk_level_count
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.bootstrapLevelAttr()));
// bsk_base_log
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.bootstrapBaseLogAttr()));
// fpksk_level_count
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.packingKeySwitchLevelAttr()));
// fpksk_base_log
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.packingKeySwitchBaseLogAttr()));
// polynomial_size
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.packingKeySwitchoutputPolynomialSizeAttr()));
// context
operands.push_back(getContextArgument(op));
}
@@ -372,9 +400,9 @@ void encodeExpandLutForBootstrapAddOperands(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.isSignedAttr()));
}
void encodeExpandLutForWopPBSAddOperands(
Concrete::EncodeExpandLutForWopPBSBufferOp op,
mlir::SmallVector<mlir::Value> &operands, mlir::RewriterBase &rewriter) {
void encodeLutForWopPBSAddOperands(Concrete::EncodeLutForCrtWopPBSBufferOp op,
mlir::SmallVector<mlir::Value> &operands,
mlir::RewriterBase &rewriter) {
// crt_decomposition
mlir::Type crtDecompositionType = mlir::RankedTensorType::get(
@@ -414,9 +442,6 @@ void encodeExpandLutForWopPBSAddOperands(
op.getLoc(), (*crtBitsGlobalMemref).type(),
(*crtBitsGlobalMemref).getName());
operands.push_back(getCastedMemRef(rewriter, crtBitsGlobalRef));
// poly_size
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.polySizeAttr()));
// modulus_product
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.modulusProductAttr()));
@@ -467,10 +492,10 @@ struct ConcreteToCAPIPass : public ConcreteToCAPIBase<ConcreteToCAPIPass> {
ConcreteToCAPICallPattern<Concrete::EncodeExpandLutForBootstrapBufferOp,
memref_encode_expand_lut_for_bootstrap>>(
&getContext(), encodeExpandLutForBootstrapAddOperands);
patterns.add<
ConcreteToCAPICallPattern<Concrete::EncodeExpandLutForWopPBSBufferOp,
memref_encode_expand_lut_for_woppbs>>(
&getContext(), encodeExpandLutForWopPBSAddOperands);
patterns
.add<ConcreteToCAPICallPattern<Concrete::EncodeLutForCrtWopPBSBufferOp,
memref_encode_lut_for_crt_woppbs>>(
&getContext(), encodeLutForWopPBSAddOperands);
if (gpu) {
patterns.add<ConcreteToCAPICallPattern<Concrete::KeySwitchLweBufferOp,
memref_keyswitch_lwe_cuda_u64>>(

View File

@@ -560,17 +560,18 @@ struct ApplyLookupTableEintOpPattern
mlir::Value newLut =
rewriter
.create<TFHE::EncodeExpandLutForWopPBSOp>(
.create<TFHE::EncodeLutForCrtWopPBSOp>(
op.getLoc(),
mlir::RankedTensorType::get(
mlir::ArrayRef<int64_t>(loweringParameters.lutSize),
mlir::ArrayRef<int64_t>{
(int64_t)loweringParameters.nMods,
(int64_t)loweringParameters.singleLutSize},
rewriter.getI64Type()),
adaptor.lut(),
rewriter.getI64ArrayAttr(
mlir::ArrayRef<int64_t>(loweringParameters.mods)),
rewriter.getI64ArrayAttr(
mlir::ArrayRef<int64_t>(loweringParameters.bits)),
rewriter.getI32IntegerAttr(loweringParameters.polynomialSize),
rewriter.getI32IntegerAttr(loweringParameters.modsProd),
rewriter.getBoolAttr(originalInputType.isSigned()))
.getResult();

View File

@@ -664,8 +664,8 @@ void TFHEToConcretePass::runOnOperation() {
mlir::concretelang::Concrete::EncodeExpandLutForBootstrapTensorOp,
true>,
mlir::concretelang::GenericOneToOneOpConversionPattern<
mlir::concretelang::TFHE::EncodeExpandLutForWopPBSOp,
mlir::concretelang::Concrete::EncodeExpandLutForWopPBSTensorOp, true>,
mlir::concretelang::TFHE::EncodeLutForCrtWopPBSOp,
mlir::concretelang::Concrete::EncodeLutForCrtWopPBSTensorOp, true>,
mlir::concretelang::GenericOneToOneOpConversionPattern<
mlir::concretelang::TFHE::EncodePlaintextWithCrtOp,
mlir::concretelang::Concrete::EncodePlaintextWithCrtTensorOp, true>,

View File

@@ -143,10 +143,10 @@ void mlir::concretelang::Concrete::
Concrete::EncodeExpandLutForBootstrapTensorOp::attachInterface<
TensorToMemrefOp<Concrete::EncodeExpandLutForBootstrapTensorOp,
Concrete::EncodeExpandLutForBootstrapBufferOp>>(*ctx);
// encode_expand_lut_for_woppbs_tensor =>
// encode_expand_lut_for_woppbs_buffer
Concrete::EncodeExpandLutForWopPBSTensorOp::attachInterface<
TensorToMemrefOp<Concrete::EncodeExpandLutForWopPBSTensorOp,
Concrete::EncodeExpandLutForWopPBSBufferOp>>(*ctx);
// encode_lut_for_crt_woppbs_tensor =>
// encode_lut_for_crt_woppbs_buffer
Concrete::EncodeLutForCrtWopPBSTensorOp::attachInterface<
TensorToMemrefOp<Concrete::EncodeLutForCrtWopPBSTensorOp,
Concrete::EncodeLutForCrtWopPBSBufferOp>>(*ctx);
});
}

View File

@@ -1,4 +1,6 @@
add_library(ConcretelangRuntime SHARED context.cpp wrappers.cpp DFRuntime.cpp StreamEmulator.cpp seeder.cpp)
add_library(ConcretelangRuntime SHARED context.cpp wrappers.cpp DFRuntime.cpp StreamEmulator.cpp)
add_dependencies(ConcretelangRuntime concrete_cpu)
if(CONCRETELANG_DATAFLOW_EXECUTION_ENABLED)
target_link_libraries(ConcretelangRuntime PRIVATE HPX::hpx HPX::iostreams_component)
@@ -25,7 +27,8 @@ if(APPLE)
target_link_libraries(ConcretelangRuntime LINK_PUBLIC ${SECURITY_FRAMEWORK})
endif()
target_link_libraries(ConcretelangRuntime PUBLIC Concrete ConcretelangClientLib pthread m dl
target_include_directories(ConcretelangRuntime PUBLIC ${CONCRETE_CPU_INCLUDE_DIR})
target_link_libraries(ConcretelangRuntime PUBLIC concrete_cpu ConcretelangClientLib pthread m dl
$<TARGET_OBJECTS:mlir_c_runner_utils>)
install(TARGETS ConcretelangRuntime omp EXPORT ConcretelangRuntime)

View File

@@ -5,29 +5,77 @@
#include "concretelang/Runtime/context.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Runtime/seeder.h"
#include <assert.h>
#include <stdio.h>
LweKeyswitchKey64 *
get_keyswitch_key_u64(mlir::concretelang::RuntimeContext *context) {
return context->get_ksk();
namespace clientlib = ::concretelang::clientlib;
namespace mlir {
namespace concretelang {
FFT::FFT(size_t polynomial_size)
: fft(nullptr), polynomial_size(polynomial_size) {
fft = (struct Fft *)aligned_alloc(CONCRETE_FFT_ALIGN, CONCRETE_FFT_SIZE);
concrete_cpu_construct_concrete_fft(fft, polynomial_size);
}
LweBootstrapKey64 *
get_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context) {
return context->get_bsk();
FFT::FFT(FFT &&other) : fft(other.fft), polynomial_size(other.polynomial_size) {
other.fft = nullptr;
}
FftFourierLweBootstrapKey64 *
get_fft_fourier_bootstrap_key_u64(mlir::concretelang::RuntimeContext *context) {
return context->get_fft_fourier_bsk();
FFT::~FFT() {
if (fft != nullptr) {
concrete_cpu_destroy_concrete_fft(fft);
free(fft);
}
}
DefaultEngine *get_engine(mlir::concretelang::RuntimeContext *context) {
return context->get_default_engine();
RuntimeContext::RuntimeContext(clientlib::EvaluationKeys evaluationKeys)
: evaluationKeys(evaluationKeys) {
{
// Initialize for each bootstrap key the fourier one
for (auto bsk : evaluationKeys.getBootstrapKeys()) {
auto param = bsk.parameters();
size_t decomposition_level_count = param.level;
size_t decomposition_base_log = param.baseLog;
size_t glwe_dimension = param.glweDimension;
size_t polynomial_size = param.polynomialSize;
size_t input_lwe_dimension = param.inputLweDimension;
// Create the FFT
FFT fft(polynomial_size);
// Allocate scratch for key conversion
size_t scratch_size;
size_t scratch_align;
concrete_cpu_bootstrap_key_convert_u64_to_fourier_scratch(
&scratch_size, &scratch_align, fft.fft);
auto scratch = (uint8_t *)aligned_alloc(scratch_align, scratch_size);
// Allocate the fourier_bootstrap_key
auto fourier_data = std::make_shared<std::vector<double>>();
fourier_data->resize(bsk.size());
auto bsk_data = bsk.buffer();
// Convert bootstrap_key to the fourier domain
concrete_cpu_bootstrap_key_convert_u64_to_fourier(
bsk_data, fourier_data->data(), decomposition_level_count,
decomposition_base_log, glwe_dimension, polynomial_size,
input_lwe_dimension, fft.fft, scratch, scratch_size);
// Store the fourier_bootstrap_key in the context
fourier_bootstrap_keys.push_back(fourier_data);
ffts.push_back(std::move(fft));
free(scratch);
}
#ifdef CONCRETELANG_CUDA_SUPPORT
bsk_gpu = nullptr;
ksk_gpu = nullptr;
#endif
}
}
FftEngine *get_fft_engine(mlir::concretelang::RuntimeContext *context) {
return context->get_fft_engine();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -1,52 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "concretelang/Runtime/seeder.h"
#include <cassert>
#include <iostream>
#include "concrete-core-ffi.h"
#include "concretelang/Common/Error.h"
SeederBuilder *get_best_seeder() {
SeederBuilder *builder = NULL;
#if defined(__x86_64__) || defined(_M_X64)
bool rdseed_seeder_available = false;
CAPI_ASSERT_ERROR(rdseed_seeder_is_available(&rdseed_seeder_available));
if (rdseed_seeder_available) {
CAPI_ASSERT_ERROR(get_rdseed_seeder_builder(&builder));
return builder;
}
#endif
#if __APPLE__
bool apple_seeder_available = false;
CAPI_ASSERT_ERROR(
apple_secure_enclave_seeder_is_available(&apple_seeder_available));
if (apple_seeder_available) {
CAPI_ASSERT_ERROR(get_apple_secure_enclave_seeder_builder(&builder));
return builder;
}
#endif
bool unix_seeder_available = false;
CAPI_ASSERT_ERROR(unix_seeder_is_available(&unix_seeder_available));
if (unix_seeder_available) {
// Security depends on /dev/random security
uint64_t secret_high_64 = 0;
uint64_t secret_low_64 = 0;
CAPI_ASSERT_ERROR(
get_unix_seeder_builder(secret_high_64, secret_low_64, &builder));
return builder;
}
std::cout << "No available seeder." << std::endl;
return builder;
}
SeederBuilder *best_seeder = get_best_seeder();

View File

@@ -4,8 +4,8 @@
// for license information.
#include "concretelang/Runtime/wrappers.h"
#include "concrete-cpu.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Runtime/seeder.h"
#include <assert.h>
#include <bitset>
#include <cmath>
@@ -16,15 +16,6 @@
#include <string.h>
#include <vector>
static DefaultEngine *levelled_engine = nullptr;
DefaultEngine *get_levelled_engine() {
if (levelled_engine == nullptr) {
CAPI_ASSERT_ERROR(new_default_engine(best_seeder, &levelled_engine));
}
return levelled_engine;
}
#include "concretelang/ClientLib/CRT.h"
#include "concretelang/Runtime/wrappers.h"
@@ -184,18 +175,21 @@ void memref_batched_bootstrap_lwe_cuda_u64(
// Construct the glwe accumulator (on CPU)
// TODO: Should be done outside of the bootstrap call, compile time if
// possible. Refactor in progress
uint64_t glwe_ct_len = poly_size * (glwe_dim + 1);
uint64_t glwe_ct_size = glwe_ct_len * sizeof(uint64_t);
uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size);
uint64_t glwe_ct_size = poly_size * (glwe_dim + 1);
uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size * sizeof(uint64_t));
auto tlu = tlu_aligned + tlu_offset;
CAPI_ASSERT_ERROR(
default_engine_discard_trivially_encrypt_glwe_ciphertext_u64_raw_ptr_buffers(
get_levelled_engine(), glwe_ct, glwe_ct_len, tlu_aligned + tlu_offset,
poly_size));
// Glwe trivial encryption
for (size_t i = 0; i < poly_size * glwe_dim; i++) {
glwe_ct[i] = 0;
}
for (size_t i = 0; i < poly_size; i++) {
glwe_ct[poly_size * glwe_dim + i] = tlu[i];
}
// Move the glwe accumulator to the GPU
void *glwe_ct_gpu =
alloc_and_memcpy_async_to_gpu(glwe_ct, 0, glwe_ct_len, gpu_idx, stream);
alloc_and_memcpy_async_to_gpu(glwe_ct, 0, glwe_ct_size, gpu_idx, stream);
// Move test vector indexes to the GPU, the test vector indexes is set of 0
uint32_t num_test_vectors = 1, lwe_idx = 0,
@@ -313,11 +307,12 @@ void memref_encode_expand_lut_for_bootstrap(
return;
}
void memref_encode_expand_lut_for_woppbs(
void memref_encode_lut_for_crt_woppbs(
// Output encoded/expanded lut
uint64_t *output_lut_allocated, uint64_t *output_lut_aligned,
uint64_t output_lut_offset, uint64_t output_lut_size,
uint64_t output_lut_stride,
uint64_t output_lut_offset, uint64_t output_lut_size0,
uint64_t output_lut_size1, uint64_t output_lut_stride0,
uint64_t output_lut_stride1,
// Input lut
uint64_t *input_lut_allocated, uint64_t *input_lut_aligned,
uint64_t input_lut_offset, uint64_t input_lut_size,
@@ -330,13 +325,22 @@ void memref_encode_expand_lut_for_woppbs(
uint64_t *crt_bits_allocated, uint64_t *crt_bits_aligned,
uint64_t crt_bits_offset, uint64_t crt_bits_size, uint64_t crt_bits_stride,
// Crypto parameters
uint32_t poly_size, uint32_t modulus_product, bool is_signed) {
uint32_t modulus_product, bool is_signed) {
assert(input_lut_stride == 1 && "Runtime: stride not equal to 1, check "
"memref_encode_expand_lut_woppbs");
assert(output_lut_stride == 1 && "Runtime: stride not equal to 1, check "
"memref_encode_expand_lut_woppbs");
assert(modulus_product > input_lut_size);
"memref_encode_lut_woppbs");
assert(output_lut_stride0 == output_lut_size1 &&
"Runtime: out dim stride not equal to in_dim size, check "
"memref_encode_lut_woppbs");
assert(output_lut_stride1 == 1 && "Runtime: stride not equal to 1, check "
"memref_encode_lut_woppbs");
assert(modulus_product >= input_lut_size);
// Initialize lut cases not supposed to be reached
for (uint64_t i = 0; i < output_lut_size0 * output_lut_size1; i++) {
output_lut_aligned[output_lut_offset + i] = 0;
}
// When the woppbs is executed on encrypted signed integers, the index of the
// lut elements must be adapted to fit the way signed are encrypted in CRT
@@ -393,25 +397,40 @@ void memref_encode_expand_lut_for_woppbs(
};
}
uint64_t lut_crt_size = output_lut_size / crt_decomposition_size;
uint64_t log_lut_crt_size = 0;
for (uint64_t index = 0; index < input_lut_size; index++) {
uint64_t index_lut = 0;
uint64_t tmp = 1;
for (size_t in_block = 0; in_block < crt_decomposition_size; in_block++) {
auto bits_count = crt_bits_aligned[crt_bits_offset + in_block];
log_lut_crt_size += bits_count;
}
for (size_t block = 0; block < crt_decomposition_size; block++) {
auto base = crt_decomposition_aligned[crt_decomposition_offset + block];
auto bits = crt_bits_aligned[crt_bits_offset + block];
index_lut += (((indexMap(index) % base) << bits) / base) * tmp;
tmp <<= bits;
uint64_t lut_crt_size = 1 << log_lut_crt_size;
assert(lut_crt_size == output_lut_size1);
assert(crt_decomposition_size == output_lut_size0);
for (uint64_t in_index = 0; in_index < input_lut_size; in_index++) {
uint64_t out_index = 0;
{
uint64_t total_bit_count = 0;
for (size_t in_block = 0; in_block < crt_decomposition_size; in_block++) {
auto in_base =
crt_decomposition_aligned[crt_decomposition_offset + in_block];
auto bits_count = crt_bits_aligned[crt_bits_offset + in_block];
out_index += (((indexMap(in_index) % in_base) << bits_count) / in_base)
<< total_bit_count;
total_bit_count += bits_count;
}
}
for (size_t block = 0; block < crt_decomposition_size; block++) {
auto base = crt_decomposition_aligned[crt_decomposition_offset + block];
auto v = encode_crt(input_lut_aligned[input_lut_offset + index], base,
modulus_product);
output_lut_aligned[output_lut_offset + block * lut_crt_size + index_lut] =
v;
for (size_t out_block = 0; out_block < crt_decomposition_size;
out_block++) {
auto out_base =
crt_decomposition_aligned[crt_decomposition_offset + out_block];
auto v = encode_crt(input_lut_aligned[input_lut_offset + in_index],
out_base, modulus_product);
output_lut_aligned[output_lut_offset + out_block * lut_crt_size +
out_index] = v;
}
}
}
@@ -424,11 +443,10 @@ void memref_add_lwe_ciphertexts_u64(
uint64_t ct1_offset, uint64_t ct1_size, uint64_t ct1_stride) {
assert(out_size == ct0_size && out_size == ct1_size &&
"size of lwe buffer are incompatible");
size_t lwe_dimension = {out_size - 1};
CAPI_ASSERT_ERROR(
default_engine_discard_add_lwe_ciphertext_u64_raw_ptr_buffers(
get_levelled_engine(), out_aligned + out_offset,
ct0_aligned + ct0_offset, ct1_aligned + ct1_offset, lwe_dimension));
size_t lwe_dimension = out_size - 1;
concrete_cpu_add_lwe_ciphertext_u64(out_aligned + out_offset,
ct0_aligned + ct0_offset,
ct1_aligned + ct1_offset, lwe_dimension);
}
void memref_add_plaintext_lwe_ciphertext_u64(
@@ -437,11 +455,10 @@ void memref_add_plaintext_lwe_ciphertext_u64(
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, uint64_t plaintext) {
assert(out_size == ct0_size && "size of lwe buffer are incompatible");
size_t lwe_dimension = {out_size - 1};
CAPI_ASSERT_ERROR(
default_engine_discard_add_lwe_ciphertext_plaintext_u64_raw_ptr_buffers(
get_levelled_engine(), out_aligned + out_offset,
ct0_aligned + ct0_offset, lwe_dimension, plaintext));
size_t lwe_dimension = out_size - 1;
concrete_cpu_add_plaintext_lwe_ciphertext_u64(out_aligned + out_offset,
ct0_aligned + ct0_offset,
plaintext, lwe_dimension);
}
void memref_mul_cleartext_lwe_ciphertext_u64(
@@ -450,11 +467,10 @@ void memref_mul_cleartext_lwe_ciphertext_u64(
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, uint64_t cleartext) {
assert(out_size == ct0_size && "size of lwe buffer are incompatible");
size_t lwe_dimension = {out_size - 1};
CAPI_ASSERT_ERROR(
default_engine_discard_mul_lwe_ciphertext_cleartext_u64_raw_ptr_buffers(
get_levelled_engine(), out_aligned + out_offset,
ct0_aligned + ct0_offset, lwe_dimension, cleartext));
size_t lwe_dimension = out_size - 1;
concrete_cpu_mul_cleartext_lwe_ciphertext_u64(out_aligned + out_offset,
ct0_aligned + ct0_offset,
cleartext, lwe_dimension);
}
void memref_negate_lwe_ciphertext_u64(
@@ -464,24 +480,25 @@ void memref_negate_lwe_ciphertext_u64(
uint64_t ct0_stride) {
assert(out_size == ct0_size && "size of lwe buffer are incompatible");
size_t lwe_dimension = {out_size - 1};
CAPI_ASSERT_ERROR(
default_engine_discard_opp_lwe_ciphertext_u64_raw_ptr_buffers(
get_levelled_engine(), out_aligned + out_offset,
ct0_aligned + ct0_offset, lwe_dimension));
concrete_cpu_negate_lwe_ciphertext_u64(
out_aligned + out_offset, ct0_aligned + ct0_offset, lwe_dimension);
}
void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
uint64_t out_offset, uint64_t out_size,
uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset,
uint64_t ct0_size, uint64_t ct0_stride,
uint32_t level, uint32_t base_log,
uint32_t input_lwe_dim, uint32_t output_lwe_dim,
mlir::concretelang::RuntimeContext *context) {
CAPI_ASSERT_ERROR(
default_engine_discard_keyswitch_lwe_ciphertext_u64_raw_ptr_buffers(
get_engine(context), get_keyswitch_key_u64(context),
out_aligned + out_offset, ct0_aligned + ct0_offset));
void memref_keyswitch_lwe_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, uint32_t decomposition_level_count,
uint32_t decomposition_base_log, uint32_t input_dimension,
uint32_t output_dimension, mlir::concretelang::RuntimeContext *context) {
assert(out_stride == 1 && ct0_stride == 1);
// Get keyswitch key - TODO Give a non hardcoded keyID
const uint64_t *keyswitch_key = context->keyswitch_key_buffer(0);
// Get stack parameter
concrete_cpu_keyswitch_lwe_ciphertext_u64(
out_aligned + out_offset, ct0_aligned + ct0_offset, keyswitch_key,
decomposition_level_count, decomposition_base_log, input_dimension,
output_dimension);
}
void memref_batched_keyswitch_lwe_u64(
@@ -507,24 +524,44 @@ void memref_bootstrap_lwe_u64(
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, uint64_t *tlu_allocated, uint64_t *tlu_aligned,
uint64_t tlu_offset, uint64_t tlu_size, uint64_t tlu_stride,
uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
uint32_t input_lwe_dimension, uint32_t polynomial_size,
uint32_t decomposition_level_count, uint32_t decomposition_base_log,
uint32_t glwe_dimension, uint32_t precision,
mlir::concretelang::RuntimeContext *context) {
uint64_t glwe_ct_size = poly_size * (glwe_dim + 1);
uint64_t glwe_ct_size = polynomial_size * (glwe_dimension + 1);
uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size * sizeof(uint64_t));
auto tlu = tlu_aligned + tlu_offset;
CAPI_ASSERT_ERROR(
default_engine_discard_trivially_encrypt_glwe_ciphertext_u64_raw_ptr_buffers(
get_levelled_engine(), glwe_ct, glwe_ct_size,
tlu_aligned + tlu_offset, poly_size));
// Glwe trivial encryption
for (size_t i = 0; i < polynomial_size * glwe_dimension; i++) {
glwe_ct[i] = 0;
}
for (size_t i = 0; i < polynomial_size; i++) {
glwe_ct[polynomial_size * glwe_dimension + i] = tlu[i];
}
// Get fourrier bootstrap key - TODO Give a non hardcoded keyID
size_t keyId = 0;
const auto &fft = context->fft(keyId);
auto bootstrap_key = context->fourier_bootstrap_key_buffer(keyId);
// Get stack parameter
size_t scratch_size;
size_t scratch_align;
concrete_cpu_bootstrap_lwe_ciphertext_u64_scratch(
&scratch_size, &scratch_align, glwe_dimension, polynomial_size, fft);
// Allocate scratch
auto scratch = (uint8_t *)aligned_alloc(scratch_align, scratch_size);
// Bootstrap
concrete_cpu_bootstrap_lwe_ciphertext_u64(
out_aligned + out_offset, ct0_aligned + ct0_offset, glwe_ct,
bootstrap_key, decomposition_level_count, decomposition_base_log,
glwe_dimension, polynomial_size, input_lwe_dimension, fft, scratch,
scratch_size);
CAPI_ASSERT_ERROR(
fft_engine_lwe_ciphertext_discarding_bootstrap_u64_raw_ptr_buffers(
get_fft_engine(context), get_engine(context),
get_fft_fourier_bootstrap_key_u64(context), out_aligned + out_offset,
ct0_aligned + ct0_offset, glwe_ct));
free(glwe_ct);
free(scratch);
}
void memref_batched_bootstrap_lwe_u64(
@@ -563,13 +600,16 @@ void memref_wop_pbs_crt_buffer(
uint64_t in_stride_1,
// clear text lut 1D memref
uint64_t *lut_ct_allocated, uint64_t *lut_ct_aligned,
uint64_t lut_ct_offset, uint64_t lut_ct_size, uint64_t lut_ct_stride,
uint64_t lut_ct_offset, uint64_t lut_ct_size0, uint64_t lut_ct_size1,
uint64_t lut_ct_stride0, uint64_t lut_ct_stride1,
// CRT decomposition 1D memref
uint64_t *crt_decomp_allocated, uint64_t *crt_decomp_aligned,
uint64_t crt_decomp_offset, uint64_t crt_decomp_size,
uint64_t crt_decomp_stride,
// Additional crypto parameters
uint32_t lwe_small_size, uint32_t cbs_level_count, uint32_t cbs_base_log,
uint32_t ksk_level_count, uint32_t ksk_base_log, uint32_t bsk_level_count,
uint32_t bsk_base_log, uint32_t fpksk_level_count, uint32_t fpksk_base_log,
uint32_t polynomial_size,
// runtime context that hold evluation keys
mlir::concretelang::RuntimeContext *context) {
@@ -585,7 +625,15 @@ void memref_wop_pbs_crt_buffer(
// Check for the size S
assert(out_size_1 == in_size_1);
uint64_t lwe_small_dim = lwe_small_size - 1;
assert(out_size_1 == in_size_1);
uint64_t lwe_big_size = in_size_1;
uint64_t lwe_big_dim = lwe_big_size - 1;
assert(lwe_big_dim % polynomial_size == 0);
assert(lwe_big_dim % polynomial_size == 0);
uint64_t glwe_dim = lwe_big_dim / polynomial_size;
// Compute the numbers of bits to extract for each block and the total one.
uint64_t total_number_of_bits_per_block = 0;
@@ -609,38 +657,83 @@ void memref_wop_pbs_crt_buffer(
new uint64_t[lwe_small_size * total_number_of_bits_per_block]{0};
// We make a private copy to apply a subtraction on the body
auto first_cyphertext = in_aligned + in_offset;
auto first_ciphertext = in_aligned + in_offset;
auto copy_size = crt_decomp_size * lwe_big_size;
std::vector<uint64_t> in_copy(first_cyphertext, first_cyphertext + copy_size);
std::vector<uint64_t> in_copy(first_ciphertext, first_ciphertext + copy_size);
// Extraction of each bit for each block
size_t fftKeyId = 0;
const auto &fft = context->fft(fftKeyId);
size_t bskKeyId = 0;
auto bootstrap_key = context->fourier_bootstrap_key_buffer(bskKeyId);
size_t kskKeyId = 0;
auto keyswicth_key = context->keyswitch_key_buffer(kskKeyId);
for (int64_t i = crt_decomp_size - 1, extract_bits_output_offset = 0; i >= 0;
extract_bits_output_offset += number_of_bits_per_block[i--]) {
auto nb_bits_to_extract = number_of_bits_per_block[i];
auto delta_log = 64 - nb_bits_to_extract;
size_t delta_log = 64 - nb_bits_to_extract;
auto in_block = &in_copy[lwe_big_size * i];
// trick ( ct - delta/2 + delta/2^4 )
uint64_t sub = (uint64_t(1) << (uint64_t(64) - nb_bits_to_extract - 1)) -
(uint64_t(1) << (uint64_t(64) - nb_bits_to_extract - 5));
in_block[lwe_big_size - 1] -= sub;
CAPI_ASSERT_ERROR(
fft_engine_lwe_ciphertext_discarding_bit_extraction_unchecked_u64_raw_ptr_buffers(
context->get_fft_engine(), context->get_default_engine(),
context->get_fft_fourier_bsk(), context->get_ksk(),
&extract_bits_output_buffer[lwe_small_size *
extract_bits_output_offset],
in_block, nb_bits_to_extract, delta_log));
size_t scratch_size;
size_t scratch_align;
concrete_cpu_extract_bit_lwe_ciphertext_u64_scratch(
&scratch_size, &scratch_align, lwe_small_dim, lwe_big_dim, glwe_dim,
polynomial_size, fft);
// Allocate scratch
auto *scratch = (uint8_t *)aligned_alloc(scratch_align, scratch_size);
concrete_cpu_extract_bit_lwe_ciphertext_u64(
&extract_bits_output_buffer[lwe_small_size *
extract_bits_output_offset],
in_block, bootstrap_key, keyswicth_key, lwe_small_dim,
nb_bits_to_extract, lwe_big_dim, nb_bits_to_extract, delta_log,
bsk_level_count, bsk_base_log, glwe_dim, polynomial_size, lwe_small_dim,
ksk_level_count, ksk_base_log, lwe_big_dim, lwe_small_dim, fft, scratch,
scratch_size);
free(scratch);
}
size_t ct_in_count = total_number_of_bits_per_block;
size_t lut_size = 1 << ct_in_count;
size_t ct_out_count = out_size_0;
size_t lut_count = ct_out_count;
assert(lut_ct_size0 == lut_count);
assert(lut_ct_size1 == lut_size);
// Vertical packing
CAPI_ASSERT_ERROR(
fft_engine_lwe_ciphertext_vector_discarding_circuit_bootstrap_boolean_vertical_packing_u64_raw_ptr_buffers(
context->get_fft_engine(), context->get_default_engine(),
context->get_fft_fourier_bsk(), out_aligned, lwe_big_size,
crt_decomp_size, extract_bits_output_buffer, lwe_small_size,
total_number_of_bits_per_block, lut_ct_aligned + lut_ct_offset,
lut_ct_size, cbs_level_count, cbs_base_log, context->get_fpksk()));
size_t scratch_size;
size_t scratch_align;
concrete_cpu_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_u64_scratch(
&scratch_size, &scratch_align, ct_out_count, lwe_small_dim, ct_in_count,
lut_size, lut_count, glwe_dim, polynomial_size, polynomial_size,
cbs_level_count, fft);
auto *scratch = (uint8_t *)aligned_alloc(scratch_align, scratch_size);
size_t fpkskKeyId = 0;
auto fp_keyswicth_key = context->fp_keyswitch_key_buffer(fpkskKeyId);
concrete_cpu_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_u64(
out_aligned + out_offset, extract_bits_output_buffer,
lut_ct_aligned + lut_ct_offset, bootstrap_key, fp_keyswicth_key,
lwe_big_dim, ct_out_count, lwe_small_dim, ct_in_count, lut_size,
lut_count, bsk_level_count, bsk_base_log, glwe_dim, polynomial_size,
lwe_small_dim, fpksk_level_count, fpksk_base_log, lwe_big_dim, glwe_dim,
polynomial_size, glwe_dim + 1, cbs_level_count, cbs_base_log, fft,
scratch, scratch_size);
free(scratch);
}
void memref_copy_one_rank(uint64_t *src_allocated, uint64_t *src_aligned,

View File

@@ -9,6 +9,7 @@
#include "concretelang/ClientLib/Serializers.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Runtime/context.h"
#include "concretelang/ServerLib/DynamicArityCall.h"
#include "concretelang/ServerLib/DynamicModule.h"
#include "concretelang/ServerLib/DynamicRankCall.h"
@@ -23,8 +24,8 @@ using concretelang::clientlib::CircuitGate;
using concretelang::clientlib::CircuitGateShape;
using concretelang::clientlib::EvaluationKeys;
using concretelang::clientlib::PublicArguments;
using concretelang::clientlib::RuntimeContext;
using concretelang::error::StringError;
using mlir::concretelang::RuntimeContext;
outcome::checked<ServerLambda, StringError>
ServerLambda::loadFromModule(std::shared_ptr<DynamicModule> module,
@@ -70,8 +71,7 @@ ServerLambda::call(PublicArguments &args, EvaluationKeys &evaluationKeys) {
std::vector<void *> preparedArgs(args.preparedArgs.begin(),
args.preparedArgs.end());
RuntimeContext runtimeContext;
runtimeContext.evaluationKeys = evaluationKeys;
RuntimeContext runtimeContext(evaluationKeys);
preparedArgs.push_back((void *)&runtimeContext);
assert(clientParameters.outputs.size() == 1 &&

View File

@@ -37,5 +37,6 @@ add_mlir_library(
${LLVM_PTHREAD_LIB}
ConcretelangRuntime
ConcretelangClientLib
ConcretelangServerLib
Concrete)
ConcretelangServerLib)
target_include_directories(ConcretelangSupport PUBLIC ${CONCRETE_CPU_INCLUDE_DIR})

View File

@@ -3,6 +3,7 @@
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <cassert>
#include <fstream>
#include "boost/outcome.h"
@@ -18,30 +19,29 @@ void CompilationFeedback::fillFromClientParameters(
// Compute the size of secret keys
totalSecretKeysSize = 0;
for (auto sk : params.secretKeys) {
totalSecretKeysSize += sk.second.byteSize();
totalSecretKeysSize += sk.byteSize();
}
// Compute the boostrap keys size
totalBootstrapKeysSize = 0;
for (auto bsk : params.bootstrapKeys) {
auto bskParam = bsk.second;
auto inputKey = params.secretKeys.find(bskParam.inputSecretKeyID);
assert(inputKey != params.secretKeys.end());
auto outputKey = params.secretKeys.find(bskParam.outputSecretKeyID);
assert(outputKey != params.secretKeys.end());
for (auto bskParam : params.bootstrapKeys) {
assert(bskParam.inputSecretKeyID < params.secretKeys.size());
auto inputKey = params.secretKeys[bskParam.inputSecretKeyID];
totalBootstrapKeysSize += bskParam.byteSize(inputKey->second.lweSize(),
outputKey->second.lweSize());
assert(bskParam.outputSecretKeyID < params.secretKeys.size());
auto outputKey = params.secretKeys[bskParam.outputSecretKeyID];
totalBootstrapKeysSize +=
bskParam.byteSize(inputKey.lweSize(), outputKey.lweSize());
}
// Compute the keyswitch keys size
totalKeyswitchKeysSize = 0;
for (auto ksk : params.keyswitchKeys) {
auto kskParam = ksk.second;
auto inputKey = params.secretKeys.find(kskParam.inputSecretKeyID);
assert(inputKey != params.secretKeys.end());
auto outputKey = params.secretKeys.find(kskParam.outputSecretKeyID);
assert(outputKey != params.secretKeys.end());
totalKeyswitchKeysSize += kskParam.byteSize(inputKey->second.lweSize(),
outputKey->second.lweSize());
for (auto kskParam : params.keyswitchKeys) {
assert(kskParam.inputSecretKeyID < params.secretKeys.size());
auto inputKey = params.secretKeys[kskParam.inputSecretKeyID];
assert(kskParam.outputSecretKeyID < params.secretKeys.size());
auto outputKey = params.secretKeys[kskParam.outputSecretKeyID];
totalKeyswitchKeysSize +=
kskParam.byteSize(inputKey.lweSize(), outputKey.lweSize());
}
// Compute the size of inputs
totalInputsSize = 0;

View File

@@ -13,10 +13,11 @@
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
#include "concretelang/Common/BitsSize.h"
#include <concretelang/Runtime/DFRuntime.hpp>
#include <concretelang/Support/Error.h>
#include <concretelang/Support/Jit.h>
#include <concretelang/Support/logging.h>
#include "concretelang/Runtime/DFRuntime.hpp"
#include "concretelang/Runtime/context.h"
#include "concretelang/Support/Error.h"
#include "concretelang/Support/Jit.h"
#include "concretelang/Support/logging.h"
namespace mlir {
namespace concretelang {
@@ -133,8 +134,7 @@ JITLambda::call(clientlib::PublicArguments &args,
rawArgs[i++] = &arg;
}
RuntimeContext runtimeContext;
runtimeContext.evaluationKeys = evaluationKeys;
mlir::concretelang::RuntimeContext runtimeContext(evaluationKeys);
// Pointer on runtime context, the rawArgs take pointer on actual value that
// is passed to the compiled function.
auto rtCtxPtr = &runtimeContext;

View File

@@ -250,11 +250,10 @@ lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
auto dec =
fheContext.value().parameter.largeInteger.value().crtDecomposition;
auto mods = mlir::SmallVector<int64_t>(dec.begin(), dec.end());
auto polySize = fheContext.value().parameter.getPolynomialSize();
addPotentiallyNestedPass(
pm,
mlir::concretelang::createConvertFHEToTFHECrtPass(
mlir::concretelang::CrtLoweringParameters(mods, polySize)),
mlir::concretelang::CrtLoweringParameters(mods)),
enablePass);
} else if (fheContext.hasValue()) {
pipelinePrinting("FHEToTFHEScalar", pm, context);

View File

@@ -2,6 +2,7 @@
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <cassert>
#include <map>
#include <llvm/ADT/Optional.h>
@@ -161,59 +162,56 @@ createClientParametersForV0(V0FHEContext fheContext,
Variance keyswitchKeyVariance = v0Curve->getVariance(1, v0Param.nSmall, 64);
// Static client parameters from global parameters for v0
ClientParameters c;
c.secretKeys = {
{clientlib::BIG_KEY, {/*.size = */ v0Param.getNBigLweDimension()}},
};
assert(c.secretKeys.size() == clientlib::BIG_KEY);
clientlib::LweSecretKeyParam skParam;
skParam.dimension = v0Param.getNBigLweDimension();
c.secretKeys.push_back(skParam);
bool has_small_key = v0Param.nSmall != 0;
bool has_bootstrap = v0Param.brLevel != 0;
if (has_small_key) {
c.secretKeys.insert({clientlib::SMALL_KEY, {/*.size = */ v0Param.nSmall}});
assert(c.secretKeys.size() == clientlib::SMALL_KEY);
clientlib::LweSecretKeyParam skParam2;
skParam2.dimension = v0Param.nSmall;
c.secretKeys.push_back(skParam2);
}
if (has_bootstrap) {
auto inputKey = (has_small_key) ? clientlib::SMALL_KEY : clientlib::BIG_KEY;
c.bootstrapKeys = {
{
clientlib::BOOTSTRAP_KEY,
{
/*.inputSecretKeyID = */ inputKey,
/*.outputSecretKeyID = */ clientlib::BIG_KEY,
/*.level = */ v0Param.brLevel,
/*.baseLog = */ v0Param.brLogBase,
/*.glweDimension = */ v0Param.glweDimension,
/*.variance = */ bootstrapKeyVariance,
},
},
};
clientlib::BootstrapKeyParam bskParam;
bskParam.inputSecretKeyID = inputKey;
bskParam.outputSecretKeyID = clientlib::BIG_KEY;
bskParam.level = v0Param.brLevel;
bskParam.baseLog = v0Param.brLogBase;
bskParam.glweDimension = v0Param.glweDimension;
bskParam.variance = bootstrapKeyVariance;
bskParam.polynomialSize = v0Param.getPolynomialSize();
bskParam.inputLweDimension = v0Param.nSmall;
c.bootstrapKeys.push_back(bskParam);
}
if (v0Param.largeInteger.hasValue()) {
clientlib::PackingKeySwitchParam param;
clientlib::PackingKeyswitchKeyParam param;
param.inputSecretKeyID = clientlib::BIG_KEY;
param.outputSecretKeyID = clientlib::BIG_KEY;
param.level = v0Param.largeInteger->wopPBS.packingKeySwitch.level;
param.baseLog = v0Param.largeInteger->wopPBS.packingKeySwitch.baseLog;
param.bootstrapKeyID = clientlib::BOOTSTRAP_KEY;
param.glweDimension = v0Param.glweDimension;
param.polynomialSize = v0Param.getPolynomialSize();
param.inputLweDimension = v0Param.getNBigLweDimension();
param.variance = v0Curve->getVariance(v0Param.glweDimension,
v0Param.getPolynomialSize(), 64);
c.packingKeys = {
{
"fpksk_v0",
param,
},
};
c.packingKeyswitchKeys.push_back(param);
}
if (has_small_key) {
c.keyswitchKeys = {
{
clientlib::KEYSWITCH_KEY,
{
/*.inputSecretKeyID = */ clientlib::BIG_KEY,
/*.outputSecretKeyID = */ clientlib::SMALL_KEY,
/*.level = */ v0Param.ksLevel,
/*.baseLog = */ v0Param.ksLogBase,
/*.variance = */ keyswitchKeyVariance,
},
},
};
clientlib::KeyswitchKeyParam kskParam;
kskParam.inputSecretKeyID = clientlib::BIG_KEY;
kskParam.outputSecretKeyID = clientlib::SMALL_KEY;
kskParam.level = v0Param.ksLevel;
kskParam.baseLog = v0Param.ksLogBase;
kskParam.variance = keyswitchKeyVariance;
c.keyswitchKeys.push_back(kskParam);
}
c.functionName = (std::string)functionName;

View File

@@ -1,8 +1,8 @@
// RUN: concretecompiler --optimize-tfhe=false --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
// CHECK: func.func @apply_lookup_table(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{3}>>
// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bootstrapBaseLog = -1 : i32, bootstrapLevel = -1 : i32, circuitBootstrapBaseLog = -1 : i32, circuitBootstrapLevel = -1 : i32, crtDecomposition = [], keyswitchBaseLog = -1 : i32, keyswitchLevel = -1 : i32, packingKeySwitchBaseLog = -1 : i32, packingKeySwitchInputLweDimension = -1 : i32, packingKeySwitchLevel = -1 : i32, packingKeySwitchoutputPolynomialSize = -1 : i32} : (tensor<5x!TFHE.glwe<{_,_,_}{2}>>, tensor<40960xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{3}>>
// CHECK-NEXT: %0 = "TFHE.encode_lut_for_crt_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32} : (tensor<4xi64>) -> tensor<5x8192xi64>
// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bootstrapBaseLog = -1 : i32, bootstrapLevel = -1 : i32, circuitBootstrapBaseLog = -1 : i32, circuitBootstrapLevel = -1 : i32, crtDecomposition = [], keyswitchBaseLog = -1 : i32, keyswitchLevel = -1 : i32, packingKeySwitchBaseLog = -1 : i32, packingKeySwitchInputLweDimension = -1 : i32, packingKeySwitchLevel = -1 : i32, packingKeySwitchoutputPolynomialSize = -1 : i32} : (tensor<5x!TFHE.glwe<{_,_,_}{2}>>, tensor<5x8192xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{3}>>
// CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<{_,_,_}{3}>>
func.func @apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE.eint<3> {
%1 = "FHE.apply_lookup_table"(%arg0, %arg1): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<3>)

View File

@@ -2,8 +2,8 @@
// CHECK: func.func @apply_lookup_table_cst(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> {
// CHECK-NEXT: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64>
// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_woppbs"(%cst) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<128xi64>) -> tensor<40960xi64>
// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bootstrapBaseLog = -1 : i32, bootstrapLevel = -1 : i32, circuitBootstrapBaseLog = -1 : i32, circuitBootstrapLevel = -1 : i32, crtDecomposition = [], keyswitchBaseLog = -1 : i32, keyswitchLevel = -1 : i32, packingKeySwitchBaseLog = -1 : i32, packingKeySwitchInputLweDimension = -1 : i32, packingKeySwitchLevel = -1 : i32, packingKeySwitchoutputPolynomialSize = -1 : i32} : (tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<40960xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: %0 = "TFHE.encode_lut_for_crt_woppbs"(%cst) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32} : (tensor<128xi64>) -> tensor<5x8192xi64>
// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bootstrapBaseLog = -1 : i32, bootstrapLevel = -1 : i32, circuitBootstrapBaseLog = -1 : i32, circuitBootstrapLevel = -1 : i32, crtDecomposition = [], keyswitchBaseLog = -1 : i32, keyswitchLevel = -1 : i32, packingKeySwitchBaseLog = -1 : i32, packingKeySwitchInputLweDimension = -1 : i32, packingKeySwitchLevel = -1 : i32, packingKeySwitchoutputPolynomialSize = -1 : i32} : (tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<5x8192xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
func.func @apply_lookup_table_cst(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
%tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64>

View File

@@ -1,10 +0,0 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
// CHECK: func.func @main(%arg0: tensor<4xi64>) -> tensor<40960xi64> {
// CHECK-NEXT: %0 = "Concrete.encode_expand_lut_for_woppbs_tensor"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
// CHECK-NEXT: return %0 : tensor<40960xi64>
// CHECK-NEXT: }
func.func @main(%arg1: tensor<4xi64>) -> tensor<40960xi64> {
%0 = "TFHE.encode_expand_lut_for_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
return %0: tensor<40960xi64>
}

View File

@@ -0,0 +1,10 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
// CHECK: func.func @main(%arg0: tensor<4xi64>) -> tensor<5x8192xi64> {
// CHECK-NEXT: %0 = "Concrete.encode_lut_for_crt_woppbs_tensor"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32} : (tensor<4xi64>) -> tensor<5x8192xi64>
// CHECK-NEXT: return %0 : tensor<5x8192xi64>
// CHECK-NEXT: }
func.func @main(%arg1: tensor<4xi64>) -> tensor<5x8192xi64> {
%0 = "TFHE.encode_lut_for_crt_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32} : (tensor<4xi64>) -> tensor<5x8192xi64>
return %0: tensor<5x8192xi64>
}

View File

@@ -3,6 +3,7 @@
#include <gtest/gtest.h>
#include <type_traits>
#include "concretelang/ClientLib/Serializers.h"
#include "concretelang/Support/CompilationFeedback.h"
#include "concretelang/Support/JITSupport.h"
#include "concretelang/Support/LibrarySupport.h"
@@ -72,6 +73,13 @@ public:
void testOnce() {
auto evaluationKeys = keySet->evaluationKeys();
/* Serialize and unserialize evaluation keys */
std::stringstream stream;
stream << evaluationKeys;
stream.seekg(0, std::ios::beg);
evaluationKeys = concretelang::clientlib::readEvaluationKeys(stream);
/* Call the server lambda */
auto publicResult =
support.serverCall(serverLambda, *publicArguments, evaluationKeys);

View File

@@ -1,3 +1,4 @@
#include <cassert>
#include <gtest/gtest.h>
#include "concretelang/ClientLib/ClientParameters.h"
@@ -8,38 +9,42 @@ namespace clientlib = concretelang::clientlib;
TEST(Support, client_parameters_json_serde) {
clientlib::ClientParameters params0;
params0.secretKeys = {
{clientlib::SMALL_KEY, {/*.size = */ 12}},
{clientlib::BIG_KEY, {/*.size = */ 14}},
};
params0.bootstrapKeys = {
{clientlib::BOOTSTRAP_KEY,
{/*.inputSecretKeyID = */ clientlib::SMALL_KEY,
/*.outputSecretKeyID = */ clientlib::BIG_KEY,
/*.level = */ 1,
/*.baseLog = */ 2,
/*.glweDimension = */ 3,
/*.variance = */ 0.001}},
{"wtf_bsk_v0",
{
/*.inputSecretKeyID = */ clientlib::BIG_KEY,
/*.outputSecretKeyID = */ clientlib::SMALL_KEY,
/*.level = */ 3,
/*.baseLog = */ 2,
/*.glweDimension = */ 1,
/*.variance = */ 0.0001,
}},
};
params0.keyswitchKeys = {{clientlib::KEYSWITCH_KEY,
{
/*.inputSecretKeyID = */
clientlib::BIG_KEY,
/*.outputSecretKeyID = */
clientlib::SMALL_KEY,
/*.level = */ 1,
/*.baseLog = */ 2,
/*.variance = */ 3,
}}};
assert(params0.secretKeys.size() == clientlib::BIG_KEY);
params0.secretKeys.push_back({14});
assert(params0.secretKeys.size() == clientlib::SMALL_KEY);
params0.secretKeys.push_back({12});
params0.bootstrapKeys.push_back({
/*.inputSecretKeyID = */ clientlib::SMALL_KEY,
/*.outputSecretKeyID = */ clientlib::BIG_KEY,
/*.level = */ 1,
/*.baseLog = */ 2,
/*.glweDimension = */ 3,
/*.variance = */ 0.001,
/*.polynomialSize = */ 1024,
/*.inputLweDimension = */ 600,
});
params0.bootstrapKeys.push_back({
/*.inputSecretKeyID = */ clientlib::BIG_KEY,
/*.outputSecretKeyID = */ clientlib::SMALL_KEY,
/*.level = */ 3,
/*.baseLog = */ 2,
/*.glweDimension = */ 1,
/*.variance = */ 0.0001,
/*.polynomialSize = */ 1024,
/*.inputLweDimension = */ 600,
});
params0.keyswitchKeys.push_back({
/*.inputSecretKeyID = */
clientlib::BIG_KEY,
/*.outputSecretKeyID = */
clientlib::SMALL_KEY,
/*.level = */ 1,
/*.baseLog = */ 2,
/*.variance = */ 3,
});
params0.inputs = {
{
/*.encryption = */ {

View File

@@ -3,6 +3,7 @@
#include "concrete/curves.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/EncryptedArguments.h"
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "tests_tools/assert.h"
namespace clientlib = concretelang::clientlib;
@@ -19,9 +20,13 @@ TEST_P(KeySetTest, encrypt_decrypt) {
auto clientParameters = GetParam();
__uint128_t seed = 0;
// Generate the client keySet
ASSERT_ASSIGN_OUTCOME_VALUE(
keySet, clientlib::KeySet::generate(clientParameters, 0, 0));
keySet,
clientlib::KeySet::generate(
clientParameters, concretelang::clientlib::ConcreteCSPRNG(seed)));
// Allocate the ciphertext
uint64_t *ciphertext = nullptr;
@@ -50,13 +55,13 @@ clientlib::ClientParameters generateClientParameterOneScalarOneScalar(
clientlib::CRTDecomposition crtDecomposition) {
// One secret key with the given dimension
clientlib::ClientParameters params;
params.secretKeys.insert({clientlib::SMALL_KEY, {/*.dimension =*/dimension}});
params.secretKeys.push_back({/*.dimension =*/dimension});
// One input and output encryption gate on the same secret key and encoded
// with the same precision
const auto v0Curve = concrete::getSecurityCurve(128, concrete::BINARY);
clientlib::EncryptionGate encryption;
encryption.secretKeyID = clientlib::SMALL_KEY;
encryption.secretKeyID = clientlib::BIG_KEY;
encryption.encoding.precision = precision;
encryption.encoding.crt = crtDecomposition;
encryption.variance = v0Curve->getVariance(1, dimension, 64);